diff --git a/normalizer.go b/normalizer.go index 4900638..5bf8d6e 100644 --- a/normalizer.go +++ b/normalizer.go @@ -83,14 +83,15 @@ const ( // Normalize takes an input SQL string and returns a normalized SQL string, a StatementMetadata struct, and an error. // The normalizer collapses input SQL into compact format, groups obfuscated values into single placeholder, // and collects metadata such as table names, comments, and commands. -func (n *Normalizer) Normalize(input string, lexerOpts ...lexerOption) (normalized string, info *StatementMetadata, err error) { +func (n *Normalizer) Normalize(input string, lexerOpts ...lexerOption) (normalizedSQL string, statementMetadata *StatementMetadata, err error) { lexer := New( input, lexerOpts..., ) - var normalizedSQL string - var statementMetadata = &StatementMetadata{ + var normalizedSQLBuilder strings.Builder + + statementMetadata = &StatementMetadata{ Tables: []string{}, Comments: []string{}, Commands: []string{}, @@ -99,41 +100,15 @@ func (n *Normalizer) Normalize(input string, lexerOpts ...lexerOption) (normaliz var lastToken Token // The last token that is not whitespace or comment for _, token := range lexer.ScanAll() { - if token.Type == COMMENT || token.Type == MULTILINE_COMMENT { - // Collect comments - if n.config.CollectComments { - statementMetadata.Comments = append(statementMetadata.Comments, token.Value) - } - } else if token.Type == IDENT { - if isCommand(strings.ToUpper(token.Value)) && n.config.CollectCommands { - // Collect commands - statementMetadata.Commands = append(statementMetadata.Commands, strings.ToUpper(token.Value)) - } else if isTableIndicator(strings.ToUpper(lastToken.Value)) { - // Collect table names - if n.config.CollectTables { - statementMetadata.Tables = append(statementMetadata.Tables, token.Value) - } - } - } - - normalizedSQL = normalizeSQL(token, lastToken, normalizedSQL, n.config) - - // TODO: We rely on the WS token to determine if we should add a whitespace - // This is not ideal, as SQLs with slightly different formatting will NOT be normalized into single family - // e.g. "SELECT * FROM table where id = ?" and "SELECT * FROM table where id= ?" will be normalized into different family - if token.Type != WS && token.Type != COMMENT && token.Type != MULTILINE_COMMENT { - lastToken = token - } + n.collectMetadata(token, lastToken, statementMetadata) + lastToken = n.normalizeSQL(token, lastToken, &normalizedSQLBuilder) } - // We use regex to group consecutive obfuscated values into single placeholder. - // This is "less" performant than token by token processing, - // but it is much simpler to implement and maintain. - // The trade off made here is assuming normalization runs on backend - // where performance is not as critical as the agent. + normalizedSQL = normalizedSQLBuilder.String() + normalizedSQL = groupObfuscatedValues(normalizedSQL) if !n.config.KeepSQLAlias { - normalizedSQL = DiscardSQLAlias(normalizedSQL) + normalizedSQL = discardSQLAlias(normalizedSQL) } // Dedupe collected metadata @@ -142,27 +117,46 @@ func (n *Normalizer) Normalize(input string, lexerOpts ...lexerOption) (normaliz return strings.TrimSpace(normalizedSQL), statementMetadata, nil } -func normalizeSQL(token Token, lastToken Token, statement string, config *normalizerConfig) string { - if token.Type == WS || token.Type == COMMENT || token.Type == MULTILINE_COMMENT { - // We don't rely on the WS token to determine if we should add a whitespace - return statement +func (n *Normalizer) collectMetadata(token Token, lastToken Token, statementMetadata *StatementMetadata) { + if n.config.CollectComments && (token.Type == COMMENT || token.Type == MULTILINE_COMMENT) { + // Collect comments + statementMetadata.Comments = append(statementMetadata.Comments, token.Value) + } else if token.Type == IDENT { + if n.config.CollectCommands && isCommand(strings.ToUpper(token.Value)) { + // Collect commands + statementMetadata.Commands = append(statementMetadata.Commands, strings.ToUpper(token.Value)) + } else if n.config.CollectTables && isTableIndicator(strings.ToUpper(lastToken.Value)) { + // Collect table names + statementMetadata.Tables = append(statementMetadata.Tables, token.Value) + } } +} - // determine if we should add a whitespace - statement = appendWhitespace(lastToken, token, statement) - if isSQLKeyword(token) && config.UppercaseKeywords { - statement += strings.ToUpper(token.Value) - } else { - statement += token.Value +func (n *Normalizer) normalizeSQL(token Token, lastToken Token, normalizedSQLBuilder *strings.Builder) Token { + if token.Type != WS && token.Type != COMMENT && token.Type != MULTILINE_COMMENT { + // determine if we should add a whitespace + appendWhitespace(lastToken, token, normalizedSQLBuilder) + if n.config.UppercaseKeywords && isSQLKeyword(token) { + normalizedSQLBuilder.WriteString(strings.ToUpper(token.Value)) + } else { + normalizedSQLBuilder.WriteString(token.Value) + } + + lastToken = token } - return statement + return lastToken } // groupObfuscatedValues groups consecutive obfuscated values in a SQL query into a single placeholder. // It replaces "(?, ?, ...)" and "[?, ?, ...]" with "( ? )" and "[ ? ]", respectively. // Returns the modified SQL query as a string. func groupObfuscatedValues(input string) string { + // We use regex to group consecutive obfuscated values into single placeholder. + // This is "less" performant than token by token processing, + // but it is much simpler to implement and maintain. + // The trade off made here is assuming normalization runs on backend + // where performance is not as critical as the agent. grouped := groupableRegex.ReplaceAllStringFunc(input, func(match string) string { if match[0] == '(' { return ArrayPlaceholder @@ -172,11 +166,11 @@ func groupObfuscatedValues(input string) string { return grouped } -// DiscardSQLAlias removes any SQL alias from the input string and returns the modified string. +// discardSQLAlias removes any SQL alias from the input string and returns the modified string. // It uses a regular expression to match the alias pattern and replace it with an empty string. // The function is case-insensitive and matches the pattern "AS ". // The input string is not modified in place. -func DiscardSQLAlias(input string) string { +func discardSQLAlias(input string) string { return sqlAliasRegex.ReplaceAllString(input, "") } @@ -201,7 +195,7 @@ func dedupeStatementMetadata(info *StatementMetadata) { info.Commands = dedupeCollectedMetadata(info.Commands) } -func appendWhitespace(lastToken Token, token Token, normalizedSQL string) string { +func appendWhitespace(lastToken Token, token Token, normalizedSQLBuilder *strings.Builder) { switch token.Value { case ",": case "=": @@ -212,8 +206,6 @@ func appendWhitespace(lastToken Token, token Token, normalizedSQL string) string } fallthrough default: - normalizedSQL += " " + normalizedSQLBuilder.WriteString(" ") } - - return normalizedSQL } diff --git a/obfuscate_and_normalize.go b/obfuscate_and_normalize.go new file mode 100644 index 0000000..316d764 --- /dev/null +++ b/obfuscate_and_normalize.go @@ -0,0 +1,40 @@ +package sqllexer + +import "strings" + +// ObfuscateAndNormalize takes an input SQL string and returns an normalized SQL string with metadata +// This function is a convenience function that combines the Obfuscator and Normalizer in one pass +func ObfuscateAndNormalize(input string, obfuscator *Obfuscator, normalizer *Normalizer, lexerOpts ...lexerOption) (normalizedSQL string, statementMetadata *StatementMetadata, err error) { + lexer := New( + input, + lexerOpts..., + ) + + var normalizedSQLBuilder strings.Builder + + statementMetadata = &StatementMetadata{ + Tables: []string{}, + Comments: []string{}, + Commands: []string{}, + } + + var lastToken Token // The last token that is not whitespace or comment + + for _, token := range lexer.ScanAll() { + obfuscatedToken := Token{Type: token.Type, Value: obfuscator.ObfuscateTokenValue(token, lexerOpts...)} + normalizer.collectMetadata(obfuscatedToken, lastToken, statementMetadata) + lastToken = normalizer.normalizeSQL(obfuscatedToken, lastToken, &normalizedSQLBuilder) + } + + normalizedSQL = normalizedSQLBuilder.String() + + normalizedSQL = groupObfuscatedValues(normalizedSQL) + if !normalizer.config.KeepSQLAlias { + normalizedSQL = discardSQLAlias(normalizedSQL) + } + + // Dedupe collected metadata + dedupeStatementMetadata(statementMetadata) + + return strings.TrimSpace(normalizedSQL), statementMetadata, nil +} diff --git a/obfuscate_and_normalize_bench_test.go b/obfuscate_and_normalize_bench_test.go new file mode 100644 index 0000000..ba0b19d --- /dev/null +++ b/obfuscate_and_normalize_bench_test.go @@ -0,0 +1,112 @@ +package sqllexer + +import ( + "fmt" + "strconv" + "testing" +) + +// Benchmark the Tokenizer using a SQL statement +func BenchmarkObfuscationAndNormalization(b *testing.B) { + // LargeQuery is sourced from https://stackoverflow.com/questions/12607667/issues-with-a-very-large-sql-query/12711494 + var LargeQuery = `SELECT '%c%' as Chapter, +(SELECT count(ticket.id) AS Matches FROM engine.ticket INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket +WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%' AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine' AND ticket.status IN ('new','assigned') ) AS 'New', +(SELECT count(ticket.id) AS Matches FROM engine.ticket INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket +WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%' AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine' AND ticket.status='document_interface' ) AS 'Document\ + Interface', +(SELECT count(ticket.id) AS Matches FROM engine.ticket INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket +WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%' AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine' AND ticket.status='interface_development' ) AS 'Inter\ +face Development', +(SELECT count(ticket.id) AS Matches FROM engine.ticket INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket +WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%' AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine' AND ticket.status='interface_check' ) AS 'Interface C\ +heck', +(SELECT count(ticket.id) AS Matches FROM engine.ticket INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket +WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%' AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine' AND ticket.status='document_routine' ) AS 'Document R\ +outine', +(SELECT count(ticket.id) AS Matches FROM engine.ticket INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket +WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%' AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine' AND ticket.status='full_development' ) AS 'Full Devel\ +opment', +(SELECT count(ticket.id) AS Matches FROM engine.ticket INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket +WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%' AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine' AND ticket.status='peer_review_1' ) AS 'Peer Review O\ +ne', +(SELECT count(ticket.id) AS Matches FROM engine.ticket INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket +WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%'AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine' AND ticket.status='peer_review_2' ) AS 'Peer Review Tw\ +o', +(SELECT count(ticket.id) AS Matches FROM engine.ticket INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket +WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%' AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine' AND ticket.status='qa' ) AS 'QA', +(SELECT count(ticket.id) AS Matches FROM engine.ticket INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket +WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%'AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine' AND ticket.status='closed' ) AS 'Closed', +count(id) AS Total, +ticket.id AS _id +FROM engine.ticket +INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket +WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%' AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine'` + + // query3 is sourced from https://www.ibm.com/support/knowledgecenter/SSCRJT_6.0.0/com.ibm.swg.im.bigsql.doc/doc/tut_bsql_uc_complex_query.html + var ComplexQuery = `WITH + sales AS + (SELECT sf.* + FROM gosalesdw.sls_order_method_dim AS md, + gosalesdw.sls_product_dim AS pd, + gosalesdw.emp_employee_dim AS ed, + gosalesdw.sls_sales_fact AS sf + WHERE pd.product_key = sf.product_key + AND pd.product_number > 10000 + AND pd.base_product_key > 30 + AND md.order_method_key = sf.order_method_key + AND md.order_method_code > 5 + AND ed.employee_key = sf.employee_key + AND ed.manager_code1 > 20), + inventory AS + (SELECT if.* + FROM gosalesdw.go_branch_dim AS bd, + gosalesdw.dist_inventory_fact AS if + WHERE if.branch_key = bd.branch_key + AND bd.branch_code > 20) +SELECT sales.product_key AS PROD_KEY, + SUM(CAST (inventory.quantity_shipped AS BIGINT)) AS INV_SHIPPED, + SUM(CAST (sales.quantity AS BIGINT)) AS PROD_QUANTITY, + RANK() OVER ( ORDER BY SUM(CAST (sales.quantity AS BIGINT)) DESC) AS PROD_RANK +FROM sales, inventory + WHERE sales.product_key = inventory.product_key +GROUP BY sales.product_key; +` + + var superLargeQuery = "select top ? percent IdTrebEmpresa, CodCli, NOMEMP, Baixa, CASE WHEN IdCentreTreball IS ? THEN ? ELSE CONVERT ( VARCHAR ( ? ) IdCentreTreball ) END, CASE WHEN NOMESTAB IS ? THEN ? ELSE NOMESTAB END, TIPUS, CASE WHEN IdLloc IS ? THEN ? ELSE CONVERT ( VARCHAR ( ? ) IdLloc ) END, CASE WHEN NomLlocComplert IS ? THEN ? ELSE NomLlocComplert END, CASE WHEN DesLloc IS ? THEN ? ELSE DesLloc END, IdLlocTreballUnic From ( SELECT ?, dbo.Treb_Empresa.IdTrebEmpresa, dbo.Treb_Empresa.IdTreballador, dbo.Treb_Empresa.CodCli, dbo.Clients.NOMEMP, dbo.Treb_Empresa.Baixa, dbo.Treb_Empresa.IdCentreTreball, dbo.Cli_Establiments.NOMESTAB, ?, ?, dbo.Treb_Empresa.DataInici, dbo.Treb_Empresa.DataFi, CASE WHEN dbo.Treb_Empresa.DesLloc IS ? THEN ? ELSE dbo.Treb_Empresa.DesLloc END DesLloc, dbo.Treb_Empresa.IdLlocTreballUnic FROM dbo.Clients WITH ( NOLOCK ) INNER JOIN dbo.Treb_Empresa WITH ( NOLOCK ) ON dbo.Clients.CODCLI = dbo.Treb_Empresa.CodCli LEFT OUTER JOIN dbo.Cli_Establiments WITH ( NOLOCK ) ON dbo.Cli_Establiments.Id_ESTAB_CLI = dbo.Treb_Empresa.IdCentreTreball AND dbo.Cli_Establiments.CODCLI = dbo.Treb_Empresa.CodCli WHERE dbo.Treb_Empresa.IdTreballador = ? AND Treb_Empresa.IdTecEIRLLlocTreball IS ? AND IdMedEIRLLlocTreball IS ? AND IdLlocTreballTemporal IS ? UNION ALL SELECT ?, dbo.Treb_Empresa.IdTrebEmpresa, dbo.Treb_Empresa.IdTreballador, dbo.Treb_Empresa.CodCli, dbo.Clients.NOMEMP, dbo.Treb_Empresa.Baixa, dbo.Treb_Empresa.IdCentreTreball, dbo.Cli_Establiments.NOMESTAB, dbo.Treb_Empresa.IdTecEIRLLlocTreball, dbo.fn_NomLlocComposat ( dbo.Treb_Empresa.IdTecEIRLLlocTreball ), dbo.Treb_Empresa.DataInici, dbo.Treb_Empresa.DataFi, CASE WHEN dbo.Treb_Empresa.DesLloc IS ? THEN ? ELSE dbo.Treb_Empresa.DesLloc END DesLloc, dbo.Treb_Empresa.IdLlocTreballUnic FROM dbo.Clients WITH ( NOLOCK ) INNER JOIN dbo.Treb_Empresa WITH ( NOLOCK ) ON dbo.Clients.CODCLI = dbo.Treb_Empresa.CodCli LEFT OUTER JOIN dbo.Cli_Establiments WITH ( NOLOCK ) ON dbo.Cli_Establiments.Id_ESTAB_CLI = dbo.Treb_Empresa.IdCentreTreball AND dbo.Cli_Establiments.CODCLI = dbo.Treb_Empresa.CodCli WHERE ( dbo.Treb_Empresa.IdTreballador = ? ) AND ( NOT ( dbo.Treb_Empresa.IdTecEIRLLlocTreball IS ? ) ) UNION ALL SELECT ?, dbo.Treb_Empresa.IdTrebEmpresa, dbo.Treb_Empresa.IdTreballador, dbo.Treb_Empresa.CodCli, dbo.Clients.NOMEMP, dbo.Treb_Empresa.Baixa, dbo.Treb_Empresa.IdCentreTreball, dbo.Cli_Establiments.NOMESTAB, dbo.Treb_Empresa.IdMedEIRLLlocTreball, dbo.fn_NomMedEIRLLlocComposat ( dbo.Treb_Empresa.IdMedEIRLLlocTreball ), dbo.Treb_Empresa.DataInici, dbo.Treb_Empresa.DataFi, CASE WHEN dbo.Treb_Empresa.DesLloc IS ? THEN ? ELSE dbo.Treb_Empresa.DesLloc END DesLloc, dbo.Treb_Empresa.IdLlocTreballUnic FROM dbo.Clients WITH ( NOLOCK ) INNER JOIN dbo.Treb_Empresa WITH ( NOLOCK ) ON dbo.Clients.CODCLI = dbo.Treb_Empresa.CodCli LEFT OUTER JOIN dbo.Cli_Establiments WITH ( NOLOCK ) ON dbo.Cli_Establiments.Id_ESTAB_CLI = dbo.Treb_Empresa.IdCentreTreball AND dbo.Cli_Establiments.CODCLI = dbo.Treb_Empresa.CodCli WHERE ( dbo.Treb_Empresa.IdTreballador = ? ) AND ( Treb_Empresa.IdTecEIRLLlocTreball IS ? ) AND ( NOT ( dbo.Treb_Empresa.IdMedEIRLLlocTreball IS ? ) ) UNION ALL SELECT ?, dbo.Treb_Empresa.IdTrebEmpresa, dbo.Treb_Empresa.IdTreballador, dbo.Treb_Empresa.CodCli, dbo.Clients.NOMEMP, dbo.Treb_Empresa.Baixa, dbo.Treb_Empresa.IdCentreTreball, dbo.Cli_Establiments.NOMESTAB, dbo.Treb_Empresa.IdLlocTreballTemporal, dbo.Lloc_Treball_Temporal.NomLlocTreball, dbo.Treb_Empresa.DataInici, dbo.Treb_Empresa.DataFi, CASE WHEN dbo.Treb_Empresa.DesLloc IS ? THEN ? ELSE dbo.Treb_Empresa.DesLloc END DesLloc, dbo.Treb_Empresa.IdLlocTreballUnic FROM dbo.Clients WITH ( NOLOCK ) INNER JOIN dbo.Treb_Empresa WITH ( NOLOCK ) ON dbo.Clients.CODCLI = dbo.Treb_Empresa.CodCli INNER JOIN dbo.Lloc_Treball_Temporal WITH ( NOLOCK ) ON dbo.Treb_Empresa.IdLlocTreballTemporal = dbo.Lloc_Treball_Temporal.IdLlocTreballTemporal LEFT OUTER JOIN dbo.Cli_Establiments WITH ( NOLOCK ) ON dbo.Cli_Establiments.Id_ESTAB_CLI = dbo.Treb_Empresa.IdCentreTreball AND dbo.Cli_Establiments.CODCLI = dbo.Treb_Empresa.CodCli WHERE dbo.Treb_Empresa.IdTreballador = ? AND Treb_Empresa.IdTecEIRLLlocTreball IS ? AND IdMedEIRLLlocTreball IS ? ) Where ? = %d" + + benchmarks := []struct { + name string + query string + }{ + {"Escaping", `INSERT INTO delayed_jobs (attempts, created_at, failed_at, handler, last_error, locked_at, locked_by, priority, queue, run_at, updated_at) VALUES (0, '2016-12-04 17:09:59', NULL, '--- !ruby/object:Delayed::PerformableMethod\nobject: !ruby/object:Item\n store:\n - a simple string\n - an \'escaped \' string\n - another \'escaped\' string\n - 42\n string: a string with many \\\\\'escapes\\\\\'\nmethod_name: :show_store\nargs: []\n', NULL, NULL, NULL, 0, NULL, '2016-12-04 17:09:59', '2016-12-04 17:09:59')`}, + {"Grouping", `INSERT INTO delayed_jobs (created_at, failed_at, handler) VALUES (0, '2016-12-04 17:09:59', NULL), (0, '2016-12-04 17:09:59', NULL), (0, '2016-12-04 17:09:59', NULL), (0, '2016-12-04 17:09:59', NULL)`}, + {"Large", LargeQuery}, + {"Complex", ComplexQuery}, + {"SuperLarge", fmt.Sprintf(superLargeQuery, 1)}, + } + obfuscator := NewObfuscator( + WithReplaceDigits(true), + WithDollarQuotedFunc(true), + ) + + normalizer := NewNormalizer( + WithCollectComments(true), + WithCollectCommands(true), + WithCollectTables(true), + WithKeepSQLAlias(false), + ) + + for _, bm := range benchmarks { + b.Run(bm.name+"/"+strconv.Itoa(len(bm.query)), func(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _, err := ObfuscateAndNormalize(bm.query, obfuscator, normalizer) + if err != nil { + b.Fatal(err) + } + } + }) + } +} diff --git a/obfuscate_and_normalize_test.go b/obfuscate_and_normalize_test.go new file mode 100644 index 0000000..02ce7e0 --- /dev/null +++ b/obfuscate_and_normalize_test.go @@ -0,0 +1,113 @@ +package sqllexer + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestObfuscationAndNormalization(t *testing.T) { + tests := []struct { + input string + expected string + statementMetadata StatementMetadata + }{ + { + input: "SELECT 1", + expected: "SELECT ?", + statementMetadata: StatementMetadata{ + Tables: []string{}, + Comments: []string{}, + Commands: []string{"SELECT"}, + }, + }, + { + input: ` + /*dddbs='orders-mysql',dde='dbm-agent-integration',ddps='orders-app',ddpv='7825a16',traceparent='00-000000000000000068e229d784ee697c-569d1b940c1fb3ac-00'*/ + /* date='12%2F31',key='val' */ + SELECT * FROM users WHERE id = 1`, + expected: "SELECT * FROM users WHERE id = ?", + statementMetadata: StatementMetadata{ + Tables: []string{"users"}, + Comments: []string{"/*dddbs='orders-mysql',dde='dbm-agent-integration',ddps='orders-app',ddpv='7825a16',traceparent='00-000000000000000068e229d784ee697c-569d1b940c1fb3ac-00'*/", "/* date='12%2F31',key='val' */"}, + Commands: []string{"SELECT"}, + }, + }, + { + input: "SELECT * FROM users WHERE id IN (1, 2) and name IN ARRAY[3, 4]", + expected: "SELECT * FROM users WHERE id IN ( ? ) and name IN ARRAY [ ? ]", + statementMetadata: StatementMetadata{ + Tables: []string{"users"}, + Comments: []string{}, + Commands: []string{"SELECT"}, + }, + }, + { + input: ` + SELECT h.id, h.org_id, h.name, ha.name as alias, h.created + FROM vs?.host h + JOIN vs?.host_alias ha on ha.host_id = h.id + WHERE ha.org_id = 1 AND ha.name = ANY ('3', '4') + `, + expected: "SELECT h.id, h.org_id, h.name, ha.name, h.created FROM vs?.host h JOIN vs?.host_alias ha on ha.host_id = h.id WHERE ha.org_id = ? AND ha.name = ANY ( ? )", + statementMetadata: StatementMetadata{ + Tables: []string{"vs?.host", "vs?.host_alias"}, + Comments: []string{}, + Commands: []string{"SELECT", "JOIN"}, + }, + }, + { + input: "/* this is a comment */ SELECT * FROM users WHERE id = '2'", + expected: "SELECT * FROM users WHERE id = ?", + statementMetadata: StatementMetadata{ + Tables: []string{"users"}, + Comments: []string{"/* this is a comment */"}, + Commands: []string{"SELECT"}, + }, + }, + { + input: ` + /* this is a +multiline comment */ + SELECT * FROM users /* comment comment */ WHERE id = 'XXX' + -- this is another comment + `, + expected: "SELECT * FROM users WHERE id = ?", + statementMetadata: StatementMetadata{ + Tables: []string{"users"}, + Comments: []string{"/* this is a \nmultiline comment */", "/* comment comment */", "-- this is another comment"}, + Commands: []string{"SELECT"}, + }, + }, + { + input: "SELECT u.id as ID, u.name as Name FROM users as u WHERE u.id = 1", + expected: "SELECT u.id, u.name FROM users WHERE u.id = ?", + statementMetadata: StatementMetadata{ + Tables: []string{"users"}, + Comments: []string{}, + Commands: []string{"SELECT"}, + }, + }, + } + + obfuscator := NewObfuscator( + WithReplaceDigits(true), + WithDollarQuotedFunc(true), + ) + + normalizer := NewNormalizer( + WithCollectComments(true), + WithCollectCommands(true), + WithCollectTables(true), + WithKeepSQLAlias(false), + ) + + for _, test := range tests { + t.Run("", func(t *testing.T) { + got, statementMetadata, err := ObfuscateAndNormalize(test.input, obfuscator, normalizer) + assert.NoError(t, err) + assert.Equal(t, test.expected, got) + assert.Equal(t, &test.statementMetadata, statementMetadata) + }) + } +} diff --git a/obfuscator.go b/obfuscator.go index 84e2067..a0ad253 100644 --- a/obfuscator.go +++ b/obfuscator.go @@ -54,31 +54,37 @@ func (o *Obfuscator) Obfuscate(input string, lexerOpts ...lexerOption) string { lexerOpts..., ) for _, token := range lexer.ScanAll() { - switch token.Type { - case NUMBER: - obfuscatedSQL.WriteString(NumberPlaceholder) - case DOLLAR_QUOTED_FUNCTION: - if o.config.DollarQuotedFunc { - // obfuscate the content of dollar quoted function - quotedFunc := token.Value[6 : len(token.Value)-6] // remove the $func$ prefix and suffix - obfuscatedSQL.WriteString("$func$") - obfuscatedSQL.WriteString(o.Obfuscate(quotedFunc, lexerOpts...)) - obfuscatedSQL.WriteString("$func$") - } else { - obfuscatedSQL.WriteString(StringPlaceholder) - } - case STRING, INCOMPLETE_STRING, DOLLAR_QUOTED_STRING: - obfuscatedSQL.WriteString(StringPlaceholder) - case IDENT: - if o.config.ReplaceDigits { - obfuscatedSQL.WriteString(replaceDigits(token.Value, "?")) - } else { - obfuscatedSQL.WriteString(token.Value) - } - default: - obfuscatedSQL.WriteString(token.Value) - } + obfuscatedSQL.WriteString(o.ObfuscateTokenValue(token, lexerOpts...)) } return strings.TrimSpace(obfuscatedSQL.String()) } + +func (o *Obfuscator) ObfuscateTokenValue(token Token, lexerOpts ...lexerOption) string { + switch token.Type { + case NUMBER: + return NumberPlaceholder + case DOLLAR_QUOTED_FUNCTION: + if o.config.DollarQuotedFunc { + // obfuscate the content of dollar quoted function + quotedFunc := token.Value[6 : len(token.Value)-6] // remove the $func$ prefix and suffix + var obfuscatedDollarQuotedFunc strings.Builder + obfuscatedDollarQuotedFunc.WriteString("$func$") + obfuscatedDollarQuotedFunc.WriteString(o.Obfuscate(quotedFunc, lexerOpts...)) + obfuscatedDollarQuotedFunc.WriteString("$func$") + return obfuscatedDollarQuotedFunc.String() + } else { + return StringPlaceholder + } + case STRING, INCOMPLETE_STRING, DOLLAR_QUOTED_STRING: + return StringPlaceholder + case IDENT: + if o.config.ReplaceDigits { + return replaceDigits(token.Value, "?") + } else { + return token.Value + } + default: + return token.Value + } +}