Skip to content

Commit

Permalink
util/parquet: add support for arrays
Browse files Browse the repository at this point in the history
This change extends and refactors the util/parquet library to
be able to read and write arrays.

Release note: None

Informs: #99028
Epic: https://cockroachlabs.atlassian.net/browse/CRDB-15071
  • Loading branch information
jayshrivastava committed Apr 20, 2023
1 parent 3df495f commit 3cbca9e
Show file tree
Hide file tree
Showing 6 changed files with 374 additions and 123 deletions.
84 changes: 58 additions & 26 deletions pkg/util/parquet/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ import (
"github.com/lib/pq/oid"
)

// Setting parquet.Repetitions.Optional makes parquet a column nullable. When
// writing a datum, we will always specify a definition level to indicate if the
// datum is null or not. See comments on nonNilDefLevel or nilDefLevel for more info.
var defaultRepetitions = parquet.Repetitions.Optional

// A schema field is an internal identifier for schema nodes used by the parquet library.
// A value of -1 will let the library auto-assign values. This does not affect reading
// or writing parquet files.
Expand All @@ -36,7 +41,7 @@ const defaultTypeLength = -1
// A column stores column metadata.
type column struct {
node schema.Node
colWriter writeFn
colWriter colWriter
decoder decoder
typ *types.T
}
Expand Down Expand Up @@ -67,7 +72,7 @@ func NewSchema(columnNames []string, columnTypes []*types.T) (*SchemaDefinition,
fields := make([]schema.Node, 0)

for i := 0; i < len(columnNames); i++ {
parquetCol, err := makeColumn(columnNames[i], columnTypes[i])
parquetCol, err := makeColumn(columnNames[i], columnTypes[i], defaultRepetitions)
if err != nil {
return nil, err
}
Expand All @@ -87,50 +92,44 @@ func NewSchema(columnNames []string, columnTypes []*types.T) (*SchemaDefinition,
}

// makeColumn constructs a column.
func makeColumn(colName string, typ *types.T) (column, error) {
// Setting parquet.Repetitions.Optional makes parquet interpret all columns as nullable.
// When writing data, we will specify a definition level of 0 (null) or 1 (not null).
// See https://blog.twitter.com/engineering/en_us/a/2013/dremel-made-simple-with-parquet
// for more information regarding definition levels.
defaultRepetitions := parquet.Repetitions.Optional

func makeColumn(colName string, typ *types.T, repetitions parquet.Repetition) (column, error) {
result := column{typ: typ}
var err error
switch typ.Family() {
case types.BoolFamily:
result.node = schema.NewBooleanNode(colName, defaultRepetitions, defaultSchemaFieldID)
result.colWriter = writeBool
result.node = schema.NewBooleanNode(colName, repetitions, defaultSchemaFieldID)
result.colWriter = scalarWriter(writeBool)
result.decoder = boolDecoder{}
result.typ = types.Bool
return result, nil
case types.StringFamily:
result.node, err = schema.NewPrimitiveNodeLogical(colName,
defaultRepetitions, schema.StringLogicalType{}, parquet.Types.ByteArray,
repetitions, schema.StringLogicalType{}, parquet.Types.ByteArray,
defaultTypeLength, defaultSchemaFieldID)

if err != nil {
return result, err
}
result.colWriter = writeString
result.colWriter = scalarWriter(writeString)
result.decoder = stringDecoder{}
return result, nil
case types.IntFamily:
// Note: integer datums are always signed: https://www.cockroachlabs.com/docs/stable/int.html
if typ.Oid() == oid.T_int8 {
result.node, err = schema.NewPrimitiveNodeLogical(colName,
defaultRepetitions, schema.NewIntLogicalType(64, true),
repetitions, schema.NewIntLogicalType(64, true),
parquet.Types.Int64, defaultTypeLength,
defaultSchemaFieldID)
if err != nil {
return result, err
}
result.colWriter = writeInt64
result.colWriter = scalarWriter(writeInt64)
result.decoder = int64Decoder{}
return result, nil
}

result.node = schema.NewInt32Node(colName, defaultRepetitions, defaultSchemaFieldID)
result.colWriter = writeInt32
result.node = schema.NewInt32Node(colName, repetitions, defaultSchemaFieldID)
result.colWriter = scalarWriter(writeInt32)
result.decoder = int32Decoder{}
return result, nil
case types.DecimalFamily:
Expand All @@ -149,37 +148,71 @@ func makeColumn(colName string, typ *types.T) (column, error) {
}

result.node, err = schema.NewPrimitiveNodeLogical(colName,
defaultRepetitions, schema.NewDecimalLogicalType(precision,
repetitions, schema.NewDecimalLogicalType(precision,
scale), parquet.Types.ByteArray, defaultTypeLength,
defaultSchemaFieldID)
if err != nil {
return result, err
}
result.colWriter = writeDecimal
result.colWriter = scalarWriter(writeDecimal)
result.decoder = decimalDecoder{}
return result, nil
case types.UuidFamily:
result.node, err = schema.NewPrimitiveNodeLogical(colName,
defaultRepetitions, schema.UUIDLogicalType{},
repetitions, schema.UUIDLogicalType{},
parquet.Types.FixedLenByteArray, uuid.Size, defaultSchemaFieldID)
if err != nil {
return result, err
}
result.colWriter = writeUUID
result.colWriter = scalarWriter(writeUUID)
result.decoder = uUIDDecoder{}
return result, nil
case types.TimestampFamily:
// Note that all timestamp datums are in UTC: https://www.cockroachlabs.com/docs/stable/timestamp.html
result.node, err = schema.NewPrimitiveNodeLogical(colName,
defaultRepetitions, schema.StringLogicalType{}, parquet.Types.ByteArray,
repetitions, schema.StringLogicalType{}, parquet.Types.ByteArray,
defaultTypeLength, defaultSchemaFieldID)
if err != nil {
return result, err
}

result.colWriter = writeTimestamp
result.colWriter = scalarWriter(writeTimestamp)
result.decoder = timestampDecoder{}
return result, nil
case types.ArrayFamily:
// Arrays for type T are represented by the following:
// message schema { -- toplevel schema
// optional group a (LIST) { -- list column
// repeated group list {
// optional T element;
// }
// }
// }
// Representing arrays this way makes it easier to differentiate NULL, [NULL],
// and [] when encoding.
// There is more info about encoding arrays here:
// https://arrow.apache.org/blog/2022/10/08/arrow-parquet-encoding-part-2/
elementCol, err := makeColumn("element", typ.ArrayContents(), parquet.Repetitions.Optional)
if err != nil {
return result, err
}
innerListFields := []schema.Node{elementCol.node}
innerListNode, err := schema.NewGroupNode("list", parquet.Repetitions.Repeated, innerListFields, defaultSchemaFieldID)
if err != nil {
return result, err
}
outerListFields := []schema.Node{innerListNode}
result.node, err = schema.NewGroupNodeLogical(colName, parquet.Repetitions.Optional, outerListFields, schema.ListLogicalType{}, defaultSchemaFieldID)
if err != nil {
return result, err
}
result.decoder = elementCol.decoder
scalarColWriter, ok := elementCol.colWriter.(scalarWriter)
if !ok {
return result, errors.AssertionFailedf("expected scalar column writer")
}
result.colWriter = arrayWriter(scalarColWriter)
result.typ = elementCol.typ
return result, nil

// TODO(#99028): implement support for the remaining types.
// case types.INetFamily:
Expand All @@ -196,8 +229,7 @@ func makeColumn(colName string, typ *types.T) (column, error) {
// case types.TimeTZFamily:
// case types.IntervalFamily:
// case types.TimestampTZFamily:
// case types.ArrayFamily:
default:
return result, pgerror.Newf(pgcode.FeatureNotSupported, "parquet export does not support the %v type yet", typ.Family())
return result, pgerror.Newf(pgcode.FeatureNotSupported, "parquet export does not support the %v type", typ.Family())
}
}
162 changes: 113 additions & 49 deletions pkg/util/parquet/testutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/apache/arrow/go/v11/parquet"
"github.com/apache/arrow/go/v11/parquet/file"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -55,41 +56,51 @@ func ReadFileAndVerifyDatums(
for rg := 0; rg < reader.NumRowGroups(); rg++ {
rgr := reader.RowGroup(rg)
rowsInRowGroup := rgr.NumRows()
defLevels := make([]int16, rowsInRowGroup)

for colIdx := 0; colIdx < numCols; colIdx++ {
col, err := rgr.Column(colIdx)
require.NoError(t, err)

dec := writer.sch.cols[colIdx].decoder
typ := writer.sch.cols[colIdx].typ

// Based on how we define schemas, we can detect an array by seeing if the
// primitive col reader has a max repetition level of 1. See comments above
// arrayEntryRepLevel for more info.
isArray := col.Descriptor().MaxRepetitionLevel() == 1

switch col.Type() {
case parquet.Types.Boolean:
values := make([]bool, rowsInRowGroup)
readBatchHelper(t, col, rowsInRowGroup, values, defLevels)
decodeValuesIntoDatumsHelper(t, readDatums, colIdx, startingRowIdx, dec, values, defLevels)
colDatums, read, err := readBatch(col, make([]bool, 1), dec, typ, isArray)
require.NoError(t, err)
require.Equal(t, rowsInRowGroup, read)
decodeValuesIntoDatumsHelper(colDatums, readDatums, colIdx, startingRowIdx)
case parquet.Types.Int32:
values := make([]int32, numRows)
readBatchHelper(t, col, rowsInRowGroup, values, defLevels)
decodeValuesIntoDatumsHelper(t, readDatums, colIdx, startingRowIdx, dec, values, defLevels)
colDatums, read, err := readBatch(col, make([]int32, 1), dec, typ, isArray)
require.NoError(t, err)
require.Equal(t, rowsInRowGroup, read)
decodeValuesIntoDatumsHelper(colDatums, readDatums, colIdx, startingRowIdx)
case parquet.Types.Int64:
values := make([]int64, rowsInRowGroup)
readBatchHelper(t, col, rowsInRowGroup, values, defLevels)
decodeValuesIntoDatumsHelper(t, readDatums, colIdx, startingRowIdx, dec, values, defLevels)
colDatums, read, err := readBatch(col, make([]int64, 1), dec, typ, isArray)
require.NoError(t, err)
require.Equal(t, rowsInRowGroup, read)
decodeValuesIntoDatumsHelper(colDatums, readDatums, colIdx, startingRowIdx)
case parquet.Types.Int96:
panic("unimplemented")
case parquet.Types.Float:
panic("unimplemented")
case parquet.Types.Double:
panic("unimplemented")
case parquet.Types.ByteArray:
values := make([]parquet.ByteArray, rowsInRowGroup)
readBatchHelper(t, col, rowsInRowGroup, values, defLevels)
decodeValuesIntoDatumsHelper(t, readDatums, colIdx, startingRowIdx, dec, values, defLevels)
colDatums, read, err := readBatch(col, make([]parquet.ByteArray, 1), dec, typ, isArray)
require.NoError(t, err)
require.Equal(t, rowsInRowGroup, read)
decodeValuesIntoDatumsHelper(colDatums, readDatums, colIdx, startingRowIdx)
case parquet.Types.FixedLenByteArray:
values := make([]parquet.FixedLenByteArray, rowsInRowGroup)
readBatchHelper(t, col, rowsInRowGroup, values, defLevels)
decodeValuesIntoDatumsHelper(t, readDatums, colIdx, startingRowIdx, dec, values, defLevels)
colDatums, read, err := readBatch(col, make([]parquet.FixedLenByteArray, 1), dec, typ, isArray)
require.NoError(t, err)
require.Equal(t, rowsInRowGroup, read)
decodeValuesIntoDatumsHelper(colDatums, readDatums, colIdx, startingRowIdx)
}
}
startingRowIdx += int(rowsInRowGroup)
Expand All @@ -98,54 +109,107 @@ func ReadFileAndVerifyDatums(

for i := 0; i < numRows; i++ {
for j := 0; j < numCols; j++ {
assert.Equal(t, writtenDatums[i][j], readDatums[i][j])
validateDatum(t, writtenDatums[i][j], readDatums[i][j])
}
}
}

func readBatchHelper[T parquetDatatypes](
t *testing.T, r file.ColumnChunkReader, expectedrowsInRowGroup int64, values []T, defLvls []int16,
) {
numRead, err := readBatch(r, expectedrowsInRowGroup, values, defLvls)
require.NoError(t, err)
require.Equal(t, numRead, expectedrowsInRowGroup)
}

type batchReader[T parquetDatatypes] interface {
ReadBatch(batchSize int64, values []T, defLvls []int16, repLvls []int16) (total int64, valuesRead int, err error)
}

// readBatch reads all the datums in a row group for a column.
func readBatch[T parquetDatatypes](
r file.ColumnChunkReader, batchSize int64, values []T, defLvls []int16,
) (int64, error) {
r file.ColumnChunkReader, valueAlloc []T, dec decoder, typ *types.T, isArray bool,
) (tree.Datums, int64, error) {
br, ok := r.(batchReader[T])
if !ok {
return 0, errors.AssertionFailedf("expected batchReader of type %T, but found %T instead", values, r)
return nil, 0, errors.AssertionFailedf("expected batchReader for type %T, but found %T instead", valueAlloc, r)
}
numRowsRead, _, err := br.ReadBatch(batchSize, values, defLvls, nil)
return numRowsRead, err

result := make([]tree.Datum, 0)
defLevels := [1]int16{}
repLevels := [1]int16{}

for {
numRowsRead, _, err := br.ReadBatch(1, valueAlloc, defLevels[:], repLevels[:])
if err != nil {
return nil, 0, err
}
if numRowsRead == 0 {
break
}

if isArray {
// Replevel 0 indicates the start of a new array.
if repLevels[0] == 0 {
// Replevel 0, Deflevel 0 represents a NULL array.
if defLevels[0] == 0 {
result = append(result, tree.DNull)
continue
}
arrDatum := tree.NewDArray(typ)
result = append(result, arrDatum)
// Replevel 0, Deflevel 1 represents an array which is empty.
if defLevels[0] == 1 {
continue
}
}
currentArrayDatum := result[len(result)-1].(*tree.DArray)
// Deflevel 2 represents a null value in an array.
if defLevels[0] == 2 {
currentArrayDatum.Array = append(currentArrayDatum.Array, tree.DNull)
continue
}
// Deflevel 3 represents a non-null datum in an array.
d, err := decode(dec, valueAlloc[0])
if err != nil {
return nil, 0, err
}
currentArrayDatum.Array = append(currentArrayDatum.Array, d)
} else {
// Deflevel 0 represents a null value
// Deflevel 1 represents a non-null value
d := tree.DNull
if defLevels[0] != 0 {
d, err = decode(dec, valueAlloc[0])
if err != nil {
return nil, 0, err
}
}
result = append(result, d)
}
}

return result, int64(len(result)), nil
}

func decodeValuesIntoDatumsHelper[T parquetDatatypes](
t *testing.T,
datums [][]tree.Datum,
colIdx int,
startingRowIdx int,
dec decoder,
values []T,
defLevels []int16,
func decodeValuesIntoDatumsHelper(
colDatums []tree.Datum, datumRows [][]tree.Datum, colIdx int, startingRowIdx int,
) {
var err error
// If the defLevel of a datum is 0, parquet will not write it to the column.
// Use valueReadIdx to only read from the front of the values array, where datums are defined.
valueReadIdx := 0
for rowOffset, defLevel := range defLevels {
d := tree.DNull
if defLevel != 0 {
d, err = decode(dec, values[valueReadIdx])
require.NoError(t, err)
valueReadIdx++
for rowOffset, datum := range colDatums {
datumRows[startingRowIdx+rowOffset][colIdx] = datum
}
}

// validateDatum validates that the "contents" of the expected datum matches the
// contents of the actual datum. For example, when validating two arrays, we
// only compare the datums in the arrays. We do not compare CRDB-specific
// metadata fields such as (tree.DArray).HasNulls or (tree.DArray).HasNonNulls.
//
// The reason for this special comparison that the parquet format is presently
// used in an export use case, so we only need to test roundtripability with
// data end users see. We do not need to check roundtripability of internal CRDB data.
func validateDatum(t *testing.T, expected tree.Datum, actual tree.Datum) {
switch expected.ResolvedType().Family() {
case types.ArrayFamily:
arr1 := expected.(*tree.DArray).Array
arr2 := actual.(*tree.DArray).Array
require.Equal(t, len(arr1), len(arr2))
for i := 0; i < len(arr1); i++ {
validateDatum(t, arr1[i], arr2[i])
}
datums[startingRowIdx+rowOffset][colIdx] = d
default:
require.Equal(t, expected, actual)
}
}
Loading

0 comments on commit 3cbca9e

Please sign in to comment.