Skip to content

Commit

Permalink
Fix flyteplugins unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
eapolinario committed Jun 6, 2024
1 parent c3fc152 commit 9a7ceb8
Show file tree
Hide file tree
Showing 13 changed files with 21 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ func TestToK8sContainer(t *testing.T) {
"foo": "bar",
})
mockTaskExecMetadata.OnGetNamespace().Return("my-namespace")
mockTaskExecMetadata.OnGetConsoleURL().Return("")

tCtx := &mocks.TaskExecutionContext{}
tCtx.OnTaskExecutionMetadata().Return(&mockTaskExecMetadata)
Expand Down Expand Up @@ -477,6 +478,7 @@ func getTemplateParametersForTest(resourceRequirements, platformResources *v1.Re
mockTaskExecMetadata.OnGetPlatformResources().Return(platformResources)
mockTaskExecMetadata.OnGetEnvironmentVariables().Return(nil)
mockTaskExecMetadata.OnGetNamespace().Return("my-namespace")
mockTaskExecMetadata.OnGetConsoleURL().Return("")

mockInputReader := mocks2.InputReader{}
mockInputPath := storage.DataReference("s3://input/path")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (

func TestGetExecutionEnvVars(t *testing.T) {
mock := mockTaskExecutionIdentifier{}
envVars := GetExecutionEnvVars(mock)
envVars := GetExecutionEnvVars(mock, "")
assert.Len(t, envVars, 12)
}

Expand Down Expand Up @@ -257,7 +257,7 @@ func TestDecorateEnvVars(t *testing.T) {
defer os.Setenv("value", originalEnvVal)

expected := append(defaultEnv, GetContextEnvVars(ctx)...)
expected = append(expected, GetExecutionEnvVars(mockTaskExecutionIdentifier{})...)
expected = append(expected, GetExecutionEnvVars(mockTaskExecutionIdentifier{}, "")...)

aggregated := append(expected, v12.EnvVar{Name: "k", Value: "v"})
type args struct {
Expand All @@ -270,20 +270,21 @@ func TestDecorateEnvVars(t *testing.T) {
additionEnvVar map[string]string
additionEnvVarFromEnv map[string]string
executionEnvVar map[string]string
consoleURL string
want []v12.EnvVar
}{
{"no-additional", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, emptyEnvVar, emptyEnvVar, expected},
{"with-additional", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, additionalEnv, emptyEnvVar, emptyEnvVar, aggregated},
{"from-env", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, envVarsFromEnv, emptyEnvVar, aggregated},
{"from-execution-metadata", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, emptyEnvVar, additionalEnv, aggregated},
{"no-additional", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, emptyEnvVar, emptyEnvVar, "", expected},
{"with-additional", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, additionalEnv, emptyEnvVar, emptyEnvVar, "", aggregated},
{"from-env", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, envVarsFromEnv, emptyEnvVar, "", aggregated},
{"from-execution-metadata", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, emptyEnvVar, additionalEnv, "", aggregated},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{
DefaultEnvVars: tt.additionEnvVar,
DefaultEnvVarsFromEnv: tt.additionEnvVarFromEnv,
}))
if got, _ := DecorateEnvVars(ctx, tt.args.envVars, tt.executionEnvVar, tt.args.id); !reflect.DeepEqual(got, tt.want) {
if got, _ := DecorateEnvVars(ctx, tt.args.envVars, tt.executionEnvVar, tt.args.id, tt.consoleURL); !reflect.DeepEqual(got, tt.want) {
t.Errorf("DecorateEnvVars() = %v, want %v", got, tt.want)
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func dummyTaskExecutionMetadata(resources *v1.ResourceRequirements, extendedReso
taskExecutionMetadata.On("IsInterruptible").Return(true)
taskExecutionMetadata.OnGetPlatformResources().Return(&v1.ResourceRequirements{})
taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil)
taskExecutionMetadata.OnGetConsoleURL().Return("")
return taskExecutionMetadata
}

Expand Down
1 change: 1 addition & 0 deletions flyteplugins/go/tasks/plugins/array/k8s/management_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ func getMockTaskExecutionContext(ctx context.Context, parallelism int) *mocks.Ta
tMeta.OnGetPlatformResources().Return(&v1.ResourceRequirements{})
tMeta.OnGetInterruptibleFailureThreshold().Return(2)
tMeta.OnGetEnvironmentVariables().Return(nil)
tMeta.OnGetConsoleURL().Return("")

ow := &mocks2.OutputWriter{}
ow.OnGetOutputPrefixPath().Return("/prefix/")
Expand Down
1 change: 1 addition & 0 deletions flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.Resourc
taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil)
taskExecutionMetadata.OnGetK8sServiceAccount().Return(defaultServiceAccountName)
taskExecutionMetadata.OnGetNamespace().Return(defaultNamespace)
taskExecutionMetadata.OnGetConsoleURL().Return("")
overrides := &mocks.TaskOverrides{}
overrides.OnGetResources().Return(resources)
overrides.OnGetExtendedResources().Return(extendedResources)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ func dummyMPITaskContext(taskTemplate *core.TaskTemplate, resources *corev1.Reso
taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount)
taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{})
taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil)
taskExecutionMetadata.OnGetConsoleURL().Return("")
taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata)
return taskCtx
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.
taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount)
taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{})
taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil)
taskExecutionMetadata.OnGetConsoleURL().Return("")
taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata)
return taskCtx
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate, resources *core
taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount)
taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{})
taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil)
taskExecutionMetadata.OnGetConsoleURL().Return("")
taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata)
return taskCtx
}
Expand Down
1 change: 1 addition & 0 deletions flyteplugins/go/tasks/plugins/k8s/pod/container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ func dummyContainerTaskMetadata(resources *v1.ResourceRequirements, extendedReso
taskMetadata.On("GetOverrides").Return(to)
taskMetadata.On("IsInterruptible").Return(true)
taskMetadata.On("GetEnvironmentVariables").Return(nil)
taskMetadata.OnGetConsoleURL().Return("")
return taskMetadata
}

