diff --git a/go.mod b/go.mod index 40b505f..a5a2d22 100644 --- a/go.mod +++ b/go.mod @@ -6,9 +6,3 @@ require ( github.com/smartystreets/goconvey v1.6.4 golang.org/x/arch v0.0.0-20201008161808-52c3e6f60cff ) - -require ( - github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 // indirect - github.com/jtolds/gls v4.20.0+incompatible // indirect - github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d // indirect -) diff --git a/internal/unsafereflect/name_above_1_17.go b/internal/unsafereflect/name_above_1_17.go new file mode 100644 index 0000000..27f8cde --- /dev/null +++ b/internal/unsafereflect/name_above_1_17.go @@ -0,0 +1,53 @@ +//go:build go1.17 +// +build go1.17 + +/* + * Copyright 2023 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package unsafereflect + +import "unsafe" + +// name is an encoded type name with optional extra data. +type name struct { + bytes *byte +} + +func (n name) data(off int, whySafe string) *byte { + return (*byte)(add(unsafe.Pointer(n.bytes), uintptr(off), whySafe)) +} + +func (n name) readVarint(off int) (int, int) { + v := 0 + for i := 0; ; i++ { + x := *n.data(off+i, "read varint") + v += int(x&0x7f) << (7 * i) + if x&0x80 == 0 { + return i + 1, v + } + } +} + +func (n name) name() (s string) { + if n.bytes == nil { + return + } + i, l := n.readVarint(1) + hdr := (*_String)(unsafe.Pointer(&s)) + hdr.Data = unsafe.Pointer(n.data(1+i, "non-empty string")) + hdr.Len = l + return +} diff --git a/internal/unsafereflect/name_below_1_17.go b/internal/unsafereflect/name_below_1_17.go new file mode 100644 index 0000000..9096f20 --- /dev/null +++ b/internal/unsafereflect/name_below_1_17.go @@ -0,0 +1,41 @@ +//go:build !go1.17 +// +build !go1.17 + +/* + * Copyright 2023 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package unsafereflect + +import ( + "unsafe" +) + +// name is an encoded type name with optional extra data. +type name struct { + bytes *byte +} + +func (n name) name() (s string) { + if n.bytes == nil { + return + } + b := (*[4]byte)(unsafe.Pointer(n.bytes)) + + hdr := (*_String)(unsafe.Pointer(&s)) + hdr.Data = unsafe.Pointer(&b[3]) + hdr.Len = int(b[1])<<8 | int(b[2]) + return s +} diff --git a/internal/unsafereflect/type.go b/internal/unsafereflect/type.go new file mode 100644 index 0000000..af1209d --- /dev/null +++ b/internal/unsafereflect/type.go @@ -0,0 +1,255 @@ +/* + * Copyright 2023 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Unsafe reflect package for mockey, copy most code from go/src/reflect/type.go, + * allow to export the address of private member methods. + */ + +package unsafereflect + +import ( + "reflect" + "unsafe" +) + +func MethodByName(target interface{}, name string) (fn unsafe.Pointer, ok bool) { + r := castRType(target) + rt := toRType(r) + if r.Kind() == reflect.Interface { + return funcPointer(r.MethodByName(name)) + } + + for _, p := range rt.methods() { + if rt.nameOff(p.name).name() == name { + return rt.Method(p), true + } + } + return nil, false +} + +// rtype is the common implementation of most values. +// It is embedded in other struct types. +// +// rtype must be kept in sync with src/runtime/type.go:/^type._type. +type rtype struct { + size uintptr + ptrdata uintptr // number of bytes in the type that can contain pointers + hash uint32 // hash of type; avoids computation in hash tables + tflag tflag // extra type information flags + align uint8 // alignment of variable with this type + fieldAlign uint8 // alignment of struct field with this type + kind uint8 // enumeration for C + // function for comparing objects of this type + // (ptr to object A, ptr to object B) -> ==? + equal func(unsafe.Pointer, unsafe.Pointer) bool + gcdata *byte // garbage collection data + str nameOff // string form + ptrToThis typeOff // type for pointer to this type, may be zero +} + +func castRType(val interface{}) reflect.Type { + if rTypeVal, ok := val.(reflect.Type); ok { + return rTypeVal + } + return reflect.TypeOf(val) +} + +func toRType(t reflect.Type) *rtype { + i := *(*funcValue)(unsafe.Pointer(&t)) + r := (*rtype)(i.p) + return r +} + +type funcValue struct { + _ uintptr + p unsafe.Pointer +} + +func funcPointer(v reflect.Method, ok bool) (unsafe.Pointer, bool) { + return (*funcValue)(unsafe.Pointer(&v.Func)).p, ok +} + +func (t *rtype) Method(p method) (fn unsafe.Pointer) { + tfn := t.textOff(p.tfn) + fn = unsafe.Pointer(&tfn) + return +} + +const kindMask = (1 << 5) - 1 + +func (t *rtype) Kind() reflect.Kind { return reflect.Kind(t.kind & kindMask) } + +type tflag uint8 +type nameOff int32 // offset to a name +type typeOff int32 // offset to an *rtype +type textOff int32 // offset from top of text section + +// resolveNameOff resolves a name offset from a base pointer. +// The (*rtype).nameOff method is a convenience wrapper for this function. +// Implemented in the runtime package. +// +//go:linkname resolveNameOff reflect.resolveNameOff +func resolveNameOff(unsafe.Pointer, int32) unsafe.Pointer + +func (t *rtype) nameOff(off nameOff) name { + return name{(*byte)(resolveNameOff(unsafe.Pointer(t), int32(off)))} +} + +// resolveTextOff resolves a function pointer offset from a base type. +// The (*rtype).textOff method is a convenience wrapper for this function. +// Implemented in the runtime package. +// +//go:linkname resolveTextOff reflect.resolveTextOff +func resolveTextOff(unsafe.Pointer, int32) unsafe.Pointer + +func (t *rtype) textOff(off textOff) unsafe.Pointer { + return resolveTextOff(unsafe.Pointer(t), int32(off)) +} + +const tflagUncommon tflag = 1 << 0 + +// uncommonType is present only for defined types or types with methods +type uncommonType struct { + pkgPath nameOff // import path; empty for built-in types like int, string + mcount uint16 // number of methods + xcount uint16 // number of exported methods + moff uint32 // offset from this uncommontype to [mcount]method + _ uint32 // unused +} + +// ptrType represents a pointer type. +type ptrType struct { + rtype + elem *rtype // pointer element (pointed at) type +} + +// funcType represents a function type. +type funcType struct { + rtype + inCount uint16 + outCount uint16 // top bit is set if last input parameter is ... +} + +func (t *funcType) in() []*rtype { + uadd := unsafe.Sizeof(*t) + if t.tflag&tflagUncommon != 0 { + uadd += unsafe.Sizeof(uncommonType{}) + } + if t.inCount == 0 { + return nil + } + return (*[1 << 20]*rtype)(add(unsafe.Pointer(t), uadd, "t.inCount > 0"))[:t.inCount:t.inCount] +} + +func (t *funcType) out() []*rtype { + uadd := unsafe.Sizeof(*t) + if t.tflag&tflagUncommon != 0 { + uadd += unsafe.Sizeof(uncommonType{}) + } + outCount := t.outCount & (1<<15 - 1) + if outCount == 0 { + return nil + } + return (*[1 << 20]*rtype)(add(unsafe.Pointer(t), uadd, "outCount > 0"))[t.inCount : t.inCount+outCount : t.inCount+outCount] +} + +func (t *rtype) IsVariadic() bool { + tt := (*funcType)(unsafe.Pointer(t)) + return tt.outCount&(1<<15) != 0 +} + +func add(p unsafe.Pointer, x uintptr, whySafe string) unsafe.Pointer { + return unsafe.Pointer(uintptr(p) + x) +} + +// interfaceType represents an interface type. +type interfaceType struct { + rtype + pkgPath name // import path + methods []imethod // sorted by hash +} + +type imethod struct { + name nameOff // name of method + typ typeOff // .(*FuncType) underneath +} + +func (t *rtype) methods() []method { + if t.tflag&tflagUncommon == 0 { + return nil + } + switch t.Kind() { + case reflect.Ptr: + type u struct { + ptrType + u uncommonType + } + return (*u)(unsafe.Pointer(t)).u.methods() + case reflect.Func: + type u struct { + funcType + u uncommonType + } + return (*u)(unsafe.Pointer(t)).u.methods() + case reflect.Interface: + type u struct { + interfaceType + u uncommonType + } + return (*u)(unsafe.Pointer(t)).u.methods() + case reflect.Struct: + type u struct { + structType + u uncommonType + } + return (*u)(unsafe.Pointer(t)).u.methods() + default: + return nil + } +} + +// Method on non-interface type +type method struct { + name nameOff // name of method + mtyp typeOff // method type (without receiver), not valid for private methods + ifn textOff // fn used in interface call (one-word receiver) + tfn textOff // fn used for normal method call +} + +func (t *uncommonType) methods() []method { + if t.mcount == 0 { + return nil + } + return (*[1 << 16]method)(add(unsafe.Pointer(t), uintptr(t.moff), "t.mcount > 0"))[:t.mcount:t.mcount] +} + +// Struct field +type structField struct { + name name // name is always non-empty + typ *rtype // type of field + offset uintptr // byte offset of field +} + +// structType +type structType struct { + rtype + pkgPath name + fields []structField // sorted by offset +} + +type _String struct { + Data unsafe.Pointer + Len int +} diff --git a/internal/unsafereflect/type_above_1_17_test.go b/internal/unsafereflect/type_above_1_17_test.go new file mode 100644 index 0000000..7f5d459 --- /dev/null +++ b/internal/unsafereflect/type_above_1_17_test.go @@ -0,0 +1,63 @@ +//go:build go1.17 +// +build go1.17 + +/* + * Copyright 2023 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package unsafereflect_test + +import ( + "crypto/sha256" + "hash" + "reflect" + "testing" + "unsafe" + + "github.com/bytedance/mockey" + "github.com/bytedance/mockey/internal/tool" + "github.com/bytedance/mockey/internal/unsafereflect" +) + +func TestMethodByNameV17(t *testing.T) { + // private structure private method: *sha256.digest.checkSum + tfn, ok := unsafereflect.MethodByName(sha256.New(), "checkSum") + tool.Assert(ok, "private member of private structure is allowed") + // type of `func(*sha256.digest, []byte) [32]byte` + pFn := unsafe.Pointer(&tfn) + + mockey.PatchConvey("InterfaceFuncReturn", t, func() { + fn := *(*func(hash.Hash) [sha256.Size]byte)(pFn) + // Interface to fit the function shape is allowed here + mockey.Mock(fn).Return([sha256.Size]byte{1: 1}).Build() + rets := sha256.New().Sum(nil) + want := make([]byte, sha256.Size) + want[1] = 1 + tool.Assert(reflect.DeepEqual(want, rets), "the method should be mocked") + }) + + mockey.PatchConvey("InterfaceFuncTo", t, func() { + fn := *(*func(hash.Hash) [sha256.Size]byte)(pFn) + // Interface to fit the function shape is allowed here, + // since the receiver's type is interface, To API can be used here + mockey.Mock(fn).To(func(hash.Hash) [sha256.Size]byte { + return [sha256.Size]byte{1: 1} + }).Build() + rets := sha256.New().Sum(nil) + want := make([]byte, sha256.Size) + want[1] = 1 + tool.Assert(reflect.DeepEqual(want, rets), "the method should be mocked") + }) +} diff --git a/internal/unsafereflect/type_test.go b/internal/unsafereflect/type_test.go new file mode 100644 index 0000000..1cf8f12 --- /dev/null +++ b/internal/unsafereflect/type_test.go @@ -0,0 +1,51 @@ +/* + * Copyright 2023 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package unsafereflect_test + +import ( + "crypto/sha256" + "reflect" + "testing" + "unsafe" + + "github.com/bytedance/mockey" + "github.com/bytedance/mockey/internal/tool" + "github.com/bytedance/mockey/internal/unsafereflect" +) + +func TestMethodByName(t *testing.T) { + // private structure private method: *sha256.digest.checkSum + tfn, ok := unsafereflect.MethodByName(sha256.New(), "checkSum") + tool.Assert(ok, "private member of private structure is allowed") + // type of `func(*sha256.digest, []byte) [32]byte` + pFn := unsafe.Pointer(&tfn) + + mockey.PatchConvey("ReflectFuncReturn", t, func() { + f := reflect.FuncOf([]reflect.Type{reflect.TypeOf(sha256.New())}, + []reflect.Type{reflect.TypeOf([sha256.Size]byte{})}, false) + fn := reflect.NewAt(f, pFn).Elem().Interface() + // Such function cannot be exported as `(*sha256.digest).checkSum`, + // since the receiver's type is *sha256.digest, only Return API can be used + mockey.Mock(fn).Return([sha256.Size]byte{1: 1}).Build() + rets := sha256.New().Sum(nil) + want := make([]byte, sha256.Size) + want[1] = 1 + tool.Assert(reflect.DeepEqual(want, rets), "the method should be mocked") + }) + + // See also TestMethodByNameV17 while go version above 1.17 +} diff --git a/utils_above_1_18.go b/utils_above_1_18.go new file mode 100644 index 0000000..d4acc29 --- /dev/null +++ b/utils_above_1_18.go @@ -0,0 +1,47 @@ +//go:build go1.18 +// +build go1.18 + +/* + * Copyright 2023 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mockey + +import ( + "unsafe" + + "github.com/bytedance/mockey/internal/tool" + "github.com/bytedance/mockey/internal/unsafereflect" +) + +// GetPrivateMemberMethod resolve a method from an instance, include private method. +// +// F must fit the shape of specific method, include receiver as the first argument. +// Especially, the receiver can be replaced as interface when F is declaring, +// this will be very useful when receiver type is not exported for other packages. +// +// for example: +// +// GetPrivateMemberMethod[func(*bytes.Buffer) bool](&bytes.Buffer{}, "empty") +// GetPrivateMemberMethod[func(hash.Hash, []byte) [sha256.Size]byte](sha256.New(), "checkSum") +func GetPrivateMemberMethod[F interface{}](instance interface{}, methodName string) interface{} { + tfn, ok := unsafereflect.MethodByName(instance, methodName) + if !ok { + tool.Assert(false, "can't reflect instance method :%v", methodName) + return nil + } + // return with (unsafe) function type cast + return *(*F)(unsafe.Pointer(&tfn)) +} diff --git a/utils_above_1_18_test.go b/utils_above_1_18_test.go new file mode 100644 index 0000000..56a2757 --- /dev/null +++ b/utils_above_1_18_test.go @@ -0,0 +1,68 @@ +//go:build go1.18 +// +build go1.18 + +/* + * Copyright 2023 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mockey + +import ( + "bytes" + "testing" + + "github.com/bytedance/mockey/internal/tool" + "github.com/smartystreets/goconvey/convey" +) + +func TestGetPrivateMemberMethod(t *testing.T) { + PatchConvey("FakeMethod", t, func() { + convey.So(func() { + GetPrivateMemberMethod[func()](&bytes.Buffer{}, "FakeMethod") + }, convey.ShouldPanicWith, "can't reflect instance method :FakeMethod") + }) + + PatchConvey("OriginalGetMethod", t, func() { + convey.So(func() { + GetMethod(&bytes.Buffer{}, "empty") + }, convey.ShouldPanicWith, "can't reflect instance method :empty") + }) + + PatchConvey("ExportFunc", t, func() { + convey.So(func() { + exportedFunc := GetPrivateMemberMethod[func(*bytes.Buffer) int](&bytes.Buffer{}, "Len") + var mocked bool + Mock(exportedFunc).To(func(buffer *bytes.Buffer) int { + mocked = true + return 0 + }).Build() + _ = new(bytes.Buffer).Len() + tool.Assert(mocked, "function should be mocked") + }, convey.ShouldNotPanic) + }) + + PatchConvey("PrivateFunc", t, func() { + convey.So(func() { + privateFunc := GetPrivateMemberMethod[func(*bytes.Buffer) bool](&bytes.Buffer{}, "empty") + var mocked bool + Mock(privateFunc).To(func(buffer *bytes.Buffer) bool { + mocked = true + return true + }).Build() + _, _ = new(bytes.Buffer).ReadByte() + tool.Assert(mocked, "function should be mocked") + }, convey.ShouldNotPanic) + }) +}