diff --git a/modules/rpc/thrift_test.go b/modules/rpc/thrift_test.go index 1d6002397..b149c41e1 100644 --- a/modules/rpc/thrift_test.go +++ b/modules/rpc/thrift_test.go @@ -83,40 +83,6 @@ modules: assert.Equal(t, 2, len(dispatcher.Inbounds())) } -func TestThriftModuleSeparateGraph_OK(t *testing.T) { - t.Parallel() - di := dig.New() - snowflake := ThriftModule(okCreate, modules.WithRoles("rescue"), withGraph(di)) - - cfg := []byte(` -modules: - rpc: - inbounds: - - tchannel: - port: 0 -`) - - mci := service.ModuleCreateInfo{ - Name: "RPC", - Host: testHost{ - Host: service.NopHost(), - config: config.NewYAMLProviderFromBytes(cfg), - }, - Items: make(map[string]interface{}), - } - - special, err := snowflake(mci) - require.NoError(t, err) - assert.NotEmpty(t, special) - - testInitRunModule(t, special[0], mci) - - // Dispatcher must be resolved in the new graph - var dispatcher *yarpc.Dispatcher - assert.NoError(t, di.Resolve(&dispatcher)) - assert.Equal(t, 1, len(dispatcher.Inbounds())) -} - func TestThriftModule_BadOptions(t *testing.T) { modCreate := ThriftModule(okCreate, errorOption) _, err := modCreate(mch()) diff --git a/modules/rpc/yarpc.go b/modules/rpc/yarpc.go index 323c7a5d1..0a83a3d4d 100644 --- a/modules/rpc/yarpc.go +++ b/modules/rpc/yarpc.go @@ -33,6 +33,7 @@ import ( "go.uber.org/fx/ulog" errs "github.com/pkg/errors" + "go.uber.org/fx/dig" "go.uber.org/yarpc" "go.uber.org/yarpc/api/middleware" "go.uber.org/yarpc/api/transport" @@ -268,21 +269,19 @@ func newYARPCModule( module.config.inboundMiddleware = inboundMiddlewareFromCreateInfo(mi) module.config.onewayInboundMiddleware = onewayInboundMiddlewareFromCreateInfo(mi) - di := graphFromCreateInfo(mi) - // Try to resolve a controller first // TODO(alsam) use dig options when available, because we can overwrite the controller in case of multiple // modules registering a controller. - if err := di.Resolve(&module.controller); err != nil { + if err := dig.Resolve(&module.controller); err != nil { // Try to register it then module.controller = &dispatcherController{} - if errCr := di.Register(module.controller); errCr != nil { + if errCr := dig.Register(module.controller); errCr != nil { return nil, errs.Wrap(errCr, "can't register a dispatcher controller") } // Register dispatcher - if err := di.Register(&module.controller.dispatcher); err != nil { + if err := dig.Register(&module.controller.dispatcher); err != nil { return nil, errs.Wrap(err, "unable to register the dispatcher") } } diff --git a/modules/rpc/yarpc_options.go b/modules/rpc/yarpc_options.go index eb3d98602..c620177a1 100644 --- a/modules/rpc/yarpc_options.go +++ b/modules/rpc/yarpc_options.go @@ -21,7 +21,6 @@ package rpc import ( - "go.uber.org/fx/dig" "go.uber.org/fx/modules" "go.uber.org/fx/service" "go.uber.org/yarpc/api/middleware" @@ -73,20 +72,3 @@ func onewayInboundMiddlewareFromCreateInfo(mci service.ModuleCreateInfo) []middl // Intentionally panic if programmer adds non-middleware slice to the data return items.([]middleware.OnewayInbound) } - -func withGraph(graph dig.Graph) modules.Option { - return func(mci *service.ModuleCreateInfo) error { - mci.Items[_graphInterceptorKey] = graph - return nil - } -} - -func graphFromCreateInfo(mci service.ModuleCreateInfo) dig.Graph { - g, ok := mci.Items[_graphInterceptorKey] - if !ok { - return dig.DefaultGraph() - } - - // Intentionally panic if someone adds non-graph to Items. - return g.(dig.Graph) -} diff --git a/modules/rpc/yarpc_options_test.go b/modules/rpc/yarpc_options_test.go index 54c07fc5c..b301e967f 100644 --- a/modules/rpc/yarpc_options_test.go +++ b/modules/rpc/yarpc_options_test.go @@ -27,7 +27,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.uber.org/fx/dig" "go.uber.org/yarpc/api/middleware" ) @@ -73,15 +72,3 @@ func TestWithOnewayInboundMiddleware_PanicsBadData(t *testing.T) { opt(mc) }) } - -func TestWithGraph_OK(t *testing.T) { - graph := dig.New() - opt := withGraph(graph) - mc := &service.ModuleCreateInfo{ - Items: make(map[string]interface{}), - } - - assert.Equal(t, dig.DefaultGraph(), graphFromCreateInfo(*mc)) - require.NoError(t, opt(mc)) - assert.Equal(t, graph, graphFromCreateInfo(*mc)) -}