Skip to content

Commit

Permalink
load extensions during compile time using go build tags
Browse files Browse the repository at this point in the history
  • Loading branch information
charithabandi committed Oct 20, 2023
1 parent 8fd3d74 commit a2cc5ad
Show file tree
Hide file tree
Showing 15 changed files with 645 additions and 175 deletions.
8 changes: 6 additions & 2 deletions cmd/kwild/server/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,13 @@ func buildDatasetsModule(d *coreDependencies, eng datasets.Engine, accs datasets
}

func buildEngine(d *coreDependencies, a *sessions.AtomicCommitter) *engine.Engine {
extensions, err := connectExtensions(d.ctx, d.cfg.AppCfg.ExtensionEndpoints)
extensions, err := getExtensions(d.ctx, d.cfg.AppCfg.ExtensionEndpoints)
if err != nil {
failBuild(err, "failed to connect to extensions")
failBuild(err, "failed to get extensions")
}

for _, ext := range extensions {
d.log.Debug("registered extension", zap.String("name", ext.Name()))
}

sqlCommitRegister := &sqlCommittableRegister{
Expand Down
26 changes: 19 additions & 7 deletions cmd/kwild/server/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ import (

"github.com/kwilteam/kwil-db/core/log"
types "github.com/kwilteam/kwil-db/core/types/admin"
extensions "github.com/kwilteam/kwil-db/extensions/actions"
extActions "github.com/kwilteam/kwil-db/extensions/actions"
"github.com/kwilteam/kwil-db/internal/abci"
"github.com/kwilteam/kwil-db/internal/abci/cometbft/privval"
"github.com/kwilteam/kwil-db/internal/engine"
"github.com/kwilteam/kwil-db/internal/extensions"
"github.com/kwilteam/kwil-db/internal/kv"
"github.com/kwilteam/kwil-db/internal/sessions"
sqlSessions "github.com/kwilteam/kwil-db/internal/sessions/sql-session"
Expand All @@ -26,9 +27,18 @@ import (
cmttypes "github.com/cometbft/cometbft/types"
)

// connectExtensions connects to the provided extension urls.
func connectExtensions(ctx context.Context, urls []string) (map[string]*extensions.Extension, error) {
exts := make(map[string]*extensions.Extension, len(urls))
// getExtensions returns both the local and remote extensions. Remote extensions are identified by
// connecting to the specified extension URLs.
func getExtensions(ctx context.Context, urls []string) (map[string]extensions.ExtensionDriver, error) {
exts := make(map[string]extensions.ExtensionDriver)

for name, ext := range extActions.GetRegisteredExtensions() {
_, ok := exts[name]
if ok {
return nil, fmt.Errorf("duplicate extension name: %s", name)
}
exts[name] = ext
}

for _, url := range urls {
ext := extensions.New(url)
Expand All @@ -44,15 +54,17 @@ func connectExtensions(ctx context.Context, urls []string) (map[string]*extensio

exts[ext.Name()] = ext
}

return exts, nil
}

func adaptExtensions(exts map[string]*extensions.Extension) map[string]engine.ExtensionInitializer {
func adaptExtensions(exts map[string]extensions.ExtensionDriver) map[string]engine.ExtensionInitializer {
adapted := make(map[string]engine.ExtensionInitializer, len(exts))

for name, ext := range exts {
adapted[name] = extensionInitializeFunc(ext.CreateInstance)
initializer := &extensions.ExtensionInitializer{
Extension: ext,
}
adapted[name] = extensionInitializeFunc(initializer.CreateInstance)
}

return adapted
Expand Down
76 changes: 30 additions & 46 deletions extensions/actions/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,70 +3,54 @@ package extensions
import (
"context"
"fmt"
"strings"
)

// Local Extension
type Extension struct {
name string
url string
methods map[string]struct{}

client ExtensionClient
// Extension name
name string
// Supported methods by the extension
methods map[string]MethodFunc
// Initializer that initializes the extension
initializeFunc InitializeFunc
}

func (e *Extension) Name() string {
return e.name
}

// New connects to the given extension, and attempts to configure it with the given config.
// If the extension is not available, an error is returned.
func New(url string) *Extension {
return &Extension{
name: "",
url: url,
methods: make(map[string]struct{}),
}
}
func (e *Extension) Execute(ctx context.Context, metadata map[string]string, method string, args ...any) ([]any, error) {
var encodedArgs []*ScalarValue
for _, arg := range args {
scalarVal, err := NewScalarValue(arg)
if err != nil {
return nil, fmt.Errorf("error encoding argument: %s", err.Error())
}

func (e *Extension) Connect(ctx context.Context) error {
extClient, err := ConnectFunc.Connect(ctx, e.url)
if err != nil {
return fmt.Errorf("failed to connect to extension at %s: %w", e.url, err)
encodedArgs = append(encodedArgs, scalarVal)
}

name, err := extClient.GetName(ctx)
if err != nil {
return fmt.Errorf("failed to get extension name: %w", err)
methodFn, ok := e.methods[method]
if !ok {
return nil, fmt.Errorf("method %s not found", method)
}

e.name = name
e.client = extClient

err = e.loadMethods(ctx)
if err != nil {
return fmt.Errorf("failed to load methods for extension %s: %w", e.name, err)
execCtx := &ExecutionContext{
Ctx: ctx,
Metadata: metadata,
}

return nil
}

func (e *Extension) loadMethods(ctx context.Context) error {
methodList, err := e.client.ListMethods(ctx)
results, err := methodFn(execCtx, encodedArgs...)
if err != nil {
return fmt.Errorf("failed to list methods for extension '%s' at target '%s': %w", e.name, e.url, err)
return nil, err
}

e.methods = make(map[string]struct{})
for _, method := range methodList {
lowerName := strings.ToLower(method)

_, ok := e.methods[lowerName]
if ok {
return fmt.Errorf("extension %s has duplicate method %s. this is an issue with the extension", e.name, lowerName)
}

e.methods[lowerName] = struct{}{}
var outputs []any
for _, result := range results {
outputs = append(outputs, result.Value)
}
return outputs, nil
}

return nil
func (e *Extension) Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) {
return e.initializeFunc(ctx, metadata)
}
19 changes: 19 additions & 0 deletions extensions/actions/extension_registry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package extensions

import "strings"

var registeredExtensions = make(map[string]*Extension)

func RegisterExtension(name string, ext *Extension) error {
name = strings.ToLower(name)
if _, ok := registeredExtensions[name]; ok {
panic("extension of same name already registered: " + name)
}

registeredExtensions[name] = ext
return nil
}

func GetRegisteredExtensions() map[string]*Extension {
return registeredExtensions
}
37 changes: 0 additions & 37 deletions extensions/actions/extension_test.go

This file was deleted.

Loading

0 comments on commit a2cc5ad

Please sign in to comment.