Skip to content

Commit

Permalink
[NL-to-ESQL] autocorrect bad LIKE wildcards (#202464)
Browse files Browse the repository at this point in the history
## Summary

Part of #198942

Add autocorrect for wrong `LIKE` wildcard.

The LLM can make mistake and use SQL wildcards for LIKE operators (`_`
instead of `?` and `%` instead of `*`)


Examples

**generated**
```
FROM logs | WHERE message LIKE "a%" AND TO_UPPER(level) LIKE "err%" | WHERE foo LIKE "ba_"
```
**corrected**
```
FROM logs | WHERE message LIKE "a*" AND TO_UPPER(level) LIKE "err*" | WHERE foo LIKE "ba?"
```

---------

Co-authored-by: kibanamachine <[email protected]>
  • Loading branch information
pgayvallet and kibanamachine authored Dec 3, 2024
1 parent d1c2e04 commit 2ace6ff
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

import type { ESQLAstQueryExpression } from '@kbn/esql-ast';
import type { QueryCorrection } from './types';
import { applyTimespanLiteralsCorrections } from './timespan_literals';
import { correctTimespanLiterals } from './timespan_literals';
import { correctLikeWildcards } from './like';

export type { QueryCorrection } from './types';

export const correctAll = (query: ESQLAstQueryExpression): QueryCorrection[] => {
const corrections: QueryCorrection[] = [];
corrections.push(...applyTimespanLiteralsCorrections(query));
corrections.push(...correctTimespanLiterals(query));
corrections.push(...correctLikeWildcards(query));
return corrections;
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { parse, BasicPrettyPrinter } from '@kbn/esql-ast';
import { correctLikeWildcards } from './like';

describe('correctLikeWildcards', () => {
it('replaces badly used "_" wildcard', () => {
const query = 'FROM logs | WHERE message LIKE "ba_"';
const { root } = parse(query);
correctLikeWildcards(root);

const output = BasicPrettyPrinter.print(root);
expect(output).toEqual('FROM logs | WHERE message LIKE "ba?"');
});

it('replaces badly used "%" wildcard', () => {
const query = 'FROM logs | WHERE message LIKE "b%"';
const { root } = parse(query);
correctLikeWildcards(root);

const output = BasicPrettyPrinter.print(root);
expect(output).toEqual('FROM logs | WHERE message LIKE "b*"');
});

it('replaces multiple bad wildcards', () => {
const query = 'FROM logs | WHERE message LIKE "a__t%"';
const { root } = parse(query);
correctLikeWildcards(root);

const output = BasicPrettyPrinter.print(root);
expect(output).toEqual('FROM logs | WHERE message LIKE "a??t*"');
});

it('replaces bad wildcards in multiple commands and functions', () => {
const query =
'FROM logs | WHERE message LIKE "a%" AND TO_UPPER(level) LIKE "err%" | WHERE foo LIKE "ba_"';
const { root } = parse(query);
correctLikeWildcards(root);

const output = BasicPrettyPrinter.print(root);
expect(output).toEqual(
'FROM logs | WHERE message LIKE "a*" AND TO_UPPER(level) LIKE "err*" | WHERE foo LIKE "ba?"'
);
});

it('does not replace escaped characters', () => {
const query = 'FROM logs | WHERE message LIKE "ba\\\\_"';
const { root } = parse(query);
correctLikeWildcards(root);

const output = BasicPrettyPrinter.print(root);
expect(output).toEqual('FROM logs | WHERE message LIKE "ba\\\\_"');
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { Walker, type ESQLAstQueryExpression } from '@kbn/esql-ast';
import { isLikeOperatorNode, isStringLiteralNode } from '../typeguards';
import type { ESQLLikeOperator, ESQLStringLiteral } from '../types';
import type { QueryCorrection } from './types';

/**
* Correct wrong LIKE wildcard mistakes.
* The LLM can make mistake and use SQL wildcards for LIKE operators.
*
* E.g.
* `column LIKE "ba_"` => `column LIKE "ba?"`
* `column LIKE "ba%"` => `column LIKE "ba*"`
*/
export const correctLikeWildcards = (query: ESQLAstQueryExpression): QueryCorrection[] => {
const corrections: QueryCorrection[] = [];

Walker.walk(query, {
visitFunction: (node) => {
if (isLikeOperatorNode(node)) {
corrections.push(...checkLikeNode(node));
}
},
});

return corrections;
};

function checkLikeNode(node: ESQLLikeOperator): QueryCorrection[] {
if (node.args.length !== 2 || !isStringLiteralNode(node.args[1])) {
return [];
}
const likeExpression = node.args[1] as ESQLStringLiteral;

const initialValue = likeExpression.value;

likeExpression.value = likeExpression.value
.replaceAll(/(?<!\\)%/g, '*')
.replaceAll(/(?<!\\)_/g, '?');

if (likeExpression.value !== initialValue) {
likeExpression.name = likeExpression.value;

const correction: QueryCorrection = {
type: 'wrong_like_wildcard',
node,
description: `Replaced wrong like wildcard in LIKE operator at position ${node.location.min}`,
};
return [correction];
}

return [];
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
*/

import { parse, BasicPrettyPrinter } from '@kbn/esql-ast';
import { applyTimespanLiteralsCorrections } from './timespan_literals';
import { correctTimespanLiterals } from './timespan_literals';

describe('getTimespanLiteralsCorrections', () => {
describe('correctTimespanLiterals', () => {
describe('with DATE_TRUNC', () => {
it('replaces a timespan with a proper timespan literal', () => {
const query = 'FROM logs | EVAL truncated = DATE_TRUNC("1 year", date)';
const { root } = parse(query);

applyTimespanLiteralsCorrections(root);
correctTimespanLiterals(root);

const output = BasicPrettyPrinter.print(root);

Expand All @@ -27,7 +27,7 @@ describe('getTimespanLiteralsCorrections', () => {
const query = 'FROM logs | EVAL truncated = DATE_TRUNC("month", date)';
const { root } = parse(query);

applyTimespanLiteralsCorrections(root);
correctTimespanLiterals(root);

const output = BasicPrettyPrinter.print(root);

Expand All @@ -40,7 +40,7 @@ describe('getTimespanLiteralsCorrections', () => {
const query = 'FROM logs | EVAL truncated = DATE_TRUNC("1 YEAR", date)';
const { root } = parse(query);

applyTimespanLiteralsCorrections(root);
correctTimespanLiterals(root);

const output = BasicPrettyPrinter.print(root);

Expand All @@ -53,7 +53,7 @@ describe('getTimespanLiteralsCorrections', () => {
const query = 'FROM logs | EVAL truncated = DATE_TRUNC("1 year", date)';
const { root } = parse(query);

const corrections = applyTimespanLiteralsCorrections(root);
const corrections = correctTimespanLiterals(root);

expect(corrections).toHaveLength(1);
expect(corrections[0]).toEqual({
Expand All @@ -70,7 +70,7 @@ describe('getTimespanLiteralsCorrections', () => {
const query = 'FROM logs | STATS hires = COUNT(*) BY week = BUCKET(hire_date, "1 week")';
const { root } = parse(query);

applyTimespanLiteralsCorrections(root);
correctTimespanLiterals(root);

const output = BasicPrettyPrinter.print(root);

Expand All @@ -83,7 +83,7 @@ describe('getTimespanLiteralsCorrections', () => {
const query = 'FROM logs | STATS hires = COUNT(*) BY hour = BUCKET(hire_date, "hour")';
const { root } = parse(query);

applyTimespanLiteralsCorrections(root);
correctTimespanLiterals(root);

const output = BasicPrettyPrinter.print(root);

Expand All @@ -96,7 +96,7 @@ describe('getTimespanLiteralsCorrections', () => {
const query = 'FROM logs | STATS hires = COUNT(*) BY week = BUCKET(hire_date, "1 WEEK")';
const { root } = parse(query);

applyTimespanLiteralsCorrections(root);
correctTimespanLiterals(root);

const output = BasicPrettyPrinter.print(root);

Expand All @@ -109,7 +109,7 @@ describe('getTimespanLiteralsCorrections', () => {
const query = 'FROM logs | STATS hires = COUNT(*) BY hour = BUCKET(hire_date, "hour")';
const { root } = parse(query);

const corrections = applyTimespanLiteralsCorrections(root);
const corrections = correctTimespanLiterals(root);

expect(corrections).toHaveLength(1);
expect(corrections[0]).toEqual({
Expand All @@ -129,7 +129,7 @@ describe('getTimespanLiteralsCorrections', () => {
| STATS hires = COUNT(*) BY hour = BUCKET(hire_date, "3 hour")`;
const { root } = parse(query);

applyTimespanLiteralsCorrections(root);
correctTimespanLiterals(root);

const output = BasicPrettyPrinter.print(root, { multiline: true, pipeTab: '' });

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ import { QueryCorrection } from './types';
* `BUCKET(@timestamp, "1 week")` => `BUCKET(@timestamp, 1 week)`
*
*/
export const applyTimespanLiteralsCorrections = (
query: ESQLAstQueryExpression
): QueryCorrection[] => {
export const correctTimespanLiterals = (query: ESQLAstQueryExpression): QueryCorrection[] => {
const corrections: QueryCorrection[] = [];

Walker.walk(query, {
Expand Down
23 changes: 21 additions & 2 deletions x-pack/plugins/inference/common/tasks/nl_to_esql/ast/typeguards.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,19 @@
* 2.0.
*/

import type { ESQLSingleAstItem, ESQLAstItem, ESQLFunction, ESQLLiteral } from '@kbn/esql-ast';
import type { ESQLStringLiteral, ESQLDateTruncFunction, ESQLBucketFunction } from './types';
import type {
ESQLSingleAstItem,
ESQLAstItem,
ESQLFunction,
ESQLLiteral,
ESQLColumn,
} from '@kbn/esql-ast';
import type {
ESQLStringLiteral,
ESQLDateTruncFunction,
ESQLBucketFunction,
ESQLLikeOperator,
} from './types';

export function isSingleItem(item: ESQLAstItem): item is ESQLSingleAstItem {
return Object.hasOwn(item, 'type');
Expand All @@ -16,6 +27,10 @@ export function isFunctionNode(node: ESQLAstItem): node is ESQLFunction {
return isSingleItem(node) && node.type === 'function';
}

export function isColumnNode(node: ESQLAstItem): node is ESQLColumn {
return isSingleItem(node) && node.type === 'column';
}

export function isLiteralNode(node: ESQLAstItem): node is ESQLLiteral {
return isSingleItem(node) && node.type === 'literal';
}
Expand All @@ -31,3 +46,7 @@ export function isDateTruncFunctionNode(node: ESQLAstItem): node is ESQLDateTrun
export function isBucketFunctionNode(node: ESQLAstItem): node is ESQLBucketFunction {
return isFunctionNode(node) && node.subtype === 'variadic-call' && node.name === 'bucket';
}

export function isLikeOperatorNode(node: ESQLAstItem): node is ESQLLikeOperator {
return isFunctionNode(node) && node.subtype === 'binary-expression' && node.name === 'like';
}
5 changes: 5 additions & 0 deletions x-pack/plugins/inference/common/tasks/nl_to_esql/ast/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ import { ESQLFunction, ESQLLiteral } from '@kbn/esql-ast';
*/
export type ESQLDateTruncFunction = ESQLFunction<'variadic-call', 'date_trunc'>;

/**
* represents a LIKE function node.
*/
export type ESQLLikeOperator = ESQLFunction<'binary-expression', 'like'>;

/**
* represents a BUCKET function node.
*/
Expand Down

0 comments on commit 2ace6ff

Please sign in to comment.