From 789d671c9cf2d56eb76501138de0cfd826875d46 Mon Sep 17 00:00:00 2001 From: AdamW Date: Fri, 1 Nov 2024 17:32:09 +0200 Subject: [PATCH] fix(utils): create a stable reference to atomWithDefault's fallback function --- src/vanilla/utils/unwrap.ts | 35 ++++++---- .../vanilla-utils/atomWithDefault.test.tsx | 67 ++++++++++++++++++- tests/vanilla/utils/unwrap.test.ts | 12 ++++ 3 files changed, 99 insertions(+), 15 deletions(-) diff --git a/src/vanilla/utils/unwrap.ts b/src/vanilla/utils/unwrap.ts index 8166b0b452..bd4df6ded2 100644 --- a/src/vanilla/utils/unwrap.ts +++ b/src/vanilla/utils/unwrap.ts @@ -12,6 +12,9 @@ const memo2 = (create: () => T, dep1: object, dep2: object): T => { const isPromise = (x: unknown): x is Promise => x instanceof Promise const defaultFallback = () => undefined +const fallbackCache = getCached(() => new WeakMap(), cache1, defaultFallback) +const getStableFallback = (fn: (prev?: T) => U, key: object): typeof fn => + getCached(() => fn, fallbackCache, key) export function unwrap( anAtom: WritableAtom, @@ -35,6 +38,11 @@ export function unwrap( anAtom: WritableAtom | Atom, fallback: (prev?: Awaited) => PendingValue = defaultFallback as never, ) { + const stableFallback = + fallback === defaultFallback + ? fallback + : getStableFallback(fallback, anAtom) + return memo2( () => { type PromiseAndValue = { readonly p?: Promise } & ( @@ -60,17 +68,16 @@ export function unwrap( return { v: promise as Awaited } } if (promise !== prev?.p) { - promise - .then( - (v) => { - promiseResultCache.set(promise, v as Awaited) - setSelf() - }, - (e) => { - promiseErrorCache.set(promise, e) - setSelf() - } - ) + promise.then( + (v) => { + promiseResultCache.set(promise, v as Awaited) + setSelf() + }, + (e) => { + promiseErrorCache.set(promise, e) + setSelf() + }, + ) } if (promiseErrorCache.has(promise)) { throw promiseErrorCache.get(promise) @@ -82,9 +89,9 @@ export function unwrap( } } if (prev && 'v' in prev) { - return { p: promise, f: fallback(prev.v), v: prev.v } + return { p: promise, f: stableFallback(prev.v), v: prev.v } } - return { p: promise, f: fallback() } + return { p: promise, f: stableFallback() } }, (_get, set) => { set(refreshAtom, (c) => c + 1) @@ -111,6 +118,6 @@ export function unwrap( ) }, anAtom, - fallback, + stableFallback, ) } diff --git a/tests/react/vanilla-utils/atomWithDefault.test.tsx b/tests/react/vanilla-utils/atomWithDefault.test.tsx index 9c5c1dd0f2..396876040e 100644 --- a/tests/react/vanilla-utils/atomWithDefault.test.tsx +++ b/tests/react/vanilla-utils/atomWithDefault.test.tsx @@ -4,7 +4,7 @@ import userEvent from '@testing-library/user-event' import { expect, it } from 'vitest' import { useAtom } from 'jotai/react' import { atom } from 'jotai/vanilla' -import { RESET, atomWithDefault } from 'jotai/vanilla/utils' +import { RESET, atomWithDefault, unwrap } from 'jotai/vanilla/utils' it('simple sync get default', async () => { const count1Atom = atom(1) @@ -228,3 +228,68 @@ it('can be set synchronously by passing value', async () => { expect(screen.getByText('count: 10')).toBeDefined() }) + +it('derive default from an unwrapped atom', async () => { + let resolve = () => {} + const anAsyncAtom = atom(async () => { + await new Promise((r) => (resolve = r)) + return 1 + }) + const defaultWithUnwrap = atomWithDefault((get) => get(unwrap(anAsyncAtom))) + + const Component = () => { + const [value] = useAtom(defaultWithUnwrap) + + if (value === undefined) { + return
loading
+ } + + return ( + <> +
value: {value}
+ + ) + } + + const { findByText } = render( + + + , + ) + + await findByText('loading') + resolve() + + await findByText('value: 1') +}) + +it('derive default from an unwrapped atom (explicit fallback)', async () => { + let resolve = () => {} + const anAsyncAtom = atom(async () => { + await new Promise((r) => (resolve = r)) + return 1 + }) + const defaultWithUnwrap = atomWithDefault((get) => + get(unwrap(anAsyncAtom, () => undefined)), + ) + + const Component = () => { + const [value] = useAtom(defaultWithUnwrap) + return ( + <> +
value: {value}
+ + ) + } + + const { findByText } = render( + + + , + ) + + await findByText('value:') + resolve() + + await findByText('value: 1') +}) diff --git a/tests/vanilla/utils/unwrap.test.ts b/tests/vanilla/utils/unwrap.test.ts index be294dde54..fa57a6f0b3 100644 --- a/tests/vanilla/utils/unwrap.test.ts +++ b/tests/vanilla/utils/unwrap.test.ts @@ -149,4 +149,16 @@ describe('unwrap', () => { expect(store.get(syncAtom)).toEqual('concrete') }) + + it('should get a fulfilled value after the promise resolves (explicit fallback function)', async () => { + const store = createStore() + const asyncAtom = atom(Promise.resolve('concrete')) + const syncAtom = unwrap(asyncAtom, (prev) => prev ?? 'fallback') + + expect(store.get(syncAtom)).toEqual('fallback') + + await store.get(asyncAtom) + + expect(store.get(syncAtom)).toEqual('concrete') + }) })