Skip to content

Commit

Permalink
Render task template in the agent client (flyteorg#384)
Browse files Browse the repository at this point in the history
* Render task template

Signed-off-by: Kevin Su <[email protected]>

* Render task template in the agent client

Signed-off-by: Kevin Su <[email protected]>

* lint

Signed-off-by: Kevin Su <[email protected]>

* lint

Signed-off-by: Kevin Su <[email protected]>

* lint

Signed-off-by: Kevin Su <[email protected]>

---------

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Aug 9, 2023
1 parent c977cf8 commit 44ee7d3
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
10 changes: 9 additions & 1 deletion flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
"k8s.io/apimachinery/pkg/util/rand"
"k8s.io/utils/strings/slices"
)

type MockPlugin struct {
Expand All @@ -39,7 +40,11 @@ type MockPlugin struct {
type MockClient struct {
}

func (m *MockClient) CreateTask(_ context.Context, _ *admin.CreateTaskRequest, _ ...grpc.CallOption) (*admin.CreateTaskResponse, error) {
func (m *MockClient) CreateTask(_ context.Context, createTaskRequest *admin.CreateTaskRequest, _ ...grpc.CallOption) (*admin.CreateTaskResponse, error) {
expectedArgs := []string{"pyflyte-fast-execute", "--output-prefix", "fake://bucket/prefix/nhv"}
if slices.Equal(createTaskRequest.Template.GetContainer().Args, expectedArgs) {
return nil, fmt.Errorf("args not as expected")
}
return &admin.CreateTaskResponse{ResourceMeta: []byte{1, 2, 3, 4}}, nil
}

Expand Down Expand Up @@ -95,6 +100,9 @@ func TestEndToEnd(t *testing.T) {
template := flyteIdlCore.TaskTemplate{
Type: "bigquery_query_job_task",
Custom: st,
Target: &flyteIdlCore.TaskTemplate_Container{
Container: &flyteIdlCore.Container{Args: []string{"pyflyte-fast-execute", "--output-prefix", "{{.outputPrefix}}"}},
},
}
basePrefix := storage.DataReference("fake://bucket/prefix/")

Expand Down
21 changes: 17 additions & 4 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
pluginErrors "github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi"
"github.com/flyteorg/flytestdlib/promutils"
Expand Down Expand Up @@ -68,6 +68,19 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR
return nil, nil, err
}

if taskTemplate.GetContainer() != nil {
templateParameters := template.Parameters{
TaskExecMetadata: taskCtx.TaskExecutionMetadata(),
Inputs: taskCtx.InputReader(),
OutputPath: taskCtx.OutputWriter(),
Task: taskCtx.TaskReader(),
}
modifiedArgs, err := template.Render(ctx, taskTemplate.GetContainer().Args, templateParameters)
if err != nil {
return nil, nil, err
}
taskTemplate.GetContainer().Args = modifiedArgs
}
outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String()

agent, err := getFinalAgent(taskTemplate.Type, p.cfg)
Expand Down Expand Up @@ -150,7 +163,7 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase

switch resource.State {
case admin.State_RUNNING:
return core.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, taskInfo), nil
return core.PhaseInfoRunning(core.DefaultPhaseVersion, taskInfo), nil
case admin.State_PERMANENT_FAILURE:
return core.PhaseInfoFailure(pluginErrors.TaskFailedWithError, "failed to run the job", taskInfo), nil
case admin.State_RETRYABLE_FAILURE:
Expand All @@ -164,7 +177,7 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase
}
return core.PhaseInfoSuccess(taskInfo), nil
}
return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", resource.State)
return core.PhaseInfoUndefined, pluginErrors.Errorf(core.SystemErrorCode, "unknown execution phase [%v].", resource.State)
}

func getFinalAgent(taskType string, cfg *Config) (*Agent, error) {
Expand Down Expand Up @@ -225,7 +238,7 @@ func getClientFunc(ctx context.Context, agent *Agent, connectionCache map[*Agent
return service.NewAsyncAgentServiceClient(conn), nil
}

func buildTaskExecutionMetadata(taskExecutionMetadata pluginsCore.TaskExecutionMetadata) admin.TaskExecutionMetadata {
func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata) admin.TaskExecutionMetadata {
taskExecutionID := taskExecutionMetadata.GetTaskExecutionID().GetID()
return admin.TaskExecutionMetadata{
TaskExecutionId: &taskExecutionID,
Expand Down

0 comments on commit 44ee7d3

Please sign in to comment.