From a2cc5ad108f434d0c91826c16bc75baca7ee1ba4 Mon Sep 17 00:00:00 2001 From: charithabandi Date: Tue, 10 Oct 2023 17:01:57 -0500 Subject: [PATCH] load extensions during compile time using go build tags --- cmd/kwild/server/build.go | 8 +- cmd/kwild/server/utils.go | 26 ++- extensions/actions/extension.go | 76 +++---- extensions/actions/extension_registry.go | 19 ++ extensions/actions/extension_test.go | 37 ---- extensions/actions/math.go | 188 ++++++++++++++++++ extensions/actions/mocks_test.go | 47 ----- extensions/actions/utils.go | 127 ++++++++++++ go.mod | 3 +- go.sum | 4 +- .../extensions/driver.go | 33 ++- internal/extensions/extensions_test.go | 117 +++++++++++ .../extensions}/instance.go | 31 +-- internal/extensions/remote_extensions.go | 97 +++++++++ test/integration/test-data/test_db.kf | 7 +- 15 files changed, 645 insertions(+), 175 deletions(-) create mode 100644 extensions/actions/extension_registry.go delete mode 100644 extensions/actions/extension_test.go create mode 100644 extensions/actions/math.go delete mode 100644 extensions/actions/mocks_test.go create mode 100644 extensions/actions/utils.go rename extensions/actions/interface.go => internal/extensions/driver.go (59%) create mode 100644 internal/extensions/extensions_test.go rename {extensions/actions => internal/extensions}/instance.go (51%) create mode 100644 internal/extensions/remote_extensions.go diff --git a/cmd/kwild/server/build.go b/cmd/kwild/server/build.go index 9ae2665b2..7c046ccb2 100644 --- a/cmd/kwild/server/build.go +++ b/cmd/kwild/server/build.go @@ -195,9 +195,13 @@ func buildDatasetsModule(d *coreDependencies, eng datasets.Engine, accs datasets } func buildEngine(d *coreDependencies, a *sessions.AtomicCommitter) *engine.Engine { - extensions, err := connectExtensions(d.ctx, d.cfg.AppCfg.ExtensionEndpoints) + extensions, err := getExtensions(d.ctx, d.cfg.AppCfg.ExtensionEndpoints) if err != nil { - failBuild(err, "failed to connect to extensions") + failBuild(err, "failed to get extensions") + } + + for _, ext := range extensions { + d.log.Debug("registered extension", zap.String("name", ext.Name())) } sqlCommitRegister := &sqlCommittableRegister{ diff --git a/cmd/kwild/server/utils.go b/cmd/kwild/server/utils.go index cfb2bbd18..3959302a3 100644 --- a/cmd/kwild/server/utils.go +++ b/cmd/kwild/server/utils.go @@ -9,10 +9,11 @@ import ( "github.com/kwilteam/kwil-db/core/log" types "github.com/kwilteam/kwil-db/core/types/admin" - extensions "github.com/kwilteam/kwil-db/extensions/actions" + extActions "github.com/kwilteam/kwil-db/extensions/actions" "github.com/kwilteam/kwil-db/internal/abci" "github.com/kwilteam/kwil-db/internal/abci/cometbft/privval" "github.com/kwilteam/kwil-db/internal/engine" + "github.com/kwilteam/kwil-db/internal/extensions" "github.com/kwilteam/kwil-db/internal/kv" "github.com/kwilteam/kwil-db/internal/sessions" sqlSessions "github.com/kwilteam/kwil-db/internal/sessions/sql-session" @@ -26,9 +27,18 @@ import ( cmttypes "github.com/cometbft/cometbft/types" ) -// connectExtensions connects to the provided extension urls. -func connectExtensions(ctx context.Context, urls []string) (map[string]*extensions.Extension, error) { - exts := make(map[string]*extensions.Extension, len(urls)) +// getExtensions returns both the local and remote extensions. Remote extensions are identified by +// connecting to the specified extension URLs. +func getExtensions(ctx context.Context, urls []string) (map[string]extensions.ExtensionDriver, error) { + exts := make(map[string]extensions.ExtensionDriver) + + for name, ext := range extActions.GetRegisteredExtensions() { + _, ok := exts[name] + if ok { + return nil, fmt.Errorf("duplicate extension name: %s", name) + } + exts[name] = ext + } for _, url := range urls { ext := extensions.New(url) @@ -44,15 +54,17 @@ func connectExtensions(ctx context.Context, urls []string) (map[string]*extensio exts[ext.Name()] = ext } - return exts, nil } -func adaptExtensions(exts map[string]*extensions.Extension) map[string]engine.ExtensionInitializer { +func adaptExtensions(exts map[string]extensions.ExtensionDriver) map[string]engine.ExtensionInitializer { adapted := make(map[string]engine.ExtensionInitializer, len(exts)) for name, ext := range exts { - adapted[name] = extensionInitializeFunc(ext.CreateInstance) + initializer := &extensions.ExtensionInitializer{ + Extension: ext, + } + adapted[name] = extensionInitializeFunc(initializer.CreateInstance) } return adapted diff --git a/extensions/actions/extension.go b/extensions/actions/extension.go index 48251abed..e62dcc566 100644 --- a/extensions/actions/extension.go +++ b/extensions/actions/extension.go @@ -3,70 +3,54 @@ package extensions import ( "context" "fmt" - "strings" ) +// Local Extension type Extension struct { - name string - url string - methods map[string]struct{} - - client ExtensionClient + // Extension name + name string + // Supported methods by the extension + methods map[string]MethodFunc + // Initializer that initializes the extension + initializeFunc InitializeFunc } func (e *Extension) Name() string { return e.name } -// New connects to the given extension, and attempts to configure it with the given config. -// If the extension is not available, an error is returned. -func New(url string) *Extension { - return &Extension{ - name: "", - url: url, - methods: make(map[string]struct{}), - } -} +func (e *Extension) Execute(ctx context.Context, metadata map[string]string, method string, args ...any) ([]any, error) { + var encodedArgs []*ScalarValue + for _, arg := range args { + scalarVal, err := NewScalarValue(arg) + if err != nil { + return nil, fmt.Errorf("error encoding argument: %s", err.Error()) + } -func (e *Extension) Connect(ctx context.Context) error { - extClient, err := ConnectFunc.Connect(ctx, e.url) - if err != nil { - return fmt.Errorf("failed to connect to extension at %s: %w", e.url, err) + encodedArgs = append(encodedArgs, scalarVal) } - name, err := extClient.GetName(ctx) - if err != nil { - return fmt.Errorf("failed to get extension name: %w", err) + methodFn, ok := e.methods[method] + if !ok { + return nil, fmt.Errorf("method %s not found", method) } - e.name = name - e.client = extClient - - err = e.loadMethods(ctx) - if err != nil { - return fmt.Errorf("failed to load methods for extension %s: %w", e.name, err) + execCtx := &ExecutionContext{ + Ctx: ctx, + Metadata: metadata, } - - return nil -} - -func (e *Extension) loadMethods(ctx context.Context) error { - methodList, err := e.client.ListMethods(ctx) + results, err := methodFn(execCtx, encodedArgs...) if err != nil { - return fmt.Errorf("failed to list methods for extension '%s' at target '%s': %w", e.name, e.url, err) + return nil, err } - e.methods = make(map[string]struct{}) - for _, method := range methodList { - lowerName := strings.ToLower(method) - - _, ok := e.methods[lowerName] - if ok { - return fmt.Errorf("extension %s has duplicate method %s. this is an issue with the extension", e.name, lowerName) - } - - e.methods[lowerName] = struct{}{} + var outputs []any + for _, result := range results { + outputs = append(outputs, result.Value) } + return outputs, nil +} - return nil +func (e *Extension) Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) { + return e.initializeFunc(ctx, metadata) } diff --git a/extensions/actions/extension_registry.go b/extensions/actions/extension_registry.go new file mode 100644 index 000000000..a88e66afb --- /dev/null +++ b/extensions/actions/extension_registry.go @@ -0,0 +1,19 @@ +package extensions + +import "strings" + +var registeredExtensions = make(map[string]*Extension) + +func RegisterExtension(name string, ext *Extension) error { + name = strings.ToLower(name) + if _, ok := registeredExtensions[name]; ok { + panic("extension of same name already registered: " + name) + } + + registeredExtensions[name] = ext + return nil +} + +func GetRegisteredExtensions() map[string]*Extension { + return registeredExtensions +} diff --git a/extensions/actions/extension_test.go b/extensions/actions/extension_test.go deleted file mode 100644 index 7da1c4ca2..000000000 --- a/extensions/actions/extension_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package extensions_test - -import ( - "context" - "testing" - - extensions "github.com/kwilteam/kwil-db/extensions/actions" -) - -// TODO: these tests are pretty bad. -// since this is a prototype, and the package is simple, this is good for now. -func Test_Extensions(t *testing.T) { - ctx := context.Background() - ext := extensions.New("local:8080") - - err := ext.Connect(ctx) - if err != nil { - t.Fatal(err) - } - - instance, err := ext.CreateInstance(ctx, map[string]string{ - "token_address": "0x12345", - "wallet_address": "0xabcd", - }) - if err != nil { - t.Fatal(err) - } - - results, err := instance.Execute(ctx, "method1", "0x12345") - if err != nil { - t.Fatal(err) - } - - if len(results) != 2 { - t.Fatalf("expected 2 results, got %d", len(results)) - } -} diff --git a/extensions/actions/math.go b/extensions/actions/math.go new file mode 100644 index 000000000..c3261b717 --- /dev/null +++ b/extensions/actions/math.go @@ -0,0 +1,188 @@ +//go:build actions_math || ext_test + +package extensions + +import ( + "context" + "fmt" + "math/big" +) + +func init() { + ext, err := NewMathExtension() + if err != nil { + panic(err) + } + + err = RegisterExtension("math", ext) + if err != nil { + panic(err) + } +} + +type MathExtension struct{} + +func NewMathExtension() (*Extension, error) { + mathExt := &MathExtension{} + methods := map[string]MethodFunc{ + "add": mathExt.add, + "subtract": mathExt.subtract, + "multiply": mathExt.multiply, + "divide": mathExt.divide, + } + + ext, err := Builder().Named("math").WithMethods(methods).WithInitializer(initialize).Build() + if err != nil { + return nil, err + } + return ext, nil +} + +func (e *MathExtension) Name() string { + return "math" +} + +// this initialize function checks if round is set. If not, it sets it to "up" +func initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) { + _, ok := metadata["round"] + if !ok { + metadata["round"] = "up" + } + + roundVal := metadata["round"] + if roundVal != "up" && roundVal != "down" { + return nil, fmt.Errorf("round must be either 'up' or 'down'. default is 'up'") + } + + return metadata, nil +} + +func (e *MathExtension) add(ctx *ExecutionContext, values ...*ScalarValue) ([]*ScalarValue, error) { + if len(values) != 2 { + return nil, fmt.Errorf("expected 2 values for method Add, got %d", len(values)) + } + + val0Int, err := values[0].Int() + if err != nil { + return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val0Int) + } + + val1Int, err := values[1].Int() + if err != nil { + return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val1Int) + } + + return encodeScalarValues(val0Int + val1Int) +} + +func (e *MathExtension) subtract(ctx *ExecutionContext, values ...*ScalarValue) ([]*ScalarValue, error) { + if len(values) != 2 { + return nil, fmt.Errorf("expected 2 values for method Subtract, got %d", len(values)) + } + + val0Int, err := values[0].Int() + if err != nil { + return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val0Int) + } + + val1Int, err := values[1].Int() + if err != nil { + return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val1Int) + } + + return encodeScalarValues(val0Int - val1Int) +} + +func (e *MathExtension) multiply(ctx *ExecutionContext, values ...*ScalarValue) ([]*ScalarValue, error) { + if len(values) != 2 { + return nil, fmt.Errorf("expected 2 values for method Multiply, got %d", len(values)) + } + + val0Int, err := values[0].Int() + if err != nil { + return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val0Int) + } + + val1Int, err := values[1].Int() + if err != nil { + return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val1Int) + } + + return encodeScalarValues(val0Int * val1Int) +} + +func (e *MathExtension) divide(ctx *ExecutionContext, values ...*ScalarValue) ([]*ScalarValue, error) { + if len(values) != 2 { + return nil, fmt.Errorf("expected 2 values for method Divide, got %d", len(values)) + } + + val0Int, err := values[0].Int() + if err != nil { + return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val0Int) + } + + val1Int, err := values[1].Int() + if err != nil { + return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val1Int) + } + + bigVal1 := newBigFloat(float64(val0Int)) + + bigVal2 := newBigFloat(float64(val1Int)) + + result := new(big.Float).Quo(bigVal1, bigVal2) + + var IntResult *big.Int + if ctx.Metadata["round"] == "up" { + IntResult = roundUp(result) + } else { + IntResult = roundDown(result) + } + + return encodeScalarValues(IntResult.Int64()) +} + +// roundUp takes a big.Float and returns a new big.Float rounded up. +func roundUp(f *big.Float) *big.Int { + c := new(big.Float).SetPrec(precision).Copy(f) + r := new(big.Int) + f.Int(r) + + if c.Sub(c, new(big.Float).SetPrec(precision).SetInt(r)).Sign() > 0 { + r.Add(r, big.NewInt(1)) + } + + return r +} + +// roundDown takes a big.Float and returns a new big.Float rounded down. +func roundDown(f *big.Float) *big.Int { + r := new(big.Int) + f.Int(r) + + return r +} + +func encodeScalarValues(values ...any) ([]*ScalarValue, error) { + scalarValues := make([]*ScalarValue, len(values)) + for i, v := range values { + scalarValue, err := NewScalarValue(v) + if err != nil { + return nil, err + } + + scalarValues[i] = scalarValue + } + + return scalarValues, nil +} + +const ( + precision = 128 +) + +func newBigFloat(num float64) *big.Float { + bg := new(big.Float).SetPrec(precision) + + return bg.SetFloat64(num) +} diff --git a/extensions/actions/mocks_test.go b/extensions/actions/mocks_test.go deleted file mode 100644 index e092d28cc..000000000 --- a/extensions/actions/mocks_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package extensions_test - -import ( - "context" - - extensions "github.com/kwilteam/kwil-db/extensions/actions" - "github.com/kwilteam/kwil-extensions/client" - "github.com/kwilteam/kwil-extensions/types" -) - -func init() { - extensions.ConnectFunc = connecterFunc(mockConnect) -} - -// this is used to inject a mock connection function for testing -func mockConnect(ctx context.Context, target string, opts ...client.ClientOpt) (extensions.ExtensionClient, error) { - return &mockClient{}, nil -} - -type connecterFunc func(ctx context.Context, target string, opts ...client.ClientOpt) (extensions.ExtensionClient, error) - -func (m connecterFunc) Connect(ctx context.Context, target string, opts ...client.ClientOpt) (extensions.ExtensionClient, error) { - return &mockClient{}, nil -} - -// mockClient implements the ExtensionClient interface -type mockClient struct{} - -func (m *mockClient) GetName(ctx context.Context) (string, error) { - return "mock", nil -} - -func (m *mockClient) CallMethod(ctx *types.ExecutionContext, method string, args ...any) ([]any, error) { - return []any{"val1", 2}, nil -} - -func (m *mockClient) Close() error { - return nil -} - -func (m *mockClient) ListMethods(ctx context.Context) ([]string, error) { - return []string{"method1", "method2"}, nil -} - -func (m *mockClient) Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) { - return metadata, nil -} diff --git a/extensions/actions/utils.go b/extensions/actions/utils.go new file mode 100644 index 000000000..01c07fa1b --- /dev/null +++ b/extensions/actions/utils.go @@ -0,0 +1,127 @@ +package extensions + +import ( + "context" + "fmt" + "reflect" + + "github.com/cstockton/go-conv" +) + +type extensionBuilder struct { + extension *Extension +} + +// ExtensionBuilder is the interface for creating an extension server +type ExtensionBuilder interface { + // WithMethods specifies the methods that should be provided + // by the extension + WithMethods(map[string]MethodFunc) ExtensionBuilder + // WithInitializer is a function that initializes a new extension instance. + WithInitializer(InitializeFunc) ExtensionBuilder + // Named specifies the name of the extensions. + Named(string) ExtensionBuilder + + // Build creates the extensions + Build() (*Extension, error) +} + +func Builder() ExtensionBuilder { + return &extensionBuilder{ + extension: &Extension{ + methods: make(map[string]MethodFunc), + initializeFunc: func(ctx context.Context, metadata map[string]string) (map[string]string, error) { + return metadata, nil + }, + }, + } +} + +func (b *extensionBuilder) Named(name string) ExtensionBuilder { + b.extension.name = name + return b +} + +func (b *extensionBuilder) WithMethods(methods map[string]MethodFunc) ExtensionBuilder { + b.extension.methods = methods + return b +} + +func (b *extensionBuilder) WithInitializer(fn InitializeFunc) ExtensionBuilder { + b.extension.initializeFunc = fn + return b +} + +func (b *extensionBuilder) Build() (*Extension, error) { + return b.extension, nil +} + +type ExecutionContext struct { + Ctx context.Context + Metadata map[string]string +} + +// MethodFunc is a function that executes a method +type MethodFunc func(ctx *ExecutionContext, inputs ...*ScalarValue) ([]*ScalarValue, error) + +// InitializeFunc is a function that creates a new instance of an extension. +// In most cases, this should just validate the metadata being sent. +type InitializeFunc func(ctx context.Context, metadata map[string]string) (map[string]string, error) + +// WithInputsCheck checks the number of inputs. +// If the number of inputs is not equal to numInputs, it returns an error. +func WithInputsCheck(fn MethodFunc, numInputs int) MethodFunc { + return func(ctx *ExecutionContext, inputs ...*ScalarValue) ([]*ScalarValue, error) { + if len(inputs) != numInputs { + return nil, fmt.Errorf("expected %d args, got %d", numInputs, len(inputs)) + } + return fn(ctx, inputs...) + } +} + +// WithOutputsCheck checks the number of outputs. +// If the number of outputs is not equal to numOutputs, it returns an error. +func WithOutputsCheck(fn MethodFunc, numOutputs int) MethodFunc { + return func(ctx *ExecutionContext, inputs ...*ScalarValue) ([]*ScalarValue, error) { + res, err := fn(ctx, inputs...) + if err != nil { + return nil, err + } + + if len(res) != numOutputs { + return nil, fmt.Errorf("expected %d returns, got %d", numOutputs, len(res)) + } + + return res, nil + } +} + +type ScalarValue struct { + Value any +} + +func NewScalarValue(v any) (*ScalarValue, error) { + valueType := reflect.TypeOf(v) + switch valueType.Kind() { + case reflect.String, reflect.Float32, reflect.Float64: + return &ScalarValue{ + Value: v, + }, nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return &ScalarValue{ + Value: v, + }, nil + default: + return nil, fmt.Errorf("invalid scalar type: %s", valueType.Kind()) + } +} + +// String returns the string representation of the value. +func (s *ScalarValue) String() (string, error) { + return conv.String(s.Value) +} + +// Int returns the int representation of the value. +func (s *ScalarValue) Int() (int64, error) { + return conv.Int64(s.Value) +} diff --git a/go.mod b/go.mod index dcf7273b7..1ee1c8053 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.21 replace ( github.com/kwilteam/kwil-db/core => ./core + github.com/kwilteam/kwil-db/extensions => ./extensions github.com/kwilteam/kwil-db/parse => ./parse ) @@ -22,7 +23,6 @@ require ( github.com/kwilteam/kuneiform v0.5.0-alpha.0.20231011193347-ab7495c55426 github.com/kwilteam/kwil-db/core v0.0.0 github.com/kwilteam/kwil-db/parse v0.0.0 - github.com/kwilteam/kwil-extensions v0.0.0-20230710163303-bfa03f64ff82 github.com/manifoldco/promptui v0.9.0 github.com/olekukonko/tablewriter v0.0.5 github.com/spf13/cobra v1.7.0 @@ -82,6 +82,7 @@ require ( github.com/jmhodges/levigo v1.0.0 // indirect github.com/klauspost/compress v1.16.5 // indirect github.com/kwilteam/action-grammar-go v0.0.1-0.20230926160920-472768e1186c // indirect + github.com/kwilteam/kwil-extensions v0.0.0-20230727040522-1cfd930226b7 github.com/kwilteam/sql-grammar-go v0.0.3-0.20230925230724-00685e1bac32 // indirect github.com/lib/pq v1.10.7 // indirect github.com/libp2p/go-buffer-pool v0.1.0 // indirect diff --git a/go.sum b/go.sum index b9dfd37fe..037c7dfe8 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/kwilteam/go-sqlite v0.0.0-20230606000142-c7eaa7111421 h1:TewJpDtkIU8Z github.com/kwilteam/go-sqlite v0.0.0-20230606000142-c7eaa7111421/go.mod h1:urRZ5yExms/OcYQHq0IAPLkNoudEbfUuQdlNvhcfrKI= github.com/kwilteam/kuneiform v0.5.0-alpha.0.20231011193347-ab7495c55426 h1:IO3Myedpq5Jr7Yo/ieqQBJPqsObA84/eEwkPexweduw= github.com/kwilteam/kuneiform v0.5.0-alpha.0.20231011193347-ab7495c55426/go.mod h1:MT8wV7wVVMz0UREaaOkkInUyvZMKO7FcHZ7E4cmsgLQ= -github.com/kwilteam/kwil-extensions v0.0.0-20230710163303-bfa03f64ff82 h1:pA0ya2WrncGSxxXB0g3dVq1jZZSm1HO6Qp0/yYn4qks= -github.com/kwilteam/kwil-extensions v0.0.0-20230710163303-bfa03f64ff82/go.mod h1:+BrFrV+3qcdYIfptqjwatE5gT19azuRHJzw77wMPY8c= +github.com/kwilteam/kwil-extensions v0.0.0-20230727040522-1cfd930226b7 h1:YiPBu0pOeYOtOVfwKQqdWB07SUef9LvngF4bVFD+x34= +github.com/kwilteam/kwil-extensions v0.0.0-20230727040522-1cfd930226b7/go.mod h1:+BrFrV+3qcdYIfptqjwatE5gT19azuRHJzw77wMPY8c= github.com/kwilteam/sql-grammar-go v0.0.3-0.20230925230724-00685e1bac32 h1:NDMw+6BKSqLxFyfpbbCJNx8EOLB3+ugCUEnMpomXBeQ= github.com/kwilteam/sql-grammar-go v0.0.3-0.20230925230724-00685e1bac32/go.mod h1:OqmGyCwHfBZvYv/sYPrQ5Ih290dhlD5AcKOHDlUSS0Y= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= diff --git a/extensions/actions/interface.go b/internal/extensions/driver.go similarity index 59% rename from extensions/actions/interface.go rename to internal/extensions/driver.go index 2adf262c5..653794337 100644 --- a/extensions/actions/interface.go +++ b/internal/extensions/driver.go @@ -7,6 +7,34 @@ import ( "github.com/kwilteam/kwil-extensions/types" ) +var ( + // this can be overridden for testing + ConnectFunc Connecter = extensionConnectFunc(client.NewExtensionClient) +) + +type ExtensionDriver interface { + Name() string + Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) + Execute(ctx context.Context, metadata map[string]string, method string, args ...any) ([]any, error) +} + +type ExtensionInitializer struct { + Extension ExtensionDriver +} + +// CreateInstance creates an instance of the extension with the given metadata. +func (e *ExtensionInitializer) CreateInstance(ctx context.Context, metadata map[string]string) (*Instance, error) { + metadata, err := e.Extension.Initialize(ctx, metadata) + if err != nil { + return nil, err + } + + return &Instance{ + metadata: metadata, + extension: e.Extension, + }, nil +} + type ExtensionClient interface { CallMethod(execCtx *types.ExecutionContext, method string, args ...any) ([]any, error) Close() error @@ -24,8 +52,3 @@ type extensionConnectFunc func(ctx context.Context, target string, opts ...clien func (e extensionConnectFunc) Connect(ctx context.Context, target string, opts ...client.ClientOpt) (ExtensionClient, error) { return e(ctx, target, opts...) } - -var ( - // this can be overridden for testing - ConnectFunc Connecter = extensionConnectFunc(client.NewExtensionClient) -) diff --git a/internal/extensions/extensions_test.go b/internal/extensions/extensions_test.go new file mode 100644 index 000000000..b662b5154 --- /dev/null +++ b/internal/extensions/extensions_test.go @@ -0,0 +1,117 @@ +//go:build actions_math || ext_test + +package extensions_test + +import ( + "context" + "testing" + + extActions "github.com/kwilteam/kwil-db/extensions/actions" + extensions "github.com/kwilteam/kwil-db/internal/extensions" + + "github.com/kwilteam/kwil-extensions/client" + "github.com/kwilteam/kwil-extensions/types" + "github.com/stretchr/testify/assert" +) + +func init() { + extensions.ConnectFunc = connecterFunc(mockConnect) +} + +// this is used to inject a mock connection function for testing +func mockConnect(ctx context.Context, target string, opts ...client.ClientOpt) (extensions.ExtensionClient, error) { + return &mockClient{}, nil +} + +type connecterFunc func(ctx context.Context, target string, opts ...client.ClientOpt) (extensions.ExtensionClient, error) + +func (m connecterFunc) Connect(ctx context.Context, target string, opts ...client.ClientOpt) (extensions.ExtensionClient, error) { + return &mockClient{}, nil +} + +// mockClient implements the ExtensionClient interface +type mockClient struct{} + +func (m *mockClient) GetName(ctx context.Context) (string, error) { + return "mock", nil +} + +func (m *mockClient) CallMethod(ctx *types.ExecutionContext, method string, args ...any) ([]any, error) { + return []any{"val1", 2}, nil +} + +func (m *mockClient) Close() error { + return nil +} + +func (m *mockClient) ListMethods(ctx context.Context) ([]string, error) { + return []string{"method1", "method2"}, nil +} + +func (m *mockClient) Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) { + return metadata, nil +} + +func Test_LocalExtension(t *testing.T) { + ctx := context.Background() + metadata := map[string]string{ + "round": "down", + } + incorrectMetadata := map[string]string{ + "roundoff": "down", + } + + ext, err := extActions.NewMathExtension() + assert.NoError(t, err) + + initializer := &extensions.ExtensionInitializer{ + Extension: ext, + } + + // Create instance with correct metadata + instance1, err := initializer.CreateInstance(ctx, metadata) + assert.NoError(t, err) + + result, err := instance1.Execute(ctx, "divide", 1, 2) + assert.NoError(t, err) + assert.Equal(t, int64(0), result[0]) // 1/2 rounded down to 0 + + // Create instance with incorrect metadata, uses defaults + instance2, err := initializer.CreateInstance(ctx, incorrectMetadata) + assert.NoError(t, err) + updatedMetadata := instance2.Metadata() + assert.Equal(t, updatedMetadata["round"], "up") + + result, err = instance2.Execute(ctx, "divide", 1, 2) + assert.NoError(t, err) + assert.Equal(t, int64(1), result[0]) // 1/2 rounded up -> 1 +} + +func Test_RemoteExtension(t *testing.T) { + ctx := context.Background() + ext := extensions.New("local:8080") + + err := ext.Connect(ctx) + if err != nil { + t.Fatal(err) + } + initializer := &extensions.ExtensionInitializer{ + Extension: ext, + } + instance, err := initializer.CreateInstance(ctx, map[string]string{ + "token_address": "0x12345", + "wallet_address": "0xabcd", + }) + if err != nil { + t.Fatal(err) + } + + results, err := instance.Execute(ctx, "method1", "0x12345") + if err != nil { + t.Fatal(err) + } + + if len(results) != 2 { + t.Fatalf("expected 2 results, got %d", len(results)) + } +} diff --git a/extensions/actions/instance.go b/internal/extensions/instance.go similarity index 51% rename from extensions/actions/instance.go rename to internal/extensions/instance.go index c8bc926f6..816c01c40 100644 --- a/extensions/actions/instance.go +++ b/internal/extensions/instance.go @@ -2,10 +2,7 @@ package extensions import ( "context" - "fmt" "strings" - - "github.com/kwilteam/kwil-extensions/types" ) // An instance is a single instance of an extension. @@ -16,19 +13,7 @@ import ( type Instance struct { metadata map[string]string - extenstion *Extension -} - -func (e *Extension) CreateInstance(ctx context.Context, metadata map[string]string) (*Instance, error) { - newMetadata, err := e.client.Initialize(ctx, metadata) - if err != nil { - return nil, err - } - - return &Instance{ - metadata: newMetadata, - extenstion: e, - }, nil + extension ExtensionDriver } func (i *Instance) Metadata() map[string]string { @@ -36,18 +21,14 @@ func (i *Instance) Metadata() map[string]string { } func (i *Instance) Name() string { - return i.extenstion.name + return i.extension.Name() } func (i *Instance) Execute(ctx context.Context, method string, args ...any) ([]any, error) { lowerMethod := strings.ToLower(method) - _, ok := i.extenstion.methods[lowerMethod] - if !ok { - return nil, fmt.Errorf("method '%s' is not available for extension '%s' at target '%s'", lowerMethod, i.extenstion.name, i.extenstion.url) + results, err := i.extension.Execute(ctx, i.metadata, lowerMethod, args...) + if err != nil { + return nil, err } - - return i.extenstion.client.CallMethod(&types.ExecutionContext{ - Ctx: ctx, - Metadata: i.metadata, - }, lowerMethod, args...) + return results, nil } diff --git a/internal/extensions/remote_extensions.go b/internal/extensions/remote_extensions.go new file mode 100644 index 000000000..def19e2d4 --- /dev/null +++ b/internal/extensions/remote_extensions.go @@ -0,0 +1,97 @@ +package extensions + +import ( + "context" + "fmt" + "strings" + + "github.com/kwilteam/kwil-extensions/types" +) + +// Remote Extension used for docker extensions defined and deployed remotely +type RemoteExtension struct { + // Name of the extension + name string + // url of the extension server + url string + // methods supported by the extension + methods map[string]struct{} + // client to connect to the server + client ExtensionClient +} + +func (e *RemoteExtension) Name() string { + return e.name +} + +// New returns a placeholder for the RemoteExtension at a given url +func New(url string) *RemoteExtension { + return &RemoteExtension{ + name: "", + url: url, + methods: make(map[string]struct{}), + } +} + +// Initialize initializes based on the given metadata and returns the updated metadata +func (e *RemoteExtension) Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) { + return e.client.Initialize(ctx, metadata) +} + +// Execute executes the requested method of an extension. If the method is not supported, an error is returned. +func (e *RemoteExtension) Execute(ctx context.Context, metadata map[string]string, method string, args ...any) ([]any, error) { + _, ok := e.methods[method] + if !ok { + return nil, fmt.Errorf("method '%s' is not available for extension '%s' at target '%s'", method, e.name, e.url) + } + + return e.client.CallMethod(&types.ExecutionContext{ + Ctx: ctx, + Metadata: metadata, + }, method, args...) +} + +// Connect connects to the given extension, and attempts to configure it with the given config. +// If the extension is not available, an error is returned. +func (e *RemoteExtension) Connect(ctx context.Context) error { + extClient, err := ConnectFunc.Connect(ctx, e.url) + if err != nil { + return fmt.Errorf("failed to connect to extension at %s: %w", e.url, err) + } + + name, err := extClient.GetName(ctx) + if err != nil { + return fmt.Errorf("failed to get extension name: %w", err) + } + + e.name = name + e.client = extClient + + err = e.loadMethods(ctx) + if err != nil { + return fmt.Errorf("failed to load methods for extension %s: %w", e.name, err) + } + + return nil +} + +func (e *RemoteExtension) loadMethods(ctx context.Context) error { + methodList, err := e.client.ListMethods(ctx) + if err != nil { + return fmt.Errorf("failed to list methods for extension '%s' at target '%s': %w", e.name, e.url, err) + } + + e.methods = make(map[string]struct{}) + for _, method := range methodList { + lowerName := strings.ToLower(method) + + _, ok := e.methods[lowerName] + if ok { + return fmt.Errorf("extension %s has duplicate method %s. this is an issue with the extension", e.name, lowerName) + } + + e.methods[lowerName] = struct{}{} + } + + return nil +} diff --git a/test/integration/test-data/test_db.kf b/test/integration/test-data/test_db.kf index d6682ff76..97553a4e5 100644 --- a/test/integration/test-data/test_db.kf +++ b/test/integration/test-data/test_db.kf @@ -51,6 +51,7 @@ action delete_user_by_id ($id) public owner { WHERE id = $id AND public_key(wallet) = public_key(@caller); } + action create_post($id, $title, $content) public { INSERT INTO posts (id, user_id, title, content) VALUES ($id, ( @@ -112,9 +113,9 @@ action multi_select() public { SELECT * FROM users; } -action divide($numerator, $denominator) public view { - $up = math_up.div($numerator, $denominator); - $down = math_down.div($numerator, $denominator); +action divide($numerator1, $numerator2, $denominator) public view { + $up = math_up.div(abs($numerator1 + $numerator2), $denominator); + $down = math_down.div(abs($numerator1 + $numerator2), $denominator); select $up AS upper_value, $down AS lower_value; }