diff --git a/normalizer.go b/normalizer.go index 6c8c12e..e9b91d2 100644 --- a/normalizer.go +++ b/normalizer.go @@ -17,11 +17,15 @@ type normalizerConfig struct { // CollectProcedure specifies whether the normalizer should extract and return procedure name as SQL metadata CollectProcedure bool - // KeepSQLAlias reports whether SQL aliases ("AS") should be truncated. + // KeepSQLAlias specifies whether SQL aliases ("AS") should be truncated. KeepSQLAlias bool - // UppercaseKeywords reports whether SQL keywords should be uppercased. + // UppercaseKeywords specifies whether SQL keywords should be uppercased. UppercaseKeywords bool + + // RemoveSpaceBetweenParentheses specifies whether spaces should be kept between parentheses. + // Spaces are inserted between parentheses by default. but this can be disabled by setting this to true. + RemoveSpaceBetweenParentheses bool } type normalizerOption func(*normalizerConfig) @@ -62,6 +66,12 @@ func WithCollectProcedures(collectProcedure bool) normalizerOption { } } +func WithRemoveSpaceBetweenParentheses(removeSpaceBetweenParentheses bool) normalizerOption { + return func(c *normalizerConfig) { + c.RemoveSpaceBetweenParentheses = removeSpaceBetweenParentheses + } +} + type StatementMetadata struct { Size int Tables []string @@ -189,7 +199,7 @@ func (n *Normalizer) normalizeSQL(token *Token, lastToken *Token, normalizedSQLB // if the last token is AS and the current token is not IDENT, // this could be a CTE like WITH ... AS (...), // so we do not discard the current token - appendWhitespace(lastToken, token, normalizedSQLBuilder) + n.appendWhitespace(lastToken, token, normalizedSQLBuilder) n.writeToken(lastToken, normalizedSQLBuilder) } } @@ -203,7 +213,7 @@ func (n *Normalizer) normalizeSQL(token *Token, lastToken *Token, normalizedSQLB } // determine if we should add a whitespace - appendWhitespace(lastToken, token, normalizedSQLBuilder) + n.appendWhitespace(lastToken, token, normalizedSQLBuilder) n.writeToken(token, normalizedSQLBuilder) *lastToken = *token @@ -242,6 +252,30 @@ func (n *Normalizer) isObfuscatedValueGroupable(token *Token, lastToken *Token, return false } +func (n *Normalizer) appendWhitespace(lastToken *Token, token *Token, normalizedSQLBuilder *strings.Builder) { + // do not add a space between parentheses if RemoveSpaceBetweenParentheses is true + if n.config.RemoveSpaceBetweenParentheses && (lastToken.Value == "(" || lastToken.Value == "[") { + return + } + + if n.config.RemoveSpaceBetweenParentheses && (token.Value == ")" || token.Value == "]") { + return + } + + switch token.Value { + case ",": + case "=": + if lastToken.Value == ":" { + // do not add a space before an equals if a colon was + // present before it. + break + } + fallthrough + default: + normalizedSQLBuilder.WriteString(" ") + } +} + func dedupeCollectedMetadata(metadata []string) (dedupedMetadata []string, size int) { // Dedupe collected metadata // e.g. [SELECT, JOIN, SELECT, JOIN] -> [SELECT, JOIN] @@ -265,18 +299,3 @@ func dedupeStatementMetadata(info *StatementMetadata) { info.Procedures, procedureSize = dedupeCollectedMetadata(info.Procedures) info.Size += tablesSize + commentsSize + commandsSize + procedureSize } - -func appendWhitespace(lastToken *Token, token *Token, normalizedSQLBuilder *strings.Builder) { - switch token.Value { - case ",": - case "=": - if lastToken.Value == ":" { - // do not add a space before an equals if a colon was - // present before it. - break - } - fallthrough - default: - normalizedSQLBuilder.WriteString(" ") - } -} diff --git a/normalizer_test.go b/normalizer_test.go index a3dcc07..b55ec34 100644 --- a/normalizer_test.go +++ b/normalizer_test.go @@ -760,19 +760,19 @@ func TestGroupObfuscatedValues(t *testing.T) { expected string }{ { - input: "( ? )", + input: "(?)", expected: "( ? )", }, { - input: "(?, ?)", + input: "( ? )", expected: "( ? )", }, { - input: "( ?, ?, ? )", + input: "(?, ?)", expected: "( ? )", }, { - input: "( ? )", + input: "( ?, ?, ? )", expected: "( ? )", }, { @@ -815,6 +815,46 @@ func TestGroupObfuscatedValues(t *testing.T) { input: "ANY(?, ?)", expected: "ANY ( ? )", }, + { + input: "(?)", + expected: "( ? )", + }, + { + input: "( ? )", + expected: "( ? )", + }, + { + input: "(?, ?)", + expected: "( ? )", + }, + { + input: "( ?, ?, ? )", + expected: "( ? )", + }, + { + input: "( ?, ? )", + expected: "( ? )", + }, + { + input: "( ?,?)", + expected: "( ? )", + }, + { + input: "[ ? ]", + expected: "[ ? ]", + }, + { + input: "[?, ?]", + expected: "[ ? ]", + }, + { + input: "[ ?, ?, ? ]", + expected: "[ ? ]", + }, + { + input: "[ ? ]", + expected: "[ ? ]", + }, } for _, test := range tests { @@ -899,6 +939,38 @@ func TestNormalizerStoredProcedure(t *testing.T) { } } +func TestNormalizerWithoutSpaceBetweenParentheses(t *testing.T) { + tests := []struct { + input string + expected string + }{ + { + input: "SELECT count(*) FROM users", + expected: "SELECT count (*) FROM users", + }, + { + input: "SELECT * FROM users WHERE id IN(?, ?)", + expected: "SELECT * FROM users WHERE id IN (?)", + }, + { + input: "INSERT INTO my_table (numbers) VALUES (array[1,2,3])", + expected: "INSERT INTO my_table (numbers) VALUES (array [1, 2, 3])", + }, + { + input: "BEGIN dbms_output.enable (?); END", + expected: "BEGIN dbms_output.enable (?) ; END", + }, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + normalizer := NewNormalizer(WithRemoveSpaceBetweenParentheses(true)) + got, _, _ := normalizer.Normalize(test.input) + assert.Equal(t, test.expected, got) + }) + } +} + func ExampleNormalizer() { normalizer := NewNormalizer( WithCollectComments(true),