diff --git a/tester/tester.go b/tester/tester.go index f5e944b0..471ee4af 100644 --- a/tester/tester.go +++ b/tester/tester.go @@ -541,6 +541,33 @@ func (wt *workflowTester[TResult]) sendEvent(wfi *core.WorkflowInstance, event * w.pendingEvents = append(w.pendingEvents, event) } +// CancelWorkflow cancels the workflow under test. +func (wt *workflowTester[TResult]) CancelWorkflow() { + _ = wt.CancelWorkflowInstance(wt.wfi) +} + +// CancelWorkflowInstance cancels the given workflow instance. +func (wt *workflowTester[TResult]) CancelWorkflowInstance(wfi *core.WorkflowInstance) error { + if wt.getWorkflow(wfi) == nil { + return backend.ErrInstanceNotFound + } + + wt.callbacks <- func() *history.WorkflowEvent { + e := history.NewPendingEvent( + wt.clock.Now(), + history.EventType_WorkflowExecutionCanceled, + &history.ExecutionCanceledAttributes{}, + ) + + return &history.WorkflowEvent{ + WorkflowInstance: wfi, + HistoryEvent: e, + } + } + + return nil +} + // SignalWorkflow sends a signal to the workflow under test. func (wt *workflowTester[TResult]) SignalWorkflow(name string, value any) { wt.SignalWorkflowInstance(wt.wfi, name, value) diff --git a/tester/tester_subworkflow_test.go b/tester/tester_subworkflow_test.go index 489ab685..72168c03 100644 --- a/tester/tester_subworkflow_test.go +++ b/tester/tester_subworkflow_test.go @@ -3,6 +3,7 @@ package tester import ( "context" "errors" + "fmt" "testing" "time" @@ -109,6 +110,47 @@ func Test_SubWorkflow_Mocked_Failure(t *testing.T) { tester.AssertExpectations(t) } +func Test_SubWorkflow_Cancel(t *testing.T) { + subWorkflow := func(ctx workflow.Context) error { + _, _ = ctx.Done().Receive(ctx) + return ctx.Err() + } + + workflowWithSub := func(ctx workflow.Context) error { + _, err := workflow.CreateSubWorkflowInstance[any]( + ctx, + workflow.DefaultSubWorkflowOptions, + subWorkflow, + ).Get(ctx) + if err != nil { + return fmt.Errorf("subworkflow: %w", err) + } + + return nil + } + + tester := NewWorkflowTester[string](workflowWithSub) + tester.Registry().RegisterWorkflow(subWorkflow) + + var subWorkflowInstance *core.WorkflowInstance + + tester.ListenSubWorkflow(func(instance *core.WorkflowInstance, _ string) { + subWorkflowInstance = instance + }) + + tester.ScheduleCallback(time.Millisecond, func() { + require.NoError(t, tester.CancelWorkflowInstance(subWorkflowInstance)) + }) + + tester.Execute(context.Background()) + + require.True(t, tester.WorkflowFinished()) + + _, err := tester.WorkflowResult() + require.EqualError(t, err, "subworkflow: context canceled") + tester.AssertExpectations(t) +} + func Test_SubWorkflow_Signals(t *testing.T) { subWorkflow := func(ctx workflow.Context, input string) (string, error) { c := workflow.NewSignalChannel[string](ctx, "subworkflow-signal") diff --git a/tester/tester_test.go b/tester/tester_test.go index 3aa745ab..f67ab0fd 100644 --- a/tester/tester_test.go +++ b/tester/tester_test.go @@ -161,6 +161,22 @@ func activityLongRunning(ctx context.Context) (int, error) { return 42, nil } +func Test_CancelWorkflow(t *testing.T) { + tester := NewWorkflowTester[any](func(ctx workflow.Context) error { + _, _ = ctx.Done().Receive(ctx) + return ctx.Err() + }) + tester.ScheduleCallback(time.Duration(time.Second), func() { + tester.CancelWorkflow() + }) + + tester.Execute(context.Background()) + + require.True(t, tester.WorkflowFinished()) + _, err := tester.WorkflowResult() + require.EqualError(t, err, "context canceled") +} + func Test_Signals(t *testing.T) { tester := NewWorkflowTester[string](workflowSignal) tester.ScheduleCallback(time.Duration(5*time.Second), func() {