diff --git a/pkg/controller/nodes/end/handler.go b/pkg/controller/nodes/end/handler.go index a26678286..eebd6eceb 100644 --- a/pkg/controller/nodes/end/handler.go +++ b/pkg/controller/nodes/end/handler.go @@ -41,11 +41,11 @@ func (e endHandler) Handle(ctx context.Context, executionContext handler.NodeExe return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(nil)), nil } -func (e endHandler) Abort(ctx context.Context, executionContext handler.NodeExecutionContext, reason string) error { +func (e endHandler) Abort(_ context.Context, _ handler.NodeExecutionContext, _ string) error { return nil } -func (e endHandler) Finalize(ctx context.Context, executionContext handler.NodeExecutionContext) error { +func (e endHandler) Finalize(_ context.Context, _ handler.NodeExecutionContext) error { return nil } diff --git a/pkg/controller/nodes/end/handler_test.go b/pkg/controller/nodes/end/handler_test.go index 897e08c74..aa47990b6 100644 --- a/pkg/controller/nodes/end/handler_test.go +++ b/pkg/controller/nodes/end/handler_test.go @@ -2,6 +2,7 @@ package end import ( "context" + "fmt" "testing" "github.com/golang/protobuf/proto" @@ -86,6 +87,15 @@ func TestEndHandler_Handle(t *testing.T) { return nCtx } + t.Run("InputReadFailure", func(t *testing.T) { + ir := &mocks2.InputReader{} + ir.OnGetMatch(mock.Anything).Return(nil, fmt.Errorf("err")) + nCtx := &mocks.NodeExecutionContext{} + nCtx.OnInputReader().Return(ir) + _, err := e.Handle(ctx, nCtx) + assert.Error(t, err) + }) + t.Run("NoInputs", func(t *testing.T) { nCtx := createNodeCtx(nil, nil) s, err := e.Handle(ctx, nCtx) @@ -122,3 +132,18 @@ func TestEndHandler_Handle(t *testing.T) { assert.Equal(t, handler.UnknownTransition, s) }) } + +func TestEndHandler_Abort(t *testing.T) { + e := New() + assert.NoError(t, e.Abort(context.TODO(), nil, "")) +} + +func TestEndHandler_Finalize(t *testing.T) { + e := New() + assert.NoError(t, e.Finalize(context.TODO(), nil)) +} + +func TestEndHandler_FinalizeRequired(t *testing.T) { + e := New() + assert.False(t, e.FinalizeRequired()) +} diff --git a/pkg/controller/nodes/handler/transition_info.go b/pkg/controller/nodes/handler/transition_info.go index 88174f6a6..7cc0e9058 100644 --- a/pkg/controller/nodes/handler/transition_info.go +++ b/pkg/controller/nodes/handler/transition_info.go @@ -105,22 +105,6 @@ func (p PhaseInfo) GetReason() string { return p.reason } -func (p *PhaseInfo) SetOcurredAt(t time.Time) { - p.occurredAt = t -} - -func (p *PhaseInfo) SetErr(err *core.ExecutionError) { - p.err = err -} - -func (p *PhaseInfo) SetInfo(info *ExecutionInfo) { - p.info = info -} - -func (p *PhaseInfo) SetReason() string { - return p.reason -} - var PhaseInfoUndefined = PhaseInfo{p: EPhaseUndefined} func phaseInfo(p EPhase, err *core.ExecutionError, info *ExecutionInfo, reason string) PhaseInfo { diff --git a/pkg/controller/nodes/handler/transition_info_test.go b/pkg/controller/nodes/handler/transition_info_test.go index 9b5bb6911..9f4febb82 100644 --- a/pkg/controller/nodes/handler/transition_info_test.go +++ b/pkg/controller/nodes/handler/transition_info_test.go @@ -3,6 +3,7 @@ package handler import ( "testing" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" "github.com/stretchr/testify/assert" ) @@ -17,10 +18,13 @@ func TestEPhase_String(t *testing.T) { p EPhase }{ {"queued", EPhaseQueued}, + {"not-ready", EPhaseNotReady}, + {"timedout", EPhaseTimedout}, {"undefined", EPhaseUndefined}, {"success", EPhaseSuccess}, {"skip", EPhaseSkip}, {"failed", EPhaseFailed}, + {"running", EPhaseRunning}, {"retryable-fail", EPhaseRetryableFailure}, } for _, tt := range tests { @@ -31,3 +35,140 @@ func TestEPhase_String(t *testing.T) { }) } } + +func TestEPhase_IsTerminal(t *testing.T) { + tests := []struct { + name string + p EPhase + want bool + }{ + {"success", EPhaseSuccess, true}, + {"failure", EPhaseFailed, true}, + {"timeout", EPhaseTimedout, true}, + {"skip", EPhaseSkip, true}, + {"any", EPhaseQueued, false}, + {"retryable", EPhaseRetryableFailure, false}, + {"run", EPhaseRunning, false}, + {"nr", EPhaseNotReady, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.p.IsTerminal(); got != tt.want { + t.Errorf("IsTerminal() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestPhaseInfo(t *testing.T) { + t.Run("undefined", func(t *testing.T) { + assert.Equal(t, EPhaseUndefined, PhaseInfoUndefined.GetPhase()) + }) + + t.Run("success", func(t *testing.T) { + i := &ExecutionInfo{} + p := PhaseInfoSuccess(i) + assert.Equal(t, EPhaseSuccess, p.GetPhase()) + assert.Equal(t, i, p.GetInfo()) + assert.Nil(t, p.GetErr()) + assert.NotNil(t, p.GetOccurredAt()) + }) + + t.Run("not-ready", func(t *testing.T) { + p := PhaseInfoNotReady("reason") + assert.Equal(t, EPhaseNotReady, p.GetPhase()) + assert.Nil(t, p.GetErr()) + assert.NotNil(t, p.GetOccurredAt()) + assert.Equal(t, "reason", p.GetReason()) + }) + + t.Run("queued", func(t *testing.T) { + p := PhaseInfoQueued("reason") + assert.Equal(t, EPhaseQueued, p.GetPhase()) + assert.Nil(t, p.GetErr()) + assert.NotNil(t, p.GetOccurredAt()) + assert.Equal(t, "reason", p.GetReason()) + }) + + t.Run("running", func(t *testing.T) { + i := &ExecutionInfo{} + p := PhaseInfoRunning(i) + assert.Equal(t, EPhaseRunning, p.GetPhase()) + assert.Equal(t, i, p.GetInfo()) + assert.Nil(t, p.GetErr()) + assert.NotNil(t, p.GetOccurredAt()) + }) + + t.Run("skip", func(t *testing.T) { + i := &ExecutionInfo{} + p := PhaseInfoSkip(i, "reason") + assert.Equal(t, EPhaseSkip, p.GetPhase()) + assert.Equal(t, i, p.GetInfo()) + assert.Nil(t, p.GetErr()) + assert.NotNil(t, p.GetOccurredAt()) + assert.Equal(t, "reason", p.GetReason()) + }) + + t.Run("timeout", func(t *testing.T) { + i := &ExecutionInfo{} + p := PhaseInfoTimedOut(i, "reason") + assert.Equal(t, EPhaseTimedout, p.GetPhase()) + assert.Equal(t, i, p.GetInfo()) + assert.Nil(t, p.GetErr()) + assert.NotNil(t, p.GetOccurredAt()) + assert.Equal(t, "reason", p.GetReason()) + }) + + t.Run("failure", func(t *testing.T) { + i := &ExecutionInfo{} + p := PhaseInfoFailure("code", "reason", i) + assert.Equal(t, EPhaseFailed, p.GetPhase()) + assert.Equal(t, i, p.GetInfo()) + if assert.NotNil(t, p.GetErr()) { + assert.Equal(t, "code", p.GetErr().Code) + assert.Equal(t, "reason", p.GetErr().Message) + } + assert.NotNil(t, p.GetOccurredAt()) + }) + + t.Run("failure-err", func(t *testing.T) { + i := &ExecutionInfo{} + e := &core.ExecutionError{} + p := PhaseInfoFailureErr(e, i) + assert.Equal(t, EPhaseFailed, p.GetPhase()) + assert.Equal(t, i, p.GetInfo()) + assert.Equal(t, e, p.GetErr()) + assert.NotNil(t, p.GetOccurredAt()) + }) + + t.Run("failure-err", func(t *testing.T) { + i := &ExecutionInfo{} + p := PhaseInfoFailureErr(nil, i) + assert.Equal(t, EPhaseFailed, p.GetPhase()) + assert.Equal(t, i, p.GetInfo()) + assert.NotNil(t, p.GetErr()) + assert.NotNil(t, p.GetOccurredAt()) + }) + + t.Run("retryable-fail", func(t *testing.T) { + i := &ExecutionInfo{} + p := PhaseInfoRetryableFailure("code", "reason", i) + assert.Equal(t, EPhaseRetryableFailure, p.GetPhase()) + assert.Equal(t, i, p.GetInfo()) + if assert.NotNil(t, p.GetErr()) { + assert.Equal(t, "code", p.GetErr().Code) + assert.Equal(t, "reason", p.GetErr().Message) + } + assert.NotNil(t, p.GetOccurredAt()) + }) + + t.Run("retryable-fail-err", func(t *testing.T) { + i := &ExecutionInfo{} + e := &core.ExecutionError{} + p := PhaseInfoRetryableFailureErr(e, i) + assert.Equal(t, EPhaseRetryableFailure, p.GetPhase()) + assert.Equal(t, i, p.GetInfo()) + assert.Equal(t, e, p.GetErr()) + assert.NotNil(t, p.GetOccurredAt()) + }) +} diff --git a/pkg/controller/nodes/handler/transition_test.go b/pkg/controller/nodes/handler/transition_test.go index 0292cde12..9a7ce1bca 100644 --- a/pkg/controller/nodes/handler/transition_test.go +++ b/pkg/controller/nodes/handler/transition_test.go @@ -23,3 +23,10 @@ func TestDoTransition(t *testing.T) { assert.Equal(t, storage.DataReference("uri"), tr.Info().GetInfo().OutputInfo.OutputURI) }) } + +func TestTransition_WithInfo(t *testing.T) { + tr := DoTransition(TransitionTypeEphemeral, PhaseInfoQueued("queued")) + assert.Equal(t, EPhaseQueued, tr.info.p) + tr = tr.WithInfo(PhaseInfoSuccess(&ExecutionInfo{})) + assert.Equal(t, EPhaseSuccess, tr.info.p) +} diff --git a/pkg/controller/nodes/start/handler_test.go b/pkg/controller/nodes/start/handler_test.go index 73c79d701..574e84188 100644 --- a/pkg/controller/nodes/start/handler_test.go +++ b/pkg/controller/nodes/start/handler_test.go @@ -22,7 +22,7 @@ func TestStartNodeHandler_Initialize(t *testing.T) { assert.NoError(t, h.Setup(context.TODO(), nil)) } -func TestStartNodeHandler_StartNode(t *testing.T) { +func TestStartNodeHandler_Handle(t *testing.T) { ctx := context.Background() h := New() t.Run("Any", func(t *testing.T) { @@ -31,3 +31,18 @@ func TestStartNodeHandler_StartNode(t *testing.T) { assert.Equal(t, handler.EPhaseSuccess, s.Info().GetPhase()) }) } + +func TestEndHandler_Abort(t *testing.T) { + e := New() + assert.NoError(t, e.Abort(context.TODO(), nil, "")) +} + +func TestEndHandler_Finalize(t *testing.T) { + e := New() + assert.NoError(t, e.Finalize(context.TODO(), nil)) +} + +func TestEndHandler_FinalizeRequired(t *testing.T) { + e := New() + assert.False(t, e.FinalizeRequired()) +} diff --git a/pkg/utils/encoder.go b/pkg/utils/encoder.go index 5f49368d8..397fe406b 100644 --- a/pkg/utils/encoder.go +++ b/pkg/utils/encoder.go @@ -19,10 +19,8 @@ func FixedLengthUniqueID(inputID string, maxLength int) (string, error) { } hasher := fnv.New32a() - _, err := hasher.Write([]byte(inputID)) - if err != nil { - return "", err - } + // Using 32a an error can never happen, so this will always remain not covered by a unit test + _, _ = hasher.Write([]byte(inputID)) // #nosec b := hasher.Sum(nil) // expected length after this step is 8 chars (1 + 7 chars from base32Encoder.EncodeToString(b)) finalStr := "f" + base32Encoder.EncodeToString(b) @@ -39,16 +37,12 @@ func FixedLengthUniqueIDForParts(maxLength int, parts ...string) (string, error) b := strings.Builder{} for i, p := range parts { if i > 0 && b.Len() > 0 { - _, err := b.WriteRune('-') - if err != nil { - return "", err - } + // Ignoring the error as it always returns a nil error + _, _ = b.WriteRune('-') // #nosec } - _, err := b.WriteString(p) - if err != nil { - return "", err - } + // Ignoring the error as this is always nil + _, _ = b.WriteString(p) // #nosec } return FixedLengthUniqueID(b.String(), maxLength) diff --git a/pkg/utils/event_helpers.go b/pkg/utils/event_helpers.go deleted file mode 100644 index 61af9f01d..000000000 --- a/pkg/utils/event_helpers.go +++ /dev/null @@ -1,32 +0,0 @@ -package utils - -import ( - "context" - - "github.com/lyft/flyteidl/clients/go/events" - eventsErr "github.com/lyft/flyteidl/clients/go/events/errors" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/event" - "github.com/lyft/flytestdlib/logger" -) - -// Construct task event recorder to pass down to plugin. This is a just a wrapper around the normal -// taskEventRecorder that can encapsulate logic to validate and handle errors. -func NewPluginTaskEventRecorder(taskEventRecorder events.TaskEventRecorder) events.TaskEventRecorder { - return &pluginTaskEventRecorder{ - taskEventRecorder: taskEventRecorder, - } -} - -type pluginTaskEventRecorder struct { - taskEventRecorder events.TaskEventRecorder -} - -func (r pluginTaskEventRecorder) RecordTaskEvent(ctx context.Context, event *event.TaskExecutionEvent) error { - err := r.taskEventRecorder.RecordTaskEvent(ctx, event) - if err != nil && eventsErr.IsAlreadyExists(err) { - logger.Infof(ctx, "Task event phase: %s, taskId %s, retry attempt %d - already exists", - event.Phase.String(), event.GetTaskId(), event.RetryAttempt) - return nil - } - return err -} diff --git a/pkg/utils/failing_datastore_test.go b/pkg/utils/failing_datastore_test.go index 9adb3e885..7c617f40e 100644 --- a/pkg/utils/failing_datastore_test.go +++ b/pkg/utils/failing_datastore_test.go @@ -21,6 +21,7 @@ func TestFailingRawStore(t *testing.T) { _, err = f.ReadRaw(ctx, "") assert.Error(t, err) - err = f.WriteRaw(ctx, "", 0, storage.Options{}, bytes.NewReader(nil)) - assert.Error(t, err) + assert.Error(t, f.WriteRaw(ctx, "", 0, storage.Options{}, bytes.NewReader(nil))) + + assert.Error(t, f.CopyRaw(ctx, "", "", storage.Options{})) } diff --git a/pkg/utils/k8s.go b/pkg/utils/k8s.go index b0bf36122..c5eb32969 100644 --- a/pkg/utils/k8s.go +++ b/pkg/utils/k8s.go @@ -80,17 +80,6 @@ func ToK8sResourceRequirements(resources *core.Resources) (*v1.ResourceRequireme return res, nil } -func GetWorkflowIDFromObject(obj metav1.Object) (v1alpha1.WorkflowID, error) { - controller := metav1.GetControllerOf(obj) - if controller == nil { - return "", NotTheOwnerError - } - if controller.Kind == v1alpha1.FlyteWorkflowKind { - return obj.GetNamespace() + "/" + controller.Name, nil - } - return "", NotTheOwnerError -} - func GetWorkflowIDFromOwner(reference *metav1.OwnerReference, namespace string) (v1alpha1.WorkflowID, error) { if reference == nil { return "", NotTheOwnerError diff --git a/pkg/utils/k8s_test.go b/pkg/utils/k8s_test.go index 806d4dd49..89948230a 100644 --- a/pkg/utils/k8s_test.go +++ b/pkg/utils/k8s_test.go @@ -6,13 +6,12 @@ import ( "github.com/golang/protobuf/ptypes" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" - "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/stretchr/testify/assert" - v12 "k8s.io/api/batch/v1" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - v13 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" ) func TestToK8sEnvVar(t *testing.T) { @@ -119,51 +118,6 @@ func TestToK8sResourceRequirements(t *testing.T) { } } -func TestGetWorkflowIDFromObject(t *testing.T) { - { - b := true - j := &v12.Job{ - ObjectMeta: v13.ObjectMeta{ - Namespace: "ns", - OwnerReferences: []v13.OwnerReference{ - { - APIVersion: "test", - Kind: v1alpha1.FlyteWorkflowKind, - Name: "my-id", - UID: "blah", - BlockOwnerDeletion: &b, - Controller: &b, - }, - }, - }, - } - w, err := GetWorkflowIDFromObject(j) - assert.NoError(t, err) - assert.Equal(t, "ns/my-id", w) - } - { - b := true - j := &v12.Job{ - ObjectMeta: v13.ObjectMeta{ - Namespace: "ns", - OwnerReferences: []v13.OwnerReference{ - { - APIVersion: "test", - Kind: "some-other", - Name: "my-id", - UID: "blah", - BlockOwnerDeletion: &b, - Controller: &b, - }, - }, - }, - } - _, err := GetWorkflowIDFromObject(j) - assert.Error(t, err) - } - -} - func TestGetProtoTime(t *testing.T) { assert.NotNil(t, GetProtoTime(nil)) n := time.Now() @@ -192,3 +146,15 @@ func TestGetWorkflowIDFromOwner(t *testing.T) { }) } } + +func TestSanitizeLabelValue(t *testing.T) { + assert.Equal(t, "a-b-c", SanitizeLabelValue("a.b.c")) + assert.Equal(t, "a-b-c", SanitizeLabelValue("a.B.c")) + assert.Equal(t, "a-9-c", SanitizeLabelValue("a.9.c")) + assert.Equal(t, "a-b-c", SanitizeLabelValue("a-b-c")) + assert.Equal(t, "a-b-c", SanitizeLabelValue("a-b-c/")) + assert.Equal(t, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", SanitizeLabelValue("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")) + assert.Equal(t, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", SanitizeLabelValue("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.")) + assert.Equal(t, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", SanitizeLabelValue("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab")) + assert.Equal(t, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", SanitizeLabelValue("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.")) +} diff --git a/pkg/utils/literals.go b/pkg/utils/literals.go index d882a3f14..900cb18f5 100644 --- a/pkg/utils/literals.go +++ b/pkg/utils/literals.go @@ -176,76 +176,77 @@ func MustMakeDefaultLiteralForType(typ *core.LiteralType) *core.Literal { } func MakeDefaultLiteralForType(typ *core.LiteralType) (*core.Literal, error) { - switch t := typ.GetType().(type) { - case *core.LiteralType_Simple: - switch t.Simple { - case core.SimpleType_NONE: - return MakeLiteral(nil) - case core.SimpleType_INTEGER: - return MakeLiteral(int(0)) - case core.SimpleType_FLOAT: - return MakeLiteral(float64(0)) - case core.SimpleType_STRING: - return MakeLiteral("") - case core.SimpleType_BOOLEAN: - return MakeLiteral(false) - case core.SimpleType_DATETIME: - return MakeLiteral(time.Now()) - case core.SimpleType_DURATION: - return MakeLiteral(time.Second) - case core.SimpleType_BINARY: - return MakeLiteral([]byte{}) - //case core.SimpleType_WAITABLE: - //case core.SimpleType_ERROR: - } - return nil, errors.Errorf("Not yet implemented. Default creation is not yet implemented. ") + if typ != nil { + switch t := typ.GetType().(type) { + case *core.LiteralType_Simple: + switch t.Simple { + case core.SimpleType_NONE: + return MakeLiteral(nil) + case core.SimpleType_INTEGER: + return MakeLiteral(int(0)) + case core.SimpleType_FLOAT: + return MakeLiteral(float64(0)) + case core.SimpleType_STRING: + return MakeLiteral("") + case core.SimpleType_BOOLEAN: + return MakeLiteral(false) + case core.SimpleType_DATETIME: + return MakeLiteral(time.Now()) + case core.SimpleType_DURATION: + return MakeLiteral(time.Second) + case core.SimpleType_BINARY: + return MakeLiteral([]byte{}) + // case core.SimpleType_ERROR: + // case core.SimpleType_STRUCT: + } + return nil, errors.Errorf("Not yet implemented. Default creation is not yet implemented. ") - case *core.LiteralType_Blob: - return &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Blob{ - Blob: &core.Blob{ - Metadata: &core.BlobMetadata{ - Type: t.Blob, + case *core.LiteralType_Blob: + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Blob{ + Blob: &core.Blob{ + Metadata: &core.BlobMetadata{ + Type: t.Blob, + }, + Uri: "/tmp/somepath", }, - Uri: "/tmp/somepath", }, }, }, - }, - }, nil - case *core.LiteralType_CollectionType: - single, err := MakeDefaultLiteralForType(t.CollectionType) - if err != nil { - return nil, err - } + }, nil + case *core.LiteralType_CollectionType: + single, err := MakeDefaultLiteralForType(t.CollectionType) + if err != nil { + return nil, err + } - return &core.Literal{ - Value: &core.Literal_Collection{ - Collection: &core.LiteralCollection{ - Literals: []*core.Literal{single}, + return &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{single}, + }, }, - }, - }, nil - case *core.LiteralType_MapValueType: - single, err := MakeDefaultLiteralForType(t.MapValueType) - if err != nil { - return nil, err - } + }, nil + case *core.LiteralType_MapValueType: + single, err := MakeDefaultLiteralForType(t.MapValueType) + if err != nil { + return nil, err + } - return &core.Literal{ - Value: &core.Literal_Map{ - Map: &core.LiteralMap{ - Literals: map[string]*core.Literal{ - "itemKey": single, + return &core.Literal{ + Value: &core.Literal_Map{ + Map: &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "itemKey": single, + }, }, }, - }, - }, nil - //case *core.LiteralType_Schema: + }, nil + // case *core.LiteralType_Schema: + } } - return nil, errors.Errorf("Failed to convert to a known Literal. Input Type [%v] not supported", typ.String()) } diff --git a/pkg/utils/literals_test.go b/pkg/utils/literals_test.go index 05dd11ff7..558ce3ee6 100644 --- a/pkg/utils/literals_test.go +++ b/pkg/utils/literals_test.go @@ -54,6 +54,8 @@ func TestMakePrimitive(t *testing.T) { j, err := ptypes.TimestampProto(v) assert.NoError(t, err) assert.Equal(t, j, p.GetDatetime()) + _, err = MakePrimitive(time.Date(0, 0, 0, 0, 0, 0, 0, time.UTC)) + assert.Error(t, err) } { v := time.Second * 10 @@ -170,31 +172,67 @@ func TestMakeBinaryLiteral(t *testing.T) { } func TestMakeDefaultLiteralForType(t *testing.T) { - - tests := [][]interface{}{ - {"Integer", core.SimpleType_INTEGER, "*core.Primitive_Integer"}, - {"Float", core.SimpleType_FLOAT, "*core.Primitive_FloatValue"}, - {"String", core.SimpleType_STRING, "*core.Primitive_StringValue"}, - {"Boolean", core.SimpleType_BOOLEAN, "*core.Primitive_Boolean"}, - {"Duration", core.SimpleType_DURATION, "*core.Primitive_Duration"}, - {"Datetime", core.SimpleType_DATETIME, "*core.Primitive_Datetime"}, + type args struct { + name string + ty core.SimpleType + tyName string + isPrimitive bool } - - for i := range tests { - name := tests[i][0].(string) - ty := tests[i][1].(core.SimpleType) - tyName := tests[i][2].(string) - - t.Run(name, func(t *testing.T) { - l, err := MakeDefaultLiteralForType(&core.LiteralType{Type: &core.LiteralType_Simple{Simple: ty}}) + tests := []args{ + {"None", core.SimpleType_NONE, "*core.Scalar_NoneType", false}, + {"Binary", core.SimpleType_BINARY, "*core.Scalar_Binary", false}, + {"Integer", core.SimpleType_INTEGER, "*core.Primitive_Integer", true}, + {"Float", core.SimpleType_FLOAT, "*core.Primitive_FloatValue", true}, + {"String", core.SimpleType_STRING, "*core.Primitive_StringValue", true}, + {"Boolean", core.SimpleType_BOOLEAN, "*core.Primitive_Boolean", true}, + {"Duration", core.SimpleType_DURATION, "*core.Primitive_Duration", true}, + {"Datetime", core.SimpleType_DATETIME, "*core.Primitive_Datetime", true}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + l, err := MakeDefaultLiteralForType(&core.LiteralType{Type: &core.LiteralType_Simple{Simple: test.ty}}) assert.NoError(t, err) - assert.Equal(t, tyName, reflect.TypeOf(l.GetScalar().GetPrimitive().Value).String()) + if test.isPrimitive { + assert.Equal(t, test.tyName, reflect.TypeOf(l.GetScalar().GetPrimitive().Value).String()) + } else { + assert.Equal(t, test.tyName, reflect.TypeOf(l.GetScalar().Value).String()) + } }) } - t.Run("Binary", func(t *testing.T) { - s, err := MakeLiteral([]byte{'h'}) + t.Run("Blob", func(t *testing.T) { + l, err := MakeDefaultLiteralForType(&core.LiteralType{Type: &core.LiteralType_Blob{}}) assert.NoError(t, err) - assert.Equal(t, []byte{'h'}, s.GetScalar().GetBinary().GetValue()) + assert.Equal(t, "*core.Scalar_Blob", reflect.TypeOf(l.GetScalar().Value).String()) + }) + + t.Run("Collection", func(t *testing.T) { + l, err := MakeDefaultLiteralForType(&core.LiteralType{Type: &core.LiteralType_CollectionType{CollectionType: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}}}) + assert.NoError(t, err) + assert.Equal(t, "*core.LiteralCollection", reflect.TypeOf(l.GetCollection()).String()) + }) + + t.Run("Map", func(t *testing.T) { + l, err := MakeDefaultLiteralForType(&core.LiteralType{Type: &core.LiteralType_MapValueType{MapValueType: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}}}) + assert.NoError(t, err) + assert.Equal(t, "*core.LiteralMap", reflect.TypeOf(l.GetMap()).String()) + }) + + t.Run("error", func(t *testing.T) { + _, err := MakeDefaultLiteralForType(nil) + assert.Error(t, err) + }) +} + +func TestMustMakeDefaultLiteralForType(t *testing.T) { + t.Run("error", func(t *testing.T) { + assert.Panics(t, func() { + MustMakeDefaultLiteralForType(nil) + }) + }) + + t.Run("Blob", func(t *testing.T) { + l := MustMakeDefaultLiteralForType(&core.LiteralType{Type: &core.LiteralType_Blob{}}) + assert.Equal(t, "*core.Scalar_Blob", reflect.TypeOf(l.GetScalar().Value).String()) }) }