Skip to content

Commit

Permalink
Add SourcesToFilter support for network-blackhole-port fault (aws#4408)
Browse files Browse the repository at this point in the history
  • Loading branch information
amogh09 authored Oct 23, 2024
1 parent 1e6b153 commit 75ee48f
Show file tree
Hide file tree
Showing 9 changed files with 261 additions and 26 deletions.
2 changes: 0 additions & 2 deletions agent/handlers/task_server_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3806,8 +3806,6 @@ func TestRegisterStartBlackholePortFaultHandler(t *testing.T) {
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
)
}
tcs := generateCommonNetworkFaultInjectionTestCases("start blackhole port", "running", setExecExpectations, happyBlackHolePortReqBody)
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 20 additions & 8 deletions ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,19 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht
if err != nil {
return
}

// Validate the fault request
err = validateRequest(w, request, requestType)
if err != nil {
return
}

if aws.StringValue(request.TrafficType) == types.TrafficTypeEgress &&
aws.Uint16Value(request.Port) == tmds.PortForTasks {
// Add TMDS IP to SouresToFilter so that access to TMDS is not blocked for the task
request.AddSourceToFilterIfNotAlready(tmds.IPForTasks)
}

// Obtain the task metadata via the endpoint container ID
taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r)
if err != nil {
Expand Down Expand Up @@ -154,7 +161,8 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht
insertTable = "OUTPUT"
}

_, cmdErr := h.startNetworkBlackholePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName,
_, cmdErr := h.startNetworkBlackholePort(ctxWithTimeout, aws.StringValue(request.Protocol),
port, aws.StringValueSlice(request.SourcesToFilter), chainName,
networkMode, networkNSPath, insertTable, taskArn)
if err := ctxWithTimeout.Err(); errors.Is(err, context.DeadlineExceeded) {
statusCode = http.StatusInternalServerError
Expand Down Expand Up @@ -187,7 +195,10 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht
// 2. Creates a new chain via `iptables -N <chain>` (the chain name is in the form of "<trafficType>-<protocol>-<port>")
// 3. Appends a new rule to the newly created chain via `iptables -A <chain> -p <protocol> --dport <port> -j DROP`
// 4. Inserts the newly created chain into the built-in INPUT/OUTPUT table
func (h *FaultHandler) startNetworkBlackholePort(ctx context.Context, protocol, port, chain, networkMode, netNs, insertTable, taskArn string) (string, error) {
func (h *FaultHandler) startNetworkBlackholePort(
ctx context.Context, protocol, port string, sourcesToFilter []string,
chain, networkMode, netNs, insertTable, taskArn string,
) (string, error) {
running, cmdOutput, err := h.checkNetworkBlackHolePort(ctx, protocol, port, chain, networkMode, netNs, taskArn)
if err != nil {
return cmdOutput, err
Expand Down Expand Up @@ -246,12 +257,13 @@ func (h *FaultHandler) startNetworkBlackholePort(ctx context.Context, protocol,
return "", nil
}

// Add a rule to accept all traffic to TMDS
protectTMDSRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd,
requestTimeoutSeconds, chain, protocol, tmds.IPForTasks, tmds.PortForTasks,
acceptTarget)
if out, err := execRuleChangeCommand(protectTMDSRuleCmdString); err != nil {
return out, err
for _, sourceToFilter := range sourcesToFilter {
filterRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd,
requestTimeoutSeconds, chain, protocol, sourceToFilter, port,
acceptTarget)
if out, err := execRuleChangeCommand(filterRuleCmdString); err != nil {
return out, err
}
}

// Add a rule to drop all traffic to the port that the fault targets
Expand Down
136 changes: 132 additions & 4 deletions ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
mock_state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/mocks"
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig"
mock_execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks"
"github.com/aws/aws-sdk-go/aws"

"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
Expand Down Expand Up @@ -521,8 +522,6 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
)
},
},
Expand Down Expand Up @@ -556,8 +555,6 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
)
},
},
Expand Down Expand Up @@ -663,6 +660,137 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
)
},
},
{
name: "SourcesToFilter validation failure",
expectedStatusCode: 400,
requestBody: map[string]interface{}{
"Port": port,
"Protocol": protocol,
"TrafficType": trafficType,
"SourcesToFilter": aws.StringSlice([]string{"1.2.3.4", "bad"}),
},
expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("invalid value bad for parameter SourcesToFilter"),
},
{
name: "TMDS IP is added to SourcesToFilter if needed",
requestBody: map[string]interface{}{
"Port": 80,
"Protocol": protocol,
"TrafficType": "egress",
},
expectedStatusCode: 200,
expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"),
setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) {
agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netConfigClient).
Return(happyTaskResponse, nil).
Times(1)
},
setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) {
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeoutDuration)
cmdExec := mock_execwrapper.NewMockCmd(ctrl)
gomock.InOrder(
exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")),
exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true),
exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(),
"nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "egress-tcp-80",
"-p", "tcp", "-d", "169.254.170.2", "--dport", "80", "-j", "ACCEPT",
).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(),
"nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "egress-tcp-80",
"-p", "tcp", "-d", "0.0.0.0/0", "--dport", "80", "-j", "DROP",
).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
)
},
},
{
name: "Sources to filter are filtered",
requestBody: map[string]interface{}{
"Port": 443,
"Protocol": "udp",
"TrafficType": "ingress",
"SourcesToFilter": []string{"1.2.3.4/20", "8.8.8.8"},
},
expectedStatusCode: 200,
expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"),
setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) {
agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netConfigClient).
Return(happyTaskResponse, nil).
Times(1)
},
setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) {
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeoutDuration)
cmdExec := mock_execwrapper.NewMockCmd(ctrl)
gomock.InOrder(
exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")),
exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true),
exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(),
"nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "ingress-udp-443",
"-p", "udp", "-d", "1.2.3.4/20", "--dport", "443", "-j", "ACCEPT",
).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(),
"nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "ingress-udp-443",
"-p", "udp", "-d", "8.8.8.8", "--dport", "443", "-j", "ACCEPT",
).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(),
"nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "ingress-udp-443",
"-p", "udp", "-d", "0.0.0.0/0", "--dport", "443", "-j", "DROP",
).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
)
},
},
{
name: "Error when filtering a source",
expectedStatusCode: 500,
requestBody: map[string]interface{}{
"Port": 443,
"Protocol": "udp",
"TrafficType": "ingress",
"SourcesToFilter": []string{"1.2.3.4/20"},
},
expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError),
setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) {
agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netConfigClient).
Return(happyTaskResponse, nil).
Times(1)
},
setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) {
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeoutDuration)
cmdExec := mock_execwrapper.NewMockCmd(ctrl)
gomock.InOrder(
exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")),
exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true),
exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(),
"nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "ingress-udp-443",
"-p", "udp", "-d", "1.2.3.4/20", "--dport", "443", "-j", "ACCEPT",
).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(internalError), errors.New("exit status 1")),
)
},
},
}

return append(tcs, commonTcs...)
Expand Down
Loading

0 comments on commit 75ee48f

Please sign in to comment.