Skip to content

Commit

Permalink
Merge pull request #166 from wu-cl/master
Browse files Browse the repository at this point in the history
feat: add origin wrapper to execute functions before mock
  • Loading branch information
agiledragon authored Jun 11, 2024
2 parents d56c682 + 1047ffc commit 248313f
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 1 deletion.
17 changes: 16 additions & 1 deletion patch.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

type Patches struct {
originals map[uintptr][]byte
targets map[uintptr]uintptr
values map[reflect.Value]reflect.Value
valueHolders map[reflect.Value]reflect.Value
}
Expand Down Expand Up @@ -70,13 +71,25 @@ func ApplyFuncVarReturn(target interface{}, output ...interface{}) *Patches {
}

func create() *Patches {
return &Patches{originals: make(map[uintptr][]byte), values: make(map[reflect.Value]reflect.Value), valueHolders: make(map[reflect.Value]reflect.Value)}
return &Patches{originals: make(map[uintptr][]byte), targets: map[uintptr]uintptr{},
values: make(map[reflect.Value]reflect.Value), valueHolders: make(map[reflect.Value]reflect.Value)}
}

func NewPatches() *Patches {
return create()
}

func (this *Patches) Origin(fn func()) {
for target, bytes := range this.originals {
modifyBinary(target, bytes)
}
fn()
for target, targetPtr := range this.targets {
code := buildJmpDirective(targetPtr)
modifyBinary(target, code)
}
}

func (this *Patches) ApplyFunc(target, double interface{}) *Patches {
t := reflect.ValueOf(target)
d := reflect.ValueOf(double)
Expand Down Expand Up @@ -214,6 +227,7 @@ func (this *Patches) ApplyCore(target, double reflect.Value) *Patches {
if _, ok := this.originals[assTarget]; !ok {
this.originals[assTarget] = original
}
this.targets[assTarget] = uintptr(getPointer(double))
this.valueHolders[double] = double
return this
}
Expand All @@ -227,6 +241,7 @@ func (this *Patches) ApplyCoreOnlyForPrivateMethod(target unsafe.Pointer, double
if _, ok := this.originals[assTarget]; !ok {
this.originals[assTarget] = original
}
this.targets[assTarget] = uintptr(getPointer(double))
this.valueHolders[double] = double
return this
}
Expand Down
47 changes: 47 additions & 0 deletions test/apply_func_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,32 @@ func TestApplyFunc(t *testing.T) {
So(output, ShouldEqual, outputExpect)
})

Convey("one func for succ with origin", func() {
patches := ApplyFunc(fake.Belong, func(_ string, _ []string) bool {
return false
})
defer patches.Reset()
output := fake.Belong("a", []string{"a", "b"})
So(output, ShouldEqual, false)
patches.Origin(func() {
output = fake.Belong("a", []string{"a", "b"})
})
So(output, ShouldEqual, true)
})

Convey("one func for succ with origin inside", func() {
var output bool
var patches *Patches
patches = ApplyFunc(fake.Belong, func(_ string, _ []string) bool {
patches.Origin(func() {
output = fake.Belong("a", []string{"a", "b"})
So(output, ShouldEqual, true)
})
return false
})
defer patches.Reset()
})

Convey("one func for fail", func() {
patches := ApplyFunc(fake.Exec, func(_ string, _ ...string) (string, error) {
return "", fake.ErrActual
Expand All @@ -51,6 +77,27 @@ func TestApplyFunc(t *testing.T) {
So(flag, ShouldBeTrue)
})

Convey("two funcs with origin", func() {
patches := ApplyFunc(fake.Exec, func(_ string, _ ...string) (string, error) {
return outputExpect, nil
})
defer patches.Reset()
patches.ApplyFunc(fake.Belong, func(_ string, _ []string) bool {
return true
})
output, err := fake.Exec("", "")
So(err, ShouldEqual, nil)
So(output, ShouldEqual, outputExpect)
flag := fake.Belong("", nil)
So(flag, ShouldBeTrue)

var outputBool bool
patches.Origin(func() {
outputBool = fake.Belong("c", []string{"a", "b"})
})
So(outputBool, ShouldEqual, false)
})

Convey("input and output param", func() {
patches := ApplyFunc(json.Unmarshal, func(data []byte, v interface{}) error {
if data == nil {
Expand Down
18 changes: 18 additions & 0 deletions test/apply_method_func_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,24 @@ func TestApplyMethodFunc(t *testing.T) {
So(len(slice), ShouldEqual, 0)
})

Convey("for origin", func() {
patches := ApplyMethodFunc(s, "Add", func(_ int) error {
return nil
})
defer patches.Reset()

var err error
patches.Origin(func() {
err = slice.Add(1)
So(err, ShouldEqual, nil)
err = slice.Add(1)
So(err, ShouldEqual, fake.ErrElemExsit)
err = slice.Remove(1)
So(err, ShouldEqual, nil)
})
So(len(slice), ShouldEqual, 0)
})

Convey("for already exist", func() {
err := slice.Add(2)
So(err, ShouldEqual, nil)
Expand Down

0 comments on commit 248313f

Please sign in to comment.