diff --git a/flyteplugins/go/tasks/plugins/hive/execution_state.go b/flyteplugins/go/tasks/plugins/hive/execution_state.go index 978c33893c..b833f37256 100644 --- a/flyteplugins/go/tasks/plugins/hive/execution_state.go +++ b/flyteplugins/go/tasks/plugins/hive/execution_state.go @@ -196,8 +196,12 @@ func GetQueryInfo(ctx context.Context, tCtx core.TaskExecutionContext) ( query = hiveJob.Query.GetQuery() cluster = hiveJob.ClusterLabel - tags = hiveJob.Tags timeoutSec = hiveJob.Query.TimeoutSec + tags = hiveJob.Tags + tags = append(tags, fmt.Sprintf("ns:%s", tCtx.TaskExecutionMetadata().GetNamespace())) + for k, v := range tCtx.TaskExecutionMetadata().GetLabels() { + tags = append(tags, fmt.Sprintf("%s:%s", k, v)) + } return } diff --git a/flyteplugins/go/tasks/plugins/hive/execution_state_test.go b/flyteplugins/go/tasks/plugins/hive/execution_state_test.go index d080646fe6..a9abb756da 100644 --- a/flyteplugins/go/tasks/plugins/hive/execution_state_test.go +++ b/flyteplugins/go/tasks/plugins/hive/execution_state_test.go @@ -2,10 +2,11 @@ package hive import ( "context" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" "net/url" "testing" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + mocks2 "github.com/lyft/flytestdlib/cache/mocks" "github.com/stretchr/testify/assert" @@ -13,6 +14,7 @@ import ( "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks" + pluginsCoreMocks "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/lyft/flyteplugins/go/tasks/plugins/hive/client" quboleMocks "github.com/lyft/flyteplugins/go/tasks/plugins/hive/client/mocks" "github.com/lyft/flyteplugins/go/tasks/plugins/hive/config" @@ -70,18 +72,23 @@ func TestGetQueryInfo(t *testing.T) { mockTaskExecutionContext := mocks.TaskExecutionContext{} mockTaskExecutionContext.On("TaskReader").Return(mockTaskReader) + taskMetadata := &pluginsCoreMocks.TaskExecutionMetadata{} + taskMetadata.On("GetNamespace").Return("myproject-staging") + taskMetadata.On("GetLabels").Return(map[string]string{"sample": "label"}) + mockTaskExecutionContext.On("TaskExecutionMetadata").Return(taskMetadata) + query, cluster, tags, timeout, err := GetQueryInfo(ctx, &mockTaskExecutionContext) assert.NoError(t, err) assert.Equal(t, "select 'one'", query) assert.Equal(t, "default", cluster) - assert.Equal(t, []string{"flyte_plugin_test"}, tags) + assert.Equal(t, []string{"flyte_plugin_test", "ns:myproject-staging", "sample:label"}, tags) assert.Equal(t, 500, int(timeout)) } func TestValidateQuboleHiveJob(t *testing.T) { hiveJob := plugins.QuboleHiveJob{ ClusterLabel: "default", - Tags: []string{"flyte_plugin_test"}, + Tags: []string{"flyte_plugin_test", "sample:label"}, Query: nil, } err := validateQuboleHiveJob(hiveJob)