diff --git a/cmd/kwild/server/utils.go b/cmd/kwild/server/utils.go index 751fc24a2..c46fdb0c0 100644 --- a/cmd/kwild/server/utils.go +++ b/cmd/kwild/server/utils.go @@ -31,10 +31,10 @@ import ( // getExtensions returns both the local and remote extensions. Remote extensions are identified by // connecting to the specified extension URLs. -func getExtensions(ctx context.Context, urls []string) (map[string]extensions.ExtensionDriver, error) { - exts := make(map[string]extensions.ExtensionDriver) +func getExtensions(ctx context.Context, urls []string) (map[string]extActions.Extension, error) { + exts := make(map[string]extActions.Extension) - for name, ext := range extActions.GetRegisteredExtensions() { + for name, ext := range extActions.RegisteredExtensions() { _, ok := exts[name] if ok { return nil, fmt.Errorf("duplicate extension name: %s", name) @@ -59,7 +59,7 @@ func getExtensions(ctx context.Context, urls []string) (map[string]extensions.Ex return exts, nil } -func adaptExtensions(exts map[string]extensions.ExtensionDriver) map[string]engine.ExtensionInitializer { +func adaptExtensions(exts map[string]extActions.Extension) map[string]engine.ExtensionInitializer { adapted := make(map[string]engine.ExtensionInitializer, len(exts)) for name, ext := range exts { diff --git a/extensions/actions/extension.go b/extensions/actions/extension.go deleted file mode 100644 index e62dcc566..000000000 --- a/extensions/actions/extension.go +++ /dev/null @@ -1,56 +0,0 @@ -package extensions - -import ( - "context" - "fmt" -) - -// Local Extension -type Extension struct { - // Extension name - name string - // Supported methods by the extension - methods map[string]MethodFunc - // Initializer that initializes the extension - initializeFunc InitializeFunc -} - -func (e *Extension) Name() string { - return e.name -} - -func (e *Extension) Execute(ctx context.Context, metadata map[string]string, method string, args ...any) ([]any, error) { - var encodedArgs []*ScalarValue - for _, arg := range args { - scalarVal, err := NewScalarValue(arg) - if err != nil { - return nil, fmt.Errorf("error encoding argument: %s", err.Error()) - } - - encodedArgs = append(encodedArgs, scalarVal) - } - - methodFn, ok := e.methods[method] - if !ok { - return nil, fmt.Errorf("method %s not found", method) - } - - execCtx := &ExecutionContext{ - Ctx: ctx, - Metadata: metadata, - } - results, err := methodFn(execCtx, encodedArgs...) - if err != nil { - return nil, err - } - - var outputs []any - for _, result := range results { - outputs = append(outputs, result.Value) - } - return outputs, nil -} - -func (e *Extension) Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) { - return e.initializeFunc(ctx, metadata) -} diff --git a/extensions/actions/extension_registry.go b/extensions/actions/extension_registry.go index a88e66afb..f9ea44ec1 100644 --- a/extensions/actions/extension_registry.go +++ b/extensions/actions/extension_registry.go @@ -1,10 +1,19 @@ package extensions -import "strings" +import ( + "context" + "strings" +) -var registeredExtensions = make(map[string]*Extension) +type Extension interface { + Name() string + Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) + Execute(ctx context.Context, metadata map[string]string, method string, args ...any) ([]any, error) +} + +var registeredExtensions = make(map[string]Extension) -func RegisterExtension(name string, ext *Extension) error { +func RegisterExtension(name string, ext Extension) error { name = strings.ToLower(name) if _, ok := registeredExtensions[name]; ok { panic("extension of same name already registered: " + name) @@ -14,6 +23,6 @@ func RegisterExtension(name string, ext *Extension) error { return nil } -func GetRegisteredExtensions() map[string]*Extension { +func RegisteredExtensions() map[string]Extension { return registeredExtensions } diff --git a/extensions/actions/math.go b/extensions/actions/math.go deleted file mode 100644 index c3261b717..000000000 --- a/extensions/actions/math.go +++ /dev/null @@ -1,188 +0,0 @@ -//go:build actions_math || ext_test - -package extensions - -import ( - "context" - "fmt" - "math/big" -) - -func init() { - ext, err := NewMathExtension() - if err != nil { - panic(err) - } - - err = RegisterExtension("math", ext) - if err != nil { - panic(err) - } -} - -type MathExtension struct{} - -func NewMathExtension() (*Extension, error) { - mathExt := &MathExtension{} - methods := map[string]MethodFunc{ - "add": mathExt.add, - "subtract": mathExt.subtract, - "multiply": mathExt.multiply, - "divide": mathExt.divide, - } - - ext, err := Builder().Named("math").WithMethods(methods).WithInitializer(initialize).Build() - if err != nil { - return nil, err - } - return ext, nil -} - -func (e *MathExtension) Name() string { - return "math" -} - -// this initialize function checks if round is set. If not, it sets it to "up" -func initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) { - _, ok := metadata["round"] - if !ok { - metadata["round"] = "up" - } - - roundVal := metadata["round"] - if roundVal != "up" && roundVal != "down" { - return nil, fmt.Errorf("round must be either 'up' or 'down'. default is 'up'") - } - - return metadata, nil -} - -func (e *MathExtension) add(ctx *ExecutionContext, values ...*ScalarValue) ([]*ScalarValue, error) { - if len(values) != 2 { - return nil, fmt.Errorf("expected 2 values for method Add, got %d", len(values)) - } - - val0Int, err := values[0].Int() - if err != nil { - return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val0Int) - } - - val1Int, err := values[1].Int() - if err != nil { - return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val1Int) - } - - return encodeScalarValues(val0Int + val1Int) -} - -func (e *MathExtension) subtract(ctx *ExecutionContext, values ...*ScalarValue) ([]*ScalarValue, error) { - if len(values) != 2 { - return nil, fmt.Errorf("expected 2 values for method Subtract, got %d", len(values)) - } - - val0Int, err := values[0].Int() - if err != nil { - return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val0Int) - } - - val1Int, err := values[1].Int() - if err != nil { - return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val1Int) - } - - return encodeScalarValues(val0Int - val1Int) -} - -func (e *MathExtension) multiply(ctx *ExecutionContext, values ...*ScalarValue) ([]*ScalarValue, error) { - if len(values) != 2 { - return nil, fmt.Errorf("expected 2 values for method Multiply, got %d", len(values)) - } - - val0Int, err := values[0].Int() - if err != nil { - return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val0Int) - } - - val1Int, err := values[1].Int() - if err != nil { - return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val1Int) - } - - return encodeScalarValues(val0Int * val1Int) -} - -func (e *MathExtension) divide(ctx *ExecutionContext, values ...*ScalarValue) ([]*ScalarValue, error) { - if len(values) != 2 { - return nil, fmt.Errorf("expected 2 values for method Divide, got %d", len(values)) - } - - val0Int, err := values[0].Int() - if err != nil { - return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val0Int) - } - - val1Int, err := values[1].Int() - if err != nil { - return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val1Int) - } - - bigVal1 := newBigFloat(float64(val0Int)) - - bigVal2 := newBigFloat(float64(val1Int)) - - result := new(big.Float).Quo(bigVal1, bigVal2) - - var IntResult *big.Int - if ctx.Metadata["round"] == "up" { - IntResult = roundUp(result) - } else { - IntResult = roundDown(result) - } - - return encodeScalarValues(IntResult.Int64()) -} - -// roundUp takes a big.Float and returns a new big.Float rounded up. -func roundUp(f *big.Float) *big.Int { - c := new(big.Float).SetPrec(precision).Copy(f) - r := new(big.Int) - f.Int(r) - - if c.Sub(c, new(big.Float).SetPrec(precision).SetInt(r)).Sign() > 0 { - r.Add(r, big.NewInt(1)) - } - - return r -} - -// roundDown takes a big.Float and returns a new big.Float rounded down. -func roundDown(f *big.Float) *big.Int { - r := new(big.Int) - f.Int(r) - - return r -} - -func encodeScalarValues(values ...any) ([]*ScalarValue, error) { - scalarValues := make([]*ScalarValue, len(values)) - for i, v := range values { - scalarValue, err := NewScalarValue(v) - if err != nil { - return nil, err - } - - scalarValues[i] = scalarValue - } - - return scalarValues, nil -} - -const ( - precision = 128 -) - -func newBigFloat(num float64) *big.Float { - bg := new(big.Float).SetPrec(precision) - - return bg.SetFloat64(num) -} diff --git a/extensions/actions/math_example/math.go b/extensions/actions/math_example/math.go new file mode 100644 index 000000000..bdb3633df --- /dev/null +++ b/extensions/actions/math_example/math.go @@ -0,0 +1,182 @@ +//go:build actions_math || ext_test + +package mathexample + +import ( + "context" + "fmt" + "math/big" + + extensions "github.com/kwilteam/kwil-db/extensions/actions" +) + +func init() { + mathExt := &MathExtension{} + err := extensions.RegisterExtension("math", mathExt) + if err != nil { + panic(err) + } +} + +type MathExtension struct{} + +func (e *MathExtension) Name() string { + return "math" +} + +// this initialize function checks if round is set. If not, it sets it to "up" +func (e *MathExtension) Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) { + _, ok := metadata["round"] + if !ok { + metadata["round"] = "up" + } + + roundVal := metadata["round"] + if roundVal != "up" && roundVal != "down" { + return nil, fmt.Errorf("round must be either 'up' or 'down'. default is 'up'") + } + + return metadata, nil +} + +func (e *MathExtension) Execute(ctx context.Context, metadata map[string]string, method string, args ...any) ([]any, error) { + switch method { + case "add": + return e.add(ctx, metadata, args...) + case "subtract": + return e.subtract(ctx, metadata, args...) + case "multiply": + return e.multiply(ctx, metadata, args...) + case "divide": + return e.divide(ctx, metadata, args...) + default: + return nil, fmt.Errorf("method %s not found", method) + } +} + +// add takes two integers and returns their sum +func (e *MathExtension) add(ctx context.Context, metadata map[string]string, values ...any) ([]any, error) { + if len(values) != 2 { + return nil, fmt.Errorf("expected 2 values for method Add, got %d", len(values)) + } + + val0Int, ok := values[0].(int) + if !ok { + return nil, fmt.Errorf("Argument 1 is not an int") + } + + val1Int, ok := values[1].(int) + if !ok { + return nil, fmt.Errorf("Argument 2 is not an int") + } + + var results []any + results = append(results, val0Int+val1Int) + return results, nil +} + +// subtract takes two integers and returns their difference +func (e *MathExtension) subtract(ctx context.Context, metadata map[string]string, values ...any) ([]any, error) { + if len(values) != 2 { + return nil, fmt.Errorf("expected 2 values for method Add, got %d", len(values)) + } + + val0Int, ok := values[0].(int) + if !ok { + return nil, fmt.Errorf("Argument 1 is not an int") + } + + val1Int, ok := values[1].(int) + if !ok { + return nil, fmt.Errorf("Argument 2 is not an int") + } + + var results []any + results = append(results, val0Int-val1Int) + return results, nil +} + +// multiply takes two integers and returns their product +func (e *MathExtension) multiply(ctx context.Context, metadata map[string]string, values ...any) ([]any, error) { + if len(values) != 2 { + return nil, fmt.Errorf("expected 2 values for method Add, got %d", len(values)) + } + + val0Int, ok := values[0].(int) + if !ok { + return nil, fmt.Errorf("Argument 1 is not an int") + } + + val1Int, ok := values[1].(int) + if !ok { + return nil, fmt.Errorf("Argument 2 is not an int") + } + + var results []any + results = append(results, val0Int*val1Int) + return results, nil +} + +// divide takes two integers and returns their quotient rounded up or down depending on how the extension was initialized +func (e *MathExtension) divide(ctx context.Context, metadata map[string]string, values ...any) ([]any, error) { + if len(values) != 2 { + return nil, fmt.Errorf("expected 2 values for method Divide, got %d", len(values)) + } + + val0Int, ok := values[0].(int) + if !ok { + return nil, fmt.Errorf("Argument 1 is not an int") + } + + val1Int, ok := values[1].(int) + if !ok { + return nil, fmt.Errorf("Argument 2 is not an int") + } + + bigVal1 := newBigFloat(float64(val0Int)) + + bigVal2 := newBigFloat(float64(val1Int)) + + result := new(big.Float).Quo(bigVal1, bigVal2) + + var IntResult *big.Int + var results []any + if metadata["round"] == "up" { + IntResult = roundUp(result) + } else { + IntResult = roundDown(result) + } + results = append(results, IntResult) + return results, nil +} + +// roundUp takes a big.Float and returns a new big.Float rounded up. +func roundUp(f *big.Float) *big.Int { + c := new(big.Float).SetPrec(precision).Copy(f) + r := new(big.Int) + f.Int(r) + + if c.Sub(c, new(big.Float).SetPrec(precision).SetInt(r)).Sign() > 0 { + r.Add(r, big.NewInt(1)) + } + + return r +} + +// roundDown takes a big.Float and returns a new big.Float rounded down. +func roundDown(f *big.Float) *big.Int { + r := new(big.Int) + f.Int(r) + + return r +} + +const ( + precision = 128 +) + +func newBigFloat(num float64) *big.Float { + bg := new(big.Float).SetPrec(precision) + + return bg.SetFloat64(num) +} diff --git a/extensions/actions/utils.go b/extensions/actions/utils.go deleted file mode 100644 index 01c07fa1b..000000000 --- a/extensions/actions/utils.go +++ /dev/null @@ -1,127 +0,0 @@ -package extensions - -import ( - "context" - "fmt" - "reflect" - - "github.com/cstockton/go-conv" -) - -type extensionBuilder struct { - extension *Extension -} - -// ExtensionBuilder is the interface for creating an extension server -type ExtensionBuilder interface { - // WithMethods specifies the methods that should be provided - // by the extension - WithMethods(map[string]MethodFunc) ExtensionBuilder - // WithInitializer is a function that initializes a new extension instance. - WithInitializer(InitializeFunc) ExtensionBuilder - // Named specifies the name of the extensions. - Named(string) ExtensionBuilder - - // Build creates the extensions - Build() (*Extension, error) -} - -func Builder() ExtensionBuilder { - return &extensionBuilder{ - extension: &Extension{ - methods: make(map[string]MethodFunc), - initializeFunc: func(ctx context.Context, metadata map[string]string) (map[string]string, error) { - return metadata, nil - }, - }, - } -} - -func (b *extensionBuilder) Named(name string) ExtensionBuilder { - b.extension.name = name - return b -} - -func (b *extensionBuilder) WithMethods(methods map[string]MethodFunc) ExtensionBuilder { - b.extension.methods = methods - return b -} - -func (b *extensionBuilder) WithInitializer(fn InitializeFunc) ExtensionBuilder { - b.extension.initializeFunc = fn - return b -} - -func (b *extensionBuilder) Build() (*Extension, error) { - return b.extension, nil -} - -type ExecutionContext struct { - Ctx context.Context - Metadata map[string]string -} - -// MethodFunc is a function that executes a method -type MethodFunc func(ctx *ExecutionContext, inputs ...*ScalarValue) ([]*ScalarValue, error) - -// InitializeFunc is a function that creates a new instance of an extension. -// In most cases, this should just validate the metadata being sent. -type InitializeFunc func(ctx context.Context, metadata map[string]string) (map[string]string, error) - -// WithInputsCheck checks the number of inputs. -// If the number of inputs is not equal to numInputs, it returns an error. -func WithInputsCheck(fn MethodFunc, numInputs int) MethodFunc { - return func(ctx *ExecutionContext, inputs ...*ScalarValue) ([]*ScalarValue, error) { - if len(inputs) != numInputs { - return nil, fmt.Errorf("expected %d args, got %d", numInputs, len(inputs)) - } - return fn(ctx, inputs...) - } -} - -// WithOutputsCheck checks the number of outputs. -// If the number of outputs is not equal to numOutputs, it returns an error. -func WithOutputsCheck(fn MethodFunc, numOutputs int) MethodFunc { - return func(ctx *ExecutionContext, inputs ...*ScalarValue) ([]*ScalarValue, error) { - res, err := fn(ctx, inputs...) - if err != nil { - return nil, err - } - - if len(res) != numOutputs { - return nil, fmt.Errorf("expected %d returns, got %d", numOutputs, len(res)) - } - - return res, nil - } -} - -type ScalarValue struct { - Value any -} - -func NewScalarValue(v any) (*ScalarValue, error) { - valueType := reflect.TypeOf(v) - switch valueType.Kind() { - case reflect.String, reflect.Float32, reflect.Float64: - return &ScalarValue{ - Value: v, - }, nil - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return &ScalarValue{ - Value: v, - }, nil - default: - return nil, fmt.Errorf("invalid scalar type: %s", valueType.Kind()) - } -} - -// String returns the string representation of the value. -func (s *ScalarValue) String() (string, error) { - return conv.String(s.Value) -} - -// Int returns the int representation of the value. -func (s *ScalarValue) Int() (int64, error) { - return conv.Int64(s.Value) -} diff --git a/go.mod b/go.mod index 1ee1c8053..fc318c858 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.21 replace ( github.com/kwilteam/kwil-db/core => ./core - github.com/kwilteam/kwil-db/extensions => ./extensions github.com/kwilteam/kwil-db/parse => ./parse ) @@ -23,6 +22,7 @@ require ( github.com/kwilteam/kuneiform v0.5.0-alpha.0.20231011193347-ab7495c55426 github.com/kwilteam/kwil-db/core v0.0.0 github.com/kwilteam/kwil-db/parse v0.0.0 + github.com/kwilteam/kwil-extensions v0.0.0-20230727040522-1cfd930226b7 github.com/manifoldco/promptui v0.9.0 github.com/olekukonko/tablewriter v0.0.5 github.com/spf13/cobra v1.7.0 @@ -82,7 +82,6 @@ require ( github.com/jmhodges/levigo v1.0.0 // indirect github.com/klauspost/compress v1.16.5 // indirect github.com/kwilteam/action-grammar-go v0.0.1-0.20230926160920-472768e1186c // indirect - github.com/kwilteam/kwil-extensions v0.0.0-20230727040522-1cfd930226b7 github.com/kwilteam/sql-grammar-go v0.0.3-0.20230925230724-00685e1bac32 // indirect github.com/lib/pq v1.10.7 // indirect github.com/libp2p/go-buffer-pool v0.1.0 // indirect diff --git a/internal/extensions/extensions_test.go b/internal/extensions/extensions_test.go index b662b5154..92ff7accf 100644 --- a/internal/extensions/extensions_test.go +++ b/internal/extensions/extensions_test.go @@ -4,9 +4,10 @@ package extensions_test import ( "context" + "math/big" "testing" - extActions "github.com/kwilteam/kwil-db/extensions/actions" + mathexample "github.com/kwilteam/kwil-db/extensions/actions/math_example" extensions "github.com/kwilteam/kwil-db/internal/extensions" "github.com/kwilteam/kwil-extensions/client" @@ -61,8 +62,7 @@ func Test_LocalExtension(t *testing.T) { "roundoff": "down", } - ext, err := extActions.NewMathExtension() - assert.NoError(t, err) + ext := &mathexample.MathExtension{} initializer := &extensions.ExtensionInitializer{ Extension: ext, @@ -72,9 +72,16 @@ func Test_LocalExtension(t *testing.T) { instance1, err := initializer.CreateInstance(ctx, metadata) assert.NoError(t, err) - result, err := instance1.Execute(ctx, "divide", 1, 2) + result, err := instance1.Execute(ctx, "add", 1, 2) + assert.NoError(t, err) + assert.Equal(t, int(3), result[0]) + + result, err = instance1.Execute(ctx, "add", 1.2, 2.3) + assert.Error(t, err) + + result, err = instance1.Execute(ctx, "divide", 1, 2) assert.NoError(t, err) - assert.Equal(t, int64(0), result[0]) // 1/2 rounded down to 0 + assert.Equal(t, big.NewInt(0), result[0]) // 1/2 rounded down to 0 // Create instance with incorrect metadata, uses defaults instance2, err := initializer.CreateInstance(ctx, incorrectMetadata) @@ -84,7 +91,7 @@ func Test_LocalExtension(t *testing.T) { result, err = instance2.Execute(ctx, "divide", 1, 2) assert.NoError(t, err) - assert.Equal(t, int64(1), result[0]) // 1/2 rounded up -> 1 + assert.Equal(t, big.NewInt(1), result[0]) // 1/2 rounded up -> 1 } func Test_RemoteExtension(t *testing.T) { diff --git a/internal/extensions/instance.go b/internal/extensions/instance.go index 816c01c40..5b878f14a 100644 --- a/internal/extensions/instance.go +++ b/internal/extensions/instance.go @@ -3,6 +3,8 @@ package extensions import ( "context" "strings" + + extensions "github.com/kwilteam/kwil-db/extensions/actions" ) // An instance is a single instance of an extension. @@ -13,7 +15,7 @@ import ( type Instance struct { metadata map[string]string - extension ExtensionDriver + extension extensions.Extension } func (i *Instance) Metadata() map[string]string { @@ -26,9 +28,5 @@ func (i *Instance) Name() string { func (i *Instance) Execute(ctx context.Context, method string, args ...any) ([]any, error) { lowerMethod := strings.ToLower(method) - results, err := i.extension.Execute(ctx, i.metadata, lowerMethod, args...) - if err != nil { - return nil, err - } - return results, nil + return i.extension.Execute(ctx, i.metadata, lowerMethod, args...) } diff --git a/internal/extensions/driver.go b/internal/extensions/interface.go similarity index 84% rename from internal/extensions/driver.go rename to internal/extensions/interface.go index 653794337..33864bf0e 100644 --- a/internal/extensions/driver.go +++ b/internal/extensions/interface.go @@ -3,6 +3,7 @@ package extensions import ( "context" + extensions "github.com/kwilteam/kwil-db/extensions/actions" "github.com/kwilteam/kwil-extensions/client" "github.com/kwilteam/kwil-extensions/types" ) @@ -12,14 +13,8 @@ var ( ConnectFunc Connecter = extensionConnectFunc(client.NewExtensionClient) ) -type ExtensionDriver interface { - Name() string - Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) - Execute(ctx context.Context, metadata map[string]string, method string, args ...any) ([]any, error) -} - type ExtensionInitializer struct { - Extension ExtensionDriver + Extension extensions.Extension } // CreateInstance creates an instance of the extension with the given metadata.