diff --git a/flow/connectors/core.go b/flow/connectors/core.go index 30197a461d..6d008f8950 100644 --- a/flow/connectors/core.go +++ b/flow/connectors/core.go @@ -485,6 +485,7 @@ var ( _ CreateTablesFromExistingConnector = &connsnowflake.SnowflakeConnector{} _ QRepPullConnector = &connpostgres.PostgresConnector{} + _ QRepPullConnector = &connmysql.MySqlConnector{} _ QRepPullConnector = &connsqlserver.SQLServerConnector{} _ QRepPullPgConnector = &connpostgres.PostgresConnector{} diff --git a/flow/connectors/mysql/cdc.go b/flow/connectors/mysql/cdc.go index 07e728fff1..c685f5ce56 100644 --- a/flow/connectors/mysql/cdc.go +++ b/flow/connectors/mysql/cdc.go @@ -66,74 +66,13 @@ func (c *MySqlConnector) getTableSchemaForTable( } columns := make([]*protos.FieldDescription, 0, len(rs.Values)) primary := make([]string, 0) + for _, field := range rs.Fields { - var qkind qvalue.QValueKind - switch field.Type { - case mysql.MYSQL_TYPE_DECIMAL: - qkind = qvalue.QValueKindNumeric - case mysql.MYSQL_TYPE_TINY: - qkind = qvalue.QValueKindInt16 // TODO qvalue.QValueKindInt8 - case mysql.MYSQL_TYPE_SHORT: - qkind = qvalue.QValueKindInt16 - case mysql.MYSQL_TYPE_LONG: - qkind = qvalue.QValueKindInt32 - case mysql.MYSQL_TYPE_FLOAT: - qkind = qvalue.QValueKindFloat32 - case mysql.MYSQL_TYPE_DOUBLE: - qkind = qvalue.QValueKindFloat64 - case mysql.MYSQL_TYPE_NULL: - qkind = qvalue.QValueKindInvalid // TODO qvalue.QValueKindNothing - case mysql.MYSQL_TYPE_TIMESTAMP: - qkind = qvalue.QValueKindTimestamp - case mysql.MYSQL_TYPE_LONGLONG: - qkind = qvalue.QValueKindInt64 - case mysql.MYSQL_TYPE_INT24: - qkind = qvalue.QValueKindInt32 - case mysql.MYSQL_TYPE_DATE: - qkind = qvalue.QValueKindDate - case mysql.MYSQL_TYPE_TIME: - qkind = qvalue.QValueKindTime - case mysql.MYSQL_TYPE_DATETIME: - qkind = qvalue.QValueKindTimestamp - case mysql.MYSQL_TYPE_YEAR: - qkind = qvalue.QValueKindInt16 - case mysql.MYSQL_TYPE_NEWDATE: - qkind = qvalue.QValueKindDate - case mysql.MYSQL_TYPE_VARCHAR: - qkind = qvalue.QValueKindString - case mysql.MYSQL_TYPE_BIT: - qkind = qvalue.QValueKindInt64 - case mysql.MYSQL_TYPE_TIMESTAMP2: - qkind = qvalue.QValueKindTimestamp - case mysql.MYSQL_TYPE_DATETIME2: - qkind = qvalue.QValueKindTimestamp - case mysql.MYSQL_TYPE_TIME2: - qkind = qvalue.QValueKindTime - case mysql.MYSQL_TYPE_JSON: - qkind = qvalue.QValueKindJSON - case mysql.MYSQL_TYPE_NEWDECIMAL: - qkind = qvalue.QValueKindNumeric - case mysql.MYSQL_TYPE_ENUM: - qkind = qvalue.QValueKindInt64 - case mysql.MYSQL_TYPE_SET: - qkind = qvalue.QValueKindInt64 - case mysql.MYSQL_TYPE_TINY_BLOB: - qkind = qvalue.QValueKindBytes - case mysql.MYSQL_TYPE_MEDIUM_BLOB: - qkind = qvalue.QValueKindBytes - case mysql.MYSQL_TYPE_LONG_BLOB: - qkind = qvalue.QValueKindBytes - case mysql.MYSQL_TYPE_BLOB: - qkind = qvalue.QValueKindBytes - case mysql.MYSQL_TYPE_VAR_STRING: - qkind = qvalue.QValueKindString - case mysql.MYSQL_TYPE_STRING: - qkind = qvalue.QValueKindString - case mysql.MYSQL_TYPE_GEOMETRY: - qkind = qvalue.QValueKindGeometry - default: - return nil, fmt.Errorf("unknown mysql type %d", field.Type) + qkind, err := qkindFromMysql(field.Type) + if err != nil { + return nil, err } + column := &protos.FieldDescription{ Name: string(field.Name), Type: string(qkind), @@ -242,61 +181,6 @@ func (c *MySqlConnector) RemoveTablesFromPublication(ctx context.Context, req *p return nil } -func qvalueFromMysql(mytype byte, qkind qvalue.QValueKind, val any) qvalue.QValue { - // TODO signedness, in ev.Table, need to extend QValue system - // See go-mysql row_event.go for mapping - switch val := val.(type) { - case nil: - return qvalue.QValueNull(qkind) - case int8: // TODO qvalue.Int8 - return qvalue.QValueInt16{Val: int16(val)} - case int16: - return qvalue.QValueInt16{Val: val} - case int32: - return qvalue.QValueInt32{Val: val} - case int64: - return qvalue.QValueInt64{Val: val} - case float32: - return qvalue.QValueFloat32{Val: val} - case float64: - return qvalue.QValueFloat64{Val: val} - case decimal.Decimal: - return qvalue.QValueNumeric{Val: val} - case int: - // YEAR: https://dev.mysql.com/doc/refman/8.4/en/year.html - return qvalue.QValueInt16{Val: int16(val)} - case time.Time: - return qvalue.QValueTimestamp{Val: val} - case *replication.JsonDiff: - // TODO support somehow?? - return qvalue.QValueNull(qvalue.QValueKindJSON) - case []byte: - switch mytype { - case mysql.MYSQL_TYPE_BLOB: - return qvalue.QValueBytes{Val: val} - case mysql.MYSQL_TYPE_JSON: - return qvalue.QValueJSON{Val: string(val)} - case mysql.MYSQL_TYPE_GEOMETRY: - // TODO figure out mysql geo encoding - return qvalue.QValueGeometry{Val: string(val)} - } - case string: - switch mytype { - case mysql.MYSQL_TYPE_TIME: - // TODO parse - case mysql.MYSQL_TYPE_TIME2: - // TODO parse - case mysql.MYSQL_TYPE_DATE: - // TODO parse - case mysql.MYSQL_TYPE_VARCHAR, - mysql.MYSQL_TYPE_VAR_STRING, - mysql.MYSQL_TYPE_STRING: - return qvalue.QValueString{Val: val} - } - } - panic(fmt.Sprintf("unexpected type %T for mysql type %d", val, mytype)) -} - func (c *MySqlConnector) PullRecords( ctx context.Context, catalogPool *pgxpool.Pool, @@ -372,7 +256,7 @@ func (c *MySqlConnector) PullRecords( items := model.NewRecordItems(len(row)) for idx, val := range row { fd := schema.Columns[idx] - items.AddColumn(fd.Name, qvalueFromMysql(ev.Table.ColumnType[idx], qvalue.QValueKind(fd.Type), val)) + items.AddColumn(fd.Name, qvalueFromMysqlRowEvent(ev.Table.ColumnType[idx], qvalue.QValueKind(fd.Type), val)) } recordCount += 1 @@ -392,13 +276,13 @@ func (c *MySqlConnector) PullRecords( oldItems := model.NewRecordItems(len(oldRow)) for idx, val := range oldRow { fd := schema.Columns[idx] - oldItems.AddColumn(fd.Name, qvalueFromMysql(ev.Table.ColumnType[idx], qvalue.QValueKind(fd.Type), val)) + oldItems.AddColumn(fd.Name, qvalueFromMysqlRowEvent(ev.Table.ColumnType[idx], qvalue.QValueKind(fd.Type), val)) } newRow := ev.Rows[idx+1] newItems := model.NewRecordItems(len(newRow)) for idx, val := range ev.Rows[idx+1] { fd := schema.Columns[idx] - newItems.AddColumn(fd.Name, qvalueFromMysql(ev.Table.ColumnType[idx], qvalue.QValueKind(fd.Type), val)) + newItems.AddColumn(fd.Name, qvalueFromMysqlRowEvent(ev.Table.ColumnType[idx], qvalue.QValueKind(fd.Type), val)) } recordCount += 1 @@ -417,7 +301,7 @@ func (c *MySqlConnector) PullRecords( items := model.NewRecordItems(len(row)) for idx, val := range row { fd := schema.Columns[idx] - items.AddColumn(fd.Name, qvalueFromMysql(ev.Table.ColumnType[idx], qvalue.QValueKind(fd.Type), val)) + items.AddColumn(fd.Name, qvalueFromMysqlRowEvent(ev.Table.ColumnType[idx], qvalue.QValueKind(fd.Type), val)) } recordCount += 1 @@ -440,3 +324,58 @@ func (c *MySqlConnector) PullRecords( } } } + +func qvalueFromMysqlRowEvent(mytype byte, qkind qvalue.QValueKind, val any) qvalue.QValue { + // TODO signedness, in ev.Table, need to extend QValue system + // See go-mysql row_event.go for mapping + switch val := val.(type) { + case nil: + return qvalue.QValueNull(qkind) + case int8: // TODO qvalue.Int8 + return qvalue.QValueInt16{Val: int16(val)} + case int16: + return qvalue.QValueInt16{Val: val} + case int32: + return qvalue.QValueInt32{Val: val} + case int64: + return qvalue.QValueInt64{Val: val} + case float32: + return qvalue.QValueFloat32{Val: val} + case float64: + return qvalue.QValueFloat64{Val: val} + case decimal.Decimal: + return qvalue.QValueNumeric{Val: val} + case int: + // YEAR: https://dev.mysql.com/doc/refman/8.4/en/year.html + return qvalue.QValueInt16{Val: int16(val)} + case time.Time: + return qvalue.QValueTimestamp{Val: val} + case *replication.JsonDiff: + // TODO support somehow?? + return qvalue.QValueNull(qvalue.QValueKindJSON) + case []byte: + switch mytype { + case mysql.MYSQL_TYPE_BLOB: + return qvalue.QValueBytes{Val: val} + case mysql.MYSQL_TYPE_JSON: + return qvalue.QValueJSON{Val: string(val)} + case mysql.MYSQL_TYPE_GEOMETRY: + // TODO figure out mysql geo encoding + return qvalue.QValueGeometry{Val: string(val)} + } + case string: + switch mytype { + case mysql.MYSQL_TYPE_TIME: + // TODO parse + case mysql.MYSQL_TYPE_TIME2: + // TODO parse + case mysql.MYSQL_TYPE_DATE: + // TODO parse + case mysql.MYSQL_TYPE_VARCHAR, + mysql.MYSQL_TYPE_VAR_STRING, + mysql.MYSQL_TYPE_STRING: + return qvalue.QValueString{Val: val} + } + } + panic(fmt.Sprintf("unexpected type %T for mysql type %d", val, mytype)) +} diff --git a/flow/connectors/mysql/mysql.go b/flow/connectors/mysql/mysql.go index e2929ad8ae..827b606769 100644 --- a/flow/connectors/mysql/mysql.go +++ b/flow/connectors/mysql/mysql.go @@ -5,6 +5,7 @@ package connmysql import ( "context" "crypto/tls" + "errors" "fmt" "log/slog" "time" @@ -16,6 +17,7 @@ import ( metadataStore "github.com/PeerDB-io/peer-flow/connectors/external_metadata" "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/PeerDB-io/peer-flow/shared" ) @@ -76,6 +78,7 @@ func (c *MySqlConnector) connect(ctx context.Context, options ...client.Option) func (c *MySqlConnector) Execute(ctx context.Context, cmd string, args ...interface{}) (*mysql.Result, error) { reconnects := 3 for { + // TODO need new connection if ctx changes between calls, or make upstream PR if c.conn == nil { var err error var argF []client.Option @@ -154,3 +157,116 @@ func (c *MySqlConnector) GetVersion(ctx context.Context) (string, error) { c.logger.Info("[mysql] version", slog.String("version", version)) return version, nil } + +func qkindFromMysql(ty uint8) (qvalue.QValueKind, error) { + switch ty { + case mysql.MYSQL_TYPE_DECIMAL: + return qvalue.QValueKindNumeric, nil + case mysql.MYSQL_TYPE_TINY: + return qvalue.QValueKindInt16, nil // TODO qvalue.QValueKindInt8 + case mysql.MYSQL_TYPE_SHORT: + return qvalue.QValueKindInt16, nil + case mysql.MYSQL_TYPE_LONG: + return qvalue.QValueKindInt32, nil + case mysql.MYSQL_TYPE_FLOAT: + return qvalue.QValueKindFloat32, nil + case mysql.MYSQL_TYPE_DOUBLE: + return qvalue.QValueKindFloat64, nil + case mysql.MYSQL_TYPE_NULL: + return qvalue.QValueKindInvalid, nil // TODO qvalue.QValueKindNothing + case mysql.MYSQL_TYPE_TIMESTAMP: + return qvalue.QValueKindTimestamp, nil + case mysql.MYSQL_TYPE_LONGLONG: + return qvalue.QValueKindInt64, nil + case mysql.MYSQL_TYPE_INT24: + return qvalue.QValueKindInt32, nil + case mysql.MYSQL_TYPE_DATE: + return qvalue.QValueKindDate, nil + case mysql.MYSQL_TYPE_TIME: + return qvalue.QValueKindTime, nil + case mysql.MYSQL_TYPE_DATETIME: + return qvalue.QValueKindTimestamp, nil + case mysql.MYSQL_TYPE_YEAR: + return qvalue.QValueKindInt16, nil + case mysql.MYSQL_TYPE_NEWDATE: + return qvalue.QValueKindDate, nil + case mysql.MYSQL_TYPE_VARCHAR: + return qvalue.QValueKindString, nil + case mysql.MYSQL_TYPE_BIT: + return qvalue.QValueKindInt64, nil + case mysql.MYSQL_TYPE_TIMESTAMP2: + return qvalue.QValueKindTimestamp, nil + case mysql.MYSQL_TYPE_DATETIME2: + return qvalue.QValueKindTimestamp, nil + case mysql.MYSQL_TYPE_TIME2: + return qvalue.QValueKindTime, nil + case mysql.MYSQL_TYPE_JSON: + return qvalue.QValueKindJSON, nil + case mysql.MYSQL_TYPE_NEWDECIMAL: + return qvalue.QValueKindNumeric, nil + case mysql.MYSQL_TYPE_ENUM: + return qvalue.QValueKindInt64, nil + case mysql.MYSQL_TYPE_SET: + return qvalue.QValueKindInt64, nil + case mysql.MYSQL_TYPE_TINY_BLOB: + return qvalue.QValueKindBytes, nil + case mysql.MYSQL_TYPE_MEDIUM_BLOB: + return qvalue.QValueKindBytes, nil + case mysql.MYSQL_TYPE_LONG_BLOB: + return qvalue.QValueKindBytes, nil + case mysql.MYSQL_TYPE_BLOB: + return qvalue.QValueKindBytes, nil + case mysql.MYSQL_TYPE_VAR_STRING: + return qvalue.QValueKindString, nil + case mysql.MYSQL_TYPE_STRING: + return qvalue.QValueKindString, nil + case mysql.MYSQL_TYPE_GEOMETRY: + return qvalue.QValueKindGeometry, nil + default: + return qvalue.QValueKind(""), fmt.Errorf("unknown mysql type %d", ty) + } +} + +func qvalueFromMysqlFieldValue(qkind qvalue.QValueKind, fv mysql.FieldValue) (qvalue.QValue, error) { + // TODO fill this in, maybe contribute upstream, figvure out how numeric etc fit in + switch v := fv.Value().(type) { + case nil: + return qvalue.QValueNull(qkind), nil + case uint64: + // TODO unsigned integers + return nil, errors.New("mysql unsigned integers not supported") + case int64: + switch qkind { + case qvalue.QValueKindInt16: + return qvalue.QValueInt16{Val: int16(v)}, nil + case qvalue.QValueKindInt32: + return qvalue.QValueInt32{Val: int32(v)}, nil + case qvalue.QValueKindInt64: + return qvalue.QValueInt64{Val: v}, nil + default: + return nil, fmt.Errorf("cannot convert int to %s", qkind) + } + case float64: + switch qkind { + case qvalue.QValueKindFloat32: + return qvalue.QValueFloat32{Val: float32(v)}, nil + case qvalue.QValueKindFloat64: + return qvalue.QValueFloat64{Val: float64(v)}, nil + default: + return nil, fmt.Errorf("cannot convert float to %s", qkind) + } + case string: + switch qkind { + case qvalue.QValueKindString: + return qvalue.QValueString{Val: v}, nil + case qvalue.QValueKindBytes: + return qvalue.QValueBytes{Val: []byte(v)}, nil + case qvalue.QValueKindJSON: + return qvalue.QValueJSON{Val: v}, nil + default: + return nil, fmt.Errorf("cannot convert string to %s", qkind) + } + default: + return nil, fmt.Errorf("unexpected mysql type %T", v) + } +} diff --git a/flow/connectors/mysql/qrep.go b/flow/connectors/mysql/qrep.go new file mode 100644 index 0000000000..9ac4b57158 --- /dev/null +++ b/flow/connectors/mysql/qrep.go @@ -0,0 +1,239 @@ +package connmysql + +import ( + "bytes" + "context" + "errors" + "fmt" + "log/slog" + "text/template" + + "github.com/go-mysql-org/go-mysql/mysql" + "github.com/google/uuid" + "go.temporal.io/sdk/log" + + utils "github.com/PeerDB-io/peer-flow/connectors/utils/partition" + "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/model/qvalue" +) + +func (c *MySqlConnector) GetQRepPartitions( + ctx context.Context, + config *protos.QRepConfig, + last *protos.QRepPartition, +) ([]*protos.QRepPartition, error) { + if config.WatermarkTable == "" { + c.logger.Info("watermark table is empty, doing full table refresh") + return []*protos.QRepPartition{ + { + PartitionId: uuid.New().String(), + FullTablePartition: true, + }, + }, nil + } + + if config.NumRowsPerPartition <= 0 { + return nil, errors.New("num rows per partition must be greater than 0 for sql server") + } + + var err error + numRowsPerPartition := int64(config.NumRowsPerPartition) + quotedWatermarkColumn := fmt.Sprintf("\"%s\"", config.WatermarkColumn) + + whereClause := "" + if last != nil && last.Range != nil { + whereClause = fmt.Sprintf("WHERE %s > $1", quotedWatermarkColumn) + } + + // Query to get the total number of rows in the table + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s %s", config.WatermarkTable, whereClause) + var minVal interface{} + var totalRows int64 + if last != nil && last.Range != nil { + switch lastRange := last.Range.Range.(type) { + case *protos.PartitionRange_IntRange: + minVal = lastRange.IntRange.End + case *protos.PartitionRange_TimestampRange: + minVal = lastRange.TimestampRange.End.AsTime() + } + c.logger.Info(fmt.Sprintf("count query: %s - minVal: %v", countQuery, minVal)) + + rs, err := c.Execute(ctx, countQuery, minVal) + if err != nil { + return nil, err + } + + totalRows, err = rs.GetInt(0, 0) + if err != nil { + return nil, fmt.Errorf("failed to query for total rows: %w", err) + } + } else { + rs, err := c.Execute(ctx, countQuery) + if err != nil { + return nil, err + } + + totalRows, err = rs.GetInt(0, 0) + if err != nil { + return nil, fmt.Errorf("failed to query for total rows: %w", err) + } + } + + if totalRows == 0 { + c.logger.Warn("no records to replicate, returning") + return make([]*protos.QRepPartition, 0), nil + } + + // Calculate the number of partitions + numPartitions := totalRows / numRowsPerPartition + if totalRows%numRowsPerPartition != 0 { + numPartitions++ + } + c.logger.Info(fmt.Sprintf("total rows: %d, num partitions: %d, num rows per partition: %d", + totalRows, numPartitions, numRowsPerPartition)) + var rs *mysql.Result + if minVal != nil { + // Query to get partitions using window functions + partitionsQuery := fmt.Sprintf( + `SELECT bucket_v, MIN(v_from) AS start_v, MAX(v_from) AS end_v + FROM ( + SELECT NTILE(%d) OVER (ORDER BY %s) AS bucket_v, %s as v_from + FROM %s WHERE %s > $1 + ) AS subquery + GROUP BY bucket_v + ORDER BY start_v`, + numPartitions, + quotedWatermarkColumn, + quotedWatermarkColumn, + config.WatermarkTable, + quotedWatermarkColumn, + ) + c.logger.Info(fmt.Sprintf("partitions query: %s - minVal: %v", partitionsQuery, minVal)) + rs, err = c.Execute(ctx, partitionsQuery, minVal) + } else { + partitionsQuery := fmt.Sprintf( + `SELECT bucket_v, MIN(v_from) AS start_v, MAX(v_from) AS end_v + FROM ( + SELECT NTILE(%d) OVER (ORDER BY %s) AS bucket_v, %s as v_from + FROM %s + ) AS subquery + GROUP BY bucket_v + ORDER BY start_v`, + numPartitions, + quotedWatermarkColumn, + quotedWatermarkColumn, + config.WatermarkTable, + ) + c.logger.Info("partitions query: " + partitionsQuery) + rs, err = c.Execute(ctx, partitionsQuery) + } + if err != nil { + return nil, fmt.Errorf("failed to query for partitions: %w", err) + } + + partitionHelper := utils.NewPartitionHelper() + for _, row := range rs.Values { + if err := partitionHelper.AddPartition(row[1].Value(), row[2].Value()); err != nil { + return nil, fmt.Errorf("failed to add partition: %w", err) + } + } + + return partitionHelper.GetPartitions(), nil +} + +// TODO use ExecuteStreamingSelect +func (c *MySqlConnector) PullQRepRecords( + ctx context.Context, + config *protos.QRepConfig, + last *protos.QRepPartition, + stream *model.QRecordStream, +) (int, error) { + // Build the query to pull records within the range from the source table + // Be sure to order the results by the watermark column to ensure consistency across pulls + query, err := BuildQuery(c.logger, config.Query) + if err != nil { + return 0, err + } + + var rs *mysql.Result + if last.FullTablePartition { + var err error + // this is a full table partition, so just run the query + rs, err = c.Execute(ctx, query) + if err != nil { + return 0, err + } + } else { + var rangeStart interface{} + var rangeEnd interface{} + + // Depending on the type of the range, convert the range into the correct type + switch x := last.Range.Range.(type) { + case *protos.PartitionRange_IntRange: + rangeStart = x.IntRange.Start + rangeEnd = x.IntRange.End + case *protos.PartitionRange_TimestampRange: + rangeStart = x.TimestampRange.Start.AsTime() + rangeEnd = x.TimestampRange.End.AsTime() + default: + return 0, fmt.Errorf("unknown range type: %v", x) + } + + var err error + rs, err = c.Execute(ctx, query, rangeStart, rangeEnd) + if err != nil { + return 0, err + } + } + + schema := make([]qvalue.QField, 0, len(rs.Fields)) + for _, field := range rs.Fields { + qkind, err := qkindFromMysql(field.Type) + if err != nil { + return 0, err + } + + schema = append(schema, qvalue.QField{ + Name: string(field.Name), + Type: qkind, + Precision: 0, // TODO numerics + Scale: 0, // TODO numerics + Nullable: (field.Flag & mysql.NOT_NULL_FLAG) == 0, + }) + } + stream.SetSchema(qvalue.QRecordSchema{Fields: schema}) + for _, row := range rs.Values { + record := make([]qvalue.QValue, 0, len(row)) + for idx, val := range row { + qv, err := qvalueFromMysqlFieldValue(schema[idx].Type, val) + if err != nil { + return 0, err + } + record = append(record, qv) + } + stream.Records <- record + } + return len(rs.Values), nil +} + +func BuildQuery(logger log.Logger, query string) (string, error) { + tmpl, err := template.New("query").Parse(query) + if err != nil { + return "", err + } + + data := map[string]interface{}{ + "start": "$1", + "end": "$2", + } + + buf := new(bytes.Buffer) + if err := tmpl.Execute(buf, data); err != nil { + return "", err + } + res := buf.String() + + logger.Info("[mysql] templated query", slog.String("query", res)) + return res, nil +} diff --git a/flow/connectors/sqlserver/qrep.go b/flow/connectors/sqlserver/qrep.go index a725fd0102..aa35d455a0 100644 --- a/flow/connectors/sqlserver/qrep.go +++ b/flow/connectors/sqlserver/qrep.go @@ -46,7 +46,7 @@ func (c *SQLServerConnector) GetQRepPartitions( // Query to get the total number of rows in the table countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s %s", config.WatermarkTable, whereClause) - var minVal interface{} = nil + var minVal interface{} var totalRows pgtype.Int8 if last != nil && last.Range != nil { switch lastRange := last.Range.Range.(type) { @@ -78,11 +78,8 @@ func (c *SQLServerConnector) GetQRepPartitions( if err != nil { return nil, fmt.Errorf("failed to query for total rows: %w", err) } - } else { - row := c.db.QueryRowContext(ctx, countQuery) - if err = row.Scan(&totalRows); err != nil { - return nil, fmt.Errorf("failed to query for total rows: %w", err) - } + } else if err := c.db.QueryRowContext(ctx, countQuery).Scan(&totalRows); err != nil { + return nil, fmt.Errorf("failed to query for total rows: %w", err) } if totalRows.Int64 == 0 { @@ -150,8 +147,7 @@ func (c *SQLServerConnector) GetQRepPartitions( return nil, fmt.Errorf("failed to scan row: %w", err) } - err = partitionHelper.AddPartition(start, end) - if err != nil { + if err := partitionHelper.AddPartition(start, end); err != nil { return nil, fmt.Errorf("failed to add partition: %w", err) } } @@ -162,7 +158,7 @@ func (c *SQLServerConnector) GetQRepPartitions( func (c *SQLServerConnector) PullQRepRecords( ctx context.Context, config *protos.QRepConfig, - partition *protos.QRepPartition, + last *protos.QRepPartition, stream *model.QRecordStream, ) (int, error) { // Build the query to pull records within the range from the source table @@ -172,40 +168,40 @@ func (c *SQLServerConnector) PullQRepRecords( return 0, err } - if partition.FullTablePartition { + var qbatch *model.QRecordBatch + if last.FullTablePartition { // this is a full table partition, so just run the query - qbatch, err := c.ExecuteAndProcessQuery(ctx, query) + var err error + qbatch, err = c.ExecuteAndProcessQuery(ctx, query) if err != nil { return 0, err } - qbatch.FeedToQRecordStream(stream) - return len(qbatch.Records), nil - } + } else { + var rangeStart interface{} + var rangeEnd interface{} - var rangeStart interface{} - var rangeEnd interface{} - - // Depending on the type of the range, convert the range into the correct type - switch x := partition.Range.Range.(type) { - case *protos.PartitionRange_IntRange: - rangeStart = x.IntRange.Start - rangeEnd = x.IntRange.End - case *protos.PartitionRange_TimestampRange: - rangeStart = x.TimestampRange.Start.AsTime() - rangeEnd = x.TimestampRange.End.AsTime() - default: - return 0, fmt.Errorf("unknown range type: %v", x) - } + // Depending on the type of the range, convert the range into the correct type + switch x := last.Range.Range.(type) { + case *protos.PartitionRange_IntRange: + rangeStart = x.IntRange.Start + rangeEnd = x.IntRange.End + case *protos.PartitionRange_TimestampRange: + rangeStart = x.TimestampRange.Start.AsTime() + rangeEnd = x.TimestampRange.End.AsTime() + default: + return 0, fmt.Errorf("unknown range type: %v", x) + } - rangeParams := map[string]interface{}{ - "startRange": rangeStart, - "endRange": rangeEnd, + var err error + qbatch, err = c.NamedExecuteAndProcessQuery(ctx, query, map[string]interface{}{ + "startRange": rangeStart, + "endRange": rangeEnd, + }) + if err != nil { + return 0, err + } } - qbatch, err := c.NamedExecuteAndProcessQuery(ctx, query, rangeParams) - if err != nil { - return 0, err - } qbatch.FeedToQRecordStream(stream) return len(qbatch.Records), nil }