diff --git a/packages/malloy/src/lang/ast/expressions/case.ts b/packages/malloy/src/lang/ast/expressions/case.ts new file mode 100644 index 000000000..985982c93 --- /dev/null +++ b/packages/malloy/src/lang/ast/expressions/case.ts @@ -0,0 +1,123 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +import {ExprValue} from '../types/expr-value'; +import {ExpressionDef} from '../types/expression-def'; +import {FieldSpace} from '../types/field-space'; +import {MalloyElement} from '../types/malloy-element'; +import {FT} from '../fragtype-utils'; +import { + CaseExpr, + EvalSpace, + ExpressionType, + maxExpressionType, + mergeEvalSpaces, +} from '../../../model'; + +interface Choice { + then: ExprValue; + when: ExprValue; +} + +function typeCoalesce(ev1: ExprValue | undefined, ev2: ExprValue): ExprValue { + return ev1 === undefined || + ev1.dataType === 'null' || + ev1.dataType === 'error' + ? ev2 + : ev1; +} + +export class Case extends ExpressionDef { + elementType = 'case'; + constructor( + readonly choices: CaseWhen[], + readonly elseValue?: ExpressionDef + ) { + super({choices}); + this.has({elseValue}); + } + + getExpression(fs: FieldSpace): ExprValue { + const caseValue: CaseExpr = { + node: 'case', + kids: { + caseWhen: [], + caseThen: [], + caseElse: null, + }, + }; + const choiceValues: Choice[] = []; + for (const c of this.choices) { + const when = c.when.getExpression(fs); + const then = c.then.getExpression(fs); + choiceValues.push({when, then}); + } + let returnType: ExprValue | undefined; + let expressionType: ExpressionType = 'scalar'; + let evalSpace: EvalSpace = 'constant'; + for (const aChoice of choiceValues) { + if (!FT.typeEq(aChoice.when, FT.boolT)) { + return this.loggedErrorExpr('case-when-must-be-boolean', { + whenType: aChoice.when.dataType, + }); + } + if (returnType && !FT.typeEq(returnType, aChoice.then, true)) { + return this.loggedErrorExpr('case-then-type-does-not-match', { + thenType: aChoice.then.dataType, + returnType: returnType.dataType, + }); + } + returnType = typeCoalesce(returnType, aChoice.then); + expressionType = maxExpressionType( + expressionType, + maxExpressionType( + aChoice.then.expressionType, + aChoice.when.expressionType + ) + ); + evalSpace = mergeEvalSpaces( + evalSpace, + aChoice.then.evalSpace, + aChoice.when.evalSpace + ); + caseValue.kids.caseWhen.push(aChoice.when.value); + caseValue.kids.caseThen.push(aChoice.then.value); + } + if (this.elseValue) { + const elseValue = this.elseValue.getExpression(fs); + if (returnType && !FT.typeEq(returnType, elseValue, true)) { + return this.loggedErrorExpr('case-else-type-does-not-match', { + elseType: elseValue.dataType, + returnType: returnType.dataType, + }); + } + returnType = typeCoalesce(returnType, elseValue); + expressionType = maxExpressionType( + expressionType, + elseValue.expressionType + ); + evalSpace = mergeEvalSpaces(evalSpace, elseValue.evalSpace); + caseValue.kids.caseElse = elseValue.value; + } + return { + value: caseValue, + dataType: returnType?.dataType ?? 'null', + expressionType, + evalSpace, + }; + } +} + +export class CaseWhen extends MalloyElement { + elementType = 'caseWhen'; + constructor( + readonly when: ExpressionDef, + readonly then: ExpressionDef + ) { + super({when, then}); + } +} diff --git a/packages/malloy/src/lang/ast/expressions/pick-when.ts b/packages/malloy/src/lang/ast/expressions/pick-when.ts index 873f91fdc..27eb57c77 100644 --- a/packages/malloy/src/lang/ast/expressions/pick-when.ts +++ b/packages/malloy/src/lang/ast/expressions/pick-when.ts @@ -22,11 +22,11 @@ */ import { + CaseExpr, EvalSpace, ExpressionType, maxExpressionType, mergeEvalSpaces, - PickExpr, } from '../../../model/malloy_types'; import {FT} from '../fragtype-utils'; @@ -78,12 +78,12 @@ export class Pick extends ExpressionDef { } apply(fs: FieldSpace, op: string, expr: ExpressionDef): ExprValue { - const caseValue: PickExpr = { - node: 'pick', + const caseValue: CaseExpr = { + node: 'case', kids: { - pickWhen: [], - pickThen: [], - pickElse: {node: 'error', message: 'pick statement not complete'}, + caseWhen: [], + caseThen: [], + caseElse: null, }, }; let returnType: ExprValue | undefined; @@ -110,8 +110,8 @@ export class Pick extends ExpressionDef { }); } returnType = typeCoalesce(returnType, thenExpr); - caseValue.kids.pickWhen.push(whenExpr.value); - caseValue.kids.pickThen.push(thenExpr.value); + caseValue.kids.caseWhen.push(whenExpr.value); + caseValue.kids.caseThen.push(thenExpr.value); } const elsePart = this.elsePick || expr; const elseVal = elsePart.getExpression(fs); @@ -129,7 +129,7 @@ export class Pick extends ExpressionDef { }); } } - caseValue.kids.pickElse = elseVal.value; + caseValue.kids.caseElse = elseVal.value; return { dataType: returnType.dataType, expressionType: maxExpressionType( @@ -142,12 +142,12 @@ export class Pick extends ExpressionDef { } getExpression(fs: FieldSpace): ExprValue { - const pick: PickExpr = { - node: 'pick', + const pick: CaseExpr = { + node: 'case', kids: { - pickWhen: [], - pickThen: [], - pickElse: {node: 'error', message: 'pick statement not complete'}, + caseWhen: [], + caseThen: [], + caseElse: null, }, }; if (this.elsePick === undefined) { @@ -165,8 +165,8 @@ export class Pick extends ExpressionDef { 'pick with no value can only be used with apply' ); } - const pickWhen = c.when.requestExpression(fs); - if (pickWhen === undefined) { + const caseWhen = c.when.requestExpression(fs); + if (caseWhen === undefined) { this.loggedErrorExpr( 'pick-illegal-partial', 'pick with partial when can only be used with apply' @@ -205,8 +205,8 @@ export class Pick extends ExpressionDef { aChoice.pick.evalSpace, aChoice.when.evalSpace ); - pick.kids.pickWhen.push(aChoice.when.value); - pick.kids.pickThen.push(aChoice.pick.value); + pick.kids.caseWhen.push(aChoice.when.value); + pick.kids.caseThen.push(aChoice.pick.value); } const defVal = this.elsePick.getExpression(fs); anyExpressionType = maxExpressionType( @@ -221,7 +221,7 @@ export class Pick extends ExpressionDef { returnType: returnType.dataType, }); } - pick.kids.pickElse = defVal.value; + pick.kids.caseElse = defVal.value; return { dataType: returnType.dataType, expressionType: anyExpressionType, @@ -232,7 +232,7 @@ export class Pick extends ExpressionDef { } export class PickWhen extends MalloyElement { - elementType = 'pickWhen'; + elementType = 'caseWhen'; constructor( readonly pick: ExpressionDef | undefined, readonly when: ExpressionDef diff --git a/packages/malloy/src/lang/ast/index.ts b/packages/malloy/src/lang/ast/index.ts index 8870d19cf..09292df23 100644 --- a/packages/malloy/src/lang/ast/index.ts +++ b/packages/malloy/src/lang/ast/index.ts @@ -76,6 +76,7 @@ export * from './expressions/time-literal'; export * from './expressions/partial-compare'; export * from './expressions/partition_by'; export * from './expressions/pick-when'; +export * from './expressions/case'; export * from './expressions/expr-record-literal'; export * from './expressions/range'; export * from './expressions/time-frame'; diff --git a/packages/malloy/src/lang/grammar/MalloyParser.g4 b/packages/malloy/src/lang/grammar/MalloyParser.g4 index 0e5358383..ec4d24353 100644 --- a/packages/malloy/src/lang/grammar/MalloyParser.g4 +++ b/packages/malloy/src/lang/grammar/MalloyParser.g4 @@ -569,6 +569,7 @@ fieldExpr | ((id (EXCLAM malloyType?)?) | timeframe) OPAREN ( argumentList? ) CPAREN # exprFunc | pickStatement # exprPick + | caseStatement # exprCase | ungroup OPAREN fieldExpr (COMMA fieldName)* CPAREN # exprUngroup ; @@ -585,6 +586,14 @@ pick : PICK (pickValue=fieldExpr)? WHEN pickWhen=partialAllowedFieldExpr ; +caseStatement + : CASE (caseWhen)+ (ELSE caseElse=fieldExpr)? END + ; + +caseWhen + : WHEN condition=fieldExpr THEN result=fieldExpr + ; + recordKey: id; recordElement : fieldPath # recordRef @@ -644,4 +653,4 @@ connectionName: string; experimentalStatementForTesting // this only exists to enable tests for the experimental compiler flag : SEMI SEMI OBRACK string CBRACK - ; \ No newline at end of file + ; diff --git a/packages/malloy/src/lang/malloy-to-ast.ts b/packages/malloy/src/lang/malloy-to-ast.ts index 6cc1eacf2..9fa8a7c4d 100644 --- a/packages/malloy/src/lang/malloy-to-ast.ts +++ b/packages/malloy/src/lang/malloy-to-ast.ts @@ -1619,6 +1619,28 @@ export class MalloyToAST ); } + visitCaseStatement(pcx: parse.CaseStatementContext): ast.Case { + const whenCxs = pcx.caseWhen(); + const whens = whenCxs.map(whenCx => { + return new ast.CaseWhen( + this.getFieldExpr(whenCx._condition), + this.getFieldExpr(whenCx._result) + ); + }); + const elseCx = pcx._caseElse; + const theElse = elseCx ? this.getFieldExpr(elseCx) : undefined; + this.warnWithReplacement( + 'sql-case', + 'Use a `pick` statement instead of `case`', + this.parseInfo.rangeFromContext(pcx), + `${[ + ...whenCxs.map(whenCx => `pick ${whenCx._result.text} when ${whenCx._condition.text}`), + elseCx ? `else ${elseCx.text}` : 'else null' + ].join(' ')}` + ); + return new ast.Case(whens, theElse); + } + visitPickStatement(pcx: parse.PickStatementContext): ast.Pick { const picks = pcx.pick().map(pwCx => { let pickExpr: ast.ExpressionDef | undefined; diff --git a/packages/malloy/src/lang/parse-log.ts b/packages/malloy/src/lang/parse-log.ts index c293dbba7..9ed398846 100644 --- a/packages/malloy/src/lang/parse-log.ts +++ b/packages/malloy/src/lang/parse-log.ts @@ -104,7 +104,7 @@ type MessageParameterTypes = { 'pick-missing-else': {}; 'pick-missing-value': {}; 'pick-illegal-partial': {}; - 'pick-when-must-be-boolean': {whenType: string}; + 'pick-when-must-be-boolean': {whenType: FieldValueType}; 'experiment-not-enabled': {experimentId: string}; 'experimental-dialect-not-enabled': {dialect: string}; 'sql-native-not-allowed-in-expression': { @@ -350,6 +350,16 @@ type MessageParameterTypes = { 'sql-is-not-null': string; 'sql-is-null': string; 'illegal-record-property-type': string; + 'sql-case': string; + 'case-then-type-does-not-match': { + thenType: FieldValueType; + returnType: FieldValueType; + }; + 'case-else-type-does-not-match': { + elseType: FieldValueType; + returnType: FieldValueType; + }; + 'case-when-must-be-boolean': {whenType: FieldValueType}; }; export const MESSAGE_FORMATTERS: PartialErrorCodeMessageMap = { @@ -390,6 +400,12 @@ export const MESSAGE_FORMATTERS: PartialErrorCodeMessageMap = { 'syntax-error': e => e.message, 'internal-translator-error': e => `Internal Translator Error: ${e.message}`, 'invalid-timezone': e => `Invalid timezone: ${e.timezone}`, + 'case-then-type-does-not-match': e => + `Case then type ${e.thenType} does not match return type ${e.returnType}`, + 'case-else-type-does-not-match': e => + `Case else type ${e.elseType} does not match return type ${e.returnType}`, + 'case-when-must-be-boolean': e => + `Case when expression must be boolean, not ${e.whenType}`, }; export type MessageCode = keyof MessageParameterTypes; diff --git a/packages/malloy/src/lang/test/expressions.spec.ts b/packages/malloy/src/lang/test/expressions.spec.ts index 01c01e305..a30bebe64 100644 --- a/packages/malloy/src/lang/test/expressions.spec.ts +++ b/packages/malloy/src/lang/test/expressions.spec.ts @@ -34,6 +34,7 @@ import { error, errorMessage, warningMessage, + warning, } from './test-translator'; import './parse-expects'; @@ -979,6 +980,145 @@ describe('expressions', () => { }); }); + describe('case statements', () => { + test('full', () => { + expect(expr` + case + when ai = 42 then 'the answer' + when ai = 54 then 'the questionable answer' + else 'random' + end + `).toLog(warning('sql-case')); + }); + test('no else', () => { + expect(expr` + case + when ai = 42 then 'the answer' + when ai = 54 then 'the questionable answer' + end + `).toLog(warning('sql-case')); + }); + test('wrong then type', () => { + expect(expr` + case + when ai = 42 then 'the answer' + when ai = 54 then 7 + end + `).toLog( + warning('sql-case'), + error('case-then-type-does-not-match', { + thenType: 'number', + returnType: 'string', + }) + ); + }); + test('wrong else type', () => { + expect(expr` + case + when ai = 42 then 'the answer' + else @2020 + end + `).toLog( + warning('sql-case'), + error('case-else-type-does-not-match', { + elseType: 'date', + returnType: 'string', + }) + ); + }); + test('null then type okay second', () => { + expect(expr` + case + when ai = 42 then 'the answer' + when ai = 54 then null + end + `).toLog(warning('sql-case')); + }); + test('null then type okay first', () => { + expect(expr` + case + when ai = 54 then null + when ai = 42 then 'the answer' + end + `).toLog(warning('sql-case')); + }); + test('null else type okay', () => { + expect(expr` + case + when ai = 42 then 'the answer' + else null + end + `).toLog(warning('sql-case')); + }); + test('null then type before else okay', () => { + expect(expr` + case + when ai = 42 then null + else 'not the answer' + end + `).toLog(warning('sql-case')); + }); + test('non boolean when', () => { + expect(expr` + case when ai then null end + `).toLog(warning('sql-case'), error('case-when-must-be-boolean')); + }); + test('type of null then second', () => { + expect(` + case + when ai = 42 then 'the answer' + when ai = 54 then null + end + `).toReturnType('string'); + }); + test('type of null then first', () => { + expect(` + case + when ai = 54 then null + when ai = 42 then 'the answer' + end + `).toReturnType('string'); + }); + test('type of null else', () => { + expect(` + case + when ai = 42 then 'the answer' + else null + end + `).toReturnType('string'); + }); + test('type of null then type before else', () => { + expect(` + case + when ai = 42 then null + else 'not the answer' + end + `).toReturnType('string'); + }); + test.skip('replacement for full case', () => { + const e = expr`case + when ai = 42 then 'the answer' + when ai = 54 then 'the questionable answer' + else 'random' + end`; + e.translator.translate(); + expect(e.translator.logger.getLog()[0].replacement).toBe(`pick 'the answer' when ai = 42 pick 'the questionable answer' when ai = 54 else 'random'`); + }); + test.skip('replacement for case with no else', () => { + const e = expr`case + when ai = 42 then 'the answer' + when ai = 54 then 'the questionable answer' + end`; + e.translator.translate(); + expect(e.translator.logger.getLog()[0].replacement).toBe(`pick 'the answer' when ai = 42 pick 'the questionable answer' when ai = 54 else null`); + }); + test('interaction with pick', () => { + expect(expr` + pick case when true then 'hooray' end when true else null + `).toLog(warning('sql-case')); + }); + }); + describe('pick statements', () => { test('full', () => { expect(expr` diff --git a/packages/malloy/src/lang/test/parse-expects.ts b/packages/malloy/src/lang/test/parse-expects.ts index 30ad489bd..1a4c3d027 100644 --- a/packages/malloy/src/lang/test/parse-expects.ts +++ b/packages/malloy/src/lang/test/parse-expects.ts @@ -96,11 +96,11 @@ function rangeToStr(loc?: DocumentRange): string { return 'undefined'; } -function ensureNoProblems(trans: MalloyTranslator) { +function ensureNoProblems(trans: MalloyTranslator, warningsOkay = false) { if (trans.logger === undefined) { throw new Error('JESTERY BROKEN, CANT FIND ERORR LOG'); } - if (!trans.logger.empty()) { + if (warningsOkay ? trans.logger.hasErrors() : !trans.logger.empty()) { return { message: () => `Translation problems:\n${trans.prettyErrors()}`, pass: false, @@ -186,8 +186,8 @@ function xlator(ts: TestSource) { return ts.translator || new TestTranslator(ts.code); } -function xlated(tt: TestTranslator) { - const errorCheck = ensureNoProblems(tt); +function xlated(tt: TestTranslator, warningsOkay = false) { + const errorCheck = ensureNoProblems(tt, warningsOkay); if (!errorCheck.pass) { return errorCheck; } @@ -265,7 +265,7 @@ expect.extend({ toReturnType: function (exprText: string, returnType: string) { const exprModel = new BetaExpression(exprText); exprModel.compile(); - const ok = xlated(exprModel); + const ok = xlated(exprModel, true); if (!ok.pass) { return ok; } @@ -273,7 +273,7 @@ expect.extend({ const pass = d.dataType === returnType; const msg = `Expression type ${d.dataType} ${ pass ? '=' : '!=' - } $[returnType`; + } ${returnType}`; return {pass, message: () => msg}; }, toLog: function (s: TestSource, ...msgs: ProblemSpec[]) { diff --git a/packages/malloy/src/model/malloy_query.ts b/packages/malloy/src/model/malloy_query.ts index 271c9fa52..910f8a4aa 100644 --- a/packages/malloy/src/model/malloy_query.ts +++ b/packages/malloy/src/model/malloy_query.ts @@ -76,7 +76,6 @@ import { UngroupNode, SourceReferenceNode, TimeTruncExpr, - PickExpr, SpreadExpr, FilteredExpr, SourceDef, @@ -104,6 +103,7 @@ import { QueryToMaterialize, PrepareResultOptions, RepeatedRecordFieldDef, + CaseExpr, } from './malloy_types'; import {Connection} from '../connection/types'; @@ -1151,14 +1151,17 @@ class QueryField extends QueryNode { return retExpr; } - generatePickSQL(pf: PickExpr): string { + generateCaseSQL(pf: CaseExpr): string { const caseStmt = ['CASE']; - for (let i = 0; i < pf.kids.pickWhen.length; i += 1) { + for (let i = 0; i < pf.kids.caseWhen.length; i += 1) { caseStmt.push( - `WHEN ${pf.kids.pickWhen[i].sql} THEN ${pf.kids.pickThen[i].sql}` + `WHEN ${pf.kids.caseWhen[i].sql} THEN ${pf.kids.caseThen[i].sql}` ); } - caseStmt.push(`ELSE ${pf.kids.pickElse.sql}`, 'END'); + if (pf.kids.caseElse !== null) { + caseStmt.push(`ELSE ${pf.kids.caseElse.sql}`); + } + caseStmt.push('END'); return caseStmt.join(' '); } @@ -1191,6 +1194,7 @@ class QueryField extends QueryNode { expr = {...exprToTranslate}; const oldKids = exprToTranslate.kids; for (const [name, kidExpr] of Object.entries(oldKids)) { + if (kidExpr === null) continue; if (Array.isArray(kidExpr)) { expr.kids[name] = kidExpr.map(e => { return {...e, sql: subExpr(this, e)}; @@ -1318,8 +1322,8 @@ class QueryField extends QueryNode { return expr.node; case 'null': return 'NULL'; - case 'pick': - return this.generatePickSQL(expr); + case 'case': + return this.generateCaseSQL(expr); case '': return ''; case 'filterCondition': diff --git a/packages/malloy/src/model/malloy_types.ts b/packages/malloy/src/model/malloy_types.ts index 93726cfca..a1228dd0c 100644 --- a/packages/malloy/src/model/malloy_types.ts +++ b/packages/malloy/src/model/malloy_types.ts @@ -43,7 +43,7 @@ export interface ExprOptionalE extends ExprLeaf { } export interface ExprWithKids extends ExprLeaf { - kids: Record; + kids: Record; } export type AnyExpr = ExprE | ExprOptionalE | ExprWithKids | ExprLeaf; @@ -91,7 +91,7 @@ export type Expr = | FunctionOrderBy | GenericSQLExpr | NullNode - | PickExpr + | CaseExpr | ArrayEachExpr | ErrorNode; @@ -320,9 +320,9 @@ export interface NullNode extends ExprLeaf { node: 'null'; } -export interface PickExpr extends ExprWithKids { - node: 'pick'; - kids: {pickWhen: Expr[]; pickThen: Expr[]; pickElse: Expr}; +export interface CaseExpr extends ExprWithKids { + node: 'case'; + kids: {caseWhen: Expr[]; caseThen: Expr[]; caseElse: Expr | null}; } export interface ArrayEachExpr extends ExprLeaf { diff --git a/packages/malloy/src/model/utils.ts b/packages/malloy/src/model/utils.ts index d25895bba..5c195d843 100644 --- a/packages/malloy/src/model/utils.ts +++ b/packages/malloy/src/model/utils.ts @@ -131,7 +131,7 @@ export function* exprKids(eNode: Expr): IterableIterator { for (const kidEnt of Object.values(eNode.kids)) { if (Array.isArray(kidEnt)) { yield* kidEnt; - } else { + } else if (kidEnt !== null) { yield kidEnt; } } @@ -162,7 +162,7 @@ export function exprMap(eNode: Expr, mapFunc: (e: Expr) => Expr): Expr { for (const [name, kidEnt] of Object.entries(eNode.kids)) { if (Array.isArray(kidEnt)) { parentNode.kids[name] = kidEnt.map(kidEl => mapFunc(kidEl)); - } else { + } else if (kidEnt !== null) { parentNode.kids[name] = mapFunc(kidEnt); } } diff --git a/test/src/databases/all/expr.spec.ts b/test/src/databases/all/expr.spec.ts index 8cda4aefe..f3d82a2e5 100644 --- a/test/src/databases/all/expr.spec.ts +++ b/test/src/databases/all/expr.spec.ts @@ -410,6 +410,21 @@ describe.each(runtimes.runtimeList)('%s', (databaseName, runtime) => { `).malloyResultMatches(expressionModel, {a: 312}); }); + it('case expressions', async () => { + await expect(` + run: aircraft_models -> { + where: manufacturer ? 'BOEING' | 'CESSNA' + group_by: + other is case when manufacturer = 'BOEING' then 'BOEING' else 'OTHER' end + group_by: + nully is case when manufacturer = 'BOEING' then 'BOEING' end + } + `).malloyResultMatches(expressionModel, [ + {other: 'BOEING', nully: 'BOEING'}, + {other: 'OTHER', nully: null}, + ]); + }); + test.when(runtime.dialect.supportsSafeCast)('sql safe cast', async () => { await expect(` run: ${databaseName}.sql('SELECT 1 as one') -> { select: