Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(csi): support multiple model registries #508

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/csi-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ jobs:
uses: helm/[email protected]
with:
node_image: "kindest/node:v1.27.11"
config: ./csi/test/kind_config.yaml

- name: Install kustomize
run: ./csi/scripts/install_kustomize.sh
Expand Down
9 changes: 6 additions & 3 deletions csi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ sequenceDiagram
U->>+MR: Register ML Model
MR-->>-U: Indexed Model
U->>U: Create InferenceService CR
Note right of U: The InferenceService should<br/>point to the model registry<br/>indexed model, e.g.,:<br/> model-registry://<model>/<version>
Note right of U: The InferenceService should<br/>point to the model registry<br/>indexed model, e.g.,:<br/> model-registry://<model-registry-url>/<model>/<version>
KC->>KC: React to InferenceService creation
KC->>+MD: Create Model Deployment
MD->>+MRSI: Initialization (Download Model)
Expand Down Expand Up @@ -66,14 +66,17 @@ Which wil create the executable under `bin/mr-storage-initializer`.

You can run `main.go` (without building the executable) by running:
```bash
./bin/mr-storage-initializer "model-registry://model/version" "./"
./bin/mr-storage-initializer "model-registry://model-registry-url/model/version" "./"
```

or directly running the `main.go` skipping the previous step:
```bash
make SOURCE_URI=model-registry://model/version DEST_PATH=./ run
make SOURCE_URI=model-registry://model-registry-url/model/version DEST_PATH=./ run
```

> [!NOTE]
> `model-registry-url` is optional, if not provided the value of `MODEL_REGISTRY_BASE_URL` env variable will be used.

> [!NOTE]
> A Model Registry service should be up and running at `localhost:8080`.

Expand Down
6 changes: 5 additions & 1 deletion csi/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"log"
"os"

