diff --git a/flyteplugins/go/tasks/plugins/presto/execution_state.go b/flyteplugins/go/tasks/plugins/presto/execution_state.go index 701b25d6a..5268cf449 100644 --- a/flyteplugins/go/tasks/plugins/presto/execution_state.go +++ b/flyteplugins/go/tasks/plugins/presto/execution_state.go @@ -123,10 +123,10 @@ func HandleExecutionState( newState, transformError = MonitorQuery(ctx, tCtx, currentState, executionsCache) case PhaseQuerySucceeded: - if currentState.QueryCount < 1 { + if currentState.QueryCount < 4 { // If there are still Presto statements to execute, increment the query count, reset the phase to 'queued' // and continue executing the remaining statements. In this case, we won't request another allocation token - // as the 2 statements that get executed are all considered to be part of the same "query" + // as the 5 statements that get executed are all considered to be part of the same "query" currentState.PreviousPhase = currentState.CurrentPhase currentState.CurrentPhase = PhaseQueued } else { @@ -304,18 +304,7 @@ func GetNextQuery( user = tCtx.TaskExecutionMetadata().GetNamespace() } - externalLocation, err := tCtx.DataStore().ConstructReference(ctx, tCtx.OutputWriter().GetRawOutputPrefix(), "") - if err != nil { - return Query{}, err - } - - queryWrapTemplate := ` - CREATE TABLE hive.flyte_temporary_tables."%s_temp" - WITH (format = 'PARQUET', external_location = '%s') - AS (%s) - ` - - statement = fmt.Sprintf(queryWrapTemplate, tempTableName, externalLocation, statement) + statement = fmt.Sprintf(`CREATE TABLE hive.flyte_temporary_tables."%s_temp" AS %s`, tempTableName, statement) prestoQuery := Query{ Statement: statement, @@ -328,16 +317,46 @@ func GetNextQuery( }, TempTableName: tempTableName + "_temp", ExternalTableName: tempTableName + "_external", - ExternalLocation: externalLocation.String(), } return prestoQuery, nil case 1: + externalLocation, err := tCtx.DataStore().ConstructReference(ctx, tCtx.OutputWriter().GetRawOutputPrefix(), "") + if err != nil { + return Query{}, err + } + + statement := fmt.Sprintf(` +CREATE TABLE hive.flyte_temporary_tables."%s" (LIKE hive.flyte_temporary_tables."%s") +WITH (format = 'PARQUET', external_location = '%s')`, + currentState.CurrentPrestoQuery.ExternalTableName, + currentState.CurrentPrestoQuery.TempTableName, + externalLocation, + ) + currentState.CurrentPrestoQuery.Statement = statement + currentState.CurrentPrestoQuery.ExternalLocation = externalLocation.String() + return currentState.CurrentPrestoQuery, nil + + case 2: + statement := ` +INSERT INTO hive.flyte_temporary_tables."%s" +SELECT * +FROM hive.flyte_temporary_tables."%s"` + statement = fmt.Sprintf(statement, currentState.CurrentPrestoQuery.ExternalTableName, currentState.CurrentPrestoQuery.TempTableName) + currentState.CurrentPrestoQuery.Statement = statement + return currentState.CurrentPrestoQuery, nil + + case 3: statement := fmt.Sprintf(`DROP TABLE hive.flyte_temporary_tables."%s"`, currentState.CurrentPrestoQuery.TempTableName) currentState.CurrentPrestoQuery.Statement = statement return currentState.CurrentPrestoQuery, nil + case 4: + statement := fmt.Sprintf(`DROP TABLE hive.flyte_temporary_tables."%s"`, currentState.CurrentPrestoQuery.ExternalTableName) + currentState.CurrentPrestoQuery.Statement = statement + return currentState.CurrentPrestoQuery, nil + default: return currentState.CurrentPrestoQuery, nil }