Skip to content

Commit

Permalink
Merge pull request #40 from bytedance/generic
Browse files Browse the repository at this point in the history
fix: wrong parameter value when using To/When on generic functions
  • Loading branch information
Sychorius authored Oct 23, 2023
2 parents 5cea56e + 47be19d commit 58ee2a5
Show file tree
Hide file tree
Showing 8 changed files with 252 additions and 76 deletions.
2 changes: 1 addition & 1 deletion internal/monkey/fn/copy_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func Copy(targetPtr, oriFn interface{}) {
targetType := targetVal.Type().Elem()
tool.Assert(targetType.Kind() == reflect.Func, "'%v' is not a function pointer", targetPtr)
oriVal := reflect.ValueOf(oriFn)
tool.Assert(tool.CheckFuncArgs(targetType, oriVal.Type(), 0), "target and ori not match")
tool.Assert(tool.CheckFuncArgs(targetType, oriVal.Type(), 0, 0), "target and ori not match")

oriAddr := oriVal.Pointer()
tool.DebugPrintf("Copy: copy start for %v\n", runtime.FuncForPC(oriAddr).Name())
Expand Down
2 changes: 0 additions & 2 deletions internal/monkey/patch.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ func (p *Patch) Unpatch() {
func PatchValue(target, hook, proxy reflect.Value, unsafe, generic bool) *Patch {
tool.Assert(hook.Kind() == reflect.Func, "'%s' is not a function", hook.Kind())
tool.Assert(proxy.Kind() == reflect.Ptr, "'%v' is not a function pointer", proxy.Kind())
tool.Assert(hook.Type() == target.Type(), "'%v' and '%s' mismatch", hook.Type(), target.Type())
tool.Assert(proxy.Elem().Type() == target.Type(), "'*%v' and '%s' mismatch", proxy.Elem().Type(), target.Type())

targetAddr := target.Pointer()
if generic {
Expand Down
7 changes: 0 additions & 7 deletions internal/tool/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,6 @@ import (
"reflect"
)

func ReflectCallWithShiftOne(f reflect.Value, args []reflect.Value, shift bool) []reflect.Value {
if shift {
return ReflectCall(f, args[1:])
}
return ReflectCall(f, args)
}

func ReflectCall(f reflect.Value, args []reflect.Value) []reflect.Value {
if f.Type().IsVariadic() {
newArgs := make([]reflect.Value, 0)
Expand Down
8 changes: 4 additions & 4 deletions internal/tool/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ func CheckReturnType(fn interface{}, results ...interface{}) {
}
}

func CheckFuncArgs(a, b reflect.Type, shift int) bool {
if a.NumIn() == b.NumIn()+shift {
for i := shift; i < a.NumIn(); i++ {
if a.In(i) != b.In(i-shift) {
func CheckFuncArgs(a, b reflect.Type, shiftA, shiftB int) bool {
if a.NumIn()-shiftA == b.NumIn()-shiftB {
for indexA, indexB := shiftA, shiftB; indexA < a.NumIn(); indexA, indexB = indexA+1, indexB+1 {
if a.In(indexA) != b.In(indexB) {
return false
}
}
Expand Down
105 changes: 65 additions & 40 deletions mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,31 +34,35 @@ const (
)

type Mocker struct {
target reflect.Value // 目标函数
hook reflect.Value // mock函数
proxy interface{} // mock之后,原函数地址
target reflect.Value // mock target value
hook reflect.Value // mock hook
proxy interface{} // proxy function to origin
times int64
mockTimes int64
patch *monkey.Patch
lock sync.Mutex
isPatched bool
builder *MockBuilder

outerCaller tool.CallerInfo // Mocker 的外部调用位置
outerCaller tool.CallerInfo
}

type MockBuilder struct {
target interface{} // 目标函数
// hook interface{} // mock函数
proxyCaller interface{} // mock之后,原函数地址
// when interface{} // 条件函数
conditions []*mockCondition // 条件转移
target interface{} // mock target
proxyCaller interface{} // origin function caller hook
conditions []*mockCondition // mock conditions
filterGoroutine FilterGoroutineType
gId int64
unsafe bool
generic bool
}

// Mock mocks target function
//
// If target is a generic method or method of generic types, you need add a genericOpt, like this:
//
// func f[int, float64](x int, y T1) T2
// Mock(f[int, float64], OptGeneric)
func Mock(target interface{}, opt ...optionFn) *MockBuilder {
tool.AssertFunc(target)

Expand All @@ -79,11 +83,38 @@ func MockUnsafe(target interface{}) *MockBuilder {
return Mock(target, OptUnsafe)
}

func (builder *MockBuilder) hookType() reflect.Type {
targetType := reflect.TypeOf(builder.target)
if builder.generic {
targetIn := []reflect.Type{genericInfoType}
for i := 0; i < targetType.NumIn(); i++ {
targetIn = append(targetIn, targetType.In(i))
}
targetOut := []reflect.Type{}
for i := 0; i < targetType.NumOut(); i++ {
targetOut = append(targetOut, targetType.Out(i))
}
return reflect.FuncOf(targetIn, targetOut, targetType.IsVariadic())
}
return targetType
}

func (builder *MockBuilder) resetCondition() *MockBuilder {
builder.conditions = []*mockCondition{builder.newCondition()} // at least 1 condition is needed
return builder
}

// Origin add an origin hook which can be used to call un-mocked origin function
//
// For example:
//
// origin := Fun // only need the same type
// mock := func(p string) string {
// return origin(p + "mocked")
// }
// mock2 := Mock(Fun).To(mock).Origin(&origin).Build()
//
// Origin only works when call origin hook directly, target will still be mocked in recursive call
func (builder *MockBuilder) Origin(funcPtr interface{}) *MockBuilder {
tool.Assert(builder.proxyCaller == nil, "re-set builder origin")
return builder.origin(funcPtr)
Expand Down Expand Up @@ -187,15 +218,15 @@ func (builder *MockBuilder) Build() *Mocker {
return &mocker
}

func (mocker *Mocker) checkReceiver(target reflect.Type, hook interface{}) bool {
func (mocker *Mocker) missReceiver(target reflect.Type, hook interface{}) bool {
hType := reflect.TypeOf(hook)
tool.Assert(hType.Kind() == reflect.Func, "Param(%v) a is not a func", hType.Kind())
tool.Assert(target.IsVariadic() == hType.IsVariadic(), "target:%v, hook:%v args not match", target, hook)
// has receiver
if tool.CheckFuncArgs(target, hType, 0) {
if tool.CheckFuncArgs(target, hType, 0, 0) {
return false
}
if tool.CheckFuncArgs(target, hType, 1) {
if tool.CheckFuncArgs(target, hType, 1, 0) {
return true
}
tool.Assert(false, "target:%v, hook:%v args not match", target, hook)
Expand All @@ -205,40 +236,36 @@ func (mocker *Mocker) checkReceiver(target reflect.Type, hook interface{}) bool
func (mocker *Mocker) buildHook() {
proxySetter := mocker.buildProxy()

origin := reflect.ValueOf(mocker.proxy).Elem()
originExec := func(args []reflect.Value) []reflect.Value {
return tool.ReflectCall(origin, args)
return tool.ReflectCall(reflect.ValueOf(mocker.proxy).Elem(), args)
}

match := []func(args []reflect.Value) bool{}
exec := []func(args []reflect.Value) []reflect.Value{}

for _, condition := range mocker.builder.conditions {
when := condition.when
hook := condition.hook

if when == nil {
for i := range mocker.builder.conditions {
condition := mocker.builder.conditions[i]
if condition.when == nil {
// when condition is not set, just go into hook exec
match = append(match, func(args []reflect.Value) bool { return true })
} else {
missWhenReceiver := mocker.checkReceiver(mocker.target.Type(), when)
match = append(match, func(args []reflect.Value) bool {
return tool.ReflectCallWithShiftOne(reflect.ValueOf(when), args, missWhenReceiver)[0].Bool()
return tool.ReflectCall(reflect.ValueOf(condition.when), args)[0].Bool()
})
}

if hook == nil {
if condition.hook == nil {
// hook condition is not set, just go into original exec
exec = append(exec, originExec)
} else {
missHookReceiver := mocker.checkReceiver(mocker.target.Type(), hook)
exec = append(exec, func(args []reflect.Value) []reflect.Value {
mocker.mock()
return tool.ReflectCallWithShiftOne(reflect.ValueOf(hook), args, missHookReceiver)
return tool.ReflectCall(reflect.ValueOf(condition.hook), args)
})
}
}

mockerHook := reflect.MakeFunc(mocker.target.Type(), func(args []reflect.Value) []reflect.Value {
mockerHook := reflect.MakeFunc(mocker.builder.hookType(), func(args []reflect.Value) []reflect.Value {
proxySetter(args) // 设置origin调用proxy

mocker.access()
Expand Down Expand Up @@ -267,29 +294,27 @@ func (mocker *Mocker) buildHook() {
mocker.hook = mockerHook
}

// buildProx create a proxyCaller which could call origin directly
func (mocker *Mocker) buildProxy() func(args []reflect.Value) {
proxy := reflect.New(mocker.target.Type())
proxy := reflect.New(mocker.builder.hookType())

proxyCallerSetter := func(args []reflect.Value) {}
missProxyReceiver := false
if mocker.builder.proxyCaller != nil {
pVal := reflect.ValueOf(mocker.builder.proxyCaller)
tool.Assert(pVal.Kind() == reflect.Ptr && pVal.Elem().Kind() == reflect.Func, "origin receiver must be a function pointer")
pElem := pVal.Elem()
missProxyReceiver = mocker.checkReceiver(mocker.target.Type(), pElem.Interface())

if missProxyReceiver {
proxyCallerSetter = func(args []reflect.Value) {
pElem.Set(reflect.MakeFunc(pElem.Type(), func(innerArgs []reflect.Value) (results []reflect.Value) {
return tool.ReflectCall(proxy.Elem(), append(args[0:1], innerArgs...))
}))
}
} else {
proxyCallerSetter = func(args []reflect.Value) {
pElem.Set(reflect.MakeFunc(pElem.Type(), func(innerArgs []reflect.Value) (results []reflect.Value) {
return tool.ReflectCall(proxy.Elem(), innerArgs)
}))
}
shift := 0
if mocker.builder.generic {
shift += 1
}
if mocker.missReceiver(mocker.target.Type(), pElem.Interface()) {
shift += 1
}
proxyCallerSetter = func(args []reflect.Value) {
pElem.Set(reflect.MakeFunc(pElem.Type(), func(innerArgs []reflect.Value) (results []reflect.Value) {
return tool.ReflectCall(proxy.Elem(), append(args[0:shift], innerArgs...))
}))
}
}
mocker.proxy = proxy.Interface()
Expand Down
85 changes: 68 additions & 17 deletions mock_condition.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,19 @@ func (m *mockCondition) SetWhenForce(when interface{}) {
tool.Assert(wVal.Type().NumOut() == 1, "when func ret value not bool")
out1 := wVal.Type().Out(0)
tool.Assert(out1.Kind() == reflect.Bool, "when func ret value not bool")
checkReceiver(reflect.TypeOf(m.builder.target), when) // inputs must be in same or has an extra self receiver
m.when = when

hookType := m.builder.hookType()
inTypes := []reflect.Type{}
for i := 0; i < hookType.NumIn(); i++ {
inTypes = append(inTypes, hookType.In(i))
}

hasGeneric, hasReceiver := m.checkGenericAndReceiver(wVal.Type())
whenType := reflect.FuncOf(inTypes, []reflect.Type{out1}, hookType.IsVariadic())
m.when = reflect.MakeFunc(whenType, func(args []reflect.Value) (results []reflect.Value) {
results = tool.ReflectCall(wVal, m.adaptArgsForReflectCall(args, hasGeneric, hasReceiver))
return
}).Interface()
}

func (m *mockCondition) SetReturn(results ...interface{}) {
Expand All @@ -61,15 +72,15 @@ func (m *mockCondition) SetReturnForce(results ...interface{}) {
}
}

targetType := reflect.TypeOf(m.builder.target)
m.hook = reflect.MakeFunc(targetType, func(args []reflect.Value) []reflect.Value {
hookType := m.builder.hookType()
m.hook = reflect.MakeFunc(hookType, func(_ []reflect.Value) []reflect.Value {
results := getResult()
tool.CheckReturnType(m.builder.target, results...)
valueResults := make([]reflect.Value, 0)
for i, result := range results {
rValue := reflect.Zero(targetType.Out(i))
rValue := reflect.Zero(hookType.Out(i))
if result != nil {
rValue = reflect.ValueOf(result).Convert(targetType.Out(i))
rValue = reflect.ValueOf(result).Convert(hookType.Out(i))
}
valueResults = append(valueResults, rValue)
}
Expand All @@ -85,20 +96,60 @@ func (m *mockCondition) SetTo(to interface{}) {
func (m *mockCondition) SetToForce(to interface{}) {
hType := reflect.TypeOf(to)
tool.Assert(hType.Kind() == reflect.Func, "to a is not a func")
m.hook = to
hasGeneric, hasReceiver := m.checkGenericAndReceiver(hType)
tool.Assert(m.builder.generic || !hasGeneric, "non-generic function should not have 'GenericInfo' as first argument")
m.hook = reflect.MakeFunc(m.builder.hookType(), func(args []reflect.Value) (results []reflect.Value) {
results = tool.ReflectCall(reflect.ValueOf(to), m.adaptArgsForReflectCall(args, hasGeneric, hasReceiver))
return
}).Interface()
}

func checkReceiver(target reflect.Type, hook interface{}) bool {
hType := reflect.TypeOf(hook)
tool.Assert(hType.Kind() == reflect.Func, "Param(%v) a is not a func", hType.Kind())
tool.Assert(target.IsVariadic() == hType.IsVariadic(), "target:%v, hook:%v args not match", target, hook)
// checkGenericAndReceiver check if typ has GenericsInfo and selfReceiver as argument
//
// The hook function will looks like func(_ GenericInfo, self *struct, arg0 int ...)
// When we use 'When' or 'To', our input hook function will looks like:
// 1. func(arg0 int ...)
// 2. func(info GenericInfo, arg0 int ...)
// 3. func(self *struct, arg0 int ...)
// 4. func(info GenericInfo, self *struct, arg0 int ...)
//
// All above input hooks are legal, but we need to make an adaptation when calling then
func (m *mockCondition) checkGenericAndReceiver(typ reflect.Type) (bool, bool) {
targetType := reflect.TypeOf(m.builder.target)
tool.Assert(typ.Kind() == reflect.Func, "Param(%v) a is not a func", typ.Kind())
tool.Assert(targetType.IsVariadic() == typ.IsVariadic(), "target:%v, hook:%v args not match", targetType, typ)

shiftTyp := 0
if typ.NumIn() > 0 && typ.In(0) == genericInfoType {
shiftTyp = 1
}

// has receiver
if tool.CheckFuncArgs(target, hType, 0) {
return false
if tool.CheckFuncArgs(targetType, typ, 0, shiftTyp) {
return shiftTyp == 1, true
}

if tool.CheckFuncArgs(targetType, typ, 1, shiftTyp) {
return shiftTyp == 1, false
}
tool.Assert(false, "target:%v, hook:%v args not match", targetType, typ)
return false, false
}

// adaptArgsForReflectCall makes an adaption for reflect call
//
// see (*mockCondition).checkGenericAndReceiver for more info
func (m *mockCondition) adaptArgsForReflectCall(args []reflect.Value, hasGeneric, hasReceiver bool) []reflect.Value {
adaption := []reflect.Value{}
if m.builder.generic {
if hasGeneric {
adaption = append(adaption, args[0])
}
args = args[1:]
}
if tool.CheckFuncArgs(target, hType, 1) {
return true
if !hasReceiver {
args = args[1:]
}
tool.Assert(false, "target:%v, hook:%v args not match", target, hook)
return false
adaption = append(adaption, args...)
return adaption
}
33 changes: 33 additions & 0 deletions mock_generics.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,39 @@

package mockey

import (
"reflect"
"unsafe"
)

// MockGeneric mocks generic function
//
// Target must be generic method or method of generic types
func MockGeneric(target interface{}) *MockBuilder {
return Mock(target, OptGeneric)
}

type GenericInfo uintptr

var genericInfoType = reflect.TypeOf(GenericInfo(0))

func (g GenericInfo) Equal(other GenericInfo) bool {
return g == other
}

// UsedParamType get the type of used parameter in generic function/struct
//
// For example: assume we have generic function "f[int, float64](x int, y T1) T2" and derived type f[int, float64]:
//
// UsedParamType(0) == reflect.TypeOf(int(0))
// UsedParamType(1) == reflect.TypeOf(float64(0))
//
// If index n is out of range, or the derived types have more complex structure(for example: define an generic struct
// in a generic function using generic types, unused parameterized type etc.), this function may return unexpected value
// or cause unrecoverable runtime error . So it is NOT RECOMMENDED to use this function unless you actually knows what
// you are doing.
func (g GenericInfo) UsedParamType(n uintptr) reflect.Type {
var vt interface{}
*(*uintptr)(unsafe.Pointer(&vt)) = *(*uintptr)(unsafe.Pointer(uintptr(g) + 8*n))
return reflect.TypeOf(vt)
}
Loading

0 comments on commit 58ee2a5

Please sign in to comment.