Skip to content

Commit

Permalink
Support map value groups
Browse files Browse the repository at this point in the history
This revision allows dig to specify value groups of map type.

For example:

```
type Params struct {
	dig.In

	Things      []int          `group:"foogroup"`
	MapOfThings map[string]int `group:"foogroup"`
}
type Result struct {
	dig.Out

	Int1 int `name:"foo1" group:"foogroup"`
	Int2 int `name:"foo2" group:"foogroup"`
	Int3 int `name:"foo3" group:"foogroup"`
}

c.Provide(func() Result {
		return Result{Int1: 1, Int2: 2, Int3: 3}
	})

c.Invoke(func(p Params) {
})
```

p.Things will be a value group slice as per usual, containing the
elements {1,2,3} in an arbitrary order.

p.MapOfThings will be a key-value pairing of
{"foo1":1, "foo2":2, "foo3":3}.
  • Loading branch information
jquirke committed Mar 19, 2023
1 parent e781757 commit 5a24c26
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 31 deletions.
155 changes: 154 additions & 1 deletion dig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,27 @@ func TestGroups(t *testing.T) {
})
})

t.Run("provide multiple with the same name and group but different type", func(t *testing.T) {
c := digtest.New(t)
type A struct{}
type B struct{}
type ret1 struct {
dig.Out
*A `name:"foo" group:"foos"`
}
type ret2 struct {
dig.Out
*B `name:"foo" group:"foos"`
}
c.RequireProvide(func() ret1 {
return ret1{A: &A{}}
})

c.RequireProvide(func() ret2 {
return ret2{B: &B{}}
})
})

t.Run("different types may be grouped", func(t *testing.T) {
c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0))))

Expand Down Expand Up @@ -1745,6 +1766,118 @@ func TestGroups(t *testing.T) {
assert.ElementsMatch(t, []string{"a"}, param.Value)
})
})
/* map tests */
t.Run("empty map received without provides", func(t *testing.T) {
c := digtest.New(t)

type in struct {
dig.In

Values map[string]int `group:"foo"`
}

c.RequireInvoke(func(i in) {
require.Empty(t, i.Values)
})
})

t.Run("map value group using dig.Name and dig.Group", func(t *testing.T) {
c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0))))

c.RequireProvide(func() int {
return 1
}, dig.Name("value1"), dig.Group("val"))
c.RequireProvide(func() int {
return 2
}, dig.Name("value2"), dig.Group("val"))
c.RequireProvide(func() int {
return 3
}, dig.Name("value3"), dig.Group("val"))

type in struct {
dig.In

Value1 int `name:"value1"`
Value2 int `name:"value2"`
Value3 int `name:"value3"`
Values []int `group:"val"`
ValueMap map[string]int `group:"val"`
}

c.RequireInvoke(func(i in) {
assert.Equal(t, []int{2, 3, 1}, i.Values)
assert.Equal(t, i.ValueMap["value1"], 1)
assert.Equal(t, i.ValueMap["value2"], 2)
assert.Equal(t, i.ValueMap["value3"], 3)
assert.Equal(t, i.Value1, 1)
assert.Equal(t, i.Value2, 2)
assert.Equal(t, i.Value3, 3)
})
})
t.Run("values are provided, map and name and slice", func(t *testing.T) {
c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0))))
type out struct {
dig.Out

Value1 int `name:"value1" group:"val"`
Value2 int `name:"value2" group:"val"`
Value3 int `name:"value3" group:"val"`
}

c.RequireProvide(func() out {
return out{Value1: 1, Value2: 2, Value3: 3}
})

type in struct {
dig.In

Value1 int `name:"value1"`
Value2 int `name:"value2"`
Value3 int `name:"value3"`
Values []int `group:"val"`
ValueMap map[string]int `group:"val"`
}

c.RequireInvoke(func(i in) {
assert.Equal(t, []int{2, 3, 1}, i.Values)
assert.Equal(t, i.ValueMap["value1"], 1)
assert.Equal(t, i.ValueMap["value2"], 2)
assert.Equal(t, i.ValueMap["value3"], 3)
assert.Equal(t, i.Value1, 1)
assert.Equal(t, i.Value2, 2)
assert.Equal(t, i.Value3, 3)
})
})

t.Run("Every item used in a map must have a named key", func(t *testing.T) {
c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0))))

type out struct {
dig.Out

Value1 int `name:"value1" group:"val"`
Value2 int `name:"value2" group:"val"`
Value3 int `group:"val"`
}

c.RequireProvide(func() out {
return out{Value1: 1, Value2: 2, Value3: 3}
})

type in struct {
dig.In

ValueMap map[string]int `group:"val"`
}
var called = false
err := c.Invoke(func(i in) { called = true })
dig.AssertErrorMatches(t, err,
`could not build arguments for function "go.uber.org/dig_test".TestGroups\S+`,
`dig_test.go:\d+`, // file:line
`every entry in a map value groups must have a name, group "val" is missing a name`)
assert.False(t, called, "shouldn't call invoked function when deps aren't available")
})

}

