From 8e5238ecfcacf2489899426aa46c19aa36c196ae Mon Sep 17 00:00:00 2001 From: Bilal Shafi Date: Thu, 1 Aug 2024 22:31:50 +0500 Subject: [PATCH] Introduce selectors with arguments --- .../GridCellCheckboxRenderer.tsx | 9 +- .../src/hooks/features/rowSelection/utils.ts | 134 +++++++------ .../src/hooks/utils/useGridSelectorV8.ts | 83 ++++++++ .../x-data-grid/src/utils/createSelectorV8.ts | 181 ++++++++++++++++++ 4 files changed, 335 insertions(+), 72 deletions(-) create mode 100644 packages/x-data-grid/src/hooks/utils/useGridSelectorV8.ts create mode 100644 packages/x-data-grid/src/utils/createSelectorV8.ts diff --git a/packages/x-data-grid/src/components/columnSelection/GridCellCheckboxRenderer.tsx b/packages/x-data-grid/src/components/columnSelection/GridCellCheckboxRenderer.tsx index 3d8284edde959..5bbb0cf4322e5 100644 --- a/packages/x-data-grid/src/components/columnSelection/GridCellCheckboxRenderer.tsx +++ b/packages/x-data-grid/src/components/columnSelection/GridCellCheckboxRenderer.tsx @@ -10,8 +10,8 @@ import { useGridRootProps } from '../../hooks/utils/useGridRootProps'; import { getDataGridUtilityClass } from '../../constants/gridClasses'; import type { DataGridProcessedProps } from '../../models/props/DataGridProps'; import type { GridRowSelectionCheckboxParams } from '../../models/params/gridRowSelectionCheckboxParams'; -import { useGridSelector } from '../../hooks/utils/useGridSelector'; -import { getGridSomeChildrenSelectedSelector } from '../../hooks/features/rowSelection/utils'; +import { useGridSelectorV8 } from '../../hooks/utils/useGridSelectorV8'; +import { gridSomeChildrenSelectedSelector } from '../../hooks/features/rowSelection/utils'; type OwnerState = { classes: DataGridProcessedProps['classes'] }; @@ -52,8 +52,9 @@ const GridCellCheckboxForwardRef = React.forwardRef(null); - const someChildrenSelectedSelector = getGridSomeChildrenSelectedSelector(id); - const someChildrenSelected = useGridSelector(apiRef, someChildrenSelectedSelector); + const someChildrenSelected = useGridSelectorV8(apiRef, gridSomeChildrenSelectedSelector, { + groupId: id, + }); const rippleRef = React.useRef(null); const handleRef = useForkRef(checkboxElement, ref); diff --git a/packages/x-data-grid/src/hooks/features/rowSelection/utils.ts b/packages/x-data-grid/src/hooks/features/rowSelection/utils.ts index ed72f8abf1d90..ab8260bf71a71 100644 --- a/packages/x-data-grid/src/hooks/features/rowSelection/utils.ts +++ b/packages/x-data-grid/src/hooks/features/rowSelection/utils.ts @@ -2,75 +2,72 @@ import type { DataGridProcessedProps } from '../../../models/props/DataGridProps import { GridSignature } from '../../utils/useGridApiEventHandler'; import { GRID_ROOT_GROUP_ID } from '../rows/gridRowsUtils'; import type { GridGroupNode, GridRowId, GridRowTreeConfig } from '../../../models/gridRows'; -import type { - GridPrivateApiCommunity, - GridApiCommunity, -} from '../../../models/api/gridApiCommunity'; +import type { GridPrivateApiCommunity } from '../../../models/api/gridApiCommunity'; import { gridFilteredRowsLookupSelector } from '../filter/gridFilterSelector'; import { gridSortedRowIdsSelector } from '../sorting/gridSortingSelector'; import { selectedIdsLookupSelector } from './gridRowSelectionSelector'; import { gridRowTreeSelector } from '../rows/gridRowsSelector'; -import { createSelector } from '../../../utils/createSelector'; - -function getGridRowGroupSelectableChildrenSelector( - apiRef: React.MutableRefObject, - groupId: GridRowId, -) { - return createSelector( - gridRowTreeSelector, - gridSortedRowIdsSelector, - gridFilteredRowsLookupSelector, - (rowTree, sortedRowIds, filteredRowsLookup) => { - const groupNode = rowTree[groupId]; - if (!groupNode || groupNode.type !== 'group') { - return []; - } - - const children: GridRowId[] = []; - - const startIndex = sortedRowIds.findIndex((id) => id === groupId) + 1; - for ( - let index = startIndex; - index < sortedRowIds.length && rowTree[sortedRowIds[index]]?.depth > groupNode.depth; - index += 1 - ) { - const id = sortedRowIds[index]; - if (filteredRowsLookup[id] !== false && apiRef.current.isRowSelectable(id)) { - children.push(id); - } - } +import { createSelectorV8 } from '../../../utils/createSelectorV8'; + +export const gridRowGroupSelectableChildrenSelector = createSelectorV8( + gridRowTreeSelector, + gridSortedRowIdsSelector, + gridFilteredRowsLookupSelector, + (rowTree, sortedRowIds, filteredRowsLookup, args) => { + const children = new Set(); + const { groupId, apiRef } = args; + if (groupId === undefined || apiRef === undefined) { return children; - }, - ); -} + } + const groupNode = rowTree[groupId]; + if (!groupNode || groupNode.type !== 'group') { + return children; + } -export function getGridSomeChildrenSelectedSelector(groupId: GridRowId) { - return createSelector( - gridRowTreeSelector, - gridSortedRowIdsSelector, - gridFilteredRowsLookupSelector, - selectedIdsLookupSelector, - (rowTree, sortedRowIds, filteredRowsLookup, rowSelectionLookup) => { - const groupNode = rowTree[groupId]; - if (!groupNode || groupNode.type !== 'group') { - return false; + const startIndex = sortedRowIds.findIndex((id) => id === groupId) + 1; + for ( + let index = startIndex; + index < sortedRowIds.length && rowTree[sortedRowIds[index]]?.depth > groupNode.depth; + index += 1 + ) { + const id = sortedRowIds[index]; + if (filteredRowsLookup[id] !== false && apiRef.current.isRowSelectable(id)) { + children.add(id); } + } + return children; + }, +); + +export const gridSomeChildrenSelectedSelector = createSelectorV8( + gridRowTreeSelector, + gridSortedRowIdsSelector, + gridFilteredRowsLookupSelector, + selectedIdsLookupSelector, + (rowTree, sortedRowIds, filteredRowsLookup, rowSelectionLookup, args) => { + const groupId = args.groupId; + if (groupId === undefined) { + return false; + } + const groupNode = rowTree[groupId]; + if (!groupNode || groupNode.type !== 'group') { + return false; + } - const startIndex = sortedRowIds.findIndex((id) => id === groupId) + 1; - for ( - let index = startIndex; - index < sortedRowIds.length && rowTree[sortedRowIds[index]]?.depth > groupNode.depth; - index += 1 - ) { - const id = sortedRowIds[index]; - if (filteredRowsLookup[id] !== false && rowSelectionLookup[id] !== undefined) { - return true; - } + const startIndex = sortedRowIds.findIndex((id) => id === groupId) + 1; + for ( + let index = startIndex; + index < sortedRowIds.length && rowTree[sortedRowIds[index]]?.depth > groupNode.depth; + index += 1 + ) { + const id = sortedRowIds[index]; + if (filteredRowsLookup[id] !== false && rowSelectionLookup[id] !== undefined) { + return true; } - return false; - }, - ); -} + } + return false; + }, +); export function isMultipleRowSelectionEnabled( props: Pick< @@ -149,9 +146,11 @@ export const findRowsToSelect = ( const rowNode = apiRef.current.getRowNode(selectedRow); if (rowNode?.type === 'group') { - const rowGroupChildrenSelector = getGridRowGroupSelectableChildrenSelector(apiRef, selectedRow); - const children = rowGroupChildrenSelector(apiRef); - return rowsToSelect.concat(children); + const children = gridRowGroupSelectableChildrenSelector(apiRef, { + groupId: selectedRow, + apiRef, + }); + return rowsToSelect.concat(Array.from(children)); } return rowsToSelect; }; @@ -173,12 +172,11 @@ export const findRowsToDeselect = ( const rowNode = apiRef.current.getRowNode(deselectedRow); if (rowNode?.type === 'group') { - const rowGroupChildrenSelector = getGridRowGroupSelectableChildrenSelector( + const children = gridRowGroupSelectableChildrenSelector(apiRef, { + groupId: deselectedRow, apiRef, - deselectedRow, - ); - const children = rowGroupChildrenSelector(apiRef); - return rowsToDeselect.concat(children); + }); + return rowsToDeselect.concat(Array.from(children)); } return rowsToDeselect; }; diff --git a/packages/x-data-grid/src/hooks/utils/useGridSelectorV8.ts b/packages/x-data-grid/src/hooks/utils/useGridSelectorV8.ts new file mode 100644 index 0000000000000..7b7b8a7ee81ef --- /dev/null +++ b/packages/x-data-grid/src/hooks/utils/useGridSelectorV8.ts @@ -0,0 +1,83 @@ +import * as React from 'react'; +import type { GridApiCommon } from '../../models/api/gridApiCommon'; +import { OutputSelectorV8 } from '../../utils/createSelectorV8'; +import { useLazyRef } from './useLazyRef'; +import { useOnMount } from './useOnMount'; +import { warnOnce } from '../../internals/utils/warning'; +import type { GridCoreApi } from '../../models/api/gridCoreApi'; +import { fastObjectShallowCompare } from '../../utils/fastObjectShallowCompare'; + +function isOutputSelector( + selector: any, +): selector is OutputSelectorV8 { + return selector.acceptsApiRef; +} + +function applySelectorV8( + apiRef: React.MutableRefObject, + selector: ((state: Api['state']) => T) | OutputSelectorV8, + args: Args, + instanceId: GridCoreApi['instanceId'], +) { + if (isOutputSelector(selector)) { + return selector(apiRef, args); + } + return selector(apiRef.current.state, instanceId); +} + +const defaultCompare = Object.is; +export const objectShallowCompare = fastObjectShallowCompare; + +const createRefs = () => ({ state: null, equals: null, selector: null }) as any; + +export const useGridSelectorV8 = ( + apiRef: React.MutableRefObject, + selector: ((state: Api['state']) => T) | OutputSelectorV8, + args: Args = {} as Args, + equals: (a: T, b: T) => boolean = defaultCompare, +) => { + if (process.env.NODE_ENV !== 'production') { + if (!apiRef.current.state) { + warnOnce([ + 'MUI X: `useGridSelector` has been called before the initialization of the state.', + 'This hook can only be used inside the context of the grid.', + ]); + } + } + + const refs = useLazyRef< + { + state: T; + equals: typeof equals; + selector: typeof selector; + }, + never + >(createRefs); + const didInit = refs.current.selector !== null; + + const [state, setState] = React.useState( + // We don't use an initialization function to avoid allocations + (didInit ? null : applySelectorV8(apiRef, selector, args, apiRef.current.instanceId)) as T, + ); + + refs.current.state = state; + refs.current.equals = equals; + refs.current.selector = selector; + + useOnMount(() => { + return apiRef.current.store.subscribe(() => { + const newState = applySelectorV8( + apiRef, + refs.current.selector, + args, + apiRef.current.instanceId, + ) as T; + if (!refs.current.equals(refs.current.state, newState)) { + refs.current.state = newState; + setState(newState); + } + }); + }); + + return state; +}; diff --git a/packages/x-data-grid/src/utils/createSelectorV8.ts b/packages/x-data-grid/src/utils/createSelectorV8.ts new file mode 100644 index 0000000000000..b64c12c3c3547 --- /dev/null +++ b/packages/x-data-grid/src/utils/createSelectorV8.ts @@ -0,0 +1,181 @@ +import * as React from 'react'; +import { createSelector as reselectCreateSelector, Selector, SelectorResultArray } from 'reselect'; +import type { GridCoreApi } from '../models/api/gridCoreApi'; +import { warnOnce } from '../internals/utils/warning'; + +type CacheKey = { id: number }; + +export interface OutputSelectorV8 { + ( + apiRef: React.MutableRefObject<{ state: State; instanceId: GridCoreApi['instanceId'] }>, + args: Args, + ): Result; + (state: State, instanceId: GridCoreApi['instanceId']): Result; + acceptsApiRef: boolean; +} + +type StateFromSelector = T extends (first: infer F, ...args: any[]) => any + ? F extends { state: infer F2 } + ? F2 + : F + : never; + +type StateFromSelectorList = Selectors extends [ + f: infer F, + ...other: infer R, +] + ? StateFromSelector extends StateFromSelectorList + ? StateFromSelector + : StateFromSelectorList + : {}; + +type SelectorResultArrayWithAdditionalArgs>> = [ + ...SelectorResultArray, + Record, +]; + +type SelectorArgs>, Result> = + // Input selectors as a separate array + | [ + selectors: [...Selectors], + combiner: (...args: SelectorResultArrayWithAdditionalArgs) => Result, + ] + // Input selectors as separate inline arguments + | [...Selectors, (...args: SelectorResultArrayWithAdditionalArgs) => Result]; + +type CreateSelectorFunction = >, Args, Result>( + ...items: SelectorArgs +) => OutputSelectorV8, Args, Result>; + +const cache = new WeakMap>(); + +function checkIsAPIRef(value: any) { + return 'current' in value && 'instanceId' in value.current; +} + +const DEFAULT_INSTANCE_ID = { id: 'default' }; + +export const createSelectorV8 = (( + a: Function, + b: Function, + c?: Function, + d?: Function, + e?: Function, + f?: Function, + ...other: any[] +) => { + if (other.length > 0) { + throw new Error('Unsupported number of selectors'); + } + + let selector: any; + + if (a && b && c && d && e && f) { + selector = (stateOrApiRef: any, args: any, instanceIdParam: any) => { + const isAPIRef = checkIsAPIRef(stateOrApiRef); + const instanceId = + instanceIdParam ?? (isAPIRef ? stateOrApiRef.current.instanceId : DEFAULT_INSTANCE_ID); + const state = isAPIRef ? stateOrApiRef.current.state : stateOrApiRef; + const va = a(state, args, instanceId); + const vb = b(state, args, instanceId); + const vc = c(state, args, instanceId); + const vd = d(state, args, instanceId); + const ve = e(state, args, instanceId); + return f(va, vb, vc, vd, ve, args); + }; + } else if (a && b && c && d && e) { + selector = (stateOrApiRef: any, args: any, instanceIdParam: any) => { + const isAPIRef = checkIsAPIRef(stateOrApiRef); + const instanceId = + instanceIdParam ?? (isAPIRef ? stateOrApiRef.current.instanceId : DEFAULT_INSTANCE_ID); + const state = isAPIRef ? stateOrApiRef.current.state : stateOrApiRef; + const va = a(state, args, instanceId); + const vb = b(state, args, instanceId); + const vc = c(state, args, instanceId); + const vd = d(state, args, instanceId); + return e(va, vb, vc, vd, args); + }; + } else if (a && b && c && d) { + selector = (stateOrApiRef: any, args: any, instanceIdParam: any) => { + const isAPIRef = checkIsAPIRef(stateOrApiRef); + const instanceId = + instanceIdParam ?? (isAPIRef ? stateOrApiRef.current.instanceId : DEFAULT_INSTANCE_ID); + const state = isAPIRef ? stateOrApiRef.current.state : stateOrApiRef; + const va = a(state, args, instanceId); + const vb = b(state, args, instanceId); + const vc = c(state, args, instanceId); + return d(va, vb, vc, args); + }; + } else if (a && b && c) { + selector = (stateOrApiRef: any, args: any, instanceIdParam: any) => { + const isAPIRef = checkIsAPIRef(stateOrApiRef); + const instanceId = + instanceIdParam ?? (isAPIRef ? stateOrApiRef.current.instanceId : DEFAULT_INSTANCE_ID); + const state = isAPIRef ? stateOrApiRef.current.state : stateOrApiRef; + const va = a(state, args, instanceId); + const vb = b(state, args, instanceId); + return c(va, vb, args); + }; + } else if (a && b) { + selector = (stateOrApiRef: any, args: any, instanceIdParam: any) => { + const isAPIRef = checkIsAPIRef(stateOrApiRef); + const instanceId = + instanceIdParam ?? (isAPIRef ? stateOrApiRef.current.instanceId : DEFAULT_INSTANCE_ID); + const state = isAPIRef ? stateOrApiRef.current.state : stateOrApiRef; + const va = a(state, args, instanceId); + return b(va, args); + }; + } else { + throw new Error('Missing arguments'); + } + + // We use this property to detect if the selector was created with createSelector + // or it's only a simple function the receives the state and returns part of it. + selector.acceptsApiRef = true; + + return selector; +}) as unknown as CreateSelectorFunction; + +export const createSelectorMemoizedV8: CreateSelectorFunction = (...args: any) => { + const selector = (stateOrApiRef: any, selectorArgs: any, instanceId?: any) => { + const isAPIRef = checkIsAPIRef(stateOrApiRef); + const cacheKey = isAPIRef + ? stateOrApiRef.current.instanceId + : (instanceId ?? DEFAULT_INSTANCE_ID); + const state = isAPIRef ? stateOrApiRef.current.state : stateOrApiRef; + + if (process.env.NODE_ENV !== 'production') { + if (cacheKey.id === 'default') { + warnOnce([ + 'MUI X: A selector was called without passing the instance ID, which may impact the performance of the grid.', + 'To fix, call it with `apiRef`, for example `mySelector(apiRef)`, or pass the instance ID explicitly, for example `mySelector(state, apiRef.current.instanceId)`.', + ]); + } + } + + const cacheArgsInit = cache.get(cacheKey); + const cacheArgs = cacheArgsInit ?? new Map(); + const cacheFn = cacheArgs?.get(args); + + if (cacheArgs && cacheFn) { + // We pass the cache key because the called selector might have as + // dependency another selector created with this `createSelector`. + return cacheFn(state, selectorArgs, cacheKey); + } + + const fn = reselectCreateSelector(...args); + + if (!cacheArgsInit) { + cache.set(cacheKey, cacheArgs); + } + cacheArgs.set(args, fn); + + return fn(state, selectorArgs, cacheKey); + }; + + // We use this property to detect if the selector was created with createSelector + // or it's only a simple function the receives the state and returns part of it. + selector.acceptsApiRef = true; + + return selector; +};