From 9d547bfc2cd8e553099e9c0f5ee359a158c71e64 Mon Sep 17 00:00:00 2001 From: Ivan Styazhkin Date: Thu, 29 Feb 2024 06:21:39 -0800 Subject: [PATCH] Capture OS signals while steps executed and perform cleanups as graceful shutdown Summary: Here is my ~~first~~ new attempt to use channels to keep track of step results, errors or OS signals. # Design overview We could use channels to communicate results and errors produced from TTP steps executed. Having this, we might also check additional channel to check if an OS signal was received. This requires changes in the following parts: 0. Steps execution 0. Step results retrieving 0. Clean up execution for steps completed only 0. Sub-TTP processing 0. Proper setup of execution context object per sub-ttp 0. Postponing clean up execution till all steps of root TTP are completed or an error encountered Resolves https://github.com/facebookincubator/TTPForge/issues/476 Reviewed By: d3sch41n Differential Revision: D54117726 fbshipit-source-id: aed5abe03534679f0c74c542e0607607ac149a45 --- cmd/run.go | 10 +- cmd/run_test.go | 3 +- pkg/blocks/basicstep_test.go | 4 +- pkg/blocks/context.go | 19 +++- pkg/blocks/loader.go | 15 ++- pkg/blocks/printstr.go | 5 + pkg/blocks/requirements_test.go | 4 +- pkg/blocks/signal_handler.go | 61 ++++++++++ pkg/blocks/step.go | 43 ++++++- pkg/blocks/step_test.go | 4 +- pkg/blocks/subttp.go | 23 ++-- pkg/blocks/subttp_test.go | 8 +- pkg/blocks/subttpcleanup.go | 2 +- pkg/blocks/ttps.go | 192 +++++++++++++++++++------------- pkg/blocks/ttps_test.go | 9 +- 15 files changed, 285 insertions(+), 117 deletions(-) create mode 100644 pkg/blocks/signal_handler.go diff --git a/cmd/run.go b/cmd/run.go index 4c441ed9..92ca244f 100755 --- a/cmd/run.go +++ b/cmd/run.go @@ -64,7 +64,15 @@ func buildRunCommand(cfg *Config) *cobra.Command { return nil } - if _, err := ttp.Execute(execCtx); err != nil { + runErr := ttp.Execute(*execCtx) + // Run clean up always + cleanupErr := ttp.RunCleanup(*execCtx) + + if cleanupErr != nil { + logging.L().Warnf("Failed to run cleanup: %v", cleanupErr) + } + + if runErr != nil { return fmt.Errorf("failed to run TTP at %v: %v", ttpAbsPath, err) } return nil diff --git a/cmd/run_test.go b/cmd/run_test.go index 36a9c4ee..64709a13 100644 --- a/cmd/run_test.go +++ b/cmd/run_test.go @@ -58,7 +58,6 @@ func checkRunCmdTestCase(t *testing.T, tc runCmdTestCase) { } require.NoError(t, err) assert.Equal(t, tc.expectedStdout, stdoutBuf.String()) - } func TestRun(t *testing.T) { @@ -95,7 +94,7 @@ func TestRun(t *testing.T) { }, { name: "subttp-cleanup", - description: "verify that execution of a subTTP with cleanup succeeds", + description: "when one of subTTP causes failures then cleanups executed in right order", args: []string{ "-c", testConfigFilePath, diff --git a/pkg/blocks/basicstep_test.go b/pkg/blocks/basicstep_test.go index 4d727aa4..356c395c 100755 --- a/pkg/blocks/basicstep_test.go +++ b/pkg/blocks/basicstep_test.go @@ -84,10 +84,10 @@ outputs: filters: - json_path: foo.bar` var s BasicStep - var execCtx TTPExecutionContext + execCtx := NewTTPExecutionContext() err := yaml.Unmarshal([]byte(content), &s) require.NoError(t, err) - err = s.Validate(TTPExecutionContext{}) + err = s.Validate(execCtx) require.NoError(t, err) // execute and check result diff --git a/pkg/blocks/context.go b/pkg/blocks/context.go index 9b562fae..ec127006 100644 --- a/pkg/blocks/context.go +++ b/pkg/blocks/context.go @@ -43,9 +43,22 @@ type TTPExecutionConfig struct { // TTPExecutionContext - holds config and context for the currently executing TTP type TTPExecutionContext struct { - Cfg TTPExecutionConfig - WorkDir string - StepResults *StepResultsRecord + Cfg TTPExecutionConfig + WorkDir string + StepResults *StepResultsRecord + actionResultsChan chan *ActResult + errorsChan chan error + shutdownChan chan bool +} + +// NewTTPExecutionContext creates a new TTPExecutionContext with empty config and created channels +func NewTTPExecutionContext() TTPExecutionContext { + return TTPExecutionContext{ + StepResults: NewStepResultsRecord(), + actionResultsChan: make(chan *ActResult, 1), + errorsChan: make(chan error, 1), + shutdownChan: SetupSignalHandler(), + } } // ExpandVariables takes a string containing the following types of variables diff --git a/pkg/blocks/loader.go b/pkg/blocks/loader.go index 66aec585..f2755971 100755 --- a/pkg/blocks/loader.go +++ b/pkg/blocks/loader.go @@ -149,16 +149,21 @@ func LoadTTP(ttpFilePath string, fsys afero.Fs, execCfg *TTPExecutionConfig, arg } ttp.WorkDir = wd } - execCtx := &TTPExecutionContext{ - Cfg: *execCfg, - WorkDir: ttp.WorkDir, + + execCtx := TTPExecutionContext{ + Cfg: *execCfg, + WorkDir: ttp.WorkDir, + StepResults: NewStepResultsRecord(), + actionResultsChan: make(chan *ActResult, 1), + errorsChan: make(chan error, 1), + shutdownChan: SetupSignalHandler(), } - err = ttp.Validate(*execCtx) + err = ttp.Validate(execCtx) if err != nil { return nil, nil, err } - return ttp, execCtx, nil + return ttp, &execCtx, nil } func readTTPBytes(ttpFilePath string, system afero.Fs) ([]byte, error) { diff --git a/pkg/blocks/printstr.go b/pkg/blocks/printstr.go index fcee1a62..1e4d582a 100755 --- a/pkg/blocks/printstr.go +++ b/pkg/blocks/printstr.go @@ -32,6 +32,11 @@ type PrintStrAction struct { Message string `yaml:"print_str,omitempty"` } +// NewPrintStrAction creates a new PrintStrAction. +func NewPrintStrAction() *PrintStrAction { + return &PrintStrAction{} +} + // IsNil checks if the step is nil or empty and returns a boolean value. func (s *PrintStrAction) IsNil() bool { switch { diff --git a/pkg/blocks/requirements_test.go b/pkg/blocks/requirements_test.go index e368c54e..e0417545 100644 --- a/pkg/blocks/requirements_test.go +++ b/pkg/blocks/requirements_test.go @@ -81,7 +81,7 @@ steps: var ttp TTP err := yaml.Unmarshal([]byte(tc.content), &ttp) require.NoError(t, err) - var ctx TTPExecutionContext + ctx := NewTTPExecutionContext() err = ttp.Validate(ctx) if tc.expectValidateError { require.Error(t, err) @@ -89,7 +89,7 @@ steps: } require.NoError(t, err) - _, err = ttp.Execute(&ctx) + err = ttp.Execute(ctx) if tc.expectExecuteError { assert.Error(t, err) return diff --git a/pkg/blocks/signal_handler.go b/pkg/blocks/signal_handler.go new file mode 100644 index 00000000..7cfaed10 --- /dev/null +++ b/pkg/blocks/signal_handler.go @@ -0,0 +1,61 @@ +/* +Copyright © 2023-present, Meta Platforms, Inc. and affiliates +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +package blocks + +import ( + "os" + "os/signal" + "sync" + "syscall" + + "github.com/facebookincubator/ttpforge/pkg/logging" +) + +var signalHandlerInstalled bool +var signalHandlerLock = sync.Mutex{} +var shutdownChan chan bool + +// SetupSignalHandler sets up SIGINT and SIGTERM handlers for graceful shutdown +func SetupSignalHandler() chan bool { + // setup signal handling only once + signalHandlerLock.Lock() + if signalHandlerInstalled { + signalHandlerLock.Unlock() + return shutdownChan + } + sigs := make(chan os.Signal, 2) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + shutdownChan = make(chan bool, 1) + signalHandlerInstalled = true + signalHandlerLock.Unlock() + + go func() { + var sig os.Signal + var counter int + for { + sig = <-sigs + logging.L().Infof("[%v] Received signal %v, shutting down now", counter, sig) + shutdownChan <- true + counter++ + } + }() + + return shutdownChan +} diff --git a/pkg/blocks/step.go b/pkg/blocks/step.go index e56ef317..b58957ec 100644 --- a/pkg/blocks/step.go +++ b/pkg/blocks/step.go @@ -25,6 +25,7 @@ import ( "github.com/facebookincubator/ttpforge/pkg/checks" "github.com/facebookincubator/ttpforge/pkg/logging" + "github.com/spf13/afero" "gopkg.in/yaml.v3" ) @@ -151,13 +152,21 @@ func (s *Step) UnmarshalYAML(node *yaml.Node) error { return nil } -// Execute runs the action associated with this step +// Execute runs the action associated with this step and sends result/error to channels of the context func (s *Step) Execute(execCtx TTPExecutionContext) (*ActResult, error) { desc := s.action.GetDescription() if desc != "" { logging.L().Infof("Description: %v", desc) } - return s.action.Execute(execCtx) + result, err := s.action.Execute(execCtx) + if err != nil { + logging.L().Errorf("Failed to execute step %v: %v", s.Name, err) + execCtx.errorsChan <- err + } else { + logging.L().Debugf("Successfully executed step %v", s.Name) + execCtx.actionResultsChan <- result + } + return result, err } // Cleanup runs the cleanup action associated with this step @@ -190,7 +199,17 @@ func (s *Step) Validate(execCtx TTPExecutionContext) error { // ParseAction decodes an action (from step or cleanup) in YAML // format into the appropriate struct func (s *Step) ParseAction(node *yaml.Node) (Action, error) { - actionCandidates := []Action{NewBasicStep(), NewFileStep(), NewSubTTPStep(), NewEditStep(), NewFetchURIStep(), NewCreateFileStep(), NewCopyPathStep(), NewRemovePathAction(), &PrintStrAction{}} + actionCandidates := []Action{ + NewBasicStep(), + NewFileStep(), + NewSubTTPStep(), + NewEditStep(), + NewFetchURIStep(), + NewCreateFileStep(), + NewCopyPathStep(), + NewRemovePathAction(), + NewPrintStrAction(), + } var action Action for _, actionType := range actionCandidates { err := node.Decode(actionType) @@ -214,3 +233,21 @@ func (s *Step) ParseAction(node *yaml.Node) (Action, error) { } return action, nil } + +// VerifyChecks runs all checks and returns an error if any of them fail +func (s *Step) VerifyChecks() error { + if len(s.Checks) == 0 { + logging.L().Debugf("No checks defined for step %v", s.Name) + return nil + } + verificationCtx := checks.VerificationContext{ + FileSystem: afero.NewOsFs(), + } + for checkIdx, check := range s.Checks { + if err := check.Verify(verificationCtx); err != nil { + return fmt.Errorf("Success check %d of step %q failed: %w", checkIdx+1, s.Name, err) + } + logging.L().Debugf("Success check %d (%q) of step %q PASSED", checkIdx+1, check.Msg, s.Name) + } + return nil +} diff --git a/pkg/blocks/step_test.go b/pkg/blocks/step_test.go index ff876dcc..af4d0a9e 100644 --- a/pkg/blocks/step_test.go +++ b/pkg/blocks/step_test.go @@ -112,7 +112,7 @@ cleanup: default`, for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var s Step - var execCtx TTPExecutionContext + execCtx := NewTTPExecutionContext() // parse the step err := yaml.Unmarshal([]byte(tc.content), &s) @@ -191,7 +191,7 @@ cleanup: for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var s Step - var execCtx TTPExecutionContext + execCtx := NewTTPExecutionContext() // hack to get a valid temporary path without creating it tmpFile, err := os.CreateTemp("", "ttpforge-test-cleanup-default") diff --git a/pkg/blocks/subttp.go b/pkg/blocks/subttp.go index cef6220f..4ba490db 100644 --- a/pkg/blocks/subttp.go +++ b/pkg/blocks/subttp.go @@ -32,9 +32,8 @@ type SubTTPStep struct { TtpRef string `yaml:"ttp"` Args map[string]string `yaml:"args"` - ttp *TTP - subExecCtx TTPExecutionContext - firstStepToCleanupIdx int + ttp *TTP + subExecCtx *TTPExecutionContext } // NewSubTTPStep creates a new SubTTPStep and returns a pointer to it. @@ -78,19 +77,17 @@ func (s *SubTTPStep) processSubTTPArgs(execCtx TTPExecutionContext) ([]string, e } // Execute runs each step of the TTP file associated with the SubTTPStep -// and manages the outputs and cleanup steps. func (s *SubTTPStep) Execute(execCtx TTPExecutionContext) (*ActResult, error) { logging.L().Infof("[*] Executing Sub TTP: %s", s.TtpRef) - execResults, firstStepToCleanupIdx, runErr := s.ttp.RunSteps(&execCtx) - s.firstStepToCleanupIdx = firstStepToCleanupIdx + runErr := s.ttp.RunSteps(*s.subExecCtx) if runErr != nil { - return nil, runErr + return &ActResult{}, runErr } logging.L().Info("[*] Completed SubTTP - No Errors :)") - // just a little annoying plumbing due to subtle type differences0 + // just a little annoying plumbing due to subtle type differences var actResults []*ActResult - for _, execResult := range execResults.ByIndex { + for _, execResult := range s.subExecCtx.StepResults.ByIndex { actResults = append(actResults, &execResult.ActResult) } return aggregateResults(actResults), nil @@ -100,7 +97,7 @@ func (s *SubTTPStep) Execute(execCtx TTPExecutionContext) (*ActResult, error) { // and validates the contained steps. func (s *SubTTPStep) loadSubTTP(execCtx TTPExecutionContext) error { repo := execCtx.Cfg.Repo - subTTPAbsPath, err := execCtx.Cfg.Repo.FindTTP(s.TtpRef) + subTTPAbsPath, err := repo.FindTTP(s.TtpRef) if err != nil { return err } @@ -110,11 +107,13 @@ func (s *SubTTPStep) loadSubTTP(execCtx TTPExecutionContext) error { return err } - ttps, _, err := LoadTTP(subTTPAbsPath, repo.GetFs(), &s.subExecCtx.Cfg, subArgsKv) + ttps, ctx, err := LoadTTP(subTTPAbsPath, repo.GetFs(), &execCtx.Cfg, subArgsKv) if err != nil { return err } s.ttp = ttps + s.subExecCtx = ctx + return nil } @@ -144,5 +143,5 @@ func (s *SubTTPStep) Validate(execCtx TTPExecutionContext) error { return err } - return s.ttp.Validate(execCtx) + return s.ttp.Validate(*s.subExecCtx) } diff --git a/pkg/blocks/subttp_test.go b/pkg/blocks/subttp_test.go index 2ac633ce..1da59c6e 100755 --- a/pkg/blocks/subttp_test.go +++ b/pkg/blocks/subttp_test.go @@ -125,11 +125,11 @@ ttp: with/cleanup.yaml`, repo, err := tc.spec.Load(tc.fsys, "") require.NoError(t, err) - execCtx := TTPExecutionContext{ - Cfg: TTPExecutionConfig{ - Repo: repo, - }, + execCtx := NewTTPExecutionContext() + execCtx.Cfg = TTPExecutionConfig{ + Repo: repo, } + err = step.Validate(execCtx) require.NoError(t, err, "step failed to validate") diff --git a/pkg/blocks/subttpcleanup.go b/pkg/blocks/subttpcleanup.go index 30b57bd9..2675a754 100644 --- a/pkg/blocks/subttpcleanup.go +++ b/pkg/blocks/subttpcleanup.go @@ -28,7 +28,7 @@ type subTTPCleanupAction struct { // Execute will cleanup the subTTP starting from the last successful step func (a *subTTPCleanupAction) Execute(execCtx TTPExecutionContext) (*ActResult, error) { - cleanupResults, err := a.step.ttp.startCleanupAtStepIdx(a.step.firstStepToCleanupIdx, &execCtx) + cleanupResults, err := a.step.ttp.startCleanupForCompletedSteps(*a.step.subExecCtx) if err != nil { return nil, err } diff --git a/pkg/blocks/ttps.go b/pkg/blocks/ttps.go index 72d39875..84ab3bb8 100755 --- a/pkg/blocks/ttps.go +++ b/pkg/blocks/ttps.go @@ -29,7 +29,6 @@ import ( "github.com/facebookincubator/ttpforge/pkg/checks" "github.com/facebookincubator/ttpforge/pkg/logging" "github.com/facebookincubator/ttpforge/pkg/platforms" - "github.com/spf13/afero" "gopkg.in/yaml.v3" ) @@ -149,6 +148,7 @@ func (t *TTP) chdir() (func(), error) { // note: t.WorkDir may not be set in tests but should // be set when actually using `ttpforge run` if t.WorkDir == "" { + logging.L().Info("Not changing working directory in tests") return func() {}, nil } origDir, err := os.Getwd() @@ -165,6 +165,17 @@ func (t *TTP) chdir() (func(), error) { }, nil } +// verify that we actually meet the necessary requirements to execute this TTP +func (t *TTP) verifyPlatform() error { + verificationCtx := checks.VerificationContext{ + Platform: platforms.Spec{ + OS: runtime.GOOS, + Arch: runtime.GOARCH, + }, + } + return t.Requirements.Verify(verificationCtx) +} + // Execute executes all of the steps in the given TTP, // then runs cleanup if appropriate // @@ -176,45 +187,14 @@ func (t *TTP) chdir() (func(), error) { // // *StepResultsRecord: A StepResultsRecord containing the results of each step. // error: An error if any of the steps fail to execute. -func (t *TTP) Execute(execCtx *TTPExecutionContext) (*StepResultsRecord, error) { +func (t *TTP) Execute(execCtx TTPExecutionContext) error { logging.L().Infof("RUNNING TTP: %v", t.Name) - // verify that we actually meet the necessary requirements to execute this TTP - verificationCtx := checks.VerificationContext{ - Platform: platforms.Spec{ - OS: runtime.GOOS, - Arch: runtime.GOARCH, - }, - } - if err := t.Requirements.Verify(verificationCtx); err != nil { - return nil, fmt.Errorf("TTP requirements not met: %w", err) + if err := t.verifyPlatform(); err != nil { + return fmt.Errorf("TTP requirements not met: %w", err) } - stepResults, firstStepToCleanupIdx, runErr := t.RunSteps(execCtx) - logging.DividerThin() - if runErr != nil { - // we need to run cleanup so we don't return here - logging.L().Errorf("[*] Error executing TTP: %v", runErr) - } else { - logging.L().Info("TTP Completed Successfully! ✅") - } - if !execCtx.Cfg.NoCleanup { - if execCtx.Cfg.CleanupDelaySeconds > 0 { - logging.L().Infof("[*] Sleeping for Requested Cleanup Delay of %v Seconds", execCtx.Cfg.CleanupDelaySeconds) - time.Sleep(time.Duration(execCtx.Cfg.CleanupDelaySeconds) * time.Second) - } - cleanupResults, err := t.startCleanupAtStepIdx(firstStepToCleanupIdx, execCtx) - if err != nil { - return nil, err - } - // since ByIndex and ByName both contain pointers to - // the same underlying struct, this will update both - for cleanupIdx, cleanupResult := range cleanupResults { - execCtx.StepResults.ByIndex[cleanupIdx].Cleanup = cleanupResult - } - } - // still pass up the run error after our cleanup - return stepResults, runErr + return t.RunSteps(execCtx) } // RunSteps executes all of the steps in the given TTP. @@ -225,69 +205,126 @@ func (t *TTP) Execute(execCtx *TTPExecutionContext) (*StepResultsRecord, error) // // **Returns:** // -// *StepResultsRecord: A StepResultsRecord containing the results of each step. -// int: the index of the step where cleanup should start (usually the last successful step) // error: An error if any of the steps fail to execute. -func (t *TTP) RunSteps(execCtx *TTPExecutionContext) (*StepResultsRecord, int, error) { +func (t *TTP) RunSteps(execCtx TTPExecutionContext) error { // go to the configuration directory for this TTP changeBack, err := t.chdir() if err != nil { - return nil, -1, err + return err } defer changeBack() + var stepError error + var verifyError error + var shutdownFlag bool + // actually run all the steps - stepResults := NewStepResultsRecord() - execCtx.StepResults = stepResults - firstStepToCleanupIdx := -1 for stepIdx, step := range t.Steps { - stepCopy := step logging.DividerThin() logging.L().Infof("Executing Step #%d: %q", stepIdx+1, step.Name) - // core execution - run the step action - stepResult, err := stepCopy.Execute(*execCtx) + go func(step Step) { + _, err := step.Execute(execCtx) + if err != nil { + // This error was logged by the step itself + logging.L().Debugf("Error executing step %s: %v", step.Name, err) + } + }(step) - // this part is tricky - SubTTP steps - // must be cleaned up even on failure - // (because substeps may have succeeded) - // so in those cases, we need to save the result - // even if nil - if err != nil { + // await one of three outcomes: + // 1. step execution successful + // 2. step execution failed + // 3. shutdown signal received + select { + case stepResult := <-execCtx.actionResultsChan: + // step execution successful - record results + execResult := &ExecutionResult{ + ActResult: *stepResult, + } + execCtx.StepResults.ByName[step.Name] = execResult + execCtx.StepResults.ByIndex = append(execCtx.StepResults.ByIndex, execResult) + + case stepError = <-execCtx.errorsChan: + // this part is tricky - SubTTP steps + // must be cleaned up even on failure + // (because substeps may have succeeded) + // so in those cases, we need to save the result + // even if nil if step.ShouldCleanupOnFailure() { logging.L().Infof("[+] Cleaning up failed step %s", step.Name) logging.L().Infof("[+] Full Cleanup will Run Afterward") - _, cleanupErr := step.Cleanup(*execCtx) + _, cleanupErr := step.Cleanup(execCtx) if cleanupErr != nil { - logging.L().Errorf("error cleaning up failed step %v: %v", step.Name, err) + logging.L().Errorf("Error cleaning up failed step %v: %v", step.Name, cleanupErr) } } - return nil, firstStepToCleanupIdx, err + + case shutdownFlag = <-execCtx.shutdownChan: + // TODO[nesusvet]: We should propagate signal to child processes if any + logging.L().Warn("Shutting down due to signal received") } // if the user specified custom success checks, run them now - verificationCtx := checks.VerificationContext{ - FileSystem: afero.NewOsFs(), - } - for checkIdx, check := range step.Checks { - if err := check.Verify(verificationCtx); err != nil { - return nil, firstStepToCleanupIdx, fmt.Errorf("success check %d of step %q failed: %w", checkIdx+1, step.Name, err) - } - logging.L().Debugf("Success check %d (%q) of step %q PASSED", checkIdx+1, check.Msg, step.Name) - } + verifyError = step.VerifyChecks() - // step execution successful - record results - firstStepToCleanupIdx++ - execResult := &ExecutionResult{ - ActResult: *stepResult, + if stepError != nil || verifyError != nil || shutdownFlag { + logging.L().Debug("[*] Stopping TTP Early") + break } - stepResults.ByName[step.Name] = execResult - stepResults.ByIndex = append(stepResults.ByIndex, execResult) } - return stepResults, firstStepToCleanupIdx, nil + + logging.DividerThin() + if stepError != nil { + logging.L().Errorf("[*] Error executing TTP: %v", stepError) + return stepError + } + if verifyError != nil { + logging.L().Errorf("[*] Error verifying TTP: %v", verifyError) + return verifyError + } + if shutdownFlag { + return fmt.Errorf("[*] Shutting Down now") + } + + logging.L().Info("All steps completed successfully! ✅") + return nil +} + +// RunCleanup executes all required cleanup for steps in the given TTP. +// +// **Parameters:** +// +// execCtx: The current TTPExecutionContext +// +// **Returns:** +// +// error: An error if any of the clean ups fail to execute. +func (t *TTP) RunCleanup(execCtx TTPExecutionContext) error { + if execCtx.Cfg.NoCleanup { + logging.L().Info("[*] Skipping Cleanup as requested by Config") + return nil + } + + if execCtx.Cfg.CleanupDelaySeconds > 0 { + logging.L().Infof("[*] Sleeping for Requested Cleanup Delay of %v Seconds", execCtx.Cfg.CleanupDelaySeconds) + time.Sleep(time.Duration(execCtx.Cfg.CleanupDelaySeconds) * time.Second) + } + + // TODO[nesusvet]: We also should catch signals in clean ups + cleanupResults, err := t.startCleanupForCompletedSteps(execCtx) + if err != nil { + return err + } + // since ByIndex and ByName both contain pointers to + // the same underlying struct, this will update both + for cleanupIdx, cleanupResult := range cleanupResults { + execCtx.StepResults.ByIndex[cleanupIdx].Cleanup = cleanupResult + } + + return nil } -func (t *TTP) startCleanupAtStepIdx(firstStepToCleanupIdx int, execCtx *TTPExecutionContext) ([]*ActResult, error) { +func (t *TTP) startCleanupForCompletedSteps(execCtx TTPExecutionContext) ([]*ActResult, error) { // go to the configuration directory for this TTP changeBack, err := t.chdir() if err != nil { @@ -296,15 +333,16 @@ func (t *TTP) startCleanupAtStepIdx(firstStepToCleanupIdx int, execCtx *TTPExecu defer changeBack() logging.DividerThick() - logging.L().Infof("CLEANING UP TTP: %q", t.Name) - var cleanupResults []*ActResult - for cleanupIdx := firstStepToCleanupIdx; cleanupIdx >= 0; cleanupIdx-- { + n := len(execCtx.StepResults.ByIndex) + logging.L().Infof("CLEANING UP %v steps of TTP: %q", n, t.Name) + cleanupResults := make([]*ActResult, n) + for cleanupIdx := n - 1; cleanupIdx >= 0; cleanupIdx-- { stepToCleanup := t.Steps[cleanupIdx] logging.DividerThin() logging.L().Infof("Cleaning Up Step #%d: %q", cleanupIdx+1, stepToCleanup.Name) - cleanupResult, err := stepToCleanup.Cleanup(*execCtx) + cleanupResult, err := stepToCleanup.Cleanup(execCtx) // must be careful to put these in step order, not in execution (reverse) order - cleanupResults = append([]*ActResult{cleanupResult}, cleanupResults...) + cleanupResults[cleanupIdx] = cleanupResult if err != nil { logging.L().Errorf("error cleaning up step: %v", err) logging.L().Errorf("will continue to try to cleanup other steps") diff --git a/pkg/blocks/ttps_test.go b/pkg/blocks/ttps_test.go index dc5d3988..dffed897 100755 --- a/pkg/blocks/ttps_test.go +++ b/pkg/blocks/ttps_test.go @@ -329,18 +329,20 @@ steps: return } + execCtx := NewTTPExecutionContext() // validate the TTP - err = ttp.Validate(TTPExecutionContext{}) + err = ttp.Validate(execCtx) require.NoError(t, err) // run it - stepResults, err := ttp.Execute(&TTPExecutionContext{}) + err = ttp.Execute(execCtx) if tc.wantError { require.Error(t, err) return } require.NoError(t, err) + stepResults := execCtx.StepResults for index, output := range tc.expectedByIndexOut { require.Equal(t, output, stepResults.ByIndex[index].Stdout) } @@ -391,7 +393,8 @@ mitre: var ttp TTP err := yaml.Unmarshal([]byte(tc.content), &ttp) require.NoError(t, err) - err = ttp.Validate(TTPExecutionContext{}) + execCtx := NewTTPExecutionContext() + err = ttp.Validate(execCtx) if tc.wantError { assert.Error(t, err) } else {