Skip to content

Commit

Permalink
Merge pull request #16 from DataDog/zhengda.lu/unicode-ident
Browse files Browse the repository at this point in the history
scan unicode identifier
  • Loading branch information
lu-zhengda authored Oct 19, 2023
2 parents 6af0438 + 4ca113d commit a69f508
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 21 deletions.
45 changes: 24 additions & 21 deletions sqllexer.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package sqllexer

import "unicode/utf8"

type TokenType int

const (
Expand Down Expand Up @@ -98,7 +100,7 @@ func (s *Lexer) Scan() Token {
case isWhitespace(ch):
return s.scanWhitespace()
case isLetter(ch):
return s.scanIdentifier()
return s.scanIdentifier(ch)
case isDoubleQuote(ch):
return s.scanDoubleQuotedIdentifier('"')
case isSingleQuote(ch):
Expand All @@ -114,9 +116,9 @@ func (s *Lexer) Scan() Token {
if isDigit(nextCh) || nextCh == '.' {
return s.scanNumberWithLeadingSign()
}
return s.scanOperator()
return s.scanOperator(ch)
case isDigit(ch):
return s.scanNumber()
return s.scanNumber(ch)
case isWildcard(ch):
return s.scanWildcard()
case ch == '$':
Expand All @@ -125,26 +127,26 @@ func (s *Lexer) Scan() Token {
return s.scanPositionalParameter()
}
if s.config.DBMS == DBMSSQLServer && isLetter(s.lookAhead(1)) {
return s.scanIdentifier()
return s.scanIdentifier(ch)
}
return s.scanDollarQuotedString()
case ch == ':':
if s.config.DBMS == DBMSOracle && isLetter(s.lookAhead(1)) {
return s.scanBindParameter()
}
return s.scanOperator()
return s.scanOperator(ch)
case ch == '@':
if s.config.DBMS == DBMSSQLServer && isLetter(s.lookAhead(1)) {
return s.scanBindParameter()
}
return s.scanOperator()
return s.scanOperator(ch)
case ch == '`':
if s.config.DBMS == DBMSMySQL {
return s.scanDoubleQuotedIdentifier('`')
}
fallthrough
case isOperator(ch):
return s.scanOperator()
return s.scanOperator(ch)
case isPunctuation(ch):
if ch == '[' && s.config.DBMS == DBMSSQLServer {
return s.scanDoubleQuotedIdentifier('[')
Expand All @@ -162,7 +164,8 @@ func (s *Lexer) lookAhead(n int) rune {
if s.cursor+n >= len(s.src) || s.cursor+n < 0 {
return 0
}
return rune(s.src[s.cursor+n])
r, _ := utf8.DecodeRuneInString(s.src[s.cursor+n:])
return r
}

// peek returns the rune at the cursor position.
Expand All @@ -177,10 +180,11 @@ func (s *Lexer) nextBy(n int) rune {
return 0
}
s.cursor += n
if s.cursor == len(s.src) {
if s.cursor >= len(s.src) {
return 0
}
return rune(s.src[s.cursor])
r, _ := utf8.DecodeRuneInString(s.src[s.cursor:])
return r
}

// next advances the cursor by 1 position and returns the rune at the cursor position.
Expand All @@ -202,17 +206,17 @@ func (s *Lexer) matchAt(match []rune) bool {

func (s *Lexer) scanNumberWithLeadingSign() Token {
s.start = s.cursor
s.next() // consume the leading sign
return s.scanNumberic()
ch := s.next() // consume the leading sign
return s.scanNumberic(ch)
}

func (s *Lexer) scanNumber() Token {
func (s *Lexer) scanNumber(ch rune) Token {
s.start = s.cursor
return s.scanNumberic()
return s.scanNumberic(ch)
}

func (s *Lexer) scanNumberic() Token {
if s.peek() == '0' {
func (s *Lexer) scanNumberic(ch rune) Token {
if ch == '0' {
nextCh := s.lookAhead(1)
if nextCh == 'x' || nextCh == 'X' {
return s.scanHexNumber()
Expand Down Expand Up @@ -293,12 +297,12 @@ func (s *Lexer) scanString() Token {
}
}

func (s *Lexer) scanIdentifier() Token {
func (s *Lexer) scanIdentifier(ch rune) Token {
// NOTE: this func does not distinguish between SQL keywords and identifiers
s.start = s.cursor
ch := s.next()
ch = s.nextBy(utf8.RuneLen(ch))
for isLetter(ch) || isDigit(ch) || ch == '.' || ch == '?' || ch == '$' {
ch = s.next()
ch = s.nextBy(utf8.RuneLen(ch))
}
// return the token as uppercase so that we can do case insensitive matching
return Token{IDENT, s.src[s.start:s.cursor]}
Expand Down Expand Up @@ -344,9 +348,8 @@ func (s *Lexer) scanWhitespace() Token {
return Token{WS, s.src[s.start:s.cursor]}
}

func (s *Lexer) scanOperator() Token {
func (s *Lexer) scanOperator(lastCh rune) Token {
s.start = s.cursor
lastCh := s.peek()
ch := s.next()
for isOperator(ch) && !(lastCh == '=' && ch == '?') {
// hack: we don't want to treat "=?" as an single operator
Expand Down
59 changes: 59 additions & 0 deletions sqllexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,65 @@ func TestLexer(t *testing.T) {
}
}

func TestLexerUnicode(t *testing.T) {
tests := []struct {
input string
expected []Token
lexerOpts []lexerOption
}{
{
input: `Descripció_CAT`,
expected: []Token{
{IDENT, `Descripció_CAT`},
},
},
{
input: `世界`,
expected: []Token{
{IDENT, `世界`},
},
},
{
input: `こんにちは`,
expected: []Token{
{IDENT, `こんにちは`},
},
},
{
input: `안녕하세요`,
expected: []Token{
{IDENT, `안녕하세요`},
},
},
{
input: `über`,
expected: []Token{
{IDENT, `über`},
},
},
{
input: `résumé`,
expected: []Token{
{IDENT, `résumé`},
},
},
{
input: `"über"`,
expected: []Token{
{IDENT, `"über"`},
},
},
}

for _, tt := range tests {
t.Run("", func(t *testing.T) {
lexer := New(tt.input, tt.lexerOpts...)
tokens := lexer.ScanAll()
assert.Equal(t, tt.expected, tokens)
})
}
}

func ExampleLexer() {
query := "SELECT * FROM users WHERE id = 1"
lexer := New(query)
Expand Down

0 comments on commit a69f508

Please sign in to comment.