Skip to content

Commit

Permalink
Merge pull request #907 from concord-consortium/186033205-fix-count-c…
Browse files Browse the repository at this point in the history
…aching

Fix count() caching
  • Loading branch information
pjanik authored Sep 28, 2023
2 parents ff9024d + 56cf4dc commit 76e4d33
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 21 deletions.
60 changes: 43 additions & 17 deletions v3/src/models/data/formula-fn-registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,16 @@ const evaluateNode = (node: MathNode, scope?: FormulaMathJsScope) => {
return node.compile().evaluate(scope)
}

// Every aggregate function can be cached in the same way. Also, each aggregate function needs to be evaluated
// within `withAggregateContext` method, so that the scope can be properly set up.
// Each aggregate function needs to be evaluated with `withAggregateContext` method.
const evaluateRawWithAggregateContext =
(fn: (args: MathNode[], mathjs: any, scope: FormulaMathJsScope) => FValue | FValue[]) => {
return (args: MathNode[], mathjs: any, scope: FormulaMathJsScope) => {
// withAggregateContext returns result of the callback function
return scope.withAggregateContext(() => fn(args, mathjs, scope))
}
}

// Almost every aggregate function can be cached in the same way.
const cachedAggregateFnFactory =
(fnName: string, fn: (args: MathNode[], mathjs: any, scope: FormulaMathJsScope) => FValue | FValue[]) => {
return (args: MathNode[], mathjs: any, scope: FormulaMathJsScope) => {
Expand All @@ -21,10 +29,7 @@ const cachedAggregateFnFactory =
if (cachedValue !== undefined) {
return cachedValue
}
let result
scope.withAggregateContext(() => {
result = fn(args, mathjs, scope)
})
const result = fn(args, mathjs, scope)
scope.setCached(cacheKey, result)
return result
}
Expand Down Expand Up @@ -221,17 +226,33 @@ export const fnRegistry = {
// count(expression, filterExpression)
count: {
isAggregate: true,
cachedEvaluateFactory: cachedAggregateFnFactory,
// Note that count is untypical aggregate function that cannot use typical caching. When count() is called without
// arguments, the default caching method would calculate incorrect cache key. Hence, caching is implemented directly
// in the function body.
cachedEvaluateFactory: undefined,
evaluateRaw: (args: MathNode[], mathjs: any, scope: FormulaMathJsScope) => {
const [ expression, filter ] = args
if (!expression) {
// Special case - count() without arguments returns number of children cases.
// Special case: count() without arguments returns number of children cases. Note that this cannot be cached
// as there is no argument and getCaseAggregateGroupId() would be calculated incorrectly. But it's not
// a problem, as scope.getCaseChildrenCount() returns result in O(1) time anyway.
return scope.getCaseChildrenCount()
}
let expressionValues = evaluateNode(expression, scope)

const cacheKey = `count(${args.toString()})-${scope.getCaseAggregateGroupId()}`
const cachedValue = scope.getCached(cacheKey)
if (cachedValue !== undefined) {
return cachedValue
}

const filterValues = filter && evaluateNode(filter, scope)
expressionValues = expressionValues.filter((v: any, i: number) => v !== "" && (filter ? !!filterValues[i] : true))
return expressionValues.length
const validExpressionValues = evaluateNode(expression, scope).filter((v: FValue, i: number) =>
v !== "" && (filter ? isValueTruthy(filterValues[i]) : true)
)
const result = validExpressionValues.length

scope.setCached(cacheKey, result)
return result
}
},

Expand All @@ -254,10 +275,9 @@ export const fnRegistry = {
if (!cachedData || casePointer >= cachedData.resultCasePointer) {
// We need to look for a new next value when there's no cached data (e.g. first case being processed) or when
// we already passed the index of cached result.
const numOfCases = scope.getNumberOfCases()
let expressionValue
if (filter) {
let filterValue
const numOfCases = scope.getNumberOfCases()
let expressionValue, filterValue
let currentGroup = caseGroupId
// Keep looking for truthy filter value as long as cases are in the same group and we didn't reach the end.
while (!isValueTruthy(filterValue) && casePointer < numOfCases && currentGroup === caseGroupId) {
Expand All @@ -280,12 +300,12 @@ export const fnRegistry = {
} else {
// When there's no filter, simply get the next expression value (within the same case group).
casePointer = scope.getCasePointer() + 1
scope.withCustomCasePointer(() => {
result = scope.withCustomCasePointer(() => {
if (scope.getCaseGroupId() === caseGroupId) {
expressionValue = evaluateNode(expression, scope)
return evaluateNode(expression, scope)
}
return undefined
}, casePointer)
result = expressionValue
}

scope.setCached(cacheKey, { result, resultCasePointer: casePointer })
Expand Down Expand Up @@ -360,7 +380,13 @@ export const typedFnRegistry: CODAPMathjsFunctionRegistry = fnRegistry
Object.keys(typedFnRegistry).forEach((key) => {
const fn = typedFnRegistry[key]
let evaluateRaw = fn.evaluateRaw
if (fn.isAggregate && !fn.evaluateRaw) {
throw new Error("Aggregate functions need to provide evaluateRaw")
}
if (evaluateRaw) {
if (fn.isAggregate) {
evaluateRaw = evaluateRawWithAggregateContext(evaluateRaw)
}
if (fn.cachedEvaluateFactory) {
// Use cachedEvaluateFactory if it's defined. Currently it's defined only for aggregate functions.
evaluateRaw = fn.cachedEvaluateFactory(key, evaluateRaw)
Expand Down
10 changes: 6 additions & 4 deletions v3/src/models/data/formula-mathjs-scope.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,18 +172,20 @@ export class FormulaMathJsScope {
// with... methods could be replaced by more elegant approach of creating sub-scope with modified properties,
// but it would require re-initialization of the data storage. Since this could happen multiple times for each
// evaluated case, it could be a performance hit. So, for now with... methods seem like a reasonable compromise.
withCustomCasePointer(callback: () => void, casePointer: number) {
withCustomCasePointer(callback: () => any, casePointer: number) {
const originalCasePointer = this.casePointer
this.casePointer = casePointer
callback()
const result = callback()
this.casePointer = originalCasePointer
return result
}

withAggregateContext(callback: () => void) {
withAggregateContext(callback: () => any) {
const originalIsAggregate = this.isAggregate
this.isAggregate = true
callback()
const result = callback()
this.isAggregate = originalIsAggregate
return result
}

getCaseChildrenCount() {
Expand Down

0 comments on commit 76e4d33

Please sign in to comment.