// --- END OF END TO END TESTS
Expand Down Expand Up @@ -2753,7 +2886,27 @@ func testProvideFailures(t *testing.T, dryRun bool) {
)
})

t.Run("provide multiple instances with the same name but different group", func(t *testing.T) {
t.Run("provide multiple instances with the same name and same group using options", func(t *testing.T) {
c := digtest.New(t, dig.DryRun(dryRun))
type A struct{}

c.RequireProvide(func() *A {
return &A{}
}, dig.Group("foos"), dig.Name("foo"))

err := c.Provide(func() *A {
return &A{}
}, dig.Group("foos"), dig.Name("foo"))
require.Error(t, err, "expected error on the second provide")
dig.AssertErrorMatches(t, err,
`cannot provide function "go.uber.org/dig_test".testProvideFailures\S+`,
`dig_test.go:\d+`, // file:line
`cannot provide \*dig_test.A\[name="foo"\] from \[1\]:`,
`already provided by "go.uber.org/dig_test".testProvideFailures\S+`,
)
})

t.Run("provide multiple instances with the same name and type but different group", func(t *testing.T) {
c := digtest.New(t, dig.DryRun(dryRun))
type A struct{}
type ret1 struct {
Expand Down
4 changes: 2 additions & 2 deletions graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type graphNode struct {
}

// graphHolder is the dependency graph of the container.
// It saves constructorNodes and paramGroupedSlice (value groups)
// It saves constructorNodes and paramGroupedCollection (value groups)
// as nodes in the graph.
// It implements the graph interface defined by internal/graph.
// It has 1-1 correspondence with the Scope whose graph it represents.
Expand Down Expand Up @@ -68,7 +68,7 @@ func (gh *graphHolder) EdgesFrom(u int) []int {
for _, param := range w.paramList.Params {
orders = append(orders, getParamOrder(gh, param)...)
}
case *paramGroupedSlice:
case *paramGroupedCollection:
providers := gh.s.getAllGroupProviders(w.Group, w.Type.Elem())
for _, provider := range providers {
orders = append(orders, provider.Order(gh.s))
Expand Down
69 changes: 46 additions & 23 deletions param.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ import (
// paramSingle An explicitly requested type.
// paramObject dig.In struct where each field in the struct can be another
// param.
// paramGroupedSlice
// A slice consuming a value group. This will receive all
// paramGroupedCollection
// A slice or map consuming a value group. This will receive all
// values produced with a `group:".."` tag with the same name
// as a slice.
// as a slice or map. For a map, every value produced with the
// same group name MUST have a name which will form the map key.

type param interface {
fmt.Stringer

Expand All @@ -59,7 +61,7 @@ var (
_ param = paramSingle{}
_ param = paramObject{}
_ param = paramList{}
_ param = paramGroupedSlice{}
_ param = paramGroupedCollection{}
)

// newParam builds a param from the given type. If the provided type is a
Expand Down Expand Up @@ -342,7 +344,7 @@ func getParamOrder(gh *graphHolder, param param) []int {
for _, provider := range providers {
orders = append(orders, provider.Order(gh.s))
}
case paramGroupedSlice:
case paramGroupedCollection:
// value group parameters have nodes of their own.
// We can directly return that here.
orders = append(orders, p.orders[gh.s])
Expand Down Expand Up @@ -401,7 +403,7 @@ func (po paramObject) Build(c containerStore) (reflect.Value, error) {
var softGroupsQueue []paramObjectField
var fields []paramObjectField
for _, f := range po.Fields {
if p, ok := f.Param.(paramGroupedSlice); ok && p.Soft {
if p, ok := f.Param.(paramGroupedCollection); ok && p.Soft {
softGroupsQueue = append(softGroupsQueue, f)
continue
}
Expand Down Expand Up @@ -451,7 +453,7 @@ func newParamObjectField(idx int, f reflect.StructField, c containerStore) (para

case f.Tag.Get(_groupTag) != "":
var err error
p, err = newParamGroupedSlice(f, c)
p, err = newParamGroupedCollection(f, c)
if err != nil {
return pof, err
}
Expand Down Expand Up @@ -488,29 +490,31 @@ func (pof paramObjectField) Build(c containerStore) (reflect.Value, error) {
return v, nil
}

// paramGroupedSlice is a param which produces a slice of values with the same
// paramGroupedCollection is a param which produces a slice or map of values with the same
// group name.
type paramGroupedSlice struct {
type paramGroupedCollection struct {
// Name of the group as specified in the `group:".."` tag.
Group string

// Type of the slice.
// Type of the map or slice.
Type reflect.Type

// Soft is used to denote a soft dependency between this param and its
// constructors, if it's true its constructors are only called if they
// provide another value requested in the graph
Soft bool

isMap bool
orders map[*Scope]int
}

func (pt paramGroupedSlice) String() string {
func (pt paramGroupedCollection) String() string {
// io.Reader[group="foo"] refers to a group of io.Readers called 'foo'
return fmt.Sprintf("%v[group=%q]", pt.Type.Elem(), pt.Group)
// JQTODO, different string for map
}

func (pt paramGroupedSlice) DotParam() []*dot.Param {
func (pt paramGroupedCollection) DotParam() []*dot.Param {
return []*dot.Param{
{
Node: &dot.Node{
Expand All @@ -521,28 +525,31 @@ func (pt paramGroupedSlice) DotParam() []*dot.Param {
}
}

// newParamGroupedSlice builds a paramGroupedSlice from the provided type with
// newParamGroupedCollection builds a paramGroupedCollection from the provided type with
// the given name.
//
// The type MUST be a slice type.
func newParamGroupedSlice(f reflect.StructField, c containerStore) (paramGroupedSlice, error) {
// The type MUST be a slice or map[string]T type.
func newParamGroupedCollection(f reflect.StructField, c containerStore) (paramGroupedCollection, error) {
g, err := parseGroupString(f.Tag.Get(_groupTag))
if err != nil {
return paramGroupedSlice{}, err
return paramGroupedCollection{}, err
}
pg := paramGroupedSlice{
isMap := f.Type.Kind() == reflect.Map && f.Type.Key().Kind() == reflect.String
isSlice := f.Type.Kind() == reflect.Slice
pg := paramGroupedCollection{
Group: g.Name,
Type: f.Type,
isMap: isMap,
orders: make(map[*Scope]int),
Soft: g.Soft,
}

name := f.Tag.Get(_nameTag)
optional, _ := isFieldOptional(f)
switch {
case f.Type.Kind() != reflect.Slice:
case !isMap && !isSlice:
return pg, newErrInvalidInput(
fmt.Sprintf("value groups may be consumed as slices only: field %q (%v) is not a slice", f.Name, f.Type), nil)
fmt.Sprintf("value groups may be consumed as slices or string-keyed maps only: field %q (%v) is not a slice or string-keyed map", f.Name, f.Type), nil)
case g.Flatten:
return pg, newErrInvalidInput(
fmt.Sprintf("cannot use flatten in parameter value groups: field %q (%v) specifies flatten", f.Name, f.Type), nil)
Expand All @@ -560,7 +567,7 @@ func newParamGroupedSlice(f reflect.StructField, c containerStore) (paramGrouped
// any of the parent Scopes. In the case where there are multiple scopes that
// are decorating the same type, the closest scope in effect will be replacing
// any decorated value groups provided in further scopes.
func (pt paramGroupedSlice) getDecoratedValues(c containerStore) (reflect.Value, bool) {
func (pt paramGroupedCollection) getDecoratedValues(c containerStore) (reflect.Value, bool) {
for _, c := range c.storesToRoot() {
if items, ok := c.getDecoratedValueGroup(pt.Group, pt.Type); ok {
return items, true
Expand All @@ -575,7 +582,7 @@ func (pt paramGroupedSlice) getDecoratedValues(c containerStore) (reflect.Value,
// The order in which the decorators are invoked is from the top level scope to
// the current scope, to account for decorators that decorate values that were
// already decorated.
func (pt paramGroupedSlice) callGroupDecorators(c containerStore) error {
func (pt paramGroupedCollection) callGroupDecorators(c containerStore) error {
stores := c.storesToRoot()
for i := len(stores) - 1; i >= 0; i-- {
c := stores[i]
Expand All @@ -600,7 +607,7 @@ func (pt paramGroupedSlice) callGroupDecorators(c containerStore) error {
// search the given container and its parent for matching group providers and
// call them to commit values. If an error is encountered, return the number
// of providers called and a non-nil error from the first provided.
func (pt paramGroupedSlice) callGroupProviders(c containerStore) (int, error) {
func (pt paramGroupedCollection) callGroupProviders(c containerStore) (int, error) {
itemCount := 0
for _, c := range c.storesToRoot() {
providers := c.getGroupProviders(pt.Group, pt.Type.Elem())
Expand All @@ -618,7 +625,7 @@ func (pt paramGroupedSlice) callGroupProviders(c containerStore) (int, error) {
return itemCount, nil
}

func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) {
func (pt paramGroupedCollection) Build(c containerStore) (reflect.Value, error) {
// do not call this if we are already inside a decorator since
// it will result in an infinite recursion. (i.e. decorate -> params.BuildList() -> Decorate -> params.BuildList...)
// this is safe since a value can be decorated at most once in a given scope.
Expand All @@ -644,6 +651,22 @@ func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) {
}

stores := c.storesToRoot()
if pt.isMap {
result := reflect.MakeMapWithSize(pt.Type, itemCount)
for _, c := range stores {
kgvs := c.getValueGroup(pt.Group, pt.Type.Elem())
for _, kgv := range kgvs {
if kgv.key == "" {
return _noValue, newErrInvalidInput(
fmt.Sprintf("every entry in a map value groups must have a name, group \"%v\" is missing a name", pt.Group),
nil,
)
}
result.SetMapIndex(reflect.ValueOf(kgv.key), kgv.value)
}
}
return result, nil
}
result := reflect.MakeSlice(pt.Type, 0, itemCount)
for _, c := range stores {
kgvs := c.getValueGroup(pt.Group, pt.Type.Elem())
Expand Down
Loading

0 comments on commit 5a24c26

Please sign in to comment.