Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix tokenize named parameter and temp table #27

Merged
merged 1 commit into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions obfuscator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,16 @@ func TestObfuscator(t *testing.T) {
expected: "begin execute immediate ?; end;",
dbms: DBMSOracle,
},
{
input: "SELECT * FROM #users where id = @id and name = @1",
expected: "SELECT * FROM #users where id = @id and name = @1",
dbms: DBMSSQLServer,
},
{
input: "SELECT * FROM users where id = :id and name = :1",
expected: "SELECT * FROM users where id = :id and name = :1",
dbms: DBMSOracle,
},
}

for _, tt := range tests {
Expand Down
11 changes: 8 additions & 3 deletions sqllexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,12 @@ func (s *Lexer) Scan() Token {
}
return s.scanDollarQuotedString()
case ch == ':':
if s.config.DBMS == DBMSOracle && isLetter(s.lookAhead(1)) {
if s.config.DBMS == DBMSOracle && isAlphaNumeric(s.lookAhead(1)) {
return s.scanBindParameter()
}
return s.scanOperator(ch)
case ch == '@':
if s.config.DBMS == DBMSSQLServer && isLetter(s.lookAhead(1)) {
if s.config.DBMS == DBMSSQLServer && isAlphaNumeric(s.lookAhead(1)) {
return s.scanBindParameter()
}
return s.scanOperator(ch)
Expand All @@ -147,6 +147,11 @@ func (s *Lexer) Scan() Token {
return s.scanDoubleQuotedIdentifier('`')
}
fallthrough
case ch == '#':
if s.config.DBMS == DBMSSQLServer {
return s.scanIdentifier('#')
}
fallthrough
case isOperator(ch):
return s.scanOperator(ch)
case isPunctuation(ch):
Expand Down Expand Up @@ -443,7 +448,7 @@ func (s *Lexer) scanBindParameter() Token {
s.start = s.cursor
ch := s.nextBy(2) // consume the (colon|at sign) and the char
for {
if !isLetter(ch) && !isDigit(ch) {
if !isAlphaNumeric(ch) {
break
}
ch = s.next()
Expand Down
34 changes: 32 additions & 2 deletions sqllexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ func TestLexer(t *testing.T) {
},
{
name: "select with bind parameter",
input: "SELECT * FROM users where id = :id",
input: "SELECT * FROM users where id = :id and name = :1",
expected: []Token{
{IDENT, "SELECT"},
{WS, " "},
Expand All @@ -489,12 +489,20 @@ func TestLexer(t *testing.T) {
{OPERATOR, "="},
{WS, " "},
{BIND_PARAMETER, ":id"},
{WS, " "},
{IDENT, "and"},
{WS, " "},
{IDENT, "name"},
{WS, " "},
{OPERATOR, "="},
{WS, " "},
{BIND_PARAMETER, ":1"},
},
lexerOpts: []lexerOption{WithDBMS(DBMSOracle)},
},
{
name: "select with bind parameter",
input: "SELECT * FROM users where id = @id",
input: "SELECT * FROM users where id = @id and name = @1",
expected: []Token{
{IDENT, "SELECT"},
{WS, " "},
Expand All @@ -511,6 +519,14 @@ func TestLexer(t *testing.T) {
{OPERATOR, "="},
{WS, " "},
{BIND_PARAMETER, "@id"},
{WS, " "},
{IDENT, "and"},
{WS, " "},
{IDENT, "name"},
{WS, " "},
{OPERATOR, "="},
{WS, " "},
{BIND_PARAMETER, "@1"},
},
lexerOpts: []lexerOption{WithDBMS(DBMSSQLServer)},
},
Expand Down Expand Up @@ -574,6 +590,20 @@ func TestLexer(t *testing.T) {
{IDENT, "users"},
},
},
{
name: "Tokenize temp table",
input: `SELECT * FROM #temp`,
expected: []Token{
{IDENT, "SELECT"},
{WS, " "},
{WILDCARD, "*"},
{WS, " "},
{IDENT, "FROM"},
{WS, " "},
{IDENT, "#temp"},
},
lexerOpts: []lexerOption{WithDBMS(DBMSSQLServer)},
},
}

for _, tt := range tests {
Expand Down
4 changes: 4 additions & 0 deletions sqllexer_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ func isLetter(ch rune) bool {
return unicode.IsLetter(ch) || ch == '_'
}

func isAlphaNumeric(ch rune) bool {
return isLetter(ch) || isDigit(ch)
}

func isDoubleQuote(ch rune) bool {
return ch == '"'
}
Expand Down