Skip to content

Commit

Permalink
load extensions during compile time using go build tags
Browse files Browse the repository at this point in the history
  • Loading branch information
jchappelow authored and charithabandi committed Oct 13, 2023
1 parent 2e5e6fb commit 10f07f4
Show file tree
Hide file tree
Showing 18 changed files with 654 additions and 176 deletions.
8 changes: 6 additions & 2 deletions cmd/kwild/server/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,13 @@ func buildDatasetsModule(d *coreDependencies, eng datasets.Engine, accs datasets
}

func buildEngine(d *coreDependencies, a *sessions.AtomicCommitter) *engine.Engine {
extensions, err := connectExtensions(d.ctx, d.cfg.AppCfg.ExtensionEndpoints)
extensions, err := getExtensions(d.ctx, d.cfg.AppCfg.ExtensionEndpoints)
if err != nil {
failBuild(err, "failed to connect to extensions")
failBuild(err, "failed to get extensions")
}

for _, ext := range extensions {
d.log.Debug("registered extension", zap.String("name", ext.Name()))
}

sqlCommitRegister := &sqlCommittableRegister{
Expand Down
55 changes: 49 additions & 6 deletions cmd/kwild/server/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ import (

"github.com/kwilteam/kwil-db/core/log"
types "github.com/kwilteam/kwil-db/core/types/admin"
extensions "github.com/kwilteam/kwil-db/extensions/actions"
extActions "github.com/kwilteam/kwil-db/extensions/actions"
"github.com/kwilteam/kwil-db/internal/abci"
"github.com/kwilteam/kwil-db/internal/abci/cometbft/privval"
"github.com/kwilteam/kwil-db/internal/engine"
"github.com/kwilteam/kwil-db/internal/extensions"
"github.com/kwilteam/kwil-db/internal/kv"
"github.com/kwilteam/kwil-db/internal/sessions"
sqlSessions "github.com/kwilteam/kwil-db/internal/sessions/sql-session"
Expand All @@ -27,11 +28,11 @@ import (
)

// connectExtensions connects to the provided extension urls.
func connectExtensions(ctx context.Context, urls []string) (map[string]*extensions.Extension, error) {
exts := make(map[string]*extensions.Extension, len(urls))
func getRemoteExtensions(ctx context.Context, urls []string) (map[string]extensions.ExtensionDriver, error) {
exts := make(map[string]extensions.ExtensionDriver, len(urls))

for _, url := range urls {
ext := extensions.New(url)
ext := extActions.New(url)
err := ext.Connect(ctx)
if err != nil {
return nil, fmt.Errorf("failed to connect extension '%s': %w", ext.Name(), err)
Expand All @@ -48,11 +49,53 @@ func connectExtensions(ctx context.Context, urls []string) (map[string]*extensio
return exts, nil
}

func adaptExtensions(exts map[string]*extensions.Extension) map[string]engine.ExtensionInitializer {
func getCompiledExtensions() (map[string]extensions.ExtensionDriver, error) {
var exts = make(map[string]extensions.ExtensionDriver)

for name, ext := range extActions.GetRegisteredExtensions() {
exts[name] = ext
}
return exts, nil
}

func getExtensions(ctx context.Context, urls []string) (map[string]extensions.ExtensionDriver, error) {
exts := make(map[string]extensions.ExtensionDriver)

compiledExts, err := getCompiledExtensions()
if err != nil {
return nil, err
}
for name, ext := range compiledExts {
_, ok := exts[name]
if ok {
return nil, fmt.Errorf("duplicate extension name: %s", name)
}
exts[name] = ext
}

remoteExts, err := getRemoteExtensions(ctx, urls)
if err != nil {
return nil, err
}
for name, ext := range remoteExts {
_, ok := exts[name]
if ok {
return nil, fmt.Errorf("duplicate extension name: %s", name)
}
exts[name] = ext
}

return exts, nil
}

func adaptExtensions(exts map[string]extensions.ExtensionDriver) map[string]engine.ExtensionInitializer {
adapted := make(map[string]engine.ExtensionInitializer, len(exts))

for name, ext := range exts {
adapted[name] = extensionInitializeFunc(ext.CreateInstance)
extInit := &extensions.ExtensionInitializer{
Extension: ext,
}
adapted[name] = extensionInitializeFunc(extInit.CreateInstance)
}

return adapted
Expand Down
51 changes: 51 additions & 0 deletions extensions/actions/builder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package extensions

import "context"

type extensionBuilder struct {
extension *Extension
}

// ExtensionBuilder is the interface for creating an extension server
type ExtensionBuilder interface {
// WithMethods specifies the methods that should be provided
// by the extension
WithMethods(map[string]MethodFunc) ExtensionBuilder
// WithInitializer is a function that initializes a new extension instance.
WithInitializer(InitializeFunc) ExtensionBuilder
// Named specifies the name of the extensions.
Named(string) ExtensionBuilder

// Build creates the extensions
Build() (*Extension, error)
}

func Builder() ExtensionBuilder {
return &extensionBuilder{
extension: &Extension{
methods: make(map[string]MethodFunc),
initializeFunc: func(ctx context.Context, metadata map[string]string) (map[string]string, error) {
return metadata, nil
},
},
}
}

func (b *extensionBuilder) Named(name string) ExtensionBuilder {
b.extension.name = name
return b
}

func (b *extensionBuilder) WithMethods(methods map[string]MethodFunc) ExtensionBuilder {
b.extension.methods = methods
return b
}

func (b *extensionBuilder) WithInitializer(fn InitializeFunc) ExtensionBuilder {
b.extension.initializeFunc = fn
return b
}

func (b *extensionBuilder) Build() (*Extension, error) {
return b.extension, nil
}
72 changes: 26 additions & 46 deletions extensions/actions/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,70 +3,50 @@ package extensions
import (
"context"
"fmt"
"strings"
)

type Extension struct {
name string
url string
methods map[string]struct{}

client ExtensionClient
name string
methods map[string]MethodFunc
initializeFunc InitializeFunc
}

func (e *Extension) Name() string {
return e.name
}

// New connects to the given extension, and attempts to configure it with the given config.
// If the extension is not available, an error is returned.
func New(url string) *Extension {
return &Extension{
name: "",
url: url,
methods: make(map[string]struct{}),
}
}
func (e *Extension) Execute(ctx context.Context, metadata map[string]string, method string, args ...any) ([]any, error) {
var encodedArgs []*ScalarValue
for _, arg := range args {
scalarVal, err := NewScalarValue(arg)
if err != nil {
return nil, fmt.Errorf("error encoding argument: %s", err.Error())
}

func (e *Extension) Connect(ctx context.Context) error {
extClient, err := ConnectFunc.Connect(ctx, e.url)
if err != nil {
return fmt.Errorf("failed to connect to extension at %s: %w", e.url, err)
encodedArgs = append(encodedArgs, scalarVal)
}

name, err := extClient.GetName(ctx)
if err != nil {
return fmt.Errorf("failed to get extension name: %w", err)
methodFn, ok := e.methods[method]
if !ok {
return nil, fmt.Errorf("method %s not found", method)
}

e.name = name
e.client = extClient

err = e.loadMethods(ctx)
if err != nil {
return fmt.Errorf("failed to load methods for extension %s: %w", e.name, err)
execCtx := &ExecutionContext{
Ctx: ctx,
Metadata: metadata,
}

return nil
}

func (e *Extension) loadMethods(ctx context.Context) error {
methodList, err := e.client.ListMethods(ctx)
results, err := methodFn(execCtx, encodedArgs...)
if err != nil {
return fmt.Errorf("failed to list methods for extension '%s' at target '%s': %w", e.name, e.url, err)
return nil, err
}

e.methods = make(map[string]struct{})
for _, method := range methodList {
lowerName := strings.ToLower(method)

_, ok := e.methods[lowerName]
if ok {
return fmt.Errorf("extension %s has duplicate method %s. this is an issue with the extension", e.name, lowerName)
}

e.methods[lowerName] = struct{}{}
var outputs []any
for _, result := range results {
outputs = append(outputs, result.Value)
}
return outputs, nil
}

return nil
func (e *Extension) Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) {
return e.initializeFunc(ctx, metadata)
}
19 changes: 19 additions & 0 deletions extensions/actions/extension_registry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package extensions

import "strings"

var registeredExtensions = make(map[string]*Extension)

func RegisterExtension(name string, ext *Extension) error {
name = strings.ToLower(name)
if _, ok := registeredExtensions[name]; ok {
panic("extension of same name already registered: " + name)
}

registeredExtensions[name] = ext
return nil
}

func GetRegisteredExtensions() map[string]*Extension {
return registeredExtensions
}
62 changes: 43 additions & 19 deletions extensions/actions/extension_test.go
Original file line number Diff line number Diff line change
@@ -1,37 +1,61 @@
//go:build ext_test

package extensions_test

import (
"context"
"testing"

extensions "github.com/kwilteam/kwil-db/extensions/actions"
"github.com/stretchr/testify/assert"
)

// TODO: these tests are pretty bad.
// since this is a prototype, and the package is simple, this is good for now.
func Test_Extensions(t *testing.T) {
ctx := context.Background()
ext := extensions.New("local:8080")
math_ext, err := extensions.NewMathExtension()
assert.NoError(t, err)

err := ext.Connect(ctx)
if err != nil {
t.Fatal(err)
// Create an instance with incorrect metadata, instance should be created with default metadata: up
incorrectMetadata := map[string]string{
"roundoff": "up",
}
instance1, err := math_ext.CreateInstance(ctx, incorrectMetadata)
assert.NoError(t, err)
assert.NotNil(t, instance1)

instance, err := ext.CreateInstance(ctx, map[string]string{
"token_address": "0x12345",
"wallet_address": "0xabcd",
})
if err != nil {
t.Fatal(err)
}
// Verify that the metadata was updated with the default value "round: up"
updatedMetadata := instance1.Metadata()
assert.Equal(t, "up", updatedMetadata["round"])

results, err := instance.Execute(ctx, "method1", "0x12345")
if err != nil {
t.Fatal(err)
// Create an instance with correct metadata
correctMetadata := map[string]string{
"round": "down",
}

if len(results) != 2 {
t.Fatalf("expected 2 results, got %d", len(results))
}
instance2, err := math_ext.CreateInstance(ctx, correctMetadata)
assert.NoError(t, err)

// test that the instance has the correct name
name := instance2.Name()
assert.Equal(t, "math", name)

// Execute an available method
// Instance1: round: up
result, err := instance1.Execute(ctx, "divide", 1, 2)
assert.NoError(t, err)
assert.Equal(t, int64(1), result[0])

// Instance2: round: down
result, err = instance2.Execute(ctx, "divide", 1, 2)
assert.NoError(t, err)
assert.Equal(t, int64(0), result[0])

// Check that the methods are case insensitive
result, err = instance2.Execute(ctx, "ADD", 1, 2)
assert.NoError(t, err)
assert.Equal(t, int64(3), result[0])

// Execute an unavailable method
_, err = instance2.Execute(ctx, "modulus", 1, 2)
assert.Error(t, err)
}
Loading

0 comments on commit 10f07f4

Please sign in to comment.