Skip to content

Commit

Permalink
Merge pull request #9 from DataDog/zhengda.lu/lexer-dbms
Browse files Browse the repository at this point in the history
dbms should be a lexer config
  • Loading branch information
lu-zhengda authored Sep 11, 2023
2 parents dcfea70 + a80bbc9 commit 70764ab
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 65 deletions.
54 changes: 29 additions & 25 deletions normalizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,51 +5,54 @@ import (
)

type normalizerConfig struct {
DBMS string `json:"dbms"`

// CollectTables specifies whether the normalizer should also extract the table names that a query addresses
CollectTables bool `json:"collect_tables"`
CollectTables bool

// CollectCommands specifies whether the normalizer should extract and return commands as SQL metadata
CollectCommands bool `json:"collect_commands"`
CollectCommands bool

// CollectComments specifies whether the normalizer should extract and return comments as SQL metadata
CollectComments bool `json:"collect_comments"`
CollectComments bool

// KeepSQLAlias reports whether SQL aliases ("AS") should be truncated.
KeepSQLAlias bool `json:"keep_sql_alias"`
}
KeepSQLAlias bool

func WithDBMS(dbms string) func(*normalizerConfig) {
return func(c *normalizerConfig) {
c.DBMS = dbms
}
// UppercaseKeywords reports whether SQL keywords should be uppercased.
UppercaseKeywords bool
}

