From 2d25302a2c530d973f5fe51d1bca104d88000583 Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Wed, 21 Aug 2024 14:51:33 +0100 Subject: [PATCH] Take all the methods out of ctx --- context.go | 112 +----------- error.go | 16 +- .../codegen/proto/helloworld_restate.pb.go | 30 ++-- facilitators.go | 160 +++++++++++------- handler.go | 47 ++--- interfaces/interfaces.go | 67 -------- internal/errors/error.go | 18 ++ internal/state/awakeable.go | 25 +-- internal/state/call.go | 24 ++- internal/state/completion.go | 4 +- internal/state/interfaces.go | 27 +++ internal/state/state.go | 107 ++++++------ internal/state/state_test.go | 18 +- internal/state/sys.go | 53 +++--- reflect.go | 15 +- router.go | 23 +-- test-services/proxy.go | 13 +- test-services/testutils.go | 3 +- 18 files changed, 339 insertions(+), 423 deletions(-) delete mode 100644 interfaces/interfaces.go create mode 100644 internal/state/interfaces.go diff --git a/context.go b/context.go index fea02ef..5c1a33b 100644 --- a/context.go +++ b/context.go @@ -3,132 +3,36 @@ package restate import ( "context" "log/slog" - "time" - "github.com/restatedev/sdk-go/interfaces" - "github.com/restatedev/sdk-go/internal/futures" - "github.com/restatedev/sdk-go/internal/options" - "github.com/restatedev/sdk-go/internal/rand" + "github.com/restatedev/sdk-go/internal/state" ) -// Context is the base set of operations that all Restate handlers may perform. +// Context is passed to Restate service handlers and enables interaction with Restate type Context interface { RunContext - - // Rand returns a random source which will give deterministic results for a given invocation - // The source wraps the stdlib rand.Rand but with some extra helper methods - // This source is not safe for use inside .Run() - Rand() *rand.Rand - - // Sleep for the duration d. Can return a terminal error in the case where the invocation was cancelled mid-sleep. - Sleep(d time.Duration) error - // After is an alternative to Context.Sleep which allows you to complete other tasks concurrently - // with the sleep. This is particularly useful when combined with Context.Select to race between - // the sleep and other Selectable operations. - After(d time.Duration) interfaces.After - - // Service gets a Service request client by service and method name - // Note: use module-level [Service] to deserialise return values - Service(service, method string, opts ...options.ClientOption) interfaces.Client - - // Object gets an Object request client by service name, key and method name - // Note: use module-level [Object] to receive serialised values - Object(object, key, method string, opts ...options.ClientOption) interfaces.Client - - // Run runs the function (fn), storing final results (including terminal errors) - // durably in the journal, or otherwise for transient errors stopping execution - // so Restate can retry the invocation. Replays will produce the same value, so - // all non-deterministic operations (eg, generating a unique ID) *must* happen - // inside Run blocks. - // Note: use module-level [Run] to get typed output values instead of providing an output pointer - Run(fn func(ctx RunContext) (any, error), output any, opts ...options.RunOption) error - - // Awakeable returns a Restate awakeable; a 'promise' to a future - // value or error, that can be resolved or rejected by other services. - // Note: use module-level [Awakeable] to avoid having to pass a output pointer to Awakeable.Result() - Awakeable(options ...options.AwakeableOption) interfaces.Awakeable - // ResolveAwakeable allows an awakeable (not necessarily from this service) to be - // resolved with a particular value. - ResolveAwakeable(id string, value any, options ...options.ResolveAwakeableOption) - // ResolveAwakeable allows an awakeable (not necessarily from this service) to be - // rejected with a particular error. - RejectAwakeable(id string, reason error) - - // Select returns an iterator over blocking Restate operations (sleep, call, awakeable) - // which allows you to safely run them in parallel. The Selector will store the order - // that things complete in durably inside Restate, so that on replay the same order - // can be used. This avoids non-determinism. It is *not* safe to use goroutines or channels - // outside of [Run] functions, as they do not behave deterministically. - Select(futs ...futures.Selectable) interfaces.Selector + inner() *state.Context } -// RunContext methods are the only methods of [Context] that are safe to call from inside a .Run() -// Calling any other method inside a Run() will panic. +// RunContext is passed to [Run] closures and provides the limited set of Restate operations that are safe to use there. type RunContext interface { context.Context - // Log obtains a handle on a slog.Logger which already has some useful fields (invocationID and method) // By default, this logger will not output messages if the invocation is currently replaying // The log handler can be set with `.WithLogger()` on the server object Log() *slog.Logger // Request gives extra information about the request that started this invocation - Request() *Request + Request() *state.Request } -type Request struct { - // The unique id that identifies the current function invocation. This id is guaranteed to be - // unique across invocations, but constant across reties and suspensions. - ID []byte - // Request headers - the following headers capture the original invocation headers, as provided to - // the ingress. - Headers map[string]string - // Attempt headers - the following headers are sent by the restate runtime. - // These headers are attempt specific, generated by the restate runtime uniquely for each attempt. - // These headers might contain information such as the W3C trace context, and attempt specific information. - AttemptHeaders map[string][]string - // Raw unparsed request body - Body []byte -} - -// ObjectContext is an extension of [Context] which can be used in exclusive-mode Virtual Object handlers, +// ObjectContext is an extension of [Context] which is passed to exclusive-mode Virtual Object handlers. // giving mutable access to state. type ObjectContext interface { - Context - KeyValueReader - KeyValueWriter + ObjectSharedContext } -// ObjectContext is an extension of [Context] which can be used in shared-mode Virtual Object handlers, +// ObjectContext is an extension of [Context] which is passed to shared-mode Virtual Object handlers, // giving read-only access to a snapshot of state. type ObjectSharedContext interface { Context - KeyValueReader -} - -// KeyValueReader is the set of read-only methods which can be used in all Virtual Object handlers. -type KeyValueReader interface { - // Get gets value associated with key and stores it in value - // If key does not exist, this function returns [ErrKeyNotFound] - // If the invocation was cancelled while obtaining the state (only possible if eager state is disabled), - // a cancellation error is returned. - // Note: Use GetAs generic helper function to avoid passing in a value pointer - Get(key string, value any, options ...options.GetOption) error - // Keys returns a list of all associated key - // If the invocation was cancelled while obtaining the state (only possible if eager state is disabled), - // a cancellation error is returned. - Keys() ([]string, error) - // Key retrieves the key for this virtual object invocation. This is a no-op and is - // always safe to call. - Key() string -} - -// KeyValueWriter is the set of mutating methods which can be used in exclusive-mode Virtual Object handlers. -type KeyValueWriter interface { - // Set sets a value against a key, using the provided codec (defaults to JSON) - Set(key string, value any, options ...options.SetOption) - // Clear deletes a key - Clear(key string) - // ClearAll drops all stored state associated with this Object key - ClearAll() } diff --git a/error.go b/error.go index 25c2599..b821d02 100644 --- a/error.go +++ b/error.go @@ -1,7 +1,6 @@ package restate import ( - stderrors "errors" "fmt" "github.com/restatedev/sdk-go/internal/errors" @@ -41,19 +40,10 @@ func TerminalErrorf(format string, a ...any) error { // IsTerminalError checks if err is terminal - ie, that returning it in a handler or Run function will finish // the invocation with the error as a result. func IsTerminalError(err error) bool { - if err == nil { - return false - } - var t *errors.TerminalError - return stderrors.As(err, &t) + return errors.IsTerminalError(err) } // ErrorCode returns [Code] associated with error, defaulting to 500 -func ErrorCode(err error) Code { - var e *errors.CodeError - if stderrors.As(err, &e) { - return e.Code - } - - return 500 +func ErrorCode(err error) errors.Code { + return errors.ErrorCode(err) } diff --git a/examples/codegen/proto/helloworld_restate.pb.go b/examples/codegen/proto/helloworld_restate.pb.go index 0e33b4a..01f3c21 100644 --- a/examples/codegen/proto/helloworld_restate.pb.go +++ b/examples/codegen/proto/helloworld_restate.pb.go @@ -13,7 +13,7 @@ import ( // GreeterClient is the client API for Greeter service. type GreeterClient interface { - SayHello(opts ...sdk_go.ClientOption) sdk_go.TypedClient[*HelloRequest, *HelloResponse] + SayHello(opts ...sdk_go.ClientOption) sdk_go.Client[*HelloRequest, *HelloResponse] } type greeterClient struct { @@ -28,12 +28,12 @@ func NewGreeterClient(ctx sdk_go.Context, opts ...sdk_go.ClientOption) GreeterCl cOpts, } } -func (c *greeterClient) SayHello(opts ...sdk_go.ClientOption) sdk_go.TypedClient[*HelloRequest, *HelloResponse] { +func (c *greeterClient) SayHello(opts ...sdk_go.ClientOption) sdk_go.Client[*HelloRequest, *HelloResponse] { cOpts := c.options if len(opts) > 0 { cOpts = append(append([]sdk_go.ClientOption{}, cOpts...), opts...) } - return sdk_go.NewTypedClient[*HelloRequest, *HelloResponse](c.ctx.Service("Greeter", "SayHello", cOpts...)) + return sdk_go.WithRequestType[*HelloRequest](sdk_go.Service[*HelloResponse](c.ctx, "Greeter", "SayHello", cOpts...)) } // GreeterServer is the server API for Greeter service. @@ -79,13 +79,13 @@ func NewGreeterServer(srv GreeterServer, opts ...sdk_go.ServiceDefinitionOption) // CounterClient is the client API for Counter service. type CounterClient interface { // Mutate the value - Add(opts ...sdk_go.ClientOption) sdk_go.TypedClient[*AddRequest, *GetResponse] + Add(opts ...sdk_go.ClientOption) sdk_go.Client[*AddRequest, *GetResponse] // Get the current value - Get(opts ...sdk_go.ClientOption) sdk_go.TypedClient[*GetRequest, *GetResponse] + Get(opts ...sdk_go.ClientOption) sdk_go.Client[*GetRequest, *GetResponse] // Internal method to store an awakeable ID for the Watch method - AddWatcher(opts ...sdk_go.ClientOption) sdk_go.TypedClient[*AddWatcherRequest, *AddWatcherResponse] + AddWatcher(opts ...sdk_go.ClientOption) sdk_go.Client[*AddWatcherRequest, *AddWatcherResponse] // Wait for the counter to change and then return the new value - Watch(opts ...sdk_go.ClientOption) sdk_go.TypedClient[*WatchRequest, *GetResponse] + Watch(opts ...sdk_go.ClientOption) sdk_go.Client[*WatchRequest, *GetResponse] } type counterClient struct { @@ -102,36 +102,36 @@ func NewCounterClient(ctx sdk_go.Context, key string, opts ...sdk_go.ClientOptio cOpts, } } -func (c *counterClient) Add(opts ...sdk_go.ClientOption) sdk_go.TypedClient[*AddRequest, *GetResponse] { +func (c *counterClient) Add(opts ...sdk_go.ClientOption) sdk_go.Client[*AddRequest, *GetResponse] { cOpts := c.options if len(opts) > 0 { cOpts = append(append([]sdk_go.ClientOption{}, cOpts...), opts...) } - return sdk_go.NewTypedClient[*AddRequest, *GetResponse](c.ctx.Object("Counter", c.key, "Add", cOpts...)) + return sdk_go.WithRequestType[*AddRequest](sdk_go.Object[*GetResponse](c.ctx, "Counter", c.key, "Add", cOpts...)) } -func (c *counterClient) Get(opts ...sdk_go.ClientOption) sdk_go.TypedClient[*GetRequest, *GetResponse] { +func (c *counterClient) Get(opts ...sdk_go.ClientOption) sdk_go.Client[*GetRequest, *GetResponse] { cOpts := c.options if len(opts) > 0 { cOpts = append(append([]sdk_go.ClientOption{}, cOpts...), opts...) } - return sdk_go.NewTypedClient[*GetRequest, *GetResponse](c.ctx.Object("Counter", c.key, "Get", cOpts...)) + return sdk_go.WithRequestType[*GetRequest](sdk_go.Object[*GetResponse](c.ctx, "Counter", c.key, "Get", cOpts...)) } -func (c *counterClient) AddWatcher(opts ...sdk_go.ClientOption) sdk_go.TypedClient[*AddWatcherRequest, *AddWatcherResponse] { +func (c *counterClient) AddWatcher(opts ...sdk_go.ClientOption) sdk_go.Client[*AddWatcherRequest, *AddWatcherResponse] { cOpts := c.options if len(opts) > 0 { cOpts = append(append([]sdk_go.ClientOption{}, cOpts...), opts...) } - return sdk_go.NewTypedClient[*AddWatcherRequest, *AddWatcherResponse](c.ctx.Object("Counter", c.key, "AddWatcher", cOpts...)) + return sdk_go.WithRequestType[*AddWatcherRequest](sdk_go.Object[*AddWatcherResponse](c.ctx, "Counter", c.key, "AddWatcher", cOpts...)) } -func (c *counterClient) Watch(opts ...sdk_go.ClientOption) sdk_go.TypedClient[*WatchRequest, *GetResponse] { +func (c *counterClient) Watch(opts ...sdk_go.ClientOption) sdk_go.Client[*WatchRequest, *GetResponse] { cOpts := c.options if len(opts) > 0 { cOpts = append(append([]sdk_go.ClientOption{}, cOpts...), opts...) } - return sdk_go.NewTypedClient[*WatchRequest, *GetResponse](c.ctx.Object("Counter", c.key, "Watch", cOpts...)) + return sdk_go.WithRequestType[*WatchRequest](sdk_go.Object[*GetResponse](c.ctx, "Counter", c.key, "Watch", cOpts...)) } // CounterServer is the server API for Counter service. diff --git a/facilitators.go b/facilitators.go index 9e9099f..21d122d 100644 --- a/facilitators.go +++ b/facilitators.go @@ -4,87 +4,120 @@ import ( "errors" "time" - "github.com/restatedev/sdk-go/interfaces" "github.com/restatedev/sdk-go/internal/futures" "github.com/restatedev/sdk-go/internal/options" "github.com/restatedev/sdk-go/internal/rand" + "github.com/restatedev/sdk-go/internal/state" ) // Rand returns a random source which will give deterministic results for a given invocation // The source wraps the stdlib rand.Rand but with some extra helper methods // This source is not safe for use inside .Run() func Rand(ctx Context) *rand.Rand { - return ctx.Rand() + return ctx.inner().Rand() } // Sleep for the duration d. Can return a terminal error in the case where the invocation was cancelled mid-sleep. func Sleep(ctx Context, d time.Duration) error { - return ctx.Sleep(d) + return ctx.inner().Sleep(d) } // After is an alternative to [Sleep] which allows you to complete other tasks concurrently // with the sleep. This is particularly useful when combined with [Select] to race between // the sleep and other Selectable operations. -func After(ctx Context, d time.Duration) interfaces.After { - return ctx.After(d) +func After(ctx Context, d time.Duration) AfterFuture { + return ctx.inner().After(d) +} + +// After is a handle on a Sleep operation which allows you to do other work concurrently +// with the sleep. +type AfterFuture interface { + // Done blocks waiting on the remaining duration of the sleep. + // It is *not* safe to call this in a goroutine - use Context.Select if you want to wait on multiple + // results at once. Can return a terminal error in the case where the invocation was cancelled mid-sleep, + // hence Done() should always be called, even after using Context.Select. + Done() error + futures.Selectable } // Service gets a Service request client by service and method name -func Service[O any](ctx Context, service string, method string, options ...options.ClientOption) TypedClient[any, O] { - return typedClient[any, O]{ctx.Service(service, method, options...)} +func Service[O any](ctx Context, service string, method string, options ...options.ClientOption) Client[any, O] { + return outputClient[O]{ctx.inner().Service(service, method, options...)} } // Service gets a Service send client by service and method name -func ServiceSend(ctx Context, service string, method string, options ...options.ClientOption) interfaces.SendClient { - return ctx.Service(service, method, options...) +func ServiceSend(ctx Context, service string, method string, options ...options.ClientOption) SendClient[any] { + return ctx.inner().Service(service, method, options...) } // Object gets an Object request client by service name, key and method name -func Object[O any](ctx Context, service string, key string, method string, options ...options.ClientOption) TypedClient[any, O] { - return typedClient[any, O]{ctx.Object(service, key, method, options...)} +func Object[O any](ctx Context, service string, key string, method string, options ...options.ClientOption) Client[any, O] { + return outputClient[O]{ctx.inner().Object(service, key, method, options...)} } // ObjectSend gets an Object send client by service name, key and method name -func ObjectSend(ctx Context, service string, key string, method string, options ...options.ClientOption) interfaces.SendClient { - return ctx.Object(service, key, method, options...) +func ObjectSend(ctx Context, service string, key string, method string, options ...options.ClientOption) SendClient[any] { + return ctx.inner().Object(service, key, method, options...) } -// TypedClient is an extension of [interfaces.Client] and [interfaces.SendClient] which deals in typed values -type TypedClient[I any, O any] interface { +// Client represents all the different ways you can invoke a particular service-method. +type Client[I any, O any] interface { // RequestFuture makes a call and returns a handle on a future response - RequestFuture(input I, options ...options.RequestOption) TypedResponseFuture[O] + RequestFuture(input I, options ...options.RequestOption) ResponseFuture[O] // Request makes a call and blocks on getting the response Request(input I, options ...options.RequestOption) (O, error) + SendClient[I] +} + +// SendClient allows making one-way invocations +type SendClient[I any] interface { // Send makes a one-way call which is executed in the background Send(input I, options ...options.SendOption) } -type typedClient[I any, O any] struct { - inner interfaces.Client +type outputClient[O any] struct { + inner *state.Client +} + +func (t outputClient[O]) Request(input any, options ...options.RequestOption) (output O, err error) { + err = t.inner.RequestFuture(input, options...).Response(&output) + return +} + +func (t outputClient[O]) RequestFuture(input any, options ...options.RequestOption) ResponseFuture[O] { + return responseFuture[O]{t.inner.RequestFuture(input, options...)} } -// NewTypedClient is primarily intended to be called from generated code, to provide +func (t outputClient[O]) Send(input any, options ...options.SendOption) { + t.inner.Send(input, options...) +} + +type client[I any, O any] struct { + inner Client[any, O] +} + +// WithRequestType is primarily intended to be called from generated code, to provide // type safety of input types. In other contexts it's generally less cumbersome to use [CallAs], // as the output type can be inferred. -func NewTypedClient[I any, O any](client interfaces.Client) TypedClient[I, O] { - return typedClient[I, O]{client} +func WithRequestType[I any, O any](inner Client[any, O]) Client[I, O] { + return client[I, O]{inner} } -func (t typedClient[I, O]) Request(input I, options ...options.RequestOption) (output O, err error) { - err = t.inner.RequestFuture(input, options...).Response(&output) +func (t client[I, O]) Request(input I, options ...options.RequestOption) (output O, err error) { + output, err = t.inner.RequestFuture(input, options...).Response() return } -func (t typedClient[I, O]) RequestFuture(input I, options ...options.RequestOption) TypedResponseFuture[O] { - return typedResponseFuture[O]{t.inner.RequestFuture(input, options...)} +func (t client[I, O]) RequestFuture(input I, options ...options.RequestOption) ResponseFuture[O] { + return t.inner.RequestFuture(input, options...) } -func (t typedClient[I, O]) Send(input I, options ...options.SendOption) { +func (t client[I, O]) Send(input I, options ...options.SendOption) { t.inner.Send(input, options...) } -// TypedResponseFuture is an extension of [ResponseFuture] which returns typed responses instead of accepting a pointer -type TypedResponseFuture[O any] interface { +// ResponseFuture is a handle on a potentially not-yet completed outbound call. +type ResponseFuture[O any] interface { // Response blocks on the response to the call and returns it or the associated error // It is *not* safe to call this in a goroutine - use Context.Select if you // want to wait on multiple results at once. @@ -92,23 +125,23 @@ type TypedResponseFuture[O any] interface { futures.Selectable } -type typedResponseFuture[O any] struct { - interfaces.ResponseFuture +type responseFuture[O any] struct { + state.DecodingResponseFuture } -func (t typedResponseFuture[O]) Response() (output O, err error) { - err = t.ResponseFuture.Response(&output) +func (t responseFuture[O]) Response() (output O, err error) { + err = t.DecodingResponseFuture.Response(&output) return } // Awakeable returns a Restate awakeable; a 'promise' to a future // value or error, that can be resolved or rejected by other services. -func Awakeable[T any](ctx Context, options ...options.AwakeableOption) TypedAwakeable[T] { - return typedAwakeable[T]{ctx.Awakeable(options...)} +func Awakeable[T any](ctx Context, options ...options.AwakeableOption) AwakeableFuture[T] { + return awakeable[T]{ctx.inner().Awakeable(options...)} } -// TypedAwakeable is an extension of [Awakeable] which returns typed responses instead of accepting a pointer -type TypedAwakeable[T any] interface { +// AwakeableFuture is a 'promise' to a future value or error, that can be resolved or rejected by other services. +type AwakeableFuture[T any] interface { // Id returns the awakeable ID, which can be stored or sent to a another service Id() string // Result blocks on receiving the result of the awakeable, storing the value it was @@ -119,29 +152,42 @@ type TypedAwakeable[T any] interface { futures.Selectable } -type typedAwakeable[T any] struct { - interfaces.Awakeable +type awakeable[T any] struct { + state.DecodingAwakeable } -func (t typedAwakeable[T]) Result() (output T, err error) { - err = t.Awakeable.Result(&output) +func (t awakeable[T]) Result() (output T, err error) { + err = t.DecodingAwakeable.Result(&output) return } // ResolveAwakeable allows an awakeable (not necessarily from this service) to be // resolved with a particular value. func ResolveAwakeable[T any](ctx Context, id string, value T, options ...options.ResolveAwakeableOption) { - ctx.ResolveAwakeable(id, value, options...) + ctx.inner().ResolveAwakeable(id, value, options...) } // ResolveAwakeable allows an awakeable (not necessarily from this service) to be // rejected with a particular error. func RejectAwakeable[T any](ctx Context, id string, reason error) { - ctx.RejectAwakeable(id, reason) + ctx.inner().RejectAwakeable(id, reason) +} + +func Select(ctx Context, futs ...futures.Selectable) Selector { + return ctx.inner().Select(futs...) } -func Select(ctx Context, futs ...interfaces.Selectable) interfaces.Selector { - return ctx.Select(futs...) +type Selectable = futures.Selectable + +// Selector is an iterator over a list of blocking Restate operations that are running +// in the background. +type Selector interface { + // Remaining returns whether there are still operations that haven't been returned by Select(). + // There will always be exactly the same number of results as there were operations + // given to Context.Select + Remaining() bool + // Select blocks on the next completed operation or returns nil if there are none left + Select() futures.Selectable } // Run runs the function (fn), storing final results (including terminal errors) @@ -150,7 +196,7 @@ func Select(ctx Context, futs ...interfaces.Selectable) interfaces.Selector { // all non-deterministic operations (eg, generating a unique ID) *must* happen // inside Run blocks. func Run[T any](ctx Context, fn func(ctx RunContext) (T, error), options ...options.RunOption) (output T, err error) { - err = ctx.Run(func(ctx RunContext) (any, error) { + err = ctx.inner().Run(func(ctx state.RunContext) (any, error) { return fn(ctx) }, &output, options...) @@ -161,8 +207,8 @@ func Run[T any](ctx Context, fn func(ctx RunContext) (T, error), options ...opti // To check explicitly for this case use ctx.Get directly or pass a pointer eg *string as T. // If the invocation was cancelled while obtaining the state (only possible if eager state is disabled), // a cancellation error is returned. -func Get[T any](ctx KeyValueReader, key string, options ...options.GetOption) (output T, err error) { - if err := ctx.Get(key, &output, options...); !errors.Is(err, ErrKeyNotFound) { +func Get[T any](ctx ObjectSharedContext, key string, options ...options.GetOption) (output T, err error) { + if err := ctx.inner().Get(key, &output, options...); !errors.Is(err, ErrKeyNotFound) { return output, err } else { return output, nil @@ -171,27 +217,27 @@ func Get[T any](ctx KeyValueReader, key string, options ...options.GetOption) (o // If the invocation was cancelled while obtaining the state (only possible if eager state is disabled), // a cancellation error is returned. -func Keys(ctx KeyValueReader) ([]string, error) { - return ctx.Keys() +func Keys(ctx ObjectSharedContext) ([]string, error) { + return ctx.inner().Keys() } // Key retrieves the key for this virtual object invocation. This is a no-op and is // always safe to call. -func Key(ctx KeyValueReader) string { - return ctx.Key() +func Key(ctx ObjectSharedContext) string { + return ctx.inner().Key() } // Set sets a value against a key, using the provided codec (defaults to JSON) -func Set[T any](ctx KeyValueWriter, key string, value T, options ...options.SetOption) { - ctx.Set(key, value, options...) +func Set[T any](ctx ObjectContext, key string, value T, options ...options.SetOption) { + ctx.inner().Set(key, value, options...) } // Clear deletes a key -func Clear(ctx KeyValueWriter, key string) { - ctx.Clear(key) +func Clear(ctx ObjectContext, key string) { + ctx.inner().Clear(key) } // ClearAll drops all stored state associated with this Object key -func ClearAll(ctx KeyValueWriter) { - ctx.ClearAll() +func ClearAll(ctx ObjectContext) { + ctx.inner().ClearAll() } diff --git a/handler.go b/handler.go index 2d1641b..c54d180 100644 --- a/handler.go +++ b/handler.go @@ -7,6 +7,7 @@ import ( "github.com/restatedev/sdk-go/encoding" "github.com/restatedev/sdk-go/internal" "github.com/restatedev/sdk-go/internal/options" + "github.com/restatedev/sdk-go/internal/state" ) // Void is a placeholder to signify 'no value' where a type is otherwise needed. It can be used in several contexts: @@ -18,26 +19,6 @@ import ( // 5. The output type for an awakeable - the result body will be ignored. A pointer is also accepted. type Void = encoding.Void -// ObjectHandler is the required set of methods for a Virtual Object handler. -type ObjectHandler interface { - Call(ctx ObjectContext, request []byte) (output []byte, err error) - Handler -} - -// ServiceHandler is the required set of methods for a Service handler. -type ServiceHandler interface { - Call(ctx Context, request []byte) (output []byte, err error) - Handler -} - -// Handler is implemented by all Restate handlers -type Handler interface { - getOptions() *options.HandlerOptions - InputPayload() *encoding.InputPayload - OutputPayload() *encoding.OutputPayload - HandlerType() *internal.ServiceHandlerType -} - // ServiceHandlerFn is the signature for a Service handler function type ServiceHandlerFn[I any, O any] func(ctx Context, input I) (O, error) @@ -52,7 +33,7 @@ type serviceHandler[I any, O any] struct { options options.HandlerOptions } -var _ ServiceHandler = (*serviceHandler[struct{}, struct{}])(nil) +var _ state.ServiceHandler = (*serviceHandler[struct{}, struct{}])(nil) // NewServiceHandler converts a function of signature [ServiceHandlerFn] into a handler on a Restate service. func NewServiceHandler[I any, O any](fn ServiceHandlerFn[I, O], opts ...options.HandlerOption) *serviceHandler[I, O] { @@ -66,14 +47,14 @@ func NewServiceHandler[I any, O any](fn ServiceHandlerFn[I, O], opts ...options. } } -func (h *serviceHandler[I, O]) Call(ctx Context, bytes []byte) ([]byte, error) { +func (h *serviceHandler[I, O]) Call(ctx *state.Context, bytes []byte) ([]byte, error) { var input I if err := encoding.Unmarshal(h.options.Codec, bytes, &input); err != nil { return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) } output, err := h.fn( - ctx, + ctxWrapper{ctx}, input, ) if err != nil { @@ -103,7 +84,7 @@ func (h *serviceHandler[I, O]) HandlerType() *internal.ServiceHandlerType { return nil } -func (h *serviceHandler[I, O]) getOptions() *options.HandlerOptions { +func (h *serviceHandler[I, O]) GetOptions() *options.HandlerOptions { return &h.options } @@ -115,7 +96,7 @@ type objectHandler[I any, O any] struct { handlerType internal.ServiceHandlerType } -var _ ObjectHandler = (*objectHandler[struct{}, struct{}])(nil) +var _ state.ObjectHandler = (*objectHandler[struct{}, struct{}])(nil) // NewObjectHandler converts a function of signature [ObjectHandlerFn] into an exclusive-mode handler on a Virtual Object. // The handler will have access to a full [ObjectContext] which may mutate state. @@ -145,7 +126,15 @@ func NewObjectSharedHandler[I any, O any](fn ObjectSharedHandlerFn[I, O], opts . } } -func (h *objectHandler[I, O]) Call(ctx ObjectContext, bytes []byte) ([]byte, error) { +type ctxWrapper struct { + *state.Context +} + +func (o ctxWrapper) inner() *state.Context { + return o.Context +} + +func (h *objectHandler[I, O]) Call(ctx *state.Context, bytes []byte) ([]byte, error) { var input I if err := encoding.Unmarshal(h.options.Codec, bytes, &input); err != nil { return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) @@ -156,12 +145,12 @@ func (h *objectHandler[I, O]) Call(ctx ObjectContext, bytes []byte) ([]byte, err switch h.handlerType { case internal.ServiceHandlerType_EXCLUSIVE: output, err = h.exclusiveFn( - ctx, + ctxWrapper{ctx}, input, ) case internal.ServiceHandlerType_SHARED: output, err = h.sharedFn( - ctx, + ctxWrapper{ctx}, input, ) } @@ -188,7 +177,7 @@ func (h *objectHandler[I, O]) OutputPayload() *encoding.OutputPayload { return encoding.OutputPayloadFor(h.options.Codec, o) } -func (h *objectHandler[I, O]) getOptions() *options.HandlerOptions { +func (h *objectHandler[I, O]) GetOptions() *options.HandlerOptions { return &h.options } diff --git a/interfaces/interfaces.go b/interfaces/interfaces.go deleted file mode 100644 index 6c33892..0000000 --- a/interfaces/interfaces.go +++ /dev/null @@ -1,67 +0,0 @@ -package interfaces - -import ( - "github.com/restatedev/sdk-go/internal/futures" - "github.com/restatedev/sdk-go/internal/options" -) - -type Selectable = futures.Selectable - -// After is a handle on a Sleep operation which allows you to do other work concurrently -// with the sleep. -type After interface { - // Done blocks waiting on the remaining duration of the sleep. - // It is *not* safe to call this in a goroutine - use Context.Select if you want to wait on multiple - // results at once. Can return a terminal error in the case where the invocation was cancelled mid-sleep, - // hence Done() should always be called, even after using Context.Select. - Done() error - Selectable -} - -// Awakeable is the Go representation of a Restate awakeable; a 'promise' to a future -// value or error, that can be resolved or rejected by other services. -type Awakeable interface { - // Id returns the awakeable ID, which can be stored or sent to a another service - Id() string - // Result blocks on receiving the result of the awakeable, storing the value it was - // resolved with in output or otherwise returning the error it was rejected with. - // It is *not* safe to call this in a goroutine - use Context.Select if you - // want to wait on multiple results at once. - // Note: use the AwakeableAs helper function to avoid having to pass a output pointer - Result(output any) error - Selectable -} - -// Client represents all the different ways you can invoke a particular service/key/method tuple. -type Client interface { - // RequestFuture makes a call and returns a handle on a future response - RequestFuture(input any, options ...options.RequestOption) ResponseFuture - // Request makes a call and blocks on getting the response which is stored in output - Request(input any, output any, options ...options.RequestOption) error - SendClient -} - -type SendClient interface { - // Send makes a one-way call which is executed in the background - Send(input any, options ...options.SendOption) -} - -// ResponseFuture is a handle on a potentially not-yet completed outbound call. -type ResponseFuture interface { - // Response blocks on the response to the call and stores it in output, or returns the associated error - // It is *not* safe to call this in a goroutine - use Context.Select if you - // want to wait on multiple results at once. - Response(output any) error - Selectable -} - -// Selector is an iterator over a list of blocking Restate operations that are running -// in the background. -type Selector interface { - // Remaining returns whether there are still operations that haven't been returned by Select(). - // There will always be exactly the same number of results as there were operations - // given to Context.Select - Remaining() bool - // Select blocks on the next completed operation or returns nil if there are none left - Select() Selectable -} diff --git a/internal/errors/error.go b/internal/errors/error.go index 5f62648..c5d16c8 100644 --- a/internal/errors/error.go +++ b/internal/errors/error.go @@ -1,6 +1,7 @@ package errors import ( + "errors" "fmt" protocol "github.com/restatedev/sdk-go/generated/dev/restate/service" @@ -30,6 +31,15 @@ func (e *CodeError) Unwrap() error { return e.Inner } +func ErrorCode(err error) Code { + var e *CodeError + if errors.As(err, &e) { + return e.Code + } + + return 500 +} + type TerminalError struct { Inner error } @@ -42,6 +52,14 @@ func (e *TerminalError) Unwrap() error { return e.Inner } +func IsTerminalError(err error) bool { + if err == nil { + return false + } + var t *TerminalError + return errors.As(err, &t) +} + func ErrorFromFailure(failure *protocol.Failure) error { return &CodeError{Inner: &TerminalError{Inner: fmt.Errorf(failure.Message)}, Code: Code(failure.Code)} } diff --git a/internal/state/awakeable.go b/internal/state/awakeable.go index dfc4e96..a56ca46 100644 --- a/internal/state/awakeable.go +++ b/internal/state/awakeable.go @@ -3,8 +3,9 @@ package state import ( "bytes" - restate "github.com/restatedev/sdk-go" + "github.com/restatedev/sdk-go/encoding" protocol "github.com/restatedev/sdk-go/generated/dev/restate/service" + "github.com/restatedev/sdk-go/internal/errors" "github.com/restatedev/sdk-go/internal/futures" "github.com/restatedev/sdk-go/internal/wire" ) @@ -30,7 +31,7 @@ func (c *Machine) _awakeable() *wire.AwakeableEntryMessage { func (m *Machine) resolveAwakeable(id string, value []byte) { _, _ = replayOrNew( m, - func(entry *wire.CompleteAwakeableEntryMessage) restate.Void { + func(entry *wire.CompleteAwakeableEntryMessage) encoding.Void { messageValue, ok := entry.Result.(*protocol.CompleteAwakeableEntryMessage_Value) if entry.Id != id || !ok || !bytes.Equal(messageValue.Value, value) { panic(m.newEntryMismatch(&wire.CompleteAwakeableEntryMessage{ @@ -40,11 +41,11 @@ func (m *Machine) resolveAwakeable(id string, value []byte) { }, }, entry)) } - return restate.Void{} + return encoding.Void{} }, - func() restate.Void { + func() encoding.Void { m._resolveAwakeable(id, value) - return restate.Void{} + return encoding.Void{} }, ) } @@ -61,24 +62,24 @@ func (c *Machine) _resolveAwakeable(id string, value []byte) { func (m *Machine) rejectAwakeable(id string, reason error) { _, _ = replayOrNew( m, - func(entry *wire.CompleteAwakeableEntryMessage) restate.Void { + func(entry *wire.CompleteAwakeableEntryMessage) encoding.Void { messageFailure, ok := entry.Result.(*protocol.CompleteAwakeableEntryMessage_Failure) - if entry.Id != id || !ok || messageFailure.Failure.Code != uint32(restate.ErrorCode(reason)) || messageFailure.Failure.Message != reason.Error() { + if entry.Id != id || !ok || messageFailure.Failure.Code != uint32(errors.ErrorCode(reason)) || messageFailure.Failure.Message != reason.Error() { panic(m.newEntryMismatch(&wire.CompleteAwakeableEntryMessage{ CompleteAwakeableEntryMessage: protocol.CompleteAwakeableEntryMessage{ Id: id, Result: &protocol.CompleteAwakeableEntryMessage_Failure{Failure: &protocol.Failure{ - Code: uint32(restate.ErrorCode(reason)), + Code: uint32(errors.ErrorCode(reason)), Message: reason.Error(), }}, }, }, entry)) } - return restate.Void{} + return encoding.Void{} }, - func() restate.Void { + func() encoding.Void { m._rejectAwakeable(id, reason) - return restate.Void{} + return encoding.Void{} }, ) } @@ -88,7 +89,7 @@ func (c *Machine) _rejectAwakeable(id string, reason error) { CompleteAwakeableEntryMessage: protocol.CompleteAwakeableEntryMessage{ Id: id, Result: &protocol.CompleteAwakeableEntryMessage_Failure{Failure: &protocol.Failure{ - Code: uint32(restate.ErrorCode(reason)), + Code: uint32(errors.ErrorCode(reason)), Message: reason.Error(), }}, }, diff --git a/internal/state/call.go b/internal/state/call.go index c7e8bbb..baf9637 100644 --- a/internal/state/call.go +++ b/internal/state/call.go @@ -7,16 +7,14 @@ import ( "slices" "time" - restate "github.com/restatedev/sdk-go" "github.com/restatedev/sdk-go/encoding" protocol "github.com/restatedev/sdk-go/generated/dev/restate/service" - "github.com/restatedev/sdk-go/interfaces" "github.com/restatedev/sdk-go/internal/futures" "github.com/restatedev/sdk-go/internal/options" "github.com/restatedev/sdk-go/internal/wire" ) -type serviceCall struct { +type Client struct { options options.ClientOptions machine *Machine service string @@ -25,7 +23,7 @@ type serviceCall struct { } // RequestFuture makes a call and returns a handle on the response -func (c *serviceCall) RequestFuture(input any, opts ...options.RequestOption) interfaces.ResponseFuture { +func (c *Client) RequestFuture(input any, opts ...options.RequestOption) DecodingResponseFuture { o := options.RequestOptions{} for _, opt := range opts { opt.BeforeRequest(&o) @@ -38,20 +36,20 @@ func (c *serviceCall) RequestFuture(input any, opts ...options.RequestOption) in entry, entryIndex := c.machine.doCall(c.service, c.key, c.method, o.Headers, bytes) - return decodingResponseFuture{ + return DecodingResponseFuture{ futures.NewResponseFuture(c.machine.suspensionCtx, entry, entryIndex, func(err error) any { return c.machine.newProtocolViolation(entry, err) }), c.machine, c.options, } } -type decodingResponseFuture struct { +type DecodingResponseFuture struct { *futures.ResponseFuture machine *Machine options options.ClientOptions } -func (d decodingResponseFuture) Response(output any) (err error) { +func (d DecodingResponseFuture) Response(output any) (err error) { bytes, err := d.ResponseFuture.Response() if err != nil { return err @@ -65,12 +63,12 @@ func (d decodingResponseFuture) Response(output any) (err error) { } // Request makes a call and blocks on the response -func (c *serviceCall) Request(input any, output any, opts ...options.RequestOption) error { +func (c *Client) Request(input any, output any, opts ...options.RequestOption) error { return c.RequestFuture(input, opts...).Response(output) } // Send runs a call in the background after delay duration -func (c *serviceCall) Send(input any, opts ...options.SendOption) { +func (c *Client) Send(input any, opts ...options.SendOption) { o := options.SendOptions{} for _, opt := range opts { opt.BeforeSend(&o) @@ -162,7 +160,7 @@ func (m *Machine) sendCall(service, key, method string, headersMap map[string]st _, _ = replayOrNew( m, - func(entry *wire.OneWayCallEntryMessage) restate.Void { + func(entry *wire.OneWayCallEntryMessage) encoding.Void { if entry.ServiceName != service || entry.Key != key || entry.HandlerName != method || @@ -179,11 +177,11 @@ func (m *Machine) sendCall(service, key, method string, headersMap map[string]st }, entry)) } - return restate.Void{} + return encoding.Void{} }, - func() restate.Void { + func() encoding.Void { m._sendCall(service, key, method, headers, body, delay) - return restate.Void{} + return encoding.Void{} }, ) } diff --git a/internal/state/completion.go b/internal/state/completion.go index b8bdd65..db8eae9 100644 --- a/internal/state/completion.go +++ b/internal/state/completion.go @@ -48,7 +48,7 @@ func (m *Machine) Write(message wire.Message) { } typ := wire.MessageType(message) m.log.LogAttrs(m.ctx, log.LevelTrace, "Sending message to runtime", log.Stringer("type", typ)) - if err := m.protocol.Write(typ, message); err != nil { + if err := m.Protocol.Write(typ, message); err != nil { panic(m.newWriteError(message, err)) } } @@ -67,7 +67,7 @@ func (m *Machine) newWriteError(entry wire.Message, err error) *writeError { func (m *Machine) handleCompletionsAcks() { for { - msg, _, err := m.protocol.Read() + msg, _, err := m.Protocol.Read() if err != nil { if m.ctx.Err() == nil { m.log.LogAttrs(m.ctx, log.LevelTrace, "Request body closed; next blocking operation will suspend") diff --git a/internal/state/interfaces.go b/internal/state/interfaces.go new file mode 100644 index 0000000..967bb89 --- /dev/null +++ b/internal/state/interfaces.go @@ -0,0 +1,27 @@ +package state + +import ( + "github.com/restatedev/sdk-go/encoding" + "github.com/restatedev/sdk-go/internal" + "github.com/restatedev/sdk-go/internal/options" +) + +// ObjectHandler is the required set of methods for a Virtual Object handler. +type ObjectHandler interface { + Call(ctx *Context, request []byte) (output []byte, err error) + Handler +} + +// ServiceHandler is the required set of methods for a Service handler. +type ServiceHandler interface { + Call(ctx *Context, request []byte) (output []byte, err error) + Handler +} + +// Handler is implemented by all Restate handlers +type Handler interface { + GetOptions() *options.HandlerOptions + InputPayload() *encoding.InputPayload + OutputPayload() *encoding.OutputPayload + HandlerType() *internal.ServiceHandlerType +} diff --git a/internal/state/state.go b/internal/state/state.go index 003d5c3..cf11798 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -12,10 +12,8 @@ import ( "sync/atomic" "time" - restate "github.com/restatedev/sdk-go" "github.com/restatedev/sdk-go/encoding" protocol "github.com/restatedev/sdk-go/generated/dev/restate/service" - "github.com/restatedev/sdk-go/interfaces" "github.com/restatedev/sdk-go/internal/errors" "github.com/restatedev/sdk-go/internal/futures" "github.com/restatedev/sdk-go/internal/log" @@ -39,16 +37,11 @@ type Context struct { machine *Machine } -var _ restate.ObjectContext = &Context{} -var _ restate.ObjectSharedContext = &Context{} -var _ restate.Context = &Context{} -var _ restate.RunContext = &Context{} - func (c *Context) Log() *slog.Logger { return c.machine.userLog } -func (c *Context) Request() *restate.Request { +func (c *Context) Request() *Request { return &c.machine.request } @@ -114,11 +107,11 @@ func (c *Context) Sleep(d time.Duration) error { return c.machine.sleep(d) } -func (c *Context) After(d time.Duration) interfaces.After { +func (c *Context) After(d time.Duration) *futures.After { return c.machine.after(d) } -func (c *Context) Service(service, method string, opts ...options.ClientOption) interfaces.Client { +func (c *Context) Service(service, method string, opts ...options.ClientOption) *Client { o := options.ClientOptions{} for _, opt := range opts { opt.BeforeClient(&o) @@ -127,7 +120,7 @@ func (c *Context) Service(service, method string, opts ...options.ClientOption) o.Codec = encoding.JSONCodec } - return &serviceCall{ + return &Client{ options: o, machine: c.machine, service: service, @@ -135,7 +128,7 @@ func (c *Context) Service(service, method string, opts ...options.ClientOption) } } -func (c *Context) Object(service, key, method string, opts ...options.ClientOption) interfaces.Client { +func (c *Context) Object(service, key, method string, opts ...options.ClientOption) *Client { o := options.ClientOptions{} for _, opt := range opts { opt.BeforeClient(&o) @@ -144,7 +137,7 @@ func (c *Context) Object(service, key, method string, opts ...options.ClientOpti o.Codec = encoding.JSONCodec } - return &serviceCall{ + return &Client{ options: o, machine: c.machine, service: service, @@ -153,7 +146,7 @@ func (c *Context) Object(service, key, method string, opts ...options.ClientOpti } } -func (c *Context) Run(fn func(ctx restate.RunContext) (any, error), output any, opts ...options.RunOption) error { +func (c *Context) Run(fn func(ctx RunContext) (any, error), output any, opts ...options.RunOption) error { o := options.RunOptions{} for _, opt := range opts { opt.BeforeRun(&o) @@ -162,7 +155,7 @@ func (c *Context) Run(fn func(ctx restate.RunContext) (any, error), output any, o.Codec = encoding.JSONCodec } - bytes, err := c.machine.run(func(ctx restate.RunContext) ([]byte, error) { + bytes, err := c.machine.run(func(ctx RunContext) ([]byte, error) { output, err := fn(ctx) if err != nil { return nil, err @@ -194,7 +187,7 @@ type AwakeableOption interface { beforeAwakeable(*awakeableOptions) } -func (c *Context) Awakeable(opts ...options.AwakeableOption) interfaces.Awakeable { +func (c *Context) Awakeable(opts ...options.AwakeableOption) DecodingAwakeable { o := options.AwakeableOptions{} for _, opt := range opts { opt.BeforeAwakeable(&o) @@ -202,17 +195,17 @@ func (c *Context) Awakeable(opts ...options.AwakeableOption) interfaces.Awakeabl if o.Codec == nil { o.Codec = encoding.JSONCodec } - return decodingAwakeable{c.machine.awakeable(), c.machine, o.Codec} + return DecodingAwakeable{c.machine.awakeable(), c.machine, o.Codec} } -type decodingAwakeable struct { +type DecodingAwakeable struct { *futures.Awakeable machine *Machine codec encoding.Codec } -func (d decodingAwakeable) Id() string { return d.Awakeable.Id() } -func (d decodingAwakeable) Result(output any) (err error) { +func (d DecodingAwakeable) Id() string { return d.Awakeable.Id() } +func (d DecodingAwakeable) Result(output any) (err error) { bytes, err := d.Awakeable.Result() if err != nil { return err @@ -242,7 +235,7 @@ func (c *Context) RejectAwakeable(id string, reason error) { c.machine.rejectAwakeable(id, reason) } -func (c *Context) Select(futs ...futures.Selectable) interfaces.Selector { +func (c *Context) Select(futs ...futures.Selectable) *selector { return c.machine.selector(futs...) } @@ -267,12 +260,12 @@ type Machine struct { suspensionCtx context.Context suspend func(error) - handler restate.Handler - protocol wire.Protocol + handler Handler + Protocol wire.Protocol // state key string - request restate.Request + request Request partial bool current map[string][]byte @@ -294,17 +287,17 @@ type Machine struct { failure any } -func NewMachine(handler restate.Handler, conn io.ReadWriter, attemptHeaders map[string][]string) *Machine { +func NewMachine(handler Handler, conn io.ReadWriter, attemptHeaders map[string][]string) *Machine { m := &Machine{ handler: handler, current: make(map[string][]byte), pendingAcks: map[uint32]wire.AckableMessage{}, pendingCompletions: map[uint32]wire.CompleteableMessage{}, - request: restate.Request{ + request: Request{ AttemptHeaders: attemptHeaders, }, } - m.protocol = wire.NewProtocol(conn) + m.Protocol = wire.NewProtocol(conn) return m } @@ -312,7 +305,7 @@ func (m *Machine) Log() *slog.Logger { return m.log } // Start starts the state machine func (m *Machine) Start(inner context.Context, dropReplayLogs bool, logHandler slog.Handler) error { - msg, _, err := m.protocol.Read() + msg, _, err := m.Protocol.Read() if err != nil { return err } @@ -357,7 +350,7 @@ func (m *Machine) invoke(ctx *Context, outputSeen bool) error { case *protocolViolation: m.log.LogAttrs(m.ctx, slog.LevelError, "Protocol violation", log.Error(typ.err)) - if err := m.protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ + if err := m.Protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ ErrorMessage: protocol.ErrorMessage{ Code: uint32(errors.ErrProtocolViolation), Message: fmt.Sprintf("Protocol violation: %v", typ.err), @@ -370,7 +363,7 @@ func (m *Machine) invoke(ctx *Context, outputSeen bool) error { case *concurrentContextUse: m.log.LogAttrs(m.ctx, slog.LevelError, "Concurrent context use detected; either a Context method was used while a Run() is in progress, or Context methods are being called from multiple goroutines. Failing invocation.", slog.Uint64("entryType", uint64(typ.entryType))) - if err := m.protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ + if err := m.Protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ ErrorMessage: protocol.ErrorMessage{ Code: uint32(errors.ErrProtocolViolation), Message: "Concurrent context use detected; either a Context method was used while a Run() is in progress, or Context methods are being called from multiple goroutines.", @@ -390,7 +383,7 @@ func (m *Machine) invoke(ctx *Context, outputSeen bool) error { slog.String("actualMessage", string(actual))) // journal entry mismatch - if err := m.protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ + if err := m.Protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ ErrorMessage: protocol.ErrorMessage{ Code: uint32(errors.ErrJournalMismatch), Message: fmt.Sprintf(`Journal mismatch: Replayed journal entries did not correspond to the user code. The user code has to be deterministic! @@ -409,7 +402,7 @@ The journal entry at position %d was: case *writeError: m.log.LogAttrs(m.ctx, slog.LevelError, "Failed to write entry to Restate, shutting down state machine", log.Error(typ.err)) // don't even check for failure here because most likely the http2 conn is closed anyhow - _ = m.protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ + _ = m.Protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ ErrorMessage: protocol.ErrorMessage{ Code: uint32(errors.ErrProtocolViolation), Message: typ.err.Error(), @@ -422,9 +415,9 @@ The journal entry at position %d was: case *runFailure: m.log.LogAttrs(m.ctx, slog.LevelError, "Run returned a failure, returning error to Restate", log.Error(typ.err)) - if err := m.protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ + if err := m.Protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ ErrorMessage: protocol.ErrorMessage{ - Code: uint32(restate.ErrorCode(typ.err)), + Code: uint32(errors.ErrorCode(typ.err)), Message: typ.err.Error(), RelatedEntryIndex: &typ.entryIndex, RelatedEntryType: wire.AwakeableEntryMessageType.UInt32(), @@ -437,9 +430,9 @@ The journal entry at position %d was: case *codecFailure: m.log.LogAttrs(m.ctx, slog.LevelError, "Encoding failed, returning error to Restate", log.Error(typ.err)) - if err := m.protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ + if err := m.Protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ ErrorMessage: protocol.ErrorMessage{ - Code: uint32(restate.ErrorCode(typ.err)), + Code: uint32(errors.ErrorCode(typ.err)), Message: typ.err.Error(), RelatedEntryIndex: &typ.entryIndex, RelatedEntryType: wire.AwakeableEntryMessageType.UInt32(), @@ -449,7 +442,7 @@ The journal entry at position %d was: } return - case *clientGoneAway: + case *ClientGoneAway: m.log.LogAttrs(m.ctx, slog.LevelWarn, "Cancelling invocation as the incoming request context was cancelled", log.Error(typ.err)) return case *wire.SuspensionPanic: @@ -464,7 +457,7 @@ The journal entry at position %d was: if stderrors.Is(typ.Err, io.EOF) { m.log.LogAttrs(m.ctx, slog.LevelInfo, "Suspending invocation", slog.Any("entryIndexes", typ.EntryIndexes)) - if err := m.protocol.Write(wire.SuspensionMessageType, &wire.SuspensionMessage{ + if err := m.Protocol.Write(wire.SuspensionMessageType, &wire.SuspensionMessage{ SuspensionMessage: protocol.SuspensionMessage{ EntryIndexes: typ.EntryIndexes, }, @@ -475,9 +468,9 @@ The journal entry at position %d was: m.log.LogAttrs(m.ctx, slog.LevelError, "Unexpected error reading completions; shutting down state machine", log.Error(typ.Err), slog.Any("entryIndexes", typ.EntryIndexes)) // don't check for error here, most likely we will fail to send if we are in such a bad state - _ = m.protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ + _ = m.Protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ ErrorMessage: protocol.ErrorMessage{ - Code: uint32(restate.ErrorCode(typ.Err)), + Code: uint32(errors.ErrorCode(typ.Err)), Message: fmt.Sprintf("problem reading completions: %v", typ.Err), }, }) @@ -489,7 +482,7 @@ The journal entry at position %d was: // unknown panic! // send an error message (retryable) - if err := m.protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ + if err := m.Protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ ErrorMessage: protocol.ErrorMessage{ Code: 500, Message: fmt.Sprint(typ), @@ -508,27 +501,27 @@ The journal entry at position %d was: if outputSeen { m.log.WarnContext(m.ctx, "Invocation already completed; ending immediately") - return m.protocol.Write(wire.EndMessageType, &wire.EndMessage{}) + return m.Protocol.Write(wire.EndMessageType, &wire.EndMessage{}) } var bytes []byte var err error switch handler := m.handler.(type) { - case restate.ObjectHandler: + case ObjectHandler: bytes, err = handler.Call(ctx, m.request.Body) - case restate.ServiceHandler: + case ServiceHandler: bytes, err = handler.Call(ctx, m.request.Body) } - if err != nil && restate.IsTerminalError(err) { + if err != nil && errors.IsTerminalError(err) { m.log.LogAttrs(m.ctx, slog.LevelError, "Invocation returned a terminal failure", log.Error(err)) // terminal errors. - if err := m.protocol.Write(wire.OutputEntryMessageType, &wire.OutputEntryMessage{ + if err := m.Protocol.Write(wire.OutputEntryMessageType, &wire.OutputEntryMessage{ OutputEntryMessage: protocol.OutputEntryMessage{ Result: &protocol.OutputEntryMessage_Failure{ Failure: &protocol.Failure{ - Code: uint32(restate.ErrorCode(err)), + Code: uint32(errors.ErrorCode(err)), Message: err.Error(), }, }, @@ -536,21 +529,21 @@ The journal entry at position %d was: }); err != nil { return err } - return m.protocol.Write(wire.EndMessageType, &wire.EndMessage{}) + return m.Protocol.Write(wire.EndMessageType, &wire.EndMessage{}) } else if err != nil { m.log.LogAttrs(m.ctx, slog.LevelError, "Invocation returned a non-terminal failure", log.Error(err)) // non terminal error - no end message - return m.protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ + return m.Protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ ErrorMessage: protocol.ErrorMessage{ - Code: uint32(restate.ErrorCode(err)), + Code: uint32(errors.ErrorCode(err)), Message: err.Error(), }, }) } else { m.log.InfoContext(m.ctx, "Invocation completed successfully") - if err := m.protocol.Write(wire.OutputEntryMessageType, &wire.OutputEntryMessage{ + if err := m.Protocol.Write(wire.OutputEntryMessageType, &wire.OutputEntryMessage{ OutputEntryMessage: protocol.OutputEntryMessage{ Result: &protocol.OutputEntryMessage_Value{ Value: bytes, @@ -560,7 +553,7 @@ The journal entry at position %d was: return err } - return m.protocol.Write(wire.EndMessageType, &wire.EndMessage{}) + return m.Protocol.Write(wire.EndMessageType, &wire.EndMessage{}) } } @@ -571,7 +564,7 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error { m.partial = start.PartialState // expect input message - msg, _, err := m.protocol.Read() + msg, _, err := m.Protocol.Read() if err != nil { return err } @@ -595,7 +588,7 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error { // we don't track the poll input entry for i := uint32(1); i < start.KnownEntries; i++ { - msg, typ, err := m.protocol.Read() + msg, typ, err := m.Protocol.Read() if err != nil { return fmt.Errorf("failed to read entry: %w", err) } @@ -706,12 +699,12 @@ func (m *Machine) newConcurrentContextUse(entry wire.Type) *concurrentContextUse return c } -type clientGoneAway struct { +type ClientGoneAway struct { err error } -func (m *Machine) newClientGoneAway(err error) *clientGoneAway { - c := &clientGoneAway{err} +func (m *Machine) newClientGoneAway(err error) *ClientGoneAway { + c := &ClientGoneAway{err} m.failure = c return c } diff --git a/internal/state/state_test.go b/internal/state/state_test.go index 1a0d4d3..614fc30 100644 --- a/internal/state/state_test.go +++ b/internal/state/state_test.go @@ -1,4 +1,4 @@ -package state +package state_test import ( "context" @@ -10,8 +10,8 @@ import ( restate "github.com/restatedev/sdk-go" protocol "github.com/restatedev/sdk-go/generated/dev/restate/service" - "github.com/restatedev/sdk-go/interfaces" "github.com/restatedev/sdk-go/internal/errors" + "github.com/restatedev/sdk-go/internal/state" "github.com/restatedev/sdk-go/internal/wire" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" @@ -26,12 +26,12 @@ type testParams struct { var clientDisconnectError = fmt.Errorf("client disconnected") -func testHandler(handler restate.Handler) testParams { - machine := NewMachine(handler, nil, nil) +func testHandler(handler state.Handler) testParams { + machine := state.NewMachine(handler, nil, nil) inputC := make(chan wire.Message) outputC := make(chan wire.Message) ctx, cancel := context.WithCancelCause(context.Background()) - machine.protocol = mockProtocol{input: inputC, output: outputC} + machine.Protocol = mockProtocol{input: inputC, output: outputC} eg := errgroup.Group{} eg.Go(func() error { @@ -103,7 +103,7 @@ func TestResponseClosed(t *testing.T) { afterCancel: func(ctx restate.Context, _ any) { restate.Awakeable[restate.Void](ctx) }, - expectedPanic: &clientGoneAway{}, + expectedPanic: &state.ClientGoneAway{}, }, { name: "starting run should lead to client gone away panic", @@ -112,7 +112,7 @@ func TestResponseClosed(t *testing.T) { panic("run should not be executed") }) }, - expectedPanic: &clientGoneAway{}, + expectedPanic: &state.ClientGoneAway{}, }, { name: "awaiting sleep should lead to suspension panic", @@ -120,7 +120,7 @@ func TestResponseClosed(t *testing.T) { return restate.After(ctx, time.Minute) }, afterCancel: func(ctx restate.Context, setupState any) { - setupState.(interfaces.After).Done() + setupState.(restate.AfterFuture).Done() }, producedEntries: 1, expectedPanic: &wire.SuspensionPanic{}, @@ -212,7 +212,7 @@ func TestInFlightRunDisconnect(t *testing.T) { require.NoError(t, tp.wait()) require.Nil(t, beforeCancelErr, "run context should not be cancelled early") require.Equal(t, context.Canceled, afterCancelErr, "run context should be cancelled") - require.IsType(t, &clientGoneAway{}, seenPanic, "after the run should lead to a client gone away panic") + require.IsType(t, &state.ClientGoneAway{}, seenPanic, "after the run should lead to a client gone away panic") } // suspension mid-run should commit the run result to the runtime, but then panic with suspension when diff --git a/internal/state/sys.go b/internal/state/sys.go index 5ac2e1d..aaa81f2 100644 --- a/internal/state/sys.go +++ b/internal/state/sys.go @@ -8,7 +8,7 @@ import ( "sort" "time" - restate "github.com/restatedev/sdk-go" + "github.com/restatedev/sdk-go/encoding" protocol "github.com/restatedev/sdk-go/generated/dev/restate/service" "github.com/restatedev/sdk-go/internal/errors" "github.com/restatedev/sdk-go/internal/futures" @@ -44,7 +44,7 @@ func (m *Machine) newProtocolViolation(entry wire.Message, err error) *protocolV func (m *Machine) set(key string, value []byte) { _, _ = replayOrNew( m, - func(entry *wire.SetStateEntryMessage) (void restate.Void) { + func(entry *wire.SetStateEntryMessage) (void encoding.Void) { if string(entry.Key) != key || !bytes.Equal(entry.Value, value) { panic(m.newEntryMismatch(&wire.SetStateEntryMessage{ SetStateEntryMessage: protocol.SetStateEntryMessage{ @@ -54,7 +54,7 @@ func (m *Machine) set(key string, value []byte) { }, entry)) } return - }, func() (void restate.Void) { + }, func() (void encoding.Void) { m._set(key, value) return void }) @@ -75,7 +75,7 @@ func (m *Machine) _set(key string, value []byte) { func (m *Machine) clear(key string) { _, _ = replayOrNew( m, - func(entry *wire.ClearStateEntryMessage) (void restate.Void) { + func(entry *wire.ClearStateEntryMessage) (void encoding.Void) { if string(entry.Key) != key { panic(m.newEntryMismatch(&wire.ClearStateEntryMessage{ ClearStateEntryMessage: protocol.ClearStateEntryMessage{ @@ -85,9 +85,9 @@ func (m *Machine) clear(key string) { } return - }, func() restate.Void { + }, func() encoding.Void { m._clear(key) - return restate.Void{} + return encoding.Void{} }, ) @@ -107,11 +107,11 @@ func (m *Machine) _clear(key string) { func (m *Machine) clearAll() { _, _ = replayOrNew( m, - func(entry *wire.ClearAllStateEntryMessage) (void restate.Void) { + func(entry *wire.ClearAllStateEntryMessage) (void encoding.Void) { return - }, func() restate.Void { + }, func() encoding.Void { m._clearAll() - return restate.Void{} + return encoding.Void{} }, ) m.current = map[string][]byte{} @@ -145,7 +145,7 @@ func (m *Machine) get(key string) ([]byte, error) { switch value := entry.Result.(type) { case *protocol.GetStateEntryMessage_Empty: - return nil, restate.ErrKeyNotFound + return nil, errors.ErrKeyNotFound case *protocol.GetStateEntryMessage_Value: m.current[key] = value.Value return value.Value, nil @@ -280,7 +280,7 @@ func (m *Machine) _sleep(d time.Duration) *wire.SleepEntryMessage { return msg } -func (m *Machine) run(fn func(restate.RunContext) ([]byte, error)) ([]byte, error) { +func (m *Machine) run(fn func(RunContext) ([]byte, error)) ([]byte, error) { entry, entryIndex := replayOrNew( m, func(entry *wire.RunEntryMessage) *wire.RunEntryMessage { @@ -307,25 +307,40 @@ func (m *Machine) run(fn func(restate.RunContext) ([]byte, error)) ([]byte, erro } } -type runContext struct { +type RunContext struct { context.Context log *slog.Logger - request *restate.Request + request *Request } -func (r runContext) Log() *slog.Logger { return r.log } -func (r runContext) Request() *restate.Request { return r.request } +func (r RunContext) Log() *slog.Logger { return r.log } +func (r RunContext) Request() *Request { return r.request } + +type Request struct { + // The unique id that identifies the current function invocation. This id is guaranteed to be + // unique across invocations, but constant across reties and suspensions. + ID []byte + // Request headers - the following headers capture the original invocation headers, as provided to + // the ingress. + Headers map[string]string + // Attempt headers - the following headers are sent by the restate runtime. + // These headers are attempt specific, generated by the restate runtime uniquely for each attempt. + // These headers might contain information such as the W3C trace context, and attempt specific information. + AttemptHeaders map[string][]string + // Raw unparsed request body + Body []byte +} -func (m *Machine) _run(fn func(restate.RunContext) ([]byte, error)) *wire.RunEntryMessage { - bytes, err := fn(runContext{m.ctx, m.userLog, &m.request}) +func (m *Machine) _run(fn func(RunContext) ([]byte, error)) *wire.RunEntryMessage { + bytes, err := fn(RunContext{m.ctx, m.userLog, &m.request}) if err != nil { - if restate.IsTerminalError(err) { + if errors.IsTerminalError(err) { msg := &wire.RunEntryMessage{ RunEntryMessage: protocol.RunEntryMessage{ Result: &protocol.RunEntryMessage_Failure{ Failure: &protocol.Failure{ - Code: uint32(restate.ErrorCode(err)), + Code: uint32(errors.ErrorCode(err)), Message: err.Error(), }, }, diff --git a/reflect.go b/reflect.go index fbe07b3..e2e04dd 100644 --- a/reflect.go +++ b/reflect.go @@ -8,6 +8,7 @@ import ( "github.com/restatedev/sdk-go/encoding" "github.com/restatedev/sdk-go/internal" "github.com/restatedev/sdk-go/internal/options" + "github.com/restatedev/sdk-go/internal/state" ) type serviceNamer interface { @@ -141,7 +142,7 @@ type reflectHandler struct { handlerType *internal.ServiceHandlerType } -func (h *reflectHandler) getOptions() *options.HandlerOptions { +func (h *reflectHandler) GetOptions() *options.HandlerOptions { return &h.options } @@ -161,9 +162,9 @@ type objectReflectHandler struct { reflectHandler } -var _ ObjectHandler = (*objectReflectHandler)(nil) +var _ state.ObjectHandler = (*objectReflectHandler)(nil) -func (h *objectReflectHandler) Call(ctx ObjectContext, bytes []byte) ([]byte, error) { +func (h *objectReflectHandler) Call(ctx *state.Context, bytes []byte) ([]byte, error) { input := reflect.New(h.input) if err := encoding.Unmarshal(h.options.Codec, bytes, input.Interface()); err != nil { @@ -173,7 +174,7 @@ func (h *objectReflectHandler) Call(ctx ObjectContext, bytes []byte) ([]byte, er // we are sure about the fn signature so it's safe to do this output := h.fn.Call([]reflect.Value{ h.receiver, - reflect.ValueOf(ctx), + reflect.ValueOf(ctxWrapper{ctx}), input.Elem(), }) @@ -196,9 +197,9 @@ type serviceReflectHandler struct { reflectHandler } -var _ ServiceHandler = (*serviceReflectHandler)(nil) +var _ state.ServiceHandler = (*serviceReflectHandler)(nil) -func (h *serviceReflectHandler) Call(ctx Context, bytes []byte) ([]byte, error) { +func (h *serviceReflectHandler) Call(ctx *state.Context, bytes []byte) ([]byte, error) { input := reflect.New(h.input) if err := encoding.Unmarshal(h.options.Codec, bytes, input.Interface()); err != nil { @@ -208,7 +209,7 @@ func (h *serviceReflectHandler) Call(ctx Context, bytes []byte) ([]byte, error) // we are sure about the fn signature so it's safe to do this output := h.fn.Call([]reflect.Value{ h.receiver, - reflect.ValueOf(ctx), + reflect.ValueOf(ctxWrapper{ctx}), input.Elem(), }) diff --git a/router.go b/router.go index da148e9..db36a87 100644 --- a/router.go +++ b/router.go @@ -4,6 +4,7 @@ import ( "github.com/restatedev/sdk-go/encoding" "github.com/restatedev/sdk-go/internal" "github.com/restatedev/sdk-go/internal/options" + "github.com/restatedev/sdk-go/internal/state" ) // ServiceDefinition is the set of methods implemented by both services and virtual objects @@ -11,13 +12,13 @@ type ServiceDefinition interface { Name() string Type() internal.ServiceType // Set of handlers associated with this service definition - Handlers() map[string]Handler + Handlers() map[string]state.Handler } // serviceDefinition stores a list of handlers under a named service type serviceDefinition struct { name string - handlers map[string]Handler + handlers map[string]state.Handler options options.ServiceDefinitionOptions typ internal.ServiceType } @@ -30,7 +31,7 @@ func (r *serviceDefinition) Name() string { } // Handlers returns the list of handlers in this service definition -func (r *serviceDefinition) Handlers() map[string]Handler { +func (r *serviceDefinition) Handlers() map[string]state.Handler { return r.handlers } @@ -55,7 +56,7 @@ func NewService(name string, opts ...options.ServiceDefinitionOption) *service { return &service{ serviceDefinition: serviceDefinition{ name: name, - handlers: make(map[string]Handler), + handlers: make(map[string]state.Handler), options: o, typ: internal.ServiceType_SERVICE, }, @@ -63,9 +64,9 @@ func NewService(name string, opts ...options.ServiceDefinitionOption) *service { } // Handler registers a new Service handler by name -func (r *service) Handler(name string, handler ServiceHandler) *service { - if handler.getOptions().Codec == nil { - handler.getOptions().Codec = r.options.DefaultCodec +func (r *service) Handler(name string, handler state.ServiceHandler) *service { + if handler.GetOptions().Codec == nil { + handler.GetOptions().Codec = r.options.DefaultCodec } r.handlers[name] = handler return r @@ -87,7 +88,7 @@ func NewObject(name string, opts ...options.ServiceDefinitionOption) *object { return &object{ serviceDefinition: serviceDefinition{ name: name, - handlers: make(map[string]Handler), + handlers: make(map[string]state.Handler), options: o, typ: internal.ServiceType_VIRTUAL_OBJECT, }, @@ -95,9 +96,9 @@ func NewObject(name string, opts ...options.ServiceDefinitionOption) *object { } // Handler registers a new Virtual Object handler by name -func (r *object) Handler(name string, handler ObjectHandler) *object { - if handler.getOptions().Codec == nil { - handler.getOptions().Codec = r.options.DefaultCodec +func (r *object) Handler(name string, handler state.ObjectHandler) *object { + if handler.GetOptions().Codec == nil { + handler.GetOptions().Codec = r.options.DefaultCodec } r.handlers[name] = handler return r diff --git a/test-services/proxy.go b/test-services/proxy.go index c38497a..d9ce340 100644 --- a/test-services/proxy.go +++ b/test-services/proxy.go @@ -2,7 +2,6 @@ package main import ( restate "github.com/restatedev/sdk-go" - "github.com/restatedev/sdk-go/interfaces" ) type ProxyRequest struct { @@ -13,15 +12,17 @@ type ProxyRequest struct { Message []int `json:"message"` } -func (req *ProxyRequest) ToTarget(ctx restate.Context) restate.TypedClient[[]byte, []byte] { +func (req *ProxyRequest) ToTarget(ctx restate.Context) restate.Client[[]byte, []byte] { if req.VirtualObjectKey != nil { - return restate.NewTypedClient[[]byte, []byte](ctx.Object( + return restate.WithRequestType[[]byte](restate.Object[[]byte]( + ctx, req.ServiceName, *req.VirtualObjectKey, req.HandlerName, restate.WithBinary)) } else { - return restate.NewTypedClient[[]byte, []byte](ctx.Service( + return restate.WithRequestType[[]byte](restate.Service[[]byte]( + ctx, req.ServiceName, req.HandlerName, restate.WithBinary)) @@ -54,7 +55,7 @@ func init() { Handler("manyCalls", restate.NewServiceHandler( // We need to use []int because Golang takes the opinionated choice of treating []byte as Base64 func(ctx restate.Context, requests []ManyCallRequest) (restate.Void, error) { - var toAwait []interfaces.Selectable + var toAwait []restate.Selectable for _, req := range requests { input := intArrayToByteArray(req.ProxyRequest.Message) @@ -71,7 +72,7 @@ func init() { selector := restate.Select(ctx, toAwait...) for selector.Remaining() { result := selector.Select() - if _, err := result.(restate.TypedResponseFuture[[]byte]).Response(); err != nil { + if _, err := result.(restate.ResponseFuture[[]byte]).Response(); err != nil { return restate.Void{}, err } } diff --git a/test-services/testutils.go b/test-services/testutils.go index 1fe1481..2cee784 100644 --- a/test-services/testutils.go +++ b/test-services/testutils.go @@ -6,7 +6,6 @@ import ( "time" restate "github.com/restatedev/sdk-go" - "github.com/restatedev/sdk-go/interfaces" ) type CreateAwakeableAndAwaitItRequest struct { @@ -72,7 +71,7 @@ func init() { })). Handler("sleepConcurrently", restate.NewServiceHandler( func(ctx restate.Context, millisDuration []int64) (restate.Void, error) { - timers := make([]interfaces.Selectable, 0, len(millisDuration)) + timers := make([]restate.Selectable, 0, len(millisDuration)) for _, d := range millisDuration { timers = append(timers, restate.After(ctx, time.Duration(d)*time.Millisecond)) }