Skip to content

Commit

Permalink
feat: RAG engine validation (kaito-project#691)
Browse files Browse the repository at this point in the history
**Reason for Change**:
Validation of RAG engine creation

**Requirements**

- [ ] added unit tests and e2e tests (if applicable).

**Issue Fixed**:
<!-- If this PR fixes GitHub issue 4321, add "Fixes #4321" to the next
line. -->

**Notes for Reviewers**:

---------

Signed-off-by: Bangqi Zhu <[email protected]>
Co-authored-by: Bangqi Zhu <[email protected]>
  • Loading branch information
bangqipropel and Bangqi Zhu authored Nov 12, 2024
1 parent 1c6eb2e commit f3ef4c8
Show file tree
Hide file tree
Showing 2 changed files with 347 additions and 0 deletions.
90 changes: 90 additions & 0 deletions api/v1alpha1/ragengine_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@ package v1alpha1
import (
"context"
"fmt"
"net/url"
"os"
"regexp"
"strings"

"github.com/kaito-project/kaito/pkg/utils"
"github.com/kaito-project/kaito/pkg/utils/consts"
admissionregistrationv1 "k8s.io/api/admissionregistration/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/klog/v2"
"knative.dev/pkg/apis"
)
Expand All @@ -32,5 +39,88 @@ func (w *RAGEngine) validateCreate() (errs *apis.FieldError) {
if w.Spec.InferenceService == nil {
errs = errs.Also(apis.ErrGeneric("InferenceService must be specified", ""))
}
errs = errs.Also(w.Spec.InferenceService.validateCreate())
if w.Spec.Embedding == nil {
errs = errs.Also(apis.ErrGeneric("Embedding must be specified", ""))
return errs
}
if w.Spec.Embedding.Local == nil && w.Spec.Embedding.Remote == nil {
errs = errs.Also(apis.ErrGeneric("Either remote embedding or local embedding must be specified, not neither", ""))
}
if w.Spec.Embedding.Local != nil && w.Spec.Embedding.Remote != nil {
errs = errs.Also(apis.ErrGeneric("Either remote embedding or local embedding must be specified, but not both", ""))
}
errs = errs.Also(w.Spec.Compute.validateRAGCreate())
if w.Spec.Embedding.Local != nil {
w.Spec.Embedding.Local.validateCreate().ViaField("embedding")
}
if w.Spec.Embedding.Remote != nil {
w.Spec.Embedding.Remote.validateCreate().ViaField("embedding")
}

return errs
}

func (r *ResourceSpec) validateRAGCreate() (errs *apis.FieldError) {
instanceType := string(r.InstanceType)

skuHandler, err := utils.GetSKUHandler()
if err != nil {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to get SKU handler: %v", err), "instanceType"))
return errs
}
gpuConfigs := skuHandler.GetGPUConfigs()

if _, exists := gpuConfigs[instanceType]; !exists {
provider := os.Getenv("CLOUD_PROVIDER")
// Check for other instance types pattern matches if cloud provider is Azure
if provider != consts.AzureCloudName || (!strings.HasPrefix(instanceType, N_SERIES_PREFIX) && !strings.HasPrefix(instanceType, D_SERIES_PREFIX)) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported instance type %s. Supported SKUs: %s", instanceType, skuHandler.GetSupportedSKUs()), "instanceType"))
}
}

// Validate labelSelector
if _, err := metav1.LabelSelectorAsMap(r.LabelSelector); err != nil {
errs = errs.Also(apis.ErrInvalidValue(err.Error(), "labelSelector"))
}

return errs
}

func (e *LocalEmbeddingSpec) validateCreate() (errs *apis.FieldError) {
if e.Image == "" && e.ModelID == "" {
errs = errs.Also(apis.ErrGeneric("Either image or modelID must be specified, not neither", ""))
}
if e.Image != "" && e.ModelID != "" {
errs = errs.Also(apis.ErrGeneric("Either image or modelID must be specified, but not both", ""))
}
if e.Image != "" {
re := regexp.MustCompile(`^(.+/[^:/]+):([^:/]+)$`)
if !re.MatchString(e.Image) {
errs = errs.Also(apis.ErrInvalidValue("Invalid image format, require full input image URL", "Image"))
} else {
// Executes if image is of correct format
err := utils.ExtractAndValidateRepoName(e.Image)
if err != nil {
errs = errs.Also(apis.ErrInvalidValue(err.Error(), "Image"))
}
}
}
return errs
}

func (e *RemoteEmbeddingSpec) validateCreate() (errs *apis.FieldError) {
_, err := url.ParseRequestURI(e.URL)
if err != nil {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("URL input error: %v", err), "remote url"))
}
return errs
}

func (e *InferenceServiceSpec) validateCreate() (errs *apis.FieldError) {
_, err := url.ParseRequestURI(e.URL)
if err != nil {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("URL input error: %v", err), "remote url"))
}
return errs
}
257 changes: 257 additions & 0 deletions api/v1alpha1/ragengine_validation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

package v1alpha1

import (
"os"
"strings"
"testing"

"github.com/kaito-project/kaito/pkg/utils/consts"
)

