From 36673fd519cc1e3abf3bd5510810618034c736da Mon Sep 17 00:00:00 2001 From: Seena Burns Date: Fri, 13 Aug 2021 08:09:25 -0700 Subject: [PATCH 1/3] Add map, repeated, type conversions --- go/protomodule/BUILD | 3 + go/protomodule/protomodule_enum.go | 4 + go/protomodule/protomodule_list.go | 191 +++++++++++++++++++++ go/protomodule/protomodule_map.go | 173 +++++++++++++++++++ go/protomodule/protomodule_test.go | 259 +++++++++++++++++++++++++++++ go/protomodule/type_conversions.go | 249 +++++++++++++++++++++++++++ 6 files changed, 879 insertions(+) create mode 100644 go/protomodule/protomodule_list.go create mode 100644 go/protomodule/protomodule_map.go create mode 100644 go/protomodule/type_conversions.go diff --git a/go/protomodule/BUILD b/go/protomodule/BUILD index 60dadde..92896b6 100644 --- a/go/protomodule/BUILD +++ b/go/protomodule/BUILD @@ -5,8 +5,11 @@ go_library( srcs = [ "protomodule.go", "protomodule_enum.go", + "protomodule_list.go", + "protomodule_map.go", "protomodule_message_type.go", "protomodule_package.go", + "type_conversions.go", ], importpath = "github.com/stripe/skycfg/go/protomodule", visibility = ["//visibility:public"], diff --git a/go/protomodule/protomodule_enum.go b/go/protomodule/protomodule_enum.go index 0a16125..35171a7 100644 --- a/go/protomodule/protomodule_enum.go +++ b/go/protomodule/protomodule_enum.go @@ -93,6 +93,10 @@ func (v *protoEnumValue) Hash() (uint32, error) { return starlark.MakeInt64(int64(v.value.Number())).Hash() } +func (v *protoEnumValue) enumNumber() protoreflect.EnumNumber { + return v.value.Number() +} + func (v *protoEnumValue) CompareSameType(op syntax.Token, y starlark.Value, depth int) (bool, error) { other := y.(*protoEnumValue) switch op { diff --git a/go/protomodule/protomodule_list.go b/go/protomodule/protomodule_list.go new file mode 100644 index 0000000..83f9cba --- /dev/null +++ b/go/protomodule/protomodule_list.go @@ -0,0 +1,191 @@ +// Copyright 2021 The Skycfg Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package protomodule + +import ( + "fmt" + + "go.starlark.net/starlark" + "go.starlark.net/syntax" + "google.golang.org/protobuf/reflect/protoreflect" +) + +var allowedListMethods = map[string]func(*protoRepeated) starlark.Value{ + "clear": nil, + "append": (*protoRepeated).wrapAppend, + "extend": (*protoRepeated).wrapExtend, +} + +// protoRepeated wraps an underlying starlark.List to provide typechecking on +// wrties +// +// starlark.List is heterogeneous, where protoRepeated enforces all values +// conform to the given fieldDesc +type protoRepeated struct { + fieldDesc protoreflect.FieldDescriptor + list *starlark.List +} + +var _ starlark.Value = (*protoRepeated)(nil) +var _ starlark.Iterable = (*protoRepeated)(nil) +var _ starlark.Sequence = (*protoRepeated)(nil) +var _ starlark.Indexable = (*protoRepeated)(nil) +var _ starlark.HasAttrs = (*protoRepeated)(nil) +var _ starlark.HasSetIndex = (*protoRepeated)(nil) +var _ starlark.HasBinary = (*protoRepeated)(nil) +var _ starlark.Comparable = (*protoRepeated)(nil) + +func newProtoRepeated(fieldDesc protoreflect.FieldDescriptor) *protoRepeated { + return &protoRepeated{fieldDesc, starlark.NewList(nil)} +} + +func newProtoRepeatedFromList(fieldDesc protoreflect.FieldDescriptor, l *starlark.List) (*protoRepeated, error) { + out := &protoRepeated{fieldDesc, l} + for i := 0; i < l.Len(); i++ { + err := scalarTypeCheck(fieldDesc, l.Index(i)) + if err != nil { + return nil, err + } + } + return out, nil +} + +func (r *protoRepeated) Attr(name string) (starlark.Value, error) { + wrapper, ok := allowedListMethods[name] + if !ok { + return nil, nil + } + if wrapper != nil { + return wrapper(r), nil + } + return r.list.Attr(name) +} + +func (r *protoRepeated) AttrNames() []string { return r.list.AttrNames() } +func (r *protoRepeated) Freeze() { r.list.Freeze() } +func (r *protoRepeated) Hash() (uint32, error) { return r.list.Hash() } +func (r *protoRepeated) Index(i int) starlark.Value { return r.list.Index(i) } +func (r *protoRepeated) Iterate() starlark.Iterator { return r.list.Iterate() } +func (r *protoRepeated) Len() int { return r.list.Len() } +func (r *protoRepeated) Slice(x, y, step int) starlark.Value { return r.list.Slice(x, y, step) } +func (r *protoRepeated) String() string { return r.list.String() } +func (r *protoRepeated) Truth() starlark.Bool { return r.list.Truth() } + +func (r *protoRepeated) Type() string { + return fmt.Sprintf("list<%s>", typeName(r.fieldDesc)) +} + +func (r *protoRepeated) CompareSameType(op syntax.Token, y starlark.Value, depth int) (bool, error) { + other, ok := y.(*protoRepeated) + if !ok { + return false, nil + } + + return starlark.CompareDepth(op, r.list, other.list, depth) +} + +func (r *protoRepeated) Append(v starlark.Value) error { + err := scalarTypeCheck(r.fieldDesc, v) + if err != nil { + return err + } + + err = r.list.Append(v) + if err != nil { + return err + } + + return nil +} + +func (r *protoRepeated) SetIndex(i int, v starlark.Value) error { + err := scalarTypeCheck(r.fieldDesc, v) + if err != nil { + return err + } + + r.list.SetIndex(i, v) + if err != nil { + return err + } + + return nil +} + +func (r *protoRepeated) Extend(iterable starlark.Iterable) error { + iter := iterable.Iterate() + defer iter.Done() + + var val starlark.Value + for iter.Next(&val) { + err := r.Append(val) + if err != nil { + return err + } + } + + return nil +} + +func (r *protoRepeated) Binary(op syntax.Token, y starlark.Value, side starlark.Side) (starlark.Value, error) { + if op == syntax.PLUS { + if side == starlark.Left { + switch y := y.(type) { + case *starlark.List: + return starlark.Binary(op, r.list, y) + case *protoRepeated: + return starlark.Binary(op, r.list, y.list) + } + return nil, nil + } + if side == starlark.Right { + if _, ok := y.(*starlark.List); ok { + return starlark.Binary(op, y, r.list) + } + return nil, nil + } + } + return nil, nil +} + +func (r *protoRepeated) wrapAppend() starlark.Value { + impl := func(thread *starlark.Thread, b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { + var val starlark.Value + if err := starlark.UnpackPositionalArgs("append", args, kwargs, 1, &val); err != nil { + return nil, err + } + if err := r.Append(val); err != nil { + return nil, err + } + return starlark.None, nil + } + return starlark.NewBuiltin("append", impl).BindReceiver(r) +} + +func (r *protoRepeated) wrapExtend() starlark.Value { + impl := func(thread *starlark.Thread, b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { + var val starlark.Iterable + if err := starlark.UnpackPositionalArgs("extend", args, kwargs, 1, &val); err != nil { + return nil, err + } + if err := r.Extend(val); err != nil { + return nil, err + } + return starlark.None, nil + } + return starlark.NewBuiltin("extend", impl).BindReceiver(r) +} diff --git a/go/protomodule/protomodule_map.go b/go/protomodule/protomodule_map.go new file mode 100644 index 0000000..4818ed2 --- /dev/null +++ b/go/protomodule/protomodule_map.go @@ -0,0 +1,173 @@ +// Copyright 2021 The Skycfg Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package protomodule + +import ( + "fmt" + + "go.starlark.net/starlark" + "go.starlark.net/syntax" + "google.golang.org/protobuf/reflect/protoreflect" +) + +var allowedDictMethods = map[string]func(*protoMap) starlark.Value{ + "clear": nil, + "get": nil, + "items": nil, + "keys": nil, + "setdefault": (*protoMap).wrapSetDefault, + "update": (*protoMap).wrapUpdate, + "values": nil, +} + +// protoMap wraps an underlying starlark.Dict to enforce typechecking +type protoMap struct { + mapKey protoreflect.FieldDescriptor + mapValue protoreflect.FieldDescriptor + dict *starlark.Dict +} + +var _ starlark.Value = (*protoMap)(nil) +var _ starlark.Iterable = (*protoMap)(nil) +var _ starlark.Sequence = (*protoMap)(nil) +var _ starlark.HasAttrs = (*protoMap)(nil) +var _ starlark.HasSetKey = (*protoMap)(nil) +var _ starlark.Comparable = (*protoMap)(nil) + +func newProtoMap(mapKey protoreflect.FieldDescriptor, mapValue protoreflect.FieldDescriptor) *protoMap { + return &protoMap{ + mapKey: mapKey, + mapValue: mapValue, + dict: starlark.NewDict(0), + } +} + +func newProtoMapFromDict(mapKey protoreflect.FieldDescriptor, mapValue protoreflect.FieldDescriptor, d *starlark.Dict) (*protoMap, error) { + out := &protoMap{ + mapKey: mapKey, + mapValue: mapValue, + dict: d, + } + + for _, item := range d.Items() { + err := out.typeCheck(item[0], item[1]) + if err != nil { + return nil, err + } + } + + return out, nil +} + +func (m *protoMap) Attr(name string) (starlark.Value, error) { + wrapper, ok := allowedDictMethods[name] + if !ok { + return nil, nil + } + if wrapper != nil { + return wrapper(m), nil + } + return m.dict.Attr(name) +} + +func (m *protoMap) AttrNames() []string { return m.dict.AttrNames() } +func (m *protoMap) Freeze() { m.dict.Freeze() } +func (m *protoMap) Hash() (uint32, error) { return m.dict.Hash() } +func (m *protoMap) Get(k starlark.Value) (starlark.Value, bool, error) { return m.dict.Get(k) } +func (m *protoMap) Iterate() starlark.Iterator { return m.dict.Iterate() } +func (m *protoMap) Len() int { return m.dict.Len() } +func (m *protoMap) String() string { return m.dict.String() } +func (m *protoMap) Truth() starlark.Bool { return m.dict.Truth() } +func (m *protoMap) Items() []starlark.Tuple { return m.dict.Items() } + +func (m *protoMap) Type() string { + return fmt.Sprintf("map<%s, %s>", typeName(m.mapKey), typeName(m.mapValue)) +} + +func (m *protoMap) CompareSameType(op syntax.Token, y starlark.Value, depth int) (bool, error) { + other, ok := y.(*protoMap) + if !ok { + return false, nil + } + + return starlark.CompareDepth(op, m.dict, other.dict, depth) +} + +func (m *protoMap) wrapSetDefault() starlark.Value { + impl := func(thread *starlark.Thread, b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { + var key, defaultValue starlark.Value = nil, starlark.None + if err := starlark.UnpackPositionalArgs("setdefault", args, kwargs, 1, &key, &defaultValue); err != nil { + return nil, err + } + if val, ok, err := m.dict.Get(key); err != nil { + return nil, err + } else if ok { + return val, nil + } + return defaultValue, m.SetKey(key, defaultValue) + } + return starlark.NewBuiltin("setdefault", impl).BindReceiver(m) +} + +func (m *protoMap) wrapUpdate() starlark.Value { + impl := func(thread *starlark.Thread, b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { + // Use the underlying starlark `dict.update()` to get a Dict containing + // all the new values, so we don't have to recreate the API here. After + // the temp dict is constructed, type check. + tempDict := &starlark.Dict{} + tempUpdate, _ := tempDict.Attr("update") + if _, err := starlark.Call(thread, tempUpdate, args, kwargs); err != nil { + return nil, err + } + for _, item := range tempDict.Items() { + if err := m.SetKey(item[0], item[1]); err != nil { + return nil, err + } + } + + return starlark.None, nil + } + return starlark.NewBuiltin("update", impl).BindReceiver(m) +} + +func (m *protoMap) SetKey(k, v starlark.Value) error { + err := m.typeCheck(k, v) + if err != nil { + return err + } + + err = m.dict.SetKey(k, v) + if err != nil { + return err + } + + return nil +} + +func (m *protoMap) typeCheck(k, v starlark.Value) error { + err := scalarTypeCheck(m.mapKey, k) + if err != nil { + return err + } + + err = scalarTypeCheck(m.mapValue, v) + if err != nil { + return err + } + + return nil +} diff --git a/go/protomodule/protomodule_test.go b/go/protomodule/protomodule_test.go index a8a57c7..3cf60d2 100644 --- a/go/protomodule/protomodule_test.go +++ b/go/protomodule/protomodule_test.go @@ -23,6 +23,7 @@ import ( "go.starlark.net/resolve" "go.starlark.net/starlark" "go.starlark.net/starlarkstruct" + "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" pb "github.com/stripe/skycfg/internal/testdata/test_proto" @@ -200,6 +201,264 @@ func TestEnumType(t *testing.T) { } } +func TestListType(t *testing.T) { + var listFieldDesc protoreflect.FieldDescriptor + msg := (&pb.MessageV3{}).ProtoReflect().Descriptor() + listFieldDesc = msg.Fields().ByName("r_string") + + globals := starlark.StringDict{ + "list": starlark.NewBuiltin("list", func( + t *starlark.Thread, + fn *starlark.Builtin, + args starlark.Tuple, + kwargs []starlark.Tuple, + ) (starlark.Value, error) { + return newProtoRepeated(listFieldDesc), nil + }), + } + + tests := []struct { + expr string + exprFun string + want string + wantErr error + }{ + { + expr: `list()`, + want: `[]`, + }, + { + expr: `dir(list())`, + want: `["append", "clear", "extend", "index", "insert", "pop", "remove"]`, + }, + // List methods + { + exprFun: ` +def fun(): + l = list() + l.append("some string") + return l +`, + want: `["some string"]`, + }, + { + exprFun: ` +def fun(): + l = list() + l.extend(["a", "b"]) + return l +`, + want: `["a", "b"]`, + }, + { + exprFun: ` +def fun(): + l = list() + l.extend(["a", "b"]) + l.clear() + return l +`, + want: `[]`, + }, + { + exprFun: ` +def fun(): + l = list() + l.extend(["a", "b"]) + l[1] = "c" + return l +`, + want: `["a", "c"]`, + }, + + // List typechecking + { + expr: `list().append(1)`, + wantErr: errors.New(`TypeError: value 1 (type "int") can't be assigned to type "string".`), + }, + { + expr: `list().extend([1,2])`, + wantErr: errors.New(`TypeError: value 1 (type "int") can't be assigned to type "string".`), + }, + { + exprFun: ` +def fun(): + l = list() + l.extend(["a", "b"]) + l[1] = 1 + return l +`, + wantErr: errors.New(`TypeError: value 1 (type "int") can't be assigned to type "string".`), + }, + } + for _, test := range tests { + t.Run("", func(t *testing.T) { + var val starlark.Value + var err error + if test.expr != "" { + val, err = starlark.Eval(&starlark.Thread{}, "", test.expr, globals) + } else { + val, err = evalFunc(test.exprFun, globals) + } + + if test.wantErr != nil { + if !checkError(err, test.wantErr) { + t.Fatalf("eval(%q): expected error %v, got %v", test.expr, test.wantErr, err) + } + return + } + if err != nil { + t.Fatalf("eval(%q): %v", test.expr, err) + } + if test.want != val.String() { + t.Errorf("eval(%q): expected value %q, got %q", test.expr, test.want, val.String()) + } + }) + } +} + +func TestMapType(t *testing.T) { + var mapFieldDesc protoreflect.FieldDescriptor + msg := (&pb.MessageV3{}).ProtoReflect().Descriptor() + mapFieldDesc = msg.Fields().ByName("map_string") + + globals := starlark.StringDict{ + "map": starlark.NewBuiltin("map", func( + t *starlark.Thread, + fn *starlark.Builtin, + args starlark.Tuple, + kwargs []starlark.Tuple, + ) (starlark.Value, error) { + return newProtoMap(mapFieldDesc.MapKey(), mapFieldDesc.MapValue()), nil + }), + } + + tests := []struct { + expr string + exprFun string + want string + wantErr error + }{ + { + expr: `map()`, + want: `{}`, + }, + { + expr: `dir(map())`, + want: `["clear", "get", "items", "keys", "pop", "popitem", "setdefault", "update", "values"]`, + }, + // Map methods + { + exprFun: ` +def fun(): + m = map() + m["a"] = "A" + m.setdefault('a', 'Z') + m.setdefault('b', 'Z') + return m +`, + want: `{"a": "A", "b": "Z"}`, + }, + { + exprFun: ` +def fun(): + m = map() + m["a"] = "some string" + return m +`, + want: `{"a": "some string"}`, + }, + { + exprFun: ` +def fun(): + m = map() + m.update([("a", "a_string"), ("b", "b_string")]) + return m +`, + want: `{"a": "a_string", "b": "b_string"}`, + }, + { + exprFun: ` +def fun(): + m = map() + m["a"] = "some string" + m.clear() + return m +`, + want: `{}`, + }, + { + exprFun: ` +def fun(): + l = list() + l.extend(["a", "b"]) + l[1] = "c" + return l +`, + want: `["a", "c"]`, + }, + + // Map typechecking + { + exprFun: ` +def fun(): + m = map() + m["a"] = 1 + return m +`, + wantErr: errors.New(`TypeError: value 1 (type "int") can't be assigned to type "string".`), + }, + { + expr: `map().update([("a", 1)])`, + wantErr: errors.New(`TypeError: value 1 (type "int") can't be assigned to type "string".`), + }, + { + expr: `map().setdefault("a", 1)`, + wantErr: errors.New(`TypeError: value 1 (type "int") can't be assigned to type "string".`), + }, + } + for _, test := range tests { + t.Run("", func(t *testing.T) { + var val starlark.Value + var err error + if test.expr != "" { + val, err = starlark.Eval(&starlark.Thread{}, "", test.expr, globals) + } else { + val, err = evalFunc(test.exprFun, globals) + } + + if test.wantErr != nil { + if !checkError(err, test.wantErr) { + t.Fatalf("eval(%q): expected error %v, got %v", test.expr, test.wantErr, err) + } + return + } + if err != nil { + t.Fatalf("eval(%q): %v", test.expr, err) + } + if test.want != val.String() { + t.Errorf("eval(%q): expected value %q, got %q", test.expr, test.want, val.String()) + } + }) + } +} + +func evalFunc(src string, globals starlark.StringDict) (starlark.Value, error) { + globals, err := starlark.ExecFile(&starlark.Thread{}, "", src, globals) + if err != nil { + return nil, err + } + v, ok := globals["fun"] + if !ok { + return nil, errors.New(`Expected function "fun", not found`) + } + fun, ok := v.(starlark.Callable) + if !ok { + return nil, errors.New("Fun not callable") + } + return starlark.Call(&starlark.Thread{}, fun, nil, nil) +} + func checkError(got, want error) bool { if got == nil { return false diff --git a/go/protomodule/type_conversions.go b/go/protomodule/type_conversions.go new file mode 100644 index 0000000..249c125 --- /dev/null +++ b/go/protomodule/type_conversions.go @@ -0,0 +1,249 @@ +// Copyright 2021 The Skycfg Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// type_conversions.go provides protomodule-to-starlark and +// starlark-to-protomodule conversions +package protomodule + +import ( + "fmt" + "math" + + "go.starlark.net/starlark" + "google.golang.org/protobuf/reflect/protoreflect" +) + +func valueFromStarlark(msg protoreflect.Message, fieldDesc protoreflect.FieldDescriptor, val starlark.Value) (protoreflect.Value, error) { + if fieldDesc.IsList() { + if list, ok := val.(*protoRepeated); ok { + protoListValue := msg.New().NewField(fieldDesc) + protoList := protoListValue.List() + for i := 0; i < list.Len(); i++ { + v, err := scalarValueFromStarlark(fieldDesc, list.Index(i)) + if err != nil { + return protoreflect.Value{}, err + } + protoList.Append(v) + } + + return protoListValue, nil + } + + return protoreflect.Value{}, typeError(fieldDesc, val, false) + } else if fieldDesc.IsMap() { + if mapVal, ok := val.(*protoMap); ok { + protoMapValue := msg.New().NewField(fieldDesc) + protoMap := protoMapValue.Map() + for _, item := range mapVal.Items() { + protoK, err := scalarValueFromStarlark(fieldDesc.MapKey(), item[0]) + if err != nil { + return protoreflect.Value{}, err + } + + protoV, err := scalarValueFromStarlark(fieldDesc.MapValue(), item[1]) + if err != nil { + return protoreflect.Value{}, err + } + + protoMap.Set(protoreflect.MapKey(protoK), protoV) + } + + return protoMapValue, nil + } + return protoreflect.Value{}, typeError(fieldDesc, val, false) + } + + return scalarValueFromStarlark(fieldDesc, val) +} + +func scalarValueFromStarlark(fieldDesc protoreflect.FieldDescriptor, val starlark.Value) (protoreflect.Value, error) { + k := fieldDesc.Kind() + switch k { + case protoreflect.BoolKind: + if val, ok := val.(starlark.Bool); ok { + return protoreflect.ValueOf(bool(val)), nil + } + case protoreflect.StringKind: + if val, ok := val.(starlark.String); ok { + return protoreflect.ValueOf(string(val)), nil + } + case protoreflect.DoubleKind: + if val, ok := starlark.AsFloat(val); ok { + return protoreflect.ValueOf(val), nil + } + case protoreflect.FloatKind: + if val, ok := starlark.AsFloat(val); ok { + return protoreflect.ValueOf(float32(val)), nil + } + case protoreflect.Int64Kind: + if valInt, ok := val.(starlark.Int); ok { + if val, ok := valInt.Int64(); ok { + return protoreflect.ValueOf(val), nil + } + return protoreflect.Value{}, fmt.Errorf("ValueError: value %v overflows type \"int64\".", valInt) + } + case protoreflect.Uint64Kind: + if valInt, ok := val.(starlark.Int); ok { + if val, ok := valInt.Uint64(); ok { + return protoreflect.ValueOf(val), nil + } + return protoreflect.Value{}, fmt.Errorf("ValueError: value %v overflows type \"uint64\".", valInt) + } + case protoreflect.Int32Kind: + if valInt, ok := val.(starlark.Int); ok { + if val, ok := valInt.Int64(); ok && val >= math.MinInt32 && val <= math.MaxInt32 { + return protoreflect.ValueOf(int32(val)), nil + } + return protoreflect.Value{}, fmt.Errorf("ValueError: value %v overflows type \"int32\".", valInt) + } + case protoreflect.Uint32Kind: + if valInt, ok := val.(starlark.Int); ok { + if val, ok := valInt.Uint64(); ok && val <= math.MaxUint32 { + return protoreflect.ValueOf(uint32(val)), nil + } + return protoreflect.Value{}, fmt.Errorf("ValueError: value %v overflows type \"uint32\".", valInt) + } + case protoreflect.MessageKind: + return protoreflect.Value{}, fmt.Errorf("MessageKind: Unimplemented") + case protoreflect.EnumKind: + if enum, ok := val.(*protoEnumValue); ok { + return protoreflect.ValueOf(enum.enumNumber()), nil + } + case protoreflect.BytesKind: + if valString, ok := val.(starlark.String); ok { + return protoreflect.ValueOf([]byte(valString)), nil + } + } + + return protoreflect.Value{}, typeError(fieldDesc, val, true) +} + +// Wrap a protobuf field value as a starlark.Value +func valueToStarlark(val protoreflect.Value, fieldDesc protoreflect.FieldDescriptor) (starlark.Value, error) { + if fieldDesc.IsList() { + if listVal, ok := val.Interface().(protoreflect.List); ok { + out := newProtoRepeated(fieldDesc) + for i := 0; i < listVal.Len(); i++ { + starlarkValue, err := scalarValueToStarlark(listVal.Get(i), fieldDesc) + if err != nil { + return starlark.None, err + } + out.Append(starlarkValue) + } + return out, nil + } else if val.Interface() == nil { + return newProtoRepeated(fieldDesc), nil + } + return starlark.None, fmt.Errorf("TypeError: cannot convert %T into list", val.Interface()) + } else if fieldDesc.IsMap() { + if mapVal, ok := val.Interface().(protoreflect.Map); ok { + out := newProtoMap(fieldDesc.MapKey(), fieldDesc.MapValue()) + var rangeErr error + mapVal.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { + starlarkKey, err := scalarValueToStarlark(protoreflect.Value(k), fieldDesc.MapKey()) + if err != nil { + rangeErr = err + return false + } + + starlarkValue, err := scalarValueToStarlark(v, fieldDesc.MapValue()) + if err != nil { + rangeErr = err + return false + } + + out.SetKey(starlarkKey, starlarkValue) + return true + }) + if rangeErr != nil { + return starlark.None, rangeErr + } + + return out, nil + } else if val.Interface() == nil { + return newProtoMap(fieldDesc.MapKey(), fieldDesc.MapValue()), nil + } + return starlark.None, fmt.Errorf("TypeError: cannot convert %T into map", val.Interface()) + } + + return scalarValueToStarlark(val, fieldDesc) +} + +func scalarValueToStarlark(val protoreflect.Value, fieldDesc protoreflect.FieldDescriptor) (starlark.Value, error) { + switch fieldDesc.Kind() { + case protoreflect.BoolKind: + return starlark.Bool(val.Bool()), nil + case protoreflect.Int32Kind: + return starlark.MakeInt64(val.Int()), nil + case protoreflect.Int64Kind: + return starlark.MakeInt64(val.Int()), nil + case protoreflect.Uint32Kind: + return starlark.MakeUint64(val.Uint()), nil + case protoreflect.Uint64Kind: + return starlark.MakeUint64(val.Uint()), nil + case protoreflect.FloatKind: + return starlark.Float(val.Float()), nil + case protoreflect.DoubleKind: + return starlark.Float(val.Float()), nil + case protoreflect.StringKind: + return starlark.String(val.String()), nil + case protoreflect.BytesKind: + // Handle []byte ([]uint8) -> string special case. + return starlark.String(val.Bytes()), nil + case protoreflect.MessageKind: + return nil, fmt.Errorf("MessageKind: Unimplemented") + } + + return starlark.None, fmt.Errorf("valueToStarlark: Value unuspported: %s\n", string(fieldDesc.FullName())) +} + +// Verify v can act as fieldDesc +func scalarTypeCheck(fieldDesc protoreflect.FieldDescriptor, v starlark.Value) error { + _, err := scalarValueFromStarlark(fieldDesc, v) + return err +} + +func typeError(fieldDesc protoreflect.FieldDescriptor, val starlark.Value, scalar bool) error { + expectedType := typeName(fieldDesc) + + // FieldDescriptor has the same typeName for []string and string + // and typeError needs to distinguish setting a []string = int versus + // appending a value in []string + if !scalar { + if fieldDesc.IsList() { + expectedType = fmt.Sprintf("[]%s", typeName(fieldDesc)) + } else if fieldDesc.IsMap() { + expectedType = fmt.Sprintf("map[%s]%s", typeName(fieldDesc.MapKey()), typeName(fieldDesc.MapValue())) + } + } + + return fmt.Errorf("TypeError: value %s (type %q) can't be assigned to type %q.", + val.String(), val.Type(), expectedType, + ) +} + +// Returns a type name for a descriptor, ignoring list/map qualifiers +func typeName(fieldDesc protoreflect.FieldDescriptor) string { + k := fieldDesc.Kind() + switch k { + case protoreflect.EnumKind: + return string(fieldDesc.Enum().FullName()) + case protoreflect.MessageKind: + return string(fieldDesc.Message().FullName()) + default: + return k.String() + } +} From fbadfe7e9b41ffc792cf5fb286621f64ccb60f2f Mon Sep 17 00:00:00 2001 From: Seena Burns Date: Tue, 17 Aug 2021 12:39:34 -0700 Subject: [PATCH 2/3] Correct date, missing err --- go/protomodule/protomodule_list.go | 4 ++-- go/protomodule/protomodule_map.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go/protomodule/protomodule_list.go b/go/protomodule/protomodule_list.go index 83f9cba..bb3e263 100644 --- a/go/protomodule/protomodule_list.go +++ b/go/protomodule/protomodule_list.go @@ -1,4 +1,4 @@ -// Copyright 2021 The Skycfg Authors. +// Copyright 2020 The Skycfg Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -118,7 +118,7 @@ func (r *protoRepeated) SetIndex(i int, v starlark.Value) error { return err } - r.list.SetIndex(i, v) + err = r.list.SetIndex(i, v) if err != nil { return err } diff --git a/go/protomodule/protomodule_map.go b/go/protomodule/protomodule_map.go index 4818ed2..891a56e 100644 --- a/go/protomodule/protomodule_map.go +++ b/go/protomodule/protomodule_map.go @@ -1,4 +1,4 @@ -// Copyright 2021 The Skycfg Authors. +// Copyright 2020 The Skycfg Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. From 3a578da610668f4a75da756c93a7e76b8b276ab7 Mon Sep 17 00:00:00 2001 From: Seena Burns Date: Tue, 17 Aug 2021 13:48:56 -0700 Subject: [PATCH 3/3] PR feedback --- go/protomodule/protomodule_list.go | 16 ++-------- go/protomodule/protomodule_map.go | 14 ++------- go/protomodule/protomodule_test.go | 50 +++++++++++++++++++++++------- 3 files changed, 44 insertions(+), 36 deletions(-) diff --git a/go/protomodule/protomodule_list.go b/go/protomodule/protomodule_list.go index bb3e263..baf5ea7 100644 --- a/go/protomodule/protomodule_list.go +++ b/go/protomodule/protomodule_list.go @@ -31,7 +31,7 @@ var allowedListMethods = map[string]func(*protoRepeated) starlark.Value{ } // protoRepeated wraps an underlying starlark.List to provide typechecking on -// wrties +// writes // // starlark.List is heterogeneous, where protoRepeated enforces all values // conform to the given fieldDesc @@ -104,12 +104,7 @@ func (r *protoRepeated) Append(v starlark.Value) error { return err } - err = r.list.Append(v) - if err != nil { - return err - } - - return nil + return r.list.Append(v) } func (r *protoRepeated) SetIndex(i int, v starlark.Value) error { @@ -118,12 +113,7 @@ func (r *protoRepeated) SetIndex(i int, v starlark.Value) error { return err } - err = r.list.SetIndex(i, v) - if err != nil { - return err - } - - return nil + return r.list.SetIndex(i, v) } func (r *protoRepeated) Extend(iterable starlark.Iterable) error { diff --git a/go/protomodule/protomodule_map.go b/go/protomodule/protomodule_map.go index 891a56e..6d1cf2c 100644 --- a/go/protomodule/protomodule_map.go +++ b/go/protomodule/protomodule_map.go @@ -150,12 +150,7 @@ func (m *protoMap) SetKey(k, v starlark.Value) error { return err } - err = m.dict.SetKey(k, v) - if err != nil { - return err - } - - return nil + return m.dict.SetKey(k, v) } func (m *protoMap) typeCheck(k, v starlark.Value) error { @@ -164,10 +159,5 @@ func (m *protoMap) typeCheck(k, v starlark.Value) error { return err } - err = scalarTypeCheck(m.mapValue, v) - if err != nil { - return err - } - - return nil + return scalarTypeCheck(m.mapValue, v) } diff --git a/go/protomodule/protomodule_test.go b/go/protomodule/protomodule_test.go index 3cf60d2..2718eb3 100644 --- a/go/protomodule/protomodule_test.go +++ b/go/protomodule/protomodule_test.go @@ -218,21 +218,25 @@ func TestListType(t *testing.T) { } tests := []struct { + name string expr string exprFun string want string wantErr error }{ { + name: "new list", expr: `list()`, want: `[]`, }, { + name: "list AttrNames", expr: `dir(list())`, want: `["append", "clear", "extend", "index", "insert", "pop", "remove"]`, }, // List methods { + name: "list.Append", exprFun: ` def fun(): l = list() @@ -242,6 +246,7 @@ def fun(): want: `["some string"]`, }, { + name: "list.Extend", exprFun: ` def fun(): l = list() @@ -251,6 +256,7 @@ def fun(): want: `["a", "b"]`, }, { + name: "list.Clear", exprFun: ` def fun(): l = list() @@ -261,6 +267,7 @@ def fun(): want: `[]`, }, { + name: "list.SetIndex", exprFun: ` def fun(): l = list() @@ -270,17 +277,33 @@ def fun(): `, want: `["a", "c"]`, }, + { + name: "list binary add operation", + exprFun: ` +def fun(): + l = list() + l2 = list() + l2.extend(["a", "b"]) + l += l2 + l += ["c", "d"] + return l +`, + want: `["a", "b", "c", "d"]`, + }, // List typechecking { + name: "list append typchecks", expr: `list().append(1)`, wantErr: errors.New(`TypeError: value 1 (type "int") can't be assigned to type "string".`), }, { + name: "list extend typchecks", expr: `list().extend([1,2])`, wantErr: errors.New(`TypeError: value 1 (type "int") can't be assigned to type "string".`), }, { + name: "list set index typchecks", exprFun: ` def fun(): l = list() @@ -292,7 +315,7 @@ def fun(): }, } for _, test := range tests { - t.Run("", func(t *testing.T) { + t.Run(test.name, func(t *testing.T) { var val starlark.Value var err error if test.expr != "" { @@ -334,21 +357,25 @@ func TestMapType(t *testing.T) { } tests := []struct { + name string expr string exprFun string want string wantErr error }{ { + name: "new map", expr: `map()`, want: `{}`, }, { + name: "map AttrNames", expr: `dir(map())`, want: `["clear", "get", "items", "keys", "pop", "popitem", "setdefault", "update", "values"]`, }, // Map methods { + name: "map.SetDefault", exprFun: ` def fun(): m = map() @@ -360,6 +387,7 @@ def fun(): want: `{"a": "A", "b": "Z"}`, }, { + name: "map.SetKey", exprFun: ` def fun(): m = map() @@ -369,6 +397,7 @@ def fun(): want: `{"a": "some string"}`, }, { + name: "map.Update", exprFun: ` def fun(): m = map() @@ -378,6 +407,7 @@ def fun(): want: `{"a": "a_string", "b": "b_string"}`, }, { + name: "map.Clear", exprFun: ` def fun(): m = map() @@ -387,19 +417,10 @@ def fun(): `, want: `{}`, }, - { - exprFun: ` -def fun(): - l = list() - l.extend(["a", "b"]) - l[1] = "c" - return l -`, - want: `["a", "c"]`, - }, // Map typechecking { + name: "map.SetKey typechecks", exprFun: ` def fun(): m = map() @@ -409,13 +430,20 @@ def fun(): wantErr: errors.New(`TypeError: value 1 (type "int") can't be assigned to type "string".`), }, { + name: "map.Update typechecks", expr: `map().update([("a", 1)])`, wantErr: errors.New(`TypeError: value 1 (type "int") can't be assigned to type "string".`), }, { + name: "map.SetDefault typechecks", expr: `map().setdefault("a", 1)`, wantErr: errors.New(`TypeError: value 1 (type "int") can't be assigned to type "string".`), }, + { + name: "map.SetDefault typechecks key", + expr: `map().setdefault(1, "a")`, + wantErr: errors.New(`TypeError: value 1 (type "int") can't be assigned to type "string".`), + }, } for _, test := range tests { t.Run("", func(t *testing.T) {