diff --git a/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/corrections/index.ts b/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/corrections/index.ts index 9d7218a6f77cb..76ffd993bb9ab 100644 --- a/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/corrections/index.ts +++ b/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/corrections/index.ts @@ -7,12 +7,14 @@ import type { ESQLAstQueryExpression } from '@kbn/esql-ast'; import type { QueryCorrection } from './types'; -import { applyTimespanLiteralsCorrections } from './timespan_literals'; +import { correctTimespanLiterals } from './timespan_literals'; +import { correctLikeWildcards } from './like'; export type { QueryCorrection } from './types'; export const correctAll = (query: ESQLAstQueryExpression): QueryCorrection[] => { const corrections: QueryCorrection[] = []; - corrections.push(...applyTimespanLiteralsCorrections(query)); + corrections.push(...correctTimespanLiterals(query)); + corrections.push(...correctLikeWildcards(query)); return corrections; }; diff --git a/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/corrections/like.test.ts b/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/corrections/like.test.ts new file mode 100644 index 0000000000000..81779188c553b --- /dev/null +++ b/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/corrections/like.test.ts @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { parse, BasicPrettyPrinter } from '@kbn/esql-ast'; +import { correctLikeWildcards } from './like'; + +describe('correctLikeWildcards', () => { + it('replaces badly used "_" wildcard', () => { + const query = 'FROM logs | WHERE message LIKE "ba_"'; + const { root } = parse(query); + correctLikeWildcards(root); + + const output = BasicPrettyPrinter.print(root); + expect(output).toEqual('FROM logs | WHERE message LIKE "ba?"'); + }); + + it('replaces badly used "%" wildcard', () => { + const query = 'FROM logs | WHERE message LIKE "b%"'; + const { root } = parse(query); + correctLikeWildcards(root); + + const output = BasicPrettyPrinter.print(root); + expect(output).toEqual('FROM logs | WHERE message LIKE "b*"'); + }); + + it('replaces multiple bad wildcards', () => { + const query = 'FROM logs | WHERE message LIKE "a__t%"'; + const { root } = parse(query); + correctLikeWildcards(root); + + const output = BasicPrettyPrinter.print(root); + expect(output).toEqual('FROM logs | WHERE message LIKE "a??t*"'); + }); + + it('replaces bad wildcards in multiple commands and functions', () => { + const query = + 'FROM logs | WHERE message LIKE "a%" AND TO_UPPER(level) LIKE "err%" | WHERE foo LIKE "ba_"'; + const { root } = parse(query); + correctLikeWildcards(root); + + const output = BasicPrettyPrinter.print(root); + expect(output).toEqual( + 'FROM logs | WHERE message LIKE "a*" AND TO_UPPER(level) LIKE "err*" | WHERE foo LIKE "ba?"' + ); + }); + + it('does not replace escaped characters', () => { + const query = 'FROM logs | WHERE message LIKE "ba\\\\_"'; + const { root } = parse(query); + correctLikeWildcards(root); + + const output = BasicPrettyPrinter.print(root); + expect(output).toEqual('FROM logs | WHERE message LIKE "ba\\\\_"'); + }); +}); diff --git a/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/corrections/like.ts b/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/corrections/like.ts new file mode 100644 index 0000000000000..be61bd216284b --- /dev/null +++ b/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/corrections/like.ts @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { Walker, type ESQLAstQueryExpression } from '@kbn/esql-ast'; +import { isLikeOperatorNode, isStringLiteralNode } from '../typeguards'; +import type { ESQLLikeOperator, ESQLStringLiteral } from '../types'; +import type { QueryCorrection } from './types'; + +/** + * Correct wrong LIKE wildcard mistakes. + * The LLM can make mistake and use SQL wildcards for LIKE operators. + * + * E.g. + * `column LIKE "ba_"` => `column LIKE "ba?"` + * `column LIKE "ba%"` => `column LIKE "ba*"` + */ +export const correctLikeWildcards = (query: ESQLAstQueryExpression): QueryCorrection[] => { + const corrections: QueryCorrection[] = []; + + Walker.walk(query, { + visitFunction: (node) => { + if (isLikeOperatorNode(node)) { + corrections.push(...checkLikeNode(node)); + } + }, + }); + + return corrections; +}; + +function checkLikeNode(node: ESQLLikeOperator): QueryCorrection[] { + if (node.args.length !== 2 || !isStringLiteralNode(node.args[1])) { + return []; + } + const likeExpression = node.args[1] as ESQLStringLiteral; + + const initialValue = likeExpression.value; + + likeExpression.value = likeExpression.value + .replaceAll(/(? { +describe('correctTimespanLiterals', () => { describe('with DATE_TRUNC', () => { it('replaces a timespan with a proper timespan literal', () => { const query = 'FROM logs | EVAL truncated = DATE_TRUNC("1 year", date)'; const { root } = parse(query); - applyTimespanLiteralsCorrections(root); + correctTimespanLiterals(root); const output = BasicPrettyPrinter.print(root); @@ -27,7 +27,7 @@ describe('getTimespanLiteralsCorrections', () => { const query = 'FROM logs | EVAL truncated = DATE_TRUNC("month", date)'; const { root } = parse(query); - applyTimespanLiteralsCorrections(root); + correctTimespanLiterals(root); const output = BasicPrettyPrinter.print(root); @@ -40,7 +40,7 @@ describe('getTimespanLiteralsCorrections', () => { const query = 'FROM logs | EVAL truncated = DATE_TRUNC("1 YEAR", date)'; const { root } = parse(query); - applyTimespanLiteralsCorrections(root); + correctTimespanLiterals(root); const output = BasicPrettyPrinter.print(root); @@ -53,7 +53,7 @@ describe('getTimespanLiteralsCorrections', () => { const query = 'FROM logs | EVAL truncated = DATE_TRUNC("1 year", date)'; const { root } = parse(query); - const corrections = applyTimespanLiteralsCorrections(root); + const corrections = correctTimespanLiterals(root); expect(corrections).toHaveLength(1); expect(corrections[0]).toEqual({ @@ -70,7 +70,7 @@ describe('getTimespanLiteralsCorrections', () => { const query = 'FROM logs | STATS hires = COUNT(*) BY week = BUCKET(hire_date, "1 week")'; const { root } = parse(query); - applyTimespanLiteralsCorrections(root); + correctTimespanLiterals(root); const output = BasicPrettyPrinter.print(root); @@ -83,7 +83,7 @@ describe('getTimespanLiteralsCorrections', () => { const query = 'FROM logs | STATS hires = COUNT(*) BY hour = BUCKET(hire_date, "hour")'; const { root } = parse(query); - applyTimespanLiteralsCorrections(root); + correctTimespanLiterals(root); const output = BasicPrettyPrinter.print(root); @@ -96,7 +96,7 @@ describe('getTimespanLiteralsCorrections', () => { const query = 'FROM logs | STATS hires = COUNT(*) BY week = BUCKET(hire_date, "1 WEEK")'; const { root } = parse(query); - applyTimespanLiteralsCorrections(root); + correctTimespanLiterals(root); const output = BasicPrettyPrinter.print(root); @@ -109,7 +109,7 @@ describe('getTimespanLiteralsCorrections', () => { const query = 'FROM logs | STATS hires = COUNT(*) BY hour = BUCKET(hire_date, "hour")'; const { root } = parse(query); - const corrections = applyTimespanLiteralsCorrections(root); + const corrections = correctTimespanLiterals(root); expect(corrections).toHaveLength(1); expect(corrections[0]).toEqual({ @@ -129,7 +129,7 @@ describe('getTimespanLiteralsCorrections', () => { | STATS hires = COUNT(*) BY hour = BUCKET(hire_date, "3 hour")`; const { root } = parse(query); - applyTimespanLiteralsCorrections(root); + correctTimespanLiterals(root); const output = BasicPrettyPrinter.print(root, { multiline: true, pipeTab: '' }); diff --git a/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/corrections/timespan_literals.ts b/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/corrections/timespan_literals.ts index c3fbe636a2de1..039632c3d103f 100644 --- a/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/corrections/timespan_literals.ts +++ b/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/corrections/timespan_literals.ts @@ -19,9 +19,7 @@ import { QueryCorrection } from './types'; * `BUCKET(@timestamp, "1 week")` => `BUCKET(@timestamp, 1 week)` * */ -export const applyTimespanLiteralsCorrections = ( - query: ESQLAstQueryExpression -): QueryCorrection[] => { +export const correctTimespanLiterals = (query: ESQLAstQueryExpression): QueryCorrection[] => { const corrections: QueryCorrection[] = []; Walker.walk(query, { diff --git a/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/typeguards.ts b/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/typeguards.ts index 233625673b872..54bbe2f7300a9 100644 --- a/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/typeguards.ts +++ b/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/typeguards.ts @@ -5,8 +5,19 @@ * 2.0. */ -import type { ESQLSingleAstItem, ESQLAstItem, ESQLFunction, ESQLLiteral } from '@kbn/esql-ast'; -import type { ESQLStringLiteral, ESQLDateTruncFunction, ESQLBucketFunction } from './types'; +import type { + ESQLSingleAstItem, + ESQLAstItem, + ESQLFunction, + ESQLLiteral, + ESQLColumn, +} from '@kbn/esql-ast'; +import type { + ESQLStringLiteral, + ESQLDateTruncFunction, + ESQLBucketFunction, + ESQLLikeOperator, +} from './types'; export function isSingleItem(item: ESQLAstItem): item is ESQLSingleAstItem { return Object.hasOwn(item, 'type'); @@ -16,6 +27,10 @@ export function isFunctionNode(node: ESQLAstItem): node is ESQLFunction { return isSingleItem(node) && node.type === 'function'; } +export function isColumnNode(node: ESQLAstItem): node is ESQLColumn { + return isSingleItem(node) && node.type === 'column'; +} + export function isLiteralNode(node: ESQLAstItem): node is ESQLLiteral { return isSingleItem(node) && node.type === 'literal'; } @@ -31,3 +46,7 @@ export function isDateTruncFunctionNode(node: ESQLAstItem): node is ESQLDateTrun export function isBucketFunctionNode(node: ESQLAstItem): node is ESQLBucketFunction { return isFunctionNode(node) && node.subtype === 'variadic-call' && node.name === 'bucket'; } + +export function isLikeOperatorNode(node: ESQLAstItem): node is ESQLLikeOperator { + return isFunctionNode(node) && node.subtype === 'binary-expression' && node.name === 'like'; +} diff --git a/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/types.ts b/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/types.ts index dd2a9810e359e..6444f1490f3d2 100644 --- a/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/types.ts +++ b/x-pack/plugins/inference/common/tasks/nl_to_esql/ast/types.ts @@ -12,6 +12,11 @@ import { ESQLFunction, ESQLLiteral } from '@kbn/esql-ast'; */ export type ESQLDateTruncFunction = ESQLFunction<'variadic-call', 'date_trunc'>; +/** + * represents a LIKE function node. + */ +export type ESQLLikeOperator = ESQLFunction<'binary-expression', 'like'>; + /** * represents a BUCKET function node. */