Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add normalization option to keep trailing semicolon #23

Merged
merged 1 commit into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion normalizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ type normalizerConfig struct {
// RemoveSpaceBetweenParentheses specifies whether spaces should be kept between parentheses.
// Spaces are inserted between parentheses by default. but this can be disabled by setting this to true.
RemoveSpaceBetweenParentheses bool

// KeepTrailingSemicolon specifies whether the normalizer should keep the trailing semicolon.
// The trailing semicolon is removed by default, but this can be disabled by setting this to true.
// PL/SQL requires a trailing semicolon, so this should be set to true when normalizing PL/SQL.
KeepTrailingSemicolon bool
}

type normalizerOption func(*normalizerConfig)
Expand Down Expand Up @@ -72,6 +77,12 @@ func WithRemoveSpaceBetweenParentheses(removeSpaceBetweenParentheses bool) norma
}
}

func WithKeepTrailingSemicolon(keepTrailingSemicolon bool) normalizerOption {
return func(c *normalizerConfig) {
c.KeepTrailingSemicolon = keepTrailingSemicolon
}
}

type StatementMetadata struct {
Size int
Tables []string
Expand Down Expand Up @@ -135,7 +146,7 @@ func (n *Normalizer) Normalize(input string, lexerOpts ...lexerOption) (normaliz
// Dedupe collected metadata
dedupeStatementMetadata(statementMetadata)

return strings.TrimSpace(strings.TrimSuffix(normalizedSQL, ";")), statementMetadata, nil
return n.trimNormalizedSQL(normalizedSQL), statementMetadata, nil
}

func (n *Normalizer) collectMetadata(token *Token, lastToken *Token, statementMetadata *StatementMetadata) {
Expand Down Expand Up @@ -264,6 +275,7 @@ func (n *Normalizer) appendWhitespace(lastToken *Token, token *Token, normalized

switch token.Value {
case ",":
case ";":
case "=":
if lastToken.Value == ":" {
// do not add a space before an equals if a colon was
Expand All @@ -276,6 +288,14 @@ func (n *Normalizer) appendWhitespace(lastToken *Token, token *Token, normalized
}
}

func (n *Normalizer) trimNormalizedSQL(normalizedSQL string) string {
if !n.config.KeepTrailingSemicolon {
// Remove trailing semicolon
normalizedSQL = strings.TrimSuffix(normalizedSQL, ";")
}
return strings.TrimSpace(normalizedSQL)
}

func dedupeCollectedMetadata(metadata []string) (dedupedMetadata []string, size int) {
// Dedupe collected metadata
// e.g. [SELECT, JOIN, SELECT, JOIN] -> [SELECT, JOIN]
Expand Down
36 changes: 30 additions & 6 deletions normalizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ multiline comment */
Delete FROM user? WHERE id = ?;
END;
`,
expected: "CREATE PROCEDURE test_procedure ( ) BEGIN SELECT * FROM users WHERE id = ? ; Update test_users set name = ? WHERE id = ? ; Delete FROM user? WHERE id = ? ; END",
expected: "CREATE PROCEDURE test_procedure ( ) BEGIN SELECT * FROM users WHERE id = ?; Update test_users set name = ? WHERE id = ?; Delete FROM user? WHERE id = ?; END",
statementMetadata: StatementMetadata{
Tables: []string{"users", "test_users", "user?"},
Comments: []string{},
Expand Down Expand Up @@ -222,7 +222,7 @@ multiline comment */
FROM cte
WHERE age <= ?;
`,
expected: "WITH cte AS ( SELECT id, name, age FROM person WHERE age > ? ) UPDATE person SET age = ? WHERE id IN ( SELECT id FROM cte ) ; INSERT INTO person ( name, age ) SELECT name, ? FROM cte WHERE age <= ?",
expected: "WITH cte AS ( SELECT id, name, age FROM person WHERE age > ? ) UPDATE person SET age = ? WHERE id IN ( SELECT id FROM cte ); INSERT INTO person ( name, age ) SELECT name, ? FROM cte WHERE age <= ?",
statementMetadata: StatementMetadata{
Tables: []string{"person", "cte"},
Comments: []string{},
Expand Down Expand Up @@ -879,7 +879,7 @@ func TestNormalizerStoredProcedure(t *testing.T) {
SELECT * FROM users WHERE id = id;
END;
`,
expected: "CREATE PROCEDURE TestProcedure ( id INT ) BEGIN SELECT * FROM users WHERE id = id ; END",
expected: "CREATE PROCEDURE TestProcedure ( id INT ) BEGIN SELECT * FROM users WHERE id = id; END",
statementMetadata: StatementMetadata{
Tables: []string{"users"},
Comments: []string{},
Expand All @@ -895,7 +895,7 @@ func TestNormalizerStoredProcedure(t *testing.T) {
UPDATE users SET name = 'test' WHERE id = id;
END;
`,
expected: "CREATE PROC TestProcedure ( id INT ) BEGIN UPDATE users SET name = 'test' WHERE id = id ; END",
expected: "CREATE PROC TestProcedure ( id INT ) BEGIN UPDATE users SET name = 'test' WHERE id = id; END",
statementMetadata: StatementMetadata{
Tables: []string{"users"},
Comments: []string{},
Expand All @@ -911,7 +911,7 @@ func TestNormalizerStoredProcedure(t *testing.T) {
DELETE FROM users WHERE id = id;
END;
`,
expected: "CREATE OR REPLACE PROCEDURE TestProcedure ( id INT ) BEGIN DELETE FROM users WHERE id = id ; END",
expected: "CREATE OR REPLACE PROCEDURE TestProcedure ( id INT ) BEGIN DELETE FROM users WHERE id = id; END",
statementMetadata: StatementMetadata{
Tables: []string{"users"},
Comments: []string{},
Expand Down Expand Up @@ -958,7 +958,7 @@ func TestNormalizerWithoutSpaceBetweenParentheses(t *testing.T) {
},
{
input: "BEGIN dbms_output.enable (?); END",
expected: "BEGIN dbms_output.enable (?) ; END",
expected: "BEGIN dbms_output.enable (?); END",
},
}

Expand All @@ -971,6 +971,30 @@ func TestNormalizerWithoutSpaceBetweenParentheses(t *testing.T) {
}
}

func TestNormalizerKeepTrailingSemicolon(t *testing.T) {
tests := []struct {
input string
expected string
}{
{
input: "SELECT * FROM users;",
expected: "SELECT * FROM users;",
},
{
input: "BEGIN NULL; END;",
expected: "BEGIN NULL; END;",
},
}

for _, test := range tests {
t.Run("", func(t *testing.T) {
normalizer := NewNormalizer(WithKeepTrailingSemicolon(true))
got, _, _ := normalizer.Normalize(test.input)
assert.Equal(t, test.expected, got)
})
}
}

func ExampleNormalizer() {
normalizer := NewNormalizer(
WithCollectComments(true),
Expand Down
2 changes: 1 addition & 1 deletion obfuscate_and_normalize.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ func ObfuscateAndNormalize(input string, obfuscator *Obfuscator, normalizer *Nor
// Dedupe collected metadata
dedupeStatementMetadata(statementMetadata)

return strings.TrimSpace(strings.TrimSuffix(normalizedSQL, ";")), statementMetadata, nil
return normalizer.trimNormalizedSQL(normalizedSQL), statementMetadata, nil
}
2 changes: 1 addition & 1 deletion obfuscate_and_normalize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ multiline comment */
},
{
input: "begin execute immediate 'alter session set sql_trace=true'; end;",
expected: "begin execute immediate ? ; end",
expected: "begin execute immediate ?; end",
statementMetadata: StatementMetadata{
Tables: []string{},
Comments: []string{},
Expand Down