From ce3c0434121c644762c748cf96e9a8588614289e Mon Sep 17 00:00:00 2001 From: ish Date: Fri, 15 Nov 2024 16:07:44 +0900 Subject: [PATCH] airflow, api: rest: Check if Airflow client is initialized --- lib/airflow/conneciton.go | 6 +-- lib/airflow/init.go | 23 +++++--- lib/airflow/workflow.go | 52 +++++++++--------- pkg/api/rest/controller/workflow.go | 83 ++++++++++++++++++++++++----- 4 files changed, 114 insertions(+), 50 deletions(-) diff --git a/lib/airflow/conneciton.go b/lib/airflow/conneciton.go index 5a040f1..eb86162 100644 --- a/lib/airflow/conneciton.go +++ b/lib/airflow/conneciton.go @@ -8,7 +8,7 @@ import ( "github.com/jollaman999/utils/logger" ) -func (client *client) RegisterConnection(connection *model.Connection) error { +func (client *Client) RegisterConnection(connection *model.Connection) error { ctx, cancel := Context() defer cancel() @@ -42,9 +42,9 @@ func (client *client) RegisterConnection(connection *model.Connection) error { Extra: extra, } - _, _ = client.api.ConnectionApi.DeleteConnection(ctx, connection.ID).Execute() + _, _ = client.ConnectionApi.DeleteConnection(ctx, connection.ID).Execute() - _, _, err := client.api.ConnectionApi.PostConnection(ctx).Connection(conn).Execute() + _, _, err := client.ConnectionApi.PostConnection(ctx).Connection(conn).Execute() if err != nil { errMsg := "AIRFLOW: Error occurred while registering connection. (ConnID: " + connection.ID + ", Error: " + err.Error() + ")." logger.Println(logger.ERROR, false, errMsg) diff --git a/lib/airflow/init.go b/lib/airflow/init.go index 7544d09..f6c4260 100644 --- a/lib/airflow/init.go +++ b/lib/airflow/init.go @@ -3,6 +3,7 @@ package airflow import ( "context" "crypto/tls" + "errors" "github.com/apache/airflow-client-go/airflow" "github.com/cloud-barista/cm-cicada/lib/config" "github.com/jollaman999/utils/logger" @@ -11,11 +12,19 @@ import ( "time" ) -type client struct { - api *airflow.APIClient +type Client struct { + *airflow.APIClient } -var Client *client +var airflowClient *Client + +func GetClient() (*Client, error) { + if airflowClient == nil { + return nil, errors.New("airflow client not initialized") + } + + return airflowClient, nil +} func ping(url string) error { timeout, _ := strconv.Atoi(config.CMCicadaConfig.CMCicada.AirflowServer.Timeout) @@ -55,7 +64,7 @@ func checkPing(url string) { func registerConnections() { for _, connection := range config.CMCicadaConfig.CMCicada.AirflowServer.Connections { logger.Println(logger.INFO, false, "Registering connection: ", connection) - err := Client.RegisterConnection(&connection) + err := airflowClient.RegisterConnection(&connection) if err != nil { logger.Println(logger.ERROR, false, err.Error()) } @@ -93,11 +102,11 @@ func Init() { checkPing(conf.Scheme + "://" + conf.Host) cli := airflow.NewAPIClient(conf) - conn := client{ - api: cli, + conn := Client{ + cli, } - Client = &conn + airflowClient = &conn registerConnections() } diff --git a/lib/airflow/workflow.go b/lib/airflow/workflow.go index a6fb583..1a73877 100644 --- a/lib/airflow/workflow.go +++ b/lib/airflow/workflow.go @@ -38,7 +38,7 @@ func callDagRequestLock(workflowID string) func() { } } -func (client *client) CreateDAG(workflow *model.Workflow) error { +func (client *Client) CreateDAG(workflow *model.Workflow) error { deferFunc := callDagRequestLock(workflow.ID) defer func() { deferFunc() @@ -52,7 +52,7 @@ func (client *client) CreateDAG(workflow *model.Workflow) error { return nil } -func (client *client) GetDAG(dagID string) (airflow.DAG, error) { +func (client *Client) GetDAG(dagID string) (airflow.DAG, error) { deferFunc := callDagRequestLock(dagID) defer func() { deferFunc() @@ -60,7 +60,7 @@ func (client *client) GetDAG(dagID string) (airflow.DAG, error) { ctx, cancel := Context() defer cancel() - resp, _, err := client.api.DAGApi.GetDag(ctx, dagID).Execute() + resp, _, err := client.DAGApi.GetDag(ctx, dagID).Execute() if err != nil { logger.Println(logger.ERROR, false, "AIRFLOW: Error occurred while getting the DAG. (Error: "+err.Error()+").") @@ -69,10 +69,10 @@ func (client *client) GetDAG(dagID string) (airflow.DAG, error) { return resp, err } -func (client *client) GetDAGs() (airflow.DAGCollection, error) { +func (client *Client) GetDAGs() (airflow.DAGCollection, error) { ctx, cancel := Context() defer cancel() - resp, _, err := client.api.DAGApi.GetDags(ctx).Execute() + resp, _, err := client.DAGApi.GetDags(ctx).Execute() if err != nil { logger.Println(logger.ERROR, false, "AIRFLOW: Error occurred while getting DAGs. (Error: "+err.Error()+").") @@ -80,7 +80,7 @@ func (client *client) GetDAGs() (airflow.DAGCollection, error) { return resp, err } -func (client *client) RunDAG(dagID string) (airflow.DAGRun, error) { +func (client *Client) RunDAG(dagID string) (airflow.DAGRun, error) { deferFunc := callDagRequestLock(dagID) defer func() { deferFunc() @@ -108,7 +108,7 @@ func (client *client) RunDAG(dagID string) (airflow.DAGRun, error) { ctx, cancel := Context() defer cancel() - resp, _, err := client.api.DAGRunApi.PostDagRun(ctx, dagID).DAGRun(*airflow.NewDAGRun()).Execute() + resp, _, err := client.DAGRunApi.PostDagRun(ctx, dagID).DAGRun(*airflow.NewDAGRun()).Execute() if err != nil { errMsg := "AIRFLOW: Error occurred while running the DAG. (DAG ID: " + dagID + ", Error: " + err.Error() + ")" logger.Println(logger.ERROR, false, errMsg) @@ -121,7 +121,7 @@ func (client *client) RunDAG(dagID string) (airflow.DAGRun, error) { return resp, err } -func (client *client) DeleteDAG(dagID string, deleteFolderOnly bool) error { +func (client *Client) DeleteDAG(dagID string, deleteFolderOnly bool) error { deferFunc := callDagRequestLock(dagID) defer func() { deferFunc() @@ -137,7 +137,7 @@ func (client *client) DeleteDAG(dagID string, deleteFolderOnly bool) error { if !deleteFolderOnly { ctx, cancel := Context() defer cancel() - _, err = client.api.DAGApi.DeleteDag(ctx, dagID).Execute() + _, err = client.DAGApi.DeleteDag(ctx, dagID).Execute() if err != nil { logger.Println(logger.ERROR, false, "AIRFLOW: Error occurred while deleting the DAG. (Error: "+err.Error()+").") @@ -146,14 +146,14 @@ func (client *client) DeleteDAG(dagID string, deleteFolderOnly bool) error { return err } -func (client *client) GetDAGRuns(dagID string) (airflow.DAGRunCollection, error) { +func (client *Client) GetDAGRuns(dagID string) (airflow.DAGRunCollection, error) { deferFunc := callDagRequestLock(dagID) defer func() { deferFunc() }() ctx, cancel := Context() defer cancel() - resp, _, err := client.api.DAGRunApi.GetDagRuns(ctx, dagID).Execute() + resp, _, err := client.DAGRunApi.GetDagRuns(ctx, dagID).Execute() if err != nil { logger.Println(logger.ERROR, false, "AIRFLOW: Error occurred while getting DAGRuns. (Error: "+err.Error()+").") @@ -161,14 +161,14 @@ func (client *client) GetDAGRuns(dagID string) (airflow.DAGRunCollection, error) return resp, err } -func (client *client) GetTaskInstances(dagID string, dagRunId string) (airflow.TaskInstanceCollection, error) { +func (client *Client) GetTaskInstances(dagID string, dagRunId string) (airflow.TaskInstanceCollection, error) { deferFunc := callDagRequestLock(dagID) defer func() { deferFunc() }() ctx, cancel := Context() defer cancel() - resp, _, err := client.api.TaskInstanceApi.GetTaskInstances(ctx, dagID, dagRunId).Execute() + resp, _, err := client.TaskInstanceApi.GetTaskInstances(ctx, dagID, dagRunId).Execute() if err != nil { logger.Println(logger.ERROR, false, "AIRFLOW: Error occurred while getting TaskInstances. (Error: "+err.Error()+").") @@ -176,7 +176,7 @@ func (client *client) GetTaskInstances(dagID string, dagRunId string) (airflow.T return resp, err } -func (client *client) GetTaskLogs(dagID, dagRunID, taskID string, taskTryNumber int) (airflow.InlineResponse200, error) { +func (client *Client) GetTaskLogs(dagID, dagRunID, taskID string, taskTryNumber int) (airflow.InlineResponse200, error) { deferFunc := callDagRequestLock(dagID) defer func() { deferFunc() @@ -185,7 +185,7 @@ func (client *client) GetTaskLogs(dagID, dagRunID, taskID string, taskTryNumber defer cancel() // TaskInstanceApi 인스턴스를 사용하여 로그 요청 - logs, _, err := client.api.TaskInstanceApi.GetLog(ctx, dagID, dagRunID, taskID, int32(taskTryNumber)).FullContent(true).Execute() + logs, _, err := client.TaskInstanceApi.GetLog(ctx, dagID, dagRunID, taskID, int32(taskTryNumber)).FullContent(true).Execute() logger.Println(logger.INFO, false, logs) if err != nil { logger.Println(logger.ERROR, false, @@ -195,7 +195,7 @@ func (client *client) GetTaskLogs(dagID, dagRunID, taskID string, taskTryNumber return logs, nil } -func (client *client) ClearTaskInstance(dagID string, dagRunID string, taskID string) (airflow.TaskInstanceReferenceCollection, error) { +func (client *Client) ClearTaskInstance(dagID string, dagRunID string, taskID string) (airflow.TaskInstanceReferenceCollection, error) { deferFunc := callDagRequestLock(dagID) defer func() { deferFunc() @@ -231,13 +231,13 @@ func (client *client) ClearTaskInstance(dagID string, dagRunID string, taskID st } // 요청 생성 - request := client.api.DAGApi.PostClearTaskInstances(ctx, dagID) + request := client.DAGApi.PostClearTaskInstances(ctx, dagID) // ClearTaskInstances 데이터 설정 request = request.ClearTaskInstances(clearTask) // 요청 실행 - logs, _, err := client.api.DAGApi.PostClearTaskInstancesExecute(request) + logs, _, err := client.DAGApi.PostClearTaskInstancesExecute(request) if err != nil { logger.Println(logger.ERROR, false, "AIRFLOW: Error occurred while clearing TaskInstance. (Error: "+err.Error()+").") @@ -246,7 +246,7 @@ func (client *client) ClearTaskInstance(dagID string, dagRunID string, taskID st return logs, nil } -func (client *client) GetEventLogs(dagID string, dagRunId string, taskId string) ([]byte, error) { +func (client *Client) GetEventLogs(dagID string, dagRunId string, taskId string) ([]byte, error) { deferFunc := callDagRequestLock(dagID) defer func() { deferFunc() @@ -254,12 +254,12 @@ func (client *client) GetEventLogs(dagID string, dagRunId string, taskId string) ctx, cancel := Context() defer cancel() - localBasePath, err := client.api.GetConfig().ServerURLWithContext(ctx, "EventLogApiService.GetEventLog") + localBasePath, err := client.GetConfig().ServerURLWithContext(ctx, "EventLogApiService.GetEventLog") if err != nil { fmt.Println("Error occurred while getting event logs:", err) } - baseURL := "http://" + client.api.GetConfig().Host + localBasePath + "/eventLogs" + baseURL := "http://" + client.GetConfig().Host + localBasePath + "/eventLogs" queryParams := map[string]string{ "offset": "0", "limit": "100", @@ -275,7 +275,7 @@ func (client *client) GetEventLogs(dagID string, dagRunId string, taskId string) } queryString := query.Encode() fullURL := fmt.Sprintf("%s?%s", baseURL, queryString) - httpclient := client.api.GetConfig().HTTPClient + httpclient := client.GetConfig().HTTPClient // 요청 생성 req, err := http.NewRequest("GET", fullURL, nil) @@ -299,12 +299,12 @@ func (client *client) GetEventLogs(dagID string, dagRunId string, taskId string) return body, err } -func (client *client) GetImportErrors() (airflow.ImportErrorCollection, error) { +func (client *Client) GetImportErrors() (airflow.ImportErrorCollection, error) { ctx, cancel := Context() defer cancel() // TaskInstanceApi 인스턴스를 사용하여 로그 요청 - logs, _, err := client.api.ImportErrorApi.GetImportErrors(ctx).Execute() + logs, _, err := client.ImportErrorApi.GetImportErrors(ctx).Execute() logger.Println(logger.INFO, false, logs) if err != nil { logger.Println(logger.ERROR, false, @@ -314,12 +314,12 @@ func (client *client) GetImportErrors() (airflow.ImportErrorCollection, error) { return logs, nil } -func (client *client) PatchDag(dagID string, dagBody airflow.DAG) (airflow.DAG, error) { +func (client *Client) PatchDag(dagID string, dagBody airflow.DAG) (airflow.DAG, error) { ctx, cancel := Context() defer cancel() // TaskInstanceApi 인스턴스를 사용하여 로그 요청 - logs, _, err := client.api.DAGApi.PatchDag(ctx, dagID).DAG(dagBody).Execute() + logs, _, err := client.DAGApi.PatchDag(ctx, dagID).DAG(dagBody).Execute() logger.Println(logger.INFO, false, logs) if err != nil { logger.Println(logger.ERROR, false, diff --git a/pkg/api/rest/controller/workflow.go b/pkg/api/rest/controller/workflow.go index e6fa2ce..d44a2b9 100644 --- a/pkg/api/rest/controller/workflow.go +++ b/pkg/api/rest/controller/workflow.go @@ -186,7 +186,12 @@ func CreateWorkflow(c echo.Context) error { } } - err = airflow.Client.CreateDAG(&workflow) + client, err := airflow.GetClient() + if err != nil { + return common.ReturnErrorMsg(c, err.Error()) + } + + err = client.CreateDAG(&workflow) if err != nil { return common.ReturnErrorMsg(c, "Failed to create the workflow. (Error:"+err.Error()+")") } @@ -259,7 +264,12 @@ func GetWorkflow(c echo.Context) error { } } - _, err = airflow.Client.GetDAG(wfId) + client, err := airflow.GetClient() + if err != nil { + return common.ReturnErrorMsg(c, err.Error()) + } + + _, err = client.GetDAG(wfId) if err != nil { return common.ReturnErrorMsg(c, "Failed to get the workflow from the airflow server.") } @@ -309,7 +319,12 @@ func GetWorkflowByName(c echo.Context) error { } } - _, err = airflow.Client.GetDAG(workflow.ID) + client, err := airflow.GetClient() + if err != nil { + return common.ReturnErrorMsg(c, err.Error()) + } + + _, err = client.GetDAG(workflow.ID) if err != nil { return common.ReturnErrorMsg(c, "Failed to get the workflow from the airflow server.") } @@ -394,7 +409,12 @@ func RunWorkflow(c echo.Context) error { return common.ReturnErrorMsg(c, err.Error()) } - _, err = airflow.Client.RunDAG(workflow.ID) + client, err := airflow.GetClient() + if err != nil { + return common.ReturnErrorMsg(c, err.Error()) + } + + _, err = client.RunDAG(workflow.ID) if err != nil { return common.ReturnInternalError(c, err, "Failed to run the workflow.") } @@ -513,12 +533,17 @@ func UpdateWorkflow(c echo.Context) error { return common.ReturnErrorMsg(c, err.Error()) } - err = airflow.Client.DeleteDAG(oldWorkflow.ID, true) + client, err := airflow.GetClient() + if err != nil { + return common.ReturnErrorMsg(c, err.Error()) + } + + err = client.DeleteDAG(oldWorkflow.ID, true) if err != nil { return common.ReturnErrorMsg(c, "Failed to update the workflow. (Error:"+err.Error()+")") } - err = airflow.Client.CreateDAG(oldWorkflow) + err = client.CreateDAG(oldWorkflow) if err != nil { return common.ReturnErrorMsg(c, "Failed to update the workflow. (Error:"+err.Error()+")") } @@ -550,7 +575,12 @@ func DeleteWorkflow(c echo.Context) error { return common.ReturnErrorMsg(c, err.Error()) } - err = airflow.Client.DeleteDAG(workflow.ID, false) + client, err := airflow.GetClient() + if err != nil { + return common.ReturnErrorMsg(c, err.Error()) + } + + err = client.DeleteDAG(workflow.ID, false) if err != nil { logger.Println(logger.ERROR, true, "AIRFLOW: "+err.Error()) } @@ -961,7 +991,11 @@ func GetTaskLogs(c echo.Context) error { if err != nil { return common.ReturnErrorMsg(c, "Invalid taskTryNum format.") } - logs, err := airflow.Client.GetTaskLogs(wfId, common.UrlDecode(wfRunId), taskInfo.Name, taskTyNumToInt) + client, err := airflow.GetClient() + if err != nil { + return common.ReturnErrorMsg(c, err.Error()) + } + logs, err := client.GetTaskLogs(wfId, common.UrlDecode(wfRunId), taskInfo.Name, taskTyNumToInt) if err != nil { return common.ReturnErrorMsg(c, "Failed to get the workflow logs: "+err.Error()) } @@ -992,7 +1026,12 @@ func GetWorkflowRuns(c echo.Context) error { return common.ReturnErrorMsg(c, "Please provide the wfId.") } - runList, err := airflow.Client.GetDAGRuns(wfId) + client, err := airflow.GetClient() + if err != nil { + return common.ReturnErrorMsg(c, err.Error()) + } + + runList, err := client.GetDAGRuns(wfId) if err != nil { return common.ReturnErrorMsg(c, "Failed to get the workflow runs: "+err.Error()) } @@ -1042,7 +1081,11 @@ func GetTaskInstances(c echo.Context) error { if wfRunId == "" { return common.ReturnErrorMsg(c, "Please provide the wfRunId.") } - runList, err := airflow.Client.GetTaskInstances(common.UrlDecode(wfId), common.UrlDecode(wfRunId)) + client, err := airflow.GetClient() + if err != nil { + return common.ReturnErrorMsg(c, err.Error()) + } + runList, err := client.GetTaskInstances(common.UrlDecode(wfId), common.UrlDecode(wfRunId)) if err != nil { return common.ReturnErrorMsg(c, "Failed to get the taskInstances: "+err.Error()) } @@ -1122,7 +1165,11 @@ func ClearTaskInstances(c echo.Context) error { taskId = taskDBInfo.Name } var TaskInstanceReferences []model.TaskInstanceReference - clearList, err := airflow.Client.ClearTaskInstance(wfId, common.UrlDecode(wfRunId), taskId) + client, err := airflow.GetClient() + if err != nil { + return common.ReturnErrorMsg(c, err.Error()) + } + clearList, err := client.ClearTaskInstance(wfId, common.UrlDecode(wfRunId), taskId) if err != nil { return common.ReturnErrorMsg(c, "Failed to get the taskInstances: "+err.Error()) } @@ -1181,7 +1228,11 @@ func GetEventLogs(c echo.Context) error { taskName = taskDBInfo.Name } var eventLogs model.EventLogs - logs, err := airflow.Client.GetEventLogs(wfId, wfRunId, taskName) + client, err := airflow.GetClient() + if err != nil { + return common.ReturnErrorMsg(c, err.Error()) + } + logs, err := client.GetEventLogs(wfId, wfRunId, taskName) if err != nil { return common.ReturnErrorMsg(c, "Failed to get the taskInstances: "+err.Error()) } @@ -1231,7 +1282,11 @@ func GetEventLogs(c echo.Context) error { // @Failure 500 {object} common.ErrorResponse "Failed to get the importErrors." // @Router /importErrors [get] func GetImportErrors(c echo.Context) error { - logs, err := airflow.Client.GetImportErrors() + client, err := airflow.GetClient() + if err != nil { + return common.ReturnErrorMsg(c, err.Error()) + } + logs, err := client.GetImportErrors() if err != nil { return common.ReturnErrorMsg(c, "Failed to get the taskInstances: "+err.Error()) } @@ -1239,7 +1294,7 @@ func GetImportErrors(c echo.Context) error { return c.JSONPretty(http.StatusOK, logs, " ") } -// ListWorkflow godoc +// ListWorkflowVersion godoc // // @ID list-workflowVersion // @Summary List workflowVersion