diff --git a/src/lib/utilities/focus-trap.ts b/src/lib/utilities/focus-trap.ts index d1dfd3d8e4..c81362301a 100644 --- a/src/lib/utilities/focus-trap.ts +++ b/src/lib/utilities/focus-trap.ts @@ -12,6 +12,7 @@ export const getFocusableElements = (node: HTMLElement) => export const focusTrap = (node: HTMLElement, enabled: boolean) => { let firstFocusable: HTMLElement; let lastFocusable: HTMLElement; + let previouslyFocused: HTMLElement | null = null; const onKeydown = (event: KeyboardEvent) => { if (event.key === 'Tab') { @@ -27,6 +28,11 @@ export const focusTrap = (node: HTMLElement, enabled: boolean) => { } }; + const removeListeners = () => { + firstFocusable?.removeEventListener('keydown', onKeydown); + lastFocusable?.removeEventListener('keydown', onKeydown); + }; + const setFocus = (fromObserver: boolean = false) => { if (enabled === false) return; @@ -34,15 +40,26 @@ export const focusTrap = (node: HTMLElement, enabled: boolean) => { firstFocusable = focusable[0]; lastFocusable = focusable[focusable.length - 1]; - if (!fromObserver) firstFocusable?.focus(); + if (!fromObserver) { + if ( + previouslyFocused === null && + document.activeElement instanceof HTMLElement + ) { + previouslyFocused = document.activeElement; + } + firstFocusable?.focus(); + } firstFocusable?.addEventListener('keydown', onKeydown); lastFocusable?.addEventListener('keydown', onKeydown); }; const cleanUp = () => { - firstFocusable?.removeEventListener('keydown', onKeydown); - lastFocusable?.removeEventListener('keydown', onKeydown); + removeListeners(); + if (previouslyFocused && document.body.contains(previouslyFocused)) { + previouslyFocused.focus(); + } + previouslyFocused = null; }; const onChange = ( @@ -50,7 +67,7 @@ export const focusTrap = (node: HTMLElement, enabled: boolean) => { observer: MutationObserver, ) => { if (mutationRecords.length) { - cleanUp(); + removeListeners(); setFocus(true); } return observer;