diff --git a/context.go b/context.go index fb05ea7..221e321 100644 --- a/context.go +++ b/context.go @@ -48,7 +48,7 @@ type Context interface { Awakeable(options ...options.AwakeableOption) Awakeable // ResolveAwakeable allows an awakeable (not necessarily from this service) to be // resolved with a particular value. - ResolveAwakeable(id string, value any, options ...options.ResolveAwakeableOption) error + ResolveAwakeable(id string, value any, options ...options.ResolveAwakeableOption) // ResolveAwakeable allows an awakeable (not necessarily from this service) to be // rejected with a particular error. RejectAwakeable(id string, reason error) @@ -81,11 +81,11 @@ type Awakeable interface { // CallClient represents all the different ways you can invoke a particular service/key/method tuple. type CallClient interface { // RequestFuture makes a call and returns a handle on a future response - RequestFuture(input any) (ResponseFuture, error) + RequestFuture(input any) ResponseFuture // Request makes a call and blocks on getting the response which is stored in output Request(input any, output any) error // Send makes a one-way call which is executed in the background - Send(input any, delay time.Duration) error + Send(input any, delay time.Duration) } // ResponseFuture is a handle on a potentially not-yet completed outbound call. @@ -166,7 +166,8 @@ type ObjectSharedContext interface { // KeyValueReader is the set of read-only methods which can be used in all Virtual Object handlers. type KeyValueReader interface { // Get gets value associated with key and stores it in value - // If key does not exist, this function returns ErrKeyNotFound + // If key does not exist, this function returns [ErrKeyNotFound] + // If the invocation was cancelled while obtaining the state, a cancellation error is returned // Note: Use GetAs generic helper function to avoid passing in a value pointer Get(key string, value any, options ...options.GetOption) error // Keys returns a list of all associated key @@ -179,7 +180,7 @@ type KeyValueReader interface { // KeyValueWriter is the set of mutating methods which can be used in exclusive-mode Virtual Object handlers. type KeyValueWriter interface { // Set sets a value against a key, using the provided codec (defaults to JSON) - Set(key string, value any, options ...options.SetOption) error + Set(key string, value any, options ...options.SetOption) // Clear deletes a key Clear(key string) // ClearAll drops all stored state associated with key diff --git a/examples/codegen/main.go b/examples/codegen/main.go index f1cc262..354d66d 100644 --- a/examples/codegen/main.go +++ b/examples/codegen/main.go @@ -45,14 +45,10 @@ func (c counter) Add(ctx restate.ObjectContext, req *helloworld.AddRequest) (*he } count += req.Delta - if err := ctx.Set("counter", count); err != nil { - return nil, err - } + ctx.Set("counter", count) for _, awakeableID := range watchers { - if err := ctx.ResolveAwakeable(awakeableID, count); err != nil { - return nil, err - } + ctx.ResolveAwakeable(awakeableID, count) } ctx.Clear("watchers") @@ -74,9 +70,7 @@ func (c counter) AddWatcher(ctx restate.ObjectContext, req *helloworld.AddWatche return nil, err } watchers = append(watchers, req.AwakeableId) - if err := ctx.Set("watchers", watchers); err != nil { - return nil, err - } + ctx.Set("watchers", watchers) return &helloworld.AddWatcherResponse{}, nil } diff --git a/examples/ticketreservation/ticket_service.go b/examples/ticketreservation/ticket_service.go index 1544157..4282f3d 100644 --- a/examples/ticketreservation/ticket_service.go +++ b/examples/ticketreservation/ticket_service.go @@ -27,7 +27,8 @@ func (t *ticketService) Reserve(ctx restate.ObjectContext, _ restate.Void) (bool } if status == TicketAvailable { - return true, ctx.Set("status", TicketReserved) + ctx.Set("status", TicketReserved) + return true, nil } return false, nil @@ -59,7 +60,8 @@ func (t *ticketService) MarkAsSold(ctx restate.ObjectContext, _ restate.Void) (v } if status == TicketReserved { - return void, ctx.Set("status", TicketSold) + ctx.Set("status", TicketSold) + return void, nil } return void, nil @@ -69,10 +71,5 @@ func (t *ticketService) Status(ctx restate.ObjectSharedContext, _ restate.Void) ticketId := ctx.Key() ctx.Log().Info("mark ticket as sold", "ticket", ticketId) - status, err := restate.GetAs[TicketStatus](ctx, "status") - if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { - return status, err - } - - return status, nil + return restate.GetAs[TicketStatus](ctx, "status") } diff --git a/examples/ticketreservation/user_session.go b/examples/ticketreservation/user_session.go index c6cb346..3bacca0 100644 --- a/examples/ticketreservation/user_session.go +++ b/examples/ticketreservation/user_session.go @@ -30,20 +30,14 @@ func (u *userSession) AddTicket(ctx restate.ObjectContext, ticketId string) (boo // add ticket to list of tickets tickets, err := restate.GetAs[[]string](ctx, "tickets") - if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { return false, err } tickets = append(tickets, ticketId) - if err := ctx.Set("tickets", tickets); err != nil { - return false, err - } - - if err := ctx.Object(UserSessionServiceName, userId, "ExpireTicket").Send(ticketId, 15*time.Minute); err != nil { - return false, err - } + ctx.Set("tickets", tickets) + ctx.Object(UserSessionServiceName, userId, "ExpireTicket").Send(ticketId, 15*time.Minute) return true, nil } @@ -66,11 +60,10 @@ func (u *userSession) ExpireTicket(ctx restate.ObjectContext, ticketId string) ( return void, nil } - if err := ctx.Set("tickets", tickets); err != nil { - return void, err - } + ctx.Set("tickets", tickets) + ctx.Object(TicketServiceName, ticketId, "Unreserve").Send(nil, 0) - return void, ctx.Object(TicketServiceName, ticketId, "Unreserve").Send(nil, 0) + return void, nil } func (u *userSession) Checkout(ctx restate.ObjectContext, _ restate.Void) (bool, error) { @@ -88,11 +81,8 @@ func (u *userSession) Checkout(ctx restate.ObjectContext, _ restate.Void) (bool, timeout := ctx.After(time.Minute) - request, err := restate.CallAs[PaymentResponse](ctx.Object(CheckoutServiceName, "", "Payment")). + request := restate.CallAs[PaymentResponse](ctx.Object(CheckoutServiceName, "", "Payment")). RequestFuture(PaymentRequest{UserID: userId, Tickets: tickets}) - if err != nil { - return false, err - } // race between the request and the timeout switch ctx.Select(timeout, request).Select() { @@ -113,9 +103,7 @@ func (u *userSession) Checkout(ctx restate.ObjectContext, _ restate.Void) (bool, for _, ticket := range tickets { call := ctx.Object(TicketServiceName, ticket, "MarkAsSold") - if err := call.Send(nil, 0); err != nil { - return false, err - } + call.Send(nil, 0) } ctx.Clear("tickets") diff --git a/facilitators.go b/facilitators.go index c476fe4..4f8e0e8 100644 --- a/facilitators.go +++ b/facilitators.go @@ -8,6 +8,9 @@ import ( // GetAs get the value for a key, returning a typed response instead of accepting a pointer. // If there is no associated value with key, [ErrKeyNotFound] is returned +// If the invocation was cancelled while obtaining the state, a cancellation error is returned, however this +// can currently only occur if RESTATE_WORKER__INVOKER__DISABLE_EAGER_STATE is set to true (default false). +// If this flag is not true, err will always be ErrKeyNotFound or nil. func GetAs[T any](ctx ObjectSharedContext, key string, options ...options.GetOption) (output T, err error) { err = ctx.Get(key, &output, options...) return @@ -51,11 +54,11 @@ func AwakeableAs[T any](ctx Context, options ...options.AwakeableOption) TypedAw // TypedCallClient is an extension of [CallClient] which deals in typed values type TypedCallClient[I any, O any] interface { // RequestFuture makes a call and returns a handle on a future response - RequestFuture(input I) (TypedResponseFuture[O], error) + RequestFuture(input I) TypedResponseFuture[O] // Request makes a call and blocks on getting the response Request(input I) (O, error) // Send makes a one-way call which is executed in the background - Send(input I, delay time.Duration) error + Send(input I, delay time.Duration) } type typedCallClient[I any, O any] struct { @@ -70,24 +73,16 @@ func NewTypedCallClient[I any, O any](client CallClient) TypedCallClient[I, O] { } func (t typedCallClient[I, O]) Request(input I) (output O, err error) { - fut, err := t.inner.RequestFuture(input) - if err != nil { - return output, err - } - err = fut.Response(&output) + err = t.inner.RequestFuture(input).Response(&output) return } -func (t typedCallClient[I, O]) RequestFuture(input I) (TypedResponseFuture[O], error) { - fut, err := t.inner.RequestFuture(input) - if err != nil { - return nil, err - } - return typedResponseFuture[O]{fut}, nil +func (t typedCallClient[I, O]) RequestFuture(input I) TypedResponseFuture[O] { + return typedResponseFuture[O]{t.inner.RequestFuture(input)} } -func (t typedCallClient[I, O]) Send(input I, delay time.Duration) error { - return t.inner.Send(input, delay) +func (t typedCallClient[I, O]) Send(input I, delay time.Duration) { + t.inner.Send(input, delay) } // TypedResponseFuture is an extension of [ResponseFuture] which returns typed responses instead of accepting a pointer diff --git a/handler.go b/handler.go index 26f25d3..2d1641b 100644 --- a/handler.go +++ b/handler.go @@ -82,7 +82,8 @@ func (h *serviceHandler[I, O]) Call(ctx Context, bytes []byte) ([]byte, error) { bytes, err = encoding.Marshal(h.options.Codec, output) if err != nil { - return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) + // we don't use a terminal error here as this is hot-fixable by changing the return type + return nil, fmt.Errorf("failed to serialize output: %w", err) } return bytes, nil @@ -170,7 +171,8 @@ func (h *objectHandler[I, O]) Call(ctx ObjectContext, bytes []byte) ([]byte, err bytes, err = encoding.Marshal(h.options.Codec, output) if err != nil { - return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) + // we don't use a terminal error here as this is hot-fixable by changing the return type + return nil, fmt.Errorf("failed to serialize output: %w", err) } return bytes, nil diff --git a/internal/state/call.go b/internal/state/call.go index 9e3e186..b96304d 100644 --- a/internal/state/call.go +++ b/internal/state/call.go @@ -10,7 +10,6 @@ import ( restate "github.com/restatedev/sdk-go" "github.com/restatedev/sdk-go/encoding" protocol "github.com/restatedev/sdk-go/generated/dev/restate/service" - "github.com/restatedev/sdk-go/internal/errors" "github.com/restatedev/sdk-go/internal/futures" "github.com/restatedev/sdk-go/internal/options" "github.com/restatedev/sdk-go/internal/wire" @@ -25,22 +24,24 @@ type serviceCall struct { } // RequestFuture makes a call and returns a handle on the response -func (c *serviceCall) RequestFuture(input any) (restate.ResponseFuture, error) { +func (c *serviceCall) RequestFuture(input any) restate.ResponseFuture { bytes, err := encoding.Marshal(c.options.Codec, input) if err != nil { - return nil, errors.NewTerminalError(fmt.Errorf("failed to marshal RequestFuture input: %w", err)) + panic(c.machine.newCodecFailure(fmt.Errorf("failed to marshal RequestFuture input: %w", err))) } entry, entryIndex := c.machine.doCall(c.service, c.key, c.method, c.options.Headers, bytes) return decodingResponseFuture{ futures.NewResponseFuture(c.machine.suspensionCtx, entry, entryIndex, func(err error) any { return c.machine.newProtocolViolation(entry, err) }), + c.machine, c.options, - }, nil + } } type decodingResponseFuture struct { *futures.ResponseFuture + machine *Machine options options.CallOptions } @@ -51,7 +52,7 @@ func (d decodingResponseFuture) Response(output any) (err error) { } if err := encoding.Unmarshal(d.options.Codec, bytes, output); err != nil { - return errors.NewTerminalError(fmt.Errorf("failed to unmarshal Call response into output: %w", err)) + panic(d.machine.newCodecFailure(fmt.Errorf("failed to unmarshal Call response into output: %w", err))) } return nil @@ -59,21 +60,17 @@ func (d decodingResponseFuture) Response(output any) (err error) { // Request makes a call and blocks on the response func (c *serviceCall) Request(input any, output any) error { - fut, err := c.RequestFuture(input) - if err != nil { - return err - } - return fut.Response(output) + return c.RequestFuture(input).Response(output) } // Send runs a call in the background after delay duration -func (c *serviceCall) Send(input any, delay time.Duration) error { +func (c *serviceCall) Send(input any, delay time.Duration) { bytes, err := encoding.Marshal(c.options.Codec, input) if err != nil { - return errors.NewTerminalError(fmt.Errorf("failed to marshal Send input: %w", err)) + panic(c.machine.newCodecFailure(fmt.Errorf("failed to marshal Send input: %w", err))) } c.machine.sendCall(c.service, c.key, c.method, c.options.Headers, bytes, delay) - return nil + return } func (m *Machine) doCall(service, key, method string, headersMap map[string]string, params []byte) (*wire.CallEntryMessage, uint32) { diff --git a/internal/state/state.go b/internal/state/state.go index 7306f89..0423e43 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -55,7 +55,7 @@ func (c *Context) Rand() *rand.Rand { return c.machine.rand } -func (c *Context) Set(key string, value any, opts ...options.SetOption) error { +func (c *Context) Set(key string, value any, opts ...options.SetOption) { o := options.SetOptions{} for _, opt := range opts { opt.BeforeSet(&o) @@ -66,11 +66,11 @@ func (c *Context) Set(key string, value any, opts ...options.SetOption) error { bytes, err := encoding.Marshal(o.Codec, value) if err != nil { - return errors.NewTerminalError(fmt.Errorf("failed to marshal Set value: %w", err)) + panic(c.machine.newCodecFailure(fmt.Errorf("failed to marshal Set value: %w", err))) } c.machine.set(key, bytes) - return nil + return } func (c *Context) Clear(key string) { @@ -93,13 +93,13 @@ func (c *Context) Get(key string, output any, opts ...options.GetOption) error { o.Codec = encoding.JSONCodec } - bytes := c.machine.get(key) - if len(bytes) == 0 { - return errors.ErrKeyNotFound + bytes, err := c.machine.get(key) + if err != nil { + return err } if err := encoding.Unmarshal(o.Codec, bytes, output); err != nil { - return errors.NewTerminalError(fmt.Errorf("failed to unmarshal Get state into output: %w", err)) + panic(c.machine.newCodecFailure(fmt.Errorf("failed to unmarshal Get state into output: %w", err))) } return nil @@ -169,7 +169,7 @@ func (c *Context) Run(fn func(ctx restate.RunContext) (any, error), output any, bytes, err := encoding.Marshal(o.Codec, output) if err != nil { - return nil, errors.NewTerminalError(fmt.Errorf("failed to marshal Run output: %w", err)) + panic(c.machine.newCodecFailure(fmt.Errorf("failed to marshal Run output: %w", err))) } return bytes, nil @@ -179,7 +179,7 @@ func (c *Context) Run(fn func(ctx restate.RunContext) (any, error), output any, } if err := encoding.Unmarshal(o.Codec, bytes, output); err != nil { - return errors.NewTerminalError(fmt.Errorf("failed to unmarshal Run output: %w", err)) + panic(c.machine.newCodecFailure(fmt.Errorf("failed to unmarshal Run output: %w", err))) } return nil @@ -201,12 +201,13 @@ func (c *Context) Awakeable(opts ...options.AwakeableOption) restate.Awakeable { if o.Codec == nil { o.Codec = encoding.JSONCodec } - return decodingAwakeable{c.machine.awakeable(), o.Codec} + return decodingAwakeable{c.machine.awakeable(), c.machine, o.Codec} } type decodingAwakeable struct { *futures.Awakeable - codec encoding.Codec + machine *Machine + codec encoding.Codec } func (d decodingAwakeable) Id() string { return d.Awakeable.Id() } @@ -216,12 +217,12 @@ func (d decodingAwakeable) Result(output any) (err error) { return err } if err := encoding.Unmarshal(d.codec, bytes, output); err != nil { - return errors.NewTerminalError(fmt.Errorf("failed to unmarshal Awakeable result into output: %w", err)) + panic(d.machine.newCodecFailure(fmt.Errorf("failed to unmarshal Awakeable result into output: %w", err))) } return } -func (c *Context) ResolveAwakeable(id string, value any, opts ...options.ResolveAwakeableOption) error { +func (c *Context) ResolveAwakeable(id string, value any, opts ...options.ResolveAwakeableOption) { o := options.ResolveAwakeableOptions{} for _, opt := range opts { opt.BeforeResolveAwakeable(&o) @@ -231,10 +232,9 @@ func (c *Context) ResolveAwakeable(id string, value any, opts ...options.Resolve } bytes, err := encoding.Marshal(o.Codec, value) if err != nil { - return errors.NewTerminalError(fmt.Errorf("failed to marshal ResolveAwakeable value: %w", err)) + panic(c.machine.newCodecFailure(fmt.Errorf("failed to marshal ResolveAwakeable value: %w", err))) } c.machine.resolveAwakeable(id, bytes) - return nil } func (c *Context) RejectAwakeable(id string, reason error) { @@ -432,6 +432,21 @@ The journal entry at position %d was: m.log.LogAttrs(m.ctx, slog.LevelError, "Error sending failure message", log.Error(typ.err)) } + return + case *codecFailure: + m.log.LogAttrs(m.ctx, slog.LevelError, "Encoding failed, returning error to Restate", log.Error(typ.err)) + + if err := m.protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ + ErrorMessage: protocol.ErrorMessage{ + Code: uint32(restate.ErrorCode(typ.err)), + Message: typ.err.Error(), + RelatedEntryIndex: &typ.entryIndex, + RelatedEntryType: wire.AwakeableEntryMessageType.UInt32(), + }, + }); err != nil { + m.log.LogAttrs(m.ctx, slog.LevelError, "Error sending failure message", log.Error(typ.err)) + } + return case *clientGoneAway: m.log.LogAttrs(m.ctx, slog.LevelWarn, "Cancelling invocation as the incoming request context was cancelled", log.Error(typ.err)) diff --git a/internal/state/state_test.go b/internal/state/state_test.go index b5536a2..57456ef 100644 --- a/internal/state/state_test.go +++ b/internal/state/state_test.go @@ -283,10 +283,7 @@ func TestInvocationCanceled(t *testing.T) { { name: "call should return cancelled error", fn: func(ctx restate.Context) error { - fut, err := ctx.Service("foo", "bar").RequestFuture(restate.Void{}) - if err != nil { - return err - } + fut := ctx.Service("foo", "bar").RequestFuture(restate.Void{}) return fut.Response(restate.Void{}) }, }, diff --git a/internal/state/sys.go b/internal/state/sys.go index f50c818..efaddc4 100644 --- a/internal/state/sys.go +++ b/internal/state/sys.go @@ -125,7 +125,7 @@ func (m *Machine) _clearAll() { ) } -func (m *Machine) get(key string) []byte { +func (m *Machine) get(key string) ([]byte, error) { entry, entryIndex := replayOrNew( m, func(entry *wire.GetStateEntryMessage) *wire.GetStateEntryMessage { @@ -145,10 +145,12 @@ func (m *Machine) get(key string) []byte { switch value := entry.Result.(type) { case *protocol.GetStateEntryMessage_Empty: - return nil + return nil, restate.ErrKeyNotFound case *protocol.GetStateEntryMessage_Value: m.current[key] = value.Value - return value.Value + return value.Value, nil + case *protocol.GetStateEntryMessage_Failure: + return nil, errors.ErrorFromFailure(value.Failure) default: panic(m.newProtocolViolation(entry, fmt.Errorf("get state entry had invalid result: %v", entry.Result))) } @@ -357,3 +359,14 @@ func (m *Machine) newRunFailure(err error) *runFailure { m.failure = s return s } + +type codecFailure struct { + entryIndex uint32 + err error +} + +func (m *Machine) newCodecFailure(err error) *codecFailure { + c := &codecFailure{m.entryIndex, err} + m.failure = c + return c +} diff --git a/reflect.go b/reflect.go index dc153a4..fbe07b3 100644 --- a/reflect.go +++ b/reflect.go @@ -185,7 +185,8 @@ func (h *objectReflectHandler) Call(ctx ObjectContext, bytes []byte) ([]byte, er bytes, err := encoding.Marshal(h.options.Codec, outI) if err != nil { - return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) + // we don't use a terminal error here as this is hot-fixable by changing the return type + return nil, fmt.Errorf("failed to serialize output: %w", err) } return bytes, nil @@ -219,7 +220,8 @@ func (h *serviceReflectHandler) Call(ctx Context, bytes []byte) ([]byte, error) bytes, err := encoding.Marshal(h.options.Codec, outI) if err != nil { - return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) + // we don't use a terminal error here as this is hot-fixable by changing the return type + return nil, fmt.Errorf("failed to serialize output: %w", err) } return bytes, nil diff --git a/test-services/awakeableholder.go b/test-services/awakeableholder.go index bf5934b..68e9adf 100644 --- a/test-services/awakeableholder.go +++ b/test-services/awakeableholder.go @@ -1,6 +1,7 @@ package main import ( + "errors" "fmt" restate "github.com/restatedev/sdk-go" @@ -13,15 +14,13 @@ func init() { restate.NewObject("AwakeableHolder"). Handler("hold", restate.NewObjectHandler( func(ctx restate.ObjectContext, id string) (restate.Void, error) { - if err := ctx.Set(ID_KEY, id); err != nil { - return restate.Void{}, err - } + ctx.Set(ID_KEY, id) return restate.Void{}, nil })). Handler("hasAwakeable", restate.NewObjectHandler( func(ctx restate.ObjectContext, _ restate.Void) (bool, error) { _, err := restate.GetAs[string](ctx, ID_KEY) - if err != nil && err != restate.ErrKeyNotFound { + if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { return false, err } return err == nil, nil @@ -30,14 +29,12 @@ func init() { func(ctx restate.ObjectContext, payload string) (restate.Void, error) { id, err := restate.GetAs[string](ctx, ID_KEY) if err != nil { - if err == restate.ErrKeyNotFound { + if errors.Is(err, restate.ErrKeyNotFound) { return restate.Void{}, restate.TerminalError(fmt.Errorf("No awakeable registered"), 404) } return restate.Void{}, err } - if err := ctx.ResolveAwakeable(id, payload); err != nil { - return restate.Void{}, err - } + ctx.ResolveAwakeable(id, payload) return restate.Void{}, nil }))) } diff --git a/test-services/canceltest.go b/test-services/canceltest.go index fdaee39..9da4b6d 100644 --- a/test-services/canceltest.go +++ b/test-services/canceltest.go @@ -24,7 +24,8 @@ func init() { func(ctx restate.ObjectContext, operation BlockingOperation) (restate.Void, error) { if err := ctx.Object("CancelTestBlockingService", "", "block").Request(operation, restate.Void{}); err != nil { if restate.ErrorCode(err) == 409 { - return restate.Void{}, ctx.Set(CANCELED_STATE, true) + ctx.Set(CANCELED_STATE, true) + return restate.Void{}, nil } return restate.Void{}, err } @@ -32,11 +33,7 @@ func init() { })). Handler("verifyTest", restate.NewObjectHandler( func(ctx restate.ObjectContext, _ restate.Void) (bool, error) { - canceled, err := restate.GetAs[bool](ctx, CANCELED_STATE) - if err != nil && err != restate.ErrKeyNotFound { - return false, err - } - return canceled, nil + return restate.GetAs[bool](ctx, CANCELED_STATE) }))) REGISTRY.AddDefinition( restate.NewObject("CancelTestBlockingService"). diff --git a/test-services/counter.go b/test-services/counter.go index 89c8ec3..7a89afc 100644 --- a/test-services/counter.go +++ b/test-services/counter.go @@ -24,45 +24,32 @@ func init() { })). Handler("get", restate.NewObjectSharedHandler( func(ctx restate.ObjectSharedContext, _ restate.Void) (int64, error) { - c, err := restate.GetAs[int64](ctx, COUNTER_KEY) - if errors.Is(err, restate.ErrKeyNotFound) { - c = 0 - } else if err != nil { - return 0, err - } - return c, nil + return restate.GetAs[int64](ctx, COUNTER_KEY) })). Handler("add", restate.NewObjectHandler( func(ctx restate.ObjectContext, addend int64) (CounterUpdateResponse, error) { oldValue, err := restate.GetAs[int64](ctx, COUNTER_KEY) - if errors.Is(err, restate.ErrKeyNotFound) { - oldValue = 0 - } else if err != nil { + if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { return CounterUpdateResponse{}, err } newValue := oldValue + addend - err = ctx.Set(COUNTER_KEY, newValue) + ctx.Set(COUNTER_KEY, newValue) return CounterUpdateResponse{ OldValue: oldValue, NewValue: newValue, - }, err + }, nil })). Handler("addThenFail", restate.NewObjectHandler( func(ctx restate.ObjectContext, addend int64) (restate.Void, error) { oldValue, err := restate.GetAs[int64](ctx, COUNTER_KEY) - if errors.Is(err, restate.ErrKeyNotFound) { - oldValue = 0 - } else if err != nil { + if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { return restate.Void{}, err } newValue := oldValue + addend - err = ctx.Set(COUNTER_KEY, newValue) - if err != nil { - return restate.Void{}, err - } + ctx.Set(COUNTER_KEY, newValue) return restate.Void{}, restate.TerminalError(fmt.Errorf("%s", ctx.Key())) }))) diff --git a/test-services/kill.go b/test-services/kill.go index 95a6c0f..9120ffd 100644 --- a/test-services/kill.go +++ b/test-services/kill.go @@ -14,9 +14,7 @@ func init() { Handler("recursiveCall", restate.NewObjectHandler( func(ctx restate.ObjectContext, _ restate.Void) (restate.Void, error) { awakeable := ctx.Awakeable() - if err := ctx.Object("AwakeableHolder", "kill", "hold").Send(awakeable.Id(), 0); err != nil { - return restate.Void{}, err - } + ctx.Object("AwakeableHolder", "kill", "hold").Send(awakeable.Id(), 0) if err := awakeable.Result(restate.Void{}); err != nil { return restate.Void{}, err } diff --git a/test-services/listobject.go b/test-services/listobject.go index 9a6c187..5aceedc 100644 --- a/test-services/listobject.go +++ b/test-services/listobject.go @@ -1,6 +1,8 @@ package main import ( + "errors" + restate "github.com/restatedev/sdk-go" ) @@ -12,17 +14,17 @@ func init() { Handler("append", restate.NewObjectHandler( func(ctx restate.ObjectContext, value string) (restate.Void, error) { list, err := restate.GetAs[[]string](ctx, LIST_KEY) - if err != nil && err != restate.ErrKeyNotFound { + if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { return restate.Void{}, err } - list = append(list, value) - return restate.Void{}, ctx.Set(LIST_KEY, list) + ctx.Set(LIST_KEY, list) + return restate.Void{}, nil })). Handler("get", restate.NewObjectHandler( func(ctx restate.ObjectContext, _ restate.Void) ([]string, error) { list, err := restate.GetAs[[]string](ctx, LIST_KEY) - if err != nil && err != restate.ErrKeyNotFound { + if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { return nil, err } if list == nil { @@ -35,7 +37,7 @@ func init() { Handler("clear", restate.NewObjectHandler( func(ctx restate.ObjectContext, _ restate.Void) ([]string, error) { list, err := restate.GetAs[[]string](ctx, LIST_KEY) - if err != nil && err != restate.ErrKeyNotFound { + if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { return nil, err } if list == nil { diff --git a/test-services/mapobject.go b/test-services/mapobject.go index 1332403..3a06c73 100644 --- a/test-services/mapobject.go +++ b/test-services/mapobject.go @@ -14,15 +14,12 @@ func init() { restate.NewObject("MapObject"). Handler("set", restate.NewObjectHandler( func(ctx restate.ObjectContext, value Entry) (restate.Void, error) { - return restate.Void{}, ctx.Set(value.Key, value.Value) + ctx.Set(value.Key, value.Value) + return restate.Void{}, nil })). Handler("get", restate.NewObjectHandler( func(ctx restate.ObjectContext, key string) (string, error) { - value, err := restate.GetAs[string](ctx, key) - if err != nil && err != restate.ErrKeyNotFound { - return "", err - } - return value, nil + return restate.GetAs[string](ctx, key) })). Handler("clearAll", restate.NewObjectHandler( func(ctx restate.ObjectContext, _ restate.Void) ([]Entry, error) { diff --git a/test-services/nondeterministic.go b/test-services/nondeterministic.go index c8bf03f..39010e6 100644 --- a/test-services/nondeterministic.go +++ b/test-services/nondeterministic.go @@ -22,8 +22,8 @@ func init() { invocationCounts[countKey] += 1 return invocationCounts[countKey]%2 == 1 } - incrementCounter := func(ctx restate.ObjectContext) error { - return ctx.Object("Counter", ctx.Key(), "add").Send(int64(1), 0) + incrementCounter := func(ctx restate.ObjectContext) { + ctx.Object("Counter", ctx.Key(), "add").Send(int64(1), 0) } REGISTRY.AddDefinition( @@ -40,7 +40,8 @@ func init() { // This is required to cause a suspension after the non-deterministic operation ctx.Sleep(100 * time.Millisecond) - return restate.Void{}, incrementCounter(ctx) + incrementCounter(ctx) + return restate.Void{}, nil })). Handler("callDifferentMethod", restate.NewObjectHandler( func(ctx restate.ObjectContext, _ restate.Void) (restate.Void, error) { @@ -56,38 +57,33 @@ func init() { // This is required to cause a suspension after the non-deterministic operation ctx.Sleep(100 * time.Millisecond) - return restate.Void{}, incrementCounter(ctx) + incrementCounter(ctx) + return restate.Void{}, nil })). Handler("backgroundInvokeWithDifferentTargets", restate.NewObjectHandler( func(ctx restate.ObjectContext, _ restate.Void) (restate.Void, error) { if doLeftAction(ctx) { - if err := ctx.Object("Counter", "abc", "get").Send(restate.Void{}, 0); err != nil { - return restate.Void{}, err - } + ctx.Object("Counter", "abc", "get").Send(restate.Void{}, 0) } else { - if err := ctx.Object("Counter", "abc", "reset").Send(restate.Void{}, 0); err != nil { - return restate.Void{}, err - } + ctx.Object("Counter", "abc", "reset").Send(restate.Void{}, 0) } // This is required to cause a suspension after the non-deterministic operation ctx.Sleep(100 * time.Millisecond) - return restate.Void{}, incrementCounter(ctx) + incrementCounter(ctx) + return restate.Void{}, nil })). Handler("setDifferentKey", restate.NewObjectHandler( func(ctx restate.ObjectContext, _ restate.Void) (restate.Void, error) { if doLeftAction(ctx) { - if err := ctx.Set(STATE_A, "my-state"); err != nil { - return restate.Void{}, err - } + ctx.Set(STATE_A, "my-state") } else { - if err := ctx.Set(STATE_B, "my-state"); err != nil { - return restate.Void{}, err - } + ctx.Set(STATE_B, "my-state") } // This is required to cause a suspension after the non-deterministic operation ctx.Sleep(100 * time.Millisecond) - return restate.Void{}, incrementCounter(ctx) + incrementCounter(ctx) + return restate.Void{}, nil }))) } diff --git a/test-services/proxy.go b/test-services/proxy.go index 1b7e775..a679316 100644 --- a/test-services/proxy.go +++ b/test-services/proxy.go @@ -47,7 +47,8 @@ func init() { // We need to use []int because Golang takes the opinionated choice of treating []byte as Base64 func(ctx restate.Context, req ProxyRequest) (restate.Void, error) { input := intArrayToByteArray(req.Message) - return restate.Void{}, req.ToTarget(ctx).Send(input, 0) + req.ToTarget(ctx).Send(input, 0) + return restate.Void{}, nil })). Handler("manyCalls", restate.NewServiceHandler( // We need to use []int because Golang takes the opinionated choice of treating []byte as Base64 @@ -57,14 +58,9 @@ func init() { for _, req := range requests { input := intArrayToByteArray(req.ProxyRequest.Message) if req.OneWayCall { - if err := req.ProxyRequest.ToTarget(ctx).Send(input, 0); err != nil { - return restate.Void{}, err - } + req.ProxyRequest.ToTarget(ctx).Send(input, 0) } else { - fut, err := req.ProxyRequest.ToTarget(ctx).RequestFuture(input) - if err != nil { - return restate.Void{}, err - } + fut := req.ProxyRequest.ToTarget(ctx).RequestFuture(input) if req.AwaitAtTheEnd { toAwait = append(toAwait, fut) } diff --git a/test-services/upgradetest.go b/test-services/upgradetest.go index d890bb6..d13dfac 100644 --- a/test-services/upgradetest.go +++ b/test-services/upgradetest.go @@ -24,15 +24,11 @@ func init() { return "", fmt.Errorf("executeComplex should not be invoked with version different from 1!") } awakeable := restate.AwakeableAs[string](ctx) - if err := ctx.Object("AwakeableHolder", "upgrade", "hold").Send(awakeable.Id(), 0); err != nil { - return "", err - } + ctx.Object("AwakeableHolder", "upgrade", "hold").Send(awakeable.Id(), 0) if _, err := awakeable.Result(); err != nil { return "", err } - if err := ctx.Object("ListObject", "upgrade-test", "append").Send(version(), 0); err != nil { - return "", err - } + ctx.Object("ListObject", "upgrade-test", "append").Send(version(), 0) return version(), nil }))) }