Skip to content

Commit

Permalink
remove scalar values and Extension struct for local extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
charithabandi committed Oct 26, 2023
1 parent 48fb093 commit e3eefeb
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 260 deletions.
6 changes: 3 additions & 3 deletions cmd/kwild/server/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ import (

// getExtensions returns both the local and remote extensions. Remote extensions are identified by
// connecting to the specified extension URLs.
func getExtensions(ctx context.Context, urls []string) (map[string]extensions.ExtensionDriver, error) {
exts := make(map[string]extensions.ExtensionDriver)
func getExtensions(ctx context.Context, urls []string) (map[string]extActions.Extension, error) {
exts := make(map[string]extActions.Extension)

for name, ext := range extActions.GetRegisteredExtensions() {
_, ok := exts[name]
Expand All @@ -59,7 +59,7 @@ func getExtensions(ctx context.Context, urls []string) (map[string]extensions.Ex
return exts, nil
}

func adaptExtensions(exts map[string]extensions.ExtensionDriver) map[string]engine.ExtensionInitializer {
func adaptExtensions(exts map[string]extActions.Extension) map[string]engine.ExtensionInitializer {
adapted := make(map[string]engine.ExtensionInitializer, len(exts))

for name, ext := range exts {
Expand Down
56 changes: 0 additions & 56 deletions extensions/actions/extension.go

This file was deleted.

17 changes: 13 additions & 4 deletions extensions/actions/extension_registry.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
package extensions

import "strings"
import (
"context"
"strings"
)

var registeredExtensions = make(map[string]*Extension)
type Extension interface {
Name() string
Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error)
Execute(ctx context.Context, metadata map[string]string, method string, args ...any) ([]any, error)
}

var registeredExtensions = make(map[string]Extension)

func RegisterExtension(name string, ext *Extension) error {
func RegisterExtension(name string, ext Extension) error {
name = strings.ToLower(name)
if _, ok := registeredExtensions[name]; ok {
panic("extension of same name already registered: " + name)
Expand All @@ -14,6 +23,6 @@ func RegisterExtension(name string, ext *Extension) error {
return nil
}

func GetRegisteredExtensions() map[string]*Extension {
func GetRegisteredExtensions() map[string]Extension {
return registeredExtensions
}
100 changes: 45 additions & 55 deletions extensions/actions/math.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,26 @@ import (
"context"
"fmt"
"math/big"

"github.com/cstockton/go-conv"
)

func init() {
ext, err := NewMathExtension()
if err != nil {
panic(err)
}

err = RegisterExtension("math", ext)
mathExt := &MathExtension{}
err := RegisterExtension("math", mathExt)
if err != nil {
panic(err)
}
}

type MathExtension struct{}

func NewMathExtension() (*Extension, error) {
mathExt := &MathExtension{}
methods := map[string]MethodFunc{
"add": mathExt.add,
"subtract": mathExt.subtract,
"multiply": mathExt.multiply,
"divide": mathExt.divide,
}

ext, err := Builder().Named("math").WithMethods(methods).WithInitializer(initialize).Build()
if err != nil {
return nil, err
}
return ext, nil
}

func (e *MathExtension) Name() string {
return "math"
}

// this initialize function checks if round is set. If not, it sets it to "up"
func initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) {
func (e *MathExtension) Initialize(ctx context.Context, metadata map[string]string) (map[string]string, error) {
_, ok := metadata["round"]
if !ok {
metadata["round"] = "up"
Expand All @@ -57,71 +39,92 @@ func initialize(ctx context.Context, metadata map[string]string) (map[string]str
return metadata, nil
}

func (e *MathExtension) add(ctx *ExecutionContext, values ...*ScalarValue) ([]*ScalarValue, error) {
func (e *MathExtension) Execute(ctx context.Context, metadata map[string]string, method string, args ...any) ([]any, error) {
switch method {
case "add":
return e.add(ctx, metadata, args...)
case "subtract":
return e.subtract(ctx, metadata, args...)
case "multiply":
return e.multiply(ctx, metadata, args...)
case "divide":
return e.divide(ctx, metadata, args...)
default:
return nil, fmt.Errorf("method %s not found", method)
}
}

func (e *MathExtension) add(ctx context.Context, metadata map[string]string, values ...any) ([]any, error) {
if len(values) != 2 {
return nil, fmt.Errorf("expected 2 values for method Add, got %d", len(values))
}

val0Int, err := values[0].Int()
val0Int, err := conv.Int(values[0])
if err != nil {
return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val0Int)
}

val1Int, err := values[1].Int()
val1Int, err := conv.Int(values[1])
if err != nil {
return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val1Int)
}

return encodeScalarValues(val0Int + val1Int)
var results []any
results = append(results, val0Int+val1Int)
return results, nil
}

func (e *MathExtension) subtract(ctx *ExecutionContext, values ...*ScalarValue) ([]*ScalarValue, error) {
func (e *MathExtension) subtract(ctx context.Context, metadata map[string]string, values ...any) ([]any, error) {
if len(values) != 2 {
return nil, fmt.Errorf("expected 2 values for method Subtract, got %d", len(values))
}

val0Int, err := values[0].Int()
val0Int, err := conv.Int(values[0])
if err != nil {
return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val0Int)
}

val1Int, err := values[1].Int()
val1Int, err := conv.Int(values[1])
if err != nil {
return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val1Int)
}

return encodeScalarValues(val0Int - val1Int)
var results []any
results = append(results, val0Int-val1Int)
return results, nil
}

func (e *MathExtension) multiply(ctx *ExecutionContext, values ...*ScalarValue) ([]*ScalarValue, error) {
func (e *MathExtension) multiply(ctx context.Context, metadata map[string]string, values ...any) ([]any, error) {
if len(values) != 2 {
return nil, fmt.Errorf("expected 2 values for method Multiply, got %d", len(values))
}

val0Int, err := values[0].Int()
val0Int, err := conv.Int(values[0])
if err != nil {
return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val0Int)
}

val1Int, err := values[1].Int()
val1Int, err := conv.Int(values[1])
if err != nil {
return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val1Int)
}

return encodeScalarValues(val0Int * val1Int)
var results []any
results = append(results, val0Int*val1Int)
return results, nil
}

func (e *MathExtension) divide(ctx *ExecutionContext, values ...*ScalarValue) ([]*ScalarValue, error) {
func (e *MathExtension) divide(ctx context.Context, metadata map[string]string, values ...any) ([]any, error) {
if len(values) != 2 {
return nil, fmt.Errorf("expected 2 values for method Divide, got %d", len(values))
}

val0Int, err := values[0].Int()
val0Int, err := conv.Int(values[0])
if err != nil {
return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val0Int)
}

val1Int, err := values[1].Int()
val1Int, err := conv.Int(values[1])
if err != nil {
return nil, fmt.Errorf("failed to convert value to int: %w. \nreceived value: %v", err, val1Int)
}
Expand All @@ -133,13 +136,14 @@ func (e *MathExtension) divide(ctx *ExecutionContext, values ...*ScalarValue) ([
result := new(big.Float).Quo(bigVal1, bigVal2)

var IntResult *big.Int
if ctx.Metadata["round"] == "up" {
var results []any
if metadata["round"] == "up" {
IntResult = roundUp(result)
} else {
IntResult = roundDown(result)
}

return encodeScalarValues(IntResult.Int64())
results = append(results, IntResult)
return results, nil
}

// roundUp takes a big.Float and returns a new big.Float rounded up.
Expand All @@ -163,20 +167,6 @@ func roundDown(f *big.Float) *big.Int {
return r
}

func encodeScalarValues(values ...any) ([]*ScalarValue, error) {
scalarValues := make([]*ScalarValue, len(values))
for i, v := range values {
scalarValue, err := NewScalarValue(v)
if err != nil {
return nil, err
}

scalarValues[i] = scalarValue
}

return scalarValues, nil
}

const (
precision = 128
)
Expand Down
Loading

0 comments on commit e3eefeb

Please sign in to comment.