diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/error_test.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/error_test.go index 4e0968205d..d709fa6803 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/error_test.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/error_test.go @@ -1,7 +1,6 @@ package v1alpha1 import ( - "encoding/json" "testing" "github.com/stretchr/testify/assert" @@ -10,21 +9,28 @@ import ( ) func TestExecutionErrorJSONMarshalling(t *testing.T) { - execError := &core.ExecutionError{ - Code: "TestCode", - Message: "Test error message", - ErrorUri: "Test error uri", + execError := ExecutionError{ + &core.ExecutionError{ + Code: "TestCode", + Message: "Test error message", + ErrorUri: "Test error uri", + }, } - execErr := &ExecutionError{ExecutionError: execError} - data, jErr := json.Marshal(execErr) - assert.Nil(t, jErr) + expected, mockErr := mockMarshalPbToBytes(execError.ExecutionError) + assert.Nil(t, mockErr) - newExecErr := &ExecutionError{} - uErr := json.Unmarshal(data, newExecErr) + // MarshalJSON + execErrorBytes, mErr := execError.MarshalJSON() + assert.Nil(t, mErr) + assert.Equal(t, expected, execErrorBytes) + + // UnmarshalJSON + execErrorObj := &ExecutionError{} + uErr := execErrorObj.UnmarshalJSON(execErrorBytes) assert.Nil(t, uErr) - assert.Equal(t, execError.Code, newExecErr.ExecutionError.Code) - assert.Equal(t, execError.Message, newExecErr.ExecutionError.Message) - assert.Equal(t, execError.ErrorUri, newExecErr.ExecutionError.ErrorUri) + assert.Equal(t, execError.Code, execErrorObj.Code) + assert.Equal(t, execError.Message, execError.Message) + assert.Equal(t, execError.ErrorUri, execErrorObj.ErrorUri) } diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/gate_test.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/gate_test.go new file mode 100644 index 0000000000..2fc532d4e4 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/gate_test.go @@ -0,0 +1,150 @@ +package v1alpha1 + +import ( + "bytes" + "testing" + + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/durationpb" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" +) + +func mockMarshalPbToBytes(msg proto.Message) ([]byte, error) { + var buf bytes.Buffer + jMarshaller := jsonpb.Marshaler{} + if err := jMarshaller.Marshal(&buf, msg); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func TestApproveConditionJSONMarshalling(t *testing.T) { + approveCondition := ApproveCondition{ + &core.ApproveCondition{ + SignalId: "TestSignalId", + }, + } + + expected, mockErr := mockMarshalPbToBytes(approveCondition.ApproveCondition) + assert.Nil(t, mockErr) + + // MarshalJSON + approveConditionBytes, mErr := approveCondition.MarshalJSON() + assert.Nil(t, mErr) + assert.Equal(t, expected, approveConditionBytes) + + // UnmarshalJSON + approveConditionObj := &ApproveCondition{} + uErr := approveConditionObj.UnmarshalJSON(approveConditionBytes) + assert.Nil(t, uErr) + assert.Equal(t, approveCondition.SignalId, approveConditionObj.SignalId) +} + +func TestSignalConditionJSONMarshalling(t *testing.T) { + signalCondition := SignalCondition{ + &core.SignalCondition{ + SignalId: "TestSignalId", + }, + } + + expected, mockErr := mockMarshalPbToBytes(signalCondition.SignalCondition) + assert.Nil(t, mockErr) + + // MarshalJSON + signalConditionBytes, mErr := signalCondition.MarshalJSON() + assert.Nil(t, mErr) + assert.Equal(t, expected, signalConditionBytes) + + // UnmarshalJSON + signalConditionObj := &SignalCondition{} + uErr := signalConditionObj.UnmarshalJSON(signalConditionBytes) + assert.Nil(t, uErr) + assert.Equal(t, signalCondition.SignalId, signalConditionObj.SignalId) +} + +func TestSleepConditionJSONMarshalling(t *testing.T) { + sleepCondition := SleepCondition{ + &core.SleepCondition{ + Duration: &durationpb.Duration{ + Seconds: 10, + Nanos: 10, + }, + }, + } + + expected, mockErr := mockMarshalPbToBytes(sleepCondition.SleepCondition) + assert.Nil(t, mockErr) + + // MarshalJSON + sleepConditionBytes, mErr := sleepCondition.MarshalJSON() + assert.Nil(t, mErr) + assert.Equal(t, expected, sleepConditionBytes) + + // UnmarshalJSON + sleepConditionObj := &SleepCondition{} + uErr := sleepConditionObj.UnmarshalJSON(sleepConditionBytes) + assert.Nil(t, uErr) + assert.Equal(t, sleepCondition.Duration, sleepConditionObj.Duration) +} + +func TestGateNodeSpec_GetKind(t *testing.T) { + kind := ConditionKindApprove + gateNodeSpec := GateNodeSpec{ + Kind: kind, + } + + if gateNodeSpec.GetKind() != kind { + t.Errorf("Expected %s, but got %s", kind, gateNodeSpec.GetKind()) + } +} + +func TestGateNodeSpec_GetApprove(t *testing.T) { + approveCondition := &ApproveCondition{ + &core.ApproveCondition{ + SignalId: "TestSignalId", + }, + } + gateNodeSpec := GateNodeSpec{ + Approve: approveCondition, + } + + if gateNodeSpec.GetApprove() != approveCondition.ApproveCondition { + t.Errorf("Expected approveCondition, but got a different value") + } +} + +func TestGateNodeSpec_GetSignal(t *testing.T) { + signalCondition := &SignalCondition{ + &core.SignalCondition{ + SignalId: "TestSignalId", + }, + } + gateNodeSpec := GateNodeSpec{ + Signal: signalCondition, + } + + if gateNodeSpec.GetSignal() != signalCondition.SignalCondition { + t.Errorf("Expected signalCondition, but got a different value") + } +} + +func TestGateNodeSpec_GetSleep(t *testing.T) { + sleepCondition := &SleepCondition{ + &core.SleepCondition{ + Duration: &durationpb.Duration{ + Seconds: 10, + Nanos: 10, + }, + }, + } + gateNodeSpec := GateNodeSpec{ + Sleep: sleepCondition, + } + + if gateNodeSpec.GetSleep() != sleepCondition.SleepCondition { + t.Errorf("Expected sleepCondition, but got a different value") + } +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/identifier_test.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/identifier_test.go new file mode 100644 index 0000000000..1267aec09b --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/identifier_test.go @@ -0,0 +1,115 @@ +package v1alpha1 + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" +) + +func TestIdentifierJSONMarshalling(t *testing.T) { + identifier := Identifier{ + &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Project: "TestProject", + Domain: "TestDomain", + Name: "TestName", + Version: "TestVersion", + }, + } + + expected, mockErr := mockMarshalPbToBytes(identifier.Identifier) + assert.Nil(t, mockErr) + + // MarshalJSON + identifierBytes, mErr := identifier.MarshalJSON() + assert.Nil(t, mErr) + assert.Equal(t, expected, identifierBytes) + + // UnmarshalJSON + identifierObj := &Identifier{} + uErr := identifierObj.UnmarshalJSON(identifierBytes) + assert.Nil(t, uErr) + assert.Equal(t, identifier.Project, identifierObj.Project) + assert.Equal(t, identifier.Domain, identifierObj.Domain) + assert.Equal(t, identifier.Name, identifierObj.Name) + assert.Equal(t, identifier.Version, identifierObj.Version) +} + +func TestIdentifier_DeepCopyInto(t *testing.T) { + identifier := Identifier{ + &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Project: "TestProject", + Domain: "TestDomain", + Name: "TestName", + Version: "TestVersion", + }, + } + + identifierCopy := Identifier{} + identifier.DeepCopyInto(&identifierCopy) + assert.Equal(t, identifier.Project, identifierCopy.Project) + assert.Equal(t, identifier.Domain, identifierCopy.Domain) + assert.Equal(t, identifier.Name, identifierCopy.Name) + assert.Equal(t, identifier.Version, identifierCopy.Version) +} + +func TestWorkflowExecutionIdentifier_DeepCopyInto(t *testing.T) { + weIdentifier := WorkflowExecutionIdentifier{ + &core.WorkflowExecutionIdentifier{ + Project: "TestProject", + Domain: "TestDomain", + Name: "TestName", + Org: "TestOrg", + }, + } + + weIdentifierCopy := WorkflowExecutionIdentifier{} + weIdentifier.DeepCopyInto(&weIdentifierCopy) + assert.Equal(t, weIdentifier.Project, weIdentifierCopy.Project) + assert.Equal(t, weIdentifier.Domain, weIdentifierCopy.Domain) + assert.Equal(t, weIdentifier.Name, weIdentifierCopy.Name) + assert.Equal(t, weIdentifier.Org, weIdentifierCopy.Org) +} + +func TestTaskExecutionIdentifier_DeepCopyInto(t *testing.T) { + teIdentifier := TaskExecutionIdentifier{ + &core.TaskExecutionIdentifier{ + TaskId: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Project: "TestProject", + Domain: "TestDomain", + Name: "TestName", + Version: "TestVersion", + Org: "TestOrg", + }, + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "TestProject", + Domain: "TestDomain", + Name: "TestName", + Org: "TestOrg", + }, + NodeId: "TestNodeId", + }, + RetryAttempt: 1, + }, + } + + teIdentifierCopy := TaskExecutionIdentifier{} + teIdentifier.DeepCopyInto(&teIdentifierCopy) + assert.Equal(t, teIdentifier.TaskId.ResourceType, teIdentifierCopy.TaskId.ResourceType) + assert.Equal(t, teIdentifier.TaskId.Project, teIdentifierCopy.TaskId.Project) + assert.Equal(t, teIdentifier.TaskId.Domain, teIdentifierCopy.TaskId.Domain) + assert.Equal(t, teIdentifier.TaskId.Name, teIdentifierCopy.TaskId.Name) + assert.Equal(t, teIdentifier.TaskId.Version, teIdentifierCopy.TaskId.Version) + assert.Equal(t, teIdentifier.TaskId.Org, teIdentifierCopy.TaskId.Org) + assert.Equal(t, teIdentifier.NodeExecutionId.ExecutionId.Project, teIdentifierCopy.NodeExecutionId.ExecutionId.Project) + assert.Equal(t, teIdentifier.NodeExecutionId.ExecutionId.Domain, teIdentifierCopy.NodeExecutionId.ExecutionId.Domain) + assert.Equal(t, teIdentifier.NodeExecutionId.ExecutionId.Name, teIdentifierCopy.NodeExecutionId.ExecutionId.Name) + assert.Equal(t, teIdentifier.NodeExecutionId.ExecutionId.Org, teIdentifierCopy.NodeExecutionId.ExecutionId.Org) + assert.Equal(t, teIdentifier.NodeExecutionId.NodeId, teIdentifierCopy.NodeExecutionId.NodeId) + assert.Equal(t, teIdentifier.RetryAttempt, teIdentifierCopy.RetryAttempt) +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/register.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/register.go index 065b8a8852..8e0d96ada3 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/register.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/register.go @@ -13,7 +13,7 @@ const FlyteWorkflowKind = "flyteworkflow" // SchemeGroupVersion is group version used to register these objects var SchemeGroupVersion = schema.GroupVersion{Group: flyteworkflow.GroupName, Version: "v1alpha1"} -// GetKind takes an unqualified kind and returns back a Group qualified GroupKind +// Kind takes an unqualified kind and returns back a Group qualified GroupKind func Kind(kind string) schema.GroupKind { return SchemeGroupVersion.WithKind(kind).GroupKind() } diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/register_test.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/register_test.go new file mode 100644 index 0000000000..f55a596ad8 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/register_test.go @@ -0,0 +1,28 @@ +package v1alpha1 + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/runtime" +) + +func TestKind(t *testing.T) { + kind := "test kind" + got := Kind(kind) + want := SchemeGroupVersion.WithKind(kind).GroupKind() + assert.Equal(t, got, want) +} + +func TestResource(t *testing.T) { + resource := "test resource" + got := Resource(resource) + want := SchemeGroupVersion.WithResource(resource).GroupResource() + assert.Equal(t, got, want) +} + +func Test_addKnownTypes(t *testing.T) { + scheme := runtime.NewScheme() + err := addKnownTypes(scheme) + assert.Nil(t, err) +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/subworkflow_test.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/subworkflow_test.go new file mode 100644 index 0000000000..c0534ef9ea --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/subworkflow_test.go @@ -0,0 +1,36 @@ +package v1alpha1 + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" +) + +func TestWorkflowNodeSpec_GetLaunchPlanRefID(t *testing.T) { + wfNodeSpec := &WorkflowNodeSpec{ + LaunchPlanRefID: &LaunchPlanRefID{ + &core.Identifier{ + Project: "TestProject", + }, + }, + } + + nilWfNodeSpec := &WorkflowNodeSpec{} + + assert.Equal(t, wfNodeSpec.GetLaunchPlanRefID(), wfNodeSpec.LaunchPlanRefID) + assert.Empty(t, nilWfNodeSpec.GetLaunchPlanRefID()) +} + +func TestWorkflowNodeSpec_GetSubWorkflowRef(t *testing.T) { + workflowID := "TestWorkflowID" + wfNodeSpec := &WorkflowNodeSpec{ + SubWorkflowReference: &workflowID, + } + + nilWfNodeSpec := &WorkflowNodeSpec{} + + assert.Equal(t, wfNodeSpec.GetSubWorkflowRef(), wfNodeSpec.SubWorkflowReference) + assert.Empty(t, nilWfNodeSpec.GetSubWorkflowRef()) +}