From 8905443daf5f2fc87d42f02d9e80c262c66dfcf6 Mon Sep 17 00:00:00 2001 From: Kadin Rabo Date: Thu, 10 Oct 2024 16:17:56 -0400 Subject: [PATCH 1/7] feat: support custom types in extension yamls --- extensions/simple_extension_test.go | 30 +++++++++++++++++++++++++++++ types/parser/type_parser.go | 3 ++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/extensions/simple_extension_test.go b/extensions/simple_extension_test.go index 7e17206..ec10cf8 100644 --- a/extensions/simple_extension_test.go +++ b/extensions/simple_extension_test.go @@ -41,6 +41,36 @@ types: assert.Nil(t, f.Types[3].Structure) } +func TestUnmarshalCustomScalarFunction(t *testing.T) { + const customDef = ` +scalar_functions: + - name: "scalar1" + impls: + - args: + - name: arg1 + value: u!customtype1 + return: i64 + - name: "scalar2" + impls: + - args: + - name: arg1 + value: i64 + return: u!customtype2? +` + + var f extensions.SimpleExtensionFile + require.NoError(t, yaml.Unmarshal([]byte(customDef), &f)) + + assert.Len(t, f.ScalarFunctions, 2) + assert.Equal(t, "scalar1", f.ScalarFunctions[0].Name) + assert.IsType(t, extensions.ValueArg{}, f.ScalarFunctions[0].Impls[0].Args[0]) + arg1 := f.ScalarFunctions[0].Impls[0].Args[0].(extensions.ValueArg) + assert.Equal(t, "u!customtype1", arg1.Value.String()) + assert.Equal(t, "scalar2", f.ScalarFunctions[1].Name) + assert.IsType(t, extensions.ValueArg{}, f.ScalarFunctions[1].Impls[0].Args[0]) + assert.Equal(t, "u!customtype2?", f.ScalarFunctions[1].Impls[0].Return.String()) +} + func TestUnmarshalSimpleExtensionScalarFunction(t *testing.T) { const addDef = ` scalar_functions: diff --git a/types/parser/type_parser.go b/types/parser/type_parser.go index 8b546cd..3662bac 100644 --- a/types/parser/type_parser.go +++ b/types/parser/type_parser.go @@ -122,7 +122,7 @@ func (t *typename) Capture(values []string) error { } type nonParamType struct { - TypeName typename `parser:"@(IntType | Boolean | FPType | Temporal | BinaryType)"` + TypeName typename `parser:"@(IntType | Boolean | FPType | Temporal | BinaryType | UserDefinedType)"` Nullability bool `parser:"@'?'?"` // Variation int `parser:"'[' @\d+ ']'?"` } @@ -598,6 +598,7 @@ var ( {Name: "LengthType", Pattern: `fixedchar|varchar|fixedbinary|precision_timestamp_tz|precision_timestamp|interval_day`}, {Name: "Int", Pattern: `[-+]?\d+`}, {Name: "ParamType", Pattern: `(?i)(struct|list|decimal|map)`}, + {Name: "UserDefinedType", Pattern: `u![a-zA-Z_][a-zA-Z0-9_]*`}, {Name: "Identifier", Pattern: `[a-zA-Z_$][a-zA-Z_$0-9]*`}, {Name: "Ident", Pattern: `([a-zA-Z_]\w*)|[><,?]`}, }) From 37f29fca6f879d88f407d37986fb49b11b42406e Mon Sep 17 00:00:00 2001 From: Kadin Rabo Date: Fri, 11 Oct 2024 14:44:27 -0400 Subject: [PATCH 2/7] feat: update nonParamType.RetType() for UDTs --- extensions/simple_extension_test.go | 23 +++++++++++++++++++++-- types/parser/type_parser.go | 6 +++++- types/types.go | 2 ++ 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/extensions/simple_extension_test.go b/extensions/simple_extension_test.go index ec10cf8..8e52c6d 100644 --- a/extensions/simple_extension_test.go +++ b/extensions/simple_extension_test.go @@ -3,6 +3,9 @@ package extensions_test import ( + "github.com/substrait-io/substrait-go/proto" + "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/parser" "strings" "testing" @@ -60,15 +63,31 @@ scalar_functions: var f extensions.SimpleExtensionFile require.NoError(t, yaml.Unmarshal([]byte(customDef), &f)) - assert.Len(t, f.ScalarFunctions, 2) + assert.Equal(t, "scalar1", f.ScalarFunctions[0].Name) assert.IsType(t, extensions.ValueArg{}, f.ScalarFunctions[0].Impls[0].Args[0]) arg1 := f.ScalarFunctions[0].Impls[0].Args[0].(extensions.ValueArg) assert.Equal(t, "u!customtype1", arg1.Value.String()) + if def, ok := arg1.Value.Expr.(*parser.Type); assert.True(t, ok, "expected *parser.Type") { + if retDef, _ := def.TypeDef.(parser.Def); assert.True(t, ok, "expected parser.Def") { + typ, _ := retDef.RetType() + assert.IsType(t, &types.UserDefinedType{}, typ) + assert.Equal(t, proto.Type_NULLABILITY_REQUIRED, typ.GetNullability(), "expected Type_NULLABILITY_REQUIRED") + } + } + assert.Equal(t, "scalar2", f.ScalarFunctions[1].Name) assert.IsType(t, extensions.ValueArg{}, f.ScalarFunctions[1].Impls[0].Args[0]) - assert.Equal(t, "u!customtype2?", f.ScalarFunctions[1].Impls[0].Return.String()) + ret := f.ScalarFunctions[1].Impls[0].Return + assert.Equal(t, "u!customtype2?", ret.String()) + if def, ok := ret.Expr.(*parser.Type); assert.True(t, ok, "expected *parser.Type") { + if retDef, _ := def.TypeDef.(parser.Def); assert.True(t, ok, "expected parser.Def") { + typ, _ := retDef.RetType() + assert.IsType(t, &types.UserDefinedType{}, typ) + assert.Equal(t, proto.Type_NULLABILITY_NULLABLE, typ.GetNullability(), "expected NULLABILITY_NULLABLE") + } + } } func TestUnmarshalSimpleExtensionScalarFunction(t *testing.T) { diff --git a/types/parser/type_parser.go b/types/parser/type_parser.go index 3662bac..cb466ee 100644 --- a/types/parser/type_parser.go +++ b/types/parser/type_parser.go @@ -148,7 +148,11 @@ func (t *nonParamType) RetType() (types.Type, error) { } else { n = types.NullabilityRequired } - typ, err := types.SimpleTypeNameToType(types.TypeName(t.TypeName)) + typName := t.TypeName + if strings.HasPrefix(string(typName), "u!") { + typName = "u!" + } + typ, err := types.SimpleTypeNameToType(types.TypeName(typName)) if err == nil { return typ.WithNullability(n), nil } diff --git a/types/types.go b/types/types.go index 946af59..14a9e08 100644 --- a/types/types.go +++ b/types/types.go @@ -43,6 +43,7 @@ const ( TypeNameIntervalYear TypeName = "interval_year" TypeNameIntervalDay TypeName = "interval_day" TypeNameUUID TypeName = "uuid" + TypeNameUDT TypeName = "u!" TypeNameFixedBinary TypeName = "fixedbinary" TypeNameFixedChar TypeName = "fixedchar" @@ -68,6 +69,7 @@ var simpleTypeNameMap = map[TypeName]Type{ TypeNameTimestampTz: &TimestampTzType{}, TypeNameIntervalYear: &IntervalYearType{}, TypeNameUUID: &UUIDType{}, + TypeNameUDT: &UserDefinedType{}, } var fixedTypeNameMap = map[TypeName]FixedType{ From b0f17c962d073ea1288507f91beab56c28f0502e Mon Sep 17 00:00:00 2001 From: Kadin Rabo Date: Tue, 15 Oct 2024 09:56:22 -0400 Subject: [PATCH 3/7] fix: linter --- extensions/simple_extension_test.go | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/extensions/simple_extension_test.go b/extensions/simple_extension_test.go index 8e52c6d..a4688f1 100644 --- a/extensions/simple_extension_test.go +++ b/extensions/simple_extension_test.go @@ -70,11 +70,9 @@ scalar_functions: arg1 := f.ScalarFunctions[0].Impls[0].Args[0].(extensions.ValueArg) assert.Equal(t, "u!customtype1", arg1.Value.String()) if def, ok := arg1.Value.Expr.(*parser.Type); assert.True(t, ok, "expected *parser.Type") { - if retDef, _ := def.TypeDef.(parser.Def); assert.True(t, ok, "expected parser.Def") { - typ, _ := retDef.RetType() - assert.IsType(t, &types.UserDefinedType{}, typ) - assert.Equal(t, proto.Type_NULLABILITY_REQUIRED, typ.GetNullability(), "expected Type_NULLABILITY_REQUIRED") - } + typ, _ := def.TypeDef.RetType() + assert.IsType(t, &types.UserDefinedType{}, typ) + assert.Equal(t, proto.Type_NULLABILITY_REQUIRED, typ.GetNullability(), "expected Type_NULLABILITY_REQUIRED") } assert.Equal(t, "scalar2", f.ScalarFunctions[1].Name) @@ -82,11 +80,9 @@ scalar_functions: ret := f.ScalarFunctions[1].Impls[0].Return assert.Equal(t, "u!customtype2?", ret.String()) if def, ok := ret.Expr.(*parser.Type); assert.True(t, ok, "expected *parser.Type") { - if retDef, _ := def.TypeDef.(parser.Def); assert.True(t, ok, "expected parser.Def") { - typ, _ := retDef.RetType() - assert.IsType(t, &types.UserDefinedType{}, typ) - assert.Equal(t, proto.Type_NULLABILITY_NULLABLE, typ.GetNullability(), "expected NULLABILITY_NULLABLE") - } + typ, _ := def.TypeDef.RetType() + assert.IsType(t, &types.UserDefinedType{}, typ) + assert.Equal(t, proto.Type_NULLABILITY_NULLABLE, typ.GetNullability(), "expected NULLABILITY_NULLABLE") } } From 612f6b3675d5f76e862e71fca8752922e41699c9 Mon Sep 17 00:00:00 2001 From: Kadin Rabo Date: Fri, 25 Oct 2024 13:30:16 -0400 Subject: [PATCH 4/7] fix: simpler test --- extensions/simple_extension_test.go | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/extensions/simple_extension_test.go b/extensions/simple_extension_test.go index a4688f1..0c02db4 100644 --- a/extensions/simple_extension_test.go +++ b/extensions/simple_extension_test.go @@ -69,21 +69,19 @@ scalar_functions: assert.IsType(t, extensions.ValueArg{}, f.ScalarFunctions[0].Impls[0].Args[0]) arg1 := f.ScalarFunctions[0].Impls[0].Args[0].(extensions.ValueArg) assert.Equal(t, "u!customtype1", arg1.Value.String()) - if def, ok := arg1.Value.Expr.(*parser.Type); assert.True(t, ok, "expected *parser.Type") { - typ, _ := def.TypeDef.RetType() - assert.IsType(t, &types.UserDefinedType{}, typ) - assert.Equal(t, proto.Type_NULLABILITY_REQUIRED, typ.GetNullability(), "expected Type_NULLABILITY_REQUIRED") - } + typ, err := arg1.Value.Expr.(*parser.Type).TypeDef.RetType() + assert.NoError(t, err) + assert.IsType(t, &types.UserDefinedType{}, typ) + assert.Equal(t, proto.Type_NULLABILITY_REQUIRED, typ.GetNullability(), "expected NULLABILITY_REQUIRED") assert.Equal(t, "scalar2", f.ScalarFunctions[1].Name) assert.IsType(t, extensions.ValueArg{}, f.ScalarFunctions[1].Impls[0].Args[0]) ret := f.ScalarFunctions[1].Impls[0].Return assert.Equal(t, "u!customtype2?", ret.String()) - if def, ok := ret.Expr.(*parser.Type); assert.True(t, ok, "expected *parser.Type") { - typ, _ := def.TypeDef.RetType() - assert.IsType(t, &types.UserDefinedType{}, typ) - assert.Equal(t, proto.Type_NULLABILITY_NULLABLE, typ.GetNullability(), "expected NULLABILITY_NULLABLE") - } + typ, err = ret.Expr.(*parser.Type).TypeDef.RetType() + assert.NoError(t, err) + assert.IsType(t, &types.UserDefinedType{}, typ) + assert.Equal(t, proto.Type_NULLABILITY_NULLABLE, typ.GetNullability(), "expected NULLABILITY_NULLABLE") } func TestUnmarshalSimpleExtensionScalarFunction(t *testing.T) { From 70e29fdb6618bc46ba59c828b7b1a52420cc2f6a Mon Sep 17 00:00:00 2001 From: Kadin Rabo Date: Mon, 28 Oct 2024 17:31:49 -0400 Subject: [PATCH 5/7] feat: add TestParseUDT --- types/parser/type_parser_test.go | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/types/parser/type_parser_test.go b/types/parser/type_parser_test.go index eb8576e..3ac37b2 100644 --- a/types/parser/type_parser_test.go +++ b/types/parser/type_parser_test.go @@ -3,6 +3,7 @@ package parser_test import ( + "github.com/substrait-io/substrait-go/proto" "reflect" "testing" @@ -100,3 +101,30 @@ func TestParserRetType(t *testing.T) { }) } } + +func TestParseUDT(t *testing.T) { + tests := []struct { + expr string + expected string + expectedTyp types.Type + expectedNullability proto.Type_Nullability + }{ + {"u!customtype1", "u!customtype1", &types.UserDefinedType{}, proto.Type_NULLABILITY_REQUIRED}, + {"u!customtype2?", "u!customtype2?", &types.UserDefinedType{}, proto.Type_NULLABILITY_NULLABLE}, + } + + p, err := parser.New() + require.NoError(t, err) + + for _, td := range tests { + t.Run(td.expr, func(t *testing.T) { + d, err := p.ParseString(td.expr) + assert.NoError(t, err) + assert.Equal(t, td.expected, d.Expr.String()) + retType, err := d.Expr.(*parser.Type).RetType() + assert.NoError(t, err) + assert.Equal(t, td.expectedNullability, retType.GetNullability()) + assert.Equal(t, reflect.TypeOf(td.expectedTyp), reflect.TypeOf(retType)) + }) + } +} From 7a28cb065edd1a07b6483e3c380f1674fddcf587 Mon Sep 17 00:00:00 2001 From: Kadin Rabo Date: Mon, 28 Oct 2024 18:12:03 -0400 Subject: [PATCH 6/7] test: TestParserListType --- types/parser/type_parser_test.go | 57 ++++++++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 2 deletions(-) diff --git a/types/parser/type_parser_test.go b/types/parser/type_parser_test.go index 3ac37b2..97f19ad 100644 --- a/types/parser/type_parser_test.go +++ b/types/parser/type_parser_test.go @@ -102,15 +102,67 @@ func TestParserRetType(t *testing.T) { } } +func TestParserListType(t *testing.T) { + tests := []struct { + expr string + expected string + expectedTyp types.Type + }{ + { + expr: "list", + expected: "list", + expectedTyp: &types.ListType{ + Nullability: types.NullabilityRequired, + Type: &types.Int32Type{Nullability: types.NullabilityRequired}, + }, + }, + { + expr: "list?", + expected: "list?", + expectedTyp: &types.ListType{ + Nullability: types.NullabilityNullable, + Type: &types.Int16Type{Nullability: types.NullabilityNullable}, + }, + }, + { + expr: "list", + expected: "list", + expectedTyp: &types.ListType{ + Nullability: types.NullabilityRequired, + Type: &types.Int16Type{Nullability: types.NullabilityRequired}, + }, + }, + } + + p, err := parser.New() + require.NoError(t, err) + + for _, td := range tests { + t.Run(td.expr, func(t *testing.T) { + d, err := p.ParseString(td.expr) + assert.NoError(t, err) + assert.Equal(t, td.expected, d.Expr.String()) + + if tExpr, ok := d.Expr.(*parser.Type); ok { + retType, err := tExpr.RetType() + assert.NoError(t, err) + assert.Equal(t, reflect.TypeOf(td.expectedTyp), reflect.TypeOf(retType)) + assert.Equal(t, td.expectedTyp, retType) + } + }) + } +} + func TestParseUDT(t *testing.T) { tests := []struct { expr string expected string expectedTyp types.Type expectedNullability proto.Type_Nullability + expectedOptional bool }{ - {"u!customtype1", "u!customtype1", &types.UserDefinedType{}, proto.Type_NULLABILITY_REQUIRED}, - {"u!customtype2?", "u!customtype2?", &types.UserDefinedType{}, proto.Type_NULLABILITY_NULLABLE}, + {"u!customtype1", "u!customtype1", &types.UserDefinedType{}, proto.Type_NULLABILITY_REQUIRED, false}, + {"u!customtype2?", "u!customtype2?", &types.UserDefinedType{}, proto.Type_NULLABILITY_NULLABLE, true}, } p, err := parser.New() @@ -124,6 +176,7 @@ func TestParseUDT(t *testing.T) { retType, err := d.Expr.(*parser.Type).RetType() assert.NoError(t, err) assert.Equal(t, td.expectedNullability, retType.GetNullability()) + assert.Equal(t, td.expectedOptional, d.Expr.(*parser.Type).Optional()) assert.Equal(t, reflect.TypeOf(td.expectedTyp), reflect.TypeOf(retType)) }) } From dd2fc7d5a5c054e81967fdba77f36e87e52ad424 Mon Sep 17 00:00:00 2001 From: Kadin Rabo Date: Tue, 29 Oct 2024 14:11:04 -0400 Subject: [PATCH 7/7] refactor: gci import order --- extensions/simple_extension_test.go | 6 +++--- types/parser/type_parser_test.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/extensions/simple_extension_test.go b/extensions/simple_extension_test.go index 0c02db4..dd20215 100644 --- a/extensions/simple_extension_test.go +++ b/extensions/simple_extension_test.go @@ -3,9 +3,6 @@ package extensions_test import ( - "github.com/substrait-io/substrait-go/proto" - "github.com/substrait-io/substrait-go/types" - "github.com/substrait-io/substrait-go/types/parser" "strings" "testing" @@ -13,6 +10,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/extensions" + "github.com/substrait-io/substrait-go/proto" + "github.com/substrait-io/substrait-go/types" + "github.com/substrait-io/substrait-go/types/parser" ) func TestUnmarshalSimpleExtension(t *testing.T) { diff --git a/types/parser/type_parser_test.go b/types/parser/type_parser_test.go index 97f19ad..fa7372c 100644 --- a/types/parser/type_parser_test.go +++ b/types/parser/type_parser_test.go @@ -3,12 +3,12 @@ package parser_test import ( - "github.com/substrait-io/substrait-go/proto" "reflect" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/proto" "github.com/substrait-io/substrait-go/types" "github.com/substrait-io/substrait-go/types/integer_parameters" "github.com/substrait-io/substrait-go/types/parser"