Skip to content

Commit

Permalink
feat: fetch and create registered models
Browse files Browse the repository at this point in the history
In this PR:
- GET /v1/model-registry/{model_registry_id}/registered_models
- POST /v1/model-registry/{model_registry_id}/registered_models
- HTTP client

Signed-off-by: Eder Ignatowicz <[email protected]>
  • Loading branch information
ederign committed Jun 12, 2024
1 parent fd1d3c1 commit 2addd32
Show file tree
Hide file tree
Showing 21 changed files with 771 additions and 54 deletions.
4 changes: 4 additions & 0 deletions clients/ui/bff/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ COPY api/ api/
COPY config/ config/
COPY data/ data/
COPY integrations/ integrations/
COPY internals/ internals/
COPY validation/ validation/



# Build the Go application
RUN CGO_ENABLED=0 GOOS=${TARGETOS:-linux} GOARCH=${TARGETARCH} go build -a -o bff ./cmd/main.go
Expand Down
36 changes: 32 additions & 4 deletions clients/ui/bff/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,41 @@ TBD

### Endpoints

| URL Pattern | Handler | Action |
|-------------------------|----------------------|-------------------------------|
| GET /v1/healthcheck | HealthcheckHandler | Show application information. |
| GET /v1/model-registry/ | ModelRegistryHandler | Get all model registries, |
| URL Pattern | Handler | Action |
|---------------------------------------------------------------|-------------------------|----------------------------------------------|
| GET /v1/healthcheck | HealthcheckHandler | Show application information. |
| GET /v1/model-registry/ | ModelRegistryHandler | Get all model registries, |
| GET /v1/model-registry/{model_registry_id}/registered_models | RegisteredModelsHandler | Gets a list of all RegisteredModel entities. |
| POST /v1/model-registry/{model_registry_id}/registered_models | RegisteredModelsHandler | Create a RegisteredModel entity. |

### Sample local calls
```
# GET /v1/healthcheck
curl -i localhost:4000/api/v1/healthcheck/
```
```
# GET /v1/model-registry/
curl -i localhost:4000/api/v1/model-registry/
```
```
# GET /v1/model-registry/{model_registry_id}/registered_models
curl -i localhost:4000/api/v1/model-registry/model-registry/registered_models
```
```
#POST /v1/model-registry/{model_registry_id}/registered_models
curl -i -X POST "http://localhost:4000/api/v1/model-registry/model-registry/registered_models" \
-H "Content-Type: application/json" \
-d '{
"customProperties": {
"my-label9": {
"metadataType": "MetadataStringValue",
"string_value": "val"
}
},
"description": "bella description",
"externalId": "9927",
"name": "bella",
"owner": "eder",
"state": "LIVE"
}'
```
13 changes: 9 additions & 4 deletions clients/ui/bff/api/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ import (
)

const (
Version = "1.0.0"
HealthCheckPath = "/api/v1/healthcheck/"
ModelRegistry = "/api/v1/model-registry/"
Version = "1.0.0"
PathPrefix = "/api/v1"
ModelRegistryId = "model_registry_id"
HealthCheckPath = PathPrefix + "/healthcheck/"
ModelRegistry = PathPrefix + "/model-registry/"
RegisteredModelsPath = ModelRegistry + ":" + ModelRegistryId + "/registered_models"
)

