diff --git a/cmd/kwild/server/utils.go b/cmd/kwild/server/utils.go index 751fc24a2..2f252762c 100644 --- a/cmd/kwild/server/utils.go +++ b/cmd/kwild/server/utils.go @@ -31,8 +31,8 @@ import ( // getExtensions returns both the local and remote extensions. Remote extensions are identified by // connecting to the specified extension URLs. -func getExtensions(ctx context.Context, urls []string) (map[string]extensions.ExtensionDriver, error) { - exts := make(map[string]extensions.ExtensionDriver) +func getExtensions(ctx context.Context, urls []string) (map[string]extActions.Extension, error) { + exts := make(map[string]extActions.Extension) for name, ext := range extActions.GetRegisteredExtensions() { _, ok := exts[name] @@ -59,7 +59,7 @@ func getExtensions(ctx context.Context, urls []string) (map[string]extensions.Ex return exts, nil } -func adaptExtensions(exts map[string]extensions.ExtensionDriver) map[string]engine.ExtensionInitializer { +func adaptExtensions(exts map[string]extActions.Extension) map[string]engine.ExtensionInitializer { adapted := make(map[string]engine.ExtensionInitializer, len(exts)) for name, ext := range exts { diff --git a/extensions/actions/extension.go b/extensions/actions/extension.go deleted file mode 100644 index e62dcc566..000000000 --- a/extensions/actions/extension.go +++ /dev/null @@ -1,56 +0,0 @@ -package extensions - -import ( - "context" - "fmt" -) - -// Local Extension -type Extension struct { - // Extension name - name string - // Supported methods by the extension - methods map[string]MethodFunc - // Initializer that initializes the extension - initializeFunc InitializeFunc -} - -func (e *Extension) Name() string { - return e.name -} - -func (e *Extension) Execute(ctx context.Context, metadata map[string]string, method string, args ...any) ([]any, error) { - var encodedArgs []*ScalarValue - for _, arg := range args { - scalarVal, err := NewScalarValue(arg) - if err != nil { - return nil, fmt.Errorf("error encoding argument: %s", err.Error()) - } - - encodedArgs = append(encodedArgs, scalarVal) - } - - methodFn, ok := e.methods[method] - if !ok { - return nil, fmt.Errorf("method %s not found", method) - } - - execCtx := &ExecutionContext{ - Ctx: ctx, - Metadata: metadata, - } - results, err := methodFn(execCtx, encodedArgs...) - if err != nil { - return nil, err - } - - var outputs []any - for _, result := range results { - outputs = append(outputs, result.Value) - } - return outputs, nil -} - -func (e *Extension) Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) { - return e.initializeFunc(ctx, metadata) -} diff --git a/extensions/actions/extension_registry.go b/extensions/actions/extension_registry.go index a88e66afb..c742d0e23 100644 --- a/extensions/actions/extension_registry.go +++ b/extensions/actions/extension_registry.go @@ -1,10 +1,19 @@ package extensions -import "strings" +import ( + "context" + "strings" +) -var registeredExtensions = make(map[string]*Extension) +type Extension interface { + Name() string + Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) + Execute(ctx context.Context, metadata map[string]string, method string, args ...any) ([]any, error) +} + +var registeredExtensions = make(map[string]Extension) -func RegisterExtension(name string, ext *Extension) error { +func RegisterExtension(name string, ext Extension) error { name = strings.ToLower(name) if _, ok := registeredExtensions[name]; ok { panic("extension of same name already registered: " + name) @@ -14,6 +23,6 @@ func RegisterExtension(name string, ext *Extension) error { return nil } -func GetRegisteredExtensions() map[string]*Extension { +func GetRegisteredExtensions() map[string]Extension { return registeredExtensions } diff --git a/extensions/actions/math.go b/extensions/actions/math.go index c3261b717..8d8cf4b5c 100644 --- a/extensions/actions/math.go +++ b/extensions/actions/math.go @@ -6,15 +6,13 @@ import ( "context" "fmt" "math/big" + + "github.com/cstockton/go-conv" ) func init() { - ext, err := NewMathExtension() - if err != nil { - panic(err) - } - - err = RegisterExtension("math", ext) + mathExt := &MathExtension{} + err := RegisterExtension("math", mathExt) if err != nil { panic(err) } @@ -22,28 +20,12 @@ func init() { type MathExtension struct{} -func NewMathExtension() (*Extension, error) { - mathExt := &MathExtension{} - methods := map[string]MethodFunc{ - "add": mathExt.add, - "subtract": mathExt.subtract, - "multiply": mathExt.multiply, - "divide": mathExt.divide, - } - - ext, err := Builder().Named("math").WithMethods(methods).WithInitializer(initialize).Build() - if err != nil { - return nil, err - } - return ext, nil -} - func (e *MathExtension) Name() string { return "math" } // this initialize function checks if round is set. If not, it sets it to "up" -func initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) { +func (e *MathExtension) Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) { _, ok := metadata["round"] if !ok { metadata["round"] = "up" @@ -57,71 +39,92 @@ func initialize(ctx context.Context, metadata map[string]string) (map[string]str return metadata, nil } -func (e *MathExtension) add(ctx *ExecutionContext, values ...*ScalarValue) ([]*ScalarValue, error) { +func (e *MathExtension) Execute(ctx context.Context, metadata map[string]string, method string, args ...any) ([]any, error) { + switch method { + case "add": + return e.add(ctx, metadata, args...) + case "subtract": + return e.subtract(ctx, metadata, args...) + case "multiply": + return e.multiply(ctx, metadata, args...) + case "divide": + return e.divide(ctx, metadata, args...) + default: + return nil, fmt.Errorf("method %s not found", method) + } +} + +func (e *MathExtension) add(ctx context.Context, metadata map[string]string, values ...any) ([]any, error) { if len(values) != 2 { return nil, fmt.Errorf("expected 2 values for method Add, got %d", len(values)) } - val0Int, err := values[0].Int() + val0Int, err := conv.Int(values[0]) if err != nil { return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val0Int) } - val1Int, err := values[1].Int() + val1Int, err := conv.Int(values[1]) if err != nil { return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val1Int) } - return encodeScalarValues(val0Int + val1Int) + var results []any + results = append(results, val0Int+val1Int) + return results, nil } -func (e *MathExtension) subtract(ctx *ExecutionContext, values ...*ScalarValue) ([]*ScalarValue, error) { +func (e *MathExtension) subtract(ctx context.Context, metadata map[string]string, values ...any) ([]any, error) { if len(values) != 2 { return nil, fmt.Errorf("expected 2 values for method Subtract, got %d", len(values)) } - val0Int, err := values[0].Int() + val0Int, err := conv.Int(values[0]) if err != nil { return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val0Int) } - val1Int, err := values[1].Int() + val1Int, err := conv.Int(values[1]) if err != nil { return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val1Int) } - return encodeScalarValues(val0Int - val1Int) + var results []any + results = append(results, val0Int-val1Int) + return results, nil } -func (e *MathExtension) multiply(ctx *ExecutionContext, values ...*ScalarValue) ([]*ScalarValue, error) { +func (e *MathExtension) multiply(ctx context.Context, metadata map[string]string, values ...any) ([]any, error) { if len(values) != 2 { return nil, fmt.Errorf("expected 2 values for method Multiply, got %d", len(values)) } - val0Int, err := values[0].Int() + val0Int, err := conv.Int(values[0]) if err != nil { return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val0Int) } - val1Int, err := values[1].Int() + val1Int, err := conv.Int(values[1]) if err != nil { return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val1Int) } - return encodeScalarValues(val0Int * val1Int) + var results []any + results = append(results, val0Int*val1Int) + return results, nil } -func (e *MathExtension) divide(ctx *ExecutionContext, values ...*ScalarValue) ([]*ScalarValue, error) { +func (e *MathExtension) divide(ctx context.Context, metadata map[string]string, values ...any) ([]any, error) { if len(values) != 2 { return nil, fmt.Errorf("expected 2 values for method Divide, got %d", len(values)) } - val0Int, err := values[0].Int() + val0Int, err := conv.Int(values[0]) if err != nil { return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val0Int) } - val1Int, err := values[1].Int() + val1Int, err := conv.Int(values[1]) if err != nil { return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val1Int) } @@ -133,13 +136,14 @@ func (e *MathExtension) divide(ctx *ExecutionContext, values ...*ScalarValue) ([ result := new(big.Float).Quo(bigVal1, bigVal2) var IntResult *big.Int - if ctx.Metadata["round"] == "up" { + var results []any + if metadata["round"] == "up" { IntResult = roundUp(result) } else { IntResult = roundDown(result) } - - return encodeScalarValues(IntResult.Int64()) + results = append(results, IntResult) + return results, nil } // roundUp takes a big.Float and returns a new big.Float rounded up. @@ -163,20 +167,6 @@ func roundDown(f *big.Float) *big.Int { return r } -func encodeScalarValues(values ...any) ([]*ScalarValue, error) { - scalarValues := make([]*ScalarValue, len(values)) - for i, v := range values { - scalarValue, err := NewScalarValue(v) - if err != nil { - return nil, err - } - - scalarValues[i] = scalarValue - } - - return scalarValues, nil -} - const ( precision = 128 ) diff --git a/extensions/actions/utils.go b/extensions/actions/utils.go deleted file mode 100644 index 01c07fa1b..000000000 --- a/extensions/actions/utils.go +++ /dev/null @@ -1,127 +0,0 @@ -package extensions - -import ( - "context" - "fmt" - "reflect" - - "github.com/cstockton/go-conv" -) - -type extensionBuilder struct { - extension *Extension -} - -// ExtensionBuilder is the interface for creating an extension server -type ExtensionBuilder interface { - // WithMethods specifies the methods that should be provided - // by the extension - WithMethods(map[string]MethodFunc) ExtensionBuilder - // WithInitializer is a function that initializes a new extension instance. - WithInitializer(InitializeFunc) ExtensionBuilder - // Named specifies the name of the extensions. - Named(string) ExtensionBuilder - - // Build creates the extensions - Build() (*Extension, error) -} - -func Builder() ExtensionBuilder { - return &extensionBuilder{ - extension: &Extension{ - methods: make(map[string]MethodFunc), - initializeFunc: func(ctx context.Context, metadata map[string]string) (map[string]string, error) { - return metadata, nil - }, - }, - } -} - -func (b *extensionBuilder) Named(name string) ExtensionBuilder { - b.extension.name = name - return b -} - -func (b *extensionBuilder) WithMethods(methods map[string]MethodFunc) ExtensionBuilder { - b.extension.methods = methods - return b -} - -func (b *extensionBuilder) WithInitializer(fn InitializeFunc) ExtensionBuilder { - b.extension.initializeFunc = fn - return b -} - -func (b *extensionBuilder) Build() (*Extension, error) { - return b.extension, nil -} - -type ExecutionContext struct { - Ctx context.Context - Metadata map[string]string -} - -// MethodFunc is a function that executes a method -type MethodFunc func(ctx *ExecutionContext, inputs ...*ScalarValue) ([]*ScalarValue, error) - -// InitializeFunc is a function that creates a new instance of an extension. -// In most cases, this should just validate the metadata being sent. -type InitializeFunc func(ctx context.Context, metadata map[string]string) (map[string]string, error) - -// WithInputsCheck checks the number of inputs. -// If the number of inputs is not equal to numInputs, it returns an error. -func WithInputsCheck(fn MethodFunc, numInputs int) MethodFunc { - return func(ctx *ExecutionContext, inputs ...*ScalarValue) ([]*ScalarValue, error) { - if len(inputs) != numInputs { - return nil, fmt.Errorf("expected %d args, got %d", numInputs, len(inputs)) - } - return fn(ctx, inputs...) - } -} - -// WithOutputsCheck checks the number of outputs. -// If the number of outputs is not equal to numOutputs, it returns an error. -func WithOutputsCheck(fn MethodFunc, numOutputs int) MethodFunc { - return func(ctx *ExecutionContext, inputs ...*ScalarValue) ([]*ScalarValue, error) { - res, err := fn(ctx, inputs...) - if err != nil { - return nil, err - } - - if len(res) != numOutputs { - return nil, fmt.Errorf("expected %d returns, got %d", numOutputs, len(res)) - } - - return res, nil - } -} - -type ScalarValue struct { - Value any -} - -func NewScalarValue(v any) (*ScalarValue, error) { - valueType := reflect.TypeOf(v) - switch valueType.Kind() { - case reflect.String, reflect.Float32, reflect.Float64: - return &ScalarValue{ - Value: v, - }, nil - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return &ScalarValue{ - Value: v, - }, nil - default: - return nil, fmt.Errorf("invalid scalar type: %s", valueType.Kind()) - } -} - -// String returns the string representation of the value. -func (s *ScalarValue) String() (string, error) { - return conv.String(s.Value) -} - -// Int returns the int representation of the value. -func (s *ScalarValue) Int() (int64, error) { - return conv.Int64(s.Value) -} diff --git a/go.mod b/go.mod index 1ee1c8053..49e20065a 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.21 replace ( github.com/kwilteam/kwil-db/core => ./core - github.com/kwilteam/kwil-db/extensions => ./extensions github.com/kwilteam/kwil-db/parse => ./parse ) diff --git a/internal/extensions/driver.go b/internal/extensions/driver.go index 653794337..33864bf0e 100644 --- a/internal/extensions/driver.go +++ b/internal/extensions/driver.go @@ -3,6 +3,7 @@ package extensions import ( "context" + extensions "github.com/kwilteam/kwil-db/extensions/actions" "github.com/kwilteam/kwil-extensions/client" "github.com/kwilteam/kwil-extensions/types" ) @@ -12,14 +13,8 @@ var ( ConnectFunc Connecter = extensionConnectFunc(client.NewExtensionClient) ) -type ExtensionDriver interface { - Name() string - Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) - Execute(ctx context.Context, metadata map[string]string, method string, args ...any) ([]any, error) -} - type ExtensionInitializer struct { - Extension ExtensionDriver + Extension extensions.Extension } // CreateInstance creates an instance of the extension with the given metadata. diff --git a/internal/extensions/extensions_test.go b/internal/extensions/extensions_test.go index b662b5154..5892523e8 100644 --- a/internal/extensions/extensions_test.go +++ b/internal/extensions/extensions_test.go @@ -4,6 +4,7 @@ package extensions_test import ( "context" + "math/big" "testing" extActions "github.com/kwilteam/kwil-db/extensions/actions" @@ -61,8 +62,7 @@ func Test_LocalExtension(t *testing.T) { "roundoff": "down", } - ext, err := extActions.NewMathExtension() - assert.NoError(t, err) + ext := &extActions.MathExtension{} initializer := &extensions.ExtensionInitializer{ Extension: ext, @@ -72,9 +72,9 @@ func Test_LocalExtension(t *testing.T) { instance1, err := initializer.CreateInstance(ctx, metadata) assert.NoError(t, err) - result, err := instance1.Execute(ctx, "divide", 1, 2) + result, err := instance1.Execute(ctx, "divide", 1.2, 2.3) assert.NoError(t, err) - assert.Equal(t, int64(0), result[0]) // 1/2 rounded down to 0 + assert.Equal(t, big.NewInt(0), result[0]) // 1/2 rounded down to 0 // Create instance with incorrect metadata, uses defaults instance2, err := initializer.CreateInstance(ctx, incorrectMetadata) @@ -82,9 +82,9 @@ func Test_LocalExtension(t *testing.T) { updatedMetadata := instance2.Metadata() assert.Equal(t, updatedMetadata["round"], "up") - result, err = instance2.Execute(ctx, "divide", 1, 2) + result, err = instance2.Execute(ctx, "divide", 1, 2.3) assert.NoError(t, err) - assert.Equal(t, int64(1), result[0]) // 1/2 rounded up -> 1 + assert.Equal(t, big.NewInt(1), result[0]) // 1/2 rounded up -> 1 } func Test_RemoteExtension(t *testing.T) { diff --git a/internal/extensions/instance.go b/internal/extensions/instance.go index 816c01c40..101569899 100644 --- a/internal/extensions/instance.go +++ b/internal/extensions/instance.go @@ -3,6 +3,8 @@ package extensions import ( "context" "strings" + + extensions "github.com/kwilteam/kwil-db/extensions/actions" ) // An instance is a single instance of an extension. @@ -13,7 +15,7 @@ import ( type Instance struct { metadata map[string]string - extension ExtensionDriver + extension extensions.Extension } func (i *Instance) Metadata() map[string]string {