diff --git a/packages/toolkit/src/createReducer.ts b/packages/toolkit/src/createReducer.ts index 5662b11fb0..6a25d51d27 100644 --- a/packages/toolkit/src/createReducer.ts +++ b/packages/toolkit/src/createReducer.ts @@ -66,6 +66,12 @@ export type CaseReducers = { [T in keyof AS]: AS[T] extends Action ? CaseReducer : void } +type NotFunction = T extends Function ? never : T + +export type ReducerWithInitialState> = Reducer & { + getInitialState: () => S +} + /** * A utility function that allows defining a reducer as a mapping from action * type to *case reducer* functions that handle these action types. The @@ -130,10 +136,10 @@ createReducer( ``` * @public */ -export function createReducer( - initialState: S, +export function createReducer>( + initialState: S | (() => S), builderCallback: (builder: ActionReducerMapBuilder) => void -): Reducer +): ReducerWithInitialState /** * A utility function that allows defining a reducer as a mapping from action @@ -180,31 +186,35 @@ const counterReducer = createReducer(0, { * @public */ export function createReducer< - S, + S extends NotFunction, CR extends CaseReducers = CaseReducers >( - initialState: S, + initialState: S | (() => S), actionsMap: CR, actionMatchers?: ActionMatcherDescriptionCollection, defaultCaseReducer?: CaseReducer -): Reducer +): ReducerWithInitialState -export function createReducer( - initialState: S, +export function createReducer>( + initialState: S | (() => S), mapOrBuilderCallback: | CaseReducers | ((builder: ActionReducerMapBuilder) => void), actionMatchers: ReadonlyActionMatcherDescriptionCollection = [], defaultCaseReducer?: CaseReducer -): Reducer { +): ReducerWithInitialState { let [actionsMap, finalActionMatchers, finalDefaultCaseReducer] = typeof mapOrBuilderCallback === 'function' ? executeReducerBuilderCallback(mapOrBuilderCallback) : [mapOrBuilderCallback, actionMatchers, defaultCaseReducer] - const frozenInitialState = createNextState(initialState, () => {}) + const getInitialState = (): S => { + const stateToUse = + typeof initialState === 'function' ? initialState() : initialState + return createNextState(stateToUse, () => {}) as S + } - return function (state = frozenInitialState, action): S { + function reducer(state = getInitialState(), action: any): S { let caseReducers = [ actionsMap[action.type], ...finalActionMatchers @@ -257,4 +267,8 @@ export function createReducer( return previousState }, state) } + + reducer.getInitialState = getInitialState + + return reducer as ReducerWithInitialState }