From 9a80e3fa7d2469b9e8fac5502de2ef1dac2c0fe4 Mon Sep 17 00:00:00 2001 From: Jared Ticotin Date: Fri, 13 Dec 2024 14:05:59 -0500 Subject: [PATCH] handle interfaces with generics --- arguments/parser.go | 7 +- fixtures/genericinterface/genericinterface.go | 21 ++ .../fake_generic_interface.go | 245 ++++++++++++++++++ .../fake_generic_interface2.go | 245 ++++++++++++++++++ generator/fake.go | 29 ++- generator/interface_template.go | 20 +- generator/loader.go | 39 ++- 7 files changed, 580 insertions(+), 26 deletions(-) create mode 100644 fixtures/genericinterface/genericinterface.go create mode 100644 fixtures/genericinterface/genericinterfacefakes/fake_generic_interface.go create mode 100644 fixtures/genericinterface/genericinterfacefakes/fake_generic_interface2.go diff --git a/arguments/parser.go b/arguments/parser.go index 557d495..edc32b4 100644 --- a/arguments/parser.go +++ b/arguments/parser.go @@ -91,7 +91,7 @@ func New(args []string, workingDir string, evaler Evaler, stater Stater) (*Parse } func (a *ParsedArguments) PrettyPrint() { - b, _ := json.Marshal(a) + b, _ := json.MarshalIndent(a, "", " ") fmt.Println(string(b)) } @@ -105,6 +105,7 @@ func (a *ParsedArguments) parseInterfaceName(packageMode bool, args []string) { a.InterfaceName = fullyQualifiedInterface[len(fullyQualifiedInterface)-1] } else { a.InterfaceName = args[1] + a.InterfaceName = strings.Split(a.InterfaceName, "[")[0] } } @@ -141,7 +142,8 @@ func (a *ParsedArguments) parseOutputPath(packageMode bool, workingDir string, o if strings.HasSuffix(outputPath, ".go") { outputPathIsFilename = true } - snakeCaseName := strings.ToLower(camelRegexp.ReplaceAllString(a.FakeImplName, "${1}_${2}")) + fakeImplName := strings.Split(a.FakeImplName, "[")[0] + snakeCaseName := strings.ToLower(camelRegexp.ReplaceAllString(fakeImplName, "${1}_${2}")) if outputPath != "" { if !filepath.IsAbs(outputPath) { @@ -187,6 +189,7 @@ func (a *ParsedArguments) parsePackagePath(packageMode bool, args []string) { a.PackagePath = strings.Join(fullyQualifiedInterface[:len(fullyQualifiedInterface)-1], ".") } else { a.InterfaceName = args[1] + a.InterfaceName = strings.Split(a.InterfaceName, "[")[0] } if a.PackagePath == "" { diff --git a/fixtures/genericinterface/genericinterface.go b/fixtures/genericinterface/genericinterface.go new file mode 100644 index 0000000..8addfc6 --- /dev/null +++ b/fixtures/genericinterface/genericinterface.go @@ -0,0 +1,21 @@ +package genericinterface + +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate + +type CustomType any + +//counterfeiter:generate . GenericInterface[T CustomType] +type GenericInterface[T CustomType] interface { + ReturnT() T + TakeT(T) + TakeAndReturnT(T) T + DoSomething() +} + +//counterfeiter:generate . GenericInterface2 +type GenericInterface2[T CustomType] interface { + ReturnT() T + TakeT(T) + TakeAndReturnT(T) T + DoSomething() +} diff --git a/fixtures/genericinterface/genericinterfacefakes/fake_generic_interface.go b/fixtures/genericinterface/genericinterfacefakes/fake_generic_interface.go new file mode 100644 index 0000000..16e9476 --- /dev/null +++ b/fixtures/genericinterface/genericinterfacefakes/fake_generic_interface.go @@ -0,0 +1,245 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package genericinterfacefakes + +import ( + "sync" + + "github.com/maxbrunsfeld/counterfeiter/v6/fixtures/genericinterface" +) + +type FakeGenericInterface[T genericinterface.CustomType] struct { + DoSomethingStub func() + doSomethingMutex sync.RWMutex + doSomethingArgsForCall []struct { + } + ReturnTStub func() T + returnTMutex sync.RWMutex + returnTArgsForCall []struct { + } + returnTReturns struct { + result1 T + } + returnTReturnsOnCall map[int]struct { + result1 T + } + TakeAndReturnTStub func(T) T + takeAndReturnTMutex sync.RWMutex + takeAndReturnTArgsForCall []struct { + arg1 T + } + takeAndReturnTReturns struct { + result1 T + } + takeAndReturnTReturnsOnCall map[int]struct { + result1 T + } + TakeTStub func(T) + takeTMutex sync.RWMutex + takeTArgsForCall []struct { + arg1 T + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeGenericInterface[T]) DoSomething() { + fake.doSomethingMutex.Lock() + fake.doSomethingArgsForCall = append(fake.doSomethingArgsForCall, struct { + }{}) + stub := fake.DoSomethingStub + fake.recordInvocation("DoSomething", []interface{}{}) + fake.doSomethingMutex.Unlock() + if stub != nil { + fake.DoSomethingStub() + } +} + +func (fake *FakeGenericInterface[T]) DoSomethingCallCount() int { + fake.doSomethingMutex.RLock() + defer fake.doSomethingMutex.RUnlock() + return len(fake.doSomethingArgsForCall) +} + +func (fake *FakeGenericInterface[T]) DoSomethingCalls(stub func()) { + fake.doSomethingMutex.Lock() + defer fake.doSomethingMutex.Unlock() + fake.DoSomethingStub = stub +} + +func (fake *FakeGenericInterface[T]) ReturnT() T { + fake.returnTMutex.Lock() + ret, specificReturn := fake.returnTReturnsOnCall[len(fake.returnTArgsForCall)] + fake.returnTArgsForCall = append(fake.returnTArgsForCall, struct { + }{}) + stub := fake.ReturnTStub + fakeReturns := fake.returnTReturns + fake.recordInvocation("ReturnT", []interface{}{}) + fake.returnTMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeGenericInterface[T]) ReturnTCallCount() int { + fake.returnTMutex.RLock() + defer fake.returnTMutex.RUnlock() + return len(fake.returnTArgsForCall) +} + +func (fake *FakeGenericInterface[T]) ReturnTCalls(stub func() T) { + fake.returnTMutex.Lock() + defer fake.returnTMutex.Unlock() + fake.ReturnTStub = stub +} + +func (fake *FakeGenericInterface[T]) ReturnTReturns(result1 T) { + fake.returnTMutex.Lock() + defer fake.returnTMutex.Unlock() + fake.ReturnTStub = nil + fake.returnTReturns = struct { + result1 T + }{result1} +} + +func (fake *FakeGenericInterface[T]) ReturnTReturnsOnCall(i int, result1 T) { + fake.returnTMutex.Lock() + defer fake.returnTMutex.Unlock() + fake.ReturnTStub = nil + if fake.returnTReturnsOnCall == nil { + fake.returnTReturnsOnCall = make(map[int]struct { + result1 T + }) + } + fake.returnTReturnsOnCall[i] = struct { + result1 T + }{result1} +} + +func (fake *FakeGenericInterface[T]) TakeAndReturnT(arg1 T) T { + fake.takeAndReturnTMutex.Lock() + ret, specificReturn := fake.takeAndReturnTReturnsOnCall[len(fake.takeAndReturnTArgsForCall)] + fake.takeAndReturnTArgsForCall = append(fake.takeAndReturnTArgsForCall, struct { + arg1 T + }{arg1}) + stub := fake.TakeAndReturnTStub + fakeReturns := fake.takeAndReturnTReturns + fake.recordInvocation("TakeAndReturnT", []interface{}{arg1}) + fake.takeAndReturnTMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeGenericInterface[T]) TakeAndReturnTCallCount() int { + fake.takeAndReturnTMutex.RLock() + defer fake.takeAndReturnTMutex.RUnlock() + return len(fake.takeAndReturnTArgsForCall) +} + +func (fake *FakeGenericInterface[T]) TakeAndReturnTCalls(stub func(T) T) { + fake.takeAndReturnTMutex.Lock() + defer fake.takeAndReturnTMutex.Unlock() + fake.TakeAndReturnTStub = stub +} + +func (fake *FakeGenericInterface[T]) TakeAndReturnTArgsForCall(i int) T { + fake.takeAndReturnTMutex.RLock() + defer fake.takeAndReturnTMutex.RUnlock() + argsForCall := fake.takeAndReturnTArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeGenericInterface[T]) TakeAndReturnTReturns(result1 T) { + fake.takeAndReturnTMutex.Lock() + defer fake.takeAndReturnTMutex.Unlock() + fake.TakeAndReturnTStub = nil + fake.takeAndReturnTReturns = struct { + result1 T + }{result1} +} + +func (fake *FakeGenericInterface[T]) TakeAndReturnTReturnsOnCall(i int, result1 T) { + fake.takeAndReturnTMutex.Lock() + defer fake.takeAndReturnTMutex.Unlock() + fake.TakeAndReturnTStub = nil + if fake.takeAndReturnTReturnsOnCall == nil { + fake.takeAndReturnTReturnsOnCall = make(map[int]struct { + result1 T + }) + } + fake.takeAndReturnTReturnsOnCall[i] = struct { + result1 T + }{result1} +} + +func (fake *FakeGenericInterface[T]) TakeT(arg1 T) { + fake.takeTMutex.Lock() + fake.takeTArgsForCall = append(fake.takeTArgsForCall, struct { + arg1 T + }{arg1}) + stub := fake.TakeTStub + fake.recordInvocation("TakeT", []interface{}{arg1}) + fake.takeTMutex.Unlock() + if stub != nil { + fake.TakeTStub(arg1) + } +} + +func (fake *FakeGenericInterface[T]) TakeTCallCount() int { + fake.takeTMutex.RLock() + defer fake.takeTMutex.RUnlock() + return len(fake.takeTArgsForCall) +} + +func (fake *FakeGenericInterface[T]) TakeTCalls(stub func(T)) { + fake.takeTMutex.Lock() + defer fake.takeTMutex.Unlock() + fake.TakeTStub = stub +} + +func (fake *FakeGenericInterface[T]) TakeTArgsForCall(i int) T { + fake.takeTMutex.RLock() + defer fake.takeTMutex.RUnlock() + argsForCall := fake.takeTArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeGenericInterface[T]) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.doSomethingMutex.RLock() + defer fake.doSomethingMutex.RUnlock() + fake.returnTMutex.RLock() + defer fake.returnTMutex.RUnlock() + fake.takeAndReturnTMutex.RLock() + defer fake.takeAndReturnTMutex.RUnlock() + fake.takeTMutex.RLock() + defer fake.takeTMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeGenericInterface[T]) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ genericinterface.GenericInterface[genericinterface.CustomType] = new(FakeGenericInterface[genericinterface.CustomType]) diff --git a/fixtures/genericinterface/genericinterfacefakes/fake_generic_interface2.go b/fixtures/genericinterface/genericinterfacefakes/fake_generic_interface2.go new file mode 100644 index 0000000..082e7a6 --- /dev/null +++ b/fixtures/genericinterface/genericinterfacefakes/fake_generic_interface2.go @@ -0,0 +1,245 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package genericinterfacefakes + +import ( + "sync" + + "github.com/maxbrunsfeld/counterfeiter/v6/fixtures/genericinterface" +) + +type FakeGenericInterface2[T genericinterface.CustomType] struct { + DoSomethingStub func() + doSomethingMutex sync.RWMutex + doSomethingArgsForCall []struct { + } + ReturnTStub func() T + returnTMutex sync.RWMutex + returnTArgsForCall []struct { + } + returnTReturns struct { + result1 T + } + returnTReturnsOnCall map[int]struct { + result1 T + } + TakeAndReturnTStub func(T) T + takeAndReturnTMutex sync.RWMutex + takeAndReturnTArgsForCall []struct { + arg1 T + } + takeAndReturnTReturns struct { + result1 T + } + takeAndReturnTReturnsOnCall map[int]struct { + result1 T + } + TakeTStub func(T) + takeTMutex sync.RWMutex + takeTArgsForCall []struct { + arg1 T + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeGenericInterface2[T]) DoSomething() { + fake.doSomethingMutex.Lock() + fake.doSomethingArgsForCall = append(fake.doSomethingArgsForCall, struct { + }{}) + stub := fake.DoSomethingStub + fake.recordInvocation("DoSomething", []interface{}{}) + fake.doSomethingMutex.Unlock() + if stub != nil { + fake.DoSomethingStub() + } +} + +func (fake *FakeGenericInterface2[T]) DoSomethingCallCount() int { + fake.doSomethingMutex.RLock() + defer fake.doSomethingMutex.RUnlock() + return len(fake.doSomethingArgsForCall) +} + +func (fake *FakeGenericInterface2[T]) DoSomethingCalls(stub func()) { + fake.doSomethingMutex.Lock() + defer fake.doSomethingMutex.Unlock() + fake.DoSomethingStub = stub +} + +func (fake *FakeGenericInterface2[T]) ReturnT() T { + fake.returnTMutex.Lock() + ret, specificReturn := fake.returnTReturnsOnCall[len(fake.returnTArgsForCall)] + fake.returnTArgsForCall = append(fake.returnTArgsForCall, struct { + }{}) + stub := fake.ReturnTStub + fakeReturns := fake.returnTReturns + fake.recordInvocation("ReturnT", []interface{}{}) + fake.returnTMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeGenericInterface2[T]) ReturnTCallCount() int { + fake.returnTMutex.RLock() + defer fake.returnTMutex.RUnlock() + return len(fake.returnTArgsForCall) +} + +func (fake *FakeGenericInterface2[T]) ReturnTCalls(stub func() T) { + fake.returnTMutex.Lock() + defer fake.returnTMutex.Unlock() + fake.ReturnTStub = stub +} + +func (fake *FakeGenericInterface2[T]) ReturnTReturns(result1 T) { + fake.returnTMutex.Lock() + defer fake.returnTMutex.Unlock() + fake.ReturnTStub = nil + fake.returnTReturns = struct { + result1 T + }{result1} +} + +func (fake *FakeGenericInterface2[T]) ReturnTReturnsOnCall(i int, result1 T) { + fake.returnTMutex.Lock() + defer fake.returnTMutex.Unlock() + fake.ReturnTStub = nil + if fake.returnTReturnsOnCall == nil { + fake.returnTReturnsOnCall = make(map[int]struct { + result1 T + }) + } + fake.returnTReturnsOnCall[i] = struct { + result1 T + }{result1} +} + +func (fake *FakeGenericInterface2[T]) TakeAndReturnT(arg1 T) T { + fake.takeAndReturnTMutex.Lock() + ret, specificReturn := fake.takeAndReturnTReturnsOnCall[len(fake.takeAndReturnTArgsForCall)] + fake.takeAndReturnTArgsForCall = append(fake.takeAndReturnTArgsForCall, struct { + arg1 T + }{arg1}) + stub := fake.TakeAndReturnTStub + fakeReturns := fake.takeAndReturnTReturns + fake.recordInvocation("TakeAndReturnT", []interface{}{arg1}) + fake.takeAndReturnTMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeGenericInterface2[T]) TakeAndReturnTCallCount() int { + fake.takeAndReturnTMutex.RLock() + defer fake.takeAndReturnTMutex.RUnlock() + return len(fake.takeAndReturnTArgsForCall) +} + +func (fake *FakeGenericInterface2[T]) TakeAndReturnTCalls(stub func(T) T) { + fake.takeAndReturnTMutex.Lock() + defer fake.takeAndReturnTMutex.Unlock() + fake.TakeAndReturnTStub = stub +} + +func (fake *FakeGenericInterface2[T]) TakeAndReturnTArgsForCall(i int) T { + fake.takeAndReturnTMutex.RLock() + defer fake.takeAndReturnTMutex.RUnlock() + argsForCall := fake.takeAndReturnTArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeGenericInterface2[T]) TakeAndReturnTReturns(result1 T) { + fake.takeAndReturnTMutex.Lock() + defer fake.takeAndReturnTMutex.Unlock() + fake.TakeAndReturnTStub = nil + fake.takeAndReturnTReturns = struct { + result1 T + }{result1} +} + +func (fake *FakeGenericInterface2[T]) TakeAndReturnTReturnsOnCall(i int, result1 T) { + fake.takeAndReturnTMutex.Lock() + defer fake.takeAndReturnTMutex.Unlock() + fake.TakeAndReturnTStub = nil + if fake.takeAndReturnTReturnsOnCall == nil { + fake.takeAndReturnTReturnsOnCall = make(map[int]struct { + result1 T + }) + } + fake.takeAndReturnTReturnsOnCall[i] = struct { + result1 T + }{result1} +} + +func (fake *FakeGenericInterface2[T]) TakeT(arg1 T) { + fake.takeTMutex.Lock() + fake.takeTArgsForCall = append(fake.takeTArgsForCall, struct { + arg1 T + }{arg1}) + stub := fake.TakeTStub + fake.recordInvocation("TakeT", []interface{}{arg1}) + fake.takeTMutex.Unlock() + if stub != nil { + fake.TakeTStub(arg1) + } +} + +func (fake *FakeGenericInterface2[T]) TakeTCallCount() int { + fake.takeTMutex.RLock() + defer fake.takeTMutex.RUnlock() + return len(fake.takeTArgsForCall) +} + +func (fake *FakeGenericInterface2[T]) TakeTCalls(stub func(T)) { + fake.takeTMutex.Lock() + defer fake.takeTMutex.Unlock() + fake.TakeTStub = stub +} + +func (fake *FakeGenericInterface2[T]) TakeTArgsForCall(i int) T { + fake.takeTMutex.RLock() + defer fake.takeTMutex.RUnlock() + argsForCall := fake.takeTArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeGenericInterface2[T]) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.doSomethingMutex.RLock() + defer fake.doSomethingMutex.RUnlock() + fake.returnTMutex.RLock() + defer fake.returnTMutex.RUnlock() + fake.takeAndReturnTMutex.RLock() + defer fake.takeAndReturnTMutex.RUnlock() + fake.takeTMutex.RLock() + defer fake.takeTMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeGenericInterface2[T]) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ genericinterface.GenericInterface2[genericinterface.CustomType] = new(FakeGenericInterface2[genericinterface.CustomType]) diff --git a/generator/fake.go b/generator/fake.go index bf88f18..5c2a6fd 100644 --- a/generator/fake.go +++ b/generator/fake.go @@ -25,19 +25,22 @@ const ( // Fake is used to generate a Fake implementation of an interface. type Fake struct { - Packages []*packages.Package - Package *packages.Package - Target *types.TypeName - Mode FakeMode - DestinationPackage string - Name string - TargetAlias string - TargetName string - TargetPackage string - Imports Imports - Methods []Method - Function Method - Header string + Packages []*packages.Package + Package *packages.Package + Target *types.TypeName + Mode FakeMode + DestinationPackage string + Name string + GenericTypeParametersAndConstraints string + GenericTypeParameters string + GenericTypeConstraints string + TargetAlias string + TargetName string + TargetPackage string + Imports Imports + Methods []Method + Function Method + Header string } // Method is a method of the interface. diff --git a/generator/interface_template.go b/generator/interface_template.go index 1a8fde9..3be9c1a 100644 --- a/generator/interface_template.go +++ b/generator/interface_template.go @@ -27,7 +27,7 @@ import ( {{- end}} ) -type {{.Name}} struct { +type {{.Name}}{{.GenericTypeParametersAndConstraints}} struct { {{- range .Methods}} {{.Name}}Stub func({{.Params.AsArgs}}) {{.Returns.AsReturnSignature}} {{UnExport .Name}}Mutex sync.RWMutex @@ -54,7 +54,7 @@ type {{.Name}} struct { } {{range .Methods -}} -func (fake *{{$.Name}}) {{.Name}}({{.Params.AsNamedArgsWithTypes}}) {{.Returns.AsReturnSignature}} { +func (fake *{{$.Name}}{{$.GenericTypeParameters}}) {{.Name}}({{.Params.AsNamedArgsWithTypes}}) {{.Returns.AsReturnSignature}} { {{- range .Params.Slices}} var {{UnExport .Name}}Copy {{.Type}} if {{UnExport .Name}} != nil { @@ -90,20 +90,20 @@ func (fake *{{$.Name}}) {{.Name}}({{.Params.AsNamedArgsWithTypes}}) {{.Returns.A {{- end}} } -func (fake *{{$.Name}}) {{Title .Name}}CallCount() int { +func (fake *{{$.Name}}{{$.GenericTypeParameters}}) {{Title .Name}}CallCount() int { fake.{{UnExport .Name}}Mutex.RLock() defer fake.{{UnExport .Name}}Mutex.RUnlock() return len(fake.{{UnExport .Name}}ArgsForCall) } -func (fake *{{$.Name}}) {{Title .Name}}Calls(stub func({{.Params.AsArgs}}) {{.Returns.AsReturnSignature}}) { +func (fake *{{$.Name}}{{$.GenericTypeParameters}}) {{Title .Name}}Calls(stub func({{.Params.AsArgs}}) {{.Returns.AsReturnSignature}}) { fake.{{UnExport .Name}}Mutex.Lock() defer fake.{{UnExport .Name}}Mutex.Unlock() fake.{{.Name}}Stub = stub } {{if .Params.HasLength -}} -func (fake *{{$.Name}}) {{Title .Name}}ArgsForCall(i int) {{.Params.AsReturnSignature}} { +func (fake *{{$.Name}}{{$.GenericTypeParameters}}) {{Title .Name}}ArgsForCall(i int) {{.Params.AsReturnSignature}} { fake.{{UnExport .Name}}Mutex.RLock() defer fake.{{UnExport .Name}}Mutex.RUnlock() argsForCall := fake.{{UnExport .Name}}ArgsForCall[i] @@ -112,7 +112,7 @@ func (fake *{{$.Name}}) {{Title .Name}}ArgsForCall(i int) {{.Params.AsReturnSign {{- end}} {{if .Returns.HasLength -}} -func (fake *{{$.Name}}) {{Title .Name}}Returns({{.Returns.AsNamedArgsWithTypes}}) { +func (fake *{{$.Name}}{{$.GenericTypeParameters}}) {{Title .Name}}Returns({{.Returns.AsNamedArgsWithTypes}}) { fake.{{UnExport .Name}}Mutex.Lock() defer fake.{{UnExport .Name}}Mutex.Unlock() fake.{{.Name}}Stub = nil @@ -123,7 +123,7 @@ func (fake *{{$.Name}}) {{Title .Name}}Returns({{.Returns.AsNamedArgsWithTypes}} }{ {{- .Returns.AsNamedArgs -}} } } -func (fake *{{$.Name}}) {{Title .Name}}ReturnsOnCall(i int, {{.Returns.AsNamedArgsWithTypes}}) { +func (fake *{{$.Name}}{{$.GenericTypeParameters}}) {{Title .Name}}ReturnsOnCall(i int, {{.Returns.AsNamedArgsWithTypes}}) { fake.{{UnExport .Name}}Mutex.Lock() defer fake.{{UnExport .Name}}Mutex.Unlock() fake.{{.Name}}Stub = nil @@ -144,7 +144,7 @@ func (fake *{{$.Name}}) {{Title .Name}}ReturnsOnCall(i int, {{.Returns.AsNamedAr {{end -}} {{end}} -func (fake *{{.Name}}) Invocations() map[string][][]interface{} { +func (fake *{{.Name}}{{$.GenericTypeParameters}}) Invocations() map[string][][]interface{} { fake.invocationsMutex.RLock() defer fake.invocationsMutex.RUnlock() {{- range .Methods}} @@ -158,7 +158,7 @@ func (fake *{{.Name}}) Invocations() map[string][][]interface{} { return copiedInvocations } -func (fake *{{.Name}}) recordInvocation(key string, args []interface{}) { +func (fake *{{.Name}}{{$.GenericTypeParameters}}) recordInvocation(key string, args []interface{}) { fake.invocationsMutex.Lock() defer fake.invocationsMutex.Unlock() if fake.invocations == nil { @@ -171,6 +171,6 @@ func (fake *{{.Name}}) recordInvocation(key string, args []interface{}) { } {{if IsExported .TargetName -}} -var _ {{.TargetAlias}}.{{.TargetName}} = new({{.Name}}) +var _ {{.TargetAlias}}.{{.TargetName}}{{.GenericTypeConstraints}} = new({{.Name}}{{.GenericTypeConstraints}}) {{- end}} ` diff --git a/generator/loader.go b/generator/loader.go index 4ffb873..648ba49 100644 --- a/generator/loader.go +++ b/generator/loader.go @@ -57,9 +57,32 @@ func (f *Fake) loadPackages(c Cacher, workingDir string) error { return nil } +func (f *Fake) getGenericTypeData(typeName *types.TypeName) (paramName string, constraintName string, found bool) { + if named, ok := typeName.Type().(*types.Named); ok { + if _, ok := named.Underlying().(*types.Interface); ok { + typeParams := named.TypeParams() + if typeParams.Len() > 0 { + for i := 0; i < typeParams.Len(); i++ { + param := typeParams.At(i) + paramName = param.Obj().Name() + constraint := param.Constraint() + constraintSections := strings.Split(constraint.String(), "/") + constraintName = constraintSections[len(constraintSections)-1] + found = true + return + } + } + } + } + return +} + func (f *Fake) findPackage() error { var target *types.TypeName var pkg *packages.Package + genericTypeParametersAndConstraints := []string{} + genericTypeConstraints := []string{} + genericTypeParameters := []string{} for i := range f.Packages { if f.Packages[i].Types == nil || f.Packages[i].Types.Scope() == nil { continue @@ -72,6 +95,15 @@ func (f *Fake) findPackage() error { raw := pkg.Types.Scope().Lookup(f.TargetName) if raw != nil { if typeName, ok := raw.(*types.TypeName); ok { + if paramName, constraintName, found := f.getGenericTypeData(typeName); found { + genericTypeParameters = append(genericTypeParameters, paramName) + genericTypeConstraints = append(genericTypeConstraints, constraintName) + genericTypeParametersAndConstraints = append( + genericTypeParametersAndConstraints, + fmt.Sprintf("%s %s", paramName, constraintName), + ) + } + target = typeName break } @@ -89,6 +121,11 @@ func (f *Fake) findPackage() error { f.Target = target f.Package = pkg f.TargetPackage = imports.VendorlessPath(pkg.PkgPath) + if len(genericTypeParameters) > 0 { + f.GenericTypeParametersAndConstraints = fmt.Sprintf("[%s]", strings.Join(genericTypeParametersAndConstraints, ", ")) + f.GenericTypeParameters = fmt.Sprintf("[%s]", strings.Join(genericTypeParameters, ", ")) + f.GenericTypeConstraints = fmt.Sprintf("[%s]", strings.Join(genericTypeConstraints, ", ")) + } t := f.Imports.Add(pkg.Name, f.TargetPackage) f.TargetAlias = t.Alias if f.Mode != Package { @@ -97,7 +134,7 @@ func (f *Fake) findPackage() error { if f.Mode == InterfaceOrFunction { if !f.IsInterface() && !f.IsFunction() { - return fmt.Errorf("cannot generate an fake for %s because it is not an interface or function", f.TargetName) + return fmt.Errorf("cannot generate a fake for %s because it is not an interface or function", f.TargetName) } }