diff --git a/cmd/kwild/server/build.go b/cmd/kwild/server/build.go index 2065ffeb6..eea8ae3d7 100644 --- a/cmd/kwild/server/build.go +++ b/cmd/kwild/server/build.go @@ -198,9 +198,13 @@ func buildDatasetsModule(d *coreDependencies, eng datasets.Engine, accs datasets } func buildEngine(d *coreDependencies, a *sessions.AtomicCommitter) *engine.Engine { - extensions, err := connectExtensions(d.ctx, d.cfg.AppCfg.ExtensionEndpoints) + extensions, err := getExtensions(d.ctx, d.cfg.AppCfg.ExtensionEndpoints) if err != nil { - failBuild(err, "failed to connect to extensions") + failBuild(err, "failed to get extensions") + } + + for _, ext := range extensions { + d.log.Debug("registered extension", zap.String("name", ext.Name())) } sqlCommitRegister := &sqlCommittableRegister{ diff --git a/cmd/kwild/server/utils.go b/cmd/kwild/server/utils.go index 379077eed..c46fdb0c0 100644 --- a/cmd/kwild/server/utils.go +++ b/cmd/kwild/server/utils.go @@ -11,10 +11,11 @@ import ( "github.com/kwilteam/kwil-db/core/log" types "github.com/kwilteam/kwil-db/core/types/admin" "github.com/kwilteam/kwil-db/core/types/transactions" - extensions "github.com/kwilteam/kwil-db/extensions/actions" + extActions "github.com/kwilteam/kwil-db/extensions/actions" "github.com/kwilteam/kwil-db/internal/abci" "github.com/kwilteam/kwil-db/internal/abci/cometbft/privval" "github.com/kwilteam/kwil-db/internal/engine" + "github.com/kwilteam/kwil-db/internal/extensions" "github.com/kwilteam/kwil-db/internal/kv" "github.com/kwilteam/kwil-db/internal/sessions" sqlSessions "github.com/kwilteam/kwil-db/internal/sessions/sql-session" @@ -28,9 +29,18 @@ import ( cmttypes "github.com/cometbft/cometbft/types" ) -// connectExtensions connects to the provided extension urls. -func connectExtensions(ctx context.Context, urls []string) (map[string]*extensions.Extension, error) { - exts := make(map[string]*extensions.Extension, len(urls)) +// getExtensions returns both the local and remote extensions. Remote extensions are identified by +// connecting to the specified extension URLs. +func getExtensions(ctx context.Context, urls []string) (map[string]extActions.Extension, error) { + exts := make(map[string]extActions.Extension) + + for name, ext := range extActions.RegisteredExtensions() { + _, ok := exts[name] + if ok { + return nil, fmt.Errorf("duplicate extension name: %s", name) + } + exts[name] = ext + } for _, url := range urls { ext := extensions.New(url) @@ -46,15 +56,17 @@ func connectExtensions(ctx context.Context, urls []string) (map[string]*extensio exts[ext.Name()] = ext } - return exts, nil } -func adaptExtensions(exts map[string]*extensions.Extension) map[string]engine.ExtensionInitializer { +func adaptExtensions(exts map[string]extActions.Extension) map[string]engine.ExtensionInitializer { adapted := make(map[string]engine.ExtensionInitializer, len(exts)) for name, ext := range exts { - adapted[name] = extensionInitializeFunc(ext.CreateInstance) + initializer := &extensions.ExtensionInitializer{ + Extension: ext, + } + adapted[name] = extensionInitializeFunc(initializer.CreateInstance) } return adapted diff --git a/extensions/actions/extension.go b/extensions/actions/extension.go deleted file mode 100644 index 48251abed..000000000 --- a/extensions/actions/extension.go +++ /dev/null @@ -1,72 +0,0 @@ -package extensions - -import ( - "context" - "fmt" - "strings" -) - -type Extension struct { - name string - url string - methods map[string]struct{} - - client ExtensionClient -} - -func (e *Extension) Name() string { - return e.name -} - -// New connects to the given extension, and attempts to configure it with the given config. -// If the extension is not available, an error is returned. -func New(url string) *Extension { - return &Extension{ - name: "", - url: url, - methods: make(map[string]struct{}), - } -} - -func (e *Extension) Connect(ctx context.Context) error { - extClient, err := ConnectFunc.Connect(ctx, e.url) - if err != nil { - return fmt.Errorf("failed to connect to extension at %s: %w", e.url, err) - } - - name, err := extClient.GetName(ctx) - if err != nil { - return fmt.Errorf("failed to get extension name: %w", err) - } - - e.name = name - e.client = extClient - - err = e.loadMethods(ctx) - if err != nil { - return fmt.Errorf("failed to load methods for extension %s: %w", e.name, err) - } - - return nil -} - -func (e *Extension) loadMethods(ctx context.Context) error { - methodList, err := e.client.ListMethods(ctx) - if err != nil { - return fmt.Errorf("failed to list methods for extension '%s' at target '%s': %w", e.name, e.url, err) - } - - e.methods = make(map[string]struct{}) - for _, method := range methodList { - lowerName := strings.ToLower(method) - - _, ok := e.methods[lowerName] - if ok { - return fmt.Errorf("extension %s has duplicate method %s. this is an issue with the extension", e.name, lowerName) - } - - e.methods[lowerName] = struct{}{} - } - - return nil -} diff --git a/extensions/actions/extension_registry.go b/extensions/actions/extension_registry.go new file mode 100644 index 000000000..f9ea44ec1 --- /dev/null +++ b/extensions/actions/extension_registry.go @@ -0,0 +1,28 @@ +package extensions + +import ( + "context" + "strings" +) + +type Extension interface { + Name() string + Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) + Execute(ctx context.Context, metadata map[string]string, method string, args ...any) ([]any, error) +} + +var registeredExtensions = make(map[string]Extension) + +func RegisterExtension(name string, ext Extension) error { + name = strings.ToLower(name) + if _, ok := registeredExtensions[name]; ok { + panic("extension of same name already registered: " + name) + } + + registeredExtensions[name] = ext + return nil +} + +func RegisteredExtensions() map[string]Extension { + return registeredExtensions +} diff --git a/extensions/actions/extension_test.go b/extensions/actions/extension_test.go deleted file mode 100644 index 7da1c4ca2..000000000 --- a/extensions/actions/extension_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package extensions_test - -import ( - "context" - "testing" - - extensions "github.com/kwilteam/kwil-db/extensions/actions" -) - -// TODO: these tests are pretty bad. -// since this is a prototype, and the package is simple, this is good for now. -func Test_Extensions(t *testing.T) { - ctx := context.Background() - ext := extensions.New("local:8080") - - err := ext.Connect(ctx) - if err != nil { - t.Fatal(err) - } - - instance, err := ext.CreateInstance(ctx, map[string]string{ - "token_address": "0x12345", - "wallet_address": "0xabcd", - }) - if err != nil { - t.Fatal(err) - } - - results, err := instance.Execute(ctx, "method1", "0x12345") - if err != nil { - t.Fatal(err) - } - - if len(results) != 2 { - t.Fatalf("expected 2 results, got %d", len(results)) - } -} diff --git a/extensions/actions/instance.go b/extensions/actions/instance.go deleted file mode 100644 index c8bc926f6..000000000 --- a/extensions/actions/instance.go +++ /dev/null @@ -1,53 +0,0 @@ -package extensions - -import ( - "context" - "fmt" - "strings" - - "github.com/kwilteam/kwil-extensions/types" -) - -// An instance is a single instance of an extension. -// Each Kuneiform schema that uses an extension will have its own instance. -// The instance is a way to encapsulate metadata. -// For example, the instance may contain the smart contract address for an ERC20 token -// that is used by the Kuneiform schema. -type Instance struct { - metadata map[string]string - - extenstion *Extension -} - -func (e *Extension) CreateInstance(ctx context.Context, metadata map[string]string) (*Instance, error) { - newMetadata, err := e.client.Initialize(ctx, metadata) - if err != nil { - return nil, err - } - - return &Instance{ - metadata: newMetadata, - extenstion: e, - }, nil -} - -func (i *Instance) Metadata() map[string]string { - return i.metadata -} - -func (i *Instance) Name() string { - return i.extenstion.name -} - -func (i *Instance) Execute(ctx context.Context, method string, args ...any) ([]any, error) { - lowerMethod := strings.ToLower(method) - _, ok := i.extenstion.methods[lowerMethod] - if !ok { - return nil, fmt.Errorf("method '%s' is not available for extension '%s' at target '%s'", lowerMethod, i.extenstion.name, i.extenstion.url) - } - - return i.extenstion.client.CallMethod(&types.ExecutionContext{ - Ctx: ctx, - Metadata: i.metadata, - }, lowerMethod, args...) -} diff --git a/extensions/actions/math_example/math.go b/extensions/actions/math_example/math.go new file mode 100644 index 000000000..bdb3633df --- /dev/null +++ b/extensions/actions/math_example/math.go @@ -0,0 +1,182 @@ +//go:build actions_math || ext_test + +package mathexample + +import ( + "context" + "fmt" + "math/big" + + extensions "github.com/kwilteam/kwil-db/extensions/actions" +) + +func init() { + mathExt := &MathExtension{} + err := extensions.RegisterExtension("math", mathExt) + if err != nil { + panic(err) + } +} + +type MathExtension struct{} + +func (e *MathExtension) Name() string { + return "math" +} + +// this initialize function checks if round is set. If not, it sets it to "up" +func (e *MathExtension) Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) { + _, ok := metadata["round"] + if !ok { + metadata["round"] = "up" + } + + roundVal := metadata["round"] + if roundVal != "up" && roundVal != "down" { + return nil, fmt.Errorf("round must be either 'up' or 'down'. default is 'up'") + } + + return metadata, nil +} + +func (e *MathExtension) Execute(ctx context.Context, metadata map[string]string, method string, args ...any) ([]any, error) { + switch method { + case "add": + return e.add(ctx, metadata, args...) + case "subtract": + return e.subtract(ctx, metadata, args...) + case "multiply": + return e.multiply(ctx, metadata, args...) + case "divide": + return e.divide(ctx, metadata, args...) + default: + return nil, fmt.Errorf("method %s not found", method) + } +} + +// add takes two integers and returns their sum +func (e *MathExtension) add(ctx context.Context, metadata map[string]string, values ...any) ([]any, error) { + if len(values) != 2 { + return nil, fmt.Errorf("expected 2 values for method Add, got %d", len(values)) + } + + val0Int, ok := values[0].(int) + if !ok { + return nil, fmt.Errorf("Argument 1 is not an int") + } + + val1Int, ok := values[1].(int) + if !ok { + return nil, fmt.Errorf("Argument 2 is not an int") + } + + var results []any + results = append(results, val0Int+val1Int) + return results, nil +} + +// subtract takes two integers and returns their difference +func (e *MathExtension) subtract(ctx context.Context, metadata map[string]string, values ...any) ([]any, error) { + if len(values) != 2 { + return nil, fmt.Errorf("expected 2 values for method Add, got %d", len(values)) + } + + val0Int, ok := values[0].(int) + if !ok { + return nil, fmt.Errorf("Argument 1 is not an int") + } + + val1Int, ok := values[1].(int) + if !ok { + return nil, fmt.Errorf("Argument 2 is not an int") + } + + var results []any + results = append(results, val0Int-val1Int) + return results, nil +} + +// multiply takes two integers and returns their product +func (e *MathExtension) multiply(ctx context.Context, metadata map[string]string, values ...any) ([]any, error) { + if len(values) != 2 { + return nil, fmt.Errorf("expected 2 values for method Add, got %d", len(values)) + } + + val0Int, ok := values[0].(int) + if !ok { + return nil, fmt.Errorf("Argument 1 is not an int") + } + + val1Int, ok := values[1].(int) + if !ok { + return nil, fmt.Errorf("Argument 2 is not an int") + } + + var results []any + results = append(results, val0Int*val1Int) + return results, nil +} + +// divide takes two integers and returns their quotient rounded up or down depending on how the extension was initialized +func (e *MathExtension) divide(ctx context.Context, metadata map[string]string, values ...any) ([]any, error) { + if len(values) != 2 { + return nil, fmt.Errorf("expected 2 values for method Divide, got %d", len(values)) + } + + val0Int, ok := values[0].(int) + if !ok { + return nil, fmt.Errorf("Argument 1 is not an int") + } + + val1Int, ok := values[1].(int) + if !ok { + return nil, fmt.Errorf("Argument 2 is not an int") + } + + bigVal1 := newBigFloat(float64(val0Int)) + + bigVal2 := newBigFloat(float64(val1Int)) + + result := new(big.Float).Quo(bigVal1, bigVal2) + + var IntResult *big.Int + var results []any + if metadata["round"] == "up" { + IntResult = roundUp(result) + } else { + IntResult = roundDown(result) + } + results = append(results, IntResult) + return results, nil +} + +// roundUp takes a big.Float and returns a new big.Float rounded up. +func roundUp(f *big.Float) *big.Int { + c := new(big.Float).SetPrec(precision).Copy(f) + r := new(big.Int) + f.Int(r) + + if c.Sub(c, new(big.Float).SetPrec(precision).SetInt(r)).Sign() > 0 { + r.Add(r, big.NewInt(1)) + } + + return r +} + +// roundDown takes a big.Float and returns a new big.Float rounded down. +func roundDown(f *big.Float) *big.Int { + r := new(big.Int) + f.Int(r) + + return r +} + +const ( + precision = 128 +) + +func newBigFloat(num float64) *big.Float { + bg := new(big.Float).SetPrec(precision) + + return bg.SetFloat64(num) +} diff --git a/extensions/actions/mocks_test.go b/extensions/actions/mocks_test.go deleted file mode 100644 index e092d28cc..000000000 --- a/extensions/actions/mocks_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package extensions_test - -import ( - "context" - - extensions "github.com/kwilteam/kwil-db/extensions/actions" - "github.com/kwilteam/kwil-extensions/client" - "github.com/kwilteam/kwil-extensions/types" -) - -func init() { - extensions.ConnectFunc = connecterFunc(mockConnect) -} - -// this is used to inject a mock connection function for testing -func mockConnect(ctx context.Context, target string, opts ...client.ClientOpt) (extensions.ExtensionClient, error) { - return &mockClient{}, nil -} - -type connecterFunc func(ctx context.Context, target string, opts ...client.ClientOpt) (extensions.ExtensionClient, error) - -func (m connecterFunc) Connect(ctx context.Context, target string, opts ...client.ClientOpt) (extensions.ExtensionClient, error) { - return &mockClient{}, nil -} - -// mockClient implements the ExtensionClient interface -type mockClient struct{} - -func (m *mockClient) GetName(ctx context.Context) (string, error) { - return "mock", nil -} - -func (m *mockClient) CallMethod(ctx *types.ExecutionContext, method string, args ...any) ([]any, error) { - return []any{"val1", 2}, nil -} - -func (m *mockClient) Close() error { - return nil -} - -func (m *mockClient) ListMethods(ctx context.Context) ([]string, error) { - return []string{"method1", "method2"}, nil -} - -func (m *mockClient) Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) { - return metadata, nil -} diff --git a/go.mod b/go.mod index dcf7273b7..fc318c858 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/kwilteam/kuneiform v0.5.0-alpha.0.20231011193347-ab7495c55426 github.com/kwilteam/kwil-db/core v0.0.0 github.com/kwilteam/kwil-db/parse v0.0.0 - github.com/kwilteam/kwil-extensions v0.0.0-20230710163303-bfa03f64ff82 + github.com/kwilteam/kwil-extensions v0.0.0-20230727040522-1cfd930226b7 github.com/manifoldco/promptui v0.9.0 github.com/olekukonko/tablewriter v0.0.5 github.com/spf13/cobra v1.7.0 diff --git a/go.sum b/go.sum index b9dfd37fe..037c7dfe8 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/kwilteam/go-sqlite v0.0.0-20230606000142-c7eaa7111421 h1:TewJpDtkIU8Z github.com/kwilteam/go-sqlite v0.0.0-20230606000142-c7eaa7111421/go.mod h1:urRZ5yExms/OcYQHq0IAPLkNoudEbfUuQdlNvhcfrKI= github.com/kwilteam/kuneiform v0.5.0-alpha.0.20231011193347-ab7495c55426 h1:IO3Myedpq5Jr7Yo/ieqQBJPqsObA84/eEwkPexweduw= github.com/kwilteam/kuneiform v0.5.0-alpha.0.20231011193347-ab7495c55426/go.mod h1:MT8wV7wVVMz0UREaaOkkInUyvZMKO7FcHZ7E4cmsgLQ= -github.com/kwilteam/kwil-extensions v0.0.0-20230710163303-bfa03f64ff82 h1:pA0ya2WrncGSxxXB0g3dVq1jZZSm1HO6Qp0/yYn4qks= -github.com/kwilteam/kwil-extensions v0.0.0-20230710163303-bfa03f64ff82/go.mod h1:+BrFrV+3qcdYIfptqjwatE5gT19azuRHJzw77wMPY8c= +github.com/kwilteam/kwil-extensions v0.0.0-20230727040522-1cfd930226b7 h1:YiPBu0pOeYOtOVfwKQqdWB07SUef9LvngF4bVFD+x34= +github.com/kwilteam/kwil-extensions v0.0.0-20230727040522-1cfd930226b7/go.mod h1:+BrFrV+3qcdYIfptqjwatE5gT19azuRHJzw77wMPY8c= github.com/kwilteam/sql-grammar-go v0.0.3-0.20230925230724-00685e1bac32 h1:NDMw+6BKSqLxFyfpbbCJNx8EOLB3+ugCUEnMpomXBeQ= github.com/kwilteam/sql-grammar-go v0.0.3-0.20230925230724-00685e1bac32/go.mod h1:OqmGyCwHfBZvYv/sYPrQ5Ih290dhlD5AcKOHDlUSS0Y= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= diff --git a/internal/extensions/extensions_test.go b/internal/extensions/extensions_test.go new file mode 100644 index 000000000..92ff7accf --- /dev/null +++ b/internal/extensions/extensions_test.go @@ -0,0 +1,124 @@ +//go:build actions_math || ext_test + +package extensions_test + +import ( + "context" + "math/big" + "testing" + + mathexample "github.com/kwilteam/kwil-db/extensions/actions/math_example" + extensions "github.com/kwilteam/kwil-db/internal/extensions" + + "github.com/kwilteam/kwil-extensions/client" + "github.com/kwilteam/kwil-extensions/types" + "github.com/stretchr/testify/assert" +) + +func init() { + extensions.ConnectFunc = connecterFunc(mockConnect) +} + +// this is used to inject a mock connection function for testing +func mockConnect(ctx context.Context, target string, opts ...client.ClientOpt) (extensions.ExtensionClient, error) { + return &mockClient{}, nil +} + +type connecterFunc func(ctx context.Context, target string, opts ...client.ClientOpt) (extensions.ExtensionClient, error) + +func (m connecterFunc) Connect(ctx context.Context, target string, opts ...client.ClientOpt) (extensions.ExtensionClient, error) { + return &mockClient{}, nil +} + +// mockClient implements the ExtensionClient interface +type mockClient struct{} + +func (m *mockClient) GetName(ctx context.Context) (string, error) { + return "mock", nil +} + +func (m *mockClient) CallMethod(ctx *types.ExecutionContext, method string, args ...any) ([]any, error) { + return []any{"val1", 2}, nil +} + +func (m *mockClient) Close() error { + return nil +} + +func (m *mockClient) ListMethods(ctx context.Context) ([]string, error) { + return []string{"method1", "method2"}, nil +} + +func (m *mockClient) Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) { + return metadata, nil +} + +func Test_LocalExtension(t *testing.T) { + ctx := context.Background() + metadata := map[string]string{ + "round": "down", + } + incorrectMetadata := map[string]string{ + "roundoff": "down", + } + + ext := &mathexample.MathExtension{} + + initializer := &extensions.ExtensionInitializer{ + Extension: ext, + } + + // Create instance with correct metadata + instance1, err := initializer.CreateInstance(ctx, metadata) + assert.NoError(t, err) + + result, err := instance1.Execute(ctx, "add", 1, 2) + assert.NoError(t, err) + assert.Equal(t, int(3), result[0]) + + result, err = instance1.Execute(ctx, "add", 1.2, 2.3) + assert.Error(t, err) + + result, err = instance1.Execute(ctx, "divide", 1, 2) + assert.NoError(t, err) + assert.Equal(t, big.NewInt(0), result[0]) // 1/2 rounded down to 0 + + // Create instance with incorrect metadata, uses defaults + instance2, err := initializer.CreateInstance(ctx, incorrectMetadata) + assert.NoError(t, err) + updatedMetadata := instance2.Metadata() + assert.Equal(t, updatedMetadata["round"], "up") + + result, err = instance2.Execute(ctx, "divide", 1, 2) + assert.NoError(t, err) + assert.Equal(t, big.NewInt(1), result[0]) // 1/2 rounded up -> 1 +} + +func Test_RemoteExtension(t *testing.T) { + ctx := context.Background() + ext := extensions.New("local:8080") + + err := ext.Connect(ctx) + if err != nil { + t.Fatal(err) + } + initializer := &extensions.ExtensionInitializer{ + Extension: ext, + } + instance, err := initializer.CreateInstance(ctx, map[string]string{ + "token_address": "0x12345", + "wallet_address": "0xabcd", + }) + if err != nil { + t.Fatal(err) + } + + results, err := instance.Execute(ctx, "method1", "0x12345") + if err != nil { + t.Fatal(err) + } + + if len(results) != 2 { + t.Fatalf("expected 2 results, got %d", len(results)) + } +} diff --git a/internal/extensions/instance.go b/internal/extensions/instance.go new file mode 100644 index 000000000..5b878f14a --- /dev/null +++ b/internal/extensions/instance.go @@ -0,0 +1,32 @@ +package extensions + +import ( + "context" + "strings" + + extensions "github.com/kwilteam/kwil-db/extensions/actions" +) + +// An instance is a single instance of an extension. +// Each Kuneiform schema that uses an extension will have its own instance. +// The instance is a way to encapsulate metadata. +// For example, the instance may contain the smart contract address for an ERC20 token +// that is used by the Kuneiform schema. +type Instance struct { + metadata map[string]string + + extension extensions.Extension +} + +func (i *Instance) Metadata() map[string]string { + return i.metadata +} + +func (i *Instance) Name() string { + return i.extension.Name() +} + +func (i *Instance) Execute(ctx context.Context, method string, args ...any) ([]any, error) { + lowerMethod := strings.ToLower(method) + return i.extension.Execute(ctx, i.metadata, lowerMethod, args...) +} diff --git a/extensions/actions/interface.go b/internal/extensions/interface.go similarity index 66% rename from extensions/actions/interface.go rename to internal/extensions/interface.go index 2adf262c5..33864bf0e 100644 --- a/extensions/actions/interface.go +++ b/internal/extensions/interface.go @@ -3,10 +3,33 @@ package extensions import ( "context" + extensions "github.com/kwilteam/kwil-db/extensions/actions" "github.com/kwilteam/kwil-extensions/client" "github.com/kwilteam/kwil-extensions/types" ) +var ( + // this can be overridden for testing + ConnectFunc Connecter = extensionConnectFunc(client.NewExtensionClient) +) + +type ExtensionInitializer struct { + Extension extensions.Extension +} + +// CreateInstance creates an instance of the extension with the given metadata. +func (e *ExtensionInitializer) CreateInstance(ctx context.Context, metadata map[string]string) (*Instance, error) { + metadata, err := e.Extension.Initialize(ctx, metadata) + if err != nil { + return nil, err + } + + return &Instance{ + metadata: metadata, + extension: e.Extension, + }, nil +} + type ExtensionClient interface { CallMethod(execCtx *types.ExecutionContext, method string, args ...any) ([]any, error) Close() error @@ -24,8 +47,3 @@ type extensionConnectFunc func(ctx context.Context, target string, opts ...clien func (e extensionConnectFunc) Connect(ctx context.Context, target string, opts ...client.ClientOpt) (ExtensionClient, error) { return e(ctx, target, opts...) } - -var ( - // this can be overridden for testing - ConnectFunc Connecter = extensionConnectFunc(client.NewExtensionClient) -) diff --git a/internal/extensions/remote_extensions.go b/internal/extensions/remote_extensions.go new file mode 100644 index 000000000..def19e2d4 --- /dev/null +++ b/internal/extensions/remote_extensions.go @@ -0,0 +1,97 @@ +package extensions + +import ( + "context" + "fmt" + "strings" + + "github.com/kwilteam/kwil-extensions/types" +) + +// Remote Extension used for docker extensions defined and deployed remotely +type RemoteExtension struct { + // Name of the extension + name string + // url of the extension server + url string + // methods supported by the extension + methods map[string]struct{} + // client to connect to the server + client ExtensionClient +} + +func (e *RemoteExtension) Name() string { + return e.name +} + +// New returns a placeholder for the RemoteExtension at a given url +func New(url string) *RemoteExtension { + return &RemoteExtension{ + name: "", + url: url, + methods: make(map[string]struct{}), + } +} + +// Initialize initializes based on the given metadata and returns the updated metadata +func (e *RemoteExtension) Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) { + return e.client.Initialize(ctx, metadata) +} + +// Execute executes the requested method of an extension. If the method is not supported, an error is returned. +func (e *RemoteExtension) Execute(ctx context.Context, metadata map[string]string, method string, args ...any) ([]any, error) { + _, ok := e.methods[method] + if !ok { + return nil, fmt.Errorf("method '%s' is not available for extension '%s' at target '%s'", method, e.name, e.url) + } + + return e.client.CallMethod(&types.ExecutionContext{ + Ctx: ctx, + Metadata: metadata, + }, method, args...) +} + +// Connect connects to the given extension, and attempts to configure it with the given config. +// If the extension is not available, an error is returned. +func (e *RemoteExtension) Connect(ctx context.Context) error { + extClient, err := ConnectFunc.Connect(ctx, e.url) + if err != nil { + return fmt.Errorf("failed to connect to extension at %s: %w", e.url, err) + } + + name, err := extClient.GetName(ctx) + if err != nil { + return fmt.Errorf("failed to get extension name: %w", err) + } + + e.name = name + e.client = extClient + + err = e.loadMethods(ctx) + if err != nil { + return fmt.Errorf("failed to load methods for extension %s: %w", e.name, err) + } + + return nil +} + +func (e *RemoteExtension) loadMethods(ctx context.Context) error { + methodList, err := e.client.ListMethods(ctx) + if err != nil { + return fmt.Errorf("failed to list methods for extension '%s' at target '%s': %w", e.name, e.url, err) + } + + e.methods = make(map[string]struct{}) + for _, method := range methodList { + lowerName := strings.ToLower(method) + + _, ok := e.methods[lowerName] + if ok { + return fmt.Errorf("extension %s has duplicate method %s. this is an issue with the extension", e.name, lowerName) + } + + e.methods[lowerName] = struct{}{} + } + + return nil +} diff --git a/test/integration/test-data/test_db.kf b/test/integration/test-data/test_db.kf index d6682ff76..97553a4e5 100644 --- a/test/integration/test-data/test_db.kf +++ b/test/integration/test-data/test_db.kf @@ -51,6 +51,7 @@ action delete_user_by_id ($id) public owner { WHERE id = $id AND public_key(wallet) = public_key(@caller); } + action create_post($id, $title, $content) public { INSERT INTO posts (id, user_id, title, content) VALUES ($id, ( @@ -112,9 +113,9 @@ action multi_select() public { SELECT * FROM users; } -action divide($numerator, $denominator) public view { - $up = math_up.div($numerator, $denominator); - $down = math_down.div($numerator, $denominator); +action divide($numerator1, $numerator2, $denominator) public view { + $up = math_up.div(abs($numerator1 + $numerator2), $denominator); + $down = math_down.div(abs($numerator1 + $numerator2), $denominator); select $up AS upper_value, $down AS lower_value; }