Skip to content

Commit

Permalink
Add implementation for fault injection allowlist IP addresses (aws#4373)
Browse files Browse the repository at this point in the history
Co-authored-by: Tianze Shan <[email protected]>
  • Loading branch information
tshan2001 and Tianze Shan authored Oct 2, 2024
1 parent 6296dfc commit 8431d7f
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 121 deletions.
8 changes: 4 additions & 4 deletions agent/handlers/task_server_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3854,8 +3854,8 @@ func TestRegisterStartLatencyFaultHandler(t *testing.T) {
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD),
mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil),
)
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(4).Return(mockCMD)
mockCMD.EXPECT().CombinedOutput().Times(4).Return([]byte(tcCommandEmptyOutput), nil)
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(5).Return(mockCMD)
mockCMD.EXPECT().CombinedOutput().Times(5).Return([]byte(tcCommandEmptyOutput), nil)
}
tcs := generateCommonNetworkFaultInjectionTestCases("start latency", "running", setExecExpectations, happyNetworkLatencyReqBody)
testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.LatencyFaultType, faulttype.StartNetworkFaultPostfix))
Expand Down Expand Up @@ -3898,8 +3898,8 @@ func TestRegisterStartPacketLossFaultHandler(t *testing.T) {
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD),
mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil),
)
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(4).Return(mockCMD)
mockCMD.EXPECT().CombinedOutput().Times(4).Return([]byte(tcCommandEmptyOutput), nil)
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(5).Return(mockCMD)
mockCMD.EXPECT().CombinedOutput().Times(5).Return([]byte(tcCommandEmptyOutput), nil)
}
tcs := generateCommonNetworkFaultInjectionTestCases("start packet loss", "running", setExecExpectations, happyNetworkPacketLossReqBody)
testRegisterFaultHandler(t, tcs, faulthandler.NetworkFaultPath(faulttype.PacketLossFaultType, faulttype.StartNetworkFaultPostfix))
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

80 changes: 36 additions & 44 deletions ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ const (
tcAddQdiscRootCommandString = "tc qdisc add dev %s root handle 1: prio priomap 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2"
tcAddQdiscLatencyCommandString = "tc qdisc add dev %s parent 1:1 handle 10: netem delay %dms %dms"
tcAddQdiscLossCommandString = "tc qdisc add dev %s parent 1:1 handle 10: netem loss %d%%"
tcAddFilterForIPCommandString = "tc filter add dev %s protocol ip parent 1:0 prio 1 u32 match ip dst %s flowid 1:1"
tcAllowlistIPCommandString = "tc filter add dev %s protocol ip parent 1:0 prio 1 u32 match ip dst %s flowid 1:3"
tcAddFilterForIPCommandString = "tc filter add dev %s protocol ip parent 1:0 prio 2 u32 match ip dst %s flowid 1:1"
tcDeleteQdiscParentCommandString = "tc qdisc del dev %s parent 1:1 handle 10:"
tcDeleteFilterCommandString = "tc filter del dev %s prio 1"
tcDeleteQdiscRootCommandString = "tc qdisc del dev %s root handle 1: prio"
)

Expand Down Expand Up @@ -1231,19 +1231,13 @@ func (h *FaultHandler) startNetworkLatencyFault(ctx context.Context, taskMetadat
field.CommandOutput: string(cmdOutput[:]),
})
// After creating the queueing discipline, create filters to associate the IPs in the request with the handle.
for _, ip := range request.Sources {
tcAddFilterForIPCommandComposed := nsenterPrefix + fmt.Sprintf(tcAddFilterForIPCommandString, interfaceName, *ip)
cmdList = strings.Split(tcAddFilterForIPCommandComposed, " ")
_, err = h.runExecCommand(ctx, cmdList)
if err != nil {
logger.Error("Command execution failed", logger.Fields{
field.CommandString: tcAddFilterForIPCommandComposed,
field.Error: err,
field.CommandOutput: string(cmdOutput[:]),
field.TaskARN: taskMetadata.TaskARN,
})
return err
}
// First redirect the allowlisted ip addresses to band 1:3 where is no network impairments.
if err := h.addIPAddressesToFilter(ctx, request.SourcesToFilter, taskMetadata, nsenterPrefix, tcAllowlistIPCommandString, interfaceName); err != nil {
return err
}
// After processing the allowlisted ips, associate the ip addresses in Sources with the qdisc.
if err := h.addIPAddressesToFilter(ctx, request.Sources, taskMetadata, nsenterPrefix, tcAddFilterForIPCommandString, interfaceName); err != nil {
return err
}

return nil
Expand Down Expand Up @@ -1296,19 +1290,13 @@ func (h *FaultHandler) startNetworkPacketLossFault(ctx context.Context, taskMeta
field.CommandOutput: string(cmdOutput[:]),
})
// After creating the queueing discipline, create filters to associate the IPs in the request with the handle.
for _, ip := range request.Sources {
tcAddFilterForIPCommandComposed := nsenterPrefix + fmt.Sprintf(tcAddFilterForIPCommandString, interfaceName, *ip)
cmdList = strings.Split(tcAddFilterForIPCommandComposed, " ")
_, err = h.runExecCommand(ctx, cmdList)
if err != nil {
logger.Error("Command execution failed", logger.Fields{
field.CommandString: tcAddFilterForIPCommandComposed,
field.Error: err,
field.CommandOutput: string(cmdOutput[:]),
field.TaskARN: taskMetadata.TaskARN,
})
return err
}
// First redirect the allowlisted ip addresses to band 1:3 where is no network impairments.
if err := h.addIPAddressesToFilter(ctx, request.SourcesToFilter, taskMetadata, nsenterPrefix, tcAllowlistIPCommandString, interfaceName); err != nil {
return err
}
// After processing the allowlisted ips, associate the ip addresses in Sources with the qdisc.
if err := h.addIPAddressesToFilter(ctx, request.Sources, taskMetadata, nsenterPrefix, tcAddFilterForIPCommandString, interfaceName); err != nil {
return err
}

return nil
Expand Down Expand Up @@ -1345,22 +1333,6 @@ func (h *FaultHandler) stopTCFault(ctx context.Context, taskMetadata *state.Task
field.CommandString: tcDeleteQdiscParentCommandComposed,
field.CommandOutput: string(cmdOutput[:]),
})
tcDeleteFilterCommandComposed := nsenterPrefix + fmt.Sprintf(tcDeleteFilterCommandString, interfaceName)
cmdList = strings.Split(tcDeleteFilterCommandComposed, " ")
cmdOutput, err = h.runExecCommand(ctx, cmdList)
if err != nil {
logger.Error("Command execution failed", logger.Fields{
field.CommandString: tcDeleteFilterCommandComposed,
field.Error: err,
field.CommandOutput: string(cmdOutput[:]),
field.TaskARN: taskMetadata.TaskARN,
})
return err
}
logger.Info("Command execution completed", logger.Fields{
field.CommandString: tcDeleteFilterCommandComposed,
field.CommandOutput: string(cmdOutput[:]),
})
tcDeleteQdiscRootCommandComposed := nsenterPrefix + fmt.Sprintf(tcDeleteQdiscRootCommandString, interfaceName)
cmdList = strings.Split(tcDeleteQdiscRootCommandComposed, " ")
_, err = h.runExecCommand(ctx, cmdList)
Expand Down Expand Up @@ -1463,6 +1435,26 @@ func checkPacketLossFault(outputUnmarshalled []map[string]interface{}) (bool, er
return false, nil
}

func (h *FaultHandler) addIPAddressesToFilter(
ctx context.Context, ipAddressList []*string, taskMetadata *state.TaskResponse,
nsenterPrefix, commandString, interfaceName string) error {
for _, ip := range ipAddressList {
commandComposed := nsenterPrefix + fmt.Sprintf(commandString, interfaceName, aws.StringValue(ip))
cmdList := strings.Split(commandComposed, " ")
cmdOutput, err := h.runExecCommand(ctx, cmdList)
if err != nil {
logger.Error("Command execution failed", logger.Fields{
field.CommandString: commandComposed,
field.Error: err,
field.CommandOutput: string(cmdOutput[:]),
field.TaskARN: taskMetadata.TaskARN,
})
return err
}
}
return nil
}

// runExecCommand wraps around the execwrapper, providing a convenient way of running any Linux command
// and getting the result in both stdout and stderr.
func (h *FaultHandler) runExecCommand(ctx context.Context, cmdList []string) ([]byte, error) {
Expand Down
28 changes: 15 additions & 13 deletions ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1216,8 +1216,8 @@ func generateStartNetworkLatencyTestCases() []networkFaultInjectionTestCase {
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD),
mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil),
)
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(4).Return(mockCMD)
mockCMD.EXPECT().CombinedOutput().Times(4).Return([]byte(tcCommandEmptyOutput), nil)
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(5).Return(mockCMD)
mockCMD.EXPECT().CombinedOutput().Times(5).Return([]byte(tcCommandEmptyOutput), nil)
},
},
{
Expand Down Expand Up @@ -1259,6 +1259,7 @@ func generateStartNetworkLatencyTestCases() []networkFaultInjectionTestCase {
"DelayMilliseconds": delayMilliseconds,
"JitterMilliseconds": jitterMilliseconds,
"Sources": ipSources,
"SourcesToFilter": []string{},
"Unknown": "",
},
expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"),
Expand Down Expand Up @@ -1318,8 +1319,8 @@ func generateStopNetworkLatencyTestCases() []networkFaultInjectionTestCase {
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD),
mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLatencyFaultExistsCommandOutput), nil),
)
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(3).Return(mockCMD)
mockCMD.EXPECT().CombinedOutput().Times(3).Return([]byte(""), nil)
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(2).Return(mockCMD)
mockCMD.EXPECT().CombinedOutput().Times(2).Return([]byte(""), nil)
},
},
{
Expand Down Expand Up @@ -1423,7 +1424,7 @@ func generateCheckNetworkLatencyTestCases() []networkFaultInjectionTestCase {
},
},
{
name: "unknown-request-body-no-existing-fault",
name: "unknown-request-body-no-existing-fault-no-allowlist-filter",
expectedStatusCode: 200,
requestBody: map[string]interface{}{
"DelayMilliseconds": delayMilliseconds,
Expand Down Expand Up @@ -1764,8 +1765,8 @@ func generateStartNetworkPacketLossTestCases() []networkFaultInjectionTestCase {
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD),
mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcCommandEmptyOutput), nil),
)
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(4).Return(mockCMD)
mockCMD.EXPECT().CombinedOutput().Times(4).Return([]byte(tcCommandEmptyOutput), nil)
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(5).Return(mockCMD)
mockCMD.EXPECT().CombinedOutput().Times(5).Return([]byte(tcCommandEmptyOutput), nil)
},
},
{
Expand Down Expand Up @@ -1801,12 +1802,13 @@ func generateStartNetworkPacketLossTestCases() []networkFaultInjectionTestCase {
},
},
{
name: "unknown-request-body-no-existing-fault",
name: "unknown-request-body-no-existing-fault-no-allowlist-filter",
expectedStatusCode: 200,
requestBody: map[string]interface{}{
"LossPercent": lossPercent,
"Sources": ipSources,
"Unknown": "",
"LossPercent": lossPercent,
"Sources": ipSources,
"SourcesToFilter": []string{},
"Unknown": "",
},
expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"),
setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) {
Expand Down Expand Up @@ -1881,8 +1883,8 @@ func generateStopNetworkPacketLossTestCases() []networkFaultInjectionTestCase {
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(mockCMD),
mockCMD.EXPECT().CombinedOutput().Times(1).Return([]byte(tcLossFaultExistsCommandOutput), nil),
)
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(3).Return(mockCMD)
mockCMD.EXPECT().CombinedOutput().Times(3).Return([]byte(""), nil)
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(2).Return(mockCMD)
mockCMD.EXPECT().CombinedOutput().Times(2).Return([]byte(""), nil)
},
},
{
Expand Down
Loading

0 comments on commit 8431d7f

Please sign in to comment.