Skip to content

Commit

Permalink
feat: Support BYOC (#171)
Browse files Browse the repository at this point in the history
Signed-off-by: xieydd <[email protected]>
  • Loading branch information
xieydd authored Sep 14, 2023
1 parent faf075d commit f604025
Show file tree
Hide file tree
Showing 39 changed files with 1,598 additions and 66 deletions.
2 changes: 2 additions & 0 deletions agent/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
FROM ubuntu:22.04

LABEL maintainer="[email protected]"
RUN apt-get -qq update \
&& apt-get -qq install -y --no-install-recommends ca-certificates curl

COPY agent /usr/bin/agent
ENTRYPOINT ["/usr/bin/agent"]
3 changes: 2 additions & 1 deletion agent/api/types/const.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package types

const (
LabelNamespace = "modelz.tensorchord.ai/namespace"
LabelNamespace = "modelz.tensorchord.ai/namespace"
LabelServerResource = "ai.tensorchord.server-resource"
)
38 changes: 38 additions & 0 deletions agent/api/types/modelz_cloud.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package types

import "time"

const (
ClusterStatusInit = "init"
ClusterStatusActive = "active"
ClusterStatusUnknown = "unknown"
)

const (
DailEndPointSuffix = "/api/v1/clusteragent/connect"
)

type AgentToken struct {
UID string `json:"uid,omitempty"`
Token string `json:"token,omitempty"`
Type string `json:"type,omitempty"`
}

type ManagedCluster struct {
ID string `json:"id,omitempty"`
TokenID string `json:"token_id,omitempty"`
Version string `json:"version,omitempty"`
KubernetesVersion string `json:"kubernetes_version,omitempty"`
Platform string `json:"platform,omitempty"`
Status string `json:"status,omitempty"`
CreatedAt time.Time `json:"created_at,omitempty"`
UpdatedAt time.Time `json:"updated_at,omitempty"`
Region string `json:"region,omitempty"`
ServerResources string `json:"server_resources,omitempty"`
}

type APIKeyMap map[string]string

type NamespaceList struct {
Items []string `json:"items,omitempty"`
}
7 changes: 6 additions & 1 deletion agent/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func defaultHTTPClient(host string) (*http.Client, error) {
_ = sockets.ConfigureTransport(transport, hostURL.Scheme, hostURL.Host)
return &http.Client{
Transport: transport,
CheckRedirect: CheckRedirect,
CheckRedirect: CheckRedirectKeepHeader,
}, nil
}

Expand All @@ -136,6 +136,11 @@ func CheckRedirect(req *http.Request, via []*http.Request) error {
return ErrRedirect
}

func CheckRedirectKeepHeader(req *http.Request, via []*http.Request) error {
req.Header = via[0].Header.Clone()
return nil
}

// DaemonHost returns the host address used by the client
func (cli *Client) DaemonHost() string {
return cli.host
Expand Down
27 changes: 16 additions & 11 deletions agent/client/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,22 @@ const defaultAddr = "0.0.0.0:8080"
const apiBasePath = ""

const (
gatewayInferControlPlanePath = "/system/inferences"
gatewayInferScaleControlPath = "/system/scale-inference"
gatewayInferInstanceControlPlanePath = "/system/inference/%s/instances"
gatewayInferInstanceExecControlPlanePath = "/system/inference/%s/instance/%s/exec"
gatewayServerControlPlanePath = "/system/servers"
gatewayServerLabelCreateControlPlanePath = "/system/server/%s/labels"
gatewayServerNodeDeleteControlPlanePath = "/system/server/%s/delete"
gatewayNamespaceControlPlanePath = "/system/namespaces"
gatewayBuildControlPlanePath = "/system/build"
gatewayBuildInstanceControlPlanePath = "/system/build/%s"
gatewayImageCacheControlPlanePath = "/system/image-cache"
gatewayInferControlPlanePath = "/system/inferences"
gatewayInferScaleControlPath = "/system/scale-inference"
gatewayInferInstanceControlPlanePath = "/system/inference/%s/instances"
gatewayInferInstanceExecControlPlanePath = "/system/inference/%s/instance/%s/exec"
gatewayServerControlPlanePath = "/system/servers"
gatewayServerLabelCreateControlPlanePath = "/system/server/%s/labels"
gatewayServerNodeDeleteControlPlanePath = "/system/server/%s/delete"
gatewayNamespaceControlPlanePath = "/system/namespaces"
gatewayBuildControlPlanePath = "/system/build"
gatewayBuildInstanceControlPlanePath = "/system/build/%s"
gatewayImageCacheControlPlanePath = "/system/image-cache"
modelzCloudClusterControlPlanePath = "/api/v1/users/%s/clusters/%s"
modelzCloudClusterWithUserControlPlanePath = "/api/v1/users/%s/clusters"
modelzCloudClusterAPIKeyControlPlanePath = "/api/v1/users/%s/clusters/%s/api_keys"
modelzCloudClusterNamespaceControlPlanePath = "/api/v1/users/%s/clusters/%s/namespaces"
modelzCloudClusterDeploymentControlPlanePath = "/api/v1/users/%s/clusters/%s/deployments/%s/agent"
)

const (
Expand Down
166 changes: 166 additions & 0 deletions agent/client/modelz_cloud.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
package client

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"time"

"github.com/sirupsen/logrus"
"github.com/tensorchord/openmodelz/agent/api/types"
"github.com/tensorchord/openmodelz/agent/pkg/consts"
"k8s.io/apimachinery/pkg/util/wait"
)

func (cli *Client) WaitForAPIServerReady() error {
err := wait.PollImmediateWithContext(context.Background(), time.Second, consts.DefaultAPIServerReadyTimeout, func(ctx context.Context) (bool, error) {
err, healthStatus := cli.waitForAPIServerReady(ctx)
if err != nil || healthStatus != http.StatusOK {
logrus.Warn("APIServer isn't ready yet, Waiting a little while.")
return false, err
}
return true, nil
})
if err != nil {
return fmt.Errorf("failed to wait for apiserver ready, %v", err)
}
return nil
}

func (cli *Client) waitForAPIServerReady(ctx context.Context) (error, int) {
urlValues := url.Values{}
resp, err := cli.get(ctx, "/healthz", urlValues, nil)
if err != nil {
return wrapResponseError(err, resp, "check apiserver is ready", ""), resp.statusCode
}
defer ensureReaderClosed(resp)
return nil, resp.statusCode
}

func (cli *Client) RegisterAgent(ctx context.Context, token string, cluster types.ManagedCluster) (string, string, error) {
urlValues := url.Values{}
agentToken, err := ParseAgentToken(token)
if err != nil {
return "", "", err
}
urlPath := fmt.Sprintf(modelzCloudClusterWithUserControlPlanePath, agentToken.UID)
headers := make(map[string][]string)
headers["Authorization"] = []string{"Bearer " + agentToken.Token}

resp, err := cli.post(ctx, urlPath, urlValues, cluster, headers)
if err != nil {
return "", "", wrapResponseError(err, resp, "register agent to modelz cloud", agentToken.UID)
}
defer ensureReaderClosed(resp)

err = json.NewDecoder(resp.body).Decode(&cluster)
if err != nil {
return "", "", err
}
return cluster.ID, cluster.TokenID, nil
}

func (cli *Client) UpdateAgentStatus(ctx context.Context, apiServerReady <-chan struct{}, token string, cluster types.ManagedCluster) error {
<-apiServerReady
urlValues := url.Values{}
agentToken, err := ParseAgentToken(token)
if err != nil {
return err
}
urlPath := fmt.Sprintf(modelzCloudClusterControlPlanePath, agentToken.UID, cluster.ID)
headers := make(map[string][]string)
headers["Authorization"] = []string{"Bearer " + agentToken.Token}

resp, err := cli.put(ctx, urlPath, urlValues, cluster, headers)
if err != nil {
return wrapResponseError(err, resp, "update agent status to modelz cloud", agentToken.UID)
}
defer ensureReaderClosed(resp)

if resp.statusCode == 200 {
return nil
}
return fmt.Errorf("failed to update agent status to modelz cloud, status code: %d", resp.statusCode)
}

func (cli *Client) GetAPIKeys(ctx context.Context, apiServerReady <-chan struct{}, token string, cluster string) (types.APIKeyMap, error) {
<-apiServerReady
urlValues := url.Values{}
agentToken, err := ParseAgentToken(token)
keys := types.APIKeyMap{}
if err != nil {
return keys, err
}
headers := make(map[string][]string)
headers["Authorization"] = []string{"Bearer " + agentToken.Token}

urlPath := fmt.Sprintf(modelzCloudClusterAPIKeyControlPlanePath, agentToken.UID, cluster)
resp, err := cli.get(ctx, urlPath, urlValues, headers)
if err != nil {
return keys, wrapResponseError(err, resp, "get api keys from modelz cloud", agentToken.UID)
}
defer ensureReaderClosed(resp)

err = json.NewDecoder(resp.body).Decode(&keys)
if err != nil {
return keys, err
}
return keys, nil
}

func (cli *Client) GetNamespaces(ctx context.Context, apiServerReady <-chan struct{}, token string, cluster string) (types.NamespaceList, error) {
<-apiServerReady
urlValues := url.Values{}
agentToken, err := ParseAgentToken(token)
ns := types.NamespaceList{}
if err != nil {
return ns, err
}
urlValues.Add("login_name", agentToken.UID)
headers := make(map[string][]string)
headers["Authorization"] = []string{"Bearer " + agentToken.Token}

resp, err := cli.get(ctx, fmt.Sprintf(modelzCloudClusterNamespaceControlPlanePath, agentToken.UID, cluster), urlValues, headers)
if err != nil {
return ns, wrapResponseError(err, resp, "get namespaces from modelz cloud", agentToken.UID)
}
defer ensureReaderClosed(resp)

err = json.NewDecoder(resp.body).Decode(&ns)
if err != nil {
return ns, err
}

ns.Items = append(ns.Items, GetNamespaceByUserID(agentToken.UID))
return ns, nil
}

func (cli *Client) GetUIDFromDeploymentID(ctx context.Context, token string, cluster string, deployment string) (string, error) {
urlValues := url.Values{}
agentToken, err := ParseAgentToken(token)
if err != nil {
return "", err
}
headers := make(map[string][]string)
headers["Authorization"] = []string{"Bearer " + agentToken.Token}
urlPath := fmt.Sprintf(modelzCloudClusterDeploymentControlPlanePath, agentToken.UID, cluster, deployment)

resp, err := cli.get(ctx, urlPath, urlValues, headers)
if err != nil {
return "", err
}
defer ensureReaderClosed(resp)

var uid string
err = json.NewDecoder(resp.body).Decode(&uid)
if err != nil {
return "", err
}

if resp.statusCode == 200 {
return uid, nil
}
return "", fmt.Errorf("failed to get uid from deployment id, status code: %d", resp.statusCode)
}
46 changes: 46 additions & 0 deletions agent/client/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package client

import (
"fmt"
"strings"

"github.com/cockroachdb/errors"
"github.com/tensorchord/openmodelz/agent/api/types"
)

const (
DefaultPrefix = "modelz-"
)

func ParseAgentToken(token string) (types.AgentToken, error) {
agentToken := types.AgentToken{}
if token == "" {
return agentToken, errors.New("agent token is empty")
}

strings := strings.Split(token, ":")
if len(strings) != 3 {
return agentToken, errors.New("invalid agent token")
}
agentToken.Type = strings[0]
agentToken.UID = strings[1]
agentToken.Token = strings[2]

return agentToken, nil
}

func GetNamespaceByUserID(uid string) string {
return fmt.Sprintf("%s%s", DefaultPrefix, uid)
}

func GetUserIDFromNamespace(ns string) (string, error) {
if len(ns) < 8 {
return "", fmt.Errorf("namespace too short")
}

if ns[:len(DefaultPrefix)] != DefaultPrefix {
return "", fmt.Errorf("namespace does not start with ")
}

return ns[7:], nil
}
12 changes: 12 additions & 0 deletions agent/pkg/app/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ func configFromCLI(c *cli.Context) config.Config {
cfg.Ingress.Domain = c.String(flagIngressDomain)
cfg.Ingress.AnyIPToDomain = c.Bool(flagIngressAnyIPToDomain)
cfg.Ingress.Namespace = c.String(flagIngressNamespace)
cfg.Ingress.TLSEnabled = c.Bool(flagIngressTLSEnabled)

// inference
cfg.Inference.LogTimeout = c.Duration(flagInferenceLogTimeout)
Expand Down Expand Up @@ -56,5 +57,16 @@ func configFromCLI(c *cli.Context) config.Config {
// postgres database
cfg.DB.EventEnabled = c.Bool(flagEventEnabled)
cfg.DB.URL = c.String(flagDBURL)

// modelz cloud
cfg.ModelZCloud.Enabled = c.Bool(flagModelZCloudEnabled)
cfg.ModelZCloud.URL = c.String(flagModelZCloudURL)
cfg.ModelZCloud.AgentToken = c.String(flagModelZCloudAgentToken)
cfg.ModelZCloud.HeartbeatInterval = c.Duration(flagModelZCloudAgentHeartbeatInterval)
cfg.ModelZCloud.Region = c.String(flagModelZCloudRegion)
cfg.ModelZCloud.UnifiedAPIKey = c.String(flagModelZCloudUnifiedAPIKey)
cfg.ModelZCloud.UpstreamTimeout = c.Duration(flagModelZCloudUpstreamTimeout)
cfg.ModelZCloud.MaxIdleConnections = c.Int(flagModelZCloudMaxIdleConnections)
cfg.ModelZCloud.MaxIdleConnectionsPerHost = c.Int(flagModelZCloudMaxIdleConnectionsPerHost)
return cfg
}
Loading

0 comments on commit f604025

Please sign in to comment.