Skip to content

Commit

Permalink
feat: add ECSClient interface and replace usage of concrete ecs.Clien…
Browse files Browse the repository at this point in the history
…t type in App
  • Loading branch information
tedsmitt committed Oct 10, 2024
1 parent 98eed67 commit 289dee8
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 37 deletions.
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,8 @@ golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
Expand Down
47 changes: 27 additions & 20 deletions internal/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,40 @@ import (
"github.com/spf13/viper"
)

type ECSClient interface {
ListClusters(ctx context.Context, params *ecs.ListClustersInput, optFns ...func(*ecs.Options)) (*ecs.ListClustersOutput, error)
ListServices(ctx context.Context, params *ecs.ListServicesInput, optFns ...func(*ecs.Options)) (*ecs.ListServicesOutput, error)
ListTasks(ctx context.Context, params *ecs.ListTasksInput, optFns ...func(*ecs.Options)) (*ecs.ListTasksOutput, error)
DescribeTasks(ctx context.Context, params *ecs.DescribeTasksInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTasksOutput, error)
DescribeTaskDefinition(ctx context.Context, params *ecs.DescribeTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTaskDefinitionOutput, error)
DescribeContainerInstances(ctx context.Context, params *ecs.DescribeContainerInstancesInput, optFns ...func(*ecs.Options)) (*ecs.DescribeContainerInstancesOutput, error)
ExecuteCommand(ctx context.Context, params *ecs.ExecuteCommandInput, optFns ...func(*ecs.Options)) (*ecs.ExecuteCommandOutput, error)
}

// App is a struct that contains information about our command state
type App struct {
input chan string
err chan error
exit chan error
client *ecs.Client
region string
endpoint string
cluster string
service string
task ecsTypes.Task
tasks map[string]*ecsTypes.Task
container *ecsTypes.Container
containers []*ecsTypes.Container
input chan string
err chan error
exit chan error
client ECSClient
region string
endpoint string
cluster string
service string
task *ecsTypes.Task
tasks map[string]*ecsTypes.Task
container *ecsTypes.Container
}