Expand Down
1 change: 1 addition & 0 deletions flyteplugins/go/tasks/plugins/k8s/pod/sidecar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func dummySidecarTaskMetadata(resources *v1.ResourceRequirements, extendedResour
to.On("GetContainerImage").Return("")
taskMetadata.On("GetOverrides").Return(to)
taskMetadata.On("GetEnvironmentVariables").Return(nil)
taskMetadata.On("GetConsoleURL").Return("")

return taskMetadata
}
Expand Down
1 change: 1 addition & 0 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ func dummyRayTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.Reso
RunAs: &core.Identity{K8SServiceAccount: serviceAccount},
})
taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil)
taskExecutionMetadata.OnGetConsoleURL().Return("")
taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata)
return taskCtx
}
Expand Down
1 change: 1 addition & 0 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool)
taskExecutionMetadata.On("GetPlatformResources").Return(nil)
taskExecutionMetadata.On("GetOverrides").Return(overrides)
taskExecutionMetadata.On("GetK8sServiceAccount").Return("new-val")
taskExecutionMetadata.On("GetConsoleURL").Return("")
taskCtx.On("TaskExecutionMetadata").Return(taskExecutionMetadata)
return taskCtx
}
Expand Down
1 change: 1 addition & 0 deletions flyteplugins/tests/end_to_end.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i
tMeta.OnGetPlatformResources().Return(&v1.ResourceRequirements{})
tMeta.OnGetInterruptibleFailureThreshold().Return(2)
tMeta.OnGetEnvironmentVariables().Return(nil)
tMeta.OnGetConsoleURL().Return("")

catClient := &catalogMocks.Client{}
catData := sync.Map{}
Expand Down

0 comments on commit 9a7ceb8

Please sign in to comment.