Skip to content
This repository has been archived by the owner on Jan 18, 2021. It is now read-only.

Commit

Permalink
Change key resolving mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxim Bovtunov committed Dec 19, 2019
1 parent 63b44df commit 00bf22f
Show file tree
Hide file tree
Showing 18 changed files with 381 additions and 380 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## v2.2.2

Internal refactoring

### Added

- Visualize parameter bag

### Fixed

- Visualize type detection

## v2.2.1

### Fixed
Expand Down
160 changes: 78 additions & 82 deletions di/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,28 @@ type ProvideParams struct {

// Provide adds constructor into container with parameters.
func (c *Container) Provide(params ProvideParams) {
prov := provider(newConstructorProvider(params.Name, params.Provider))
k := prov.resultKey()

if c.exists(k) {
panicf("The `%s` type already exists in container", prov.resultKey())
p := provider(newProviderConstructor(params.Name, params.Provider))
if c.exists(p.Key()) {
panicf("The `%s` type already exists in container", p.Key())
}

if !params.IsPrototype {
prov = asSingleton(prov)
p = asSingleton(p)
}

c.addProvider(prov)
c.provideEmbedParameters(prov)

// add provider to graph
c.add(p)
// parse embed parameters
for _, parameter := range p.ParameterList() {
if parameter.embed {
c.add(newProviderEmbed(parameter))
}
}
// provide parameter bag
if len(params.Parameters) != 0 {
parameterBugProvider := createParameterBugProvider(k, params.Parameters)
c.addProvider(parameterBugProvider)
c.add(createParameterBugProvider(p.Key(), params.Parameters))
}

// process interfaces
for _, iface := range params.Interfaces {
c.processProviderInterface(prov, iface)
c.processProviderInterface(p, iface)
}
}

Expand All @@ -68,19 +69,14 @@ func (c *Container) Compile() {
return c
},
})

c.Provide(ProvideParams{
Provider: func() *Graph {
return &Graph{graph: c.graph.DOTGraph()}
},
})