func TestRAGEngineValidateCreate(t *testing.T) {
tests := []struct {
name string
ragEngine *RAGEngine
wantErr bool
errField string
}{
{
name: "Both Local and Remote Embedding specified",
ragEngine: &RAGEngine{
Spec: &RAGEngineSpec{
Compute: &ResourceSpec{
InstanceType: "Standard_NC12s_v3",
},
InferenceService: &InferenceServiceSpec{URL: "http://example.com"},
Embedding: &EmbeddingSpec{
Local: &LocalEmbeddingSpec{
ModelID: "BAAI/bge-small-en-v1.5",
},
Remote: &RemoteEmbeddingSpec{URL: "http://remote-embedding.com"},
},
},
},
wantErr: true,
errField: "Either remote embedding or local embedding must be specified, but not both",
},
{
name: "Embedding not specified",
ragEngine: &RAGEngine{
Spec: &RAGEngineSpec{
Compute: &ResourceSpec{
InstanceType: "Standard_NC12s_v3",
},
InferenceService: &InferenceServiceSpec{URL: "http://example.com"},
},
},
wantErr: true,
errField: "Embedding must be specified",
},
{
name: "None of Local and Remote Embedding specified",
ragEngine: &RAGEngine{
Spec: &RAGEngineSpec{
Compute: &ResourceSpec{
InstanceType: "Standard_NC12s_v3",
},
InferenceService: &InferenceServiceSpec{URL: "http://example.com"},
Embedding: &EmbeddingSpec{},
},
},
wantErr: true,
errField: "Either remote embedding or local embedding must be specified, not neither",
},
{
name: "Only Local Embedding specified",
ragEngine: &RAGEngine{
Spec: &RAGEngineSpec{
Compute: &ResourceSpec{
InstanceType: "Standard_NC12s_v3",
},
InferenceService: &InferenceServiceSpec{URL: "http://example.com"},
Embedding: &EmbeddingSpec{
Local: &LocalEmbeddingSpec{
ModelID: "BAAI/bge-small-en-v1.5",
},
},
},
},
wantErr: false,
},
{
name: "Only Remote Embedding specified",
ragEngine: &RAGEngine{
Spec: &RAGEngineSpec{
Compute: &ResourceSpec{
InstanceType: "Standard_NC12s_v3",
},
InferenceService: &InferenceServiceSpec{URL: "http://example.com"},
Embedding: &EmbeddingSpec{
Remote: &RemoteEmbeddingSpec{URL: "http://remote-embedding.com"},
},
},
},
wantErr: false,
},
}
os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.ragEngine.validateCreate()
hasErr := err != nil

if hasErr != tt.wantErr {
t.Errorf("validateCreate() error = %v, wantErr %v", err, tt.wantErr)
}

if hasErr && tt.errField != "" && !strings.Contains(err.Error(), tt.errField) {
t.Errorf("validateCreate() expected error to contain %s, but got %s", tt.errField, err.Error())
}
})
}
}

func TestLocalEmbeddingValidateCreate(t *testing.T) {
tests := []struct {
name string
localEmbedding *LocalEmbeddingSpec
wantErr bool
errField string
}{
{
name: "Neither Image nor ModelID specified",
localEmbedding: &LocalEmbeddingSpec{},
wantErr: true,
errField: "Either image or modelID must be specified, not neither",
},
{
name: "Both Image and ModelID specified",
localEmbedding: &LocalEmbeddingSpec{
Image: "image-path",
ModelID: "model-id",
},
wantErr: true,
errField: "Either image or modelID must be specified, but not both",
},
{
name: "Invalid Image Format",
localEmbedding: &LocalEmbeddingSpec{
Image: "invalid-image-format",
},
wantErr: true,
errField: "Invalid image format",
},
{
name: "Valid Image Specified",
localEmbedding: &LocalEmbeddingSpec{
Image: "myrepo/myimage:tag",
},
wantErr: false,
},
{
name: "Valid ModelID Specified",
localEmbedding: &LocalEmbeddingSpec{
ModelID: "valid-model-id",
},
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.localEmbedding.validateCreate()
hasErr := err != nil

if hasErr != tt.wantErr {
t.Errorf("validateCreate() error = %v, wantErr %v", err, tt.wantErr)
}

if hasErr && tt.errField != "" && !strings.Contains(err.Error(), tt.errField) {
t.Errorf("validateCreate() expected error to contain %s, but got %s", tt.errField, err.Error())
}
})
}
}

func TestRemoteEmbeddingValidateCreate(t *testing.T) {
tests := []struct {
name string
remoteEmbedding *RemoteEmbeddingSpec
wantErr bool
errField string
}{
{
name: "Invalid URL Specified",
remoteEmbedding: &RemoteEmbeddingSpec{
URL: "invalid-url",
},
wantErr: true,
errField: "URL input error",
},
{
name: "Valid URL Specified",
remoteEmbedding: &RemoteEmbeddingSpec{
URL: "http://example.com",
},
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.remoteEmbedding.validateCreate()
hasErr := err != nil

if hasErr != tt.wantErr {
t.Errorf("validateCreate() error = %v, wantErr %v", err, tt.wantErr)
}

if hasErr && tt.errField != "" && !strings.Contains(err.Error(), tt.errField) {
t.Errorf("validateCreate() expected error to contain %s, but got %s", tt.errField, err.Error())
}
})
}
}

func TestInferenceServiceValidateCreate(t *testing.T) {
tests := []struct {
name string
inferenceService *InferenceServiceSpec
wantErr bool
errField string
}{
{
name: "Invalid URL Specified",
inferenceService: &InferenceServiceSpec{
URL: "invalid-url",
},
wantErr: true,
errField: "URL input error",
},
{
name: "Valid URL Specified",
inferenceService: &InferenceServiceSpec{
URL: "http://example.com",
},
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.inferenceService.validateCreate()
hasErr := err != nil

if hasErr != tt.wantErr {
t.Errorf("validateCreate() error = %v, wantErr %v", err, tt.wantErr)
}

if hasErr && tt.errField != "" && !strings.Contains(err.Error(), tt.errField) {
t.Errorf("validateCreate() expected error to contain %s, but got %s", tt.errField, err.Error())
}
})
}
}

0 comments on commit f3ef4c8

Please sign in to comment.