diff --git a/pkg/ccl/changefeedccl/helpers_test.go b/pkg/ccl/changefeedccl/helpers_test.go index 53ee7c488705..c8de4955da0d 100644 --- a/pkg/ccl/changefeedccl/helpers_test.go +++ b/pkg/ccl/changefeedccl/helpers_test.go @@ -1178,8 +1178,7 @@ func waitForJobStatus( } // TestingSetIncludeParquetMetadata adds the option to turn on adding metadata -// (primary key column names) to the parquet file which is used to convert parquet -// data to JSON format +// to the parquet file which is used in testing. func TestingSetIncludeParquetMetadata() func() { includeParquetTestMetadata = true return func() { diff --git a/pkg/ccl/changefeedccl/parquet.go b/pkg/ccl/changefeedccl/parquet.go index e11bd1d5d40b..4caeeef2efcc 100644 --- a/pkg/ccl/changefeedccl/parquet.go +++ b/pkg/ccl/changefeedccl/parquet.go @@ -25,7 +25,7 @@ type parquetWriter struct { // newParquetWriterFromRow constructs a new parquet writer which outputs to // the given sink. This function interprets the schema from the supplied row. func newParquetWriterFromRow( - row cdcevent.Row, sink io.Writer, maxRowGroupSize int64, + row cdcevent.Row, sink io.Writer, opts ...parquet.Option, ) (*parquetWriter, error) { columnNames := make([]string, len(row.ResultColumns())+1) columnTypes := make([]*types.T, len(row.ResultColumns())+1) @@ -48,7 +48,12 @@ func newParquetWriterFromRow( return nil, err } - writer, err := parquet.NewWriter(schemaDef, sink, parquet.WithMaxRowGroupLength(maxRowGroupSize)) + writerConstructor := parquet.NewWriter + if includeParquetTestMetadata { + writerConstructor = parquet.NewWriterWithReaderMeta + } + + writer, err := writerConstructor(schemaDef, sink, opts...) if err != nil { return nil, err } diff --git a/pkg/ccl/changefeedccl/parquet_test.go b/pkg/ccl/changefeedccl/parquet_test.go index 5a4843afe6a2..030e26a490cb 100644 --- a/pkg/ccl/changefeedccl/parquet_test.go +++ b/pkg/ccl/changefeedccl/parquet_test.go @@ -38,6 +38,8 @@ func TestParquetRows(t *testing.T) { // Rangefeed reader can time out under stress. skip.UnderStress(t) + defer TestingSetIncludeParquetMetadata()() + ctx := context.Background() s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) defer s.Stopper().Stop(ctx) @@ -106,7 +108,7 @@ func TestParquetRows(t *testing.T) { require.NoError(t, err) if writer == nil { - writer, err = newParquetWriterFromRow(updatedRow, f, maxRowGroupSize) + writer, err = newParquetWriterFromRow(updatedRow, f, parquet.WithMaxRowGroupLength(maxRowGroupSize)) if err != nil { t.Fatalf(err.Error()) } diff --git a/pkg/util/parquet/BUILD.bazel b/pkg/util/parquet/BUILD.bazel index a2a457617f37..ced63863101f 100644 --- a/pkg/util/parquet/BUILD.bazel +++ b/pkg/util/parquet/BUILD.bazel @@ -33,7 +33,6 @@ go_library( "@com_github_apache_arrow_go_v11//parquet/schema", "@com_github_cockroachdb_errors//:errors", "@com_github_lib_pq//oid", - "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", ], ) diff --git a/pkg/util/parquet/decoders.go b/pkg/util/parquet/decoders.go index e3b4e17e8e0e..9f3f37cf7176 100644 --- a/pkg/util/parquet/decoders.go +++ b/pkg/util/parquet/decoders.go @@ -17,6 +17,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/geo" "github.com/cockroachdb/cockroach/pkg/geo/geopb" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/util/bitarray" "github.com/cockroachdb/cockroach/pkg/util/duration" "github.com/cockroachdb/cockroach/pkg/util/timeofday" @@ -236,6 +237,76 @@ func (collatedStringDecoder) decode(v parquet.ByteArray) (tree.Datum, error) { return &tree.DCollatedString{Contents: string(v)}, nil } +// decoderFromFamilyAndType returns the decoder to use based on the type oid and +// family. Note the logical similarity to makeColumn in schema.go. This is +// intentional as each decoder returned by this function corresponds to a +// particular colWriter determined by makeColumn. +// TODO: refactor to remove the code duplication with makeColumn +func decoderFromFamilyAndType(typOid oid.Oid, family types.Family) (decoder, error) { + switch family { + case types.BoolFamily: + return boolDecoder{}, nil + case types.StringFamily: + return stringDecoder{}, nil + case types.IntFamily: + typ, ok := types.OidToType[typOid] + if !ok { + return nil, errors.AssertionFailedf("could not determine type from oid %d", typOid) + } + if typ.Oid() == oid.T_int8 { + return int64Decoder{}, nil + } + return int32Decoder{}, nil + case types.DecimalFamily: + return decimalDecoder{}, nil + case types.TimestampFamily: + return timestampDecoder{}, nil + case types.TimestampTZFamily: + return timestampTZDecoder{}, nil + case types.UuidFamily: + return uUIDDecoder{}, nil + case types.INetFamily: + return iNetDecoder{}, nil + case types.JsonFamily: + return jsonDecoder{}, nil + case types.BitFamily: + return bitDecoder{}, nil + case types.BytesFamily: + return bytesDecoder{}, nil + case types.EnumFamily: + return enumDecoder{}, nil + case types.DateFamily: + return dateDecoder{}, nil + case types.Box2DFamily: + return box2DDecoder{}, nil + case types.GeographyFamily: + return geographyDecoder{}, nil + case types.GeometryFamily: + return geometryDecoder{}, nil + case types.IntervalFamily: + return intervalDecoder{}, nil + case types.TimeFamily: + return timeDecoder{}, nil + case types.TimeTZFamily: + return timeTZDecoder{}, nil + case types.FloatFamily: + typ, ok := types.OidToType[typOid] + if !ok { + return nil, errors.AssertionFailedf("could not determine type from oid %d", typOid) + } + if typ.Oid() == oid.T_float4 { + return float32Decoder{}, nil + } + return float64Decoder{}, nil + case types.OidFamily: + return oidDecoder{}, nil + case types.CollatedStringFamily: + return collatedStringDecoder{}, nil + default: + return nil, errors.AssertionFailedf("could not find decoder for type oid %d and family %d", typOid, family) + } +} + // Defeat the linter's unused lint errors. func init() { var _, _ = boolDecoder{}.decode(false) @@ -253,7 +324,6 @@ func init() { var _, _ = enumDecoder{}.decode(parquet.ByteArray{}) var _, _ = dateDecoder{}.decode(parquet.ByteArray{}) var _, _ = box2DDecoder{}.decode(parquet.ByteArray{}) - var _, _ = box2DDecoder{}.decode(parquet.ByteArray{}) var _, _ = geographyDecoder{}.decode(parquet.ByteArray{}) var _, _ = geometryDecoder{}.decode(parquet.ByteArray{}) var _, _ = intervalDecoder{}.decode(parquet.ByteArray{}) diff --git a/pkg/util/parquet/schema.go b/pkg/util/parquet/schema.go index e68c1560a25e..ea8caf10fd07 100644 --- a/pkg/util/parquet/schema.go +++ b/pkg/util/parquet/schema.go @@ -42,7 +42,6 @@ const defaultTypeLength = -1 type column struct { node schema.Node colWriter colWriter - decoder decoder typ *types.T } @@ -99,7 +98,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c case types.BoolFamily: result.node = schema.NewBooleanNode(colName, repetitions, defaultSchemaFieldID) result.colWriter = scalarWriter(writeBool) - result.decoder = boolDecoder{} result.typ = types.Bool return result, nil case types.StringFamily: @@ -111,7 +109,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeString) - result.decoder = stringDecoder{} return result, nil case types.IntFamily: // Note: integer datums are always signed: https://www.cockroachlabs.com/docs/stable/int.html @@ -124,13 +121,11 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeInt64) - result.decoder = int64Decoder{} return result, nil } result.node = schema.NewInt32Node(colName, repetitions, defaultSchemaFieldID) result.colWriter = scalarWriter(writeInt32) - result.decoder = int32Decoder{} return result, nil case types.DecimalFamily: // According to PostgresSQL docs, scale or precision of 0 implies max @@ -155,7 +150,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeDecimal) - result.decoder = decimalDecoder{} return result, nil case types.UuidFamily: result.node, err = schema.NewPrimitiveNodeLogical(colName, @@ -165,7 +159,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeUUID) - result.decoder = uUIDDecoder{} return result, nil case types.TimestampFamily: // We do not use schema.TimestampLogicalType because the library will enforce @@ -177,7 +170,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeTimestamp) - result.decoder = timestampDecoder{} return result, nil case types.TimestampTZFamily: // We do not use schema.TimestampLogicalType because the library will enforce @@ -189,7 +181,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeTimestampTZ) - result.decoder = timestampTZDecoder{} return result, nil case types.INetFamily: result.node, err = schema.NewPrimitiveNodeLogical(colName, @@ -199,7 +190,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeINet) - result.decoder = iNetDecoder{} return result, nil case types.JsonFamily: result.node, err = schema.NewPrimitiveNodeLogical(colName, @@ -209,7 +199,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeJSON) - result.decoder = jsonDecoder{} return result, nil case types.BitFamily: result.node, err = schema.NewPrimitiveNode(colName, @@ -219,7 +208,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeBit) - result.decoder = bitDecoder{} return result, nil case types.BytesFamily: result.node, err = schema.NewPrimitiveNode(colName, @@ -229,7 +217,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeBytes) - result.decoder = bytesDecoder{} return result, nil case types.EnumFamily: result.node, err = schema.NewPrimitiveNodeLogical(colName, @@ -239,7 +226,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeEnum) - result.decoder = enumDecoder{} return result, nil case types.DateFamily: // We do not use schema.DateLogicalType because the library will enforce @@ -251,7 +237,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeDate) - result.decoder = dateDecoder{} return result, nil case types.Box2DFamily: result.node, err = schema.NewPrimitiveNodeLogical(colName, @@ -261,7 +246,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeBox2D) - result.decoder = box2DDecoder{} return result, nil case types.GeographyFamily: result.node, err = schema.NewPrimitiveNode(colName, @@ -271,7 +255,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeGeography) - result.decoder = geographyDecoder{} return result, nil case types.GeometryFamily: result.node, err = schema.NewPrimitiveNode(colName, @@ -281,7 +264,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeGeometry) - result.decoder = geometryDecoder{} return result, nil case types.IntervalFamily: result.node, err = schema.NewPrimitiveNodeLogical(colName, @@ -291,7 +273,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeInterval) - result.decoder = intervalDecoder{} return result, nil case types.TimeFamily: // CRDB stores time datums in microseconds, adjusted to UTC. @@ -303,7 +284,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeTime) - result.decoder = timeDecoder{} return result, nil case types.TimeTZFamily: // We cannot use the schema.NewTimeLogicalType because it does not support @@ -315,7 +295,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeTimeTZ) - result.decoder = timeTZDecoder{} return result, nil case types.FloatFamily: if typ.Oid() == oid.T_float4 { @@ -326,7 +305,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeFloat32) - result.decoder = float32Decoder{} return result, nil } result.node, err = schema.NewPrimitiveNode(colName, @@ -336,12 +314,10 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeFloat64) - result.decoder = float64Decoder{} return result, nil case types.OidFamily: result.node = schema.NewInt32Node(colName, repetitions, defaultSchemaFieldID) result.colWriter = scalarWriter(writeOid) - result.decoder = oidDecoder{} return result, nil case types.CollatedStringFamily: result.node, err = schema.NewPrimitiveNodeLogical(colName, @@ -351,7 +327,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c return result, err } result.colWriter = scalarWriter(writeCollatedString) - result.decoder = collatedStringDecoder{} return result, nil case types.ArrayFamily: // Arrays for type T are represented by the following: @@ -383,7 +358,6 @@ func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (c if err != nil { return result, err } - result.decoder = elementCol.decoder scalarColWriter, ok := elementCol.colWriter.(scalarWriter) if !ok { return result, errors.AssertionFailedf("expected scalar column writer") diff --git a/pkg/util/parquet/testutils.go b/pkg/util/parquet/testutils.go index e87e5652c19a..9123ce8c7919 100644 --- a/pkg/util/parquet/testutils.go +++ b/pkg/util/parquet/testutils.go @@ -11,24 +11,37 @@ package parquet import ( + "fmt" + "io" "math" "os" + "strconv" + "strings" "testing" "github.com/apache/arrow/go/v11/parquet" "github.com/apache/arrow/go/v11/parquet/file" + "github.com/apache/arrow/go/v11/parquet/metadata" "github.com/cockroachdb/cockroach/pkg/sql/sem/catid" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/util/encoding" "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" + "github.com/lib/pq/oid" "github.com/stretchr/testify/require" ) -// ReadFileAndVerifyDatums reads the parquetFile and first asserts its metadata -// matches the metadata from the writer. Then, it reads the file and asserts -// that it's data matches writtenDatums. +// NewWriterWithReaderMeta constructs a Writer that writes additional metadata +// required to use the reader utility functions below. +func NewWriterWithReaderMeta( + sch *SchemaDefinition, sink io.Writer, opts ...Option, +) (*Writer, error) { + opts = append(opts, WithMetadata(MakeReaderMetadata(sch))) + return NewWriter(sch, sink, opts...) +} + +// ReadFileAndVerifyDatums asserts that a parquet file's metadata matches the +// metadata from the writer and its data matches writtenDatums. func ReadFileAndVerifyDatums( t *testing.T, parquetFile string, @@ -37,34 +50,91 @@ func ReadFileAndVerifyDatums( writer *Writer, writtenDatums [][]tree.Datum, ) { - f, err := os.Open(parquetFile) - require.NoError(t, err) + meta, readDatums, closeReader, err := ReadFile(parquetFile) + defer func() { + require.NoError(t, closeReader()) + }() - reader, err := file.NewParquetReader(f) require.NoError(t, err) - - assert.Equal(t, reader.NumRows(), int64(numRows)) - assert.Equal(t, reader.MetaData().Schema.NumColumns(), numCols) + require.Equal(t, int64(numRows), meta.GetNumRows()) + require.Equal(t, numCols, meta.Schema.NumColumns()) numRowGroups := int(math.Ceil(float64(numRows) / float64(writer.cfg.maxRowGroupLength))) - assert.EqualValues(t, numRowGroups, reader.NumRowGroups()) + require.EqualValues(t, numRowGroups, len(meta.GetRowGroups())) - readDatums := make([][]tree.Datum, numRows) for i := 0; i < numRows; i++ { - readDatums[i] = make([]tree.Datum, numCols) + for j := 0; j < numCols; j++ { + ValidateDatum(t, writtenDatums[i][j], readDatums[i][j]) + } + } +} + +// ReadFile reads a parquet file and returns the contained metadata and datums. +// +// To use this function, the Writer must be configured to write CRDB-specific +// metadata for the reader. See NewWriterWithReaderMeta. +// +// close() should be called to release the underlying reader once the returned +// metadata is not required anymore. The returned metadata should not be used +// after close() is called. +// +// NB: The returned datums may not be hydrated or identical to the ones +// which were written. See comment on ValidateDatum for more info. +func ReadFile( + parquetFile string, +) (meta *metadata.FileMetaData, datums [][]tree.Datum, close func() error, err error) { + f, err := os.Open(parquetFile) + if err != nil { + return nil, nil, nil, err + } + + reader, err := file.NewParquetReader(f) + if err != nil { + return nil, nil, nil, err + } + + var readDatums [][]tree.Datum + + typFamiliesMeta := reader.MetaData().KeyValueMetadata().FindValue(typeFamilyMetaKey) + if typFamiliesMeta == nil { + return nil, nil, nil, + errors.AssertionFailedf("missing type family metadata. ensure the writer is configured" + + " to write reader metadata. see NewWriterWithReaderMeta()") + } + typFamilies, err := deserializeIntArray(*typFamiliesMeta) + if err != nil { + return nil, nil, nil, err + } + + typOidsMeta := reader.MetaData().KeyValueMetadata().FindValue(typeOidMetaKey) + if typOidsMeta == nil { + return nil, nil, nil, + errors.AssertionFailedf("missing type oid metadata. ensure the writer is configured" + + " to write reader metadata. see NewWriterWithReaderMeta()") + } + typOids, err := deserializeIntArray(*typOidsMeta) + if err != nil { + return nil, nil, nil, err } startingRowIdx := 0 for rg := 0; rg < reader.NumRowGroups(); rg++ { rgr := reader.RowGroup(rg) rowsInRowGroup := rgr.NumRows() + for i := int64(0); i < rowsInRowGroup; i++ { + readDatums = append(readDatums, make([]tree.Datum, rgr.NumColumns())) + } - for colIdx := 0; colIdx < numCols; colIdx++ { + for colIdx := 0; colIdx < rgr.NumColumns(); colIdx++ { col, err := rgr.Column(colIdx) - require.NoError(t, err) + if err != nil { + return nil, nil, nil, err + } - dec := writer.sch.cols[colIdx].decoder - typ := writer.sch.cols[colIdx].typ + dec, err := decoderFromFamilyAndType(oid.Oid(typOids[colIdx]), types.Family(typFamilies[colIdx])) + if err != nil { + return nil, nil, nil, err + } // Based on how we define schemas, we can detect an array by seeing if the // primitive col reader has a max repetition level of 1. See comments above @@ -73,53 +143,85 @@ func ReadFileAndVerifyDatums( switch col.Type() { case parquet.Types.Boolean: - colDatums, read, err := readBatch(col, make([]bool, 1), dec, typ, isArray) - require.NoError(t, err) - require.Equal(t, rowsInRowGroup, read) + colDatums, read, err := readBatch(col, make([]bool, 1), dec, isArray) + if err != nil { + return nil, nil, nil, err + } + if read != rowsInRowGroup { + return nil, nil, nil, + errors.AssertionFailedf("expected to read %d values but found %d", rowsInRowGroup, read) + } decodeValuesIntoDatumsHelper(colDatums, readDatums, colIdx, startingRowIdx) case parquet.Types.Int32: - colDatums, read, err := readBatch(col, make([]int32, 1), dec, typ, isArray) - require.NoError(t, err) - require.Equal(t, rowsInRowGroup, read) + colDatums, read, err := readBatch(col, make([]int32, 1), dec, isArray) + if err != nil { + return nil, nil, nil, err + } + if read != rowsInRowGroup { + return nil, nil, nil, + errors.AssertionFailedf("expected to read %d values but found %d", rowsInRowGroup, read) + } decodeValuesIntoDatumsHelper(colDatums, readDatums, colIdx, startingRowIdx) case parquet.Types.Int64: - colDatums, read, err := readBatch(col, make([]int64, 1), dec, typ, isArray) - require.NoError(t, err) - require.Equal(t, rowsInRowGroup, read) + colDatums, read, err := readBatch(col, make([]int64, 1), dec, isArray) + if err != nil { + return nil, nil, nil, err + } + if read != rowsInRowGroup { + return nil, nil, nil, + errors.AssertionFailedf("expected to read %d values but found %d", rowsInRowGroup, read) + } decodeValuesIntoDatumsHelper(colDatums, readDatums, colIdx, startingRowIdx) case parquet.Types.Int96: panic("unimplemented") case parquet.Types.Float: - arrs, read, err := readBatch(col, make([]float32, 1), dec, typ, isArray) - require.NoError(t, err) - require.Equal(t, rowsInRowGroup, read) + arrs, read, err := readBatch(col, make([]float32, 1), dec, isArray) + if err != nil { + return nil, nil, nil, err + } + if read != rowsInRowGroup { + return nil, nil, nil, + errors.AssertionFailedf("expected to read %d values but found %d", rowsInRowGroup, read) + } decodeValuesIntoDatumsHelper(arrs, readDatums, colIdx, startingRowIdx) case parquet.Types.Double: - arrs, read, err := readBatch(col, make([]float64, 1), dec, typ, isArray) - require.NoError(t, err) - require.Equal(t, rowsInRowGroup, read) + arrs, read, err := readBatch(col, make([]float64, 1), dec, isArray) + if err != nil { + return nil, nil, nil, err + } + if read != rowsInRowGroup { + return nil, nil, nil, + errors.AssertionFailedf("expected to read %d values but found %d", rowsInRowGroup, read) + } decodeValuesIntoDatumsHelper(arrs, readDatums, colIdx, startingRowIdx) case parquet.Types.ByteArray: - colDatums, read, err := readBatch(col, make([]parquet.ByteArray, 1), dec, typ, isArray) - require.NoError(t, err) - require.Equal(t, rowsInRowGroup, read) + colDatums, read, err := readBatch(col, make([]parquet.ByteArray, 1), dec, isArray) + if err != nil { + return nil, nil, nil, err + } + if read != rowsInRowGroup { + return nil, nil, nil, + errors.AssertionFailedf("expected to read %d values but found %d", rowsInRowGroup, read) + } decodeValuesIntoDatumsHelper(colDatums, readDatums, colIdx, startingRowIdx) case parquet.Types.FixedLenByteArray: - colDatums, read, err := readBatch(col, make([]parquet.FixedLenByteArray, 1), dec, typ, isArray) - require.NoError(t, err) - require.Equal(t, rowsInRowGroup, read) + colDatums, read, err := readBatch(col, make([]parquet.FixedLenByteArray, 1), dec, isArray) + if err != nil { + return nil, nil, nil, err + } + if read != rowsInRowGroup { + return nil, nil, nil, + errors.AssertionFailedf("expected to read %d values but found %d", rowsInRowGroup, read) + } decodeValuesIntoDatumsHelper(colDatums, readDatums, colIdx, startingRowIdx) } } startingRowIdx += int(rowsInRowGroup) } - require.NoError(t, reader.Close()) - - for i := 0; i < numRows; i++ { - for j := 0; j < numCols; j++ { - validateDatum(t, writtenDatums[i][j], readDatums[i][j]) - } - } + // Since reader.MetaData() is being returned, we do not close the reader. + // This is defensive - we should not assume any method or data on the reader + // is safe to read once it is closed. + return reader.MetaData(), readDatums, reader.Close, nil } type batchReader[T parquetDatatypes] interface { @@ -128,7 +230,7 @@ type batchReader[T parquetDatatypes] interface { // readBatch reads all the datums in a row group for a column. func readBatch[T parquetDatatypes]( - r file.ColumnChunkReader, valueAlloc []T, dec decoder, typ *types.T, isArray bool, + r file.ColumnChunkReader, valueAlloc []T, dec decoder, isArray bool, ) (tree.Datums, int64, error) { br, ok := r.(batchReader[T]) if !ok { @@ -156,7 +258,7 @@ func readBatch[T parquetDatatypes]( result = append(result, tree.DNull) continue } - arrDatum := tree.NewDArray(typ) + arrDatum := &tree.DArray{} result = append(result, arrDatum) // Replevel 0, Deflevel 1 represents an array which is empty. if defLevels[0] == 1 { @@ -209,15 +311,16 @@ func unwrapDatum(d tree.Datum) tree.Datum { } } -// validateDatum validates that the "contents" of the expected datum matches the +// ValidateDatum validates that the "contents" of the expected datum matches the // contents of the actual datum. For example, when validating two arrays, we // only compare the datums in the arrays. We do not compare CRDB-specific // metadata fields such as (tree.DArray).HasNulls or (tree.DArray).HasNonNulls. // // The reason for this special comparison that the parquet format is presently // used in an export use case, so we only need to test roundtripability with -// data end users see. We do not need to check roundtripability of internal CRDB data. -func validateDatum(t *testing.T, expected tree.Datum, actual tree.Datum) { +// data end users see. We do not need to check roundtripability of internal CRDB +// data. +func ValidateDatum(t *testing.T, expected tree.Datum, actual tree.Datum) { // The randgen library may generate datums wrapped in a *tree.DOidWrapper, so // we should unwrap them. We unwrap at this stage as opposed to when // generating datums to test that the writer can handle wrapped datums. @@ -246,7 +349,7 @@ func validateDatum(t *testing.T, expected tree.Datum, actual tree.Datum) { arr2 := actual.(*tree.DArray).Array require.Equal(t, len(arr1), len(arr2)) for i := 0; i < len(arr1); i++ { - validateDatum(t, arr1[i], arr2[i]) + ValidateDatum(t, arr1[i], arr2[i]) } case types.EnumFamily: require.Equal(t, expected.(*tree.DEnum).LogicalRep, actual.(*tree.DEnum).LogicalRep) @@ -276,3 +379,42 @@ func makeTestingEnumType() *types.T { } return enumType } + +const typeOidMetaKey = `crdbTypeOIDs` +const typeFamilyMetaKey = `crdbTypeFamilies` + +// MakeReaderMetadata returns column type metadata that will be written to all +// parquet files. This metadata is useful for roundtrip tests where we construct +// CRDB datums from the raw data in written to parquet files. +func MakeReaderMetadata(sch *SchemaDefinition) map[string]string { + meta := map[string]string{} + typOids := make([]uint32, 0, len(sch.cols)) + typFamilies := make([]int32, 0, len(sch.cols)) + for _, col := range sch.cols { + typOids = append(typOids, uint32(col.typ.Oid())) + typFamilies = append(typFamilies, int32(col.typ.Family())) + } + meta[typeOidMetaKey] = serializeIntArray(typOids) + meta[typeFamilyMetaKey] = serializeIntArray(typFamilies) + return meta +} + +// serializeIntArray serializes an int array to a string "23 2 32 43 32". +func serializeIntArray[I int32 | uint32](ints []I) string { + return strings.Trim(fmt.Sprint(ints), "[]") +} + +// deserializeIntArray deserializes an integer sting in the format "23 2 32 43 +// 32" to an array of ints. +func deserializeIntArray(s string) ([]uint32, error) { + vals := strings.Split(s, " ") + result := make([]uint32, 0, len(vals)) + for _, val := range vals { + intVal, err := strconv.Atoi(val) + if err != nil { + return nil, err + } + result = append(result, uint32(intVal)) + } + return result, nil +} diff --git a/pkg/util/parquet/writer_test.go b/pkg/util/parquet/writer_test.go index a745700bda29..f305d72b1e50 100644 --- a/pkg/util/parquet/writer_test.go +++ b/pkg/util/parquet/writer_test.go @@ -121,7 +121,7 @@ func TestRandomDatums(t *testing.T) { schemaDef, err := NewSchema(sch.columnNames, sch.columnTypes) require.NoError(t, err) - writer, err := NewWriter(schemaDef, f, WithMaxRowGroupLength(maxRowGroupSize)) + writer, err := NewWriterWithReaderMeta(schemaDef, f, WithMaxRowGroupLength(maxRowGroupSize)) require.NoError(t, err) for _, row := range datums { @@ -135,6 +135,8 @@ func TestRandomDatums(t *testing.T) { ReadFileAndVerifyDatums(t, f.Name(), numRows, numCols, writer, datums) } +// TestBasicDatums tests roundtripability for all supported scalar data types +// and one simple array type. func TestBasicDatums(t *testing.T) { for _, tc := range []struct { name string @@ -468,7 +470,7 @@ func TestBasicDatums(t *testing.T) { schemaDef, err := NewSchema(tc.sch.columnNames, tc.sch.columnTypes) require.NoError(t, err) - writer, err := NewWriter(schemaDef, f, WithMaxRowGroupLength(maxRowGroupSize)) + writer, err := NewWriterWithReaderMeta(schemaDef, f, WithMaxRowGroupLength(maxRowGroupSize)) require.NoError(t, err) for _, row := range datums {