Skip to content

Commit

Permalink
tensorflow savedmodel warmup
Browse files Browse the repository at this point in the history
Signed-off-by: datelier <[email protected]>
  • Loading branch information
datelier authored and actions-user committed Jul 6, 2020
1 parent 94f6f1f commit 68beb90
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 20 deletions.
13 changes: 13 additions & 0 deletions internal/core/converter/tensorflow/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,19 @@ func WithFetches(operationNames []string, outputIndexes []int) Option {
}
}

// WithWarmupInputs returns Option that sets warmupInputs.
func WithWarmupInputs(warmupInputs ...string) Option {
return func(t *tensorflow) {
if warmupInputs != nil {
if t.warmupInputs != nil {
t.warmupInputs = append(t.warmupInputs, warmupInputs...)
} else {
t.warmupInputs = warmupInputs
}
}
}
}

// WithNdim returns Option that sets ndim.
func WithNdim(ndim uint8) Option {
return func(t *tensorflow) {
Expand Down
98 changes: 98 additions & 0 deletions internal/core/converter/tensorflow/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,104 @@ func TestWithFetches(t *testing.T) {
}
}

func TestWithWarmupInputs(t *testing.T) {
type T = tensorflow
type args struct {
warmupInputs []string
}
type fields struct {
warmupInputs []string
}
type want struct {
obj *T
}
type test struct {
name string
args args
want want
fields fields
checkFunc func(want, *T) error
beforeFunc func(args)
afterFunc func(args)
}

defaultCheckFunc := func(w want, obj *T) error {
if !reflect.DeepEqual(obj, w.obj) {
return errors.Errorf("got = %v, want %v", obj, w.obj)
}
return nil
}

tests := []test{
{
name: "set nothing when warmupInputs is nil",
want: want{
obj: new(T),
},
},
{
name: "set success when warmupInputs is not nil and warmupInputs field is not nil",
args: args{
warmupInputs: []string{
"test",
},
},
fields: fields{
warmupInputs: []string{
"test",
},
},
want: want{
obj: &T{
warmupInputs: []string{
"test",
"test",
},
},
},
},
{
name: "set success when warmupInputs is not nil and warmupInputs field is nil",
args: args{
warmupInputs: []string{
"test",
},
},
want: want{
obj: &T{
warmupInputs: []string{
"test",
},
},
},
},
}

for _, test := range tests {
t.Run(test.name, func(tt *testing.T) {
defer goleak.VerifyNone(t)
if test.beforeFunc != nil {
test.beforeFunc(test.args)
}
if test.afterFunc != nil {
defer test.afterFunc(test.args)
}

if test.checkFunc == nil {
test.checkFunc = defaultCheckFunc
}
got := WithWarmupInputs(test.args.warmupInputs...)
obj := &T{
warmupInputs: test.fields.warmupInputs,
}
got(obj)
if err := test.checkFunc(test.want, obj); err != nil {
tt.Errorf("error = %v", err)
}
})
}
}

func TestWithNdim(t *testing.T) {
type T = tensorflow
type args struct {
Expand Down
43 changes: 29 additions & 14 deletions internal/core/converter/tensorflow/tensorflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type SessionOptions = tf.SessionOptions
// Operation is a type alias for tensorflow.Operation.
type Operation = tf.Operation

// Closer is a type alias io.Closer
// Closer is a type alias io.Closer.
type Closer = io.Closer

// TF represents a tensorflow interface.
Expand All @@ -47,15 +47,16 @@ type session interface {
}

type tensorflow struct {
exportDir string
tags []string
feeds []OutputSpec
fetches []OutputSpec
operations []*Operation
options *SessionOptions
graph *tf.Graph
session session
ndim uint8
exportDir string
tags []string
feeds []OutputSpec
fetches []OutputSpec
operations []*Operation
options *SessionOptions
graph *tf.Graph
session session
warmupInputs []string
ndim uint8
}

// OutputSpec is the specification of an feed/fetch.
Expand All @@ -69,7 +70,17 @@ const (
threeDim
)

var loadFunc = tf.LoadSavedModel
var loadFunc = func(t *tensorflow) error {
model, err := tf.LoadSavedModel(t.exportDir, t.tags, t.options)
if err != nil {
return err
}

t.graph = model.Graph
t.session = model.Session

return nil
}

// New load a tensorlfow model and returns a new tensorflow struct.
func New(opts ...Option) (TF, error) {
Expand All @@ -79,13 +90,17 @@ func New(opts ...Option) (TF, error) {
opt(t)
}

model, err := loadFunc(t.exportDir, t.tags, t.options)
err := loadFunc(t)
if err != nil {
return nil, err
}

t.graph = model.Graph
t.session = model.Session
if t.warmupInputs != nil {
_, err := t.run(t.warmupInputs...)
if err != nil {
return nil, err
}
}

return t, nil
}
Expand Down
86 changes: 80 additions & 6 deletions internal/core/converter/tensorflow/tensorflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ func TestNew(t *testing.T) {
beforeFunc func(args)
afterFunc func(args)
}
graph := tf.NewGraph()
session := &mockSession{
RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) {
return []*tf.Tensor{}, nil
},
}
defaultCheckFunc := func(w want, got TF, err error) error {
if !errors.Is(err, w.err) {
return errors.Errorf("got error = %v, want %v", err, w.err)
Expand All @@ -61,8 +67,10 @@ func TestNew(t *testing.T) {
},
beforeFunc: func(args args) {
defaultOpts = []Option{}
loadFunc = func(s string, ss []string, o *SessionOptions) (*tf.SavedModel, error) {
return &tf.SavedModel{}, nil
loadFunc = func(t *tensorflow) error {
t.graph = nil
t.session = (&tf.SavedModel{}).Session
return nil
}
},
},
Expand All @@ -88,8 +96,57 @@ func TestNew(t *testing.T) {
},
beforeFunc: func(args args) {
defaultOpts = []Option{}
loadFunc = func(s string, ss []string, o *SessionOptions) (*tf.SavedModel, error) {
return &tf.SavedModel{}, nil
loadFunc = func(t *tensorflow) error {
t.graph = nil
t.session = (&tf.SavedModel{}).Session
return nil
}
},
},
{
name: "returns (t, nil) when args and warmupInputs are not nil",
args: args{
opts: []Option{
WithFeed("test", 0),
WithFetch("test", 0),
WithSessionTarget("test"),
WithSessionConfig([]byte{}),
WithWarmupInputs("test"),
WithNdim(1),
},
},
want: want{
want: &tensorflow{
feeds: []OutputSpec{
{
operationName: "test",
outputIndex: 0,
},
},
fetches: []OutputSpec{
{
operationName: "test",
outputIndex: 0,
},
},
options: &tf.SessionOptions{
Target: "test",
Config: []byte{},
},
graph: graph,
session: session,
warmupInputs: []string{
"test",
},
ndim: 1,
},
},
beforeFunc: func(args args) {
defaultOpts = []Option{}
loadFunc = func(t *tensorflow) error {
t.graph = graph
t.session = session
return nil
}
},
},
Expand All @@ -100,8 +157,25 @@ func TestNew(t *testing.T) {
},
beforeFunc: func(args args) {
defaultOpts = []Option{}
loadFunc = func(s string, ss []string, o *SessionOptions) (*tf.SavedModel, error) {
return nil, errors.New("load error")
loadFunc = func(t *tensorflow) error {
return errors.New("load error")
}
},
},
{
name: "returns (nil, error) when warmup error",
args: args{
opts: []Option{
WithWarmupInputs("test"),
},
},
want: want{
err: errors.ErrInputLength(1, 0),
},
beforeFunc: func(args args) {
defaultOpts = []Option{}
loadFunc = func(t *tensorflow) error {
return nil
}
},
},
Expand Down

0 comments on commit 68beb90

Please sign in to comment.