Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposal/WIP: action interface refactoring #331

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions go/ai/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ import (
)

// Generator is the interface used to query an AI model.
type Generator interface {
type GeneratorAction interface {
core.Action[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk]

// If the streaming callback is non-nil:
// - Each response candidate will be passed to that callback instead of
// populating the result's Candidates field.
Expand All @@ -38,7 +40,7 @@ type Generator interface {
}

// GeneratorStreamingCallback is the type for the streaming callback of a generator.
type GeneratorStreamingCallback = func(context.Context, *Candidate) error
type GeneratorStreamingCallback = func(context.Context, *GenerateResponseChunk) error

// GeneratorCapabilities describes various capabilities of the generator.
type GeneratorCapabilities struct {
Expand All @@ -56,7 +58,7 @@ type GeneratorMetadata struct {

// DefineGenerator registers the given generate function as an action, and returns a
// [Generator] whose Generate method runs it.
func DefineGenerator(provider, name string, metadata *GeneratorMetadata, generate func(context.Context, *GenerateRequest, GeneratorStreamingCallback) (*GenerateResponse, error)) Generator {
func DefineGenerator(provider, name string, metadata *GeneratorMetadata, generate func(context.Context, *GenerateRequest, GeneratorStreamingCallback) (*GenerateResponse, error)) GeneratorAction {
metadataMap := map[string]any{}
if metadata != nil {
if metadata.Label != "" {
Expand All @@ -70,22 +72,41 @@ func DefineGenerator(provider, name string, metadata *GeneratorMetadata, generat
}
metadataMap["supports"] = supports
}
a := core.DefineStreamingAction(provider, name, core.ActionTypeModel, map[string]any{
act := core.NewStreamingAction(name, core.ActionTypeModel, map[string]any{
"model": metadataMap,
}, generate)
return generator{a}
ga := &generator{act}
core.RegisterAction(provider, ga)
return ga
}

type generator struct {
generateAction *core.Action[*GenerateRequest, *GenerateResponse, *Candidate]
generateAction core.Action[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk]
}

// Name returns the Action's name.
func (a *generator) Name() string { return a.generateAction.Name() }

// Description returns the Action's description.
func (a *generator) Description() string { return a.generateAction.Description() }

// ActionType returns the Action's type.
func (a *generator) ActionType() core.ActionType { return a.generateAction.ActionType() }

// Metadata returns the Action's metadata.
func (a *generator) Metadata() map[string]any { return a.generateAction.Metadata() }

// Run run the action...
func (a *generator) Run(ctx context.Context, req *GenerateRequest, cb GeneratorStreamingCallback) (*GenerateResponse, error) {
return a.generateAction.Run(ctx, req, cb)
}

func (g generator) Generate(ctx context.Context, req *GenerateRequest, cb GeneratorStreamingCallback) (*GenerateResponse, error) {
return g.generateAction.Run(ctx, req, cb)
}

// Generate applies a [Generator] to some input, handling tool requests.
func Generate(ctx context.Context, g Generator, req *GenerateRequest, cb GeneratorStreamingCallback) (*GenerateResponse, error) {
func Generate(ctx context.Context, g GeneratorAction, req *GenerateRequest, cb GeneratorStreamingCallback) (*GenerateResponse, error) {
if err := conformOutput(req); err != nil {
return nil, err
}
Expand Down Expand Up @@ -114,21 +135,17 @@ func Generate(ctx context.Context, g Generator, req *GenerateRequest, cb Generat
}
}

// generatorActionType is the instantiated core.Action type registered
// by RegisterGenerator.
type generatorActionType = core.Action[*GenerateRequest, *GenerateResponse, *Candidate]

// LookupGenerator looks up a [Generator] registered by [DefineGenerator].
func LookupGenerator(provider, name string) (Generator, error) {
func LookupGenerator(provider, name string) (GeneratorAction, error) {
action := core.LookupAction(core.ActionTypeModel, provider, name)
if action == nil {
return nil, fmt.Errorf("LookupGenerator: no generator action named %q/%q", provider, name)
}
actionInst, ok := action.(*generatorActionType)
actionInst, ok := action.(GeneratorAction)
if !ok {
return nil, fmt.Errorf("LookupGenerator: generator action %q has type %T, want %T", name, action, &generatorActionType{})
return nil, fmt.Errorf("LookupGenerator: generator action %q has type %T, want %T", name, action, &generator{})
}
return generator{actionInst}, nil
return actionInst, nil
}

// conformOutput appends a message to the request indicating conformance to the expected schema.
Expand Down
71 changes: 34 additions & 37 deletions go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,53 +52,61 @@ type streamingCallback[Stream any] func(context.Context, Stream) error
// and JSON Schemas for its input and output.
//
// Each time an Action is run, it results in a new trace span.
type Action[In, Out, Stream any] struct {
type action[In, Out, Stream any] struct {
name string
atype ActionType
fn Func[In, Out, Stream]
tstate *tracing.State
inputSchema *jsonschema.Schema
outputSchema *jsonschema.Schema
// optional
Description string
Metadata map[string]any
description string
metadata map[string]any
}

type Action[In, Out, Stream any] interface {
Run(ctx context.Context, input In, cb func(context.Context, Stream) error) (output Out, err error)
Name() string
Description() string
ActionType() ActionType
Metadata() map[string]any
}

// See js/core/src/action.ts

// DefineAction creates a new Action and registers it.
func DefineAction[In, Out any](provider, name string, atype ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
func DefineAction[In, Out any](provider, name string, atype ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) Action[In, Out, struct{}] {
return defineAction(globalRegistry, provider, name, atype, metadata, fn)
}

func defineAction[In, Out any](r *registry, provider, name string, atype ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
func defineAction[In, Out any](r *registry, provider, name string, atype ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) Action[In, Out, struct{}] {
a := NewAction(name, atype, metadata, fn)
r.registerAction(provider, a)
return a
}

func DefineStreamingAction[In, Out, Stream any](provider, name string, atype ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
func DefineStreamingAction[In, Out, Stream any](provider, name string, atype ActionType, metadata map[string]any, fn Func[In, Out, Stream]) Action[In, Out, Stream] {
return defineStreamingAction(globalRegistry, provider, name, atype, metadata, fn)
}

func defineStreamingAction[In, Out, Stream any](r *registry, provider, name string, atype ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
func defineStreamingAction[In, Out, Stream any](r *registry, provider, name string, atype ActionType, metadata map[string]any, fn Func[In, Out, Stream]) Action[In, Out, Stream] {
a := NewStreamingAction(name, atype, metadata, fn)
r.registerAction(provider, a)
return a
}

// NewAction creates a new Action with the given name and non-streaming function.
func NewAction[In, Out any](name string, atype ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
func NewAction[In, Out any](name string, atype ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) Action[In, Out, struct{}] {
return NewStreamingAction(name, atype, metadata, func(ctx context.Context, in In, cb NoStream) (Out, error) {
return fn(ctx, in)
})
}

// NewStreamingAction creates a new Action with the given name and streaming function.
func NewStreamingAction[In, Out, Stream any](name string, atype ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
func NewStreamingAction[In, Out, Stream any](name string, atype ActionType, metadata map[string]any, fn Func[In, Out, Stream]) Action[In, Out, Stream] {
var i In
var o Out
return &Action[In, Out, Stream]{
return &action[In, Out, Stream]{
name: name,
atype: atype,
fn: func(ctx context.Context, input In, sc func(context.Context, Stream) error) (Out, error) {
Expand All @@ -107,20 +115,27 @@ func NewStreamingAction[In, Out, Stream any](name string, atype ActionType, meta
},
inputSchema: inferJSONSchema(i),
outputSchema: inferJSONSchema(o),
Metadata: metadata,
metadata: metadata,
}
}

// Name returns the Action's name.
func (a *Action[In, Out, Stream]) Name() string { return a.name }
func (a *action[In, Out, Stream]) Name() string { return a.name }

func (a *Action[In, Out, Stream]) actionType() ActionType { return a.atype }
// Description returns the Action's description.
func (a *action[In, Out, Stream]) Description() string { return a.description }

// Metadata returns the Action's metadata.
func (a *action[In, Out, Stream]) Metadata() map[string]any { return a.metadata }

// ActionType returns the Action's type.
func (a *action[In, Out, Stream]) ActionType() ActionType { return a.atype }

// setTracingState sets the action's tracing.State.
func (a *Action[In, Out, Stream]) setTracingState(tstate *tracing.State) { a.tstate = tstate }
func (a *action[In, Out, Stream]) setTracingState(tstate *tracing.State) { a.tstate = tstate }

// Run executes the Action's function in a new trace span.
func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(context.Context, Stream) error) (output Out, err error) {
func (a *action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(context.Context, Stream) error) (output Out, err error) {
logger.FromContext(ctx).Debug("Action.Run",
"name", a.name,
"input", fmt.Sprintf("%#v", input))
Expand Down Expand Up @@ -161,7 +176,7 @@ func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(con
})
}

func (a *Action[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) {
func (a *action[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) {
// Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process.
if err := ValidateJSON(input, a.inputSchema); err != nil {
return nil, err
Expand Down Expand Up @@ -191,24 +206,6 @@ func (a *Action[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMes
return json.RawMessage(bytes), nil
}

// action is the type that all Action[I, O, S] have in common.
type action interface {
Name() string
actionType() ActionType

// runJSON uses encoding/json to unmarshal the input,
// calls Action.Run, then returns the marshaled result.
runJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error)

// desc returns a description of the action.
// It should set all fields of actionDesc except Key, which
// the registry will set.
desc() actionDesc

// setTracingState set's the action's tracing.State.
setTracingState(*tracing.State)
}

// An actionDesc is a description of an Action.
// It is used to provide a list of registered actions.
type actionDesc struct {
Expand All @@ -227,11 +224,11 @@ func (d1 actionDesc) equal(d2 actionDesc) bool {
maps.Equal(d1.Metadata, d2.Metadata)
}

func (a *Action[I, O, S]) desc() actionDesc {
func (a *action[I, O, S]) desc() actionDesc {
ad := actionDesc{
Name: a.name,
Description: a.Description,
Metadata: a.Metadata,
Description: a.description,
Metadata: a.metadata,
InputSchema: a.inputSchema,
OutputSchema: a.outputSchema,
}
Expand Down
Loading