Skip to content

Commit

Permalink
Merge branch 'dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
xxx0624 authored Oct 22, 2024
2 parents f7bae94 + 79f17a5 commit b7e0b28
Show file tree
Hide file tree
Showing 823 changed files with 165,623 additions and 130,125 deletions.
2 changes: 2 additions & 0 deletions agent/handlers/task_server_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3806,6 +3806,8 @@ func TestRegisterStartBlackholePortFaultHandler(t *testing.T) {
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
)
}
tcs := generateCommonNetworkFaultInjectionTestCases("start blackhole port", "running", setExecExpectations, happyBlackHolePortReqBody)
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

52 changes: 37 additions & 15 deletions ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/aws/amazon-ecs-agent/ecs-agent/logger"
"github.com/aws/amazon-ecs-agent/ecs-agent/logger/field"
"github.com/aws/amazon-ecs-agent/ecs-agent/metrics"
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds"
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/types"
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils"
v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4"
Expand All @@ -56,7 +57,7 @@ const (
requestTimeoutSeconds = 5
// Commands that will be used to start/stop/check fault.
iptablesNewChainCmd = "iptables -w %d -N %s"
iptablesAppendChainRuleCmd = "iptables -w %d -A %s -p %s --dport %s -j DROP"
iptablesAppendChainRuleCmd = "iptables -w %d -A %s -p %s -d %s --dport %s -j %s"
iptablesInsertChainCmd = "iptables -w %d -I %s -j %s"
iptablesChainExistCmd = "iptables -w %d -C %s -p %s --dport %s -j DROP"
iptablesClearChainCmd = "iptables -w %d -F %s"
Expand All @@ -71,6 +72,9 @@ const (
tcAddFilterForIPCommandString = "tc filter add dev %s protocol ip parent 1:0 prio 2 u32 match ip dst %s flowid 1:1"
tcDeleteQdiscParentCommandString = "tc qdisc del dev %s parent 1:1 handle 10:"
tcDeleteQdiscRootCommandString = "tc qdisc del dev %s root handle 1: prio"
allIPv4CIDR = "0.0.0.0/0"
dropTarget = "DROP"
acceptTarget = "ACCEPT"
)

type FaultHandler struct {
Expand Down Expand Up @@ -220,24 +224,42 @@ func (h *FaultHandler) startNetworkBlackholePort(ctx context.Context, protocol,
"taskArn": taskArn,
})

// Appending a new rule based on the protocol and port number from the request body
appendRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd, requestTimeoutSeconds, chain, protocol, port)
cmdOutput, err = h.runExecCommand(ctx, strings.Split(appendRuleCmdString, " "))
if err != nil {
logger.Error("Unable to append rule to chain", logger.Fields{
"netns": netNs,
"command": appendRuleCmdString,
// Helper function to run iptables rule change commands
var execRuleChangeCommand = func(cmdString string) (string, error) {
// Appending a new rule based on the protocol and port number from the request body
cmdOutput, err = h.runExecCommand(ctx, strings.Split(cmdString, " "))
if err != nil {
logger.Error("Unable to add rule to chain", logger.Fields{
"netns": netNs,
"command": cmdString,
"output": string(cmdOutput),
"taskArn": taskArn,
"error": err,
})
return string(cmdOutput), err
}
logger.Info("Successfully added new rule to iptable chain", logger.Fields{
"command": cmdString,
"output": string(cmdOutput),
"taskArn": taskArn,
"error": err,
})
return string(cmdOutput), err
return "", nil
}

// Add a rule to accept all traffic to TMDS
protectTMDSRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd,
requestTimeoutSeconds, chain, protocol, tmds.IPForTasks, tmds.PortForTasks,
acceptTarget)
if out, err := execRuleChangeCommand(protectTMDSRuleCmdString); err != nil {
return out, err
}

// Add a rule to drop all traffic to the port that the fault targets
faultRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd,
requestTimeoutSeconds, chain, protocol, allIPv4CIDR, port, dropTarget)
if out, err := execRuleChangeCommand(faultRuleCmdString); err != nil {
return out, err
}
logger.Info("Successfully appended new rule to iptable chain", logger.Fields{
"command": appendRuleCmdString,
"output": string(cmdOutput),
"taskArn": taskArn,
})

// Inserting the chain into the built-in INPUT/OUTPUT table
insertChainCmdString := nsenterPrefix + fmt.Sprintf(iptablesInsertChainCmd, requestTimeoutSeconds, insertTable, chain)
Expand Down
34 changes: 33 additions & 1 deletion ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,8 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
)
},
},
Expand Down Expand Up @@ -554,6 +556,8 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
)
},
},
Expand All @@ -578,7 +582,7 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
},
},
{
name: fmt.Sprintf("%s fail append rule to chain", startNetworkBlackHolePortTestPrefix),
name: fmt.Sprintf("%s fail append ACCEPT rule to chain", startNetworkBlackHolePortTestPrefix),
expectedStatusCode: 500,
requestBody: happyBlackHolePortReqBody,
expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError),
Expand All @@ -603,6 +607,34 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
)
},
},
{
name: fmt.Sprintf("%s fail append DROP rule to chain", startNetworkBlackHolePortTestPrefix),
expectedStatusCode: 500,
requestBody: happyBlackHolePortReqBody,
expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError),
setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) {
agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netConfigClient).
Return(happyTaskResponse, nil).
Times(1)
},
setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) {
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeoutDuration)
cmdExec := mock_execwrapper.NewMockCmd(ctrl)
gomock.InOrder(
exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")),
exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true),
exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(internalError), errors.New("exit status 1")),
)
},
},
{
name: fmt.Sprintf("%s fail insert chain to table", startNetworkBlackHolePortTestPrefix),
expectedStatusCode: 500,
Expand Down
6 changes: 4 additions & 2 deletions ecs-agent/tmds/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ import (

const (
// TMDS IP and port
IPv4 = "127.0.0.1"
Port = 51679
IPv4 = "127.0.0.1"
Port = 51679
IPForTasks = "169.254.170.2"
PortForTasks = "80"
)

// IPv4 address for TMDS
Expand Down
16 changes: 9 additions & 7 deletions ecs-init/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package cache

import (
"bufio"
"context"
"crypto/md5"
"fmt"
"io"
Expand All @@ -25,8 +26,8 @@ import (

"github.com/aws/amazon-ecs-agent/ecs-init/config"

"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
log "github.com/cihub/seelog"
"github.com/pkg/errors"
)
Expand Down Expand Up @@ -76,7 +77,7 @@ func NewDownloader() (*Downloader, error) {
if config.RunningInExternal() {
downloader.metadata = &blackholeInstanceMetadata{}
} else {
sessionInstance, err := session.NewSession()
cfg, err := awsconfig.LoadDefaultConfig(context.TODO())
if err != nil {
// metadata client is only used for retrieving the user's region.
// If it cannot be initialized, the region field is populated with the default value to prevent future
Expand All @@ -85,7 +86,7 @@ func NewDownloader() (*Downloader, error) {
err, config.DefaultRegionName)
downloader.region = config.DefaultRegionName
} else {
downloader.metadata = ec2metadata.New(sessionInstance)
downloader.metadata = imds.NewFromConfig(cfg)
}
}

Expand Down Expand Up @@ -181,13 +182,14 @@ func (d *Downloader) getRegion() string {
return d.region
}

region, err := d.metadata.Region()
output, err := d.metadata.GetRegion(context.TODO(), &imds.GetRegionInput{})
if err != nil {
log.Warnf("Could not retrieve the region from EC2 Instance Metadata. Error: %s", err.Error())
region = defaultRegion
d.region = defaultRegion
return d.region
}
d.region = region

d.region = output.Region
return d.region
}

Expand Down
32 changes: 17 additions & 15 deletions ecs-init/cache/dependencies.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,23 @@ package cache
//go:generate mockgen.sh cache $GOFILE

import (
"context"
"io"
"os"
"path/filepath"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
log "github.com/cihub/seelog"
"github.com/pkg/errors"
)

// s3API captures the only method used from the s3 package
type s3API interface {
Download(w io.WriterAt, input *s3.GetObjectInput, options ...func(*s3manager.Downloader)) (n int64, err error)
Download(ctx context.Context, w io.WriterAt, input *s3.GetObjectInput, options ...func(*manager.Downloader)) (int64, error)
}

// s3BucketDownloader wraps a bucket together with a downloader that can download from it
Expand All @@ -47,16 +48,17 @@ type s3BucketDownloader struct {
}

func newS3BucketDownloader(region, bucketName string) (*s3BucketDownloader, error) {
session, err := session.NewSession(&aws.Config{
Credentials: credentials.AnonymousCredentials,
Region: aws.String(region),
})
cfg, err := config.LoadDefaultConfig(
context.TODO(),
config.WithCredentialsProvider((aws.AnonymousCredentials{})),
config.WithRegion((region)),
)
if err != nil {
return nil, errors.Wrapf(err, "failed to initialize downloader in region %s", region)
}

s3BucketDownloader := &s3BucketDownloader{
client: s3manager.NewDownloader(session),
client: manager.NewDownloader(s3.NewFromConfig(cfg)),
bucket: bucketName,
region: region,
}
Expand All @@ -77,7 +79,7 @@ func (bd *s3BucketDownloader) download(fileName, cacheDir string, fs fileSystem)
}
}()

_, err = bd.client.Download(file, &s3.GetObjectInput{
_, err = bd.client.Download(context.TODO(), file, &s3.GetObjectInput{
Bucket: aws.String(bd.bucket),
Key: aws.String(fileName),
})
Expand Down Expand Up @@ -137,14 +139,14 @@ type fileSizeInfo interface {
}

type instanceMetadata interface {
Region() (string, error)
GetRegion(ctx context.Context, input *imds.GetRegionInput, opts ...func(*imds.Options)) (*imds.GetRegionOutput, error)
}

type blackholeInstanceMetadata struct {
}

func (b *blackholeInstanceMetadata) Region() (string, error) {
return "", errors.New("blackholed")
func (b *blackholeInstanceMetadata) GetRegion(ctx context.Context, input *imds.GetRegionInput, opts ...func(*imds.Options)) (*imds.GetRegionOutput, error) {
return nil, errors.New("blackholed")
}

// standardFS delegates to the package-level functions
Expand Down
Loading

0 comments on commit b7e0b28

Please sign in to comment.