Skip to content

Commit

Permalink
kubeflow: make MLMD type names (and prefix) pluggable
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Mortari <[email protected]>
  • Loading branch information
tarilabs committed Feb 26, 2024
1 parent edefcc4 commit c544791
Show file tree
Hide file tree
Showing 13 changed files with 203 additions and 186 deletions.
4 changes: 2 additions & 2 deletions clients/python/src/model_registry/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def get_proto_type_name(cls) -> str:
"""Name of the proto type.
Returns:
Name of the class prefixed with `kfmr.`
Name of the class prefixed with `kf.`
"""
return f"kfmr.{cls.__name__}"
return f"kf.{cls.__name__}"

@property
@abstractmethod
Expand Down
5 changes: 3 additions & 2 deletions cmd/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ func runProxyServer(cmd *cobra.Command, args []string) error {
defer conn.Close()
glog.Infof("connected to MLMD server")

_, err = mlmdtypes.CreateMLMDTypes(conn)
mlmdTypeNamesConfig := mlmdtypes.NewMLMDTypeNamesConfigFromDefaults()
_, err = mlmdtypes.CreateMLMDTypes(conn, mlmdTypeNamesConfig)
if err != nil {
return fmt.Errorf("error creating MLMD types: %v", err)
}
service, err := core.NewModelRegistryService(conn)
service, err := core.NewModelRegistryService(conn, mlmdTypeNamesConfig)
if err != nil {
return fmt.Errorf("error creating core service: %v", err)
}
Expand Down
12 changes: 0 additions & 12 deletions internal/constants/constants.go

This file was deleted.

20 changes: 10 additions & 10 deletions internal/converter/mlmd_converter_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"strings"
"testing"

"github.com/kubeflow/model-registry/internal/constants"
"github.com/kubeflow/model-registry/internal/defaults"
"github.com/kubeflow/model-registry/internal/ml_metadata/proto"
"github.com/kubeflow/model-registry/pkg/openapi"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -208,7 +208,7 @@ func TestMapRegisteredModelType(t *testing.T) {

typeName := MapRegisteredModelType(&openapi.RegisteredModel{})
assertion.NotNil(typeName)
assertion.Equal(constants.RegisteredModelTypeName, *typeName)
assertion.Equal(defaults.RegisteredModelTypeName, *typeName)
}

func TestMapModelVersionProperties(t *testing.T) {
Expand Down Expand Up @@ -236,7 +236,7 @@ func TestMapModelVersionType(t *testing.T) {

typeName := MapModelVersionType(&openapi.ModelVersion{})
assertion.NotNil(typeName)
assertion.Equal(constants.ModelVersionTypeName, *typeName)
assertion.Equal(defaults.ModelVersionTypeName, *typeName)
}

func TestMapModelVersionName(t *testing.T) {
Expand Down Expand Up @@ -287,7 +287,7 @@ func TestMapModelArtifactType(t *testing.T) {

typeName := MapModelArtifactType(&openapi.ModelArtifact{})
assertion.NotNil(typeName)
assertion.Equal(constants.ModelArtifactTypeName, *typeName)
assertion.Equal(defaults.ModelArtifactTypeName, *typeName)
}

func TestMapModelArtifactName(t *testing.T) {
Expand Down Expand Up @@ -346,7 +346,7 @@ func TestMapDocArtifactType(t *testing.T) {

typeName := MapModelArtifactType(&openapi.ModelArtifact{})
assertion.NotNil(typeName)
assertion.Equal(constants.ModelArtifactTypeName, *typeName)
assertion.Equal(defaults.ModelArtifactTypeName, *typeName)
}

func TestMapDocArtifactName(t *testing.T) {
Expand Down Expand Up @@ -577,13 +577,13 @@ func TestMapArtifactType(t *testing.T) {
assertion := setup(t)

artifactType, err := MapArtifactType(&proto.Artifact{
Type: of(constants.ModelArtifactTypeName),
Type: of(defaults.ModelArtifactTypeName),
})
assertion.Nil(err)
assertion.Equal("model-artifact", artifactType)

artifactType, err = MapArtifactType(&proto.Artifact{
Type: of(constants.DocArtifactTypeName),
Type: of(defaults.DocArtifactTypeName),
})
assertion.Nil(err)
assertion.Equal("doc-artifact", artifactType)
Expand Down Expand Up @@ -659,15 +659,15 @@ func TestMapServingEnvironmentType(t *testing.T) {

typeName := MapServingEnvironmentType(&openapi.ServingEnvironment{})
assertion.NotNil(typeName)
assertion.Equal(constants.ServingEnvironmentTypeName, *typeName)
assertion.Equal(defaults.ServingEnvironmentTypeName, *typeName)
}

func TestMapInferenceServiceType(t *testing.T) {
assertion := setup(t)

typeName := MapInferenceServiceType(&openapi.InferenceService{})
assertion.NotNil(typeName)
assertion.Equal(constants.InferenceServiceTypeName, *typeName)
assertion.Equal(defaults.InferenceServiceTypeName, *typeName)
}

func TestMapInferenceServiceProperties(t *testing.T) {
Expand Down Expand Up @@ -710,7 +710,7 @@ func TestMapServeModelType(t *testing.T) {

typeName := MapServeModelType(&openapi.ServeModel{})
assertion.NotNil(typeName)
assertion.Equal(constants.ServeModelTypeName, *typeName)
assertion.Equal(defaults.ServeModelTypeName, *typeName)
}

func TestMapServeModelProperties(t *testing.T) {
Expand Down
6 changes: 3 additions & 3 deletions internal/converter/mlmd_openapi_converter_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"fmt"
"strings"

"github.com/kubeflow/model-registry/internal/constants"
"github.com/kubeflow/model-registry/internal/defaults"
"github.com/kubeflow/model-registry/internal/ml_metadata/proto"
"github.com/kubeflow/model-registry/pkg/openapi"
)
Expand Down Expand Up @@ -87,9 +87,9 @@ func MapArtifactType(source *proto.Artifact) (string, error) {
return "", fmt.Errorf("artifact type is nil")
}
switch *source.Type {
case constants.ModelArtifactTypeName:
case defaults.ModelArtifactTypeName:
return "model-artifact", nil
case constants.DocArtifactTypeName:
case defaults.DocArtifactTypeName:
return "doc-artifact", nil
default:
return "", fmt.Errorf("invalid artifact type found: %v", source.Type)
Expand Down
16 changes: 8 additions & 8 deletions internal/converter/openapi_mlmd_converter_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"strconv"

"github.com/google/uuid"
"github.com/kubeflow/model-registry/internal/constants"
"github.com/kubeflow/model-registry/internal/defaults"
"github.com/kubeflow/model-registry/internal/ml_metadata/proto"
"github.com/kubeflow/model-registry/pkg/openapi"
"google.golang.org/protobuf/types/known/structpb"
Expand Down Expand Up @@ -144,7 +144,7 @@ func MapRegisteredModelProperties(source *openapi.RegisteredModel) (map[string]*

// MapRegisteredModelType return RegisteredModel corresponding MLMD context type
func MapRegisteredModelType(_ *openapi.RegisteredModel) *string {
return of(constants.RegisteredModelTypeName)
return of(defaults.RegisteredModelTypeName)
}

// MODEL VERSION
Expand Down Expand Up @@ -194,7 +194,7 @@ func MapModelVersionProperties(source *OpenAPIModelWrapper[openapi.ModelVersion]

// MapModelVersionType return ModelVersion corresponding MLMD context type
func MapModelVersionType(_ *openapi.ModelVersion) *string {
return of(constants.ModelVersionTypeName)
return of(defaults.ModelVersionTypeName)
}

// MapModelVersionName maps the user-provided name into MLMD one, i.e., prefixing it with
Expand Down Expand Up @@ -222,7 +222,7 @@ func MapOpenAPIArtifactState(source *openapi.ArtifactState) (*proto.Artifact_Sta

// get DocArtifact MLMD type name
func MapDocArtifactType(_ *openapi.DocArtifact) *string {
return of(constants.DocArtifactTypeName)
return of(defaults.DocArtifactTypeName)
}

func MapDocArtifactProperties(source *openapi.DocArtifact) (map[string]*proto.Value, error) {
Expand Down Expand Up @@ -307,7 +307,7 @@ func MapModelArtifactProperties(source *openapi.ModelArtifact) (map[string]*prot

// MapModelArtifactType return ModelArtifact corresponding MLMD context type
func MapModelArtifactType(_ *openapi.ModelArtifact) *string {
return of(constants.ModelArtifactTypeName)
return of(defaults.ModelArtifactTypeName)
}

// MapModelArtifactName maps the user-provided name into MLMD one, i.e., prefixing it with
Expand All @@ -328,7 +328,7 @@ func MapModelArtifactName(source *OpenAPIModelWrapper[openapi.ModelArtifact]) *s

// MapServingEnvironmentType return ServingEnvironment corresponding MLMD context type
func MapServingEnvironmentType(_ *openapi.ServingEnvironment) *string {
return of(constants.ServingEnvironmentTypeName)
return of(defaults.ServingEnvironmentTypeName)
}

// MapServingEnvironmentProperties maps ServingEnvironment fields to specific MLMD properties
Expand All @@ -350,7 +350,7 @@ func MapServingEnvironmentProperties(source *openapi.ServingEnvironment) (map[st

// MapInferenceServiceType return InferenceService corresponding MLMD context type
func MapInferenceServiceType(_ *openapi.InferenceService) *string {
return of(constants.InferenceServiceTypeName)
return of(defaults.InferenceServiceTypeName)
}

// MapInferenceServiceProperties maps InferenceService fields to specific MLMD properties
Expand Down Expand Up @@ -436,7 +436,7 @@ func MapInferenceServiceName(source *OpenAPIModelWrapper[openapi.InferenceServic

// MapServeModelType return ServeModel corresponding MLMD context type
func MapServeModelType(_ *openapi.ServeModel) *string {
return of(constants.ServeModelTypeName)
return of(defaults.ServeModelTypeName)
}

// MapServeModelProperties maps ServeModel fields to specific MLMD properties
Expand Down
12 changes: 12 additions & 0 deletions internal/defaults/defaults.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package defaults

// MLMD type names
const (
RegisteredModelTypeName = "kf.RegisteredModel"
ModelVersionTypeName = "kf.ModelVersion"
ModelArtifactTypeName = "kf.ModelArtifact"
DocArtifactTypeName = "kf.DocArtifact"
ServingEnvironmentTypeName = "kf.ServingEnvironment"
InferenceServiceTypeName = "kf.InferenceService"
ServeModelTypeName = "kf.ServeModel"
)
34 changes: 17 additions & 17 deletions internal/mapper/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package mapper
import (
"fmt"

"github.com/kubeflow/model-registry/internal/constants"
"github.com/kubeflow/model-registry/internal/converter"
"github.com/kubeflow/model-registry/internal/converter/generated"
"github.com/kubeflow/model-registry/internal/defaults"
"github.com/kubeflow/model-registry/internal/ml_metadata/proto"
"github.com/kubeflow/model-registry/pkg/openapi"
)
Expand All @@ -28,14 +28,14 @@ func NewMapper(mlmdTypes map[string]int64) *Mapper {

func (m *Mapper) MapFromRegisteredModel(registeredModel *openapi.RegisteredModel) (*proto.Context, error) {
return m.OpenAPIConverter.ConvertRegisteredModel(&converter.OpenAPIModelWrapper[openapi.RegisteredModel]{
TypeId: m.MLMDTypes[constants.RegisteredModelTypeName],
TypeId: m.MLMDTypes[defaults.RegisteredModelTypeName],
Model: registeredModel,
})
}

func (m *Mapper) MapFromModelVersion(modelVersion *openapi.ModelVersion, registeredModelId string, registeredModelName *string) (*proto.Context, error) {
return m.OpenAPIConverter.ConvertModelVersion(&converter.OpenAPIModelWrapper[openapi.ModelVersion]{
TypeId: m.MLMDTypes[constants.ModelVersionTypeName],
TypeId: m.MLMDTypes[defaults.ModelVersionTypeName],
Model: modelVersion,
ParentResourceId: &registeredModelId,
ModelName: registeredModelName,
Expand All @@ -44,15 +44,15 @@ func (m *Mapper) MapFromModelVersion(modelVersion *openapi.ModelVersion, registe

func (m *Mapper) MapFromModelArtifact(modelArtifact *openapi.ModelArtifact, modelVersionId *string) (*proto.Artifact, error) {
return m.OpenAPIConverter.ConvertModelArtifact(&converter.OpenAPIModelWrapper[openapi.ModelArtifact]{
TypeId: m.MLMDTypes[constants.ModelArtifactTypeName],
TypeId: m.MLMDTypes[defaults.ModelArtifactTypeName],
Model: modelArtifact,
ParentResourceId: modelVersionId,
})
}

func (m *Mapper) MapFromDocArtifact(docArtifact *openapi.DocArtifact, modelVersionId *string) (*proto.Artifact, error) {
return m.OpenAPIConverter.ConvertDocArtifact(&converter.OpenAPIModelWrapper[openapi.DocArtifact]{
TypeId: m.MLMDTypes[constants.DocArtifactTypeName],
TypeId: m.MLMDTypes[defaults.DocArtifactTypeName],
Model: docArtifact,
ParentResourceId: modelVersionId,
})
Expand Down Expand Up @@ -89,22 +89,22 @@ func (m *Mapper) MapFromModelArtifacts(modelArtifacts []openapi.ModelArtifact, m

func (m *Mapper) MapFromServingEnvironment(servingEnvironment *openapi.ServingEnvironment) (*proto.Context, error) {
return m.OpenAPIConverter.ConvertServingEnvironment(&converter.OpenAPIModelWrapper[openapi.ServingEnvironment]{
TypeId: m.MLMDTypes[constants.ServingEnvironmentTypeName],
TypeId: m.MLMDTypes[defaults.ServingEnvironmentTypeName],
Model: servingEnvironment,
})
}

func (m *Mapper) MapFromInferenceService(inferenceService *openapi.InferenceService, servingEnvironmentId string) (*proto.Context, error) {
return m.OpenAPIConverter.ConvertInferenceService(&converter.OpenAPIModelWrapper[openapi.InferenceService]{
TypeId: m.MLMDTypes[constants.InferenceServiceTypeName],
TypeId: m.MLMDTypes[defaults.InferenceServiceTypeName],
Model: inferenceService,
ParentResourceId: &servingEnvironmentId,
})
}

func (m *Mapper) MapFromServeModel(serveModel *openapi.ServeModel, inferenceServiceId string) (*proto.Execution, error) {
return m.OpenAPIConverter.ConvertServeModel(&converter.OpenAPIModelWrapper[openapi.ServeModel]{
TypeId: m.MLMDTypes[constants.ServeModelTypeName],
TypeId: m.MLMDTypes[defaults.ServeModelTypeName],
Model: serveModel,
ParentResourceId: &inferenceServiceId,
})
Expand All @@ -113,19 +113,19 @@ func (m *Mapper) MapFromServeModel(serveModel *openapi.ServeModel, inferenceServ
// Utilities for MLMD --> OpenAPI mapping, make use of generated Converters

func (m *Mapper) MapToRegisteredModel(ctx *proto.Context) (*openapi.RegisteredModel, error) {
return mapTo(ctx, m.MLMDTypes, constants.RegisteredModelTypeName, m.MLMDConverter.ConvertRegisteredModel)
return mapTo(ctx, m.MLMDTypes, defaults.RegisteredModelTypeName, m.MLMDConverter.ConvertRegisteredModel)
}

func (m *Mapper) MapToModelVersion(ctx *proto.Context) (*openapi.ModelVersion, error) {
return mapTo(ctx, m.MLMDTypes, constants.ModelVersionTypeName, m.MLMDConverter.ConvertModelVersion)
return mapTo(ctx, m.MLMDTypes, defaults.ModelVersionTypeName, m.MLMDConverter.ConvertModelVersion)
}

func (m *Mapper) MapToModelArtifact(art *proto.Artifact) (*openapi.ModelArtifact, error) {
return mapTo(art, m.MLMDTypes, constants.ModelArtifactTypeName, m.MLMDConverter.ConvertModelArtifact)
return mapTo(art, m.MLMDTypes, defaults.ModelArtifactTypeName, m.MLMDConverter.ConvertModelArtifact)
}

func (m *Mapper) MapToDocArtifact(art *proto.Artifact) (*openapi.DocArtifact, error) {
return mapTo(art, m.MLMDTypes, constants.DocArtifactTypeName, m.MLMDConverter.ConvertDocArtifact)
return mapTo(art, m.MLMDTypes, defaults.DocArtifactTypeName, m.MLMDConverter.ConvertDocArtifact)
}

func (m *Mapper) MapToArtifact(art *proto.Artifact) (*openapi.Artifact, error) {
Expand All @@ -136,12 +136,12 @@ func (m *Mapper) MapToArtifact(art *proto.Artifact) (*openapi.Artifact, error) {
return nil, fmt.Errorf("invalid artifact type, can't map from nil")
}
switch art.GetType() {
case constants.ModelArtifactTypeName:
case defaults.ModelArtifactTypeName:
ma, err := m.MapToModelArtifact(art)
return &openapi.Artifact{
ModelArtifact: ma,
}, err
case constants.DocArtifactTypeName:
case defaults.DocArtifactTypeName:
da, err := m.MapToDocArtifact(art)
return &openapi.Artifact{
DocArtifact: da,
Expand All @@ -152,15 +152,15 @@ func (m *Mapper) MapToArtifact(art *proto.Artifact) (*openapi.Artifact, error) {
}

func (m *Mapper) MapToServingEnvironment(ctx *proto.Context) (*openapi.ServingEnvironment, error) {
return mapTo(ctx, m.MLMDTypes, constants.ServingEnvironmentTypeName, m.MLMDConverter.ConvertServingEnvironment)
return mapTo(ctx, m.MLMDTypes, defaults.ServingEnvironmentTypeName, m.MLMDConverter.ConvertServingEnvironment)
}

func (m *Mapper) MapToInferenceService(ctx *proto.Context) (*openapi.InferenceService, error) {
return mapTo(ctx, m.MLMDTypes, constants.InferenceServiceTypeName, m.MLMDConverter.ConvertInferenceService)
return mapTo(ctx, m.MLMDTypes, defaults.InferenceServiceTypeName, m.MLMDConverter.ConvertInferenceService)
}

func (m *Mapper) MapToServeModel(ex *proto.Execution) (*openapi.ServeModel, error) {
return mapTo(ex, m.MLMDTypes, constants.ServeModelTypeName, m.MLMDConverter.ConvertServeModel)
return mapTo(ex, m.MLMDTypes, defaults.ServeModelTypeName, m.MLMDConverter.ConvertServeModel)
}

type getTypeIder interface {
Expand Down
Loading

0 comments on commit c544791

Please sign in to comment.