diff --git a/CHANGELOG.md b/CHANGELOG.md index 000b110..2b64048 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,18 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## v2.2.2 + +Internal refactoring + +### Added + +- Visualize parameter bag + +### Fixed + +- Visualize type detection + ## v2.2.1 ### Fixed diff --git a/di/container.go b/di/container.go index d5fe8a3..d28ac2a 100644 --- a/di/container.go +++ b/di/container.go @@ -35,27 +35,28 @@ type ProvideParams struct { // Provide adds constructor into container with parameters. func (c *Container) Provide(params ProvideParams) { - prov := provider(newConstructorProvider(params.Name, params.Provider)) - k := prov.resultKey() - - if c.exists(k) { - panicf("The `%s` type already exists in container", prov.resultKey()) + p := provider(newProviderConstructor(params.Name, params.Provider)) + if c.exists(p.Key()) { + panicf("The `%s` type already exists in container", p.Key()) } - if !params.IsPrototype { - prov = asSingleton(prov) + p = asSingleton(p) } - - c.addProvider(prov) - c.provideEmbedParameters(prov) - + // add provider to graph + c.add(p) + // parse embed parameters + for _, parameter := range p.ParameterList() { + if parameter.embed { + c.add(newProviderEmbed(parameter)) + } + } + // provide parameter bag if len(params.Parameters) != 0 { - parameterBugProvider := createParameterBugProvider(k, params.Parameters) - c.addProvider(parameterBugProvider) + c.add(createParameterBugProvider(p.Key(), params.Parameters)) } - + // process interfaces for _, iface := range params.Interfaces { - c.processProviderInterface(prov, iface) + c.processProviderInterface(p, iface) } } @@ -68,19 +69,14 @@ func (c *Container) Compile() { return c }, }) - c.Provide(ProvideParams{ Provider: func() *Graph { return &Graph{graph: c.graph.DOTGraph()} }, }) - - for _, key := range c.all() { - // register provider parameters - provider, _ := c.provider(key) - c.registerParameters(provider) + for _, p := range c.all() { + c.registerProviderParameters(p) } - _, err := c.graph.DFSSort() if err != nil { switch err { @@ -90,7 +86,6 @@ func (c *Container) Compile() { panic(err.Error()) } } - c.compiled = true } @@ -105,21 +100,25 @@ func (c *Container) Extract(params ExtractParams) error { if !c.compiled { return fmt.Errorf("container not compiled") } - if params.Target == nil { return fmt.Errorf("extract target must be a pointer, got `nil`") } - if !reflection.IsPtr(params.Target) { return fmt.Errorf("extract target must be a pointer, got `%s`", reflect.TypeOf(params.Target)) } - - key := key{ - name: params.Name, - typ: reflect.TypeOf(params.Target).Elem(), + typ := reflect.TypeOf(params.Target) + param := parameter{ + name: params.Name, + res: typ.Elem(), + embed: isEmbedParameter(typ), } - - return key.extract(c, params.Target) + value, err := param.ResolveValue(c) + if err != nil { + return err + } + targetValue := reflect.ValueOf(params.Target).Elem() + targetValue.Set(value) + return nil } // InvokeParams @@ -133,38 +132,26 @@ func (c *Container) Invoke(params InvokeParams) error { if !c.compiled { return fmt.Errorf("container not compiled") } - invoker, err := newInvoker(params.Fn) if err != nil { return err } - return invoker.Invoke(c) } // Cleanup func (c *Container) Cleanup() { - for _, key := range c.all() { - provider, _ := c.provider(key) - if cleanup, ok := provider.(cleanup); ok { + for _, p := range c.all() { + if cleanup, ok := p.(cleanup); ok { cleanup.cleanup() } } } -// addProvider -func (c *Container) addProvider(p provider) { - c.graph.AddNode(p.resultKey()) - c.providers[p.resultKey()] = p -} - -// provideEmbedParameters -func (c *Container) provideEmbedParameters(p provider) { - for _, parameter := range p.parameters() { - if parameter.embed { - c.addProvider(newEmbedProvider(parameter)) - } - } +// add +func (c *Container) add(p provider) { + c.graph.AddNode(p.Key()) + c.providers[p.Key()] = p } // exists checks that key registered in container graph. @@ -177,64 +164,73 @@ func (c *Container) provider(k key) (provider, bool) { if !c.exists(k) { return nil, false } - return c.providers[k], true } // all return all container keys. -func (c *Container) all() []key { - var keys []key - - for _, node := range c.graph.Nodes() { - keys = append(keys, node.(key)) +func (c *Container) all() []provider { + var providers []provider + for _, k := range c.graph.Nodes() { + p, _ := c.provider(k.(key)) + providers = append(providers, p) } - - return keys + return providers } // processProviderInterface represents instances as interfaces and groups. func (c *Container) processProviderInterface(provider provider, as interface{}) { - // create interface from embedParamProvider - iface := newInterfaceProvider(provider, as) - ifaceKey := iface.resultKey() - - if c.graph.NodeExists(ifaceKey) { + // create interface from provider + iface := newProviderInterface(provider, as) + if c.graph.NodeExists(iface.Key()) { // if iface already exists, restrict interface resolving - c.providers[ifaceKey] = iface.Multiple() + c.providers[iface.Key()] = newProviderStub(iface.Key(), "have several implementations") } else { // add interface node - c.graph.AddNode(ifaceKey) - c.providers[ifaceKey] = iface + c.graph.AddNode(iface.Key()) + c.providers[iface.Key()] = iface } - // create group - group := newGroupProvider(ifaceKey) - groupKey := group.resultKey() - + group := newGroupProvider(iface.Key()) // check exists - if c.graph.NodeExists(groupKey) { + if c.exists(group.Key()) { // if exists use existing group - group = c.providers[groupKey].(*interfaceGroup) + group = c.providers[group.Key()].(*interfaceGroup) } else { // else add new group to graph - c.graph.AddNode(groupKey) - c.providers[groupKey] = group + c.graph.AddNode(group.Key()) + c.providers[group.Key()] = group } - // add embedParamProvider ifaceKey into group - group.Add(provider.resultKey()) + group.Add(provider.Key()) } -// registerParameters registers provider parameters in a dependency graph. -func (c *Container) registerParameters(provider provider) { - for _, parameter := range provider.parameters() { - _, exists := c.provider(parameter.resultKey()) +// registerProviderParameters registers provider parameters in a dependency graph. +func (c *Container) registerProviderParameters(p provider) { + for _, param := range p.ParameterList() { + paramProvider, exists := c.resolveParameterProvider(param) if exists { - c.graph.AddEdge(parameter.resultKey(), provider.resultKey()) + c.graph.AddEdge(paramProvider.Key(), p.Key()) + continue } + if !exists && !param.optional { + panicf("%s: dependency %s not exists in container", p.Key(), param) + } + } +} - if !exists && !parameter.optional { - panicf("%s: dependency %s not exists in container", provider.resultKey(), parameter.resultKey()) +// resolveParameterProvider lookup provider by parameter. +func (c *Container) resolveParameterProvider(param parameter) (provider, bool) { + for _, pt := range providerLookupSequence { + k := key{ + name: param.name, + res: param.res, + typ: pt, + } + provider, exists := c.provider(k) + if !exists { + continue } + return provider, true } + return nil, false } diff --git a/di/container_test.go b/di/container_test.go index b622817..db6bc63 100644 --- a/di/container_test.go +++ b/di/container_test.go @@ -139,7 +139,7 @@ func TestContainerExtractErrors(t *testing.T) { c.MustProvide(ditest.NewBar) c.MustCompile() var bar *ditest.Bar - c.MustExtractError(&bar, "*ditest.Bar: *ditest.Foo: internal error") + c.MustExtractError(&bar, "*ditest.Foo: internal error") }) t.Run("extract interface with multiple implementations cause error", func(t *testing.T) { @@ -150,7 +150,7 @@ func TestContainerExtractErrors(t *testing.T) { c.MustCompile() var extracted ditest.Fooer - c.MustExtractError(&extracted, "ditest.Fooer: ditest.Fooer have sereral implementations") + c.MustExtractError(&extracted, "ditest.Fooer: have several implementations") }) } @@ -519,7 +519,7 @@ func TestContainerCleanup(t *testing.T) { } func TestContainer_GraphVisualizing(t *testing.T) { - t.Run("", func(t *testing.T) { + t.Run("graph", func(t *testing.T) { c := NewTestContainer(t) c.MustProvide(ditest.NewLogger) @@ -534,18 +534,22 @@ func TestContainer_GraphVisualizing(t *testing.T) { Target: &graph, })) + fmt.Println(graph.String()) + require.Equal(t, `digraph { subgraph cluster_s3 { ID = "cluster_s3"; bgcolor="#E8E8E8";color="lightgrey";fontcolor="#46494C";fontname="COURIER";label="";style="rounded"; - n8[color="#46494C",fontcolor="white",fontname="COURIER",label="*di.Graph",shape="box",style="filled"]; + n10[color="#46494C",fontcolor="white",fontname="COURIER",label="*di.Graph",shape="box",style="filled"]; + n9[color="#46494C",fontcolor="white",fontname="COURIER",label="di.Extractor",shape="box",style="filled"]; }subgraph cluster_s2 { ID = "cluster_s2"; bgcolor="#E8E8E8";color="lightgrey";fontcolor="#46494C";fontname="COURIER";label="";style="rounded"; - n5[color="#46494C",fontcolor="white",fontname="COURIER",label="*ditest.AccountController",shape="box",style="filled"]; - n7[color="#46494C",fontcolor="white",fontname="COURIER",label="*ditest.AuthController",shape="box",style="filled"]; - n6[color="#E54B4B",fontcolor="white",fontname="COURIER",label="[]ditest.Controller",shape="doubleoctagon",style="filled"]; + n6[color="#46494C",fontcolor="white",fontname="COURIER",label="*ditest.AccountController",shape="box",style="filled"]; + n8[color="#46494C",fontcolor="white",fontname="COURIER",label="*ditest.AuthController",shape="box",style="filled"]; + n7[color="#E54B4B",fontcolor="white",fontname="COURIER",label="[]ditest.Controller",shape="doubleoctagon",style="filled"]; + n4[color="#E5984B",fontcolor="white",fontname="COURIER",label="ditest.RouterParams",shape="box",style="filled"]; }subgraph cluster_s0 { ID = "cluster_s0"; @@ -557,18 +561,19 @@ func TestContainer_GraphVisualizing(t *testing.T) { bgcolor="#E8E8E8";color="lightgrey";fontcolor="#46494C";fontname="COURIER";label="";style="rounded"; n3[color="#46494C",fontcolor="white",fontname="COURIER",label="*http.ServeMux",shape="box",style="filled"]; n2[color="#46494C",fontcolor="white",fontname="COURIER",label="*http.Server",shape="box",style="filled"]; - n4[color="#2589BD",fontcolor="white",fontname="COURIER",label="http.Handler",style="filled"]; + n5[color="#2589BD",fontcolor="white",fontname="COURIER",label="http.Handler",style="filled"]; }splines="ortho"; - n5->n6[color="#949494"]; - n7->n6[color="#949494"]; - n3->n4[color="#949494"]; + n6->n7[color="#949494"]; + n8->n7[color="#949494"]; + n3->n5[color="#949494"]; n1->n2[color="#949494"]; n1->n3[color="#949494"]; - n1->n5[color="#949494"]; - n1->n7[color="#949494"]; - n6->n3[color="#949494"]; - n4->n2[color="#949494"]; + n1->n6[color="#949494"]; + n1->n8[color="#949494"]; + n7->n4[color="#949494"]; + n4->n3[color="#949494"]; + n5->n2[color="#949494"]; }`, graph.String()) }) diff --git a/di/errors.go b/di/errors.go index d76af65..2a9a39c 100644 --- a/di/errors.go +++ b/di/errors.go @@ -2,11 +2,21 @@ package di import "fmt" -// ErrProviderNotFound -type ErrProviderNotFound struct { - k key +// ErrParameterProvideFailed +type ErrParameterProvideFailed struct { + k key + err error } -func (e ErrProviderNotFound) Error() string { - return fmt.Sprintf("not exists in container") +func (e ErrParameterProvideFailed) Error() string { + return fmt.Sprintf("%s: %s", e.k, e.err) +} + +// ErrParameterProviderNotFound +type ErrParameterProviderNotFound struct { + param parameter +} + +func (e ErrParameterProviderNotFound) Error() string { + return fmt.Sprintf("%s: not exists in container", e.param) } diff --git a/di/internal/dag/output.go b/di/internal/dag/output.go index 7b25263..727e362 100644 --- a/di/internal/dag/output.go +++ b/di/internal/dag/output.go @@ -10,7 +10,7 @@ import ( type NodeVisualizer interface { Visualize(node *dot.Node) SubGraph() string - IsPrimary() bool + IsAlwaysVisible() bool } // DOTGraph returns a textual representation of the graph in the DOT graph @@ -22,21 +22,21 @@ func (g *DirectedGraph) DOTGraph() *dot.Graph { subgraphs := make(map[string]*dot.Graph) itemsByNode := make(map[Node]dot.Node) for _, node := range g.Nodes() { - visualizer := node.(NodeVisualizer) + nv := node.(NodeVisualizer) - if !g.HasOutgoingEdges(node) && !visualizer.IsPrimary() { + if !g.HasOutgoingEdges(node) && !nv.IsAlwaysVisible() { continue } name := fmt.Sprintf("%s", node) - subgraph, ok := subgraphs[visualizer.SubGraph()] + subgraph, ok := subgraphs[nv.SubGraph()] if !ok { - subgraph = root.Subgraph(visualizer.SubGraph(), dot.ClusterOption{}) - subgraphs[visualizer.SubGraph()] = subgraph + subgraph = root.Subgraph(nv.SubGraph(), dot.ClusterOption{}) + subgraphs[nv.SubGraph()] = subgraph applySubGraphStyle(subgraph) } item := subgraph.Node(name) - visualizer.Visualize(&item) + nv.Visualize(&item) itemsByNode[node] = item } diff --git a/di/internal/ditest/full.go b/di/internal/ditest/full.go index d89f3f7..c25bd7c 100644 --- a/di/internal/ditest/full.go +++ b/di/internal/ditest/full.go @@ -4,6 +4,8 @@ import ( "log" "net/http" "os" + + "github.com/defval/inject/v2/di" ) // NewLogger @@ -22,14 +24,20 @@ func NewServer(logger *log.Logger, handler http.Handler) *http.Server { } } +// RouterParams +type RouterParams struct { + di.Parameter + Controllers []Controller `di:"optional"` +} + // NewRouter -func NewRouter(logger *log.Logger, controllers []Controller) *http.ServeMux { +func NewRouter(logger *log.Logger, params RouterParams) *http.ServeMux { logger.Println("Create router!") defer logger.Println("Router created!") mux := &http.ServeMux{} - for _, ctrl := range controllers { + for _, ctrl := range params.Controllers { ctrl.RegisterRoutes(mux) } diff --git a/di/invoker.go b/di/invoker.go index fba4b8e..b796933 100644 --- a/di/invoker.go +++ b/di/invoker.go @@ -15,6 +15,16 @@ const ( invokerError // func (deps) error {} ) +func determineInvokerType(fn *reflection.Func) (invokerType, error) { + if fn.NumOut() == 0 { + return invokerStd, nil + } + if fn.NumOut() == 1 && reflection.IsError(fn.Out(0)) { + return invokerError, nil + } + return invokerUnknown, fmt.Errorf("the invoke function must be a function like `func([dep1, dep2, ...]) [error]`, got `%s`", fn.Type) +} + type invoker struct { typ invokerType fn *reflection.Func @@ -24,17 +34,14 @@ func newInvoker(fn interface{}) (*invoker, error) { if fn == nil { return nil, fmt.Errorf("the invoke function must be a function like `func([dep1, dep2, ...]) [error]`, got `%s`", "nil") } - if !reflection.IsFunc(fn) { return nil, fmt.Errorf("the invoke function must be a function like `func([dep1, dep2, ...]) [error]`, got `%s`", reflect.ValueOf(fn).Type()) } - ifn := reflection.InspectFunction(fn) typ, err := determineInvokerType(ifn) if err != nil { return nil, err } - return &invoker{ typ: typ, fn: reflection.InspectFunction(fn), @@ -43,52 +50,30 @@ func newInvoker(fn interface{}) (*invoker, error) { func (i *invoker) Invoke(c *Container) error { plist := i.parameters() - values, err := plist.resolve(c) + values, err := plist.ResolveValues(c) if err != nil { return fmt.Errorf("could not resolve invoke parameters: %s", err) } - results := i.fn.Call(values) - if len(results) == 0 { return nil } - if results[0].Interface() == nil { return nil } - return results[0].Interface().(error) } func (i *invoker) parameters() parameterList { - var list parameterList - + var plist parameterList for j := 0; j < i.fn.NumIn(); j++ { ptype := i.fn.In(j) - p := parameter{ - key: key{ - typ: ptype, - }, + res: ptype, optional: false, embed: isEmbedParameter(ptype), } - - list = append(list, p) - } - - return list -} - -func determineInvokerType(fn *reflection.Func) (invokerType, error) { - if fn.NumOut() == 0 { - return invokerStd, nil - } - - if fn.NumOut() == 1 && reflection.IsError(fn.Out(0)) { - return invokerError, nil + plist = append(plist, p) } - - return invokerUnknown, fmt.Errorf("the invoke function must be a function like `func([dep1, dep2, ...]) [error]`, got `%s`", fn.Type) + return plist } diff --git a/di/key.go b/di/key.go index 2da024f..f0a3cf7 100644 --- a/di/key.go +++ b/di/key.go @@ -7,45 +7,34 @@ import ( "github.com/emicklei/dot" ) -// resultKey is a id of represented instance in container +// key is a id of provider in container type key struct { name string - typ reflect.Type + res reflect.Type + typ providerType } // String represent resultKey as string. func (k key) String() string { if k.name == "" { - return fmt.Sprintf("%s", k.typ) + return fmt.Sprintf("%s", k.res) } - - return fmt.Sprintf("%s[%s]", k.typ, k.name) -} - -// resultKey -func (k key) resultKey() key { - return k + return fmt.Sprintf("%s[%s]", k.res, k.name) } -// IsPrimary -func (k key) IsPrimary() bool { - if k.typ.Kind() == reflect.Slice { - return false - } - if k.typ.Kind() == reflect.Interface { - return false - } - return true +// IsAlwaysVisible +func (k key) IsAlwaysVisible() bool { + return k.typ == ptConstructor } // Package func (k key) SubGraph() string { var pkg string - switch k.typ.Kind() { + switch k.res.Kind() { case reflect.Slice, reflect.Ptr: - pkg = k.typ.Elem().PkgPath() + pkg = k.res.Elem().PkgPath() default: - pkg = k.typ.PkgPath() + pkg = k.res.PkgPath() } return pkg @@ -57,42 +46,17 @@ func (k key) Visualize(node *dot.Node) { node.Attr("fontname", "COURIER") node.Attr("style", "filled") node.Attr("fontcolor", "white") - if k.typ.Kind() == reflect.Slice { + switch k.typ { + case ptConstructor: + node.Attr("shape", "box") + node.Attr("color", "#46494C") + case ptGroup: node.Attr("shape", "doubleoctagon") node.Attr("color", "#E54B4B") - return - } - if k.typ.Kind() == reflect.Interface { + case ptInterface: node.Attr("color", "#2589BD") - return - } - node.Attr("color", "#46494C") - node.Box() -} - -func (k key) resolve(c *Container) (reflect.Value, error) { - provider, exists := c.provider(k) - if !exists { - return reflect.Value{}, ErrProviderNotFound{k: k} + case ptEmbedParameter: + node.Attr("shape", "box") + node.Attr("color", "#E5984B") } - - values, err := provider.parameters().resolve(c) - if err != nil { - return reflect.Value{}, err - } - - return provider.provide(values...) -} - -// Extract extracts instance by resultKey from container into target. -func (k key) extract(c *Container, target interface{}) error { - value, err := k.resolve(c) - if err != nil { - return fmt.Errorf("%s: %s", k, err) - } - - targetValue := reflect.ValueOf(target).Elem() - targetValue.Set(value) - - return nil } diff --git a/di/parameter.go b/di/parameter.go index 63c0ca4..ad4dfe8 100644 --- a/di/parameter.go +++ b/di/parameter.go @@ -1,7 +1,6 @@ package di import ( - "fmt" "reflect" ) @@ -12,41 +11,57 @@ type Parameter struct { // parameterRequired type parameter struct { - key + name string + res reflect.Type optional bool embed bool } -func (p parameter) resolve(c *Container) (reflect.Value, error) { - value, err := p.key.resolve(c) - if _, notFound := err.(ErrProviderNotFound); notFound && p.optional { - // create empty instance of type - return reflect.New(p.typ).Elem(), nil +func (p parameter) String() string { + return key{name: p.name, res: p.res}.String() +} + +// ResolveProvider resolves parameter provider +func (p parameter) ResolveProvider(c *Container) (provider, bool) { + for _, pt := range providerLookupSequence { + k := key{ + name: p.name, + res: p.res, + typ: pt, + } + provider, exists := c.provider(k) + if !exists { + continue + } + return provider, true } + return nil, false +} +func (p parameter) ResolveValue(c *Container) (reflect.Value, error) { + provider, exists := p.ResolveProvider(c) + if !exists && p.optional { + return reflect.New(p.res).Elem(), nil + } + if !exists { + return reflect.Value{}, ErrParameterProviderNotFound{param: p} + } + pl := provider.ParameterList() + values, err := pl.ResolveValues(c) if err != nil { return reflect.Value{}, err } + value, err := provider.Provide(values...) + if err != nil { + return value, ErrParameterProvideFailed{k: provider.Key(), err: err} + } return value, nil } -// parameterList -type parameterList []parameter - -// resolve loads all parameters presented in parameter list. -func (pl parameterList) resolve(c *Container) ([]reflect.Value, error) { - var values []reflect.Value - for _, p := range pl { - pvalue, err := p.resolve(c) - if err != nil { - return nil, fmt.Errorf("%s: %s", p.resultKey(), err) - } - - values = append(values, pvalue) - } - - return values, nil +// isEmbedParameter +func isEmbedParameter(typ reflect.Type) bool { + return typ.Kind() == reflect.Struct && typ.Implements(parameterInterface) } // internalParameter @@ -54,9 +69,5 @@ type internalParameter interface { isDependencyInjectionParameter() } +// parameterInterface var parameterInterface = reflect.TypeOf(new(internalParameter)).Elem() - -// isEmbedParameter -func isEmbedParameter(typ reflect.Type) bool { - return typ.Kind() == reflect.Struct && typ.Implements(parameterInterface) -} diff --git a/di/parameter_bag.go b/di/parameter_bag.go index 7496ab0..febecc6 100644 --- a/di/parameter_bag.go +++ b/di/parameter_bag.go @@ -91,7 +91,7 @@ func (b ParameterBag) RequireFloat64(key string) float64 { // createParameterBugProvider func createParameterBugProvider(key key, parameters ParameterBag) provider { - return newConstructorProvider(key.String(), func() ParameterBag { return parameters }) + return newProviderConstructor(key.String(), func() ParameterBag { return parameters }) } var parameterBagType = reflect.TypeOf(ParameterBag{}) diff --git a/di/parameter_list.go b/di/parameter_list.go new file mode 100644 index 0000000..e16442c --- /dev/null +++ b/di/parameter_list.go @@ -0,0 +1,20 @@ +package di + +import "reflect" + +// parameterList +type parameterList []parameter + +// ResolveValues loads all parameters presented in parameter list. +func (pl parameterList) ResolveValues(c *Container) ([]reflect.Value, error) { + var values []reflect.Value + for _, p := range pl { + value, err := p.ResolveValue(c) + if err != nil { + return nil, err + } + values = append(values, value) + } + + return values, nil +} diff --git a/di/provider.go b/di/provider.go index 2e55179..8c081cf 100644 --- a/di/provider.go +++ b/di/provider.go @@ -2,11 +2,28 @@ package di import "reflect" +// provider lookup sequence +var providerLookupSequence = []providerType{ptConstructor, ptInterface, ptGroup, ptEmbedParameter} + +// providerType +type providerType int + +const ( + ptUnknown providerType = iota + ptConstructor + ptInterface + ptGroup + ptEmbedParameter +) + // provider type provider interface { - resultKey() key - parameters() parameterList - provide(parameters ...reflect.Value) (reflect.Value, error) + // The identity of result type. + Key() key + // ParameterList returns array of dependencies. + ParameterList() parameterList + // Provide provides value from provided parameters. + Provide(values ...reflect.Value) (reflect.Value, error) } // cleanup diff --git a/di/provider_ctor.go b/di/provider_ctor.go index 8c759af..9829e09 100644 --- a/di/provider_ctor.go +++ b/di/provider_ctor.go @@ -18,84 +18,70 @@ const ( ctorCleanupError // (deps) (result, cleanup, error) ) -// newConstructorProvider -func newConstructorProvider(name string, ctor interface{}) *constructorProvider { +// newProviderConstructor +func newProviderConstructor(name string, ctor interface{}) *providerConstructor { if ctor == nil { panicf("The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `%s`", "nil") } - if !reflection.IsFunc(ctor) { panicf("The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `%s`", reflect.ValueOf(ctor).Type()) } - fn := reflection.InspectFunction(ctor) ctorType := determineCtorType(fn) - - return &constructorProvider{ + return &providerConstructor{ name: name, ctor: fn, ctorType: ctorType, } } -// constructorProvider -type constructorProvider struct { +// providerConstructor +type providerConstructor struct { name string ctor *reflection.Func ctorType ctorType clean *reflection.Func } -// resultKey returns constructor result type resultKey. -func (c constructorProvider) resultKey() key { +func (c providerConstructor) Key() key { return key{ name: c.name, - typ: c.ctor.Out(0), + res: c.ctor.Out(0), + typ: ptConstructor, } } -// parameters returns constructor parameters -func (c constructorProvider) parameters() parameterList { - var list parameterList - +func (c providerConstructor) ParameterList() parameterList { + var plist parameterList for i := 0; i < c.ctor.NumIn(); i++ { ptype := c.ctor.In(i) - var pname string - + var name string if ptype == parameterBagType { - pname = c.resultKey().String() + name = c.Key().String() } - p := parameter{ - key: key{ - name: pname, - typ: ptype, - }, + name: name, + res: ptype, optional: false, embed: isEmbedParameter(ptype), } - - list = append(list, p) + plist = append(plist, p) } - - return list + return plist } // Provide -func (c *constructorProvider) provide(parameters ...reflect.Value) (reflect.Value, error) { +func (c *providerConstructor) Provide(parameters ...reflect.Value) (reflect.Value, error) { out := c.ctor.Call(parameters) - switch c.ctorType { case ctorStd: return out[0], nil case ctorError: instance := out[0] err := out[1] - if err.IsNil() { return instance, nil } - return instance, err.Interface().(error) case ctorCleanup: c.saveCleanup(out[1]) @@ -104,25 +90,21 @@ func (c *constructorProvider) provide(parameters ...reflect.Value) (reflect.Valu instance := out[0] cleanup := out[1] err := out[2] - c.saveCleanup(cleanup) - if err.IsNil() { return instance, nil } - return instance, err.Interface().(error) } - return reflect.Value{}, errors.New("you found a bug, please create new issue for " + "this: https://github.com/defval/inject/issues/new") } -func (c *constructorProvider) saveCleanup(value reflect.Value) { +func (c *providerConstructor) saveCleanup(value reflect.Value) { c.clean = reflection.InspectFunction(value.Interface()) } -func (c *constructorProvider) cleanup() { +func (c *providerConstructor) cleanup() { if c.clean != nil && c.clean.IsValid() { c.clean.Call([]reflect.Value{}) } @@ -133,20 +115,16 @@ func determineCtorType(fn *reflection.Func) ctorType { if fn.NumOut() == 1 { return ctorStd } - if fn.NumOut() == 2 { if reflection.IsError(fn.Out(1)) { return ctorError } - if reflection.IsCleanup(fn.Out(1)) { return ctorCleanup } } - if fn.NumOut() == 3 && reflection.IsCleanup(fn.Out(1)) && reflection.IsError(fn.Out(2)) { return ctorCleanupError } - panic(fmt.Sprintf("The constructor must be a function like `func([dep1, dep2, ...]) (, [cleanup, error])`, got `%s`", fn.Name)) } diff --git a/di/provider_embed.go b/di/provider_embed.go index ca4d511..1a2f267 100644 --- a/di/provider_embed.go +++ b/di/provider_embed.go @@ -6,101 +6,91 @@ import ( ) // createStructProvider -func newEmbedProvider(p parameter) *embedParamProvider { - result := p.resultKey() - +func newProviderEmbed(p parameter) *providerEmbed { var embedType reflect.Type - if result.typ.Kind() == reflect.Ptr { - embedType = result.typ.Elem() + if p.res.Kind() == reflect.Ptr { + embedType = p.res.Elem() } else { - embedType = result.typ + embedType = p.res } - return &embedParamProvider{ - key: result, + return &providerEmbed{ + key: key{ + name: p.name, + res: p.res, + typ: ptEmbedParameter, + }, embedType: embedType, embedValue: reflect.New(embedType).Elem(), } } -type embedParamProvider struct { +type providerEmbed struct { key key embedType reflect.Type embedValue reflect.Value } -func (s *embedParamProvider) resultKey() key { - return s.key +func (p *providerEmbed) Key() key { + return p.key } -func (s *embedParamProvider) parameters() parameterList { - var pl parameterList - - for i := 0; i < s.embedType.NumField(); i++ { - name, optional, isDependency := s.inspectFieldTag(i) +func (p *providerEmbed) ParameterList() parameterList { + var plist parameterList + for i := 0; i < p.embedType.NumField(); i++ { + name, optional, isDependency := p.inspectFieldTag(i) if !isDependency { continue } - - // parameter field - pField := s.embedType.Field(i) - - pl = append(pl, parameter{ - key: key{ - name: name, - typ: pField.Type, - }, + field := p.embedType.Field(i) + plist = append(plist, parameter{ + name: name, + res: field.Type, optional: optional, - embed: isEmbedParameter(pField.Type), + embed: isEmbedParameter(field.Type), }) } - - return pl + return plist } -func (s *embedParamProvider) inspectFieldTag(num int) (name string, optional bool, isDependency bool) { - tag, tagExists := s.embedType.Field(num).Tag.Lookup("di") - canSet := s.embedValue.Field(num).CanSet() - if !tagExists || !canSet { - return "", false, false +func (p *providerEmbed) Provide(parameters ...reflect.Value) (reflect.Value, error) { + for i, offset := 0, 0; i < p.embedType.NumField(); i++ { + _, _, isDependency := p.inspectFieldTag(i) + if !isDependency { + offset++ + continue + } + + p.embedValue.Field(i).Set(parameters[i-offset]) } - name, optional = s.parseTag(tag) + return p.embedValue, nil +} +func (p *providerEmbed) inspectFieldTag(num int) (name string, optional bool, isDependency bool) { + fieldType := p.embedType.Field(num) + fieldValue := p.embedValue.Field(num) + tag, tagExists := fieldType.Tag.Lookup("di") + if !tagExists || !fieldValue.CanSet() { + return "", false, false + } + name, optional = p.parseTag(tag) return name, optional, true } -func (s *embedParamProvider) parseTag(tag string) (name string, optional bool) { +func (p *providerEmbed) parseTag(tag string) (name string, optional bool) { options := strings.Split(tag, ",") if len(options) == 0 { return "", false } - if len(options) == 1 && options[0] == "optional" { return "", true } - if len(options) == 1 { return options[0], false } - if len(options) == 2 && options[1] == "optional" { return options[0], true } - panic("incorrect di tag") } - -func (s *embedParamProvider) provide(parameters ...reflect.Value) (reflect.Value, error) { - for i, offset := 0, 0; i < s.embedType.NumField(); i++ { - _, _, isDependency := s.inspectFieldTag(i) - if !isDependency { - offset++ - continue - } - - s.embedValue.Field(i).Set(parameters[i-offset]) - } - - return s.embedValue, nil -} diff --git a/di/provider_group.go b/di/provider_group.go index 1bb3a26..3e7bae5 100644 --- a/di/provider_group.go +++ b/di/provider_group.go @@ -7,7 +7,8 @@ import ( // newGroupProvider creates new group from provided resultKey. func newGroupProvider(k key) *interfaceGroup { ifaceKey := key{ - typ: reflect.SliceOf(k.typ), + res: reflect.SliceOf(k.res), + typ: ptGroup, } return &interfaceGroup{ @@ -25,24 +26,25 @@ type interfaceGroup struct { // Add func (i *interfaceGroup) Add(k key) { i.pl = append(i.pl, parameter{ - key: k, + name: k.name, + res: k.res, optional: false, embed: false, }) } // resultKey -func (i interfaceGroup) resultKey() key { +func (i interfaceGroup) Key() key { return i.result } // parameters -func (i interfaceGroup) parameters() parameterList { +func (i interfaceGroup) ParameterList() parameterList { return i.pl } // Provide -func (i interfaceGroup) provide(parameters ...reflect.Value) (reflect.Value, error) { - group := reflect.New(i.result.typ).Elem() +func (i interfaceGroup) Provide(parameters ...reflect.Value) (reflect.Value, error) { + group := reflect.New(i.result.res).Elem() return reflect.Append(group, parameters...), nil } diff --git a/di/provider_iface.go b/di/provider_iface.go index 5d94890..a40f072 100644 --- a/di/provider_iface.go +++ b/di/provider_iface.go @@ -1,71 +1,48 @@ package di import ( - "fmt" "reflect" "github.com/defval/inject/v2/di/internal/reflection" ) -// newInterfaceProvider -func newInterfaceProvider(provider provider, as interface{}) *interfaceProvider { +// newProviderInterface +func newProviderInterface(provider provider, as interface{}) *providerInterface { iface := reflection.InspectInterfacePtr(as) - - if !provider.resultKey().typ.Implements(iface.Type) { - panicf("%s not implement %s", provider.resultKey(), iface.Type) + if !provider.Key().res.Implements(iface.Type) { + panicf("%s not implement %s", provider.Key(), iface.Type) } - - return &interfaceProvider{ - result: key{ - name: provider.resultKey().name, - typ: iface.Type, + return &providerInterface{ + res: key{ + name: provider.Key().name, + res: iface.Type, + typ: ptInterface, }, - implementation: provider, + provider: provider, } } -// interfaceProvider -type interfaceProvider struct { - result key - implementation provider +// providerInterface +type providerInterface struct { + res key + provider provider } -func (i *interfaceProvider) resultKey() key { - return i.result +func (i *providerInterface) Key() key { + return i.res } -func (i *interfaceProvider) parameters() parameterList { - var list parameterList - list = append(list, parameter{ - key: i.implementation.resultKey(), +func (i *providerInterface) ParameterList() parameterList { + var plist parameterList + plist = append(plist, parameter{ + name: i.provider.Key().name, + res: i.provider.Key().res, optional: false, embed: false, }) - - return list + return plist } -func (i *interfaceProvider) provide(parameters ...reflect.Value) (reflect.Value, error) { +func (i *providerInterface) Provide(parameters ...reflect.Value) (reflect.Value, error) { return parameters[0], nil } - -func (i *interfaceProvider) Multiple() *multipleInterfaceProvider { - return &multipleInterfaceProvider{result: i.result} -} - -// multipleInterfaceProvider -type multipleInterfaceProvider struct { - result key -} - -func (m *multipleInterfaceProvider) resultKey() key { - return m.result -} - -func (m *multipleInterfaceProvider) parameters() parameterList { - return parameterList{} -} - -func (m *multipleInterfaceProvider) provide(parameters ...reflect.Value) (reflect.Value, error) { - return reflect.Value{}, fmt.Errorf("%s have sereral implementations", m.result.typ) -} diff --git a/di/provider_stub.go b/di/provider_stub.go new file mode 100644 index 0000000..78403f7 --- /dev/null +++ b/di/provider_stub.go @@ -0,0 +1,29 @@ +package di + +import ( + "fmt" + "reflect" +) + +// providerStub +type providerStub struct { + msg string + res key +} + +// newProviderStub +func newProviderStub(k key, msg string) *providerStub { + return &providerStub{res: k, msg: msg} +} + +func (m *providerStub) Key() key { + return m.res +} + +func (m *providerStub) ParameterList() parameterList { + return parameterList{} +} + +func (m *providerStub) Provide(_ ...reflect.Value) (reflect.Value, error) { + return reflect.Value{}, fmt.Errorf(m.msg) +} diff --git a/di/singleton.go b/di/singleton.go index 56d3d32..2e11f4f 100644 --- a/di/singleton.go +++ b/di/singleton.go @@ -16,18 +16,15 @@ type singletonWrapper struct { } // Provide -func (s *singletonWrapper) provide(parameters ...reflect.Value) (reflect.Value, error) { +func (s *singletonWrapper) Provide(parameters ...reflect.Value) (reflect.Value, error) { if s.value.IsValid() { return s.value, nil } - - value, err := s.provider.provide(parameters...) + value, err := s.provider.Provide(parameters...) if err != nil { return reflect.Value{}, err } - s.value = value - return value, nil }