Skip to content

Commit

Permalink
Add command metadata so we can pass in more stuff for Hive Queries #m…
Browse files Browse the repository at this point in the history
…inor (flyteorg#117)

* Add command metadata so we can pass in more stuff for Hive Queries

* k

* k

* k

* feedback

* k

* Trigger Build
  • Loading branch information
akashkatipally authored Sep 3, 2020
1 parent adfa4b2 commit a1e2aee
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 24 deletions.
18 changes: 9 additions & 9 deletions go/tasks/plugins/hive/client/mocks/qubole_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 12 additions & 3 deletions go/tasks/plugins/hive/client/qubole_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ type QuboleCommandDetails struct {
URI url.URL
}

type CommandMetadata struct {
TaskName string
Domain string
Project string
}

// QuboleClient API Request Body, meant to be passed into JSON.marshal
// Any nil, 0 or "" fields will not be marshaled
type RequestBody struct {
Expand All @@ -62,7 +68,7 @@ type RequestBody struct {
// Interface to interact with QuboleClient for hive tasks
type QuboleClient interface {
ExecuteHiveCommand(ctx context.Context, commandStr string, timeoutVal uint32, clusterPrimaryLabel string,
accountKey string, tags []string) (*QuboleCommandDetails, error)
accountKey string, tags []string, commandMetadata CommandMetadata) (*QuboleCommandDetails, error)
KillCommand(ctx context.Context, commandID string, accountKey string) error
GetCommandStatus(ctx context.Context, commandID string, accountKey string) (QuboleStatus, error)
}
Expand Down Expand Up @@ -124,7 +130,8 @@ func closeBody(ctx context.Context, response *http.Response) {
}

// Helper method to execute the requests
func (q *quboleClient) executeRequest(ctx context.Context, method string, u *url.URL, body *RequestBody, accountKey string) (*http.Response, error) {
func (q *quboleClient) executeRequest(ctx context.Context, method string, u *url.URL,
body *RequestBody, accountKey string) (*http.Response, error) {
var req *http.Request
var err error

Expand Down Expand Up @@ -159,6 +166,7 @@ func (q *quboleClient) executeRequest(ctx context.Context, method string, u *url
param: string commandStr: the query to execute
param: uint32 timeoutVal: timeout for the query to execute in seconds
param: string ClusterLabel: label for cluster on which to execute the Hive Command.
param: CommandMetadata _: additional labels for the command
return: *int64: CommandID for the command executed
return: error: error in-case of a failure
*/
Expand All @@ -168,7 +176,8 @@ func (q *quboleClient) ExecuteHiveCommand(
timeoutVal uint32,
clusterPrimaryLabel string,
accountKey string,
tags []string) (*QuboleCommandDetails, error) {
tags []string,
_ CommandMetadata) (*QuboleCommandDetails, error) {

requestBody := RequestBody{
CommandType: hiveCommandType,
Expand Down
4 changes: 2 additions & 2 deletions go/tasks/plugins/hive/client/qubole_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func TestQuboleClient_GetCommandStatus(t *testing.T) {
func TestQuboleClient_ExecuteHiveCommand(t *testing.T) {
client := createQuboleClient(createCommandResponse)
details, err := client.ExecuteHiveCommand(context.Background(),
"", 0, "clusterLabel", "", nil)
"", 0, "clusterLabel", "", nil, CommandMetadata{})
assert.NoError(t, err)
assert.Equal(t, int64(3850), details.ID)
assert.Equal(t, QuboleStatusWaiting, details.Status)
Expand All @@ -98,7 +98,7 @@ func TestQuboleClient_KillCommand(t *testing.T) {
func TestQuboleClient_ExecuteHiveCommandError(t *testing.T) {
client := createQuboleErrorClient("bad token")
details, err := client.ExecuteHiveCommand(context.Background(),
"", 0, "clusterLabel", "", nil)
"", 0, "clusterLabel", "", nil, CommandMetadata{})
assert.Error(t, err)
assert.Nil(t, details)
}
Expand Down
22 changes: 14 additions & 8 deletions go/tasks/plugins/hive/execution_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func ConstructTaskInfo(e ExecutionState) *core.TaskInfo {
}

func composeResourceNamespaceWithClusterPrimaryLabel(ctx context.Context, tCtx core.TaskExecutionContext) (core.ResourceNamespace, error) {
_, clusterLabelOverride, _, _, err := GetQueryInfo(ctx, tCtx)
_, clusterLabelOverride, _, _, _, err := GetQueryInfo(ctx, tCtx)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -230,26 +230,27 @@ func validateQuboleHiveJob(hiveJob plugins.QuboleHiveJob) error {
// This function is the link between the output written by the SDK, and the execution side. It extracts the query
// out of the task template.
func GetQueryInfo(ctx context.Context, tCtx core.TaskExecutionContext) (
query string, cluster string, tags []string, timeoutSec uint32, err error) {
query string, cluster string, tags []string, timeoutSec uint32, taskName string, err error) {

taskTemplate, err := tCtx.TaskReader().Read(ctx)
if err != nil {
return "", "", []string{}, 0, err
return "", "", []string{}, 0, "", err
}

hiveJob := plugins.QuboleHiveJob{}
err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &hiveJob)
if err != nil {
return "", "", []string{}, 0, err
return "", "", []string{}, 0, "", err
}

if err := validateQuboleHiveJob(hiveJob); err != nil {
return "", "", []string{}, 0, err
return "", "", []string{}, 0, "", err
}

query = hiveJob.Query.GetQuery()
cluster = hiveJob.ClusterLabel
timeoutSec = hiveJob.Query.TimeoutSec
taskName = taskTemplate.Id.Name
tags = hiveJob.Tags
tags = append(tags, fmt.Sprintf("ns:%s", tCtx.TaskExecutionMetadata().GetNamespace()))
for k, v := range tCtx.TaskExecutionMetadata().GetLabels() {
Expand Down Expand Up @@ -334,15 +335,20 @@ func KickOffQuery(ctx context.Context, tCtx core.TaskExecutionContext, currentSt
return currentState, errors.Wrapf(errors.RuntimeFailure, err, "Failed to read token from secrets manager")
}

query, clusterLabelOverride, tags, timeoutSec, err := GetQueryInfo(ctx, tCtx)
query, clusterLabelOverride, tags, timeoutSec, taskName, err := GetQueryInfo(ctx, tCtx)
if err != nil {
return currentState, err
}

clusterPrimaryLabel := getClusterPrimaryLabel(ctx, tCtx, clusterLabelOverride)

taskExecutionIdentifier := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID()
commandMetadata := client.CommandMetadata{TaskName: taskName,
Domain: taskExecutionIdentifier.GetTaskId().GetDomain(),
Project: taskExecutionIdentifier.GetNodeExecutionId().GetExecutionId().GetProject()}

cmdDetails, err := quboleClient.ExecuteHiveCommand(ctx, query, timeoutSec,
clusterPrimaryLabel, apiKey, tags)
clusterPrimaryLabel, apiKey, tags, commandMetadata)
if err != nil {
// If we failed, we'll keep the NotStarted state
currentState.CreationFailureCount = currentState.CreationFailureCount + 1
Expand All @@ -366,7 +372,7 @@ func KickOffQuery(ctx context.Context, tCtx core.TaskExecutionContext, currentSt
if err != nil {
// This means that our cache has fundamentally broken... return a system error
logger.Errorf(ctx, "Cache failed to GetOrCreate for execution [%s] cache key [%s], owner [%s]. Error %s",
tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), uniqueID,
taskExecutionIdentifier, uniqueID,
tCtx.TaskExecutionMetadata().GetOwnerReference(), err)
return currentState, err
}
Expand Down
5 changes: 3 additions & 2 deletions go/tasks/plugins/hive/execution_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,13 @@ func TestGetQueryInfo(t *testing.T) {
taskMetadata.On("GetLabels").Return(map[string]string{"sample": "label"})
mockTaskExecutionContext.On("TaskExecutionMetadata").Return(taskMetadata)

query, cluster, tags, timeout, err := GetQueryInfo(ctx, &mockTaskExecutionContext)
query, cluster, tags, timeout, taskName, err := GetQueryInfo(ctx, &mockTaskExecutionContext)
assert.NoError(t, err)
assert.Equal(t, "select 'one'", query)
assert.Equal(t, "default", cluster)
assert.Equal(t, []string{"flyte_plugin_test", "ns:myproject-staging", "sample:label"}, tags)
assert.Equal(t, 500, int(timeout))
assert.Equal(t, "sample_hive_task_test_name", taskName)
}

func TestValidateQuboleHiveJob(t *testing.T) {
Expand Down Expand Up @@ -327,7 +328,7 @@ func TestKickOffQuery(t *testing.T) {
}
mockQubole := &quboleMocks.QuboleClient{}
mockQubole.OnExecuteHiveCommandMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything,
mock.Anything, mock.Anything).Run(func(_ mock.Arguments) {
mock.Anything, mock.Anything, mock.Anything).Run(func(_ mock.Arguments) {
quboleCalled = true
}).Return(quboleCommandDetails, nil)

Expand Down
3 changes: 3 additions & 0 deletions go/tasks/plugins/hive/test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ func GetMockTaskExecutionMetadata() core.TaskExecutionMetadata {

tID := &coreMock.TaskExecutionID{}
tID.On("GetID").Return(idlCore.TaskExecutionIdentifier{
TaskId: &idlCore.Identifier{
Domain: "production",
},
NodeExecutionId: &idlCore.NodeExecutionIdentifier{
ExecutionId: &idlCore.WorkflowExecutionIdentifier{
Name: "my_wf_exec_name",
Expand Down

0 comments on commit a1e2aee

Please sign in to comment.