for _, key := range c.all() {
// register provider parameters
provider, _ := c.provider(key)
c.registerParameters(provider)
for _, p := range c.all() {
c.registerProviderParameters(p)
}

_, err := c.graph.DFSSort()
if err != nil {
switch err {
Expand All @@ -90,7 +86,6 @@ func (c *Container) Compile() {
panic(err.Error())
}
}

c.compiled = true
}

Expand All @@ -105,21 +100,25 @@ func (c *Container) Extract(params ExtractParams) error {
if !c.compiled {
return fmt.Errorf("container not compiled")
}

if params.Target == nil {
return fmt.Errorf("extract target must be a pointer, got `nil`")
}

if !reflection.IsPtr(params.Target) {
return fmt.Errorf("extract target must be a pointer, got `%s`", reflect.TypeOf(params.Target))
}

key := key{
name: params.Name,
typ: reflect.TypeOf(params.Target).Elem(),
typ := reflect.TypeOf(params.Target)
param := parameter{
name: params.Name,
res: typ.Elem(),
embed: isEmbedParameter(typ),
}

return key.extract(c, params.Target)
value, err := param.ResolveValue(c)
if err != nil {
return err
}
targetValue := reflect.ValueOf(params.Target).Elem()
targetValue.Set(value)
return nil
}

// InvokeParams
Expand All @@ -133,38 +132,26 @@ func (c *Container) Invoke(params InvokeParams) error {
if !c.compiled {
return fmt.Errorf("container not compiled")
}

invoker, err := newInvoker(params.Fn)
if err != nil {
return err
}

return invoker.Invoke(c)
}

// Cleanup
func (c *Container) Cleanup() {
for _, key := range c.all() {
provider, _ := c.provider(key)
if cleanup, ok := provider.(cleanup); ok {
for _, p := range c.all() {
if cleanup, ok := p.(cleanup); ok {
cleanup.cleanup()
}
}
}

// addProvider
func (c *Container) addProvider(p provider) {
c.graph.AddNode(p.resultKey())
c.providers[p.resultKey()] = p
}

// provideEmbedParameters
func (c *Container) provideEmbedParameters(p provider) {
for _, parameter := range p.parameters() {
if parameter.embed {
c.addProvider(newEmbedProvider(parameter))
}
}
// add
func (c *Container) add(p provider) {
c.graph.AddNode(p.Key())
c.providers[p.Key()] = p
}

// exists checks that key registered in container graph.
Expand All @@ -177,64 +164,73 @@ func (c *Container) provider(k key) (provider, bool) {
if !c.exists(k) {
return nil, false
}

return c.providers[k], true
}

// all return all container keys.
func (c *Container) all() []key {
var keys []key

for _, node := range c.graph.Nodes() {
keys = append(keys, node.(key))
func (c *Container) all() []provider {
var providers []provider
for _, k := range c.graph.Nodes() {
p, _ := c.provider(k.(key))
providers = append(providers, p)
}

return keys
return providers
}

// processProviderInterface represents instances as interfaces and groups.
func (c *Container) processProviderInterface(provider provider, as interface{}) {
// create interface from embedParamProvider
iface := newInterfaceProvider(provider, as)
ifaceKey := iface.resultKey()

if c.graph.NodeExists(ifaceKey) {
// create interface from provider
iface := newProviderInterface(provider, as)
if c.graph.NodeExists(iface.Key()) {
// if iface already exists, restrict interface resolving
c.providers[ifaceKey] = iface.Multiple()
c.providers[iface.Key()] = newProviderStub(iface.Key(), "have several implementations")
} else {
// add interface node
c.graph.AddNode(ifaceKey)
c.providers[ifaceKey] = iface
c.graph.AddNode(iface.Key())
c.providers[iface.Key()] = iface
}

// create group
group := newGroupProvider(ifaceKey)
groupKey := group.resultKey()

group := newGroupProvider(iface.Key())
// check exists
if c.graph.NodeExists(groupKey) {
if c.exists(group.Key()) {
// if exists use existing group
group = c.providers[groupKey].(*interfaceGroup)
group = c.providers[group.Key()].(*interfaceGroup)
} else {
// else add new group to graph
c.graph.AddNode(groupKey)
c.providers[groupKey] = group
c.graph.AddNode(group.Key())
c.providers[group.Key()] = group
}

// add embedParamProvider ifaceKey into group
group.Add(provider.resultKey())
group.Add(provider.Key())
}

// registerParameters registers provider parameters in a dependency graph.
func (c *Container) registerParameters(provider provider) {
for _, parameter := range provider.parameters() {
_, exists := c.provider(parameter.resultKey())
// registerProviderParameters registers provider parameters in a dependency graph.
func (c *Container) registerProviderParameters(p provider) {
for _, param := range p.ParameterList() {
paramProvider, exists := c.resolveParameterProvider(param)
if exists {
c.graph.AddEdge(parameter.resultKey(), provider.resultKey())
c.graph.AddEdge(paramProvider.Key(), p.Key())
continue
}
if !exists && !param.optional {
panicf("%s: dependency %s not exists in container", p.Key(), param)
}
}
}

if !exists && !parameter.optional {
panicf("%s: dependency %s not exists in container", provider.resultKey(), parameter.resultKey())
// resolveParameterProvider lookup provider by parameter.
func (c *Container) resolveParameterProvider(param parameter) (provider, bool) {
for _, pt := range providerLookupSequence {
k := key{
name: param.name,
res: param.res,
typ: pt,
}
provider, exists := c.provider(k)
if !exists {
continue
}
return provider, true
}
return nil, false
}
35 changes: 20 additions & 15 deletions di/container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func TestContainerExtractErrors(t *testing.T) {
c.MustProvide(ditest.NewBar)
c.MustCompile()
var bar *ditest.Bar
c.MustExtractError(&bar, "*ditest.Bar: *ditest.Foo: internal error")
c.MustExtractError(&bar, "*ditest.Foo: internal error")
})

t.Run("extract interface with multiple implementations cause error", func(t *testing.T) {
Expand All @@ -150,7 +150,7 @@ func TestContainerExtractErrors(t *testing.T) {
c.MustCompile()

var extracted ditest.Fooer
c.MustExtractError(&extracted, "ditest.Fooer: ditest.Fooer have sereral implementations")
c.MustExtractError(&extracted, "ditest.Fooer: have several implementations")
})
}

Expand Down Expand Up @@ -519,7 +519,7 @@ func TestContainerCleanup(t *testing.T) {
}

func TestContainer_GraphVisualizing(t *testing.T) {
t.Run("", func(t *testing.T) {
t.Run("graph", func(t *testing.T) {
c := NewTestContainer(t)

c.MustProvide(ditest.NewLogger)
Expand All @@ -534,18 +534,22 @@ func TestContainer_GraphVisualizing(t *testing.T) {
Target: &graph,
}))

fmt.Println(graph.String())

require.Equal(t, `digraph {
subgraph cluster_s3 {
ID = "cluster_s3";
bgcolor="#E8E8E8";color="lightgrey";fontcolor="#46494C";fontname="COURIER";label="";style="rounded";
n8[color="#46494C",fontcolor="white",fontname="COURIER",label="*di.Graph",shape="box",style="filled"];
n10[color="#46494C",fontcolor="white",fontname="COURIER",label="*di.Graph",shape="box",style="filled"];
n9[color="#46494C",fontcolor="white",fontname="COURIER",label="di.Extractor",shape="box",style="filled"];
}subgraph cluster_s2 {
ID = "cluster_s2";
bgcolor="#E8E8E8";color="lightgrey";fontcolor="#46494C";fontname="COURIER";label="";style="rounded";
n5[color="#46494C",fontcolor="white",fontname="COURIER",label="*ditest.AccountController",shape="box",style="filled"];
n7[color="#46494C",fontcolor="white",fontname="COURIER",label="*ditest.AuthController",shape="box",style="filled"];
n6[color="#E54B4B",fontcolor="white",fontname="COURIER",label="[]ditest.Controller",shape="doubleoctagon",style="filled"];
n6[color="#46494C",fontcolor="white",fontname="COURIER",label="*ditest.AccountController",shape="box",style="filled"];
n8[color="#46494C",fontcolor="white",fontname="COURIER",label="*ditest.AuthController",shape="box",style="filled"];
n7[color="#E54B4B",fontcolor="white",fontname="COURIER",label="[]ditest.Controller",shape="doubleoctagon",style="filled"];
n4[color="#E5984B",fontcolor="white",fontname="COURIER",label="ditest.RouterParams",shape="box",style="filled"];
}subgraph cluster_s0 {
ID = "cluster_s0";
Expand All @@ -557,18 +561,19 @@ func TestContainer_GraphVisualizing(t *testing.T) {
bgcolor="#E8E8E8";color="lightgrey";fontcolor="#46494C";fontname="COURIER";label="";style="rounded";
n3[color="#46494C",fontcolor="white",fontname="COURIER",label="*http.ServeMux",shape="box",style="filled"];
n2[color="#46494C",fontcolor="white",fontname="COURIER",label="*http.Server",shape="box",style="filled"];
n4[color="#2589BD",fontcolor="white",fontname="COURIER",label="http.Handler",style="filled"];
n5[color="#2589BD",fontcolor="white",fontname="COURIER",label="http.Handler",style="filled"];
}splines="ortho";
n5->n6[color="#949494"];
n7->n6[color="#949494"];
n3->n4[color="#949494"];
n6->n7[color="#949494"];
n8->n7[color="#949494"];
n3->n5[color="#949494"];
n1->n2[color="#949494"];
n1->n3[color="#949494"];
n1->n5[color="#949494"];
n1->n7[color="#949494"];
n6->n3[color="#949494"];
n4->n2[color="#949494"];
n1->n6[color="#949494"];
n1->n8[color="#949494"];
n7->n4[color="#949494"];
n4->n3[color="#949494"];
n5->n2[color="#949494"];
}`, graph.String())
})
Expand Down
20 changes: 15 additions & 5 deletions di/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,21 @@ package di

import "fmt"

// ErrProviderNotFound
type ErrProviderNotFound struct {
k key
// ErrParameterProvideFailed
type ErrParameterProvideFailed struct {
k key
err error
}

func (e ErrProviderNotFound) Error() string {
return fmt.Sprintf("not exists in container")
func (e ErrParameterProvideFailed) Error() string {
return fmt.Sprintf("%s: %s", e.k, e.err)
}

// ErrParameterProviderNotFound
type ErrParameterProviderNotFound struct {
param parameter
}

func (e ErrParameterProviderNotFound) Error() string {
return fmt.Sprintf("%s: not exists in container", e.param)
}
Loading

0 comments on commit 00bf22f

Please sign in to comment.