diff --git a/src/app/content/components/ContentWarning.spec.tsx b/src/app/content/components/ContentWarning.spec.tsx
index 80b59e3a01..09b45cd1d1 100644
--- a/src/app/content/components/ContentWarning.spec.tsx
+++ b/src/app/content/components/ContentWarning.spec.tsx
@@ -35,13 +35,14 @@ describe('ContentWarning', () => {
renderToDom();
expect(services.osWebLoader.getBookFromId).toBeCalledWith(dummyBook.id);
-
await act(() => new Promise((resolve) => setTimeout(resolve, 1)));
const root = document?.body;
const b = root?.querySelector('button');
expect(b).toBeTruthy();
+ // Exercises the when-focus-is-already-in-the-modal branch
+ b!.focus();
act(() => ReactTestUtils.Simulate.click(b!));
expect(root?.querySelector('button')).toBeFalsy();
});
diff --git a/src/app/content/components/ContentWarning.tsx b/src/app/content/components/ContentWarning.tsx
index 282561f3e5..bdf4fe7051 100644
--- a/src/app/content/components/ContentWarning.tsx
+++ b/src/app/content/components/ContentWarning.tsx
@@ -9,6 +9,7 @@ import Modal from './Modal';
import theme from '../../theme';
import Cookies from 'js-cookie';
import { useTrapTabNavigation } from '../../reactUtils';
+import { assertDocument } from '../../utils';
// tslint:disable-next-line
const WarningDiv = styled.div`
@@ -37,7 +38,23 @@ function WarningDivWithTrap({
}) {
const ref = React.useRef(null);
- React.useEffect(() => ref.current?.focus(), []);
+ // Demand focus
+ React.useEffect(
+ () => {
+ const document = assertDocument();
+ const grabFocus = () => {
+ if (!ref.current?.contains(document.activeElement)) {
+ ref.current?.focus();
+ }
+ };
+
+ grabFocus();
+ document.body.addEventListener('focusin', grabFocus);
+
+ return () => document.body.removeEventListener('focusin', grabFocus);
+ },
+ []
+ );
useTrapTabNavigation(ref);
diff --git a/src/app/reactUtils.ts b/src/app/reactUtils.ts
index d79ae2af6a..c66c004960 100644
--- a/src/app/reactUtils.ts
+++ b/src/app/reactUtils.ts
@@ -99,11 +99,10 @@ export function createTrapTab(...elements: HTMLElement[]) {
return;
}
const trapTab = createTrapTab(el);
- const document = assertDocument();
- document.body.addEventListener('keydown', trapTab, true);
+ el.addEventListener('keydown', trapTab, true);
- return () => document.body.removeEventListener('keydown', trapTab, true);
+ return () => el.removeEventListener('keydown', trapTab, true);
},
[ref, otherDep]
);