Skip to content

Commit

Permalink
Setup test suite to verify different query flavors per dbms (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
lu-zhengda authored Nov 10, 2023
1 parent 5e11e41 commit 692fac0
Show file tree
Hide file tree
Showing 91 changed files with 1,871 additions and 22 deletions.
142 changes: 142 additions & 0 deletions dbms_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package sqllexer

import (
"embed"
"encoding/json"
"path/filepath"
"strings"
"testing"

"github.com/stretchr/testify/assert"
)

//go:embed testdata/*
var testdata embed.FS

type output struct {
Expected string `json:"expected"`
ObfuscatorConfig *obfuscatorConfig `json:"obfuscator_config,omitempty"`
NormalizerConfig *normalizerConfig `json:"normalizer_config,omitempty"`
StatementMetadata *StatementMetadata `json:"statement_metadata,omitempty"`
}

type testcase struct {
Input string `json:"input"`
Outputs []output `json:"outputs"`
}

// TestQueriesPerDBMS tests a preset of queries and expected output per DBMS
// Test folder structure:
// -- testdata
//
// -- dbms_type
// -- query_type
// -- query_name.json
func TestQueriesPerDBMS(t *testing.T) {
dbmsTypes := []DBMSType{
DBMSPostgres,
}

for _, dbms := range dbmsTypes {
// Get all subdirectories of the testdata folder
baseDir := filepath.Join("testdata", string(dbms))
// Get all subdirectories of the testdata folder
queryTypes, err := testdata.ReadDir(baseDir)
if err != nil {
t.Fatal(err)
}

for _, qt := range queryTypes {
dirPath := filepath.Join(baseDir, qt.Name())
files, err := testdata.ReadDir(dirPath)
if err != nil {
t.Fatal(err)
}

for _, file := range files {
testName := strings.TrimSuffix(file.Name(), ".json")
t.Run(testName, func(t *testing.T) {
queryPath := filepath.Join(dirPath, file.Name())

testfile, err := testdata.ReadFile(queryPath)
if err != nil {
t.Fatal(err)
}

var tt testcase

if err := json.Unmarshal(testfile, &tt); err != nil {
t.Fatal(err)
}

var defaultObfuscatorConfig *obfuscatorConfig
var defaultNormalizerConfig *normalizerConfig

for _, output := range tt.Outputs {
// If the test case has a custom obfuscator or normalizer config
// use it, otherwise use the default config
if output.ObfuscatorConfig != nil {
defaultObfuscatorConfig = output.ObfuscatorConfig
} else {
defaultObfuscatorConfig = &obfuscatorConfig{
DollarQuotedFunc: true,
ReplaceDigits: true,
ReplacePositionalParameter: true,
ReplaceBoolean: true,
ReplaceNull: true,
}
}

if output.NormalizerConfig != nil {
defaultNormalizerConfig = output.NormalizerConfig
} else {
defaultNormalizerConfig = &normalizerConfig{
CollectComments: true,
CollectCommands: true,
CollectTables: true,
CollectProcedure: true,
KeepSQLAlias: false,
UppercaseKeywords: false,
RemoveSpaceBetweenParentheses: false,
KeepTrailingSemicolon: false,
}
}

obfuscator := NewObfuscator(
WithDollarQuotedFunc(defaultObfuscatorConfig.DollarQuotedFunc),
WithReplaceDigits(defaultObfuscatorConfig.ReplaceDigits),
WithReplacePositionalParameter(defaultObfuscatorConfig.ReplacePositionalParameter),
WithReplaceBoolean(defaultObfuscatorConfig.ReplaceBoolean),
WithReplaceNull(defaultObfuscatorConfig.ReplaceNull),
)

normalizer := NewNormalizer(
WithCollectComments(defaultNormalizerConfig.CollectComments),
WithCollectCommands(defaultNormalizerConfig.CollectCommands),
WithCollectTables(defaultNormalizerConfig.CollectTables),
WithCollectProcedures(defaultNormalizerConfig.CollectProcedure),
WithKeepSQLAlias(defaultNormalizerConfig.KeepSQLAlias),
WithUppercaseKeywords(defaultNormalizerConfig.UppercaseKeywords),
WithRemoveSpaceBetweenParentheses(defaultNormalizerConfig.RemoveSpaceBetweenParentheses),
WithKeepTrailingSemicolon(defaultNormalizerConfig.KeepTrailingSemicolon),
)

got, statementMetadata, err := ObfuscateAndNormalize(string(tt.Input), obfuscator, normalizer, WithDBMS(dbms))

if err != nil {
t.Fatal(err)
}

// Compare the expected output with the actual output
assert.Equal(t, output.Expected, got)

// Compare the expected statement metadata with the actual statement metadata
if output.StatementMetadata != nil {
assert.Equal(t, output.StatementMetadata, statementMetadata)
}
}
})
}
}
}
}
41 changes: 25 additions & 16 deletions normalizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,31 @@ import (

type normalizerConfig struct {
// CollectTables specifies whether the normalizer should also extract the table names that a query addresses
CollectTables bool
CollectTables bool `json:"collect_tables"`

// CollectCommands specifies whether the normalizer should extract and return commands as SQL metadata
CollectCommands bool
CollectCommands bool `json:"collect_commands"`

// CollectComments specifies whether the normalizer should extract and return comments as SQL metadata
CollectComments bool
CollectComments bool `json:"collect_comments"`

// CollectProcedure specifies whether the normalizer should extract and return procedure name as SQL metadata
CollectProcedure bool
CollectProcedure bool `json:"collect_procedure"`

// KeepSQLAlias specifies whether SQL aliases ("AS") should be truncated.
KeepSQLAlias bool
KeepSQLAlias bool `json:"keep_sql_alias"`

// UppercaseKeywords specifies whether SQL keywords should be uppercased.
UppercaseKeywords bool
UppercaseKeywords bool `json:"uppercase_keywords"`

// RemoveSpaceBetweenParentheses specifies whether spaces should be kept between parentheses.
// Spaces are inserted between parentheses by default. but this can be disabled by setting this to true.
RemoveSpaceBetweenParentheses bool
RemoveSpaceBetweenParentheses bool `json:"remove_space_between_parentheses"`

// KeepTrailingSemicolon specifies whether the normalizer should keep the trailing semicolon.
// The trailing semicolon is removed by default, but this can be disabled by setting this to true.
// PL/SQL requires a trailing semicolon, so this should be set to true when normalizing PL/SQL.
KeepTrailingSemicolon bool
KeepTrailingSemicolon bool `json:"keep_trailing_semicolon"`
}

type normalizerOption func(*normalizerConfig)
Expand Down Expand Up @@ -84,11 +84,11 @@ func WithKeepTrailingSemicolon(keepTrailingSemicolon bool) normalizerOption {
}

type StatementMetadata struct {
Size int
Tables []string
Comments []string
Commands []string
Procedures []string
Size int `json:"size"`
Tables []string `json:"tables"`
Comments []string `json:"comments"`
Commands []string `json:"commands"`
Procedures []string `json:"procedures"`
}

type groupablePlaceholder struct {
Expand Down Expand Up @@ -162,7 +162,7 @@ func (n *Normalizer) collectMetadata(token *Token, lastToken *Token, statementMe
if n.config.CollectCommands && isCommand(strings.ToUpper(tokenVal)) {
// Collect commands
statementMetadata.Commands = append(statementMetadata.Commands, strings.ToUpper(tokenVal))
} else if n.config.CollectTables && isTableIndicator(strings.ToUpper(lastToken.Value)) {
} else if n.config.CollectTables && isTableIndicator(strings.ToUpper(lastToken.Value)) && !isSQLKeyword(token) {
// Collect table names
statementMetadata.Tables = append(statementMetadata.Tables, tokenVal)
} else if n.config.CollectProcedure && isProcedure(lastToken) {
Expand Down Expand Up @@ -217,7 +217,7 @@ func (n *Normalizer) normalizeSQL(token *Token, lastToken *Token, normalizedSQLB
}

// group consecutive obfuscated values into single placeholder
if n.isObfuscatedValueGroupable(token, lastToken, groupablePlaceholder) {
if n.isObfuscatedValueGroupable(token, lastToken, groupablePlaceholder, normalizedSQLBuilder) {
// return the token but not write it to the normalizedSQLBuilder
*lastToken = *token
return
Expand All @@ -239,7 +239,7 @@ func (n *Normalizer) writeToken(token *Token, normalizedSQLBuilder *strings.Buil
}
}

func (n *Normalizer) isObfuscatedValueGroupable(token *Token, lastToken *Token, groupablePlaceholder *groupablePlaceholder) bool {
func (n *Normalizer) isObfuscatedValueGroupable(token *Token, lastToken *Token, groupablePlaceholder *groupablePlaceholder, normalizedSQLBuilder *strings.Builder) bool {
if token.Value == NumberPlaceholder || token.Value == StringPlaceholder {
if lastToken.Value == "(" || lastToken.Value == "[" {
// if the last token is "(" or "[", and the current token is a placeholder,
Expand All @@ -258,6 +258,15 @@ func (n *Normalizer) isObfuscatedValueGroupable(token *Token, lastToken *Token,
if groupablePlaceholder.groupable && (token.Value == ")" || token.Value == "]") {
// end of groupable placeholders
groupablePlaceholder.groupable = false
return false
}

if groupablePlaceholder.groupable && token.Value != NumberPlaceholder && token.Value != StringPlaceholder && lastToken.Value == "," {
// This is a tricky edge case. If we are inside a groupbale block, and the current token is not a placeholder,
// we not only want to write the current token to the normalizedSQLBuilder, but also write the last comma that we skipped.
// For example, (?, ARRAY[?, ?, ?]) should be normalized as (?, ARRAY[?])
normalizedSQLBuilder.WriteString(lastToken.Value)
return false
}

return false
Expand Down
10 changes: 5 additions & 5 deletions obfuscator.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ import (
)

type obfuscatorConfig struct {
DollarQuotedFunc bool
ReplaceDigits bool
ReplacePositionalParameter bool
ReplaceBoolean bool
ReplaceNull bool
DollarQuotedFunc bool `json:"dollar_quoted_func"`
ReplaceDigits bool `json:"replace_digits"`
ReplacePositionalParameter bool `json:"replace_positional_parameter"`
ReplaceBoolean bool `json:"replace_boolean"`
ReplaceNull bool `json:"replace_null"`
}

type obfuscatorOption func(*obfuscatorConfig)
Expand Down
2 changes: 1 addition & 1 deletion sqllexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type Token struct {
}

type LexerConfig struct {
DBMS DBMSType
DBMS DBMSType `json:"dbms,omitempty"`
}

type lexerOption func(*LexerConfig)
Expand Down
2 changes: 2 additions & 0 deletions sqllexer_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ var keywords = map[string]bool{
"UNLOGGED": true,
"RECURSIVE": true,
"RETURNING": true,
"OFFSET": true,
"OF": true,
}

func isWhitespace(ch rune) bool {
Expand Down
54 changes: 54 additions & 0 deletions testdata/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Test Suite

The test suite is a collection of test SQL statements that are organized per DBMS. The test suite is used to test the SQL obfuscator and normalizer for correctness and completeness. It is also intended to cover DBMS specific edge cases, that are not covered by the generic unit tests.

## Test Suite Structure

The test suite is organized in the following way:

```text
testdata
├── README.md
├── dbms1
│   ├── query_type1
│   │   ├── test1.json
│   └── query_type2
│   ├── test1.json
dbms_test.go
```

The test suite is organized per DBMS. Each DBMS has a number of query types. Each query type has a number of test cases. Each test case consists of a SQL statement and the expected output of the obfuscator/normalizer.

## Test File Format

The test files are simple json files where each test case comes with one input SQL statements and an array of expected outputs.
Each expected output can optionally come with a configuration for the obfuscator and normalizer. The configuration is optional, because the default configuration is used if no configuration is provided.

testcase.json:

```json
{
"input": "SELECT * FROM table1",
"outputs": [
{
// Test case 1
"expected": "SELECT * FROM table1",
"obfuscator_config": {...}, // optional
"normalizer_config": {...} // optional
},
{
// Test case 2
"expected": "SELECT * FROM table1",
"obfuscator_config": {...}, // optional
"normalizer_config": {...} // optional
}
]
}
```

## How to write a new test case

1. Create a new directory for the DBMS, if it does not exist yet. (this step is often not necessary)
2. Create a new directory for the query type, if it does not exist yet.
3. Create a new test case `.json` file with the SQL statement and expected output. Refer to the [test file format](#test-file-format) or `testcase struct` in [dbms_test.go](../dbms_test.go) for more details.
4. Run the test suite to verify that the test case is working as expected.
19 changes: 19 additions & 0 deletions testdata/postgresql/complex/delete-complex-subqueries-joins.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"input": "DELETE FROM \n users u\nUSING \n orders o,\n order_items oi,\n products p\nWHERE \n u.id = o.user_id\nAND o.id = oi.order_id\nAND oi.product_id = p.id\nAND p.category = 'obsolete'\nAND o.order_date < NOW() - INTERVAL '5 years';",
"outputs": [
{
"expected": "DELETE FROM users u USING orders o, order_items oi, products p WHERE u.id = o.user_id AND o.id = oi.order_id AND oi.product_id = p.id AND p.category = ? AND o.order_date < NOW ( ) - INTERVAL ?",
"statement_metadata": {
"size": 11,
"tables": [
"users"
],
"commands": [
"DELETE"
],
"comments": [],
"procedures": []
}
}
]
}
24 changes: 24 additions & 0 deletions testdata/postgresql/complex/insert-complex-select-joins.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"input": "INSERT INTO order_summaries (order_id, product_count, total_amount, average_product_price)\nSELECT \n o.id,\n COUNT(p.id),\n SUM(oi.amount),\n AVG(p.price)\nFROM \n orders o\nJOIN order_items oi ON o.id = oi.order_id\nJOIN products p ON oi.product_id = p.id\nGROUP BY \n o.id\nHAVING \n SUM(oi.amount) > 1000;",
"outputs": [
{
"expected": "INSERT INTO order_summaries ( order_id, product_count, total_amount, average_product_price ) SELECT o.id, COUNT ( p.id ), SUM ( oi.amount ), AVG ( p.price ) FROM orders o JOIN order_items oi ON o.id = oi.order_id JOIN products p ON oi.product_id = p.id GROUP BY o.id HAVING SUM ( oi.amount ) > ?",
"statement_metadata": {
"size": 56,
"tables": [
"order_summaries",
"orders",
"order_items",
"products"
],
"commands": [
"INSERT",
"SELECT",
"JOIN"
],
"comments": [],
"procedures": []
}
}
]
}
Loading

0 comments on commit 692fac0

Please sign in to comment.