-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(csi): support multiple model registries #508
Changes from 2 commits
1f3cb5d
1278ab9
1a6a412
4d9780f
d057a31
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,7 @@ jobs: | |
uses: helm/[email protected] | ||
with: | ||
node_image: "kindest/node:v1.27.11" | ||
config: ./csi/test/kind_config.yaml | ||
|
||
- name: Install kustomize | ||
run: ./csi/scripts/install_kustomize.sh | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
package constants | ||
|
||
import kserve "github.com/kserve/kserve/pkg/agent/storage" | ||
|
||
const MR kserve.Protocol = "model-registry://" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
package modelregistry | ||
|
||
import ( | ||
"context" | ||
"log" | ||
"strings" | ||
|
||
"github.com/kubeflow/model-registry/csi/pkg/constants" | ||
"github.com/kubeflow/model-registry/pkg/openapi" | ||
) | ||
|
||
func NewAPIClient(cfg *openapi.Configuration, storageUri string) *openapi.APIClient { | ||
client := openapi.NewAPIClient(cfg) | ||
|
||
// Parse the URI to retrieve the needed information to query model registry (modelArtifact) | ||
mrUri := strings.TrimPrefix(storageUri, string(constants.MR)) | ||
|
||
tokens := strings.SplitN(mrUri, "/", 3) | ||
|
||
if len(tokens) < 2 { | ||
return client | ||
} | ||
|
||
newCfg := openapi.NewConfiguration() | ||
newCfg.Host = tokens[0] | ||
newCfg.Scheme = cfg.Scheme | ||
|
||
newClient := openapi.NewAPIClient(newCfg) | ||
|
||
if len(tokens) == 2 { | ||
// Check if the model registry service is available | ||
_, _, err := newClient.ModelRegistryServiceAPI.GetRegisteredModels(context.Background()).Execute() | ||
if err != nil { | ||
log.Printf("Falling back to base url %s for model registry service", cfg.Host) | ||
|
||
return client | ||
} | ||
} | ||
|
||
return newClient | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,102 +2,97 @@ package storage | |
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"log" | ||
"regexp" | ||
"strings" | ||
|
||
kserve "github.com/kserve/kserve/pkg/agent/storage" | ||
"github.com/kubeflow/model-registry/csi/pkg/constants" | ||
"github.com/kubeflow/model-registry/pkg/openapi" | ||
) | ||
|
||
const MR kserve.Protocol = "model-registry://" | ||
var ( | ||
_ kserve.Provider = (*ModelRegistryProvider)(nil) | ||
ErrInvalidMRURI = errors.New("invalid model registry URI, use like model-registry://{dnsName}/{registeredModelName}/{versionName}") | ||
ErrNoVersionAssociated = errors.New("no versions associated to registered model") | ||
ErrNoArtifactAssociated = errors.New("no artifacts associated to model version") | ||
ErrNoModelArtifact = errors.New("no model artifact found for model version") | ||
ErrModelArtifactEmptyURI = errors.New("model artifact has empty URI") | ||
ErrNoStorageURI = errors.New("there is no storageUri supplied") | ||
ErrNoProtocolInSTorageURI = errors.New("there is no protocol specified for the storageUri") | ||
ErrProtocolNotSupported = errors.New("protocol not supported for storageUri") | ||
ErrFetchingModelVersion = errors.New("error fetching model version") | ||
ErrFetchingModelVersions = errors.New("error fetching model versions") | ||
) | ||
|
||
type ModelRegistryProvider struct { | ||
Client *openapi.APIClient | ||
Providers map[kserve.Protocol]kserve.Provider | ||
} | ||
|
||
func NewModelRegistryProvider(cfg *openapi.Configuration) (*ModelRegistryProvider, error) { | ||
client := openapi.NewAPIClient(cfg) | ||
|
||
func NewModelRegistryProvider(client *openapi.APIClient) (*ModelRegistryProvider, error) { | ||
return &ModelRegistryProvider{ | ||
Client: client, | ||
Providers: map[kserve.Protocol]kserve.Provider{}, | ||
}, nil | ||
} | ||
|
||
var _ kserve.Provider = (*ModelRegistryProvider)(nil) | ||
|
||
// storageUri formatted like model-registry://{registeredModelName}/{versionName} | ||
// storageUri formatted like model-registry://{modelRegistryUrl}/{registeredModelName}/{versionName} | ||
func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string, storageUri string) error { | ||
log.Printf("Download model indexed in model registry: modelName=%s, storageUri=%s, modelDir=%s", modelName, storageUri, modelDir) | ||
|
||
// Parse the URI to retrieve the needed information to query model registry (modelArtifact) | ||
mrUri := strings.TrimPrefix(storageUri, string(MR)) | ||
tokens := strings.SplitN(mrUri, "/", 2) | ||
log.Printf("Download model indexed in model registry: modelName=%s, storageUri=%s, modelDir=%s", | ||
modelName, | ||
storageUri, | ||
modelDir, | ||
) | ||
|
||
if len(tokens) == 0 || len(tokens) > 2 { | ||
return fmt.Errorf("invalid model registry URI, use like model-registry://{registeredModelName}/{versionName}") | ||
registeredModelName, versionName, err := p.parseModelVersion(storageUri) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
registeredModelName := tokens[0] | ||
var versionName *string | ||
if len(tokens) == 2 { | ||
versionName = &tokens[1] | ||
} | ||
log.Printf("Fetching model: registeredModelName=%s, versionName=%v", registeredModelName, versionName) | ||
|
||
// Fetch the registered model | ||
model, _, err := p.Client.ModelRegistryServiceAPI.FindRegisteredModel(context.Background()).Name(registeredModelName).Execute() | ||
if err != nil { | ||
return err | ||
} | ||
|
||
// Fetch model version by name or latest if not specified | ||
var version *openapi.ModelVersion | ||
if versionName != nil { | ||
version, _, err = p.Client.ModelRegistryServiceAPI.FindModelVersion(context.Background()).Name(*versionName).ParentResourceId(*model.Id).Execute() | ||
if err != nil { | ||
return err | ||
} | ||
} else { | ||
versions, _, err := p.Client.ModelRegistryServiceAPI.GetRegisteredModelVersions(context.Background(), *model.Id). | ||
OrderBy(openapi.ORDERBYFIELD_CREATE_TIME). | ||
SortOrder(openapi.SORTORDER_DESC). | ||
Execute() | ||
if err != nil { | ||
return err | ||
} | ||
log.Printf("Fetching model version: model=%v", model) | ||
|
||
if versions.Size == 0 { | ||
return fmt.Errorf("no versions associated to registered model %s", registeredModelName) | ||
} | ||
version = &versions.Items[0] | ||
// Fetch model version by name or latest if not specified | ||
version, err := p.fetchModelVersion(versionName, registeredModelName, model) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
log.Printf("Fetching model artifacts: version=%v", version) | ||
|
||
artifacts, _, err := p.Client.ModelRegistryServiceAPI.GetModelVersionArtifacts(context.Background(), *version.Id). | ||
OrderBy(openapi.ORDERBYFIELD_CREATE_TIME). | ||
// OrderBy(openapi.ORDERBYFIELD_CREATE_TIME). not supported | ||
SortOrder(openapi.SORTORDER_DESC). | ||
Execute() | ||
if err != nil { | ||
return err | ||
} | ||
|
||
if artifacts.Size == 0 { | ||
return fmt.Errorf("no artifacts associated to model version %s", *version.Id) | ||
return fmt.Errorf("%w %s", ErrNoArtifactAssociated, *version.Id) | ||
} | ||
|
||
modelArtifact := artifacts.Items[0].ModelArtifact | ||
if modelArtifact == nil { | ||
return fmt.Errorf("no model artifact found for model version %s", *version.Id) | ||
return fmt.Errorf("%w %s", ErrNoModelArtifact, *version.Id) | ||
} | ||
|
||
// Call appropriate kserve provider based on the indexed model artifact URI | ||
if modelArtifact.Uri == nil { | ||
return fmt.Errorf("model artifact %s has empty URI", *modelArtifact.Id) | ||
return fmt.Errorf("%w %s", ErrModelArtifactEmptyURI, *modelArtifact.Id) | ||
} | ||
|
||
protocol, err := extractProtocol(*modelArtifact.Uri) | ||
protocol, err := p.extractProtocol(*modelArtifact.Uri) | ||
if err != nil { | ||
return err | ||
} | ||
|
@@ -110,19 +105,84 @@ func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string, | |
return provider.DownloadModel(modelDir, "", *modelArtifact.Uri) | ||
} | ||
|
||
func extractProtocol(storageURI string) (kserve.Protocol, error) { | ||
// Possible URIs: | ||
// (1) model-registry://{modelName} | ||
// (2) model-registry://{modelName}/{modelVersion} | ||
// (3) model-registry://{modelRegistryUrl}/{modelName} | ||
// (4) model-registry://{modelRegistryUrl}/{modelName}/{modelVersion} | ||
func (p *ModelRegistryProvider) parseModelVersion(storageUri string) (string, *string, error) { | ||
var versionName *string | ||
|
||
// Parse the URI to retrieve the needed information to query model registry (modelArtifact) | ||
mrUri := strings.TrimPrefix(storageUri, string(constants.MR)) | ||
|
||
tokens := strings.SplitN(mrUri, "/", 3) | ||
|
||
if len(tokens) == 0 || len(tokens) > 3 { | ||
return "", nil, ErrInvalidMRURI | ||
} | ||
|
||
// Check if the first token is the host and remove it so that we reduce cases (3) and (4) to (1) and (2) | ||
if len(tokens) >= 2 && p.Client.GetConfig().Host == tokens[0] { | ||
tokens = tokens[1:] | ||
} | ||
|
||
Comment on lines
+132
to
+136
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Here (line 129) you have done all the storageURI parsing to determine in which case you are.
This way, we're being explicit on how we interpret what we have received based on the documentation provided. wdyt? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with you @tarilabs @Al-Pragliola , I think that providing a more meaningful and explicit log message would be enough for this use case! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added log.Printf("Parsed storageUri=%s as: modelRegistryUrl=%s, registeredModelName=%s, versionName=%v",
storageUri,
p.Client.GetConfig().Host,
registeredModelName,
versionName,
) |
||
registeredModelName := tokens[0] | ||
|
||
if len(tokens) == 2 { | ||
versionName = &tokens[1] | ||
} | ||
|
||
return registeredModelName, versionName, nil | ||
} | ||
|
||
func (p *ModelRegistryProvider) fetchModelVersion( | ||
versionName *string, | ||
registeredModelName string, | ||
model *openapi.RegisteredModel, | ||
) (*openapi.ModelVersion, error) { | ||
if versionName != nil { | ||
version, _, err := p.Client.ModelRegistryServiceAPI. | ||
FindModelVersion(context.Background()). | ||
Name(*versionName). | ||
ParentResourceId(*model.Id). | ||
Execute() | ||
if err != nil { | ||
return nil, fmt.Errorf("%w: %w", ErrFetchingModelVersion, err) | ||
} | ||
|
||
return version, nil | ||
} | ||
|
||
versions, _, err := p.Client.ModelRegistryServiceAPI.GetRegisteredModelVersions(context.Background(), *model.Id). | ||
// OrderBy(openapi.ORDERBYFIELD_CREATE_TIME). not supported | ||
SortOrder(openapi.SORTORDER_DESC). | ||
Execute() | ||
if err != nil { | ||
return nil, fmt.Errorf("%w: %w", ErrFetchingModelVersions, err) | ||
} | ||
|
||
if versions.Size == 0 { | ||
return nil, fmt.Errorf("%w %s", ErrNoVersionAssociated, registeredModelName) | ||
} | ||
|
||
return &versions.Items[0], nil | ||
} | ||
|
||
func (*ModelRegistryProvider) extractProtocol(storageURI string) (kserve.Protocol, error) { | ||
if storageURI == "" { | ||
return "", fmt.Errorf("there is no storageUri supplied") | ||
return "", ErrNoStorageURI | ||
} | ||
|
||
if !regexp.MustCompile("\\w+?://").MatchString(storageURI) { | ||
return "", fmt.Errorf("there is no protocol specified for the storageUri") | ||
if !regexp.MustCompile(`\w+?://`).MatchString(storageURI) { | ||
return "", ErrNoProtocolInSTorageURI | ||
} | ||
|
||
for _, prefix := range kserve.SupportedProtocols { | ||
if strings.HasPrefix(storageURI, string(prefix)) { | ||
return prefix, nil | ||
} | ||
} | ||
return "", fmt.Errorf("protocol not supported for storageUri") | ||
|
||
return "", ErrProtocolNotSupported | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not completely sure this is actually supported, because if you have just 2 tokens (after the trim) it will return the default apiClient (i.e., using the default env model registry).
I think here the problem would be, how do you know whether the user is providing option 2 or 3?
I tried running
make SOURCE_URI=model-registry://model-registry-url/model DEST_PATH=./ run
and in factmodel-registry-url
is interpreted as being the registered model, which is actually not what the user was trying to do.I think that we have two options here:
model-registry-url
is provided, users cannot omit theversion
model-registry-url
, e.g.,model-registry://{modelRegistryUrl}:{modelName}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way I approached this problem can be found here:
https://github.com/Al-Pragliola/model-registry/blob/feat/multi-mr-registries-csi-support/csi/pkg/modelregistry/api_client.go#L24
We try to reach (token 0) as a model registry and on failure we assume that it is a model name and not a valid mr url (falling back to the default url from env var), it's not perfect but I think it might be a fair compromise
by doing this in the function from this comment:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry missed that, I tried and I can confirm that if the "host" is reachable it works as expected:
With a model registry running at
localhost:8080
.My main concern with this assumption is that, if there is a model registry running but for any reason it is not accessible/reachable we are going to make the wrong assumption that it is a registeredModel name , right? And the error would be misleading to the user as it will find in the logs.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep that's true, the only hint that user have from the logs is this:
We can improve this message telling the user that it failed to reach
url from uri
and it's going to usefallback url
I wanted this to be as retrocompatible as possible, otherwise there are many alternatives, like a query parameter approach:
model-registry://url?modelName=x&modelVersion=y
model-registry://?modelName=x&modelVersion=y
or the other delimiter you mentioned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added comment below for line 129 my2c
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw forgot to mention; options 3-4 also consistent with standard URLs and also KServe examples
much appreciated @Al-Pragliola