From 5e11e41b9056c7fedcadbae762fb5bbf5921c4d3 Mon Sep 17 00:00:00 2001 From: Zhengda Lu Date: Mon, 6 Nov 2023 13:37:12 -0500 Subject: [PATCH] Support tokenize SQL function (#24) * Support tokenize SQL function * add unit test to verify tokenize function --- normalizer.go | 4 ++-- normalizer_test.go | 4 ++-- sqllexer.go | 6 +++++- sqllexer_test.go | 16 ++++++++++++++++ 4 files changed, 25 insertions(+), 5 deletions(-) diff --git a/normalizer.go b/normalizer.go index fe9de55..4b2d7c1 100644 --- a/normalizer.go +++ b/normalizer.go @@ -153,7 +153,7 @@ func (n *Normalizer) collectMetadata(token *Token, lastToken *Token, statementMe if n.config.CollectComments && (token.Type == COMMENT || token.Type == MULTILINE_COMMENT) { // Collect comments statementMetadata.Comments = append(statementMetadata.Comments, token.Value) - } else if token.Type == IDENT || token.Type == QUOTED_IDENT { + } else if token.Type == IDENT || token.Type == QUOTED_IDENT || token.Type == FUNCTION { tokenVal := token.Value if token.Type == QUOTED_IDENT { // remove all open and close quotes @@ -265,7 +265,7 @@ func (n *Normalizer) isObfuscatedValueGroupable(token *Token, lastToken *Token, func (n *Normalizer) appendWhitespace(lastToken *Token, token *Token, normalizedSQLBuilder *strings.Builder) { // do not add a space between parentheses if RemoveSpaceBetweenParentheses is true - if n.config.RemoveSpaceBetweenParentheses && (lastToken.Value == "(" || lastToken.Value == "[") { + if n.config.RemoveSpaceBetweenParentheses && (lastToken.Type == FUNCTION || lastToken.Value == "(" || lastToken.Value == "[") { return } diff --git a/normalizer_test.go b/normalizer_test.go index 89bae84..e28a347 100644 --- a/normalizer_test.go +++ b/normalizer_test.go @@ -946,10 +946,10 @@ func TestNormalizerWithoutSpaceBetweenParentheses(t *testing.T) { }{ { input: "SELECT count(*) FROM users", - expected: "SELECT count (*) FROM users", + expected: "SELECT count(*) FROM users", }, { - input: "SELECT * FROM users WHERE id IN(?, ?)", + input: "SELECT * FROM users WHERE id IN (?, ?)", expected: "SELECT * FROM users WHERE id IN (?)", }, { diff --git a/sqllexer.go b/sqllexer.go index c164bb4..3154d2c 100644 --- a/sqllexer.go +++ b/sqllexer.go @@ -22,6 +22,7 @@ const ( DOLLAR_QUOTED_STRING // dollar quoted string POSITIONAL_PARAMETER // numbered parameter BIND_PARAMETER // bind parameter + FUNCTION // function UNKNOWN // unknown token ) @@ -305,7 +306,10 @@ func (s *Lexer) scanIdentifier(ch rune) Token { for isLetter(ch) || isDigit(ch) || ch == '.' || ch == '?' || ch == '$' || ch == '#' { ch = s.nextBy(utf8.RuneLen(ch)) } - // return the token as uppercase so that we can do case insensitive matching + if ch == '(' { + // if the identifier is followed by a (, then it's a function + return Token{FUNCTION, s.src[s.start:s.cursor]} + } return Token{IDENT, s.src[s.start:s.cursor]} } diff --git a/sqllexer_test.go b/sqllexer_test.go index ea0b11e..6827347 100644 --- a/sqllexer_test.go +++ b/sqllexer_test.go @@ -558,6 +558,22 @@ func TestLexer(t *testing.T) { }, lexerOpts: []lexerOption{WithDBMS(DBMSMySQL)}, }, + { + name: "Tokenize function", + input: "SELECT count(*) FROM users", + expected: []Token{ + {IDENT, "SELECT"}, + {WS, " "}, + {FUNCTION, "count"}, + {PUNCTUATION, "("}, + {WILDCARD, "*"}, + {PUNCTUATION, ")"}, + {WS, " "}, + {IDENT, "FROM"}, + {WS, " "}, + {IDENT, "users"}, + }, + }, } for _, tt := range tests {