Skip to content

Commit

Permalink
fvt: add basic load and infer test for TorchServe runtime
Browse files Browse the repository at this point in the history
Signed-off-by: Rafael Vasquez <[email protected]>
  • Loading branch information
rafvasq committed Jan 9, 2023
1 parent ab63270 commit 43d1c75
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 3 deletions.
15 changes: 14 additions & 1 deletion fvt/fvtclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
Expand Down Expand Up @@ -54,6 +54,7 @@ import (

inference "github.com/kserve/modelmesh-serving/fvt/generated"
tfsapi "github.com/kserve/modelmesh-serving/fvt/generated/tensorflow_serving/apis"
torchserveapi "github.com/kserve/modelmesh-serving/fvt/generated/torchserve/apis"
)

const predictorTimeout = time.Second * 120
Expand Down Expand Up @@ -480,6 +481,18 @@ func (fvt *FVTClient) RunTfsInference(req *tfsapi.PredictRequest) (*tfsapi.Predi
return grpcClient.Predict(ctx, req)
}

func (fvt *FVTClient) RunTorchserveInference(req *torchserveapi.PredictionsRequest) (*torchserveapi.PredictionResponse, error) {
if fvt.grpcConn == nil {
return nil, errors.New("you must connect to model mesh before running an inference")
}

grpcClient := torchserveapi.NewInferenceAPIsServiceClient(fvt.grpcConn)

ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
return grpcClient.Predictions(ctx, req)
}

func (fvt *FVTClient) ConnectToModelServing(connectionType ModelServingConnectionType) error {
if fvt.grpcPortForward == nil {
podName := fvt.GetRandomReadyRuntimePodNameFromEndpoints()
Expand Down
18 changes: 17 additions & 1 deletion fvt/inference.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
Expand All @@ -17,6 +17,7 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"io/ioutil"
"math"
"os"

Expand All @@ -29,6 +30,7 @@ import (

tfsframework "github.com/kserve/modelmesh-serving/fvt/generated/tensorflow/core/framework"
tfsapi "github.com/kserve/modelmesh-serving/fvt/generated/tensorflow_serving/apis"
torchserveapi "github.com/kserve/modelmesh-serving/fvt/generated/torchserve/apis"
)

// Used for checking if floats are sufficiently close enough.
Expand Down Expand Up @@ -103,6 +105,20 @@ func ExpectSuccessfulInference_openvinoMnistTFSPredict(predictorName string) {
Expect(err).ToNot(HaveOccurred())
}

func ExpectSuccessfulInference_torchserveMARPredict(predictorName string) {
imageBytes, err := ioutil.ReadFile(TestDataPath("0.png"))
Expect(err).ToNot(HaveOccurred())

inferRequest := &torchserveapi.PredictionsRequest{
ModelName: predictorName,
Input: map[string][]byte{"data": imageBytes},
}

inferResponse, err := FVTClientInstance.RunTorchserveInference(inferRequest)
Expect(err).ToNot(HaveOccurred())
Expect(inferResponse).ToNot(BeNil())
}

// PyTorch CIFAR
// COS path: fvt/pytorch/pytorch-cifar
func ExpectSuccessfulInference_pytorchCifar(predictorName string) {
Expand Down
34 changes: 33 additions & 1 deletion fvt/predictor/predictor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
Expand Down Expand Up @@ -112,6 +112,14 @@ var predictorsArray = []FVTPredictor{
differentPredictorName: "xgboost",
differentPredictorFilename: "xgboost-predictor.yaml",
},
{
predictorName: "pytorch-mar",
predictorFilename: "pytorch-mar-predictor.yaml",
currentModelPath: "fvt/pytorch/pytorch-mar/mnist.mar",
updatedModelPath: "fvt/pytorch/pytorch-mar-dup/mnist.mar",
differentPredictorName: "pytorch",
differentPredictorFilename: "pytorch-predictor.yaml",
},
}

var _ = Describe("Predictor", func() {
Expand Down Expand Up @@ -576,6 +584,30 @@ var _ = Describe("Predictor", func() {
})
})

var _ = Describe("TorchServe Inference", Ordered, func() {
var torchservePredictorObject *unstructured.Unstructured
var torchservePredictorName string

BeforeAll(func() {
// load the test predictor object
torchservePredictorObject = NewPredictorForFVT("pytorch-mar-predictor.yaml")
torchservePredictorName = torchservePredictorObject.GetName()

CreatePredictorAndWaitAndExpectLoaded(torchservePredictorObject)

err := FVTClientInstance.ConnectToModelServing(Insecure)
Expect(err).ToNot(HaveOccurred())
})

AfterAll(func() {
FVTClientInstance.DeletePredictor(torchservePredictorName)
})

It("should successfully run an inference", func() {
ExpectSuccessfulInference_torchserveMARPredict(torchservePredictorName)
})
})

var _ = Describe("MLServer inference", Ordered, func() {
var mlsPredictorObject *unstructured.Unstructured
var mlsPredictorName string
Expand Down
Binary file added fvt/testdata/0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 24 additions & 0 deletions fvt/testdata/predictors/pytorch-mar-predictor.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2022 IBM Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
apiVersion: serving.kserve.io/v1alpha1
kind: Predictor
metadata:
name: pytorch-mar-predictor
spec:
modelType:
name: pytorch-mar
path: fvt/pytorch/pytorch-mar/mnist.mar
storage:
s3:
secretKey: localMinIO

0 comments on commit 43d1c75

Please sign in to comment.