diff --git a/go.sum b/go.sum index 4e4f8c0..d85a04e 100644 --- a/go.sum +++ b/go.sum @@ -311,8 +311,8 @@ golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM= -golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= diff --git a/internal/app.go b/internal/app.go index b79aa37..bac4a48 100644 --- a/internal/app.go +++ b/internal/app.go @@ -14,33 +14,40 @@ import ( "github.com/spf13/viper" ) +type ECSClient interface { + ListClusters(ctx context.Context, params *ecs.ListClustersInput, optFns ...func(*ecs.Options)) (*ecs.ListClustersOutput, error) + ListServices(ctx context.Context, params *ecs.ListServicesInput, optFns ...func(*ecs.Options)) (*ecs.ListServicesOutput, error) + ListTasks(ctx context.Context, params *ecs.ListTasksInput, optFns ...func(*ecs.Options)) (*ecs.ListTasksOutput, error) + DescribeTasks(ctx context.Context, params *ecs.DescribeTasksInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTasksOutput, error) + DescribeTaskDefinition(ctx context.Context, params *ecs.DescribeTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTaskDefinitionOutput, error) + DescribeContainerInstances(ctx context.Context, params *ecs.DescribeContainerInstancesInput, optFns ...func(*ecs.Options)) (*ecs.DescribeContainerInstancesOutput, error) + ExecuteCommand(ctx context.Context, params *ecs.ExecuteCommandInput, optFns ...func(*ecs.Options)) (*ecs.ExecuteCommandOutput, error) +} + // App is a struct that contains information about our command state type App struct { - input chan string - err chan error - exit chan error - client *ecs.Client - region string - endpoint string - cluster string - service string - task ecsTypes.Task - tasks map[string]*ecsTypes.Task - container *ecsTypes.Container - containers []*ecsTypes.Container + input chan string + err chan error + exit chan error + client ECSClient + region string + endpoint string + cluster string + service string + task *ecsTypes.Task + tasks map[string]*ecsTypes.Task + container *ecsTypes.Container } // CreateApp initialises a new App struct with the required initial values func CreateApp() *App { - ecsClient := createEcsClient() - ssmClient := createSSMClient() + client := createEcsClient() e := &App{ input: make(chan string, 1), err: make(chan error, 1), exit: make(chan error, 1), client: client, region: client.Options().Region, - // endpoint: client.Endpoint, } return e @@ -277,13 +284,13 @@ func (e *App) getTask() { return } if len(describe.Tasks) > 0 { - e.task = describe.Tasks[0] + e.task = &describe.Tasks[0] e.getContainerOS() e.input <- "getContainer" viper.Set("task", "") // Reset the cli arg so user can navigate return } else { - fmt.Printf(Red(fmt.Sprintf("\nTask with ID %s not found in cluster %s\n", cliArg, e.cluster))) + fmt.Println(Red(fmt.Sprintf("\nTask with ID %s not found in cluster %s\n", cliArg, e.cluster))) e.input <- "getService" return } @@ -364,7 +371,7 @@ func (e *App) getTask() { e.input <- "getService" return } - e.task = *selection + e.task = selection e.getContainerOS() e.input <- "getContainer" return @@ -425,7 +432,7 @@ func (e *App) getContainer() { func (e *App) getContainerOS() { // Get associated task definition and determine OS family if EC2 launch-type if e.task.LaunchType == "EC2" { - family, err := getPlatformFamily(e.client, e.cluster, &e.task) + family, err := getPlatformFamily(e.client, e.cluster, e.task) if err != nil { e.err <- err return @@ -434,7 +441,7 @@ func (e *App) getContainerOS() { // then we refer to the container instance to determine the OS if family == "" { ec2Client := createEc2Client() - family, err = getContainerInstanceOS(e.ecsClient, ec2Client, e.cluster, *e.task.ContainerInstanceArn) + family, err = getContainerInstanceOS(e.client, ec2Client, e.cluster, *e.task.ContainerInstanceArn) if err != nil { e.err <- err return diff --git a/internal/forward.go b/internal/forward.go index 8a8e3f4..264c542 100644 --- a/internal/forward.go +++ b/internal/forward.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ecs" "github.com/aws/aws-sdk-go-v2/service/ssm" "github.com/spf13/viper" ) @@ -29,7 +30,8 @@ func (e *App) executeForward() error { panic(err) } client := ssm.NewFromConfig(cfg) // TODO: add region - containerPort, err := getContainerPort(e.client, *e.task.TaskDefinitionArn, *e.container.Name) + ecsClient := e.client.(*ecs.Client) + containerPort, err := getContainerPort(ecsClient, *e.task.TaskDefinitionArn, *e.container.Name) if err != nil { e.err <- err return err diff --git a/internal/internal.go b/internal/internal.go index 7dd44cd..4017dfc 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -13,6 +13,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ecs" ecsTypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/aws/aws-sdk-go-v2/service/ssm" "github.com/fatih/color" "github.com/spf13/viper" ) @@ -39,7 +40,12 @@ func createOpts(opts []string) []string { func createEcsClient() *ecs.Client { region := viper.GetString("region") - endpointUrl := viper.GetString("aws-endpoint-url") + getCustomAWSEndpoint := func(o *ecs.Options) { + endpointUrl := viper.GetString("aws-endpoint-url") + if endpointUrl != "" { + o.BaseEndpoint = aws.String(endpointUrl) + } + } cfg, err := config.LoadDefaultConfig(context.Background(), config.WithSharedConfigProfile(viper.GetString("profile")), config.WithRegion(region), @@ -47,13 +53,19 @@ func createEcsClient() *ecs.Client { if err != nil { panic(err) } - client := ecs.NewFromConfig(cfg) + client := ecs.NewFromConfig(cfg, getCustomAWSEndpoint) return client } func createEc2Client() *ec2.Client { region := viper.GetString("region") + getCustomAWSEndpoint := func(o *ec2.Options) { + endpointUrl := viper.GetString("aws-endpoint-url") + if endpointUrl != "" { + o.BaseEndpoint = aws.String(endpointUrl) + } + } cfg, err := config.LoadDefaultConfig(context.Background(), config.WithSharedConfigProfile(viper.GetString("profile")), config.WithRegion(region), @@ -61,27 +73,34 @@ func createEc2Client() *ec2.Client { if err != nil { panic(err) } - client := ec2.NewFromConfig(cfg) + client := ec2.NewFromConfig(cfg, getCustomAWSEndpoint) return client } -func createSSMClient() *ssm.SSM { +func createSSMClient() *ssm.Client { region := viper.GetString("region") - endpointUrl := viper.GetString("aws-endpoint-url") - sess := session.Must(session.NewSessionWithOptions(session.Options{ - Config: aws.Config{Region: aws.String(region), Endpoint: aws.String(endpointUrl)}, - Profile: viper.GetString("profile"), - SharedConfigState: session.SharedConfigEnable, - })) - client := ssm.New(sess) + getCustomAWSEndpoint := func(o *ssm.Options) { + endpointUrl := viper.GetString("aws-endpoint-url") + if endpointUrl != "" { + o.BaseEndpoint = aws.String(endpointUrl) + } + } + cfg, err := config.LoadDefaultConfig(context.Background(), + config.WithSharedConfigProfile(viper.GetString("profile")), + config.WithRegion(region), + ) + if err != nil { + panic(err) + } + client := ssm.NewFromConfig(cfg, getCustomAWSEndpoint) return client } // getPlatformFamily checks an ECS tasks properties to see if the OS can be derived from its properties, otherwise // it will check the container instance itself to determine the OS. -func getPlatformFamily(client *ecs.Client, clusterName string, task *ecsTypes.Task) (string, error) { +func getPlatformFamily(client ECSClient, clusterName string, task *ecsTypes.Task) (string, error) { taskDefinition, err := client.DescribeTaskDefinition(context.TODO(), &ecs.DescribeTaskDefinitionInput{ TaskDefinition: task.TaskDefinitionArn, }) @@ -96,7 +115,7 @@ func getPlatformFamily(client *ecs.Client, clusterName string, task *ecsTypes.Ta // getContainerInstanceOS describes the specified container instance and checks against the backing EC2 instance // to determine the platform. -func getContainerInstanceOS(ecsClient *ecs.Client, ec2Client *ec2.Client, cluster string, containerInstanceArn string) (string, error) { +func getContainerInstanceOS(ecsClient ECSClient, ec2Client *ec2.Client, cluster string, containerInstanceArn string) (string, error) { res, err := ecsClient.DescribeContainerInstances(context.TODO(), &ecs.DescribeContainerInstancesInput{ Cluster: aws.String(cluster), ContainerInstances: []string{ @@ -107,7 +126,7 @@ func getContainerInstanceOS(ecsClient *ecs.Client, ec2Client *ec2.Client, cluste return "", err } instanceId := res.ContainerInstances[0].Ec2InstanceId - instance, err := ec2Client.DescribeInstances(context.TODO(), &ec2.DescribeInstancesInput{ + instance, _ := ec2Client.DescribeInstances(context.TODO(), &ec2.DescribeInstancesInput{ InstanceIds: []string{ *instanceId, },