Skip to content

Commit

Permalink
fix(vanilla): should update derived atoms during write (#2111)
Browse files Browse the repository at this point in the history
* add failing test

* revert #2086 change

* fix #2086 with a different approach
  • Loading branch information
dai-shi authored Sep 7, 2023
1 parent f7d7f74 commit ae91b58
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 21 deletions.
39 changes: 18 additions & 21 deletions src/vanilla/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -462,15 +462,20 @@ export const createStore = () => {
}
}

const recomputeDependents = (updatedAtoms: Set<AnyAtom>): void => {
if (!updatedAtoms.size) {
return
}
const recomputeDependents = (atom: AnyAtom): void => {
const dependencyMap = new Map<AnyAtom, Set<AnyAtom>>()
const dirtyMap = new WeakMap<AnyAtom, number>()
const getDependents = (a: AnyAtom): Dependents => {
const dependents = new Set(mountedMap.get(a)?.t)
pendingMap.forEach((_, pendingAtom) => {
if (getAtomState(pendingAtom)?.d.has(a)) {
dependents.add(pendingAtom)
}
})
return dependents
}
const loop1 = (a: AnyAtom) => {
const mounted = mountedMap.get(a)
mounted?.t.forEach((dependent) => {
getDependents(a).forEach((dependent) => {
if (dependent !== a) {
dependencyMap.set(
dependent,
Expand All @@ -481,10 +486,9 @@ export const createStore = () => {
}
})
}
updatedAtoms.forEach(loop1)
loop1(atom)
const loop2 = (a: AnyAtom) => {
const mounted = mountedMap.get(a)
mounted?.t.forEach((dependent) => {
getDependents(a).forEach((dependent) => {
if (dependent !== a) {
let dirtyCount = dirtyMap.get(dependent)
if (dirtyCount) {
Expand All @@ -507,12 +511,10 @@ export const createStore = () => {
}
})
}
updatedAtoms.forEach(loop2)
updatedAtoms.clear()
loop2(atom)
}

const writeAtomState = <Value, Args extends unknown[], Result>(
updatedAtoms: Set<AnyAtom>,
atom: WritableAtom<Value, Args, Result>,
...args: Args
): Result => {
Expand All @@ -531,13 +533,12 @@ export const createStore = () => {
const prevAtomState = getAtomState(a)
const nextAtomState = setAtomValueOrPromise(a, args[0] as V)
if (!prevAtomState || !isEqualAtomValue(prevAtomState, nextAtomState)) {
updatedAtoms.add(a)
recomputeDependents(a)
}
} else {
r = writeAtomState(updatedAtoms, a as AnyWritableAtom, ...args) as R
r = writeAtomState(a as AnyWritableAtom, ...args) as R
}
if (!isSync) {
recomputeDependents(updatedAtoms)
const flushed = flushPending()
if (import.meta.env?.MODE !== 'production') {
storeListenersRev2.forEach((l) =>
Expand All @@ -556,9 +557,7 @@ export const createStore = () => {
atom: WritableAtom<Value, Args, Result>,
...args: Args
): Result => {
const updatedAtoms = new Set<AnyAtom>()
const result = writeAtomState(updatedAtoms, atom, ...args)
recomputeDependents(updatedAtoms)
const result = writeAtomState(atom, ...args)
const flushed = flushPending()
if (import.meta.env?.MODE !== 'production') {
storeListenersRev2.forEach((l) =>
Expand Down Expand Up @@ -761,14 +760,12 @@ export const createStore = () => {
dev_get_atom_state: (a: AnyAtom) => atomStateMap.get(a),
dev_get_mounted: (a: AnyAtom) => mountedMap.get(a),
dev_restore_atoms: (values: Iterable<readonly [AnyAtom, AnyValue]>) => {
const updatedAtoms = new Set<AnyAtom>()
for (const [atom, valueOrPromise] of values) {
if (hasInitialValue(atom)) {
setAtomValueOrPromise(atom, valueOrPromise)
updatedAtoms.add(atom)
recomputeDependents(atom)
}
}
recomputeDependents(updatedAtoms)
const flushed = flushPending()
storeListenersRev2.forEach((l) =>
l({ type: 'restore', flushed: flushed as Set<AnyAtom> })
Expand Down
20 changes: 20 additions & 0 deletions tests/vanilla/store.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,23 @@ it("should recompute dependents' state after onMount (#2098)", async () => {
expect(store.get(derivedAtom)).toBe(false)
expect(store.get(finalAtom)).toBe(false)
})

it('should update derived atoms during write (#2107)', async () => {
const store = createStore()

const baseCountAtom = atom(1)
const countAtom = atom(
(get) => get(baseCountAtom),
(get, set, newValue: number) => {
set(baseCountAtom, newValue)
if (get(countAtom) !== newValue) {
throw new Error('mismatch')
}
}
)

store.sub(countAtom, () => {})
expect(store.get(countAtom)).toBe(1)
store.set(countAtom, 2)
expect(store.get(countAtom)).toBe(2)
})

1 comment on commit ae91b58

@vercel
Copy link

@vercel vercel bot commented on ae91b58 Sep 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.