diff --git a/obfuscator.go b/obfuscator.go index a0ad253..ee0e658 100644 --- a/obfuscator.go +++ b/obfuscator.go @@ -5,8 +5,11 @@ import ( ) type obfuscatorConfig struct { - ReplaceDigits bool - DollarQuotedFunc bool + DollarQuotedFunc bool + ReplaceDigits bool + ReplacePositionalParameter bool + ReplaceBoolean bool + ReplaceNull bool } type obfuscatorOption func(*obfuscatorConfig) @@ -17,6 +20,24 @@ func WithReplaceDigits(replaceDigits bool) obfuscatorOption { } } +func WithReplacePositionalParameter(replacePositionalParameter bool) obfuscatorOption { + return func(c *obfuscatorConfig) { + c.ReplacePositionalParameter = replacePositionalParameter + } +} + +func WithReplaceBoolean(replaceBoolean bool) obfuscatorOption { + return func(c *obfuscatorConfig) { + c.ReplaceBoolean = replaceBoolean + } +} + +func WithReplaceNull(replaceNull bool) obfuscatorOption { + return func(c *obfuscatorConfig) { + c.ReplaceNull = replaceNull + } +} + func WithDollarQuotedFunc(dollarQuotedFunc bool) obfuscatorOption { return func(c *obfuscatorConfig) { c.DollarQuotedFunc = dollarQuotedFunc @@ -78,7 +99,20 @@ func (o *Obfuscator) ObfuscateTokenValue(token Token, lexerOpts ...lexerOption) } case STRING, INCOMPLETE_STRING, DOLLAR_QUOTED_STRING: return StringPlaceholder + case POSITIONAL_PARAMETER: + if o.config.ReplacePositionalParameter { + return StringPlaceholder + } else { + return token.Value + } case IDENT: + if o.config.ReplaceBoolean && isBoolean(token.Value) { + return StringPlaceholder + } + if o.config.ReplaceNull && isNull(token.Value) { + return StringPlaceholder + } + if o.config.ReplaceDigits { return replaceDigits(token.Value, "?") } else { diff --git a/obfuscator_test.go b/obfuscator_test.go index 74549fe..e7d8a4d 100644 --- a/obfuscator_test.go +++ b/obfuscator_test.go @@ -9,11 +9,14 @@ import ( func TestObfuscator(t *testing.T) { tests := []struct { - input string - expected string - replaceDigits bool - dollarQuotedFunc bool - dbms DBMSType + input string + expected string + replaceDigits bool + replacePositionalParameter bool + replaceBoolean bool + replaceNull bool + dollarQuotedFunc bool + dbms DBMSType }{ { input: "SELECT * FROM users where id = 1", @@ -55,6 +58,30 @@ func TestObfuscator(t *testing.T) { expected: "SELECT * FROM users? where id = ?", replaceDigits: true, }, + { + input: "SELECT * FROM users where id is NULL and is_active = TRUE and is_admin = FALSE", + expected: "SELECT * FROM users where id is NULL and is_active = TRUE and is_admin = FALSE", + replaceBoolean: false, + replaceNull: false, + }, + { + input: "SELECT * FROM users where id is NULL and is_active = TRUE and is_admin = FALSE", + expected: "SELECT * FROM users where id is NULL and is_active = ? and is_admin = ?", + replaceBoolean: true, + replaceNull: false, + }, + { + input: "SELECT * FROM users where id is NULL and is_active = TRUE and is_admin = FALSE", + expected: "SELECT * FROM users where id is ? and is_active = TRUE and is_admin = FALSE", + replaceBoolean: false, + replaceNull: true, + }, + { + input: "SELECT * FROM users where id is NULL and is_active = TRUE and is_admin = FALSE", + expected: "SELECT * FROM users where id is ? and is_active = ? and is_admin = ?", + replaceBoolean: true, + replaceNull: true, + }, { input: "SELECT * FROM users where id = 1 -- this is a comment", expected: "SELECT * FROM users where id = ? -- this is a comment", @@ -206,6 +233,16 @@ func TestObfuscator(t *testing.T) { expected: `USING - SELECT`, replaceDigits: true, }, + { + input: "SELECT * FROM users where id = $1", + expected: `SELECT * FROM users where id = $1`, + replacePositionalParameter: false, + }, + { + input: "SELECT * FROM users where id = $1", + expected: `SELECT * FROM users where id = ?`, + replacePositionalParameter: true, + }, { input: `SELECT * FROM "public"."users" where id = 2`, expected: `SELECT * FROM "public"."users" where id = ?`, @@ -335,6 +372,9 @@ func TestObfuscator(t *testing.T) { t.Run("", func(t *testing.T) { obfuscator := NewObfuscator( WithReplaceDigits(tt.replaceDigits), + WithReplacePositionalParameter(tt.replacePositionalParameter), + WithReplaceBoolean(tt.replaceBoolean), + WithReplaceNull(tt.replaceNull), WithDollarQuotedFunc(tt.dollarQuotedFunc), ) got := obfuscator.Obfuscate(tt.input, WithDBMS(tt.dbms)) diff --git a/sqllexer.go b/sqllexer.go index 113177f..3d58016 100644 --- a/sqllexer.go +++ b/sqllexer.go @@ -17,7 +17,7 @@ const ( PUNCTUATION // punctuation DOLLAR_QUOTED_FUNCTION // dollar quoted function DOLLAR_QUOTED_STRING // dollar quoted string - NUMBERED_PARAMETER // numbered parameter + POSITIONAL_PARAMETER // numbered parameter UNKNOWN // unknown token ) @@ -121,7 +121,7 @@ func (s *Lexer) Scan() Token { case ch == '$': if isDigit(s.lookAhead(1)) { // if the dollar sign is followed by a digit, then it's a numbered parameter - return s.scanNumberedParameter() + return s.scanPositionalParameter() } if s.config.DBMS == DBMSSQLServer && isLetter(s.lookAhead(1)) { return s.scanIdentifier() @@ -393,7 +393,7 @@ func (s *Lexer) scanDollarQuotedString() Token { return Token{ERROR, s.src[s.start:s.cursor]} } -func (s *Lexer) scanNumberedParameter() Token { +func (s *Lexer) scanPositionalParameter() Token { s.start = s.cursor ch := s.nextBy(2) // consume the dollar sign and the number for { @@ -402,7 +402,7 @@ func (s *Lexer) scanNumberedParameter() Token { } ch = s.next() } - return Token{NUMBERED_PARAMETER, s.src[s.start:s.cursor]} + return Token{POSITIONAL_PARAMETER, s.src[s.start:s.cursor]} } func (s *Lexer) scanUnknown() Token { diff --git a/sqllexer_test.go b/sqllexer_test.go index f6f827f..42f3582 100644 --- a/sqllexer_test.go +++ b/sqllexer_test.go @@ -312,7 +312,7 @@ func TestLexer(t *testing.T) { {WS, " "}, {OPERATOR, "="}, {WS, " "}, - {NUMBERED_PARAMETER, "$1"}, + {POSITIONAL_PARAMETER, "$1"}, }, }, { diff --git a/sqllexer_utils.go b/sqllexer_utils.go index 2e097b6..a1b3510 100644 --- a/sqllexer_utils.go +++ b/sqllexer_utils.go @@ -116,6 +116,14 @@ func isSQLKeyword(token Token) bool { return token.Type == IDENT && keywordsRegex.MatchString(token.Value) } +func isBoolean(ident string) bool { + return strings.ToUpper(ident) == "TRUE" || strings.ToUpper(ident) == "FALSE" +} + +func isNull(ident string) bool { + return strings.ToUpper(ident) == "NULL" +} + func replaceDigits(input string, placeholder string) string { var builder strings.Builder