func WithCollectTables(collectTables bool) func(*normalizerConfig) {
type normalizerOption func(*normalizerConfig)

func WithCollectTables(collectTables bool) normalizerOption {
return func(c *normalizerConfig) {
c.CollectTables = collectTables
}
}

func WithCollectCommands(collectCommands bool) func(*normalizerConfig) {
func WithCollectCommands(collectCommands bool) normalizerOption {
return func(c *normalizerConfig) {
c.CollectCommands = collectCommands
}
}

func WithCollectComments(collectComments bool) func(*normalizerConfig) {
func WithCollectComments(collectComments bool) normalizerOption {
return func(c *normalizerConfig) {
c.CollectComments = collectComments
}
}

func WithKeepSQLAlias(keepSQLAlias bool) func(*normalizerConfig) {
func WithKeepSQLAlias(keepSQLAlias bool) normalizerOption {
return func(c *normalizerConfig) {
c.KeepSQLAlias = keepSQLAlias
}
}

func WithUppercaseKeywords(uppercaseKeywords bool) normalizerOption {
return func(c *normalizerConfig) {
c.UppercaseKeywords = uppercaseKeywords
}
}

type StatementMetadata struct {
Tables []string
Comments []string
Expand All @@ -60,7 +63,7 @@ type Normalizer struct {
config *normalizerConfig
}

func NewNormalizer(opts ...func(*normalizerConfig)) *Normalizer {
func NewNormalizer(opts ...normalizerOption) *Normalizer {
normalizer := Normalizer{
config: &normalizerConfig{},
}
Expand All @@ -80,8 +83,11 @@ const (
// Normalize takes an input SQL string and returns a normalized SQL string, a StatementMetadata struct, and an error.
// The normalizer collapses input SQL into compact format, groups obfuscated values into single placeholder,
// and collects metadata such as table names, comments, and commands.
func (n *Normalizer) Normalize(input string) (normalized string, info *StatementMetadata, err error) {
lexer := New(input)
func (n *Normalizer) Normalize(input string, lexerOpts ...lexerOption) (normalized string, info *StatementMetadata, err error) {
lexer := New(
input,
lexerOpts...,
)

var normalizedSQL string
var statementMetadata = &StatementMetadata{
Expand Down Expand Up @@ -110,7 +116,7 @@ func (n *Normalizer) Normalize(input string) (normalized string, info *Statement
}
}

normalizedSQL = normalizeSQL(token, lastToken, normalizedSQL)
normalizedSQL = normalizeSQL(token, lastToken, normalizedSQL, n.config)

// TODO: We rely on the WS token to determine if we should add a whitespace
// This is not ideal, as SQLs with slightly different formatting will NOT be normalized into single family
Expand All @@ -136,21 +142,19 @@ func (n *Normalizer) Normalize(input string) (normalized string, info *Statement
return strings.TrimSpace(normalizedSQL), statementMetadata, nil
}

func normalizeSQL(token Token, lastToken Token, statement string) string {
func normalizeSQL(token Token, lastToken Token, statement string, config *normalizerConfig) string {
if token.Type == WS || token.Type == COMMENT || token.Type == MULTILINE_COMMENT {
// We don't rely on the WS token to determine if we should add a whitespace
return statement
}

// determine if we should add a whitespace
statement = appendWhitespace(lastToken, token, statement)

// UPPER CASE SQL keywords
if isSQLKeyword(token) {
if isSQLKeyword(token) && config.UppercaseKeywords {
statement += strings.ToUpper(token.Value)
return statement
} else {
statement += token.Value
}
statement += token.Value

return statement
}
Expand Down
67 changes: 44 additions & 23 deletions normalizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestNormalizer(t *testing.T) {
},
{
input: "SELECT * FROM users WHERE id IN (?, ?) and name IN ARRAY[?, ?]",
expected: "SELECT * FROM users WHERE id IN ( ? ) AND name IN ARRAY [ ? ]",
expected: "SELECT * FROM users WHERE id IN ( ? ) and name IN ARRAY [ ? ]",
statementMetadata: StatementMetadata{
Tables: []string{"users"},
Comments: []string{},
Expand All @@ -50,7 +50,7 @@ func TestNormalizer(t *testing.T) {
JOIN vs?.host_alias ha on ha.host_id = h.id
WHERE ha.org_id = ? AND ha.name = ANY ( ?, ? )
`,
expected: "SELECT h.id, h.org_id, h.name, ha.name, h.created FROM vs?.host h JOIN vs?.host_alias ha ON ha.host_id = h.id WHERE ha.org_id = ? AND ha.name = ANY ( ? )",
expected: "SELECT h.id, h.org_id, h.name, ha.name, h.created FROM vs?.host h JOIN vs?.host_alias ha on ha.host_id = h.id WHERE ha.org_id = ? AND ha.name = ANY ( ? )",
statementMetadata: StatementMetadata{
Tables: []string{"vs?.host", "vs?.host_alias"},
Comments: []string{},
Expand Down Expand Up @@ -139,7 +139,7 @@ multiline comment */
Delete FROM user? WHERE id = ?;
END
`,
expected: "CREATE PROCEDURE test_procedure ( ) BEGIN SELECT * FROM users WHERE id = ? ; UPDATE test_users SET name = ? WHERE id = ? ; DELETE FROM user? WHERE id = ? ; END",
expected: "CREATE PROCEDURE test_procedure ( ) BEGIN SELECT * FROM users WHERE id = ? ; Update test_users set name = ? WHERE id = ? ; Delete FROM user? WHERE id = ? ; END",
statementMetadata: StatementMetadata{
Tables: []string{"users", "test_users", "user?"},
Comments: []string{},
Expand Down Expand Up @@ -280,6 +280,24 @@ multiline comment */
Commands: []string{"SELECT"},
},
},
{
input: "SELECT d.id, d.uuid, d.org_id, d.creator_id, d.updater_id, d.monitor_id, d.parent_id, d.original_parent_id, d.scope, d.start_dt, d.end_dt, d.canceled_dt, d.active, d.disabled, d.created, d.modified, d.message, d.monitor_tags, d.recurrence, d.mute_first_recovery_notification, d.scope_v2_query, d.scope_v2 FROM monitor_downtime d, org o WHERE o.id = d.org_id AND d.modified >= ? AND o.partition_num = ANY (?, ?, ?)",
expected: "SELECT d.id, d.uuid, d.org_id, d.creator_id, d.updater_id, d.monitor_id, d.parent_id, d.original_parent_id, d.scope, d.start_dt, d.end_dt, d.canceled_dt, d.active, d.disabled, d.created, d.modified, d.message, d.monitor_tags, d.recurrence, d.mute_first_recovery_notification, d.scope_v2_query, d.scope_v2 FROM monitor_downtime d, org o WHERE o.id = d.org_id AND d.modified >= ? AND o.partition_num = ANY ( ? )",
statementMetadata: StatementMetadata{
Tables: []string{"monitor_downtime"},
Comments: []string{},
Commands: []string{"SELECT"},
},
},
{
input: "SELECT set_host_tags_bigint (? ARRAY[?, ?, ?])",
expected: "SELECT set_host_tags_bigint ( ? ARRAY [ ? ] )",
statementMetadata: StatementMetadata{
Tables: []string{},
Comments: []string{},
Commands: []string{"SELECT"},
},
},
}

normalizer := NewNormalizer(
Expand Down Expand Up @@ -325,7 +343,7 @@ func TestNormalizerNotCollectMetadata(t *testing.T) {
},
{
input: "SELECT id as ID, name as Name FROM users WHERE id IN (?, ?)",
expected: "SELECT id AS ID, name AS Name FROM users WHERE id IN ( ? )",
expected: "SELECT id as ID, name as Name FROM users WHERE id IN ( ? )",
statementMetadata: StatementMetadata{
Tables: []string{},
Comments: []string{},
Expand Down Expand Up @@ -362,41 +380,44 @@ func TestNormalizerNotCollectMetadata(t *testing.T) {

func TestNormalizerFormatting(t *testing.T) {
tests := []struct {
queries []string
expected string
queries []string
expected string
uppercaseKeywords bool
}{
{
queries: []string{
"SELECT id,name, address FROM users where id = ?",
"select id, name, address FROM users where id = ?",
"select id as ID, name as Name, address FROM users where id = ?",
"SELECT id, name, address FROM users where id = ?",
"SELECT id as ID, name as Name, address FROM users where id = ?",
},
expected: "SELECT id, name, address FROM users WHERE id = ?",
expected: "SELECT id, name, address FROM users where id = ?",
},
{
queries: []string{
"SELECT id,name, address FROM users where id IN (?, ?,?, ?)",
"select id, name, address FROM users where id IN ( ? )",
"select id, name, address FROM users where id IN ( ? )",
"select id, name, address FROM users where id IN (?,?,?)",
"SELECT id, name, address FROM users where id IN ( ? )",
"SELECT id, name, address FROM users where id IN ( ? )",
"SELECT id, name, address FROM users where id IN (?,?,?)",
},
expected: "SELECT id, name, address FROM users WHERE id IN ( ? )",
expected: "SELECT id, name, address FROM users where id IN ( ? )",
},
{
queries: []string{
"SELECT * FROM discount where description LIKE ?",
"select * from discount where description LIKE ?",
"select * from discount where description like ?",
},
expected: "SELECT * FROM discount WHERE description LIKE ?",
expected: "SELECT * FROM discount WHERE description LIKE ?",
uppercaseKeywords: true,
},
}

normalizer := NewNormalizer(
WithCollectComments(false),
)
for _, test := range tests {
t.Run("", func(t *testing.T) {
normalizer := NewNormalizer(
WithCollectComments(false),
WithUppercaseKeywords(test.uppercaseKeywords),
)
for _, query := range test.queries {
got, _, err := normalizer.Normalize(query)
assert.NoError(t, err)
Expand Down Expand Up @@ -491,7 +512,7 @@ func TestNormalizeDeobfuscatedSQL(t *testing.T) {
}{
{
input: "SELECT id,name, address FROM users where id = 1",
expected: "SELECT id, name, address FROM users WHERE id = 1",
expected: "SELECT id, name, address FROM users where id = 1",
statementMetadata: StatementMetadata{
Tables: []string{"users"},
Comments: []string{},
Expand All @@ -506,7 +527,7 @@ func TestNormalizeDeobfuscatedSQL(t *testing.T) {
},
{
input: "SELECT id,name, address FROM users where id IN (1, 2, 3, 4)",
expected: "SELECT id, name, address FROM users WHERE id IN ( 1, 2, 3, 4 )",
expected: "SELECT id, name, address FROM users where id IN ( 1, 2, 3, 4 )",
statementMetadata: StatementMetadata{
Tables: []string{"users"},
Comments: []string{},
Expand All @@ -523,7 +544,7 @@ func TestNormalizeDeobfuscatedSQL(t *testing.T) {
input: `
/* test comment */
SELECT id,name, address FROM users where id IN (1, 2, 3, 4)`,
expected: "SELECT id, name, address FROM users WHERE id IN ( 1, 2, 3, 4 )",
expected: "SELECT id, name, address FROM users where id IN ( 1, 2, 3, 4 )",
statementMetadata: StatementMetadata{
Tables: []string{"users"},
Comments: []string{"/* test comment */"},
Expand All @@ -541,7 +562,7 @@ func TestNormalizeDeobfuscatedSQL(t *testing.T) {
input: `
/* test comment */
SELECT id,name, address FROM users where id IN (1, 2, 3, 4)`,
expected: "SELECT id, name, address FROM users WHERE id IN ( 1, 2, 3, 4 )",
expected: "SELECT id, name, address FROM users where id IN ( 1, 2, 3, 4 )",
statementMetadata: StatementMetadata{
Tables: []string{"users"},
Comments: []string{},
Expand Down Expand Up @@ -588,7 +609,7 @@ func TestNormalizeDeobfuscatedSQL(t *testing.T) {
},
{
input: "SELECT u.id as ID, u.name as Name FROM users as u WHERE u.id = '123'",
expected: "SELECT u.id AS ID, u.name AS Name FROM users AS u WHERE u.id = '123'",
expected: "SELECT u.id as ID, u.name as Name FROM users as u WHERE u.id = '123'",
statementMetadata: StatementMetadata{
Tables: []string{"users"},
Comments: []string{},
Expand Down Expand Up @@ -699,6 +720,6 @@ func ExampleNormalizer() {

fmt.Println(normalizedSQL)
fmt.Println(statementMetadata)
// Output: SELECT * FROM users WHERE id IN ( ? )
// Output: SELECT * FROM users WHERE id in ( ? )
// &{[users] [/* this is a comment */] [SELECT]}
}
17 changes: 11 additions & 6 deletions obfuscator.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ type obfuscatorConfig struct {
DollarQuotedFunc bool
}

func WithReplaceDigits(replaceDigits bool) func(*obfuscatorConfig) {
type obfuscatorOption func(*obfuscatorConfig)

func WithReplaceDigits(replaceDigits bool) obfuscatorOption {
return func(c *obfuscatorConfig) {
c.ReplaceDigits = replaceDigits
}
}

func WithDollarQuotedFunc(dollarQuotedFunc bool) func(*obfuscatorConfig) {
func WithDollarQuotedFunc(dollarQuotedFunc bool) obfuscatorOption {
return func(c *obfuscatorConfig) {
c.DollarQuotedFunc = dollarQuotedFunc
}
Expand All @@ -25,7 +27,7 @@ type Obfuscator struct {
config *obfuscatorConfig
}

func NewObfuscator(opts ...func(*obfuscatorConfig)) *Obfuscator {
func NewObfuscator(opts ...obfuscatorOption) *Obfuscator {
obfuscator := &Obfuscator{
config: &obfuscatorConfig{},
}
Expand All @@ -44,10 +46,13 @@ const (

// Obfuscate takes an input SQL string and returns an obfuscated SQL string.
// The obfuscator replaces all literal values with a single placeholder
func (o *Obfuscator) Obfuscate(input string) string {
func (o *Obfuscator) Obfuscate(input string, lexerOpts ...lexerOption) string {
var obfuscatedSQL strings.Builder

lexer := New(input)
lexer := New(
input,
lexerOpts...,
)
for _, token := range lexer.ScanAll() {
switch token.Type {
case NUMBER:
Expand All @@ -57,7 +62,7 @@ func (o *Obfuscator) Obfuscate(input string) string {
// obfuscate the content of dollar quoted function
quotedFunc := token.Value[6 : len(token.Value)-6] // remove the $func$ prefix and suffix
obfuscatedSQL.WriteString("$func$")
obfuscatedSQL.WriteString(o.Obfuscate(quotedFunc))
obfuscatedSQL.WriteString(o.Obfuscate(quotedFunc, lexerOpts...))
obfuscatedSQL.WriteString("$func$")
} else {
obfuscatedSQL.WriteString(StringPlaceholder)
Expand Down
Loading

0 comments on commit 70764ab

Please sign in to comment.