-
Notifications
You must be signed in to change notification settings - Fork 0
/
sql_query_tool.ts
133 lines (127 loc) · 4.75 KB
/
sql_query_tool.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import { EventEmitter } from "node:events";
import OpenAI from "openai";
import { AssistantStreamEvent } from "openai/resources/beta/assistants";
import { RequiredActionFunctionToolCall } from "openai/resources/beta/threads/index";
import {inspect} from "util";
import { AST, Parser } from "node-sql-parser";
import { TABLE_SCHEMA } from ".";
export class SQLQueryToolEventHandler extends EventEmitter {
private client: OpenAI;
constructor(client) {
super();
this.client = client;
}
async onEvent(event: AssistantStreamEvent) {
try {
if (!event.event.includes("delta")) {
console.log(inspect(event,false,null));
}
// Retrieve events that are denoted with 'requires_action'
// since these will have our tool_calls
if (event.event === "thread.run.requires_action") {
await this.handleRequiresAction(
event,
);
}
} catch (error) {
console.error("Error handling event:", error);
}
}
async handleRequiresAction(event: AssistantStreamEvent.ThreadRunRequiresAction) {
const data = event.data;
try {
const toolOutputs =
await Promise.all(data?.required_action?.submit_tool_outputs?.tool_calls?.map(async (toolCall) => {
if (toolCall.function.name === "sql_query") {
console.log("run sql query");
return await this.handleSQLQuery(event, toolCall);
}
}) ?? []);
// Submit all the tool outputs at the same time
await this.submitToolOutputs(toolOutputs ?? [], data.id, data.thread_id);
} catch (error) {
console.error("Error processing required action:", error);
}
}
async handleSQLQuery(event: AssistantStreamEvent.ThreadRunRequiresAction, toolCall: RequiredActionFunctionToolCall): Promise<OpenAI.Beta.Threads.Runs.RunSubmitToolOutputsParams.ToolOutput> {
let output;
try {
const jsonArgs = JSON.parse(toolCall.function.arguments);
const sqlQueryTool = new SQLQueryTool(jsonArgs);
output = {
success: true,
error: null,
output: "page_views\n512",
}
} catch (e) {
output = {
success: false,
error: `${e}`,
output: null
}
}
console.log(inspect(output,false,null));
return {
output: JSON.stringify(output),
tool_call_id: toolCall.id
};
}
async submitToolOutputs(toolOutputs, runId, threadId) {
try {
console.log(`submitting: ${toolOutputs}`);
const stream = await this.client.beta.threads.runs.submitToolOutputsStream(
threadId,
runId,
{ tool_outputs: toolOutputs },
);
for await (const event of stream) {
this.emit("event", event);
}
} catch (error) {
console.error("Error submitting tool outputs:", error);
}
}
}
type SQLToolArgs = {
schema_name: string;
table_name: string;
sql: string;
}
class SQLQueryTool {
private args: SQLToolArgs;
private sqlParser: Parser;
public constructor(args: object) {
this.sqlParser = new Parser();
if (this.parseSQL(args)) {
this.args = args;
}
}
parseSQL(args: object): args is SQLToolArgs {
for (const field of ["schema_name","table_name","sql"] as const) {
if (!(field in args)) {
throw new Error(`Field '${field} is missing or undefined`);
}
}
const sql = (args as any).sql as string;
const schemaNames = TABLE_SCHEMA.map(t => t.schema_name);
const tableNames = TABLE_SCHEMA.map(t => t.table_name);
const columnNames = TABLE_SCHEMA.flatMap(t => Object.keys(t.fields));
const tableAuthority = `select::(${schemaNames.join("|")})::(${tableNames.join("|")})`;
const columnAuthority = columnNames.map(name => `select::null::${name}`);
const tableCheck = this.sqlParser.whiteListCheck(sql,[tableAuthority], {
database: "Redshift",
type: "table"
});
const columnCheck = this.sqlParser.whiteListCheck(sql,columnAuthority, {
database: "Redshift",
type: "column"
});
if (tableCheck != undefined) {
throw new Error(`${tableCheck.name}: ${tableCheck.message}`);
}
if (columnCheck != undefined) {
throw new Error(`${columnCheck.name}: ${columnCheck.message}`);
}
return true;
}
}