Skip to content

Commit

Permalink
airflow, api: rest: Check if Airflow client is initialized
Browse files Browse the repository at this point in the history
  • Loading branch information
ish-hcc committed Nov 15, 2024
1 parent fe4b82e commit ce3c043
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 50 deletions.
6 changes: 3 additions & 3 deletions lib/airflow/conneciton.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/jollaman999/utils/logger"
)

func (client *client) RegisterConnection(connection *model.Connection) error {
func (client *Client) RegisterConnection(connection *model.Connection) error {
ctx, cancel := Context()
defer cancel()

Expand Down Expand Up @@ -42,9 +42,9 @@ func (client *client) RegisterConnection(connection *model.Connection) error {
Extra: extra,
}

_, _ = client.api.ConnectionApi.DeleteConnection(ctx, connection.ID).Execute()
_, _ = client.ConnectionApi.DeleteConnection(ctx, connection.ID).Execute()

_, _, err := client.api.ConnectionApi.PostConnection(ctx).Connection(conn).Execute()
_, _, err := client.ConnectionApi.PostConnection(ctx).Connection(conn).Execute()
if err != nil {
errMsg := "AIRFLOW: Error occurred while registering connection. (ConnID: " + connection.ID + ", Error: " + err.Error() + ")."
logger.Println(logger.ERROR, false, errMsg)
Expand Down
23 changes: 16 additions & 7 deletions lib/airflow/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package airflow
import (
"context"
"crypto/tls"
"errors"
"github.com/apache/airflow-client-go/airflow"
"github.com/cloud-barista/cm-cicada/lib/config"
"github.com/jollaman999/utils/logger"
Expand All @@ -11,11 +12,19 @@ import (
"time"
)

type client struct {
api *airflow.APIClient
type Client struct {
*airflow.APIClient
}

var Client *client
var airflowClient *Client

func GetClient() (*Client, error) {
if airflowClient == nil {
return nil, errors.New("airflow client not initialized")
}

return airflowClient, nil
}

func ping(url string) error {
timeout, _ := strconv.Atoi(config.CMCicadaConfig.CMCicada.AirflowServer.Timeout)
Expand Down Expand Up @@ -55,7 +64,7 @@ func checkPing(url string) {
func registerConnections() {
for _, connection := range config.CMCicadaConfig.CMCicada.AirflowServer.Connections {
logger.Println(logger.INFO, false, "Registering connection: ", connection)
err := Client.RegisterConnection(&connection)
err := airflowClient.RegisterConnection(&connection)
if err != nil {
logger.Println(logger.ERROR, false, err.Error())
}
Expand Down Expand Up @@ -93,11 +102,11 @@ func Init() {
checkPing(conf.Scheme + "://" + conf.Host)

cli := airflow.NewAPIClient(conf)
conn := client{
api: cli,
conn := Client{
cli,
}

Client = &conn
airflowClient = &conn

registerConnections()
}
52 changes: 26 additions & 26 deletions lib/airflow/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func callDagRequestLock(workflowID string) func() {
}
}

func (client *client) CreateDAG(workflow *model.Workflow) error {
func (client *Client) CreateDAG(workflow *model.Workflow) error {
deferFunc := callDagRequestLock(workflow.ID)
defer func() {
deferFunc()
Expand All @@ -52,15 +52,15 @@ func (client *client) CreateDAG(workflow *model.Workflow) error {
return nil
}

func (client *client) GetDAG(dagID string) (airflow.DAG, error) {
func (client *Client) GetDAG(dagID string) (airflow.DAG, error) {
deferFunc := callDagRequestLock(dagID)
defer func() {
deferFunc()
}()

ctx, cancel := Context()
defer cancel()
resp, _, err := client.api.DAGApi.GetDag(ctx, dagID).Execute()
resp, _, err := client.DAGApi.GetDag(ctx, dagID).Execute()
if err != nil {
logger.Println(logger.ERROR, false,
"AIRFLOW: Error occurred while getting the DAG. (Error: "+err.Error()+").")
Expand All @@ -69,18 +69,18 @@ func (client *client) GetDAG(dagID string) (airflow.DAG, error) {
return resp, err
}

func (client *client) GetDAGs() (airflow.DAGCollection, error) {
func (client *Client) GetDAGs() (airflow.DAGCollection, error) {
ctx, cancel := Context()
defer cancel()
resp, _, err := client.api.DAGApi.GetDags(ctx).Execute()
resp, _, err := client.DAGApi.GetDags(ctx).Execute()
if err != nil {
logger.Println(logger.ERROR, false,
"AIRFLOW: Error occurred while getting DAGs. (Error: "+err.Error()+").")
}
return resp, err
}

func (client *client) RunDAG(dagID string) (airflow.DAGRun, error) {
func (client *Client) RunDAG(dagID string) (airflow.DAGRun, error) {
deferFunc := callDagRequestLock(dagID)
defer func() {
deferFunc()
Expand Down Expand Up @@ -108,7 +108,7 @@ func (client *client) RunDAG(dagID string) (airflow.DAGRun, error) {

ctx, cancel := Context()
defer cancel()
resp, _, err := client.api.DAGRunApi.PostDagRun(ctx, dagID).DAGRun(*airflow.NewDAGRun()).Execute()
resp, _, err := client.DAGRunApi.PostDagRun(ctx, dagID).DAGRun(*airflow.NewDAGRun()).Execute()
if err != nil {
errMsg := "AIRFLOW: Error occurred while running the DAG. (DAG ID: " + dagID + ", Error: " + err.Error() + ")"
logger.Println(logger.ERROR, false, errMsg)
Expand All @@ -121,7 +121,7 @@ func (client *client) RunDAG(dagID string) (airflow.DAGRun, error) {
return resp, err
}

func (client *client) DeleteDAG(dagID string, deleteFolderOnly bool) error {
func (client *Client) DeleteDAG(dagID string, deleteFolderOnly bool) error {
deferFunc := callDagRequestLock(dagID)
defer func() {
deferFunc()
Expand All @@ -137,7 +137,7 @@ func (client *client) DeleteDAG(dagID string, deleteFolderOnly bool) error {
if !deleteFolderOnly {
ctx, cancel := Context()
defer cancel()
_, err = client.api.DAGApi.DeleteDag(ctx, dagID).Execute()
_, err = client.DAGApi.DeleteDag(ctx, dagID).Execute()
if err != nil {
logger.Println(logger.ERROR, false,
"AIRFLOW: Error occurred while deleting the DAG. (Error: "+err.Error()+").")
Expand All @@ -146,37 +146,37 @@ func (client *client) DeleteDAG(dagID string, deleteFolderOnly bool) error {

return err
}
func (client *client) GetDAGRuns(dagID string) (airflow.DAGRunCollection, error) {
func (client *Client) GetDAGRuns(dagID string) (airflow.DAGRunCollection, error) {
deferFunc := callDagRequestLock(dagID)
defer func() {
deferFunc()
}()
ctx, cancel := Context()
defer cancel()
resp, _, err := client.api.DAGRunApi.GetDagRuns(ctx, dagID).Execute()
resp, _, err := client.DAGRunApi.GetDagRuns(ctx, dagID).Execute()
if err != nil {
logger.Println(logger.ERROR, false,
"AIRFLOW: Error occurred while getting DAGRuns. (Error: "+err.Error()+").")
}
return resp, err
}

func (client *client) GetTaskInstances(dagID string, dagRunId string) (airflow.TaskInstanceCollection, error) {
func (client *Client) GetTaskInstances(dagID string, dagRunId string) (airflow.TaskInstanceCollection, error) {
deferFunc := callDagRequestLock(dagID)
defer func() {
deferFunc()
}()
ctx, cancel := Context()
defer cancel()
resp, _, err := client.api.TaskInstanceApi.GetTaskInstances(ctx, dagID, dagRunId).Execute()
resp, _, err := client.TaskInstanceApi.GetTaskInstances(ctx, dagID, dagRunId).Execute()
if err != nil {
logger.Println(logger.ERROR, false,
"AIRFLOW: Error occurred while getting TaskInstances. (Error: "+err.Error()+").")
}
return resp, err
}

func (client *client) GetTaskLogs(dagID, dagRunID, taskID string, taskTryNumber int) (airflow.InlineResponse200, error) {
func (client *Client) GetTaskLogs(dagID, dagRunID, taskID string, taskTryNumber int) (airflow.InlineResponse200, error) {
deferFunc := callDagRequestLock(dagID)
defer func() {
deferFunc()
Expand All @@ -185,7 +185,7 @@ func (client *client) GetTaskLogs(dagID, dagRunID, taskID string, taskTryNumber
defer cancel()

// TaskInstanceApi 인스턴스를 사용하여 로그 요청
logs, _, err := client.api.TaskInstanceApi.GetLog(ctx, dagID, dagRunID, taskID, int32(taskTryNumber)).FullContent(true).Execute()
logs, _, err := client.TaskInstanceApi.GetLog(ctx, dagID, dagRunID, taskID, int32(taskTryNumber)).FullContent(true).Execute()
logger.Println(logger.INFO, false, logs)
if err != nil {
logger.Println(logger.ERROR, false,
Expand All @@ -195,7 +195,7 @@ func (client *client) GetTaskLogs(dagID, dagRunID, taskID string, taskTryNumber
return logs, nil
}

func (client *client) ClearTaskInstance(dagID string, dagRunID string, taskID string) (airflow.TaskInstanceReferenceCollection, error) {
func (client *Client) ClearTaskInstance(dagID string, dagRunID string, taskID string) (airflow.TaskInstanceReferenceCollection, error) {
deferFunc := callDagRequestLock(dagID)
defer func() {
deferFunc()
Expand Down Expand Up @@ -231,13 +231,13 @@ func (client *client) ClearTaskInstance(dagID string, dagRunID string, taskID st
}

// 요청 생성
request := client.api.DAGApi.PostClearTaskInstances(ctx, dagID)
request := client.DAGApi.PostClearTaskInstances(ctx, dagID)

// ClearTaskInstances 데이터 설정
request = request.ClearTaskInstances(clearTask)

// 요청 실행
logs, _, err := client.api.DAGApi.PostClearTaskInstancesExecute(request)
logs, _, err := client.DAGApi.PostClearTaskInstancesExecute(request)
if err != nil {
logger.Println(logger.ERROR, false,
"AIRFLOW: Error occurred while clearing TaskInstance. (Error: "+err.Error()+").")
Expand All @@ -246,20 +246,20 @@ func (client *client) ClearTaskInstance(dagID string, dagRunID string, taskID st

return logs, nil
}
func (client *client) GetEventLogs(dagID string, dagRunId string, taskId string) ([]byte, error) {
func (client *Client) GetEventLogs(dagID string, dagRunId string, taskId string) ([]byte, error) {
deferFunc := callDagRequestLock(dagID)
defer func() {
deferFunc()
}()
ctx, cancel := Context()
defer cancel()

localBasePath, err := client.api.GetConfig().ServerURLWithContext(ctx, "EventLogApiService.GetEventLog")
localBasePath, err := client.GetConfig().ServerURLWithContext(ctx, "EventLogApiService.GetEventLog")
if err != nil {
fmt.Println("Error occurred while getting event logs:", err)
}

baseURL := "http://" + client.api.GetConfig().Host + localBasePath + "/eventLogs"
baseURL := "http://" + client.GetConfig().Host + localBasePath + "/eventLogs"
queryParams := map[string]string{
"offset": "0",
"limit": "100",
Expand All @@ -275,7 +275,7 @@ func (client *client) GetEventLogs(dagID string, dagRunId string, taskId string)
}
queryString := query.Encode()
fullURL := fmt.Sprintf("%s?%s", baseURL, queryString)
httpclient := client.api.GetConfig().HTTPClient
httpclient := client.GetConfig().HTTPClient

// 요청 생성
req, err := http.NewRequest("GET", fullURL, nil)
Expand All @@ -299,12 +299,12 @@ func (client *client) GetEventLogs(dagID string, dagRunId string, taskId string)
return body, err
}

func (client *client) GetImportErrors() (airflow.ImportErrorCollection, error) {
func (client *Client) GetImportErrors() (airflow.ImportErrorCollection, error) {
ctx, cancel := Context()
defer cancel()

// TaskInstanceApi 인스턴스를 사용하여 로그 요청
logs, _, err := client.api.ImportErrorApi.GetImportErrors(ctx).Execute()
logs, _, err := client.ImportErrorApi.GetImportErrors(ctx).Execute()
logger.Println(logger.INFO, false, logs)
if err != nil {
logger.Println(logger.ERROR, false,
Expand All @@ -314,12 +314,12 @@ func (client *client) GetImportErrors() (airflow.ImportErrorCollection, error) {
return logs, nil
}

func (client *client) PatchDag(dagID string, dagBody airflow.DAG) (airflow.DAG, error) {
func (client *Client) PatchDag(dagID string, dagBody airflow.DAG) (airflow.DAG, error) {
ctx, cancel := Context()
defer cancel()

// TaskInstanceApi 인스턴스를 사용하여 로그 요청
logs, _, err := client.api.DAGApi.PatchDag(ctx, dagID).DAG(dagBody).Execute()
logs, _, err := client.DAGApi.PatchDag(ctx, dagID).DAG(dagBody).Execute()
logger.Println(logger.INFO, false, logs)
if err != nil {
logger.Println(logger.ERROR, false,
Expand Down
Loading

0 comments on commit ce3c043

Please sign in to comment.