diff --git a/obfuscator_test.go b/obfuscator_test.go index e35d7b2..f238b59 100644 --- a/obfuscator_test.go +++ b/obfuscator_test.go @@ -398,6 +398,16 @@ func TestObfuscator(t *testing.T) { expected: "begin execute immediate ?; end;", dbms: DBMSOracle, }, + { + input: "SELECT * FROM #users where id = @id and name = @1", + expected: "SELECT * FROM #users where id = @id and name = @1", + dbms: DBMSSQLServer, + }, + { + input: "SELECT * FROM users where id = :id and name = :1", + expected: "SELECT * FROM users where id = :id and name = :1", + dbms: DBMSOracle, + }, } for _, tt := range tests { diff --git a/sqllexer.go b/sqllexer.go index 749e8e6..6651849 100644 --- a/sqllexer.go +++ b/sqllexer.go @@ -133,12 +133,12 @@ func (s *Lexer) Scan() Token { } return s.scanDollarQuotedString() case ch == ':': - if s.config.DBMS == DBMSOracle && isLetter(s.lookAhead(1)) { + if s.config.DBMS == DBMSOracle && isAlphaNumeric(s.lookAhead(1)) { return s.scanBindParameter() } return s.scanOperator(ch) case ch == '@': - if s.config.DBMS == DBMSSQLServer && isLetter(s.lookAhead(1)) { + if s.config.DBMS == DBMSSQLServer && isAlphaNumeric(s.lookAhead(1)) { return s.scanBindParameter() } return s.scanOperator(ch) @@ -147,6 +147,11 @@ func (s *Lexer) Scan() Token { return s.scanDoubleQuotedIdentifier('`') } fallthrough + case ch == '#': + if s.config.DBMS == DBMSSQLServer { + return s.scanIdentifier('#') + } + fallthrough case isOperator(ch): return s.scanOperator(ch) case isPunctuation(ch): @@ -443,7 +448,7 @@ func (s *Lexer) scanBindParameter() Token { s.start = s.cursor ch := s.nextBy(2) // consume the (colon|at sign) and the char for { - if !isLetter(ch) && !isDigit(ch) { + if !isAlphaNumeric(ch) { break } ch = s.next() diff --git a/sqllexer_test.go b/sqllexer_test.go index 6827347..4e0bed7 100644 --- a/sqllexer_test.go +++ b/sqllexer_test.go @@ -472,7 +472,7 @@ func TestLexer(t *testing.T) { }, { name: "select with bind parameter", - input: "SELECT * FROM users where id = :id", + input: "SELECT * FROM users where id = :id and name = :1", expected: []Token{ {IDENT, "SELECT"}, {WS, " "}, @@ -489,12 +489,20 @@ func TestLexer(t *testing.T) { {OPERATOR, "="}, {WS, " "}, {BIND_PARAMETER, ":id"}, + {WS, " "}, + {IDENT, "and"}, + {WS, " "}, + {IDENT, "name"}, + {WS, " "}, + {OPERATOR, "="}, + {WS, " "}, + {BIND_PARAMETER, ":1"}, }, lexerOpts: []lexerOption{WithDBMS(DBMSOracle)}, }, { name: "select with bind parameter", - input: "SELECT * FROM users where id = @id", + input: "SELECT * FROM users where id = @id and name = @1", expected: []Token{ {IDENT, "SELECT"}, {WS, " "}, @@ -511,6 +519,14 @@ func TestLexer(t *testing.T) { {OPERATOR, "="}, {WS, " "}, {BIND_PARAMETER, "@id"}, + {WS, " "}, + {IDENT, "and"}, + {WS, " "}, + {IDENT, "name"}, + {WS, " "}, + {OPERATOR, "="}, + {WS, " "}, + {BIND_PARAMETER, "@1"}, }, lexerOpts: []lexerOption{WithDBMS(DBMSSQLServer)}, }, @@ -574,6 +590,20 @@ func TestLexer(t *testing.T) { {IDENT, "users"}, }, }, + { + name: "Tokenize temp table", + input: `SELECT * FROM #temp`, + expected: []Token{ + {IDENT, "SELECT"}, + {WS, " "}, + {WILDCARD, "*"}, + {WS, " "}, + {IDENT, "FROM"}, + {WS, " "}, + {IDENT, "#temp"}, + }, + lexerOpts: []lexerOption{WithDBMS(DBMSSQLServer)}, + }, } for _, tt := range tests { diff --git a/sqllexer_utils.go b/sqllexer_utils.go index 9382cd1..49bd7a7 100644 --- a/sqllexer_utils.go +++ b/sqllexer_utils.go @@ -160,6 +160,10 @@ func isLetter(ch rune) bool { return unicode.IsLetter(ch) || ch == '_' } +func isAlphaNumeric(ch rune) bool { + return isLetter(ch) || isDigit(ch) +} + func isDoubleQuote(ch rune) bool { return ch == '"' }