Skip to content

Commit

Permalink
tensorflow savedmodel warmup (#539)
Browse files Browse the repository at this point in the history
* tensorflow savedmodel warmup

Signed-off-by: datelier <[email protected]>

* fix warmup

Signed-off-by: datelier <[email protected]>

* fix DeepSource issue: Empty string test can be improved

Signed-off-by: datelier <[email protected]>

* fix test checkFunc

Signed-off-by: datelier <[email protected]>

Co-authored-by: Yusuke Kato <[email protected]>
  • Loading branch information
datelier and Yusuke Kato authored Jul 23, 2020
1 parent dc6a44e commit 6d215f6
Show file tree
Hide file tree
Showing 4 changed files with 424 additions and 48 deletions.
33 changes: 30 additions & 3 deletions internal/core/converter/tensorflow/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,19 @@
// Package tensorflow provides implementation of Go API for extract data to vector
package tensorflow

import (
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)

// Option is tensorflow configure.
type Option func(*tensorflow)

var (
defaultOpts = []Option{
WithOperations(), // set to default
WithSessionOptions(nil), // set to default
WithNdim(0), // set to default
withLoadFunc(tf.LoadSavedModel), // set to default
WithOperations(), // set to default
WithSessionOptions(nil), // set to default
WithNdim(0), // set to default
}
)

Expand Down Expand Up @@ -102,6 +107,15 @@ func WithTags(tags ...string) Option {
}
}

func withLoadFunc(
loadFunc func(exportDir string, tags []string, options *SessionOptions) (*tf.SavedModel, error)) Option {
return func(t *tensorflow) {
if loadFunc != nil {
t.loadFunc = loadFunc
}
}
}

// WithFeed returns Option that sets feeds.
func WithFeed(operationName string, outputIndex int) Option {
return func(t *tensorflow) {
Expand Down Expand Up @@ -138,6 +152,19 @@ func WithFetches(operationNames []string, outputIndexes []int) Option {
}
}

// WithWarmupInputs returns Option that sets warmupInputs.
func WithWarmupInputs(warmupInputs ...string) Option {
return func(t *tensorflow) {
if warmupInputs != nil {
if t.warmupInputs != nil {
t.warmupInputs = append(t.warmupInputs, warmupInputs...)
} else {
t.warmupInputs = warmupInputs
}
}
}
}

// WithNdim returns Option that sets ndim.
func WithNdim(ndim uint8) Option {
return func(t *tensorflow) {
Expand Down
206 changes: 195 additions & 11 deletions internal/core/converter/tensorflow/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ import (
"reflect"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
"github.com/vdaas/vald/internal/errors"
"go.uber.org/goleak"
)
Expand Down Expand Up @@ -71,7 +74,7 @@ func TestWithSessionOptions(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand Down Expand Up @@ -140,7 +143,7 @@ func TestWithSessionTarget(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand Down Expand Up @@ -209,7 +212,7 @@ func TestWithSessionConfig(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand Down Expand Up @@ -294,7 +297,7 @@ func TestWithOperations(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand Down Expand Up @@ -363,7 +366,7 @@ func TestWithExportPath(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand Down Expand Up @@ -459,7 +462,7 @@ func TestWithTags(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand All @@ -482,6 +485,89 @@ func TestWithTags(t *testing.T) {
}
}

func TestWithLoadFunc(t *testing.T) {
type T = tensorflow
type args struct {
loadFunc func(string, []string, *SessionOptions) (*tf.SavedModel, error)
}
type want struct {
obj *T
}
type test struct {
name string
args args
want want
checkFunc func(want, *T) error
beforeFunc func(args)
afterFunc func(args)
}

defaultCheckFunc := func(w want, obj *T) error {
opts := []cmp.Option{
cmp.AllowUnexported(tensorflow{}),
cmp.AllowUnexported(OutputSpec{}),
cmpopts.IgnoreFields(tensorflow{}, "loadFunc"),
cmp.Comparer(func(want, obj T) bool {
p1 := reflect.ValueOf(want).FieldByName("loadFunc").Pointer()
p2 := reflect.ValueOf(obj).FieldByName("loadFunc").Pointer()
return p1 == p2
}),
}
if diff := cmp.Diff(w.obj, obj, opts...); diff != "" {
return errors.Errorf("err: %s", diff)
}
return nil
}

loadFunc := func(exportDir string, tags []string, options *SessionOptions) (*tf.SavedModel, error) {
return nil, nil
}
tests := []test{
{
name: "set success when loadFunc is not nil",
args: args{
loadFunc: loadFunc,
},
want: want{
obj: &T{
loadFunc: loadFunc,
},
},
},
{
name: "do nothing when loadFunc is nil",
args: args{
loadFunc: nil,
},
want: want{
obj: &T{},
},
},
}

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
if test.afterFunc != nil {
defer test.afterFunc(test.args)
}

if test.checkFunc == nil {
test.checkFunc = defaultCheckFunc
}
got := withLoadFunc(test.args.loadFunc)
obj := new(T)
got(obj)
if err := test.checkFunc(test.want, obj); err != nil {
tt.Errorf("error = %v", err)
}
})
}
}

func TestWithFeed(t *testing.T) {
type T = tensorflow
type args struct {
Expand Down Expand Up @@ -529,7 +615,7 @@ func TestWithFeed(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand Down Expand Up @@ -635,7 +721,7 @@ func TestWithFeeds(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand Down Expand Up @@ -703,7 +789,7 @@ func TestWithFetch(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand Down Expand Up @@ -809,7 +895,7 @@ func TestWithFetches(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand All @@ -830,6 +916,104 @@ func TestWithFetches(t *testing.T) {
}
}

func TestWithWarmupInputs(t *testing.T) {
type T = tensorflow
type args struct {
warmupInputs []string
}
type fields struct {
warmupInputs []string
}
type want struct {
obj *T
}
type test struct {
name string
args args
want want
fields fields
checkFunc func(want, *T) error
beforeFunc func(args)
afterFunc func(args)
}

defaultCheckFunc := func(w want, obj *T) error {
if !reflect.DeepEqual(obj, w.obj) {
return errors.Errorf("got = %v, want %v", obj, w.obj)
}
return nil
}

tests := []test{
{
name: "set nothing when warmupInputs is nil",
want: want{
obj: new(T),
},
},
{
name: "set success when warmupInputs is not nil and warmupInputs field is not nil",
args: args{
warmupInputs: []string{
"test",
},
},
fields: fields{
warmupInputs: []string{
"test",
},
},
want: want{
obj: &T{
warmupInputs: []string{
"test",
"test",
},
},
},
},
{
name: "set success when warmupInputs is not nil and warmupInputs field is nil",
args: args{
warmupInputs: []string{
"test",
},
},
want: want{
obj: &T{
warmupInputs: []string{
"test",
},
},
},
},
}

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
if test.afterFunc != nil {
defer test.afterFunc(test.args)
}

if test.checkFunc == nil {
test.checkFunc = defaultCheckFunc
}
got := WithWarmupInputs(test.args.warmupInputs...)
obj := &T{
warmupInputs: test.fields.warmupInputs,
}
got(obj)
if err := test.checkFunc(test.want, obj); err != nil {
tt.Errorf("error = %v", err)
}
})
}
}

func TestWithNdim(t *testing.T) {
type T = tensorflow
type args struct {
Expand Down Expand Up @@ -870,7 +1054,7 @@ func TestWithNdim(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
defer goleak.VerifyNone(tt)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
Expand Down
Loading

0 comments on commit 6d215f6

Please sign in to comment.