diff --git a/CHANGELOG.md b/CHANGELOG.md index efa430a7..ae8a17a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### Added +- Support for queries using secondary indexes added. - Switched to using the upstream Go driver https://github.com/scylladb/gocql instead of the regular driver. The goal is performance gains by using that shard awareness feature as well as providing proper more real testing of the driver. diff --git a/cmd/gemini/root.go b/cmd/gemini/root.go index e46f933e..6d3b96a8 100644 --- a/cmd/gemini/root.go +++ b/cmd/gemini/root.go @@ -225,7 +225,7 @@ func validationJob(schema *gemini.Schema, table gemini.Table, s *gemini.Session, if verbose { fmt.Printf("%s (values=%v)\n", checkQuery, checkValues) } - err := s.Check(checkQuery, checkValues...) + err := s.Check(table, checkQuery, checkValues...) if err == nil { testStatus.ReadOps++ } else { diff --git a/go.mod b/go.mod index 776efb72..35cfabb4 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/kr/pretty v0.1.0 // indirect github.com/mattn/go-colorable v0.1.1 // indirect github.com/mattn/go-isatty v0.0.6 // indirect + github.com/scylladb/go-set v1.0.2 github.com/segmentio/ksuid v1.0.2 github.com/spf13/cobra v0.0.3 github.com/spf13/pflag v1.0.3 // indirect diff --git a/go.sum b/go.sum index f625b612..a9ee4bf2 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,7 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/fatih/set v0.2.1/go.mod h1:+RKtMCH+favT2+3YecHGxcc0b4KyVWA1QWWJUs4E0CI= github.com/gocql/gocql v0.0.0-20190301043612-f6df8288f9b4 h1:vF83LI8tAakwEwvWZtrIEx7pOySacl2TOxx6eXk4ePo= github.com/gocql/gocql v0.0.0-20190301043612-f6df8288f9b4/go.mod h1:4Fw1eo5iaEhDUs8XyuhSVCVy52Jq3L+/3GJgYkwc+/0= github.com/golang/snappy v0.0.0-20170215233205-553a64147049 h1:K9KHZbXKpGydfDN0aZrsoHpLJlZsBrGMFWbgLDGnPZk= @@ -30,6 +31,9 @@ github.com/mattn/go-isatty v0.0.6 h1:SrwhHcpV4nWrMGdNcC2kXpMfcBVYGDuTArqyhocJgvA github.com/mattn/go-isatty v0.0.6/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/scylladb/go-set v1.0.2 h1:SkvlMCKhP0wyyct6j+0IHJkBkSZL+TDzZ4E7f7BCcRE= +github.com/scylladb/go-set v1.0.2/go.mod h1:DkpGd78rljTxKAnTDPFqXSGxvETQnJyuSOQwsHycqfs= +github.com/scylladb/gocql v1.0.1 h1:LVWuLOTllhzKNh4QzPEe/gbsTVww1Li1xTHLv/vaTvY= github.com/scylladb/gocql v1.0.1/go.mod h1:4Fw1eo5iaEhDUs8XyuhSVCVy52Jq3L+/3GJgYkwc+/0= github.com/segmentio/ksuid v1.0.2 h1:9yBfKyw4ECGTdALaF09Snw3sLJmYIX6AbPJrAy6MrDc= github.com/segmentio/ksuid v1.0.2/go.mod h1:BXuJDr2byAiHuQaQtSKoXh1J0YmUDurywOXgB2w+OSU= diff --git a/schema.go b/schema.go index bd86aeb6..67d8c687 100644 --- a/schema.go +++ b/schema.go @@ -17,11 +17,17 @@ type ColumnDef struct { Type string } +type IndexDef struct { + Name string + Column ColumnDef +} + type Table struct { Name string `json:"name"` PartitionKeys []ColumnDef `json:"partition_keys"` ClusteringKeys []ColumnDef `json:"clustering_keys"` Columns []ColumnDef `json:"columns"` + Indexes []IndexDef `json:"indexes"` } type Stmt struct { @@ -63,6 +69,10 @@ func genColumnDef(prefix string, idx int) ColumnDef { } } +func genIndexName(prefix string, idx int) string { + return fmt.Sprintf("%s_idx", genColumnName(prefix, idx)) +} + const ( MaxPartitionKeys = 2 MaxClusteringKeys = 4 @@ -75,26 +85,34 @@ func GenSchema() *Schema { Name: "ks1", } builder.Keyspace(keyspace) - partitionKeys := []ColumnDef{} + var partitionKeys []ColumnDef numPartitionKeys := rand.Intn(MaxPartitionKeys-1) + 1 for i := 0; i < numPartitionKeys; i++ { partitionKeys = append(partitionKeys, ColumnDef{Name: genColumnName("pk", i), Type: genColumnType()}) } - clusteringKeys := []ColumnDef{} + var clusteringKeys []ColumnDef numClusteringKeys := rand.Intn(MaxClusteringKeys) for i := 0; i < numClusteringKeys; i++ { clusteringKeys = append(clusteringKeys, ColumnDef{Name: genColumnName("ck", i), Type: genColumnType()}) } - columns := []ColumnDef{} + var columns []ColumnDef numColumns := rand.Intn(MaxColumns) for i := 0; i < numColumns; i++ { columns = append(columns, ColumnDef{Name: genColumnName("col", i), Type: genColumnType()}) } + var indexes []IndexDef + if numColumns > 0 { + numIndexes := rand.Intn(numColumns) + for i := 0; i < numIndexes; i++ { + indexes = append(indexes, IndexDef{Name: genIndexName("col", i), Column: columns[i]}) + } + } table := Table{ Name: "table1", PartitionKeys: partitionKeys, ClusteringKeys: clusteringKeys, Columns: columns, + Indexes: indexes, } builder.Table(table) return builder.Build() @@ -159,9 +177,11 @@ func (s *Schema) GetCreateSchema() []string { stmts := []string{createKeyspace} for _, t := range s.Tables { - partitionKeys := []string{} - clusteringKeys := []string{} - columns := []string{} + var ( + partitionKeys []string + clusteringKeys []string + columns []string + ) for _, pk := range t.PartitionKeys { partitionKeys = append(partitionKeys, pk.Name) columns = append(columns, fmt.Sprintf("%s %s", pk.Name, pk.Type)) @@ -181,13 +201,18 @@ func (s *Schema) GetCreateSchema() []string { strings.Join(partitionKeys, ","), strings.Join(clusteringKeys, ",")) } stmts = append(stmts, createTable) + for _, idef := range t.Indexes { + stmts = append(stmts, fmt.Sprintf("CREATE INDEX %s ON %s.%s (%s)", idef.Name, s.Keyspace.Name, t.Name, idef.Column.Name)) + } } return stmts } func (s *Schema) GenInsertStmt(t Table, p *PartitionRange) *Stmt { - columns := []string{} - placeholders := []string{} + var ( + columns []string + placeholders []string + ) values := make([]interface{}, 0) for _, pk := range t.PartitionKeys { columns = append(columns, pk.Name) @@ -214,8 +239,10 @@ func (s *Schema) GenInsertStmt(t Table, p *PartitionRange) *Stmt { } func (s *Schema) GenDeleteRows(t Table, p *PartitionRange) *Stmt { - relations := []string{} - values := make([]interface{}, 0) + var ( + relations []string + values []interface{} + ) for _, pk := range t.PartitionKeys { relations = append(relations, fmt.Sprintf("%s = ?", pk.Name)) values = genValue(pk.Type, p, values) @@ -242,11 +269,16 @@ func (s *Schema) GenMutateStmt(t Table, p *PartitionRange) *Stmt { default: return s.GenInsertStmt(t, p) } - return nil } func (s *Schema) GenCheckStmt(t Table, p *PartitionRange) *Stmt { - switch n := rand.Intn(4); n { + var n int + if len(t.Indexes) > 0 { + n = rand.Intn(5) + } else { + n = rand.Intn(4) + } + switch n { case 0: return s.genSinglePartitionQuery(t, p) case 1: @@ -255,12 +287,14 @@ func (s *Schema) GenCheckStmt(t Table, p *PartitionRange) *Stmt { return s.genClusteringRangeQuery(t, p) case 3: return s.genMultiplePartitionClusteringRangeQuery(t, p) + case 4: + return s.genSingleIndexQuery(t, p) } return nil } func (s *Schema) genSinglePartitionQuery(t Table, p *PartitionRange) *Stmt { - relations := []string{} + var relations []string values := make([]interface{}, 0) for _, pk := range t.PartitionKeys { relations = append(relations, fmt.Sprintf("%s = ?", pk.Name)) @@ -276,10 +310,12 @@ func (s *Schema) genSinglePartitionQuery(t Table, p *PartitionRange) *Stmt { } func (s *Schema) genMultiplePartitionQuery(t Table, p *PartitionRange) *Stmt { - relations := []string{} - values := make([]interface{}, 0) + var ( + relations []string + pkNames []string + values []interface{} + ) pkNum := rand.Intn(10) - pkNames := []string{} for _, pk := range t.PartitionKeys { pkNames = append(pkNames, pk.Name) relations = append(relations, fmt.Sprintf("%s IN (%s)", pk.Name, strings.TrimRight(strings.Repeat("?,", pkNum), ","))) @@ -297,8 +333,10 @@ func (s *Schema) genMultiplePartitionQuery(t Table, p *PartitionRange) *Stmt { } func (s *Schema) genClusteringRangeQuery(t Table, p *PartitionRange) *Stmt { - relations := []string{} - values := make([]interface{}, 0) + var ( + relations []string + values []interface{} + ) for _, pk := range t.PartitionKeys { relations = append(relations, fmt.Sprintf("%s = ?", pk.Name)) values = genValue(pk.Type, p, values) @@ -317,10 +355,12 @@ func (s *Schema) genClusteringRangeQuery(t Table, p *PartitionRange) *Stmt { } func (s *Schema) genMultiplePartitionClusteringRangeQuery(t Table, p *PartitionRange) *Stmt { - relations := []string{} + var ( + relations []string + pkNames []string + values []interface{} + ) pkNum := rand.Intn(10) - pkNames := []string{} - values := make([]interface{}, 0) for _, pk := range t.PartitionKeys { pkNames = append(pkNames, pk.Name) relations = append(relations, fmt.Sprintf("%s IN (%s)", pk.Name, strings.TrimRight(strings.Repeat("?,", pkNum), ","))) @@ -341,6 +381,21 @@ func (s *Schema) genMultiplePartitionClusteringRangeQuery(t Table, p *PartitionR } } +func (s *Schema) genSingleIndexQuery(t Table, p *PartitionRange) *Stmt { + if len(t.Indexes) == 0 { + return nil + } + idx := rand.Intn(len(t.Indexes)) + query := fmt.Sprintf("SELECT * FROM %s.%s WHERE %s=?", s.Keyspace.Name, t.Name, t.Indexes[idx].Column.Name) + values := genValue(t.Indexes[idx].Column.Type, p, nil) + return &Stmt{ + Query: query, + Values: func() []interface{} { + return values + }, + } +} + type SchemaBuilder interface { Keyspace(Keyspace) SchemaBuilder Table(Table) SchemaBuilder diff --git a/session.go b/session.go index 0c4c7fe6..9c6faaf4 100644 --- a/session.go +++ b/session.go @@ -3,15 +3,20 @@ package gemini import ( "errors" "fmt" + "sort" + "strconv" + "strings" "time" "github.com/gocql/gocql" "github.com/google/go-cmp/cmp" + "github.com/scylladb/go-set/strset" ) type Session struct { testSession *gocql.Session oracleSession *gocql.Session + schema *Schema } var ( @@ -54,23 +59,112 @@ func (s *Session) Mutate(query string, values ...interface{}) error { return nil } -func (s *Session) Check(query string, values ...interface{}) error { +func (s *Session) Check(table Table, query string, values ...interface{}) error { testIter := s.testSession.Query(query, values...).Iter() oracleIter := s.oracleSession.Query(query, values...).Iter() - for { - testRow := make(map[string]interface{}) - if !testIter.MapScan(testRow) { - break - } - oracleRow := make(map[string]interface{}) - if !oracleIter.MapScan(oracleRow) { - break - } + defer func() { + testIter.Close() + oracleIter.Close() + }() + + testRows := loadSet(testIter) + oracleRows := loadSet(oracleIter) + if len(testRows) == 0 && len(oracleRows) == 0 { + return ErrReadNoDataReturned + } + if len(testRows) != len(oracleRows) { + testSet := strset.New(pks(table, testRows)...) + oracleSet := strset.New(pks(table, oracleRows)...) + fmt.Printf("Missing in Test: %s\n", strset.Difference(oracleSet, testSet).List()) + fmt.Printf("Missing in Oracle: %s\n", strset.Difference(testSet, oracleSet).List()) + return fmt.Errorf("row count differ (%d ne %d)", len(testRows), len(oracleRows)) + } + sort.SliceStable(testRows, func(i, j int) bool { + return lt(testRows[i], testRows[j]) + }) + sort.SliceStable(oracleRows, func(i, j int) bool { + return lt(oracleRows[i], oracleRows[j]) + }) + for i, oracleRow := range oracleRows { + testRow := testRows[i] diff := cmp.Diff(oracleRow, testRow) if diff != "" { return fmt.Errorf("rows differ (-%v +%v): %v", oracleRow, testRow, diff) } - return nil } - return ErrReadNoDataReturned + return nil +} + +func pks(t Table, rows []map[string]interface{}) []string { + var keySet []string + for _, row := range rows { + keys := make([]string, 0, len(t.PartitionKeys)+len(t.ClusteringKeys)) + keys = extractRowValues(keys, t.PartitionKeys, row) + keys = extractRowValues(keys, t.ClusteringKeys, row) + keySet = append(keySet, strings.Join(keys, ", ")) + } + return keySet +} + +func extractRowValues(values []string, columns []ColumnDef, row map[string]interface{}) []string { + for _, pk := range columns { + cv := row[pk.Name] + switch pk.Type { + case "int": + v, _ := cv.(int) + values = append(values, pk.Name+"="+strconv.Itoa(v)) + case "bigint": + v, _ := cv.(int64) + values = append(values, pk.Name+"="+strconv.FormatInt(v, 10)) + case "uuid": + v, _ := cv.(gocql.UUID) + values = append(values, pk.Name+"="+v.String()) + case "blob": + v, _ := cv.([]byte) + values = append(values, pk.Name+"="+string(v)) + case "text", "varchar": + v, _ := cv.(string) + values = append(values, pk.Name+"="+v) + case "timestamp", "date": + v, _ := cv.(time.Time) + values = append(values, pk.Name+"="+v.String()) + default: + panic(fmt.Sprintf("not supported type %s", pk)) + } + } + return values +} + +func lt(mi, mj map[string]interface{}) bool { + switch mis := mi["pk0"].(type) { + case []byte: + mjs, _ := mj["pk0"].([]byte) + return string(mis) < string(mjs) + case string: + mjs, _ := mj["pk0"].(string) + return mis < mjs + case int: + mjs, _ := mj["pk0"].(int) + return mis < mjs + case gocql.UUID: + mjs, _ := mj["pk0"].(gocql.UUID) + return mis.String() < mjs.String() + case time.Time: + mjs, _ := mj["pk0"].(time.Time) + return mis.UnixNano() < mjs.UnixNano() + default: + panic(fmt.Sprintf("unhandled type %T!\n", mis)) + } +} + +func loadSet(iter *gocql.Iter) []map[string]interface{} { + var rows []map[string]interface{} + for { + row := make(map[string]interface{}) + if !iter.MapScan(row) { + break + } + rows = append(rows, row) + } + return rows }