diff --git a/convey.go b/convey.go index 5c8f307..7e7bd3e 100644 --- a/convey.go +++ b/convey.go @@ -61,3 +61,33 @@ func removeFromGlobal(mocker mockerInstance) { tool.DebugPrintf("%v removed\n", mocker.key()) delete(gMocker[len(gMocker)-1], mocker.key()) } + +// Unpatch all mocks in current 'PatchConvey' context +// +// If the caller is out of 'PatchConvey', it will unpatch all mocks +// +// For example: +// +// Test1(t) { +// Mock(a).Build() +// Mock(b).Build() +// +// // a and b will be unpatched +// UnpatchAll() +// } +// }) +// +// Test2(t) { +// Mock(a).Build() +// PatchConvey(t,func(){ +// Mock(b).Build() +// +// // only b will be unpatched +// UnpatchAll() +// } +// }) +func UnPatchAll() { + for _, mocker := range gMocker[len(gMocker)-1] { + mocker.unPatch() + } +} diff --git a/convey_test.go b/convey_test.go index fbdfbe9..e5a65bc 100644 --- a/convey_test.go +++ b/convey_test.go @@ -98,3 +98,46 @@ func TestConvey(t *testing.T) { }) }) } + +func TestUnpatchAll_Convey(t *testing.T) { + fn1 := func() string { + return "fn1" + } + fn2 := func() string { + return "fn2" + } + fn3 := func() string { + return "fn3" + } + + Mock(fn1).Return("mocked").Build() + if fn1() != "mocked" { + t.Error("mock fn1 failed") + } + + PatchConvey("UnpatchAll_Convey", t, func() { + Mock(fn2).Return("mocked").Build() + Mock(fn3).Return("mocked").Build() + convey.So(fn1(), convey.ShouldEqual, "mocked") + convey.So(fn2(), convey.ShouldEqual, "mocked") + convey.So(fn3(), convey.ShouldEqual, "mocked") + + UnPatchAll() + + convey.So(fn1(), convey.ShouldEqual, "mocked") + convey.So(fn2(), convey.ShouldEqual, "fn2") + convey.So(fn3(), convey.ShouldEqual, "fn3") + }) + + r1, r2, r3 := fn1(), fn2(), fn3() + if r1 != "mocked" || r2 != "fn2" || r3 != "fn3" { + t.Error("mock failed", r1, r2, r3) + } + + UnPatchAll() + + r1, r2, r3 = fn1(), fn2(), fn3() + if r1 != "fn1" || r2 != "fn2" || r3 != "fn3" { + t.Error("mock failed", r1, r2, r3) + } +}