Skip to content

Commit

Permalink
Merge pull request #7 from DataDog/zhengda.lu/lexer
Browse files Browse the repository at this point in the history
improve sqllexer and obfuscator operation and memory efficiency
lu-zhengda authored Sep 11, 2023
2 parents 4404b23 + 0ead7c8 commit dcfea70
Showing 8 changed files with 309 additions and 175 deletions.
16 changes: 6 additions & 10 deletions normalizer.go
Original file line number Diff line number Diff line change
@@ -90,9 +90,9 @@ func (n *Normalizer) Normalize(input string) (normalized string, info *Statement
Commands: []string{},
}

var lastToken *Token // The last token that is not whitespace or comment
var lastToken Token // The last token that is not whitespace or comment

for token := range lexer.ScanAllTokens() {
for _, token := range lexer.ScanAll() {
if token.Type == COMMENT || token.Type == MULTILINE_COMMENT {
// Collect comments
if n.config.CollectComments {
@@ -102,7 +102,7 @@ func (n *Normalizer) Normalize(input string) (normalized string, info *Statement
if isCommand(strings.ToUpper(token.Value)) && n.config.CollectCommands {
// Collect commands
statementMetadata.Commands = append(statementMetadata.Commands, strings.ToUpper(token.Value))
} else if lastToken != nil && isTableIndicator(strings.ToUpper(lastToken.Value)) {
} else if isTableIndicator(strings.ToUpper(lastToken.Value)) {
// Collect table names
if n.config.CollectTables {
statementMetadata.Tables = append(statementMetadata.Tables, token.Value)
@@ -136,7 +136,7 @@ func (n *Normalizer) Normalize(input string) (normalized string, info *Statement
return strings.TrimSpace(normalizedSQL), statementMetadata, nil
}

func normalizeSQL(token *Token, lastToken *Token, statement string) string {
func normalizeSQL(token Token, lastToken Token, statement string) string {
if token.Type == WS || token.Type == COMMENT || token.Type == MULTILINE_COMMENT {
// We don't rely on the WS token to determine if we should add a whitespace
return statement
@@ -197,15 +197,11 @@ func dedupeStatementMetadata(info *StatementMetadata) {
info.Commands = dedupeCollectedMetadata(info.Commands)
}

func isSQLKeyword(token *Token) bool {
return token.Type == IDENT && keywordsRegex.MatchString(token.Value)
}

func appendWhitespace(lastToken *Token, token *Token, normalizedSQL string) string {
func appendWhitespace(lastToken Token, token Token, normalizedSQL string) string {
switch token.Value {
case ",":
case "=":
if lastToken != nil && lastToken.Value == ":" {
if lastToken.Value == ":" {
// do not add a space before an equals if a colon was
// present before it.
break
26 changes: 13 additions & 13 deletions normalizer_test.go
Original file line number Diff line number Diff line change
@@ -293,8 +293,8 @@ multiline comment */
t.Run("", func(t *testing.T) {
got, statementMetadata, err := normalizer.Normalize(test.input)
assert.NoError(t, err)
assert.Equal(t, got, test.expected)
assert.Equal(t, statementMetadata, &test.statementMetadata)
assert.Equal(t, test.expected, got)
assert.Equal(t, &test.statementMetadata, statementMetadata)
})
}
}
@@ -354,8 +354,8 @@ func TestNormalizerNotCollectMetadata(t *testing.T) {
t.Run("", func(t *testing.T) {
got, statementMetadata, err := normalizer.Normalize(test.input)
assert.NoError(t, err)
assert.Equal(t, got, test.expected)
assert.Equal(t, statementMetadata, &test.statementMetadata)
assert.Equal(t, test.expected, got)
assert.Equal(t, &test.statementMetadata, statementMetadata)
})
}
}
@@ -400,7 +400,7 @@ func TestNormalizerFormatting(t *testing.T) {
for _, query := range test.queries {
got, _, err := normalizer.Normalize(query)
assert.NoError(t, err)
assert.Equal(t, got, test.expected)
assert.Equal(t, test.expected, got)
}
})
}
@@ -465,14 +465,14 @@ func TestNormalizerRepeatedExecution(t *testing.T) {
for i := 0; i < 10; i++ {
got, statementMetadata, err := normalizer.Normalize(input)
assert.NoError(t, err)
assert.Equal(t, got, test.expected)
assert.Equal(t, statementMetadata.Commands, test.statementMetadata.Commands)
assert.Equal(t, statementMetadata.Tables, test.statementMetadata.Tables)
assert.Equal(t, test.expected, got)
assert.Equal(t, test.statementMetadata.Commands, statementMetadata.Commands)
assert.Equal(t, test.statementMetadata.Tables, statementMetadata.Tables)
// comments are stripped after the first execution
if i == 0 {
assert.Equal(t, statementMetadata.Comments, test.statementMetadata.Comments)
assert.Equal(t, test.statementMetadata.Comments, statementMetadata.Comments)
} else {
assert.Equal(t, statementMetadata.Comments, []string{})
assert.Equal(t, []string{}, statementMetadata.Comments)
}
input = got
}
@@ -613,8 +613,8 @@ func TestNormalizeDeobfuscatedSQL(t *testing.T) {
)
got, statementMetadata, err := normalizer.Normalize(test.input)
assert.NoError(t, err)
assert.Equal(t, got, test.expected)
assert.Equal(t, statementMetadata, &test.statementMetadata)
assert.Equal(t, test.expected, got)
assert.Equal(t, &test.statementMetadata, statementMetadata)
})
}
}
@@ -677,7 +677,7 @@ func TestGroupObfuscatedValues(t *testing.T) {
for _, test := range tests {
t.Run("", func(t *testing.T) {
got := groupObfuscatedValues(test.input)
assert.Equal(t, got, test.expected)
assert.Equal(t, test.expected, got)
})
}
}
37 changes: 10 additions & 27 deletions obfuscator.go
Original file line number Diff line number Diff line change
@@ -48,45 +48,28 @@ func (o *Obfuscator) Obfuscate(input string) string {
var obfuscatedSQL strings.Builder

lexer := New(input)
for token := range lexer.ScanAllTokens() {
for _, token := range lexer.ScanAll() {
switch token.Type {
case NUMBER:
obfuscatedSQL.WriteString(NumberPlaceholder)
case STRING:
obfuscatedSQL.WriteString(StringPlaceholder)
case INCOMPLETE_STRING:
obfuscatedSQL.WriteString(StringPlaceholder)
case IDENT:
if o.config.ReplaceDigits {
// regex to replace digits in identifier
// we try to avoid using regex as much as possible,
// as regex isn't the most performant,
// but it's the easiest to implement and maintain
obfuscatedSQL.WriteString(digitsRegex.ReplaceAllString(token.Value, "?"))
} else {
obfuscatedSQL.WriteString(token.Value)
}
case COMMENT:
obfuscatedSQL.WriteString(token.Value)
case MULTILINE_COMMENT:
obfuscatedSQL.WriteString(token.Value)
case DOLLAR_QUOTED_STRING:
obfuscatedSQL.WriteString(StringPlaceholder)
case DOLLAR_QUOTED_FUNCTION:
if o.config.DollarQuotedFunc {
// obfuscate the content of dollar quoted function
quotedFunc := strings.TrimPrefix(token.Value, "$func$")
quotedFunc = strings.TrimSuffix(quotedFunc, "$func$")
quotedFunc := token.Value[6 : len(token.Value)-6] // remove the $func$ prefix and suffix
obfuscatedSQL.WriteString("$func$")
obfuscatedSQL.WriteString(o.Obfuscate(quotedFunc))
obfuscatedSQL.WriteString("$func$")
} else {
// treat dollar quoted function as dollar quoted string
obfuscatedSQL.WriteString(StringPlaceholder)
}
case ERROR | UNKNOWN:
// if we encounter an error or unknown token, we just append the value
obfuscatedSQL.WriteString(token.Value)
case STRING, INCOMPLETE_STRING, DOLLAR_QUOTED_STRING:
obfuscatedSQL.WriteString(StringPlaceholder)
case IDENT:
if o.config.ReplaceDigits {
obfuscatedSQL.WriteString(replaceDigits(token.Value, "?"))
} else {
obfuscatedSQL.WriteString(token.Value)
}
default:
obfuscatedSQL.WriteString(token.Value)
}
21 changes: 18 additions & 3 deletions obfuscator_test.go
Original file line number Diff line number Diff line change
@@ -19,11 +19,26 @@ func TestObfuscator(t *testing.T) {
expected: "SELECT * FROM users where id = ?",
replaceDigits: true,
},
{
input: "SELECT * FROM users where id = 0x124af",
expected: "SELECT * FROM users where id = ?",
replaceDigits: true,
},
{
input: "SELECT * FROM users where id = 0617",
expected: "SELECT * FROM users where id = ?",
replaceDigits: true,
},
{
input: "SELECT * FROM users where id = '12'",
expected: "SELECT * FROM users where id = ?",
replaceDigits: true,
},
{
input: "SELECT * FROM users where id = 'j\\'s'",
expected: "SELECT * FROM users where id = ?",
replaceDigits: false,
},
{
input: "SELECT * FROM \"users table\" where id = 1",
expected: "SELECT * FROM \"users table\" where id = ?",
@@ -176,8 +191,8 @@ func TestObfuscator(t *testing.T) {
replaceDigits: true,
},
{
input: `SELECT * FROM Items WHERE id = -1 OR id = -01 OR id = -108 OR id = -.018 OR id = -.08 OR id = -908129`,
expected: `SELECT * FROM Items WHERE id = ? OR id = ? OR id = ? OR id = ? OR id = ? OR id = ?`,
input: `SELECT * FROM Items WHERE id = -1 OR id = +01 OR id = -108 OR id = -.018 OR id = -.08 OR id = -908129 OR id = 1e2 OR id = 1e-1`,
expected: `SELECT * FROM Items WHERE id = ? OR id = ? OR id = ? OR id = ? OR id = ? OR id = ? OR id = ? OR id = ?`,
replaceDigits: true,
},
{
@@ -321,7 +336,7 @@ func TestObfuscator(t *testing.T) {
WithDollarQuotedFunc(tt.dollarQuotedFunc),
)
got := obfuscator.Obfuscate(tt.input)
assert.Equal(t, got, tt.expected)
assert.Equal(t, tt.expected, got)
})
}
}
206 changes: 100 additions & 106 deletions sqllexer.go
Original file line number Diff line number Diff line change
@@ -54,8 +54,8 @@ func (s *Lexer) ScanAll() []Token {

// ScanAllTokens scans the entire input string and returns a channel of tokens.
// Use this if you want to process the tokens as they are scanned.
func (s *Lexer) ScanAllTokens() <-chan *Token {
tokenCh := make(chan *Token)
func (s *Lexer) ScanAllTokens() <-chan Token {
tokenCh := make(chan Token)

go func() {
defer close(tokenCh)
@@ -66,7 +66,7 @@ func (s *Lexer) ScanAllTokens() <-chan *Token {
// don't include EOF token in the result
break
}
tokenCh <- &token
tokenCh <- token
}
}()

@@ -90,11 +90,11 @@ func (s *Lexer) Scan() Token {
case isMultiLineComment(ch, s.lookAhead(1)):
return s.scanMultiLineComment()
case isLeadingSign(ch):
// if the leading sign is followed by a digit,
// and preceded by whitespace, then it's a number
// if the leading sign is followed by a digit, then it's a number
// although this is not strictly true, it's good enough for our purposes
if (isDigit(s.lookAhead(1)) || s.lookAhead(1) == '.') && isWhitespace(s.lookAhead(-1)) {
return s.scanNumber()
nextCh := s.lookAhead(1)
if isDigit(nextCh) || nextCh == '.' {
return s.scanNumberWithLeadingSign()
}
return s.scanOperator()
case isDigit(ch):
@@ -106,11 +106,6 @@ func (s *Lexer) Scan() Token {
// if the dollar sign is followed by a digit, then it's a numbered parameter
return s.scanNumberedParameter()
}
if isDollarQuotedFunction(s.peekBy(6)) {
// check for dollar quoted function
// dollar quoted function will be obfuscated as a SQL query
return s.scanDollarQuotedFunction()
}
return s.scanDollarQuotedString()
case isOperator(ch):
return s.scanOperator()
@@ -131,14 +126,6 @@ func (s *Lexer) lookAhead(n int) rune {
return rune(s.src[s.cursor+n])
}

// peekBy returns a slice of runes n positions ahead of the cursor.
func (s *Lexer) peekBy(n int) []rune {
if s.cursor+n >= len(s.src) || s.cursor+n < 0 {
return []rune{}
}
return []rune(s.src[s.cursor : s.cursor+n])
}

// peek returns the rune at the cursor position.
func (s *Lexer) peek() rune {
return s.lookAhead(0)
@@ -162,24 +149,44 @@ func (s *Lexer) next() rune {
return s.nextBy(1)
}

func (s *Lexer) matchAt(match []rune) bool {
if s.cursor+len(match) > len(s.src) {
return false
}
for i, ch := range match {
if s.src[s.cursor+i] != byte(ch) {
return false
}
}
return true
}

func (s *Lexer) scanNumberWithLeadingSign() Token {
s.start = s.cursor
s.next() // consume the leading sign
return s.scanNumberic()
}

func (s *Lexer) scanNumber() Token {
s.start = s.cursor
ch := s.peek()
nextCh := s.lookAhead(1)
return s.scanNumberic()
}

// check for hex or octal number
if ch == '0' {
func (s *Lexer) scanNumberic() Token {
if s.peek() == '0' {
nextCh := s.lookAhead(1)
if nextCh == 'x' || nextCh == 'X' {
return s.scanHexNumber()
} else if nextCh >= '0' && nextCh <= '7' {
return s.scanOctalNumber()
}
}

// optional leading sign e.g. +1, -1
if isLeadingSign(ch) {
ch = s.next()
}
return s.scanDecimalNumber()
}

func (s *Lexer) scanDecimalNumber() Token {
ch := s.next()

// scan digits
for isDigit(ch) || ch == '.' || isExpontent(ch) {
@@ -196,55 +203,61 @@ func (s *Lexer) scanNumber() Token {
}

func (s *Lexer) scanHexNumber() Token {
s.start = s.cursor
s.nextBy(2) // consume the leading 0x
for digitVal(s.peek()) < 16 {
s.next()
ch := s.nextBy(2) // consume the leading 0x

for isDigit(ch) || ('a' <= ch && ch <= 'f') || ('A' <= ch && ch <= 'F') {
ch = s.next()
}
return Token{NUMBER, s.src[s.start:s.cursor]}
}

func (s *Lexer) scanOctalNumber() Token {
s.start = s.cursor
s.next() // consume the leading 0
for digitVal(s.peek()) < 8 {
s.next()
ch := s.nextBy(2) // consume the leading 0 and number

for '0' <= ch && ch <= '7' {
ch = s.next()
}
return Token{NUMBER, s.src[s.start:s.cursor]}
}

func (s *Lexer) scanString() Token {
s.start = s.cursor
s.next() // consume the opening quote
ch := s.next() // consume the opening quote
escaped := false

for {
ch := s.peek()
if ch == '\'' {
// encountered the closing quote
break
if escaped {
// encountered an escape character
// reset the escaped flag and continue
escaped = false
ch = s.next()
continue
}

if ch == '\\' {
// Encountered an escape character, look ahead the next character
// If it's a single quote or a backslash, consume the escape character
nextCh := s.lookAhead(1)
if nextCh == '\'' || nextCh == '\\' {
s.next()
}
escaped = true
ch = s.next()
continue
}

if ch == '\'' {
s.next() // consume the closing quote
return Token{STRING, s.src[s.start:s.cursor]}
}

if isEOF(ch) {
// encountered EOF before closing quote
// this usually happens when the string is truncated
return Token{INCOMPLETE_STRING, s.src[s.start:s.cursor]}
}
s.next()
ch = s.next()
}
s.next() // consume the closing quote
return Token{STRING, s.src[s.start:s.cursor]}
}

func (s *Lexer) scanIdentifier() Token {
// NOTE: this func does not distinguish between SQL keywords and identifiers
s.start = s.cursor
ch := s.peek()
ch := s.next()
for isLetter(ch) || isDigit(ch) || ch == '.' || ch == '?' {
ch = s.next()
}
@@ -254,23 +267,22 @@ func (s *Lexer) scanIdentifier() Token {

func (s *Lexer) scanDoubleQuotedIdentifier() Token {
s.start = s.cursor
s.next() // consume the opening quote
ch := s.next() // consume the opening quote
for {
ch := s.peek()
// encountered the closing quote
// BUT if it's followed by .", then we should keep going
// e.g. postgre "foo"."bar"
if ch == '"' {
if string(s.peekBy(3)) == `"."` {
s.nextBy(3) // consume the "."
if s.matchAt([]rune(`"."`)) {
ch = s.nextBy(3) // consume the "."
continue
}
break
}
if isEOF(ch) {
return Token{ERROR, s.src[s.start:s.cursor]}
}
s.next()
ch = s.next()
}
s.next() // consume the closing quote
return Token{IDENT, s.src[s.start:s.cursor]}
@@ -279,16 +291,18 @@ func (s *Lexer) scanDoubleQuotedIdentifier() Token {
func (s *Lexer) scanWhitespace() Token {
// scan whitespace, tab, newline, carriage return
s.start = s.cursor
for isWhitespace(s.peek()) {
s.next()
ch := s.next()
for isWhitespace(ch) {
ch = s.next()
}
return Token{WS, s.src[s.start:s.cursor]}
}

func (s *Lexer) scanOperator() Token {
s.start = s.cursor
for isOperator(s.peek()) {
s.next()
ch := s.next()
for isOperator(ch) {
ch = s.next()
}
return Token{OPERATOR, s.src[s.start:s.cursor]}
}
@@ -301,28 +315,28 @@ func (s *Lexer) scanWildcard() Token {

func (s *Lexer) scanSingleLineComment() Token {
s.start = s.cursor
for s.peek() != '\n' && !isEOF(s.peek()) {
s.next()
ch := s.nextBy(2) // consume the opening dashes
for ch != '\n' && !isEOF(ch) {
ch = s.next()
}
return Token{COMMENT, s.src[s.start:s.cursor]}
}

func (s *Lexer) scanMultiLineComment() Token {
s.start = s.cursor
s.nextBy(2) // consume the opening slash and asterisk
ch := s.nextBy(2) // consume the opening slash and asterisk
for {
ch := s.peek()
if ch == '*' && s.lookAhead(1) == '/' {
s.nextBy(2) // consume the closing asterisk and slash
break
}
if ch == 0 {
if isEOF(ch) {
// encountered EOF before closing comment
// this usually happens when the comment is truncated
return Token{ERROR, s.src[s.start:s.cursor]}
}
s.next()
ch = s.next()
}
s.nextBy(2) // consume the closing asterisk and slash
return Token{MULTILINE_COMMENT, s.src[s.start:s.cursor]}
}

@@ -332,58 +346,38 @@ func (s *Lexer) scanPunctuation() Token {
return Token{PUNCTUATION, s.src[s.start:s.cursor]}
}

func (s *Lexer) scanDollarQuotedFunction() Token {
s.start = s.cursor
s.nextBy(6) // consume the opening dollar and the function name
for {
ch := s.peek()
if ch == '$' && isDollarQuotedFunction(s.peekBy(6)) {
break
}
if isEOF(ch) {
return Token{ERROR, s.src[s.start:s.cursor]}
}
s.next()
}
s.nextBy(6) // consume the closing dollar quoted function
return Token{DOLLAR_QUOTED_FUNCTION, s.src[s.start:s.cursor]}
}

func (s *Lexer) scanDollarQuotedString() Token {
s.start = s.cursor
s.next() // consume the dollar sign
dollars := 1
for {
ch := s.peek()
if ch == '$' {
// keep track of the number of dollar signs we've seen
// a valid dollar quoted string is either $$string$$ or $tag$string$tag$
// we technically should see 4 dollar signs
// this is a bit of a hack because we don't want to keep track of the tag
// but it's good enough for our purposes of tokenization and obfuscation
dollars += 1
}
if isEOF(ch) {
break
ch := s.next() // consume the dollar sign
tagStart := s.cursor

for s.cursor < len(s.src) && ch != '$' {
ch = s.next()
}
s.next() // consume the closing dollar sign of the tag
tag := s.src[tagStart-1 : s.cursor] // include the opening and closing dollar sign e.g. $tag$

for s.cursor < len(s.src) {
if s.matchAt([]rune(tag)) {
s.nextBy(len(tag)) // consume the closing tag
if tag == "$func$" {
return Token{DOLLAR_QUOTED_FUNCTION, s.src[s.start:s.cursor]}
}
return Token{DOLLAR_QUOTED_STRING, s.src[s.start:s.cursor]}
}
s.next()
}

if dollars != 4 {
return Token{UNKNOWN, s.src[s.start:s.cursor]}
}
return Token{DOLLAR_QUOTED_STRING, s.src[s.start:s.cursor]}
return Token{ERROR, s.src[s.start:s.cursor]}
}

func (s *Lexer) scanNumberedParameter() Token {
s.start = s.cursor
s.next() // consume the dollar sign
ch := s.nextBy(2) // consume the dollar sign and the number
for {
ch := s.peek()
if !isDigit(ch) {
break
}
s.next()
ch = s.next()
}
return Token{NUMBERED_PARAMETER, s.src[s.start:s.cursor]}
}
98 changes: 98 additions & 0 deletions sqllexer_bench_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package sqllexer

import (
"fmt"
"strconv"
"testing"
)

func BenchmarkLexer(b *testing.B) {
// LargeQuery is sourced from https://stackoverflow.com/questions/12607667/issues-with-a-very-large-sql-query/12711494
var LargeQuery = `SELECT '%c%' as Chapter,
(SELECT count(ticket.id) AS Matches FROM engine.ticket INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket
WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%' AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine' AND ticket.status IN ('new','assigned') ) AS 'New',
(SELECT count(ticket.id) AS Matches FROM engine.ticket INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket
WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%' AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine' AND ticket.status='document_interface' ) AS 'Document\
Interface',
(SELECT count(ticket.id) AS Matches FROM engine.ticket INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket
WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%' AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine' AND ticket.status='interface_development' ) AS 'Inter\
face Development',
(SELECT count(ticket.id) AS Matches FROM engine.ticket INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket
WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%' AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine' AND ticket.status='interface_check' ) AS 'Interface C\
heck',
(SELECT count(ticket.id) AS Matches FROM engine.ticket INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket
WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%' AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine' AND ticket.status='document_routine' ) AS 'Document R\
outine',
(SELECT count(ticket.id) AS Matches FROM engine.ticket INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket
WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%' AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine' AND ticket.status='full_development' ) AS 'Full Devel\
opment',
(SELECT count(ticket.id) AS Matches FROM engine.ticket INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket
WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%' AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine' AND ticket.status='peer_review_1' ) AS 'Peer Review O\
ne',
(SELECT count(ticket.id) AS Matches FROM engine.ticket INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket
WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%'AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine' AND ticket.status='peer_review_2' ) AS 'Peer Review Tw\
o',
(SELECT count(ticket.id) AS Matches FROM engine.ticket INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket
WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%' AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine' AND ticket.status='qa' ) AS 'QA',
(SELECT count(ticket.id) AS Matches FROM engine.ticket INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket
WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%'AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine' AND ticket.status='closed' ) AS 'Closed',
count(id) AS Total,
ticket.id AS _id
FROM engine.ticket
INNER JOIN engine.ticket_custom ON ticket.id = ticket_custom.ticket
WHERE ticket_custom.name='chapter' AND ticket_custom.value LIKE '%c%' AND type='New material' AND milestone='1.1.12' AND component NOT LIKE 'internal_engine'`

// query3 is sourced from https://www.ibm.com/support/knowledgecenter/SSCRJT_6.0.0/com.ibm.swg.im.bigsql.doc/doc/tut_bsql_uc_complex_query.html
var ComplexQuery = `WITH
sales AS
(SELECT sf.*
FROM gosalesdw.sls_order_method_dim AS md,
gosalesdw.sls_product_dim AS pd,
gosalesdw.emp_employee_dim AS ed,
gosalesdw.sls_sales_fact AS sf
WHERE pd.product_key = sf.product_key
AND pd.product_number > 10000
AND pd.base_product_key > 30
AND md.order_method_key = sf.order_method_key
AND md.order_method_code > 5
AND ed.employee_key = sf.employee_key
AND ed.manager_code1 > 20),
inventory AS
(SELECT if.*
FROM gosalesdw.go_branch_dim AS bd,
gosalesdw.dist_inventory_fact AS if
WHERE if.branch_key = bd.branch_key
AND bd.branch_code > 20)
SELECT sales.product_key AS PROD_KEY,
SUM(CAST (inventory.quantity_shipped AS BIGINT)) AS INV_SHIPPED,
SUM(CAST (sales.quantity AS BIGINT)) AS PROD_QUANTITY,
RANK() OVER ( ORDER BY SUM(CAST (sales.quantity AS BIGINT)) DESC) AS PROD_RANK
FROM sales, inventory
WHERE sales.product_key = inventory.product_key
GROUP BY sales.product_key;
`

var superLargeQuery = "select top ? percent IdTrebEmpresa, CodCli, NOMEMP, Baixa, CASE WHEN IdCentreTreball IS ? THEN ? ELSE CONVERT ( VARCHAR ( ? ) IdCentreTreball ) END, CASE WHEN NOMESTAB IS ? THEN ? ELSE NOMESTAB END, TIPUS, CASE WHEN IdLloc IS ? THEN ? ELSE CONVERT ( VARCHAR ( ? ) IdLloc ) END, CASE WHEN NomLlocComplert IS ? THEN ? ELSE NomLlocComplert END, CASE WHEN DesLloc IS ? THEN ? ELSE DesLloc END, IdLlocTreballUnic From ( SELECT ?, dbo.Treb_Empresa.IdTrebEmpresa, dbo.Treb_Empresa.IdTreballador, dbo.Treb_Empresa.CodCli, dbo.Clients.NOMEMP, dbo.Treb_Empresa.Baixa, dbo.Treb_Empresa.IdCentreTreball, dbo.Cli_Establiments.NOMESTAB, ?, ?, dbo.Treb_Empresa.DataInici, dbo.Treb_Empresa.DataFi, CASE WHEN dbo.Treb_Empresa.DesLloc IS ? THEN ? ELSE dbo.Treb_Empresa.DesLloc END DesLloc, dbo.Treb_Empresa.IdLlocTreballUnic FROM dbo.Clients WITH ( NOLOCK ) INNER JOIN dbo.Treb_Empresa WITH ( NOLOCK ) ON dbo.Clients.CODCLI = dbo.Treb_Empresa.CodCli LEFT OUTER JOIN dbo.Cli_Establiments WITH ( NOLOCK ) ON dbo.Cli_Establiments.Id_ESTAB_CLI = dbo.Treb_Empresa.IdCentreTreball AND dbo.Cli_Establiments.CODCLI = dbo.Treb_Empresa.CodCli WHERE dbo.Treb_Empresa.IdTreballador = ? AND Treb_Empresa.IdTecEIRLLlocTreball IS ? AND IdMedEIRLLlocTreball IS ? AND IdLlocTreballTemporal IS ? UNION ALL SELECT ?, dbo.Treb_Empresa.IdTrebEmpresa, dbo.Treb_Empresa.IdTreballador, dbo.Treb_Empresa.CodCli, dbo.Clients.NOMEMP, dbo.Treb_Empresa.Baixa, dbo.Treb_Empresa.IdCentreTreball, dbo.Cli_Establiments.NOMESTAB, dbo.Treb_Empresa.IdTecEIRLLlocTreball, dbo.fn_NomLlocComposat ( dbo.Treb_Empresa.IdTecEIRLLlocTreball ), dbo.Treb_Empresa.DataInici, dbo.Treb_Empresa.DataFi, CASE WHEN dbo.Treb_Empresa.DesLloc IS ? THEN ? ELSE dbo.Treb_Empresa.DesLloc END DesLloc, dbo.Treb_Empresa.IdLlocTreballUnic FROM dbo.Clients WITH ( NOLOCK ) INNER JOIN dbo.Treb_Empresa WITH ( NOLOCK ) ON dbo.Clients.CODCLI = dbo.Treb_Empresa.CodCli LEFT OUTER JOIN dbo.Cli_Establiments WITH ( NOLOCK ) ON dbo.Cli_Establiments.Id_ESTAB_CLI = dbo.Treb_Empresa.IdCentreTreball AND dbo.Cli_Establiments.CODCLI = dbo.Treb_Empresa.CodCli WHERE ( dbo.Treb_Empresa.IdTreballador = ? ) AND ( NOT ( dbo.Treb_Empresa.IdTecEIRLLlocTreball IS ? ) ) UNION ALL SELECT ?, dbo.Treb_Empresa.IdTrebEmpresa, dbo.Treb_Empresa.IdTreballador, dbo.Treb_Empresa.CodCli, dbo.Clients.NOMEMP, dbo.Treb_Empresa.Baixa, dbo.Treb_Empresa.IdCentreTreball, dbo.Cli_Establiments.NOMESTAB, dbo.Treb_Empresa.IdMedEIRLLlocTreball, dbo.fn_NomMedEIRLLlocComposat ( dbo.Treb_Empresa.IdMedEIRLLlocTreball ), dbo.Treb_Empresa.DataInici, dbo.Treb_Empresa.DataFi, CASE WHEN dbo.Treb_Empresa.DesLloc IS ? THEN ? ELSE dbo.Treb_Empresa.DesLloc END DesLloc, dbo.Treb_Empresa.IdLlocTreballUnic FROM dbo.Clients WITH ( NOLOCK ) INNER JOIN dbo.Treb_Empresa WITH ( NOLOCK ) ON dbo.Clients.CODCLI = dbo.Treb_Empresa.CodCli LEFT OUTER JOIN dbo.Cli_Establiments WITH ( NOLOCK ) ON dbo.Cli_Establiments.Id_ESTAB_CLI = dbo.Treb_Empresa.IdCentreTreball AND dbo.Cli_Establiments.CODCLI = dbo.Treb_Empresa.CodCli WHERE ( dbo.Treb_Empresa.IdTreballador = ? ) AND ( Treb_Empresa.IdTecEIRLLlocTreball IS ? ) AND ( NOT ( dbo.Treb_Empresa.IdMedEIRLLlocTreball IS ? ) ) UNION ALL SELECT ?, dbo.Treb_Empresa.IdTrebEmpresa, dbo.Treb_Empresa.IdTreballador, dbo.Treb_Empresa.CodCli, dbo.Clients.NOMEMP, dbo.Treb_Empresa.Baixa, dbo.Treb_Empresa.IdCentreTreball, dbo.Cli_Establiments.NOMESTAB, dbo.Treb_Empresa.IdLlocTreballTemporal, dbo.Lloc_Treball_Temporal.NomLlocTreball, dbo.Treb_Empresa.DataInici, dbo.Treb_Empresa.DataFi, CASE WHEN dbo.Treb_Empresa.DesLloc IS ? THEN ? ELSE dbo.Treb_Empresa.DesLloc END DesLloc, dbo.Treb_Empresa.IdLlocTreballUnic FROM dbo.Clients WITH ( NOLOCK ) INNER JOIN dbo.Treb_Empresa WITH ( NOLOCK ) ON dbo.Clients.CODCLI = dbo.Treb_Empresa.CodCli INNER JOIN dbo.Lloc_Treball_Temporal WITH ( NOLOCK ) ON dbo.Treb_Empresa.IdLlocTreballTemporal = dbo.Lloc_Treball_Temporal.IdLlocTreballTemporal LEFT OUTER JOIN dbo.Cli_Establiments WITH ( NOLOCK ) ON dbo.Cli_Establiments.Id_ESTAB_CLI = dbo.Treb_Empresa.IdCentreTreball AND dbo.Cli_Establiments.CODCLI = dbo.Treb_Empresa.CodCli WHERE dbo.Treb_Empresa.IdTreballador = ? AND Treb_Empresa.IdTecEIRLLlocTreball IS ? AND IdMedEIRLLlocTreball IS ? ) Where ? = %d"

benchmarks := []struct {
name string
query string
}{
{"Escaping", `INSERT INTO delayed_jobs (attempts, created_at, failed_at, handler, last_error, locked_at, locked_by, priority, queue, run_at, updated_at) VALUES (0, '2016-12-04 17:09:59', NULL, '--- !ruby/object:Delayed::PerformableMethod\nobject: !ruby/object:Item\n store:\n - a simple string\n - an \'escaped \' string\n - another \'escaped\' string\n - 42\n string: a string with many \\\\\'escapes\\\\\'\nmethod_name: :show_store\nargs: []\n', NULL, NULL, NULL, 0, NULL, '2016-12-04 17:09:59', '2016-12-04 17:09:59')`},
{"Grouping", `INSERT INTO delayed_jobs (created_at, failed_at, handler) VALUES (0, '2016-12-04 17:09:59', NULL), (0, '2016-12-04 17:09:59', NULL), (0, '2016-12-04 17:09:59', NULL), (0, '2016-12-04 17:09:59', NULL)`},
{"Large", LargeQuery},
{"Complex", ComplexQuery},
{"SuperLarge", fmt.Sprintf(superLargeQuery, 1)},
}

for _, bm := range benchmarks {
b.Run(bm.name+"/"+strconv.Itoa(len(bm.query)), func(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
lexer := New(bm.query)
lexer.ScanAll()
}
})
}
}
44 changes: 43 additions & 1 deletion sqllexer_test.go
Original file line number Diff line number Diff line change
@@ -273,6 +273,27 @@ func TestLexer(t *testing.T) {
{DOLLAR_QUOTED_STRING, "$tag$test$tag$"},
},
},
{
name: "dollar quoted string",
input: "SELECT * FROM users where id = $$test$$",
expected: []Token{
{IDENT, "SELECT"},
{WS, " "},
{WILDCARD, "*"},
{WS, " "},
{IDENT, "FROM"},
{WS, " "},
{IDENT, "users"},
{WS, " "},
{IDENT, "where"},
{WS, " "},
{IDENT, "id"},
{WS, " "},
{OPERATOR, "="},
{WS, " "},
{DOLLAR_QUOTED_STRING, "$$test$$"},
},
},
{
name: "numbered parameter",
input: "SELECT * FROM users where id = $1",
@@ -407,13 +428,34 @@ func TestLexer(t *testing.T) {
{IDENT, `"public"."users table"`},
},
},
{
name: "select with escaped string",
input: "SELECT * FROM users where id = 'j\\'s'",
expected: []Token{
{IDENT, "SELECT"},
{WS, " "},
{WILDCARD, "*"},
{WS, " "},
{IDENT, "FROM"},
{WS, " "},
{IDENT, "users"},
{WS, " "},
{IDENT, "where"},
{WS, " "},
{IDENT, "id"},
{WS, " "},
{OPERATOR, "="},
{WS, " "},
{STRING, "'j\\'s'"},
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
lexer := New(tt.input)
tokens := lexer.ScanAll()
assert.Equal(t, tokens, tt.expected)
assert.Equal(t, tt.expected, tokens)
})
}
}
36 changes: 21 additions & 15 deletions sqllexer_utils.go
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@ package sqllexer

import (
"regexp"
"strings"
"unicode"
)

@@ -41,8 +42,6 @@ var tableIndicators = map[string]bool{
"TABLE": true,
}

var digitsRegex = regexp.MustCompile(`\d+`)

var keywordsRegex = regexp.MustCompile(`(?i)^(SELECT|INSERT|UPDATE|DELETE|CREATE|ALTER|DROP|GRANT|REVOKE|ADD|ALL|AND|ANY|AS|ASC|BEGIN|BETWEEN|BY|CASE|CHECK|COLUMN|COMMIT|CONSTRAINT|DATABASE|DECLARE|DEFAULT|DESC|DISTINCT|ELSE|END|EXEC|EXISTS|FOREIGN|FROM|GROUP|HAVING|IN|INDEX|INNER|INTO|IS|JOIN|KEY|LEFT|LIKE|LIMIT|NOT|ON|OR|ORDER|OUTER|PRIMARY|PROCEDURE|REPLACE|RETURNS|RIGHT|ROLLBACK|ROWNUM|SET|SOME|TABLE|TOP|TRUNCATE|UNION|UNIQUE|USE|VALUES|VIEW|WHERE|CUBE|ROLLUP|LITERAL|WINDOW|VACCUM|ANALYZE|ILIKE|USING|ASSERTION|DOMAIN|CLUSTER|COPY|EXPLAIN|PLPGSQL|TRIGGER|TEMPORARY|UNLOGGED|RECURSIVE|RETURNING)$`)

var groupableRegex = regexp.MustCompile(`(\()\s*\?(?:\s*,\s*\?\s*)*\s*(\))|(\[)\s*\?(?:\s*,\s*\?\s*)*\s*(\])`)
@@ -101,10 +100,6 @@ func isEOF(ch rune) bool {
return ch == 0
}

func isDollarQuotedFunction(chs []rune) bool {
return string(chs) == "$func$"
}

func isCommand(ident string) bool {
_, ok := Commands[ident]
return ok
@@ -115,14 +110,25 @@ func isTableIndicator(ident string) bool {
return ok
}

func digitVal(ch rune) int {
switch {
case '0' <= ch && ch <= '9':
return int(ch) - '0'
case 'a' <= ch && ch <= 'f':
return int(ch) - 'a' + 10
case 'A' <= ch && ch <= 'F':
return int(ch) - 'A' + 10
func isSQLKeyword(token Token) bool {
return token.Type == IDENT && keywordsRegex.MatchString(token.Value)
}

func replaceDigits(input string, placeholder string) string {
var builder strings.Builder

i := 0
for i < len(input) {
if isDigit(rune(input[i])) {
builder.WriteString(placeholder)
for i < len(input) && isDigit(rune(input[i])) {
i++
}
} else {
builder.WriteByte(input[i])
i++
}
}
return 16 // larger than any legal digit val

return builder.String()
}

0 comments on commit dcfea70

Please sign in to comment.