// CreateApp initialises a new App struct with the required initial values
func CreateApp() *App {
ecsClient := createEcsClient()
ssmClient := createSSMClient()
client := createEcsClient()
e := &App{
input: make(chan string, 1),
err: make(chan error, 1),
exit: make(chan error, 1),
client: client,
region: client.Options().Region,
// endpoint: client.Endpoint,
}

return e
Expand Down Expand Up @@ -277,13 +284,13 @@ func (e *App) getTask() {
return
}
if len(describe.Tasks) > 0 {
e.task = describe.Tasks[0]
e.task = &describe.Tasks[0]
e.getContainerOS()
e.input <- "getContainer"
viper.Set("task", "") // Reset the cli arg so user can navigate
return
} else {
fmt.Printf(Red(fmt.Sprintf("\nTask with ID %s not found in cluster %s\n", cliArg, e.cluster)))
fmt.Println(Red(fmt.Sprintf("\nTask with ID %s not found in cluster %s\n", cliArg, e.cluster)))
e.input <- "getService"
return
}
Expand Down Expand Up @@ -364,7 +371,7 @@ func (e *App) getTask() {
e.input <- "getService"
return
}
e.task = *selection
e.task = selection
e.getContainerOS()
e.input <- "getContainer"
return
Expand Down Expand Up @@ -425,7 +432,7 @@ func (e *App) getContainer() {
func (e *App) getContainerOS() {
// Get associated task definition and determine OS family if EC2 launch-type
if e.task.LaunchType == "EC2" {
family, err := getPlatformFamily(e.client, e.cluster, &e.task)
family, err := getPlatformFamily(e.client, e.cluster, e.task)
if err != nil {
e.err <- err
return
Expand All @@ -434,7 +441,7 @@ func (e *App) getContainerOS() {
// then we refer to the container instance to determine the OS
if family == "" {
ec2Client := createEc2Client()
family, err = getContainerInstanceOS(e.ecsClient, ec2Client, e.cluster, *e.task.ContainerInstanceArn)
family, err = getContainerInstanceOS(e.client, ec2Client, e.cluster, *e.task.ContainerInstanceArn)
if err != nil {
e.err <- err
return
Expand Down
4 changes: 3 additions & 1 deletion internal/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ecs"
"github.com/aws/aws-sdk-go-v2/service/ssm"
"github.com/spf13/viper"
)
Expand All @@ -29,7 +30,8 @@ func (e *App) executeForward() error {
panic(err)
}
client := ssm.NewFromConfig(cfg) // TODO: add region
containerPort, err := getContainerPort(e.client, *e.task.TaskDefinitionArn, *e.container.Name)
ecsClient := e.client.(*ecs.Client)
containerPort, err := getContainerPort(ecsClient, *e.task.TaskDefinitionArn, *e.container.Name)
if err != nil {
e.err <- err
return err
Expand Down
47 changes: 33 additions & 14 deletions internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ecs"
ecsTypes "github.com/aws/aws-sdk-go-v2/service/ecs/types"
"github.com/aws/aws-sdk-go-v2/service/ssm"
"github.com/fatih/color"
"github.com/spf13/viper"
)
Expand All @@ -39,49 +40,67 @@ func createOpts(opts []string) []string {

func createEcsClient() *ecs.Client {
region := viper.GetString("region")
endpointUrl := viper.GetString("aws-endpoint-url")
getCustomAWSEndpoint := func(o *ecs.Options) {
endpointUrl := viper.GetString("aws-endpoint-url")
if endpointUrl != "" {
o.BaseEndpoint = aws.String(endpointUrl)
}
}
cfg, err := config.LoadDefaultConfig(context.Background(),
config.WithSharedConfigProfile(viper.GetString("profile")),
config.WithRegion(region),
)
if err != nil {
panic(err)
}
client := ecs.NewFromConfig(cfg)
client := ecs.NewFromConfig(cfg, getCustomAWSEndpoint)

return client
}

func createEc2Client() *ec2.Client {
region := viper.GetString("region")
getCustomAWSEndpoint := func(o *ec2.Options) {
endpointUrl := viper.GetString("aws-endpoint-url")
if endpointUrl != "" {
o.BaseEndpoint = aws.String(endpointUrl)
}
}
cfg, err := config.LoadDefaultConfig(context.Background(),
config.WithSharedConfigProfile(viper.GetString("profile")),
config.WithRegion(region),
)
if err != nil {
panic(err)
}
client := ec2.NewFromConfig(cfg)
client := ec2.NewFromConfig(cfg, getCustomAWSEndpoint)

return client
}

func createSSMClient() *ssm.SSM {
func createSSMClient() *ssm.Client {
region := viper.GetString("region")
endpointUrl := viper.GetString("aws-endpoint-url")
sess := session.Must(session.NewSessionWithOptions(session.Options{
Config: aws.Config{Region: aws.String(region), Endpoint: aws.String(endpointUrl)},
Profile: viper.GetString("profile"),
SharedConfigState: session.SharedConfigEnable,
}))
client := ssm.New(sess)
getCustomAWSEndpoint := func(o *ssm.Options) {
endpointUrl := viper.GetString("aws-endpoint-url")
if endpointUrl != "" {
o.BaseEndpoint = aws.String(endpointUrl)
}
}
cfg, err := config.LoadDefaultConfig(context.Background(),
config.WithSharedConfigProfile(viper.GetString("profile")),
config.WithRegion(region),
)
if err != nil {
panic(err)
}
client := ssm.NewFromConfig(cfg, getCustomAWSEndpoint)

return client
}

// getPlatformFamily checks an ECS tasks properties to see if the OS can be derived from its properties, otherwise
// it will check the container instance itself to determine the OS.
func getPlatformFamily(client *ecs.Client, clusterName string, task *ecsTypes.Task) (string, error) {
func getPlatformFamily(client ECSClient, clusterName string, task *ecsTypes.Task) (string, error) {
taskDefinition, err := client.DescribeTaskDefinition(context.TODO(), &ecs.DescribeTaskDefinitionInput{
TaskDefinition: task.TaskDefinitionArn,
})
Expand All @@ -96,7 +115,7 @@ func getPlatformFamily(client *ecs.Client, clusterName string, task *ecsTypes.Ta

// getContainerInstanceOS describes the specified container instance and checks against the backing EC2 instance
// to determine the platform.
func getContainerInstanceOS(ecsClient *ecs.Client, ec2Client *ec2.Client, cluster string, containerInstanceArn string) (string, error) {
func getContainerInstanceOS(ecsClient ECSClient, ec2Client *ec2.Client, cluster string, containerInstanceArn string) (string, error) {
res, err := ecsClient.DescribeContainerInstances(context.TODO(), &ecs.DescribeContainerInstancesInput{
Cluster: aws.String(cluster),
ContainerInstances: []string{
Expand All @@ -107,7 +126,7 @@ func getContainerInstanceOS(ecsClient *ecs.Client, ec2Client *ec2.Client, cluste
return "", err
}
instanceId := res.ContainerInstances[0].Ec2InstanceId
instance, err := ec2Client.DescribeInstances(context.TODO(), &ec2.DescribeInstancesInput{
instance, _ := ec2Client.DescribeInstances(context.TODO(), &ec2.DescribeInstancesInput{
InstanceIds: []string{
*instanceId,
},
Expand Down

0 comments on commit 289dee8

Please sign in to comment.