type App struct {
Expand All @@ -25,7 +28,7 @@ type App struct {
}

func NewApp(cfg config.EnvConfig, logger *slog.Logger) (*App, error) {
k8sClient, err := integrations.NewKubernetesClient()
k8sClient, err := integrations.NewKubernetesClient(logger)
if err != nil {
return nil, fmt.Errorf("failed to create Kubernetes client: %w", err)
}
Expand All @@ -46,6 +49,8 @@ func (app *App) Routes() http.Handler {

// HTTP client routes
router.GET(HealthCheckPath, app.HealthcheckHandler)
router.GET(RegisteredModelsPath, app.AttachRESTClient(app.GetRegisteredModelsHandler))
router.POST(RegisteredModelsPath, app.AttachRESTClient(app.CreateRegisteredModelHandler))

// Kubernetes client routes
router.GET(ModelRegistry, app.ModelRegistryHandler)
Expand Down
23 changes: 12 additions & 11 deletions clients/ui/bff/api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package api
import (
"encoding/json"
"fmt"
"github.com/kubeflow/model-registry/ui/bff/integrations"
"net/http"
"strconv"
)
Expand All @@ -27,17 +28,17 @@ func (app *App) LogError(r *http.Request, err error) {
}

func (app *App) badRequestResponse(w http.ResponseWriter, r *http.Request, err error) {
httpError := &HTTPError{
httpError := &integrations.HTTPError{
StatusCode: http.StatusBadRequest,
ErrorResponse: ErrorResponse{
ErrorResponse: integrations.ErrorResponse{
Code: strconv.Itoa(http.StatusBadRequest),
Message: err.Error(),
},
}
app.errorResponse(w, r, httpError)
}

func (app *App) errorResponse(w http.ResponseWriter, r *http.Request, error *HTTPError) {
func (app *App) errorResponse(w http.ResponseWriter, r *http.Request, error *integrations.HTTPError) {

env := Envelope{"error": error}

Expand All @@ -52,9 +53,9 @@ func (app *App) errorResponse(w http.ResponseWriter, r *http.Request, error *HTT
func (app *App) serverErrorResponse(w http.ResponseWriter, r *http.Request, err error) {
app.LogError(r, err)

httpError := &HTTPError{
httpError := &integrations.HTTPError{
StatusCode: http.StatusInternalServerError,
ErrorResponse: ErrorResponse{
ErrorResponse: integrations.ErrorResponse{
Code: strconv.Itoa(http.StatusInternalServerError),
Message: "the server encountered a problem and could not process your request",
},
Expand All @@ -64,9 +65,9 @@ func (app *App) serverErrorResponse(w http.ResponseWriter, r *http.Request, err

func (app *App) notFoundResponse(w http.ResponseWriter, r *http.Request) {

httpError := &HTTPError{
httpError := &integrations.HTTPError{
StatusCode: http.StatusNotFound,
ErrorResponse: ErrorResponse{
ErrorResponse: integrations.ErrorResponse{
Code: strconv.Itoa(http.StatusNotFound),
Message: "the requested resource could not be found",
},
Expand All @@ -76,9 +77,9 @@ func (app *App) notFoundResponse(w http.ResponseWriter, r *http.Request) {

func (app *App) methodNotAllowedResponse(w http.ResponseWriter, r *http.Request) {

httpError := &HTTPError{
httpError := &integrations.HTTPError{
StatusCode: http.StatusMethodNotAllowed,
ErrorResponse: ErrorResponse{
ErrorResponse: integrations.ErrorResponse{
Code: strconv.Itoa(http.StatusMethodNotAllowed),
Message: fmt.Sprintf("the %s method is not supported for this resource", r.Method),
},
Expand All @@ -92,9 +93,9 @@ func (app *App) failedValidationResponse(w http.ResponseWriter, r *http.Request,
if err != nil {
message = []byte("{}")
}
httpError := &HTTPError{
httpError := &integrations.HTTPError{
StatusCode: http.StatusUnprocessableEntity,
ErrorResponse: ErrorResponse{
ErrorResponse: integrations.ErrorResponse{
Code: strconv.Itoa(http.StatusUnprocessableEntity),
Message: string(message),
},
Expand Down
61 changes: 61 additions & 0 deletions clients/ui/bff/api/middleware.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
package api

import (
"context"
"fmt"
"github.com/julienschmidt/httprouter"
"github.com/kubeflow/model-registry/ui/bff/integrations"
"k8s.io/client-go/rest"
"net/http"
)

type contextKey string

const httpClientKey contextKey = "httpClientKey"

func (app *App) RecoverPanic(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
Expand All @@ -27,3 +35,56 @@ func (app *App) enableCORS(next http.Handler) http.Handler {
next.ServeHTTP(w, r)
})
}

func (app *App) AttachRESTClient(handler func(http.ResponseWriter, *http.Request, httprouter.Params)) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {

modelRegistryID := ps.ByName(ModelRegistryId)

modelRegistryBaseURL, err := resolveModelRegistryURL(modelRegistryID, app.kubernetesClient)
if err != nil {
app.serverErrorResponse(w, r, fmt.Errorf("failed to resolve model registry base URL): %v", err))
return
}
var bearerToken string
bearerToken, err = resolveBearerToken(app.kubernetesClient)
if err != nil {
app.serverErrorResponse(w, r, fmt.Errorf("failed to resolve BearerToken): %v", err))
return
}

client, err := integrations.NewHTTPClient(modelRegistryBaseURL, bearerToken)
if err != nil {
app.serverErrorResponse(w, r, fmt.Errorf("failed to create Kubernetes client: %v", err))
return
}
ctx := context.WithValue(r.Context(), httpClientKey, client)
handler(w, r.WithContext(ctx), ps)
}
}

func resolveBearerToken(k8s integrations.KubernetesClientInterface) (string, error) {
var bearerToken string
_, err := rest.InClusterConfig()
if err == nil {
//in cluster
//TODO (eder) load bearerToken probably from x-forwarded-access-bearerToken
return "", fmt.Errorf("failed to create Rest client (not implemented yet - inside cluster): %v", err)
} else {
//off cluster (development)
bearerToken, err = k8s.BearerToken()
if err != nil {
return "", fmt.Errorf("failed to fetch BearerToken in development mode: %v", err)
}
}
return bearerToken, err
}

func resolveModelRegistryURL(id string, client integrations.KubernetesClientInterface) (string, error) {
serviceDetails, err := client.GetServiceDetailsByName(id)
if err != nil {
return "", err
}
url := fmt.Sprintf("http://%s:%d/api/model_registry/v1alpha3", serviceDetails.ClusterIP, serviceDetails.HTTPPort)
return url, nil
}
8 changes: 3 additions & 5 deletions clients/ui/bff/api/model_registry_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package api

import (
"encoding/json"
"fmt"
"github.com/kubeflow/model-registry/ui/bff/data"
"github.com/kubeflow/model-registry/ui/bff/internals/mocks"
"github.com/stretchr/testify/assert"
Expand All @@ -14,7 +13,7 @@ import (

func TestModelRegistryHandler(t *testing.T) {
mockK8sClient := new(mocks.KubernetesClientMock)
mockK8sClient.On("FetchServiceNamesByComponent", "model-registry-server").Return([]string{"model-registry-dora", "model-registry-bella"}, nil)
mockK8sClient.On("GetServiceNames").Return(mockK8sClient.MockServiceNames(), nil)

testApp := App{
kubernetesClient: mockK8sClient,
Expand All @@ -31,7 +30,6 @@ func TestModelRegistryHandler(t *testing.T) {
defer rs.Body.Close()
body, err := io.ReadAll(rs.Body)
assert.NoError(t, err)
fmt.Println(string(body))
var modelRegistryRes Envelope
err = json.Unmarshal(body, &modelRegistryRes)
assert.NoError(t, err)
Expand All @@ -48,8 +46,8 @@ func TestModelRegistryHandler(t *testing.T) {

var expected = Envelope{
"model_registry": []data.ModelRegistryModel{
{Name: "model-registry-dora"},
{Name: "model-registry-bella"},
{Name: mockK8sClient.MockServiceNames()[0]},
{Name: mockK8sClient.MockServiceNames()[1]},
},
}

Expand Down
85 changes: 85 additions & 0 deletions clients/ui/bff/api/registered_models_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package api

import (
"encoding/json"
"errors"
"fmt"
"github.com/julienschmidt/httprouter"
"github.com/kubeflow/model-registry/pkg/openapi"
"github.com/kubeflow/model-registry/ui/bff/data"
"github.com/kubeflow/model-registry/ui/bff/integrations"
"github.com/kubeflow/model-registry/ui/bff/validation"
"net/http"
)

func (app *App) GetRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
//TODO (ederign) implement pagination
client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface)
if !ok {
app.serverErrorResponse(w, r, errors.New("REST client not found"))
return
}

modelList, err := data.FetchAllRegisteredModels(client)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}

modelRegistryRes := Envelope{
"registered_models": modelList,
}

err = app.WriteJSON(w, http.StatusOK, modelRegistryRes, nil)
if err != nil {
app.serverErrorResponse(w, r, err)
}
}

func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface)
if !ok {
app.serverErrorResponse(w, r, errors.New("REST client not found"))
return
}

var model openapi.RegisteredModel
if err := json.NewDecoder(r.Body).Decode(&model); err != nil {
app.serverErrorResponse(w, r, fmt.Errorf("error decoding JSON:: %v", err.Error()))
return
}

if err := validation.ValidateRegisteredModel(model); err != nil {
app.badRequestResponse(w, r, fmt.Errorf("validation error:: %v", err.Error()))
return
}

jsonData, err := json.Marshal(model)
if err != nil {
app.serverErrorResponse(w, r, fmt.Errorf("error marshaling model to JSON: %w", err))
return
}

createdModel, err := data.CreateRegisteredModel(client, jsonData)
if err != nil {
var httpErr *integrations.HTTPError
if errors.As(err, &httpErr) {
app.errorResponse(w, r, httpErr)
} else {
app.serverErrorResponse(w, r, err)
}
return
}

if createdModel == nil {
app.serverErrorResponse(w, r, fmt.Errorf("created model is nil"))
return
}

w.Header().Set("Location", fmt.Sprintf("%s/%s", RegisteredModelsPath, *createdModel.Id))
err = app.WriteJSON(w, http.StatusCreated, createdModel, nil)
if err != nil {
app.serverErrorResponse(w, r, fmt.Errorf("error writing JSON"))
return
}
}
2 changes: 1 addition & 1 deletion clients/ui/bff/data/model_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ type ModelRegistryModel struct {

func (m ModelRegistryModel) FetchAllModelRegistry(client k8s.KubernetesClientInterface) ([]ModelRegistryModel, error) {

resources, err := client.FetchServiceNamesByComponent(k8s.ModelRegistryServiceComponentSelector)
resources, err := client.GetServiceNames()
if err != nil {
return nil, fmt.Errorf("error fetching model registries: %w", err)
}
Expand Down
Loading

0 comments on commit 2addd32

Please sign in to comment.