Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(bff): create endpoint to list all model versions #707

Merged
merged 1 commit into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions clients/ui/bff/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ curl -i -H "kubeflow-userid: [email protected]" -X PATCH "http://localhost:4000/a
}}'
```
```
# GET /api/v1/model_registry/{model_registry_id}/model_versions
curl -i -H "kubeflow-userid: [email protected]" "http://localhost:4000/api/v1/model_registry/model-registry/model_versions?namespace=kubeflow"
```
```
# GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}
curl -i -H "kubeflow-userid: [email protected]" "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1?namespace=kubeflow"
```
Expand Down
3 changes: 2 additions & 1 deletion clients/ui/bff/internal/api/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,9 @@ func (app *App) Routes() http.Handler {
apiRouter.PATCH(RegisteredModelPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.UpdateRegisteredModelHandler))))
apiRouter.GET(RegisteredModelVersionsPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.GetAllModelVersionsForRegisteredModelHandler))))
apiRouter.POST(RegisteredModelVersionsPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.CreateModelVersionForRegisteredModelHandler))))
apiRouter.GET(ModelVersionPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient((app.GetModelVersionHandler)))))
apiRouter.POST(ModelVersionListPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.CreateModelVersionHandler))))
apiRouter.GET(ModelVersionListPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.GetAllModelVersionHandler))))
apiRouter.GET(ModelVersionPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.GetModelVersionHandler))))
apiRouter.PATCH(ModelVersionPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.UpdateModelVersionHandler))))
apiRouter.GET(ModelVersionArtifactListPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.GetAllModelArtifactsByModelVersionHandler))))
apiRouter.POST(ModelVersionArtifactListPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.CreateModelArtifactByModelVersionHandler))))
Expand Down
24 changes: 24 additions & 0 deletions clients/ui/bff/internal/api/model_versions_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,30 @@ type ModelVersionUpdateEnvelope Envelope[*openapi.ModelVersionUpdate, None]
type ModelArtifactListEnvelope Envelope[*openapi.ModelArtifactList, None]
type ModelArtifactEnvelope Envelope[*openapi.ModelArtifact, None]

func (app *App) GetAllModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface)
if !ok {
app.serverErrorResponse(w, r, errors.New("REST client not found"))
return
}

versionList, err := app.repositories.ModelRegistryClient.GetAllModelVersions(client)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}

responseBody := ModelVersionListEnvelope{
Data: versionList,
}

err = app.WriteJSON(w, http.StatusOK, responseBody, nil)
if err != nil {
app.serverErrorResponse(w, r, err)
}

}

func (app *App) GetModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface)
if !ok {
Expand Down
12 changes: 12 additions & 0 deletions clients/ui/bff/internal/api/model_versions_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@ import (
var _ = Describe("TestGetModelVersionHandler", func() {
Context("testing Model Version Handler", Ordered, func() {

It("should retrieve all model versions", func() {
By("fetching all model versions")
data := mocks.GetModelVersionListMock()
expected := ModelVersionListEnvelope{Data: &data}
actual, rs, err := setupApiTest[ModelVersionListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions?namespace=kubeflow", nil, k8sClient, mocks.KubeflowUserIDHeaderValue, "kubeflow")
Expect(err).NotTo(HaveOccurred())
By("should match the expected model versions")
Expect(rs.StatusCode).To(Equal(http.StatusOK))
Expect(actual.Data.Size).To(Equal(expected.Data.Size))
Expect(actual.Data.Items).To(Equal(expected.Data.Items))
})

It("should retrieve a model version", func() {
By("fetching a model version")
data := mocks.GetModelVersionMocks()[0]
Expand Down
2 changes: 1 addition & 1 deletion clients/ui/bff/internal/api/registered_models_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (app *App) GetAllModelVersionsForRegisteredModelHandler(w http.ResponseWrit
return
}

versionList, err := app.repositories.ModelRegistryClient.GetAllModelVersions(client, ps.ByName(RegisteredModelId), r.URL.Query())
versionList, err := app.repositories.ModelRegistryClient.GetAllModelVersionsForRegisteredModel(client, ps.ByName(RegisteredModelId), r.URL.Query())

if err != nil {
app.serverErrorResponse(w, r, err)
Expand Down
7 changes: 6 additions & 1 deletion clients/ui/bff/internal/mocks/model_registry_client_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ func (m *ModelRegistryClientMock) UpdateRegisteredModel(_ integrations.HTTPClien
return &mockData, nil
}

func (m *ModelRegistryClientMock) GetAllModelVersions(_ integrations.HTTPClientInterface) (*openapi.ModelVersionList, error) {
mockData := GetModelVersionListMock()
return &mockData, nil
}

func (m *ModelRegistryClientMock) GetModelVersion(_ integrations.HTTPClientInterface, id string) (*openapi.ModelVersion, error) {
if id == "3" {
mockData := GetModelVersionMocks()[2]
Expand All @@ -61,7 +66,7 @@ func (m *ModelRegistryClientMock) UpdateModelVersion(_ integrations.HTTPClientIn
return &mockData, nil
}

func (m *ModelRegistryClientMock) GetAllModelVersions(_ integrations.HTTPClientInterface, _ string, _ url.Values) (*openapi.ModelVersionList, error) {
func (m *ModelRegistryClientMock) GetAllModelVersionsForRegisteredModel(_ integrations.HTTPClientInterface, _ string, _ url.Values) (*openapi.ModelVersionList, error) {
mockData := GetModelVersionListMock()
return &mockData, nil
}
Expand Down
16 changes: 16 additions & 0 deletions clients/ui/bff/internal/repositories/model_version.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ const modelVersionPath = "/model_versions"
const artifactsByModelVersionPath = "/artifacts"

type ModelVersionInterface interface {
GetAllModelVersions(client integrations.HTTPClientInterface) (*openapi.ModelVersionList, error)
GetModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersion, error)
CreateModelVersion(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.ModelVersion, error)
UpdateModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error)
Expand All @@ -24,6 +25,21 @@ type ModelVersion struct {
ModelVersionInterface
}

func (v ModelVersion) GetAllModelVersions(client integrations.HTTPClientInterface) (*openapi.ModelVersionList, error) {
response, err := client.GET(modelVersionPath)

if err != nil {
return nil, fmt.Errorf("error fetching model versions: %w", err)
}

var models openapi.ModelVersionList
if err := json.Unmarshal(response, &models); err != nil {
return nil, fmt.Errorf("error decoding response data: %w", err)
}

return &models, nil
}

func (v ModelVersion) GetModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersion, error) {
path, err := url.JoinPath(modelVersionPath, id)
if err != nil {
Expand Down
26 changes: 26 additions & 0 deletions clients/ui/bff/internal/repositories/model_version_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,32 @@ func TestGetModelVersion(t *testing.T) {
mockClient.AssertExpectations(t)
}

func TestGetAllModelVersions(t *testing.T) {
_ = gofakeit.Seed(0)

expected := mocks.GenerateMockModelVersionList()

mockData, err := json.Marshal(expected)
assert.NoError(t, err)

modelVersion := ModelVersion{}

mockClient := new(mocks.MockHTTPClient)
mockClient.On("GET", modelVersionPath).Return(mockData, nil)

actual, err := modelVersion.GetAllModelVersions(mockClient)
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.Equal(t, expected.NextPageToken, actual.NextPageToken)
assert.Equal(t, expected.PageSize, actual.PageSize)
assert.Equal(t, expected.Size, actual.Size)
assert.Equal(t, len(expected.Items), len(actual.Items))

mockClient.AssertExpectations(t)
}

func TestCreateModelVersion(t *testing.T) {
_ = gofakeit.Seed(0)

Expand Down
4 changes: 2 additions & 2 deletions clients/ui/bff/internal/repositories/registered_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type RegisteredModelInterface interface {
CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.RegisteredModel, error)
GetRegisteredModel(client integrations.HTTPClientInterface, id string) (*openapi.RegisteredModel, error)
UpdateRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.RegisteredModel, error)
GetAllModelVersions(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error)
GetAllModelVersionsForRegisteredModel(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error)
CreateModelVersionForRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error)
}

Expand Down Expand Up @@ -94,7 +94,7 @@ func (m RegisteredModel) UpdateRegisteredModel(client integrations.HTTPClientInt
return &model, nil
}

func (m RegisteredModel) GetAllModelVersions(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error) {
func (m RegisteredModel) GetAllModelVersionsForRegisteredModel(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error) {
path, err := url.JoinPath(registeredModelPath, id, versionsPath)

if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions clients/ui/bff/internal/repositories/registered_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func TestUpdateRegisteredModel(t *testing.T) {
mockClient.AssertExpectations(t)
}

func TestGetAllModelVersions(t *testing.T) {
func TestGetAllModelVersionsByRegisteredModel(t *testing.T) {
_ = gofakeit.Seed(0)

expected := mocks.GenerateMockModelVersionList()
Expand All @@ -149,7 +149,7 @@ func TestGetAllModelVersions(t *testing.T) {
assert.NoError(t, err)
mockClient.On("GET", path).Return(mockData, nil)

actual, err := registeredModel.GetAllModelVersions(mockClient, "1", nil)
actual, err := registeredModel.GetAllModelVersionsForRegisteredModel(mockClient, "1", nil)
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.NoError(t, err)
Expand Down Expand Up @@ -180,7 +180,7 @@ func TestGetAllModelVersionsWithPageParams(t *testing.T) {

mockClient.On("GET", reqUrl).Return(mockData, nil)

actual, err := registeredModel.GetAllModelVersions(mockClient, "1", pageValues)
actual, err := registeredModel.GetAllModelVersionsForRegisteredModel(mockClient, "1", pageValues)
assert.NoError(t, err)
assert.NotNil(t, actual)

Expand Down
Loading