diff --git a/scheduler/cmd/agent/main.go b/scheduler/cmd/agent/main.go index 8a919b396c..c9151b9e3f 100644 --- a/scheduler/cmd/agent/main.go +++ b/scheduler/cmd/agent/main.go @@ -182,7 +182,13 @@ func main() { }() defer func() { _ = promMetrics.Stop() }() - rpHTTP := agent.NewReverseHTTPProxy(logger, uint(cli.ReverseProxyHttpPort), promMetrics) + rpHTTP := agent.NewReverseHTTPProxy( + logger, + cli.InferenceHost, + uint(cli.InferenceHttpPort), + uint(cli.ReverseProxyHttpPort), + promMetrics, + ) defer func() { _ = rpHTTP.Stop() }() rpGRPC := agent.NewReverseGRPCProxy( diff --git a/scheduler/pkg/agent/rproxy.go b/scheduler/pkg/agent/rproxy.go index bac5dad99c..a756bc1a94 100644 --- a/scheduler/pkg/agent/rproxy.go +++ b/scheduler/pkg/agent/rproxy.go @@ -3,8 +3,10 @@ package agent import ( "context" "fmt" + "net" "net/http" "net/http/httputil" + "net/url" "regexp" "strconv" "sync" @@ -28,13 +30,15 @@ const ( ) type reverseHTTPProxy struct { - stateManager *LocalStateManager - logger log.FieldLogger - server *http.Server - serverReady bool - port uint - mu sync.RWMutex - metrics metrics.MetricsHandler + stateManager *LocalStateManager + logger log.FieldLogger + server *http.Server + serverReady bool + backendHTTPServerHost string + backendHTTPServerPort uint + servicePort uint + mu sync.RWMutex + metrics metrics.MetricsHandler } // need to rewrite the host of the outbound request with the host of the incoming request @@ -78,7 +82,7 @@ func (rp *reverseHTTPProxy) Start() error { return fmt.Errorf("State not set, aborting") } - backend := rp.stateManager.GetBackEndPath() + backend := rp.getBackEndPath() proxy := httputil.NewSingleHostReverseProxy(backend) proxy.Transport = &http.Transport{ MaxIdleConns: maxIdleConnsHTTP, @@ -87,8 +91,8 @@ func (rp *reverseHTTPProxy) Start() error { MaxConnsPerHost: maxConnsPerHostHTTP, IdleConnTimeout: idleConnTimeoutSeconds * time.Second, } - rp.logger.Infof("Start reverse proxy on port %d for %s", rp.port, backend) - rp.server = &http.Server{Addr: ":" + strconv.Itoa(int(rp.port)), Handler: rp.addHandlers(proxy)} + rp.logger.Infof("Start reverse proxy on port %d for %s", rp.servicePort, backend) + rp.server = &http.Server{Addr: ":" + strconv.Itoa(int(rp.servicePort)), Handler: rp.addHandlers(proxy)} // TODO: check for errors? we rely for now on Ready go func() { rp.mu.Lock() @@ -103,6 +107,14 @@ func (rp *reverseHTTPProxy) Start() error { return nil } +func (rp *reverseHTTPProxy) getBackEndPath() *url.URL { + return &url.URL{ + Scheme: "http", + Host: net.JoinHostPort(rp.backendHTTPServerHost, strconv.Itoa(int(rp.backendHTTPServerPort))), + Path: "/", + } +} + func (rp *reverseHTTPProxy) Stop() error { // Shutdown is graceful rp.mu.Lock() @@ -128,14 +140,18 @@ func (rp *reverseHTTPProxy) Name() string { func NewReverseHTTPProxy( logger log.FieldLogger, - port uint, + backendHTTPServerHost string, + backendHTTPServerPort uint, + servicePort uint, metrics metrics.MetricsHandler, ) *reverseHTTPProxy { rp := reverseHTTPProxy{ - logger: logger.WithField("Source", "HTTPProxy"), - port: port, - metrics: metrics, + logger: logger.WithField("Source", "HTTPProxy"), + backendHTTPServerHost: backendHTTPServerHost, + backendHTTPServerPort: backendHTTPServerPort, + servicePort: servicePort, + metrics: metrics, } return &rp diff --git a/scheduler/pkg/agent/rproxy_test.go b/scheduler/pkg/agent/rproxy_test.go index 2aa5df43a2..c3bd04a61e 100644 --- a/scheduler/pkg/agent/rproxy_test.go +++ b/scheduler/pkg/agent/rproxy_test.go @@ -158,7 +158,12 @@ func (f fakeMetricsHandler) UnaryServerInterceptor() func(ctx context.Context, r func setupReverseProxy(logger log.FieldLogger, numModels int, modelPrefix string, rpPort int) *reverseHTTPProxy { v2Client := NewV2Client("localhost", backEndServerPort, logger, false) localCacheManager := setupLocalTestManager(numModels, modelPrefix, v2Client, numModels-2, 1) - rp := NewReverseHTTPProxy(logger, uint(rpPort), fakeMetricsHandler{}) + rp := NewReverseHTTPProxy( + logger, + "localhost", + uint(backEndServerPort), + uint(rpPort), + fakeMetricsHandler{}) rp.SetState(localCacheManager) return rp } @@ -210,7 +215,7 @@ func TestReverseProxySmoke(t *testing.T) { _, _ = rpHTTP.stateManager.modelVersions.addModelVersion( getDummyModelDetails(test.modelToLoad, uint64(1), uint32(1))) - // make a dummy predict call with any model name + // make a dummy predict call with any model name, URL does not matter, only headers inferV2Path := "/v2/models/RANDOM/infer" url := "http://localhost:" + strconv.Itoa(rpPort) + inferV2Path req, err := http.NewRequest(http.MethodPost, url, nil) diff --git a/scheduler/pkg/agent/state_manager.go b/scheduler/pkg/agent/state_manager.go index 1881db5629..8c9e51b659 100644 --- a/scheduler/pkg/agent/state_manager.go +++ b/scheduler/pkg/agent/state_manager.go @@ -2,7 +2,6 @@ package agent import ( "fmt" - "net/url" "sync" "github.com/seldonio/seldon-core/scheduler/apis/mlops/agent" @@ -30,10 +29,6 @@ type LocalStateManager struct { metrics metrics.MetricsHandler } -func (manager *LocalStateManager) GetBackEndPath() *url.URL { - return manager.v2Client.getUrl("/") -} - // this should be called from control plane (if directly) // the load request will always come with versioned model name (only one version) func (manager *LocalStateManager) LoadModelVersion(modelVersionDetails *agent.ModelVersion) error {