diff --git a/internal/verifier/change_stream.go b/internal/verifier/change_stream.go index 372de6bc..409ee722 100644 --- a/internal/verifier/change_stream.go +++ b/internal/verifier/change_stream.go @@ -226,7 +226,7 @@ func (verifier *Verifier) iterateChangeStream(ctx context.Context, cs *mongo.Cha break } - if curTs.After(writesOffTs) { + if !curTs.Before(writesOffTs) { verifier.logger.Debug(). Interface("currentTimestamp", curTs). Interface("writesOffTimestamp", writesOffTs). diff --git a/internal/verifier/clustertime.go b/internal/verifier/clustertime.go index 30a2c579..6eb273bd 100644 --- a/internal/verifier/clustertime.go +++ b/internal/verifier/clustertime.go @@ -7,6 +7,7 @@ import ( "github.com/10gen/migration-verifier/internal/retry" "github.com/10gen/migration-verifier/internal/util" "github.com/10gen/migration-verifier/mbson" + "github.com/10gen/migration-verifier/option" "github.com/pkg/errors" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" @@ -17,8 +18,8 @@ import ( const opTimeKeyInServerResponse = "operationTime" -// GetNewClusterTime advances the cluster time and returns that time. -// All shards’ cluster times will meet or exceed the returned time. +// GetNewClusterTime creates a new cluster time, updates all shards’ +// cluster times to meet or exceed that time, then returns it. func GetNewClusterTime( ctx context.Context, logger *logger.Logger, @@ -35,7 +36,12 @@ func GetNewClusterTime( logger, func(_ *retry.Info) error { var err error - clusterTime, err = fetchClusterTime(ctx, client) + clusterTime, err = runAppendOplogNote( + ctx, + client, + "new ts", + option.None[primitive.Timestamp](), + ) return err }, ) @@ -52,7 +58,12 @@ func GetNewClusterTime( logger, func(_ *retry.Info) error { var err error - _, err = syncClusterTimeAcrossShards(ctx, client, clusterTime) + _, err = runAppendOplogNote( + ctx, + client, + "sync ts", + option.Some(clusterTime), + ) return err }, ) @@ -65,46 +76,31 @@ func GetNewClusterTime( return clusterTime, nil } -// Use this when we just need the correct cluster time without -// actually changing any shards’ oplogs. -func fetchClusterTime( +func runAppendOplogNote( ctx context.Context, client *mongo.Client, + note string, + maxClusterTimeOpt option.Option[primitive.Timestamp], ) (primitive.Timestamp, error) { - cmd, rawResponse, err := runAppendOplogNote( - ctx, - client, - "expect StaleClusterTime error", - primitive.Timestamp{1, 0}, - ) - - // We expect an error here; if we didn't get one then something is - // amiss on the server. - if err == nil { - return primitive.Timestamp{}, errors.Errorf("server request unexpectedly succeeded: %v", cmd) + cmd := bson.D{ + {"appendOplogNote", 1}, + {"data", bson.D{ + {"migration-verifier", note}, + }}, } - if !util.IsStaleClusterTimeError(err) { - return primitive.Timestamp{}, errors.Wrap( - err, - "unexpected error (expected StaleClusterTime) from request", - ) + if maxClusterTime, has := maxClusterTimeOpt.Get(); has { + cmd = append(cmd, bson.E{"maxClusterTime", maxClusterTime}) } - return getOpTimeFromRawResponse(rawResponse) -} + resp := client. + Database( + "admin", + options.Database().SetWriteConcern(writeconcern.Majority()), + ). + RunCommand(ctx, cmd) -func syncClusterTimeAcrossShards( - ctx context.Context, - client *mongo.Client, - maxTime primitive.Timestamp, -) (primitive.Timestamp, error) { - _, rawResponse, err := runAppendOplogNote( - ctx, - client, - "syncing cluster time", - maxTime, - ) + rawResponse, err := resp.Raw() // If any shard’s cluster time >= maxTime, the mongos will return a // StaleClusterTime error. This particular error doesn’t indicate a @@ -119,36 +115,6 @@ func syncClusterTimeAcrossShards( return getOpTimeFromRawResponse(rawResponse) } -func runAppendOplogNote( - ctx context.Context, - client *mongo.Client, - note string, - maxClusterTime primitive.Timestamp, -) (bson.D, bson.Raw, error) { - cmd := bson.D{ - {"appendOplogNote", 1}, - {"maxClusterTime", maxClusterTime}, - {"data", bson.D{ - {"migration-verifier", note}, - }}, - } - - resp := client. - Database( - "admin", - options.Database().SetWriteConcern(writeconcern.Majority()), - ). - RunCommand(ctx, cmd) - - raw, err := resp.Raw() - - return cmd, raw, errors.Wrapf( - err, - "command (%v) failed unexpectedly", - cmd, - ) -} - func getOpTimeFromRawResponse(rawResponse bson.Raw) (primitive.Timestamp, error) { // Get the `operationTime` from the response and return it. var optime primitive.Timestamp diff --git a/internal/verifier/clustertime_test.go b/internal/verifier/clustertime_test.go index 6d2e5acb..6a13c02d 100644 --- a/internal/verifier/clustertime_test.go +++ b/internal/verifier/clustertime_test.go @@ -1,13 +1,35 @@ package verifier -import "context" +import ( + "context" -func (suite *IntegrationTestSuite) TestGetClusterTime() { + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" +) + +func (suite *IntegrationTestSuite) TestGetNewClusterTime() { ctx := context.Background() logger, _ := getLoggerAndWriter("stdout") + sess, err := suite.srcMongoClient.StartSession() + suite.Require().NoError(err) + + _, err = suite.srcMongoClient. + Database(suite.DBNameForTest()). + Collection("mycoll"). + InsertOne(mongo.NewSessionContext(ctx, sess), bson.D{}) + suite.Require().NoError(err) + + clusterTimeVal, err := sess.ClusterTime().LookupErr("$clusterTime", "clusterTime") + suite.Require().NoError(err, "should extract cluster time from %+v", sess.ClusterTime()) + + clusterT, clusterI, ok := clusterTimeVal.TimestampOK() + suite.Require().True(ok, "session cluster time (%s: %v) must be a timestamp", clusterTimeVal.Type, clusterTimeVal) + ts, err := GetNewClusterTime(ctx, logger, suite.srcMongoClient) suite.Require().NoError(err) - suite.Assert().NotZero(ts, "timestamp should be nonzero") + suite.Require().NotZero(ts, "timestamp should be nonzero") + suite.Assert().True(ts.After(primitive.Timestamp{T: clusterT, I: clusterI})) } diff --git a/option/bson.go b/option/bson.go new file mode 100644 index 00000000..3394185c --- /dev/null +++ b/option/bson.go @@ -0,0 +1,49 @@ +package option + +import ( + "github.com/pkg/errors" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/bson/primitive" +) + +// MarshalBSONValue implements bson.ValueMarshaler. +func (o Option[T]) MarshalBSONValue() (bsontype.Type, []byte, error) { + val, exists := o.Get() + if !exists { + return bson.MarshalValue(primitive.Null{}) + } + + return bson.MarshalValue(val) +} + +// UnmarshalBSONValue implements bson.ValueUnmarshaler. +func (o *Option[T]) UnmarshalBSONValue(bType bsontype.Type, raw []byte) error { + + switch bType { + case bson.TypeNull: + o.val = nil + + default: + valPtr := new(T) + + err := bson.UnmarshalValue(bType, raw, &valPtr) + if err != nil { + return errors.Wrapf(err, "failed to unmarshal %T", *o) + } + + // This may not even be possible, but we should still check. + if isNil(*valPtr) { + return errors.Wrapf(err, "refuse to unmarshal nil %T value", *o) + } + + o.val = valPtr + } + + return nil +} + +// IsZero implements bsoncodec.Zeroer. +func (o Option[T]) IsZero() bool { + return o.IsNone() +} diff --git a/option/json.go b/option/json.go new file mode 100644 index 00000000..b5290bbf --- /dev/null +++ b/option/json.go @@ -0,0 +1,37 @@ +package option + +import ( + "bytes" + "encoding/json" +) + +var _ json.Marshaler = &Option[int]{} +var _ json.Unmarshaler = &Option[int]{} + +// MarshalJSON encodes Option into json. +func (o Option[T]) MarshalJSON() ([]byte, error) { + val, exists := o.Get() + if exists { + return json.Marshal(val) + } + + return json.Marshal(nil) +} + +// UnmarshalJSON decodes Option from json. +func (o *Option[T]) UnmarshalJSON(b []byte) error { + if bytes.Equal(b, []byte("null")) { + o.val = nil + } else { + val := *new(T) + + err := json.Unmarshal(b, &val) + if err != nil { + return err + } + + o.val = &val + } + + return nil +} diff --git a/option/option.go b/option/option.go new file mode 100644 index 00000000..03706a9d --- /dev/null +++ b/option/option.go @@ -0,0 +1,153 @@ +// Package option implements [option types] in Go. +// It takes inspiration from [samber/mo] but also works with BSON and exposes +// a (hopefully) more refined interface. +// +// Option types facilitate avoidance of nil-dereference bugs, at the cost of a +// bit more overhead. +// +// A couple special notes: +// - nil values inside the Option, like `Some([]int(nil))`, are forbidden. +// - Option’s BSON marshaling/unmarshaling interoperates with the [bson] +// package’s handling of nilable pointers. So any code that uses nilable +// pointers to represent optional values can switch to Option and +// should continue working with existing persisted data. +// - Because encoding/json provides no equivalent to bsoncodec.Zeroer, +// Option always marshals to JSON null if empty. +// +// Prefer Option to nilable pointers in all new code, and consider +// changing existing code to use it. +// +// [option types]: https://en.wikipedia.org/wiki/Option_type +package option + +import ( + "fmt" + "reflect" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsoncodec" +) + +var _ bson.ValueMarshaler = &Option[int]{} +var _ bson.ValueUnmarshaler = &Option[int]{} +var _ bsoncodec.Zeroer = &Option[int]{} + +// Option represents a possibly-empty value. +// Its zero value is the empty case. +type Option[T any] struct { + val *T +} + +// Some creates an Option with a value. +func Some[T any](value T) Option[T] { + if isNil(value) { + panic(fmt.Sprintf("Option forbids nil value (%T).", value)) + } + + return Option[T]{&value} +} + +// None creates an Option with no value. +// +// Note that `None[T]()` is interchangeable with `Option[T]{}`. +func None[T any]() Option[T] { + return Option[T]{} +} + +// FromPointer will convert a nilable pointer into its +// equivalent Option. +func FromPointer[T any](valPtr *T) Option[T] { + if valPtr == nil { + return None[T]() + } + + if isNil(*valPtr) { + panic(fmt.Sprintf("Given pointer (%T) refers to nil, which is forbidden.", valPtr)) + } + + myCopy := *valPtr + + return Option[T]{&myCopy} +} + +// IfNotZero returns an Option that’s populated if & only if +// the given value is a non-zero value. (NB: The zero value +// for slices & maps is nil, not empty!) +// +// This is useful, e.g., to interface with code that uses +// nil to indicate a missing slice or map. +func IfNotZero[T any](value T) Option[T] { + + // copied from samber/mo.EmptyableToOption: + if reflect.ValueOf(&value).Elem().IsZero() { + return Option[T]{} + } + + return Option[T]{&value} +} + +// Get “unboxes” the Option’s internal value. +// The boolean indicates whether the value exists. +func (o Option[T]) Get() (T, bool) { + if o.val == nil { + return *new(T), false + } + + return *o.val, true +} + +// MustGet is like Get but panics if the Option is empty. +func (o Option[T]) MustGet() T { + val, exists := o.Get() + if !exists { + panic(fmt.Sprintf("MustGet() called on empty %T", o)) + } + + return val +} + +// OrZero returns either the Option’s internal value or +// the type’s zero value. +func (o Option[T]) OrZero() T { + val, exists := o.Get() + if exists { + return val + } + + return *new(T) +} + +// OrElse returns either the Option’s internal value or +// the given `fallback`. +func (o Option[T]) OrElse(fallback T) T { + val, exists := o.Get() + if exists { + return val + } + + return fallback +} + +// ToPointer converts the Option to a nilable pointer. +// The internal value (if it exists) is (shallow-)copied. +func (o Option[T]) ToPointer() *T { + val, exists := o.Get() + if exists { + theCopy := val + return &theCopy + } + + return nil +} + +// IsNone returns a boolean indicating whether or not the option is a None +// value. +func (o Option[T]) IsNone() bool { + return o.val == nil +} + +// IsSome returns a boolean indicating whether or not the option is a Some +// value. +func (o Option[T]) IsSome() bool { + return o.val != nil +} diff --git a/option/unit_test.go b/option/unit_test.go new file mode 100644 index 00000000..157bac89 --- /dev/null +++ b/option/unit_test.go @@ -0,0 +1,351 @@ +package option + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/suite" + "go.mongodb.org/mongo-driver/bson" +) + +type mySuite struct { + suite.Suite +} + +func TestUnitTestSuite(t *testing.T) { + suite.Run(t, &mySuite{}) +} + +func (s *mySuite) Test_Option_BSON() { + type MyType struct { + IsNone Option[int] + IsNoneOmitEmpty Option[int] `bson:",omitempty"` + IsSome Option[bool] + } + + type MyTypePtrs struct { + IsNone *int + IsNoneOmitEmpty *int `bson:",omitempty"` + IsSome *bool + } + + s.Run( + "marshal pointer, unmarshal Option", + func() { + + bytes, err := bson.Marshal(MyTypePtrs{ + IsNoneOmitEmpty: pointerTo(234), + IsSome: pointerTo(false), + }) + s.Require().NoError(err) + + rt := MyType{} + s.Require().NoError(bson.Unmarshal(bytes, &rt)) + + s.Assert().Equal( + MyType{ + IsNoneOmitEmpty: Some(234), + IsSome: Some(false), + }, + rt, + ) + }, + ) + + s.Run( + "marshal Option, unmarshal pointer", + func() { + + bytes, err := bson.Marshal(MyType{ + IsNoneOmitEmpty: Some(234), + IsSome: Some(false), + }) + s.Require().NoError(err) + + rt := MyTypePtrs{} + s.Require().NoError(bson.Unmarshal(bytes, &rt)) + + s.Assert().Equal( + MyTypePtrs{ + IsNoneOmitEmpty: pointerTo(234), + IsSome: pointerTo(false), + }, + rt, + ) + }, + ) + + s.Run( + "round-trip bson.D", + func() { + simpleDoc := bson.D{ + {"a", None[int]()}, + {"b", Some(123)}, + } + + bytes, err := bson.Marshal(simpleDoc) + s.Require().NoError(err) + + rt := bson.D{} + s.Require().NoError(bson.Unmarshal(bytes, &rt)) + + s.Assert().Equal( + bson.D{{"a", nil}, {"b", int32(123)}}, + rt, + ) + }, + ) + + s.Run( + "round-trip struct", + func() { + myThing := MyType{None[int](), None[int](), Some(true)} + + bytes, err := bson.Marshal(&myThing) + s.Require().NoError(err) + + // Unmarshal to a bson.D to test `omitempty`. + rtDoc := bson.D{} + s.Require().NoError(bson.Unmarshal(bytes, &rtDoc)) + + keys := make([]string, 0) + for _, el := range rtDoc { + keys = append(keys, el.Key) + } + + s.Assert().ElementsMatch( + []string{"isnone", "issome"}, + keys, + ) + + rtStruct := MyType{} + s.Require().NoError(bson.Unmarshal(bytes, &rtStruct)) + s.Assert().Equal( + myThing, + rtStruct, + ) + }, + ) +} + +func (s *mySuite) Test_Option_JSON() { + type MyType struct { + IsNone Option[int] + Omitted Option[int] + IsSome Option[bool] + } + + type MyTypePtrs struct { + IsNone *int + Omitted *int + IsSome *bool + } + + s.Run( + "marshal pointer, unmarshal Option", + func() { + + bytes, err := json.Marshal(MyTypePtrs{ + IsNone: pointerTo(234), + IsSome: pointerTo(false), + }) + s.Require().NoError(err) + + rt := MyType{} + s.Require().NoError(json.Unmarshal(bytes, &rt)) + + s.Assert().Equal( + MyType{ + IsNone: Some(234), + IsSome: Some(false), + }, + rt, + ) + }, + ) + + s.Run( + "marshal Option, unmarshal pointer", + func() { + + bytes, err := json.Marshal(MyType{ + IsNone: Some(234), + IsSome: Some(false), + }) + s.Require().NoError(err) + + rt := MyTypePtrs{} + s.Require().NoError(json.Unmarshal(bytes, &rt)) + + s.Assert().Equal( + MyTypePtrs{ + IsNone: pointerTo(234), + IsSome: pointerTo(false), + }, + rt, + ) + }, + ) + + s.Run( + "round-trip bson.D", + func() { + simpleDoc := bson.D{ + {"a", None[int]()}, + {"b", Some(123)}, + } + + bytes, err := json.Marshal(simpleDoc) + s.Require().NoError(err) + + rt := bson.D{} + s.Require().NoError(json.Unmarshal(bytes, &rt)) + + s.Assert().Equal( + bson.D{{"a", nil}, {"b", float64(123)}}, + rt, + ) + }, + ) + + s.Run( + "round-trip struct", + func() { + myThing := MyType{None[int](), None[int](), Some(true)} + + bytes, err := json.Marshal(&myThing) + s.Require().NoError(err) + + rtStruct := MyType{} + s.Require().NoError(json.Unmarshal(bytes, &rtStruct)) + s.Assert().Equal( + myThing, + rtStruct, + ) + }, + ) +} + +func (s *mySuite) Test_Option_NoNilSome() { + assertPanics(s, (chan int)(nil)) + assertPanics(s, (func())(nil)) + assertPanics(s, any(nil)) + assertPanics(s, map[int]any(nil)) + assertPanics(s, []any(nil)) + assertPanics(s, (*any)(nil)) +} + +func (s *mySuite) Test_Option_Pointer() { + opt := Some(123) + ptr := opt.ToPointer() + *ptr = 1234 + + s.Assert().Equal( + Some(123), + opt, + "ToPointer() sholuldn’t let caller alter Option value", + ) + + opt2 := FromPointer(ptr) + *ptr = 2345 + s.Assert().Equal( + Some(1234), + opt2, + "FromPointer() sholuldn’t let caller alter Option value", + ) +} + +func (s *mySuite) Test_Option() { + + //nolint:testifylint // None is, in fact, the expected value. + s.Assert().Equal( + None[int](), + Option[int]{}, + "zero value is None", + ) + + //nolint:testifylint + s.Assert().Equal(Some(1), Some(1), "same internal value") + s.Assert().NotEqual(Some(1), Some(2), "different internal value") + + foo := "foo" + fooPtr := Some(foo).ToPointer() + + s.Assert().Equal(&foo, fooPtr) + + s.Assert().Equal(Some(foo), FromPointer(fooPtr)) + + s.Assert().Equal( + foo, + Some(foo).OrZero(), + ) + + s.Assert().Equal( + "", + None[string]().OrZero(), + ) + + s.Assert().Equal( + "elf", + None[string]().OrElse("elf"), + ) + + val, has := Some(123).Get() + s.Assert().True(has) + s.Assert().Equal(123, val) + + val, has = None[int]().Get() + s.Assert().False(has) + s.Assert().Equal(0, val) + + some := Some(456) + s.Assert().True(some.IsSome()) + s.Assert().False(some.IsNone()) + + none := None[int]() + s.Assert().False(none.IsSome()) + s.Assert().True(none.IsNone()) +} + +func (s *mySuite) Test_Option_IfNonZero() { + assertIfNonZero(s, 0, 1) + assertIfNonZero(s, "", "a") + assertIfNonZero(s, []int(nil), []int{}) + assertIfNonZero(s, map[int]int(nil), map[int]int{}) + assertIfNonZero(s, any(nil), any(0)) + assertIfNonZero(s, bson.D(nil), bson.D{}) + + type myStruct struct { + name string + } + + assertIfNonZero(s, myStruct{}, myStruct{"foo"}) +} + +func assertIfNonZero[T any](s *mySuite, zeroVal, nonZeroVal T) { + noneOpt := IfNotZero(zeroVal) + someOpt := IfNotZero(nonZeroVal) + + s.Assert().Equal(None[T](), noneOpt) + s.Assert().Equal(Some(nonZeroVal), someOpt) +} + +func pointerTo[T any](val T) *T { + return &val +} + +func assertPanics[T any](s *mySuite, val T) { + s.T().Helper() + + s.Assert().Panics( + func() { Some(val) }, + "Some(%T)", + val, + ) + + s.Assert().Panics( + func() { FromPointer(&val) }, + "FromPointer(&%T)", + val, + ) +} diff --git a/option/validate.go b/option/validate.go new file mode 100644 index 00000000..698ae75b --- /dev/null +++ b/option/validate.go @@ -0,0 +1,28 @@ +package option + +import ( + "reflect" + + mapset "github.com/deckarep/golang-set/v2" +) + +var nilable = mapset.NewThreadUnsafeSet( + reflect.Chan, + reflect.Func, + reflect.Interface, + reflect.Map, + reflect.Pointer, + reflect.Slice, +) + +func isNil(val any) bool { + if val == nil { + return true + } + + if nilable.Contains(reflect.TypeOf(val).Kind()) { + return reflect.ValueOf(val).IsNil() + } + + return false +}