"github.com/kubeflow/model-registry/csi/pkg/modelregistry"
"github.com/kubeflow/model-registry/csi/pkg/storage"
"github.com/kubeflow/model-registry/pkg/openapi"
)
Expand Down Expand Up @@ -38,7 +39,10 @@ func main() {
cfg := openapi.NewConfiguration()
cfg.Host = baseUrl
cfg.Scheme = scheme
provider, err := storage.NewModelRegistryProvider(cfg)

apiClient := modelregistry.NewAPIClient(cfg, sourceUri)

provider, err := storage.NewModelRegistryProvider(apiClient)
if err != nil {
log.Fatalf("Error initiliazing model registry provider: %v", err)
}
Expand Down
5 changes: 5 additions & 0 deletions csi/pkg/constants/constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package constants

import kserve "github.com/kserve/kserve/pkg/agent/storage"

const MR kserve.Protocol = "model-registry://"
41 changes: 41 additions & 0 deletions csi/pkg/modelregistry/api_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package modelregistry

import (
"context"
"log"
"strings"

"github.com/kubeflow/model-registry/csi/pkg/constants"
"github.com/kubeflow/model-registry/pkg/openapi"
)

func NewAPIClient(cfg *openapi.Configuration, storageUri string) *openapi.APIClient {
client := openapi.NewAPIClient(cfg)

// Parse the URI to retrieve the needed information to query model registry (modelArtifact)
mrUri := strings.TrimPrefix(storageUri, string(constants.MR))

tokens := strings.SplitN(mrUri, "/", 3)

if len(tokens) < 2 {
return client
}

newCfg := openapi.NewConfiguration()
newCfg.Host = tokens[0]
newCfg.Scheme = cfg.Scheme

newClient := openapi.NewAPIClient(newCfg)

if len(tokens) == 2 {
// Check if the model registry service is available
_, _, err := newClient.ModelRegistryServiceAPI.GetRegisteredModels(context.Background()).Execute()
if err != nil {
log.Printf("Falling back to base url %s for model registry service", cfg.Host)

return client
}
}

return newClient
}
156 changes: 108 additions & 48 deletions csi/pkg/storage/modelregistry_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,102 +2,97 @@ package storage

import (
"context"
"errors"
"fmt"
"log"
"regexp"
"strings"

kserve "github.com/kserve/kserve/pkg/agent/storage"
"github.com/kubeflow/model-registry/csi/pkg/constants"
"github.com/kubeflow/model-registry/pkg/openapi"
)

const MR kserve.Protocol = "model-registry://"
var (
_ kserve.Provider = (*ModelRegistryProvider)(nil)
ErrInvalidMRURI = errors.New("invalid model registry URI, use like model-registry://{dnsName}/{registeredModelName}/{versionName}")
ErrNoVersionAssociated = errors.New("no versions associated to registered model")
ErrNoArtifactAssociated = errors.New("no artifacts associated to model version")
ErrNoModelArtifact = errors.New("no model artifact found for model version")
ErrModelArtifactEmptyURI = errors.New("model artifact has empty URI")
ErrNoStorageURI = errors.New("there is no storageUri supplied")
ErrNoProtocolInSTorageURI = errors.New("there is no protocol specified for the storageUri")
ErrProtocolNotSupported = errors.New("protocol not supported for storageUri")
ErrFetchingModelVersion = errors.New("error fetching model version")
ErrFetchingModelVersions = errors.New("error fetching model versions")
)

type ModelRegistryProvider struct {
Client *openapi.APIClient
Providers map[kserve.Protocol]kserve.Provider
}

func NewModelRegistryProvider(cfg *openapi.Configuration) (*ModelRegistryProvider, error) {
client := openapi.NewAPIClient(cfg)

func NewModelRegistryProvider(client *openapi.APIClient) (*ModelRegistryProvider, error) {
return &ModelRegistryProvider{
Client: client,
Providers: map[kserve.Protocol]kserve.Provider{},
}, nil
}

var _ kserve.Provider = (*ModelRegistryProvider)(nil)

// storageUri formatted like model-registry://{registeredModelName}/{versionName}
// storageUri formatted like model-registry://{modelRegistryUrl}/{registeredModelName}/{versionName}
func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string, storageUri string) error {
log.Printf("Download model indexed in model registry: modelName=%s, storageUri=%s, modelDir=%s", modelName, storageUri, modelDir)

// Parse the URI to retrieve the needed information to query model registry (modelArtifact)
mrUri := strings.TrimPrefix(storageUri, string(MR))
tokens := strings.SplitN(mrUri, "/", 2)
log.Printf("Download model indexed in model registry: modelName=%s, storageUri=%s, modelDir=%s",
modelName,
storageUri,
modelDir,
)

if len(tokens) == 0 || len(tokens) > 2 {
return fmt.Errorf("invalid model registry URI, use like model-registry://{registeredModelName}/{versionName}")
registeredModelName, versionName, err := p.parseModelVersion(storageUri)
if err != nil {
return err
}

registeredModelName := tokens[0]
var versionName *string
if len(tokens) == 2 {
versionName = &tokens[1]
}
log.Printf("Fetching model: registeredModelName=%s, versionName=%v", registeredModelName, versionName)

// Fetch the registered model
model, _, err := p.Client.ModelRegistryServiceAPI.FindRegisteredModel(context.Background()).Name(registeredModelName).Execute()
if err != nil {
return err
}

// Fetch model version by name or latest if not specified
var version *openapi.ModelVersion
if versionName != nil {
version, _, err = p.Client.ModelRegistryServiceAPI.FindModelVersion(context.Background()).Name(*versionName).ParentResourceId(*model.Id).Execute()
if err != nil {
return err
}
} else {
versions, _, err := p.Client.ModelRegistryServiceAPI.GetRegisteredModelVersions(context.Background(), *model.Id).
OrderBy(openapi.ORDERBYFIELD_CREATE_TIME).
SortOrder(openapi.SORTORDER_DESC).
Execute()
if err != nil {
return err
}
log.Printf("Fetching model version: model=%v", model)

if versions.Size == 0 {
return fmt.Errorf("no versions associated to registered model %s", registeredModelName)
}
version = &versions.Items[0]
// Fetch model version by name or latest if not specified
version, err := p.fetchModelVersion(versionName, registeredModelName, model)
if err != nil {
return err
}

log.Printf("Fetching model artifacts: version=%v", version)

artifacts, _, err := p.Client.ModelRegistryServiceAPI.GetModelVersionArtifacts(context.Background(), *version.Id).
OrderBy(openapi.ORDERBYFIELD_CREATE_TIME).
// OrderBy(openapi.ORDERBYFIELD_CREATE_TIME). not supported
SortOrder(openapi.SORTORDER_DESC).
Execute()
if err != nil {
return err
}

if artifacts.Size == 0 {
return fmt.Errorf("no artifacts associated to model version %s", *version.Id)
return fmt.Errorf("%w %s", ErrNoArtifactAssociated, *version.Id)
}

modelArtifact := artifacts.Items[0].ModelArtifact
if modelArtifact == nil {
return fmt.Errorf("no model artifact found for model version %s", *version.Id)
return fmt.Errorf("%w %s", ErrNoModelArtifact, *version.Id)
}

// Call appropriate kserve provider based on the indexed model artifact URI
if modelArtifact.Uri == nil {
return fmt.Errorf("model artifact %s has empty URI", *modelArtifact.Id)
return fmt.Errorf("%w %s", ErrModelArtifactEmptyURI, *modelArtifact.Id)
}

protocol, err := extractProtocol(*modelArtifact.Uri)
protocol, err := p.extractProtocol(*modelArtifact.Uri)
if err != nil {
return err
}
Expand All @@ -110,19 +105,84 @@ func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string,
return provider.DownloadModel(modelDir, "", *modelArtifact.Uri)
}

func extractProtocol(storageURI string) (kserve.Protocol, error) {
// Possible URIs:
// (1) model-registry://{modelName}
// (2) model-registry://{modelName}/{modelVersion}
// (3) model-registry://{modelRegistryUrl}/{modelName}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not completely sure this is actually supported, because if you have just 2 tokens (after the trim) it will return the default apiClient (i.e., using the default env model registry).

I think here the problem would be, how do you know whether the user is providing option 2 or 3?

I tried running make SOURCE_URI=model-registry://model-registry-url/model DEST_PATH=./ run and in fact model-registry-url is interpreted as being the registered model, which is actually not what the user was trying to do.

I think that we have two options here:

  1. If the model-registry-url is provided, users cannot omit the version
  2. We use a different delimiter for the model-registry-url, e.g., model-registry://{modelRegistryUrl}:{modelName}

Copy link
Contributor Author

@Al-Pragliola Al-Pragliola Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way I approached this problem can be found here:

https://github.com/Al-Pragliola/model-registry/blob/feat/multi-mr-registries-csi-support/csi/pkg/modelregistry/api_client.go#L24

We try to reach (token 0) as a model registry and on failure we assume that it is a model name and not a valid mr url (falling back to the default url from env var), it's not perfect but I think it might be a fair compromise

by doing this in the function from this comment:

	// Check if the first token is the host and remove it so that we reduce cases (3) and (4) to (1) and (2)
	if len(tokens) >= 2 && p.Client.GetConfig().Host == tokens[0] {
		tokens = tokens[1:]
	}
  • case (1) stays the same
  • case (2) stays the same
  • case (3) by removing token[0] becomes (1)
  • case (4) by removing token[0] becomes (2)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry missed that, I tried and I can confirm that if the "host" is reachable it works as expected:

$ make SOURCE_URI=model-registry://localhost:8080/mymodel DEST_PATH=./ run                                                                                             [17:59:56]
"/usr/bin/go" fmt ./...
"/usr/bin/go" vet ./...
"/usr/bin/go" run ./main.go model-registry://localhost:8080/mymodel ./
2024/10/24 18:00:02 Initializing, args: src_uri [model-registry://localhost:8080/mymodel] dest_path[ [./]
2024/10/24 18:00:02 Download model indexed in model registry: modelName=, storageUri=model-registry://localhost:8080/mymodel, modelDir=./
2024/10/24 18:00:02 Fetching model: registeredModelName=mymodel, versionName=<nil>
2024/10/24 18:00:02 404 Not Found
exit status 1
make: *** [Makefile:59: run] Error 1

With a model registry running at localhost:8080.

My main concern with this assumption is that, if there is a model registry running but for any reason it is not accessible/reachable we are going to make the wrong assumption that it is a registeredModel name , right? And the error would be misleading to the user as it will find in the logs.

What do you think?

Copy link
Contributor Author

@Al-Pragliola Al-Pragliola Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep that's true, the only hint that user have from the logs is this:

log.Printf("Falling back to base url %s for model registry service", cfg.Host)

We can improve this message telling the user that it failed to reach url from uri and it's going to use fallback url


I wanted this to be as retrocompatible as possible, otherwise there are many alternatives, like a query parameter approach:

model-registry://url?modelName=x&modelVersion=y
model-registry://?modelName=x&modelVersion=y

or the other delimiter you mentioned.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comment below for line 129 my2c

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw forgot to mention; options 3-4 also consistent with standard URLs and also KServe examples

    - regex: "https://(.+?).blob.core.windows.net/(.+)"
    - regex: "https://(.+?).file.core.windows.net/(.+)"

much appreciated @Al-Pragliola

// (4) model-registry://{modelRegistryUrl}/{modelName}/{modelVersion}
func (p *ModelRegistryProvider) parseModelVersion(storageUri string) (string, *string, error) {
var versionName *string

// Parse the URI to retrieve the needed information to query model registry (modelArtifact)
mrUri := strings.TrimPrefix(storageUri, string(constants.MR))

tokens := strings.SplitN(mrUri, "/", 3)

if len(tokens) == 0 || len(tokens) > 3 {
return "", nil, ErrInvalidMRURI
}

// Check if the first token is the host and remove it so that we reduce cases (3) and (4) to (1) and (2)
if len(tokens) >= 2 && p.Client.GetConfig().Host == tokens[0] {
tokens = tokens[1:]
}

Comment on lines +132 to +136
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My main concern with this assumption is that, if there is a model registry running but for any reason it is not accessible/reachable we are going to make the wrong assumption that it is a registeredModel name , right? And the error would be misleading to the user as it will find in the logs.

We can improve this message telling the user that it failed to reach url from uri and it's going to use fallback url

Here (line 129) you have done all the storageURI parsing to determine in which case you are.
To me, add a Log info here that shows ~circa

Parsed storageUri=... as: modelRegistryUrl=... modelName=... modelVersion=...

This way, we're being explicit on how we interpret what we have received based on the documentation provided. wdyt?

Copy link
Member

@lampajr lampajr Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you @tarilabs @Al-Pragliola , I think that providing a more meaningful and explicit log message would be enough for this use case!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

	log.Printf("Parsed storageUri=%s as: modelRegistryUrl=%s, registeredModelName=%s, versionName=%v",
		storageUri,
		p.Client.GetConfig().Host,
		registeredModelName,
		versionName,
	)

registeredModelName := tokens[0]

if len(tokens) == 2 {
versionName = &tokens[1]
}

return registeredModelName, versionName, nil
}

func (p *ModelRegistryProvider) fetchModelVersion(
versionName *string,
registeredModelName string,
model *openapi.RegisteredModel,
) (*openapi.ModelVersion, error) {
if versionName != nil {
version, _, err := p.Client.ModelRegistryServiceAPI.
FindModelVersion(context.Background()).
Name(*versionName).
ParentResourceId(*model.Id).
Execute()
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrFetchingModelVersion, err)
}

return version, nil
}

versions, _, err := p.Client.ModelRegistryServiceAPI.GetRegisteredModelVersions(context.Background(), *model.Id).
// OrderBy(openapi.ORDERBYFIELD_CREATE_TIME). not supported
SortOrder(openapi.SORTORDER_DESC).
Execute()
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrFetchingModelVersions, err)
}

if versions.Size == 0 {
return nil, fmt.Errorf("%w %s", ErrNoVersionAssociated, registeredModelName)
}

return &versions.Items[0], nil
}

func (*ModelRegistryProvider) extractProtocol(storageURI string) (kserve.Protocol, error) {
if storageURI == "" {
return "", fmt.Errorf("there is no storageUri supplied")
return "", ErrNoStorageURI
}

if !regexp.MustCompile("\\w+?://").MatchString(storageURI) {
return "", fmt.Errorf("there is no protocol specified for the storageUri")
if !regexp.MustCompile(`\w+?://`).MatchString(storageURI) {
return "", ErrNoProtocolInSTorageURI
}

for _, prefix := range kserve.SupportedProtocols {
if strings.HasPrefix(storageURI, string(prefix)) {
return prefix, nil
}
}
return "", fmt.Errorf("protocol not supported for storageUri")

return "", ErrProtocolNotSupported
}
8 changes: 5 additions & 3 deletions csi/scripts/install_modelregistry.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ if ! kubectl get namespace "$namespace" &> /dev/null; then
fi
# Apply model-registry kustomize manifests
echo Using model registry image: $image
cd $MR_ROOT/manifests/kustomize/base && kustomize edit set image kubeflow/model-registry:latest=${image} && cd -
cd $MR_ROOT/manifests/kustomize/base && kustomize edit set image kubeflow/model-registry:latest=${image} && \
kustomize edit set namespace $namespace && cd -
cd $MR_ROOT/manifests/kustomize/overlays/db && kustomize edit set namespace $namespace && cd -
kubectl -n $namespace apply -k "$MR_ROOT/manifests/kustomize/overlays/db"

# Wait for model registry deployment
modelregistry=$(kubectl get pod -n kubeflow --selector="component=model-registry-server" --output jsonpath='{.items[0].metadata.name}')
kubectl wait --for=condition=Ready pod/$modelregistry -n $namespace --timeout=6m
modelregistry=$(kubectl get pod -n $namespace --selector="component=model-registry-server" --output jsonpath='{.items[0].metadata.name}')
kubectl wait --for=condition=Ready pod/$modelregistry -n $namespace --timeout=6m
Loading