Skip to content

Commit

Permalink
GODRIVER-1180 Remove legacy transform functions from mongo (#583)
Browse files Browse the repository at this point in the history
  • Loading branch information
benjirewis authored May 10, 2021
1 parent 899e5cd commit 4760305
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 176 deletions.
2 changes: 1 addition & 1 deletion mongo/bulk_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (opera
var i int
for _, model := range batch.models {
converted := model.(*InsertOneModel)
doc, _, err := transformAndEnsureIDv2(bw.collection.registry, converted.Document)
doc, _, err := transformAndEnsureID(bw.collection.registry, converted.Document)
if err != nil {
return operation.InsertResult{}, err
}
Expand Down
4 changes: 2 additions & 2 deletions mongo/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func (coll *Collection) insert(ctx context.Context, documents []interface{},

for i, doc := range documents {
var err error
docs[i], result[i], err = transformAndEnsureIDv2(coll.registry, doc)
docs[i], result[i], err = transformAndEnsureID(coll.registry, doc)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -746,7 +746,7 @@ func aggregate(a aggregateParams) (*Cursor, error) {
a.ctx = context.Background()
}

pipelineArr, hasOutputStage, err := transformAggregatePipelinev2(a.registry, a.pipeline)
pipelineArr, hasOutputStage, err := transformAggregatePipeline(a.registry, a.pipeline)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion mongo/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ func (db *Database) CreateCollection(ctx context.Context, name string, opts ...*
func (db *Database) CreateView(ctx context.Context, viewName, viewOn string, pipeline interface{},
opts ...*options.CreateViewOptions) error {

pipelineArray, _, err := transformAggregatePipelinev2(db.registry, pipeline)
pipelineArray, _, err := transformAggregatePipeline(db.registry, pipeline)
if err != nil {
return err
}
Expand Down
115 changes: 6 additions & 109 deletions mongo/mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,70 +72,9 @@ func (me MarshalError) Error() string {
//
type Pipeline []bson.D

// transformAndEnsureID is a hack that makes it easy to get a RawValue as the _id value. This will
// be removed when we switch from using bsonx to bsoncore for the driver package.
func transformAndEnsureID(registry *bsoncodec.Registry, val interface{}) (bsonx.Doc, interface{}, error) {
// TODO: performance is going to be pretty bad for bsonx.Doc here since we turn it into a []byte
// only to turn it back into a bsonx.Doc. We can fix this post beta1 when we refactor the driver
// package to use bsoncore.Document instead of bsonx.Doc.
if registry == nil {
registry = bson.NewRegistryBuilder().Build()
}
switch tt := val.(type) {
case nil:
return nil, nil, ErrNilDocument
case bsonx.Doc:
val = tt.Copy()
case []byte:
// Slight optimization so we'll just use MarshalBSON and not go through the codec machinery.
val = bson.Raw(tt)
}

// TODO(skriptble): Use a pool of these instead.
buf := make([]byte, 0, 256)
b, err := bson.MarshalAppendWithRegistry(registry, buf, val)
if err != nil {
return nil, nil, MarshalError{Value: val, Err: err}
}

d, err := bsonx.ReadDoc(b)
if err != nil {
return nil, nil, err
}

var id interface{}

idx := d.IndexOf("_id")
var idElem bsonx.Elem
switch idx {
case -1:
idElem = bsonx.Elem{"_id", bsonx.ObjectID(primitive.NewObjectID())}
d = append(d, bsonx.Elem{})
copy(d[1:], d)
d[0] = idElem
default:
idElem = d[idx]
copy(d[1:idx+1], d[0:idx])
d[0] = idElem
}

idBuf := make([]byte, 0, 256)
t, data, err := idElem.Value.MarshalAppendBSONValue(idBuf[:0])
if err != nil {
return nil, nil, err
}

err = bson.RawValue{Type: t, Value: data}.UnmarshalWithRegistry(registry, &id)
if err != nil {
return nil, nil, err
}

return d, id, nil
}

// transformAndEnsureIDv2 is a hack that makes it easy to get a RawValue as the _id value. This will
// be removed when we switch from using bsonx to bsoncore for the driver package.
func transformAndEnsureIDv2(registry *bsoncodec.Registry, val interface{}) (bsoncore.Document, interface{}, error) {
// transformAndEnsureID is a hack that makes it easy to get a RawValue as the _id value.
// It will also add an ObjectID _id as the first key if it not already present in the passed-in val.
func transformAndEnsureID(registry *bsoncodec.Registry, val interface{}) (bsoncore.Document, interface{}, error) {
if registry == nil {
registry = bson.NewRegistryBuilder().Build()
}
Expand Down Expand Up @@ -237,17 +176,7 @@ func ensureID(d bsonx.Doc) (bsonx.Doc, interface{}) {
return d, id
}

func ensureDollarKey(doc bsonx.Doc) error {
if len(doc) == 0 {
return errors.New("update document must have at least one element")
}
if !strings.HasPrefix(doc[0].Key, "$") {
return errors.New("update document must contain key beginning with '$'")
}
return nil
}

func ensureDollarKeyv2(doc bsoncore.Document) error {
func ensureDollarKey(doc bsoncore.Document) error {
firstElem, err := doc.IndexErr(0)
if err != nil {
return errors.New("update document must have at least one element")
Expand All @@ -267,39 +196,7 @@ func ensureNoDollarKey(doc bsoncore.Document) error {
return nil
}

func transformAggregatePipeline(registry *bsoncodec.Registry, pipeline interface{}) (bsonx.Arr, error) {
pipelineArr := bsonx.Arr{}
switch t := pipeline.(type) {
case bsoncodec.ValueMarshaler:
btype, val, err := t.MarshalBSONValue()
if err != nil {
return nil, err
}
if btype != bsontype.Array {
return nil, fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v", btype, bsontype.Array)
}
err = pipelineArr.UnmarshalBSONValue(btype, val)
if err != nil {
return nil, err
}
default:
val := reflect.ValueOf(t)
if !val.IsValid() || (val.Kind() != reflect.Slice && val.Kind() != reflect.Array) {
return nil, fmt.Errorf("can only transform slices and arrays into aggregation pipelines, but got %v", val.Kind())
}
for idx := 0; idx < val.Len(); idx++ {
elem, err := transformDocument(registry, val.Index(idx).Interface())
if err != nil {
return nil, err
}
pipelineArr = append(pipelineArr, bsonx.Document(elem))
}
}

return pipelineArr, nil
}

func transformAggregatePipelinev2(registry *bsoncodec.Registry, pipeline interface{}) (bsoncore.Document, bool, error) {
func transformAggregatePipeline(registry *bsoncodec.Registry, pipeline interface{}) (bsoncore.Document, bool, error) {
switch t := pipeline.(type) {
case bsoncodec.ValueMarshaler:
btype, val, err := t.MarshalBSONValue()
Expand Down Expand Up @@ -350,7 +247,7 @@ func transformAggregatePipelinev2(registry *bsoncodec.Registry, pipeline interfa
}

func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, dollarKeysAllowed bool) (bsoncore.Value, error) {
documentCheckerFunc := ensureDollarKeyv2
documentCheckerFunc := ensureDollarKey
if !dollarKeysAllowed {
documentCheckerFunc = ensureNoDollarKey
}
Expand Down
Loading

0 comments on commit 4760305

Please sign in to comment.