Skip to content

Commit

Permalink
load extensions during compile time using go build tags
Browse files Browse the repository at this point in the history
  • Loading branch information
charithabandi committed Oct 13, 2023
1 parent c4b2347 commit 21bfcbe
Show file tree
Hide file tree
Showing 28 changed files with 437 additions and 1,171 deletions.
8 changes: 5 additions & 3 deletions cmd/kwild/server/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/kwilteam/kwil-db/cmd/kwild/config"
extensions "github.com/kwilteam/kwil-db/extensions/actions"
"github.com/kwilteam/kwil-db/internal/abci"
"github.com/kwilteam/kwil-db/internal/abci/cometbft"
"github.com/kwilteam/kwil-db/internal/abci/snapshots"
Expand Down Expand Up @@ -195,9 +196,10 @@ func buildDatasetsModule(d *coreDependencies, eng datasets.Engine, accs datasets
}

func buildEngine(d *coreDependencies, a *sessions.AtomicCommitter) *engine.Engine {
extensions, err := connectExtensions(d.ctx, d.cfg.AppCfg.ExtensionEndpoints)
if err != nil {
failBuild(err, "failed to connect to extensions")
extensions := extensions.GetRegisteredExtensions()

for _, ext := range extensions {
d.log.Debug("registered extension", zap.String("name", ext.Name()))
}

sqlCommitRegister := &sqlCommittableRegister{
Expand Down
23 changes: 0 additions & 23 deletions cmd/kwild/server/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"context"
"errors"
"fmt"
"strings"

"github.com/kwilteam/kwil-db/core/log"
Expand All @@ -26,28 +25,6 @@ import (
cmttypes "github.com/cometbft/cometbft/types"
)

// connectExtensions connects to the provided extension urls.
func connectExtensions(ctx context.Context, urls []string) (map[string]*extensions.Extension, error) {
exts := make(map[string]*extensions.Extension, len(urls))

for _, url := range urls {
ext := extensions.New(url)
err := ext.Connect(ctx)
if err != nil {
return nil, fmt.Errorf("failed to connect extension '%s': %w", ext.Name(), err)
}

_, ok := exts[ext.Name()]
if ok {
return nil, fmt.Errorf("duplicate extension name: %s", ext.Name())
}

exts[ext.Name()] = ext
}

return exts, nil
}

func adaptExtensions(exts map[string]*extensions.Extension) map[string]engine.ExtensionInitializer {
adapted := make(map[string]engine.ExtensionInitializer, len(exts))

Expand Down
51 changes: 51 additions & 0 deletions extensions/actions/builder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package extensions

import "context"

type extensionBuilder struct {
extension *Extension
}

// ExtensionBuilder is the interface for creating an extension server
type ExtensionBuilder interface {
// WithMethods specifies the methods that should be provided
// by the extension
WithMethods(map[string]MethodFunc) ExtensionBuilder
// WithInitializer is a function that initializes a new extension instance.
WithInitializer(InitializeFunc) ExtensionBuilder
// Named specifies the name of the extensions.
Named(string) ExtensionBuilder

// Build creates the extensions
Build() (*Extension, error)
}

func Builder() ExtensionBuilder {
return &extensionBuilder{
extension: &Extension{
methods: make(map[string]MethodFunc),
initializeFunc: func(ctx context.Context, metadata map[string]string) (map[string]string, error) {
return metadata, nil
},
},
}
}

func (b *extensionBuilder) Named(name string) ExtensionBuilder {
b.extension.name = name
return b
}

func (b *extensionBuilder) WithMethods(methods map[string]MethodFunc) ExtensionBuilder {
b.extension.methods = methods
return b
}

func (b *extensionBuilder) WithInitializer(fn InitializeFunc) ExtensionBuilder {
b.extension.initializeFunc = fn
return b
}

func (b *extensionBuilder) Build() (*Extension, error) {
return b.extension, nil
}
69 changes: 13 additions & 56 deletions extensions/actions/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,70 +3,27 @@ package extensions
import (
"context"
"fmt"
"strings"
)

type Extension struct {
name string
url string
methods map[string]struct{}

client ExtensionClient
}

func (e *Extension) Name() string {
return e.name
name string
methods map[string]MethodFunc
initializeFunc InitializeFunc
}

// New connects to the given extension, and attempts to configure it with the given config.
// If the extension is not available, an error is returned.
func New(url string) *Extension {
return &Extension{
name: "",
url: url,
methods: make(map[string]struct{}),
func (e *Extension) execute(ctx *ExecutionContext, method string, args ...*ScalarValue) ([]*ScalarValue, error) {
methodFn, ok := e.methods[method]
if !ok {
return nil, fmt.Errorf("method %s not found", method)
}
}

func (e *Extension) Connect(ctx context.Context) error {
extClient, err := ConnectFunc.Connect(ctx, e.url)
if err != nil {
return fmt.Errorf("failed to connect to extension at %s: %w", e.url, err)
}

name, err := extClient.GetName(ctx)
if err != nil {
return fmt.Errorf("failed to get extension name: %w", err)
}

e.name = name
e.client = extClient

err = e.loadMethods(ctx)
if err != nil {
return fmt.Errorf("failed to load methods for extension %s: %w", e.name, err)
}

return nil
return methodFn(ctx, args...)
}

func (e *Extension) loadMethods(ctx context.Context) error {
methodList, err := e.client.ListMethods(ctx)
if err != nil {
return fmt.Errorf("failed to list methods for extension '%s' at target '%s': %w", e.name, e.url, err)
}

e.methods = make(map[string]struct{})
for _, method := range methodList {
lowerName := strings.ToLower(method)

_, ok := e.methods[lowerName]
if ok {
return fmt.Errorf("extension %s has duplicate method %s. this is an issue with the extension", e.name, lowerName)
}

e.methods[lowerName] = struct{}{}
}
func (e *Extension) initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) {
return e.initializeFunc(ctx, metadata)
}

return nil
func (e *Extension) Name() string {
return e.name
}
19 changes: 19 additions & 0 deletions extensions/actions/extension_registry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package extensions

import "strings"

var registeredExtensions = make(map[string]*Extension)

func RegisterExtension(name string, ext *Extension) error {
name = strings.ToLower(name)
if _, ok := registeredExtensions[name]; ok {
panic("extension of same name already registered: " + name)
}

registeredExtensions[name] = ext
return nil
}

func GetRegisteredExtensions() map[string]*Extension {
return registeredExtensions
}
62 changes: 43 additions & 19 deletions extensions/actions/extension_test.go
Original file line number Diff line number Diff line change
@@ -1,37 +1,61 @@
//go:build ext_test

package extensions_test

import (
"context"
"testing"

extensions "github.com/kwilteam/kwil-db/extensions/actions"
"github.com/stretchr/testify/assert"
)

// TODO: these tests are pretty bad.
// since this is a prototype, and the package is simple, this is good for now.
func Test_Extensions(t *testing.T) {
ctx := context.Background()
ext := extensions.New("local:8080")
math_ext, err := extensions.NewMathExtension()
assert.NoError(t, err)

err := ext.Connect(ctx)
if err != nil {
t.Fatal(err)
// Create an instance with incorrect metadata, instance should be created with default metadata: up
incorrectMetadata := map[string]string{
"roundoff": "up",
}
instance1, err := math_ext.CreateInstance(ctx, incorrectMetadata)
assert.NoError(t, err)
assert.NotNil(t, instance1)

instance, err := ext.CreateInstance(ctx, map[string]string{
"token_address": "0x12345",
"wallet_address": "0xabcd",
})
if err != nil {
t.Fatal(err)
}
// Verify that the metadata was updated with the default value "round: up"
updatedMetadata := instance1.Metadata()
assert.Equal(t, "up", updatedMetadata["round"])

results, err := instance.Execute(ctx, "method1", "0x12345")
if err != nil {
t.Fatal(err)
// Create an instance with correct metadata
correctMetadata := map[string]string{
"round": "down",
}

if len(results) != 2 {
t.Fatalf("expected 2 results, got %d", len(results))
}
instance2, err := math_ext.CreateInstance(ctx, correctMetadata)
assert.NoError(t, err)

// test that the instance has the correct name
name := instance2.Name()
assert.Equal(t, "math", name)

// Execute an available method
// Instance1: round: up
result, err := instance1.Execute(ctx, "divide", 1, 2)
assert.NoError(t, err)
assert.Equal(t, int64(1), result[0])

// Instance2: round: down
result, err = instance2.Execute(ctx, "divide", 1, 2)
assert.NoError(t, err)
assert.Equal(t, int64(0), result[0])

// Check that the methods are case insensitive
result, err = instance2.Execute(ctx, "ADD", 1, 2)
assert.NoError(t, err)
assert.Equal(t, int64(3), result[0])

// Execute an unavailable method
_, err = instance2.Execute(ctx, "modulus", 1, 2)
assert.Error(t, err)
}
40 changes: 27 additions & 13 deletions extensions/actions/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"context"
"fmt"
"strings"

"github.com/kwilteam/kwil-extensions/types"
)

// An instance is a single instance of an extension.
Expand All @@ -16,18 +14,18 @@ import (
type Instance struct {
metadata map[string]string

extenstion *Extension
extension *Extension
}

func (e *Extension) CreateInstance(ctx context.Context, metadata map[string]string) (*Instance, error) {
newMetadata, err := e.client.Initialize(ctx, metadata)
newMetadata, err := e.initialize(ctx, metadata)
if err != nil {
return nil, err
}

return &Instance{
metadata: newMetadata,
extenstion: e,
metadata: newMetadata,
extension: e,
}, nil
}

Expand All @@ -36,18 +34,34 @@ func (i *Instance) Metadata() map[string]string {
}

func (i *Instance) Name() string {
return i.extenstion.name
return i.extension.name
}

func (i *Instance) Execute(ctx context.Context, method string, args ...any) ([]any, error) {
lowerMethod := strings.ToLower(method)
_, ok := i.extenstion.methods[lowerMethod]
if !ok {
return nil, fmt.Errorf("method '%s' is not available for extension '%s' at target '%s'", lowerMethod, i.extenstion.name, i.extenstion.url)
var encodedArgs []*ScalarValue
for _, arg := range args {
scalarVal, err := NewScalarValue(arg)
if err != nil {
return nil, fmt.Errorf("error encoding argument: %s", err.Error())
}

encodedArgs = append(encodedArgs, scalarVal)
}

return i.extenstion.client.CallMethod(&types.ExecutionContext{
execCtx := &ExecutionContext{
Ctx: ctx,
Metadata: i.metadata,
}, lowerMethod, args...)
}

lowerMethod := strings.ToLower(method)
results, err := i.extension.execute(execCtx, lowerMethod, encodedArgs...)
if err != nil {
return nil, err
}

var outputs []any
for _, result := range results {
outputs = append(outputs, result.Value)
}
return outputs, nil
}
Loading

0 comments on commit 21bfcbe

Please sign in to comment.