From bc9562feb19831bb98cecebdb848a9f6cece1f1d Mon Sep 17 00:00:00 2001 From: Ted Xu Date: Fri, 8 Nov 2024 08:30:57 +0800 Subject: [PATCH] enhance: avoid memory copy and serde in mix compaction (#37479) See: #37234 --------- Signed-off-by: Ted Xu --- internal/datanode/compaction/mix_compactor.go | 81 +++++++++--- .../datanode/compaction/segment_writer.go | 74 ++++++++++- internal/storage/serde.go | 106 +++++++++++++-- internal/storage/serde_events.go | 121 +++++++++--------- internal/storage/serde_events_test.go | 79 ++++++++++++ 5 files changed, 361 insertions(+), 100 deletions(-) diff --git a/internal/datanode/compaction/mix_compactor.go b/internal/datanode/compaction/mix_compactor.go index a12ef12edec6f..f8f646c46436b 100644 --- a/internal/datanode/compaction/mix_compactor.go +++ b/internal/datanode/compaction/mix_compactor.go @@ -23,6 +23,7 @@ import ( "math" "time" + "github.com/apache/arrow/go/v12/arrow/array" "github.com/cockroachdb/errors" "github.com/samber/lo" "go.opentelemetry.io/otel" @@ -33,6 +34,7 @@ import ( "github.com/milvus-io/milvus/internal/flushcommon/io" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -199,26 +201,43 @@ func (t *mixCompactionTask) writeSegment(ctx context.Context, log.Warn("compact wrong, fail to merge deltalogs", zap.Error(err)) return } - isValueDeleted := func(v *storage.Value) bool { - ts, ok := delta[v.PK.GetValue()] + + isValueDeleted := func(pk any, ts typeutil.Timestamp) bool { + oldts, ok := delta[pk] // insert task and delete task has the same ts when upsert // here should be < instead of <= // to avoid the upsert data to be deleted after compact - if ok && uint64(v.Timestamp) < ts { + if ok && ts < oldts { + deletedRowCount++ + return true + } + // Filtering expired entity + if isExpiredEntity(t.plan.GetCollectionTtl(), t.currentTs, typeutil.Timestamp(ts)) { + expiredRowCount++ return true } return false } - iter, err := storage.NewBinlogDeserializeReader(blobs, pkField.GetFieldID()) + reader, err := storage.NewCompositeBinlogRecordReader(blobs) if err != nil { log.Warn("compact wrong, failed to new insert binlogs reader", zap.Error(err)) return } - defer iter.Close() + defer reader.Close() + writeSlice := func(r storage.Record, start, end int) error { + sliced := r.Slice(start, end) + defer sliced.Release() + err = mWriter.WriteRecord(sliced) + if err != nil { + log.Warn("compact wrong, failed to writer row", zap.Error(err)) + return err + } + return nil + } for { - err = iter.Next() + err = reader.Next() if err != nil { if err == sio.EOF { err = nil @@ -228,23 +247,45 @@ func (t *mixCompactionTask) writeSegment(ctx context.Context, return } } - v := iter.Value() - - if isValueDeleted(v) { - deletedRowCount++ - continue - } + r := reader.Record() + pkArray := r.Column(pkField.FieldID) + tsArray := r.Column(common.TimeStampField).(*array.Int64) + + sliceStart := -1 + rows := r.Len() + for i := 0; i < rows; i++ { + // Filtering deleted entities + var pk any + switch pkField.DataType { + case schemapb.DataType_Int64: + pk = pkArray.(*array.Int64).Value(i) + case schemapb.DataType_VarChar: + pk = pkArray.(*array.String).Value(i) + default: + panic("invalid data type") + } + ts := typeutil.Timestamp(tsArray.Value(i)) + if isValueDeleted(pk, ts) { + if sliceStart != -1 { + err = writeSlice(r, sliceStart, i) + if err != nil { + return + } + sliceStart = -1 + } + continue + } - // Filtering expired entity - if isExpiredEntity(t.plan.GetCollectionTtl(), t.currentTs, typeutil.Timestamp(v.Timestamp)) { - expiredRowCount++ - continue + if sliceStart == -1 { + sliceStart = i + } } - err = mWriter.Write(v) - if err != nil { - log.Warn("compact wrong, failed to writer row", zap.Error(err)) - return + if sliceStart != -1 { + err = writeSlice(r, sliceStart, r.Len()) + if err != nil { + return + } } } return diff --git a/internal/datanode/compaction/segment_writer.go b/internal/datanode/compaction/segment_writer.go index b6171b2b6f829..6eb900da6ebfe 100644 --- a/internal/datanode/compaction/segment_writer.go +++ b/internal/datanode/compaction/segment_writer.go @@ -9,6 +9,7 @@ import ( "fmt" "math" + "github.com/apache/arrow/go/v12/arrow/array" "github.com/samber/lo" "go.uber.org/atomic" "go.uber.org/zap" @@ -20,6 +21,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -183,12 +185,7 @@ func (w *MultiSegmentWriter) getWriter() (*SegmentWriter, error) { return w.writers[w.current], nil } -func (w *MultiSegmentWriter) Write(v *storage.Value) error { - writer, err := w.getWriter() - if err != nil { - return err - } - +func (w *MultiSegmentWriter) writeInternal(writer *SegmentWriter) error { if writer.IsFull() { // init segment fieldBinlogs if it is not exist if _, ok := w.cachedMeta[writer.segmentID]; !ok { @@ -206,6 +203,29 @@ func (w *MultiSegmentWriter) Write(v *storage.Value) error { mergeFieldBinlogs(w.cachedMeta[writer.segmentID], partialBinlogs) } + return nil +} + +func (w *MultiSegmentWriter) WriteRecord(r storage.Record) error { + writer, err := w.getWriter() + if err != nil { + return err + } + if err := w.writeInternal(writer); err != nil { + return err + } + + return writer.WriteRecord(r) +} + +func (w *MultiSegmentWriter) Write(v *storage.Value) error { + writer, err := w.getWriter() + if err != nil { + return err + } + if err := w.writeInternal(writer); err != nil { + return err + } return writer.Write(v) } @@ -358,6 +378,48 @@ func (w *SegmentWriter) WrittenMemorySize() uint64 { return w.writer.WrittenMemorySize() } +func (w *SegmentWriter) WriteRecord(r storage.Record) error { + tsArray := r.Column(common.TimeStampField).(*array.Int64) + rows := r.Len() + for i := 0; i < rows; i++ { + ts := typeutil.Timestamp(tsArray.Value(i)) + if ts < w.tsFrom { + w.tsFrom = ts + } + if ts > w.tsTo { + w.tsTo = ts + } + + switch schemapb.DataType(w.pkstats.PkType) { + case schemapb.DataType_Int64: + pkArray := r.Column(w.GetPkID()).(*array.Int64) + pk := &storage.Int64PrimaryKey{ + Value: pkArray.Value(i), + } + w.pkstats.Update(pk) + case schemapb.DataType_VarChar: + pkArray := r.Column(w.GetPkID()).(*array.String) + pk := &storage.VarCharPrimaryKey{ + Value: pkArray.Value(i), + } + w.pkstats.Update(pk) + default: + panic("invalid data type") + } + + for fieldID, stats := range w.bm25Stats { + field, ok := r.Column(fieldID).(*array.Binary) + if !ok { + return fmt.Errorf("bm25 field value not found") + } + stats.AppendBytes(field.Value(i)) + } + + w.rowCount.Inc() + } + return w.writer.WriteRecord(r) +} + func (w *SegmentWriter) Write(v *storage.Value) error { ts := typeutil.Timestamp(v.Timestamp) if ts < w.tsFrom { diff --git a/internal/storage/serde.go b/internal/storage/serde.go index fc6f6468b2265..7da70f282bf56 100644 --- a/internal/storage/serde.go +++ b/internal/storage/serde.go @@ -24,9 +24,12 @@ import ( "github.com/apache/arrow/go/v12/arrow" "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/arrow/go/v12/arrow/memory" "github.com/apache/arrow/go/v12/parquet" "github.com/apache/arrow/go/v12/parquet/compress" "github.com/apache/arrow/go/v12/parquet/pqarrow" + "github.com/cockroachdb/errors" + "github.com/samber/lo" "go.uber.org/atomic" "google.golang.org/protobuf/proto" @@ -40,6 +43,7 @@ type Record interface { Column(i FieldID) arrow.Array Len() int Release() + Slice(start, end int) Record } type RecordReader interface { @@ -50,6 +54,7 @@ type RecordReader interface { type RecordWriter interface { Write(r Record) error + GetWrittenUncompressed() uint64 Close() } @@ -64,6 +69,8 @@ type compositeRecord struct { schema map[FieldID]schemapb.DataType } +var _ Record = (*compositeRecord)(nil) + func (r *compositeRecord) Column(i FieldID) arrow.Array { return r.recs[i].Column(0) } @@ -93,6 +100,17 @@ func (r *compositeRecord) ArrowSchema() *arrow.Schema { return arrow.NewSchema(fields, nil) } +func (r *compositeRecord) Slice(start, end int) Record { + slices := make(map[FieldID]arrow.Record) + for i, rec := range r.recs { + slices[i] = rec.NewSlice(int64(start), int64(end)) + } + return &compositeRecord{ + recs: slices, + schema: r.schema, + } +} + type serdeEntry struct { // arrowType returns the arrow type for the given dimension arrowType func(int) arrow.DataType @@ -527,6 +545,17 @@ func (deser *DeserializeReader[T]) Next() error { return nil } +func (deser *DeserializeReader[T]) NextRecord() (Record, error) { + if len(deser.values) != 0 { + return nil, errors.New("deserialize result is not empty") + } + + if err := deser.rr.Next(); err != nil { + return nil, err + } + return deser.rr.Record(), nil +} + func (deser *DeserializeReader[T]) Value() T { return deser.values[deser.pos] } @@ -580,6 +609,16 @@ func (r *selectiveRecord) Release() { // do nothing. } +func (r *selectiveRecord) Slice(start, end int) Record { + panic("not implemented") +} + +func calculateArraySize(a arrow.Array) int { + return lo.SumBy[*memory.Buffer, int](a.Data().Buffers(), func(b *memory.Buffer) int { + return b.Len() + }) +} + func newSelectiveRecord(r Record, selectedFieldId FieldID) *selectiveRecord { dt, ok := r.Schema()[selectedFieldId] if !ok { @@ -594,16 +633,29 @@ func newSelectiveRecord(r Record, selectedFieldId FieldID) *selectiveRecord { } } -var _ RecordWriter = (*compositeRecordWriter)(nil) +var _ RecordWriter = (*CompositeRecordWriter)(nil) -type compositeRecordWriter struct { +type CompositeRecordWriter struct { writers map[FieldID]RecordWriter + + writtenUncompressed uint64 +} + +func (crw *CompositeRecordWriter) GetWrittenUncompressed() uint64 { + return crw.writtenUncompressed } -func (crw *compositeRecordWriter) Write(r Record) error { +func (crw *CompositeRecordWriter) Write(r Record) error { if len(r.Schema()) != len(crw.writers) { return fmt.Errorf("schema length mismatch %d, expected %d", len(r.Schema()), len(crw.writers)) } + + var bytes uint64 + for fid := range r.Schema() { + arr := r.Column(fid) + bytes += uint64(calculateArraySize(arr)) + } + crw.writtenUncompressed += bytes for fieldId, w := range crw.writers { sr := newSelectiveRecord(r, fieldId) if err := w.Write(sr); err != nil { @@ -613,7 +665,7 @@ func (crw *compositeRecordWriter) Write(r Record) error { return nil } -func (crw *compositeRecordWriter) Close() { +func (crw *CompositeRecordWriter) Close() { if crw != nil { for _, w := range crw.writers { if w != nil { @@ -623,8 +675,8 @@ func (crw *compositeRecordWriter) Close() { } } -func newCompositeRecordWriter(writers map[FieldID]RecordWriter) *compositeRecordWriter { - return &compositeRecordWriter{ +func NewCompositeRecordWriter(writers map[FieldID]RecordWriter) *CompositeRecordWriter { + return &CompositeRecordWriter{ writers: writers, } } @@ -640,22 +692,29 @@ func WithRecordWriterProps(writerProps *parquet.WriterProperties) RecordWriterOp } type singleFieldRecordWriter struct { - fw *pqarrow.FileWriter - fieldId FieldID - schema *arrow.Schema - - numRows int + fw *pqarrow.FileWriter + fieldId FieldID + schema *arrow.Schema writerProps *parquet.WriterProperties + + numRows int + writtenUncompressed uint64 } func (sfw *singleFieldRecordWriter) Write(r Record) error { sfw.numRows += r.Len() a := r.Column(sfw.fieldId) + + sfw.writtenUncompressed += uint64(a.Data().Buffers()[0].Len()) rec := array.NewRecord(sfw.schema, []arrow.Array{a}, int64(r.Len())) defer rec.Release() return sfw.fw.WriteBuffered(rec) } +func (sfw *singleFieldRecordWriter) GetWrittenUncompressed() uint64 { + return sfw.writtenUncompressed +} + func (sfw *singleFieldRecordWriter) Close() { sfw.fw.Close() } @@ -687,7 +746,8 @@ type multiFieldRecordWriter struct { fieldIds []FieldID schema *arrow.Schema - numRows int + numRows int + writtenUncompressed uint64 } func (mfw *multiFieldRecordWriter) Write(r Record) error { @@ -695,12 +755,17 @@ func (mfw *multiFieldRecordWriter) Write(r Record) error { columns := make([]arrow.Array, len(mfw.fieldIds)) for i, fieldId := range mfw.fieldIds { columns[i] = r.Column(fieldId) + mfw.writtenUncompressed += uint64(calculateArraySize(columns[i])) } rec := array.NewRecord(mfw.schema, columns, int64(r.Len())) defer rec.Release() return mfw.fw.WriteBuffered(rec) } +func (mfw *multiFieldRecordWriter) GetWrittenUncompressed() uint64 { + return mfw.writtenUncompressed +} + func (mfw *multiFieldRecordWriter) Close() { mfw.fw.Close() } @@ -765,6 +830,23 @@ func (sw *SerializeWriter[T]) Write(value T) error { return nil } +func (sw *SerializeWriter[T]) WriteRecord(r Record) error { + if len(sw.buffer) != 0 { + return errors.New("serialize buffer is not empty") + } + + if err := sw.rw.Write(r); err != nil { + return err + } + + size := 0 + for fid := range r.Schema() { + size += calculateArraySize(r.Column(fid)) + } + sw.writtenMemorySize.Add(uint64(size)) + return nil +} + func (sw *SerializeWriter[T]) WrittenMemorySize() uint64 { return sw.writtenMemorySize.Load() } diff --git a/internal/storage/serde_events.go b/internal/storage/serde_events.go index 255bc1cef8042..da319a959e257 100644 --- a/internal/storage/serde_events.go +++ b/internal/storage/serde_events.go @@ -28,7 +28,6 @@ import ( "github.com/apache/arrow/go/v12/arrow" "github.com/apache/arrow/go/v12/arrow/array" "github.com/apache/arrow/go/v12/arrow/memory" - "github.com/samber/lo" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" @@ -38,9 +37,9 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) -var _ RecordReader = (*compositeBinlogRecordReader)(nil) +var _ RecordReader = (*CompositeBinlogRecordReader)(nil) -type compositeBinlogRecordReader struct { +type CompositeBinlogRecordReader struct { blobs [][]*Blob blobPos int @@ -51,7 +50,7 @@ type compositeBinlogRecordReader struct { r compositeRecord } -func (crr *compositeBinlogRecordReader) iterateNextBatch() error { +func (crr *CompositeBinlogRecordReader) iterateNextBatch() error { if crr.closers != nil { for _, close := range crr.closers { if close != nil { @@ -91,7 +90,7 @@ func (crr *compositeBinlogRecordReader) iterateNextBatch() error { return nil } -func (crr *compositeBinlogRecordReader) Next() error { +func (crr *CompositeBinlogRecordReader) Next() error { if crr.rrs == nil { if crr.blobs == nil || len(crr.blobs) == 0 { return io.EOF @@ -135,11 +134,11 @@ func (crr *compositeBinlogRecordReader) Next() error { return nil } -func (crr *compositeBinlogRecordReader) Record() Record { +func (crr *CompositeBinlogRecordReader) Record() Record { return &crr.r } -func (crr *compositeBinlogRecordReader) Close() { +func (crr *CompositeBinlogRecordReader) Close() { for _, close := range crr.closers { if close != nil { close() @@ -158,7 +157,7 @@ func parseBlobKey(blobKey string) (colId FieldID, logId UniqueID) { return InvalidUniqueID, InvalidUniqueID } -func newCompositeBinlogRecordReader(blobs []*Blob) (*compositeBinlogRecordReader, error) { +func NewCompositeBinlogRecordReader(blobs []*Blob) (*CompositeBinlogRecordReader, error) { blobMap := make(map[FieldID][]*Blob) for _, blob := range blobs { colId, _ := parseBlobKey(blob.Key) @@ -178,13 +177,13 @@ func newCompositeBinlogRecordReader(blobs []*Blob) (*compositeBinlogRecordReader }) sortedBlobs = append(sortedBlobs, blobsForField) } - return &compositeBinlogRecordReader{ + return &CompositeBinlogRecordReader{ blobs: sortedBlobs, }, nil } func NewBinlogDeserializeReader(blobs []*Blob, PKfieldID UniqueID) (*DeserializeReader[*Value], error) { - reader, err := newCompositeBinlogRecordReader(blobs) + reader, err := NewCompositeBinlogRecordReader(blobs) if err != nil { return nil, err } @@ -234,7 +233,7 @@ func NewBinlogDeserializeReader(blobs []*Blob, PKfieldID UniqueID) (*Deserialize } func newDeltalogOneFieldReader(blobs []*Blob) (*DeserializeReader[*DeleteLog], error) { - reader, err := newCompositeBinlogRecordReader(blobs) + reader, err := NewCompositeBinlogRecordReader(blobs) if err != nil { return nil, err } @@ -264,8 +263,6 @@ type BinlogStreamWriter struct { segmentID UniqueID fieldSchema *schemapb.FieldSchema - memorySize int // To be updated on the fly - buf bytes.Buffer rw *singleFieldRecordWriter } @@ -306,7 +303,7 @@ func (bsw *BinlogStreamWriter) Finalize() (*Blob, error) { Key: strconv.Itoa(int(bsw.fieldSchema.FieldID)), Value: b.Bytes(), RowNum: int64(bsw.rw.numRows), - MemorySize: int64(bsw.memorySize), + MemorySize: int64(bsw.rw.writtenUncompressed), }, nil } @@ -319,7 +316,7 @@ func (bsw *BinlogStreamWriter) writeBinlogHeaders(w io.Writer) error { de := NewBaseDescriptorEvent(bsw.collectionID, bsw.partitionID, bsw.segmentID) de.PayloadDataType = bsw.fieldSchema.DataType de.FieldID = bsw.fieldSchema.FieldID - de.descriptorEventData.AddExtra(originalSizeKey, strconv.Itoa(bsw.memorySize)) + de.descriptorEventData.AddExtra(originalSizeKey, strconv.Itoa(int(bsw.rw.writtenUncompressed))) de.descriptorEventData.AddExtra(nullableKey, bsw.fieldSchema.Nullable) if err := de.Write(w); err != nil { return err @@ -356,6 +353,50 @@ func NewBinlogStreamWriters(collectionID, partitionID, segmentID UniqueID, return bws } +func ValueSerializer(v []*Value, fieldSchema []*schemapb.FieldSchema) (Record, uint64, error) { + builders := make(map[FieldID]array.Builder, len(fieldSchema)) + types := make(map[FieldID]schemapb.DataType, len(fieldSchema)) + for _, f := range fieldSchema { + dim, _ := typeutil.GetDim(f) + builders[f.FieldID] = array.NewBuilder(memory.DefaultAllocator, serdeMap[f.DataType].arrowType(int(dim))) + types[f.FieldID] = f.DataType + } + + var memorySize uint64 + for _, vv := range v { + m := vv.Value.(map[FieldID]any) + + for fid, e := range m { + typeEntry, ok := serdeMap[types[fid]] + if !ok { + panic("unknown type") + } + ok = typeEntry.serialize(builders[fid], e) + if !ok { + return nil, 0, merr.WrapErrServiceInternal(fmt.Sprintf("serialize error on type %s", types[fid])) + } + } + } + arrays := make([]arrow.Array, len(types)) + fields := make([]arrow.Field, len(types)) + field2Col := make(map[FieldID]int, len(types)) + i := 0 + for fid, builder := range builders { + arrays[i] = builder.NewArray() + memorySize += uint64(calculateArraySize(arrays[i])) + + builder.Release() + fields[i] = arrow.Field{ + Name: strconv.Itoa(int(fid)), + Type: arrays[i].DataType(), + Nullable: true, // No nullable check here. + } + field2Col[fid] = i + i++ + } + return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(fields, nil), arrays, int64(len(v))), types, field2Col), memorySize, nil +} + func NewBinlogSerializeWriter(schema *schemapb.CollectionSchema, partitionID, segmentID UniqueID, eventWriters map[FieldID]*BinlogStreamWriter, batchSize int, ) (*SerializeWriter[*Value], error) { @@ -368,53 +409,9 @@ func NewBinlogSerializeWriter(schema *schemapb.CollectionSchema, partitionID, se } rws[fid] = rw } - compositeRecordWriter := newCompositeRecordWriter(rws) + compositeRecordWriter := NewCompositeRecordWriter(rws) return NewSerializeRecordWriter[*Value](compositeRecordWriter, func(v []*Value) (Record, uint64, error) { - builders := make(map[FieldID]array.Builder, len(schema.Fields)) - types := make(map[FieldID]schemapb.DataType, len(schema.Fields)) - for _, f := range schema.Fields { - dim, _ := typeutil.GetDim(f) - builders[f.FieldID] = array.NewBuilder(memory.DefaultAllocator, serdeMap[f.DataType].arrowType(int(dim))) - types[f.FieldID] = f.DataType - } - - var memorySize uint64 - for _, vv := range v { - m := vv.Value.(map[FieldID]any) - - for fid, e := range m { - typeEntry, ok := serdeMap[types[fid]] - if !ok { - panic("unknown type") - } - ok = typeEntry.serialize(builders[fid], e) - if !ok { - return nil, 0, merr.WrapErrServiceInternal(fmt.Sprintf("serialize error on type %s", types[fid])) - } - } - } - arrays := make([]arrow.Array, len(types)) - fields := make([]arrow.Field, len(types)) - field2Col := make(map[FieldID]int, len(types)) - i := 0 - for fid, builder := range builders { - arrays[i] = builder.NewArray() - size := lo.SumBy[*memory.Buffer, int](arrays[i].Data().Buffers(), func(b *memory.Buffer) int { - return b.Len() - }) - eventWriters[fid].memorySize += size - memorySize += uint64(size) - - builder.Release() - fields[i] = arrow.Field{ - Name: strconv.Itoa(int(fid)), - Type: arrays[i].DataType(), - Nullable: true, // No nullable check here. - } - field2Col[fid] = i - i++ - } - return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(fields, nil), arrays, int64(len(v))), types, field2Col), memorySize, nil + return ValueSerializer(v, schema.Fields) }, batchSize), nil } @@ -515,7 +512,7 @@ func newDeltalogSerializeWriter(eventWriter *DeltalogStreamWriter, batchSize int return nil, err } rws[0] = rw - compositeRecordWriter := newCompositeRecordWriter(rws) + compositeRecordWriter := NewCompositeRecordWriter(rws) return NewSerializeRecordWriter[*DeleteLog](compositeRecordWriter, func(v []*DeleteLog) (Record, uint64, error) { builder := array.NewBuilder(memory.DefaultAllocator, arrow.BinaryTypes.String) diff --git a/internal/storage/serde_events_test.go b/internal/storage/serde_events_test.go index f3c19842bebb9..d97c17fd5e1f9 100644 --- a/internal/storage/serde_events_test.go +++ b/internal/storage/serde_events_test.go @@ -19,9 +19,13 @@ package storage import ( "bytes" "context" + "fmt" "io" + "math/rand" + "sort" "strconv" "testing" + "time" "github.com/apache/arrow/go/v12/arrow" "github.com/apache/arrow/go/v12/arrow/array" @@ -29,9 +33,12 @@ import ( "github.com/apache/arrow/go/v12/parquet/file" "github.com/apache/arrow/go/v12/parquet/pqarrow" "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" ) func TestBinlogDeserializeReader(t *testing.T) { @@ -182,6 +189,78 @@ func TestBinlogSerializeWriter(t *testing.T) { }) } +func BenchmarkSerializeWriter(b *testing.B) { + const ( + dim = 128 + numRows = 200000 + ) + + var ( + rId = &schemapb.FieldSchema{FieldID: common.RowIDField, Name: common.RowIDFieldName, DataType: schemapb.DataType_Int64} + ts = &schemapb.FieldSchema{FieldID: common.TimeStampField, Name: common.TimeStampFieldName, DataType: schemapb.DataType_Int64} + pk = &schemapb.FieldSchema{FieldID: 100, Name: "pk", IsPrimaryKey: true, DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{{Key: common.MaxLengthKey, Value: "100"}}} + f = &schemapb.FieldSchema{FieldID: 101, Name: "random", DataType: schemapb.DataType_Double} + // fVec = &schemapb.FieldSchema{FieldID: 102, Name: "vec", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: strconv.Itoa(dim)}}} + ) + schema := &schemapb.CollectionSchema{Name: "test-aaa", Fields: []*schemapb.FieldSchema{rId, ts, pk, f}} + + // prepare data values + start := time.Now() + vec := make([]float32, dim) + for j := 0; j < dim; j++ { + vec[j] = rand.Float32() + } + values := make([]*Value, numRows) + for i := 0; i < numRows; i++ { + value := &Value{} + value.Value = make(map[int64]interface{}, len(schema.GetFields())) + m := value.Value.(map[int64]interface{}) + for _, field := range schema.GetFields() { + switch field.GetDataType() { + case schemapb.DataType_Int64: + m[field.GetFieldID()] = int64(i) + case schemapb.DataType_VarChar: + k := fmt.Sprintf("test_pk_%d", i) + m[field.GetFieldID()] = k + value.PK = &VarCharPrimaryKey{ + Value: k, + } + case schemapb.DataType_Double: + m[field.GetFieldID()] = float64(i) + case schemapb.DataType_FloatVector: + m[field.GetFieldID()] = vec + } + } + value.ID = int64(i) + value.Timestamp = int64(0) + value.IsDeleted = false + value.Value = m + values[i] = value + } + sort.Slice(values, func(i, j int) bool { + return values[i].PK.LT(values[j].PK) + }) + log.Info("prepare data done", zap.Int("len", len(values)), zap.Duration("dur", time.Since(start))) + + b.ResetTimer() + + sizes := []int{100, 1000, 10000, 100000} + for _, s := range sizes { + b.Run(fmt.Sprintf("batch size=%d", s), func(b *testing.B) { + for i := 0; i < b.N; i++ { + writers := NewBinlogStreamWriters(0, 0, 0, schema.Fields) + writer, err := NewBinlogSerializeWriter(schema, 0, 0, writers, s) + assert.NoError(b, err) + for _, v := range values { + _ = writer.Write(v) + assert.NoError(b, err) + } + writer.Close() + } + }) + } +} + func TestNull(t *testing.T) { t.Run("test null", func(t *testing.T) { schema := generateTestSchema()