diff --git a/src/app/content/components/ContentWarning.spec.tsx b/src/app/content/components/ContentWarning.spec.tsx index 80b59e3a01..09b45cd1d1 100644 --- a/src/app/content/components/ContentWarning.spec.tsx +++ b/src/app/content/components/ContentWarning.spec.tsx @@ -35,13 +35,14 @@ describe('ContentWarning', () => { renderToDom(); expect(services.osWebLoader.getBookFromId).toBeCalledWith(dummyBook.id); - await act(() => new Promise((resolve) => setTimeout(resolve, 1))); const root = document?.body; const b = root?.querySelector('button'); expect(b).toBeTruthy(); + // Exercises the when-focus-is-already-in-the-modal branch + b!.focus(); act(() => ReactTestUtils.Simulate.click(b!)); expect(root?.querySelector('button')).toBeFalsy(); }); diff --git a/src/app/content/components/ContentWarning.tsx b/src/app/content/components/ContentWarning.tsx index 282561f3e5..bdf4fe7051 100644 --- a/src/app/content/components/ContentWarning.tsx +++ b/src/app/content/components/ContentWarning.tsx @@ -9,6 +9,7 @@ import Modal from './Modal'; import theme from '../../theme'; import Cookies from 'js-cookie'; import { useTrapTabNavigation } from '../../reactUtils'; +import { assertDocument } from '../../utils'; // tslint:disable-next-line const WarningDiv = styled.div` @@ -37,7 +38,23 @@ function WarningDivWithTrap({ }) { const ref = React.useRef(null); - React.useEffect(() => ref.current?.focus(), []); + // Demand focus + React.useEffect( + () => { + const document = assertDocument(); + const grabFocus = () => { + if (!ref.current?.contains(document.activeElement)) { + ref.current?.focus(); + } + }; + + grabFocus(); + document.body.addEventListener('focusin', grabFocus); + + return () => document.body.removeEventListener('focusin', grabFocus); + }, + [] + ); useTrapTabNavigation(ref); diff --git a/src/app/reactUtils.ts b/src/app/reactUtils.ts index d79ae2af6a..c66c004960 100644 --- a/src/app/reactUtils.ts +++ b/src/app/reactUtils.ts @@ -99,11 +99,10 @@ export function createTrapTab(...elements: HTMLElement[]) { return; } const trapTab = createTrapTab(el); - const document = assertDocument(); - document.body.addEventListener('keydown', trapTab, true); + el.addEventListener('keydown', trapTab, true); - return () => document.body.removeEventListener('keydown', trapTab, true); + return () => el.removeEventListener('keydown', trapTab, true); }, [ref, otherDep] );