From 3d57c1406f3f3f8e109a0464c9c0f4caba4811a6 Mon Sep 17 00:00:00 2001 From: datelier <57349093+datelier@users.noreply.github.com> Date: Tue, 28 Apr 2020 10:51:34 +0900 Subject: [PATCH 01/13] Add tensorflow test code Signed-off-by: datelier <57349093+datelier@users.noreply.github.com> --- .../core/converter/tensorflow/tensorflow.go | 17 +- .../tensorflow/tensorflow_mock_test.go | 35 + .../converter/tensorflow/tensorflow_test.go | 903 ++++++++++++------ 3 files changed, 674 insertions(+), 281 deletions(-) create mode 100644 internal/core/converter/tensorflow/tensorflow_mock_test.go diff --git a/internal/core/converter/tensorflow/tensorflow.go b/internal/core/converter/tensorflow/tensorflow.go index 799a1268f6..841b2e97d3 100644 --- a/internal/core/converter/tensorflow/tensorflow.go +++ b/internal/core/converter/tensorflow/tensorflow.go @@ -29,6 +29,15 @@ type TF interface { GetVector(inputs ...string) ([]float64, error) GetValue(inputs ...string) (interface{}, error) GetValues(inputs ...string) (values []interface{}, err error) + Closer +} + +type session interface { + Run(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*Operation) ([]*tf.Tensor, error) + Closer +} + +type Closer interface { Close() error } @@ -42,7 +51,7 @@ type tensorflow struct { sessionConfig []byte options *SessionOptions graph *tf.Graph - session *tf.Session + session session ndim uint8 } @@ -56,6 +65,10 @@ const ( ThreeDim ) +var loadFunc = func(exportDir string, tags []string, options *SessionOptions) (*tf.SavedModel, error) { + return tf.LoadSavedModel(exportDir, tags, options) +} + func New(opts ...Option) (TF, error) { t := new(tensorflow) for _, opt := range append(defaultOpts, opts...) { @@ -69,7 +82,7 @@ func New(opts ...Option) (TF, error) { } } - model, err := tf.LoadSavedModel(t.exportDir, t.tags, t.options) + model, err := loadFunc(t.exportDir, t.tags, t.options) if err != nil { return nil, err } diff --git a/internal/core/converter/tensorflow/tensorflow_mock_test.go b/internal/core/converter/tensorflow/tensorflow_mock_test.go new file mode 100644 index 0000000000..b9b6e4bc85 --- /dev/null +++ b/internal/core/converter/tensorflow/tensorflow_mock_test.go @@ -0,0 +1,35 @@ +// +// Copyright (C) 2019-2020 Vdaas.org Vald team ( kpango, rinx, kmrmt ) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package tensorflow provides implementation of Go API for extract data to vector +package tensorflow + +import ( + tf "github.com/tensorflow/tensorflow/tensorflow/go" +) + +type mockSession struct { + RunFunc func(map[tf.Output]*tf.Tensor, []tf.Output, []*Operation) ([]*tf.Tensor, error) + CloseFunc func() error +} + +func (m *mockSession) Run(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*Operation) ([]*tf.Tensor, error) { + return m.RunFunc(feeds, fetches, operations) +} + +func (m *mockSession) Close() error { + return m.CloseFunc() +} diff --git a/internal/core/converter/tensorflow/tensorflow_test.go b/internal/core/converter/tensorflow/tensorflow_test.go index 58b9f67118..8a5fb7dfad 100644 --- a/internal/core/converter/tensorflow/tensorflow_test.go +++ b/internal/core/converter/tensorflow/tensorflow_test.go @@ -53,31 +53,74 @@ func TestNew(t *testing.T) { return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - opts: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - opts: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ + { + name: "return (t, nil): default options", + args: args{ + opts: nil, + }, + want: want{ + want: &tensorflow{ + graph: nil, + session: (&tf.SavedModel{}).Session, + }, + err: nil, + }, + checkFunc: defaultCheckFunc, + beforeFunc: func(args args) { + defaultOpts = []Option{} + loadFunc = func(s string, ss []string, o *SessionOptions) (*tf.SavedModel, error) { + return &tf.SavedModel{}, nil + } + }, + }, + { + name: "return (t, nil): args options", + args: args{ + opts: []Option{ + WithSessionTarget("test"), + WithSessionConfig([]byte{}), + WithNdim(1), + }, + }, + want: want{ + want: &tensorflow{ + sessionTarget: "test", + sessionConfig: []byte{}, + options: &tf.SessionOptions{ + Target: "test", + Config: []byte{}, + }, + graph: nil, + session: (&tf.SavedModel{}).Session, + ndim: 1, + }, + err: nil, + }, + checkFunc: defaultCheckFunc, + beforeFunc: func(args args) { + defaultOpts = []Option{} + loadFunc = func(s string, ss []string, o *SessionOptions) (*tf.SavedModel, error) { + return &tf.SavedModel{}, nil + } + }, + }, + { + name: "return (nil, error)", + args: args{ + nil, + }, + want: want{ + want: nil, + err: errors.New("load error"), + }, + checkFunc: defaultCheckFunc, + beforeFunc: func(args args) { + defaultOpts = []Option{} + loadFunc = func(s string, ss []string, o *SessionOptions) (*tf.SavedModel, error) { + return nil, errors.New("load error") + } + }, + }, } for _, test := range tests { @@ -113,7 +156,7 @@ func Test_tensorflow_Close(t *testing.T) { sessionConfig []byte options *SessionOptions graph *tf.Graph - session *tf.Session + session session ndim uint8 } type want struct { @@ -134,51 +177,54 @@ func Test_tensorflow_Close(t *testing.T) { return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ + { + name: "return nil", + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + CloseFunc: func() error { + return nil + }, + }, + ndim: 0, + }, + want: want{ + err: nil, + }, + checkFunc: defaultCheckFunc, + }, + { + name: "return error", + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + CloseFunc: func() error { + return errors.New("fail") + }, + }, + ndim: 0, + }, + want: want{ + err: errors.New("fail"), + }, + checkFunc: defaultCheckFunc, + }, } for _, test := range tests { @@ -230,7 +276,7 @@ func Test_tensorflow_run(t *testing.T) { sessionConfig []byte options *SessionOptions graph *tf.Graph - session *tf.Session + session session ndim uint8 } type want struct { @@ -256,57 +302,121 @@ func Test_tensorflow_run(t *testing.T) { return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ + { + name: "return ([], nil): inputs=nil", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return []*tf.Tensor{}, nil + }, + }, + ndim: 0, + }, + want: want{ + want: []*tf.Tensor{}, + err: nil, + }, + checkFunc: defaultCheckFunc, + }, + { + name: "return ([], nil): inputs={\"test\"}", + args: args{ + inputs: []string{ + "test", + }, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: []OutputSpec{ + OutputSpec{ + operationName: "test", + outputIndex: 0, + }, + }, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: tf.NewGraph(), + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return []*tf.Tensor{}, nil + }, + }, + ndim: 0, + }, + want: want{ + want: []*tf.Tensor{}, + err: nil, + }, + checkFunc: defaultCheckFunc, + }, + { + name: "length error", + args: args{ + inputs: []string{ + "", + }, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: nil, + ndim: 0, + }, + want: want{ + err: errors.ErrInputLength(1, 0), + }, + checkFunc: defaultCheckFunc, + }, + { + name: "session.Run() error", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: tf.NewGraph(), + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return nil, errors.New("session.Run() error") + }, + }, + ndim: 0, + }, + want: want{ + err: errors.New("session.Run() error"), + }, + checkFunc: defaultCheckFunc, + }, } for _, test := range tests { @@ -358,7 +468,7 @@ func Test_tensorflow_GetVector(t *testing.T) { sessionConfig []byte options *SessionOptions graph *tf.Graph - session *tf.Session + session session ndim uint8 } type want struct { @@ -384,57 +494,218 @@ func Test_tensorflow_GetVector(t *testing.T) { return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ + { + name: "return (vector, nil)", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor([]float64{1, 2, 3}) + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor}, nil + }, + }, + ndim: 0, + }, + want: want{ + want: []float64{1, 2, 3}, + err: nil, + }, + checkFunc: defaultCheckFunc, + }, + { + name: "run() error", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return nil, errors.New("session.Run() error") + }, + }, + ndim: 0, + }, + want: want{ + want: nil, + err: errors.New("session.Run() error"), + }, + checkFunc: defaultCheckFunc, + }, + { + name: "nil tensor error: run() return nil", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return nil, nil + }, + }, + ndim: 0, + }, + want: want{ + want: nil, + err: errors.ErrNilTensorTF([]*tf.Tensor{}), + }, + checkFunc: defaultCheckFunc, + }, + { + name: "nil tensor error: run() return [nil]", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return []*tf.Tensor{nil}, nil + }, + }, + ndim: 0, + }, + want: want{ + want: nil, + err: errors.ErrNilTensorTF([]*tf.Tensor{nil}), + }, + checkFunc: defaultCheckFunc, + }, + { + name: "failed to cast error: ndim=TwoDim", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor("test") + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor}, nil + }, + }, + ndim: 2, + }, + want: want{ + want: nil, + err: errors.ErrFailedToCastTF("test"), + }, + checkFunc: defaultCheckFunc, + }, + { + name: "failed to cast error: ndim=ThreeDim", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor("test") + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor}, nil + }, + }, + ndim: 3, + }, + want: want{ + want: nil, + err: errors.ErrFailedToCastTF("test"), + }, + checkFunc: defaultCheckFunc, + }, + { + name: "failed to cast error: ndim=default", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor("test") + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor}, nil + }, + }, + ndim: 0, + }, + want: want{ + want: nil, + err: errors.ErrFailedToCastTF("test"), + }, + checkFunc: defaultCheckFunc, + }, } for _, test := range tests { @@ -486,7 +757,7 @@ func Test_tensorflow_GetValue(t *testing.T) { sessionConfig []byte options *SessionOptions graph *tf.Graph - session *tf.Session + session session ndim uint8 } type want struct { @@ -512,57 +783,122 @@ func Test_tensorflow_GetValue(t *testing.T) { return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ + { + name: "return (value, nil)", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor("test") + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor}, nil + }, + }, + ndim: 0, + }, + want: want{ + want: "test", + err: nil, + }, + checkFunc: defaultCheckFunc, + }, + { + name: "run() error", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return nil, errors.New("session.Run() error") + }, + }, + ndim: 0, + }, + want: want{ + want: nil, + err: errors.New("session.Run() error"), + }, + checkFunc: defaultCheckFunc, + }, + { + name: "nil tensor error: run() return nil", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return nil, nil + }, + }, + ndim: 0, + }, + want: want{ + want: nil, + err: errors.ErrNilTensorTF([]*tf.Tensor{}), + }, + checkFunc: defaultCheckFunc, + }, + { + name: "nil tensor error: run() return [nil]", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return []*tf.Tensor{nil}, nil + }, + }, + ndim: 0, + }, + want: want{ + want: nil, + err: errors.ErrNilTensorTF([]*tf.Tensor{nil}), + }, + checkFunc: defaultCheckFunc, + }, } for _, test := range tests { @@ -614,7 +950,7 @@ func Test_tensorflow_GetValues(t *testing.T) { sessionConfig []byte options *SessionOptions graph *tf.Graph - session *tf.Session + session session ndim uint8 } type want struct { @@ -640,57 +976,66 @@ func Test_tensorflow_GetValues(t *testing.T) { return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ + { + name: "return (values, nil)", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor("test") + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor, tensor}, nil + }, + }, + ndim: 0, + }, + want: want{ + wantValues: []interface{}{"test", "test"}, + err: nil, + }, + checkFunc: defaultCheckFunc, + }, + { + name: "run() error", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return nil, errors.New("session.Run() error") + }, + }, + ndim: 0, + }, + want: want{ + wantValues: nil, + err: errors.New("session.Run() error"), + }, + checkFunc: defaultCheckFunc, + }, } for _, test := range tests { From 867bcc3437bf7e97a490a4a3bedd39411f77b30d Mon Sep 17 00:00:00 2001 From: datelier <57349093+datelier@users.noreply.github.com> Date: Thu, 14 May 2020 13:46:59 +0900 Subject: [PATCH 02/13] Add tensorflow option test code Signed-off-by: datelier <57349093+datelier@users.noreply.github.com> --- .../core/converter/tensorflow/option_test.go | 1450 ++++++----------- 1 file changed, 529 insertions(+), 921 deletions(-) diff --git a/internal/core/converter/tensorflow/option_test.go b/internal/core/converter/tensorflow/option_test.go index 3da2be475e..d7a8f1e3c0 100644 --- a/internal/core/converter/tensorflow/option_test.go +++ b/internal/core/converter/tensorflow/option_test.go @@ -18,83 +18,58 @@ package tensorflow import ( + "reflect" "testing" + "github.com/vdaas/vald/internal/errors" "go.uber.org/goleak" ) func TestWithSessionOptions(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { opts *SessionOptions } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - opts: nil, - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - opts: nil, - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set default", + args: args{ + opts: nil, + }, + want: want{ + obj: new(T), + }, + }, + { + name: "set value", + args: args{ + opts: &SessionOptions{}, + }, + want: want{ + obj: &T{ + options: &SessionOptions{}, + }, + }, + }, } for _, test := range tests { @@ -107,107 +82,64 @@ func TestWithSessionOptions(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithSessionOptions(test.args.opts) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithSessionOptions(test.args.opts) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithSessionOptions(test.args.opts) + obj := new(T) + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } func TestWithSessionTarget(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { tgt string } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - tgt: "", - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - tgt: "", - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set default", + args: args{ + tgt: "", + }, + want: want{ + obj: new(T), + }, + }, + { + name: "set value", + args: args{ + tgt: "test", + }, + want: want{ + obj: &T{ + sessionTarget: "test", + }, + }, + }, } for _, test := range tests { @@ -220,107 +152,64 @@ func TestWithSessionTarget(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithSessionTarget(test.args.tgt) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithSessionTarget(test.args.tgt) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithSessionTarget(test.args.tgt) + obj := new(T) + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } func TestWithSessionConfig(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { cfg []byte } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - cfg: nil, - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - cfg: nil, - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set default", + args: args{ + cfg: nil, + }, + want: want{ + obj: new(T), + }, + }, + { + name: "set value", + args: args{ + cfg: []byte{0}, + }, + want: want{ + obj: &T{ + sessionConfig: []byte{0}, + }, + }, + }, } for _, test := range tests { @@ -333,107 +222,81 @@ func TestWithSessionConfig(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithSessionConfig(test.args.cfg) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithSessionConfig(test.args.cfg) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithSessionConfig(test.args.cfg) + obj := new(T) + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } func TestWithOperations(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { opes []*Operation } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + obj *T + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - opes: nil, - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - opes: nil, - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set default", + args: args{ + opes: nil, + }, + want: want{ + obj: new(T), + }, + obj: new(T), + }, + { + name: "set value: tensorflow.operations != nil", + args: args{ + opes: []*Operation{}, + }, + want: want{ + obj: &T{ + operations: []*Operation{}, + }, + }, + obj: &T{ + operations: []*Operation{}, + }, + }, + { + name: "set value: tensorflow.operations == nil", + args: args{ + opes: []*Operation{}, + }, + want: want{ + obj: &T{ + operations: []*Operation{}, + }, + }, + obj: new(T), + }, } for _, test := range tests { @@ -446,107 +309,63 @@ func TestWithOperations(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithOperations(test.args.opes...) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithOperations(test.args.opes...) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithOperations(test.args.opes...) + got(test.obj) + if err := test.checkFunc(test.want, test.obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } func TestWithExportPath(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { path string } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - path: "", - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - path: "", - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set default", + args: args{ + path: "", + }, + want: want{ + obj: new(T), + }, + }, + { + name: "set value", + args: args{ + path: "test", + }, + want: want{ + obj: &T{ + exportDir: "test", + }, + }, + }, } for _, test := range tests { @@ -559,107 +378,88 @@ func TestWithExportPath(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithExportPath(test.args.path) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithExportPath(test.args.path) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithExportPath(test.args.path) + obj := new(T) + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } func TestWithTags(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { tags []string } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + obj *T + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - tags: nil, - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - tags: nil, - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set default", + args: args{ + tags: nil, + }, + want: want{ + obj: new(T), + }, + obj: new(T), + }, + { + name: "set value: tensorflow.tags != nil", + args: args{ + tags: []string{"test"}, + }, + want: want{ + obj: &T{ + tags: []string{ + "test", + "test", + }, + }, + }, + obj: &T{ + tags: []string{ + "test", + }, + }, + }, + { + name: "set value: tensorflow.tags == nil", + args: args{ + tags: []string{"test"}, + }, + want: want{ + obj: &T{ + tags: []string{ + "test", + }, + }, + }, + obj: new(T), + }, } for _, test := range tests { @@ -672,110 +472,61 @@ func TestWithTags(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithTags(test.args.tags...) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithTags(test.args.tags...) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithTags(test.args.tags...) + got(test.obj) + if err := test.checkFunc(test.want, test.obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } func TestWithFeed(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { operationName string outputIndex int } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - operationName: "", - outputIndex: 0, - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - operationName: "", - outputIndex: 0, - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set value", + args: args{ + operationName: "test", + outputIndex: 0, + }, + want: want{ + obj: &T{ + feeds: []OutputSpec{ + OutputSpec{ + operationName: "test", + outputIndex: 0, + }, + }, + }, + }, + }, } for _, test := range tests { @@ -788,110 +539,92 @@ func TestWithFeed(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithFeed(test.args.operationName, test.args.outputIndex) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithFeed(test.args.operationName, test.args.outputIndex) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithFeed(test.args.operationName, test.args.outputIndex) + obj := new(T) + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } func TestWithFeeds(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { operationNames []string outputIndexes []int } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - operationNames: nil, - outputIndexes: nil, - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - operationNames: nil, - outputIndexes: nil, - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set value", + args: args{ + operationNames: []string{"test"}, + outputIndexes: []int{0}, + }, + want: want{ + obj: &T{ + feeds: []OutputSpec{ + OutputSpec{ + operationName: "test", + outputIndex: 0, + }, + }, + }, + }, + }, + { + name: "operationNames == nil", + args: args{ + operationNames: nil, + outputIndexes: []int{0}, + }, + want: want{ + obj: new(T), + }, + }, + { + name: "outputIndexes == nil", + args: args{ + operationNames: []string{"test"}, + outputIndexes: nil, + }, + want: want{ + obj: new(T), + }, + }, + { + name: "operationName length != outputIndexes length", + args: args{ + operationNames: []string{}, + outputIndexes: []int{0}, + }, + want: want{ + obj: new(T), + }, + }, } for _, test := range tests { @@ -904,110 +637,62 @@ func TestWithFeeds(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithFeeds(test.args.operationNames, test.args.outputIndexes) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithFeeds(test.args.operationNames, test.args.outputIndexes) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithFeeds(test.args.operationNames, test.args.outputIndexes) + obj := new(T) + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } func TestWithFetch(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { operationName string outputIndex int } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - operationName: "", - outputIndex: 0, - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - operationName: "", - outputIndex: 0, - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set value", + args: args{ + operationName: "test", + outputIndex: 0, + }, + want: want{ + obj: &T{ + fetches: []OutputSpec{ + OutputSpec{ + operationName: "test", + outputIndex: 0, + }, + }, + }, + }, + }, } for _, test := range tests { @@ -1020,110 +705,92 @@ func TestWithFetch(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithFetch(test.args.operationName, test.args.outputIndex) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithFetch(test.args.operationName, test.args.outputIndex) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithFetch(test.args.operationName, test.args.outputIndex) + obj := new(T) + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } func TestWithFetches(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { operationNames []string outputIndexes []int } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - operationNames: nil, - outputIndexes: nil, - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - operationNames: nil, - outputIndexes: nil, - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set value", + args: args{ + operationNames: []string{"test"}, + outputIndexes: []int{0}, + }, + want: want{ + obj: &T{ + fetches: []OutputSpec{ + OutputSpec{ + operationName: "test", + outputIndex: 0, + }, + }, + }, + }, + }, + { + name: "operationName == nil", + args: args{ + operationNames: nil, + outputIndexes: []int{0}, + }, + want: want{ + obj: new(T), + }, + }, + { + name: "outputIndexs == nil", + args: args{ + operationNames: []string{"test"}, + outputIndexes: nil, + }, + want: want{ + obj: new(T), + }, + }, + { + name: "operationNames length != outputIndexs length", + args: args{ + operationNames: []string{}, + outputIndexes: []int{0}, + }, + want: want{ + obj: new(T), + }, + }, } for _, test := range tests { @@ -1136,107 +803,64 @@ func TestWithFetches(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithFetches(test.args.operationNames, test.args.outputIndexes) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithFetches(test.args.operationNames, test.args.outputIndexes) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithFetches(test.args.operationNames, test.args.outputIndexes) + obj := new(T) + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } func TestWithNdim(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { ndim uint8 } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - ndim: 0, - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - ndim: 0, - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set defalut", + args: args{ + ndim: 0, + }, + want: want{ + obj: new(T), + }, + }, + { + name: "set value", + args: args{ + ndim: 1, + }, + want: want{ + obj: &T{ + ndim: 1, + }, + }, + }, } for _, test := range tests { @@ -1249,31 +873,15 @@ func TestWithNdim(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithNdim(test.args.ndim) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithNdim(test.args.ndim) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithNdim(test.args.ndim) + obj := new(T) + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } From a0bb13f153368ffb6ffe40f81faef15d3dffecf7 Mon Sep 17 00:00:00 2001 From: datelier <57349093+datelier@users.noreply.github.com> Date: Thu, 14 May 2020 13:50:46 +0900 Subject: [PATCH 03/13] fix test name Signed-off-by: datelier <57349093+datelier@users.noreply.github.com> --- internal/core/converter/tensorflow/tensorflow_test.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/internal/core/converter/tensorflow/tensorflow_test.go b/internal/core/converter/tensorflow/tensorflow_test.go index 8a5fb7dfad..515a6a6dd5 100644 --- a/internal/core/converter/tensorflow/tensorflow_test.go +++ b/internal/core/converter/tensorflow/tensorflow_test.go @@ -23,7 +23,6 @@ import ( tf "github.com/tensorflow/tensorflow/tensorflow/go" "github.com/vdaas/vald/internal/errors" - "go.uber.org/goleak" ) @@ -303,7 +302,7 @@ func Test_tensorflow_run(t *testing.T) { } tests := []test{ { - name: "return ([], nil): inputs=nil", + name: "return ([], nil): inputs == nil", args: args{ inputs: nil, }, @@ -331,7 +330,7 @@ func Test_tensorflow_run(t *testing.T) { checkFunc: defaultCheckFunc, }, { - name: "return ([], nil): inputs={\"test\"}", + name: "return ([], nil): inputs == {\"test\"}", args: args{ inputs: []string{ "test", @@ -611,7 +610,7 @@ func Test_tensorflow_GetVector(t *testing.T) { checkFunc: defaultCheckFunc, }, { - name: "failed to cast error: ndim=TwoDim", + name: "failed to cast error: ndim == TwoDim", args: args{ inputs: nil, }, @@ -643,7 +642,7 @@ func Test_tensorflow_GetVector(t *testing.T) { checkFunc: defaultCheckFunc, }, { - name: "failed to cast error: ndim=ThreeDim", + name: "failed to cast error: ndim == ThreeDim", args: args{ inputs: nil, }, @@ -675,7 +674,7 @@ func Test_tensorflow_GetVector(t *testing.T) { checkFunc: defaultCheckFunc, }, { - name: "failed to cast error: ndim=default", + name: "failed to cast error: ndim == default", args: args{ inputs: nil, }, From f12578a66e5aec44b15a77bff876f1c872a2ed5b Mon Sep 17 00:00:00 2001 From: datelier <57349093+datelier@users.noreply.github.com> Date: Mon, 18 May 2020 14:14:41 +0900 Subject: [PATCH 04/13] fix DeepSource issue: Empty string test can be improved Signed-off-by: datelier <57349093+datelier@users.noreply.github.com> --- internal/core/converter/tensorflow/tensorflow.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/core/converter/tensorflow/tensorflow.go b/internal/core/converter/tensorflow/tensorflow.go index 841b2e97d3..28e34311e0 100644 --- a/internal/core/converter/tensorflow/tensorflow.go +++ b/internal/core/converter/tensorflow/tensorflow.go @@ -75,7 +75,7 @@ func New(opts ...Option) (TF, error) { opt(t) } - if t.options == nil && (len(t.sessionTarget) != 0 || t.sessionConfig != nil) { + if t.options == nil && (t.sessionTarget == "" || t.sessionConfig != nil) { t.options = &tf.SessionOptions{ Target: t.sessionTarget, Config: t.sessionConfig, From c512b49bf14de356d14c8886fc03b24b8e08ec10 Mon Sep 17 00:00:00 2001 From: datelier <57349093+datelier@users.noreply.github.com> Date: Mon, 18 May 2020 14:27:00 +0900 Subject: [PATCH 05/13] fix Deepsource issue: Incomplete condition detected Signed-off-by: datelier <57349093+datelier@users.noreply.github.com> --- internal/core/converter/tensorflow/tensorflow.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/core/converter/tensorflow/tensorflow.go b/internal/core/converter/tensorflow/tensorflow.go index 28e34311e0..b2d66e2425 100644 --- a/internal/core/converter/tensorflow/tensorflow.go +++ b/internal/core/converter/tensorflow/tensorflow.go @@ -122,7 +122,7 @@ func (t *tensorflow) GetVector(inputs ...string) ([]float64, error) { if err != nil { return nil, err } - if tensors == nil || tensors[0] == nil || tensors[0].Value() == nil { + if len(tensors) == 0 || tensors[0] == nil || tensors[0].Value() == nil { return nil, errors.ErrNilTensorTF(tensors) } @@ -140,7 +140,7 @@ func (t *tensorflow) GetVector(inputs ...string) ([]float64, error) { case ThreeDim: value, ok := tensors[0].Value().([][][]float64) if ok { - if value == nil || value[0] == nil { + if len(value) == 0 || value[0] == nil { return nil, errors.ErrNilTensorValueTF(value) } return value[0][0], nil @@ -162,7 +162,7 @@ func (t *tensorflow) GetValue(inputs ...string) (interface{}, error) { if err != nil { return nil, err } - if tensors == nil || tensors[0] == nil { + if len(tensors) == 0 || tensors[0] == nil { return nil, errors.ErrNilTensorTF(tensors) } return tensors[0].Value(), nil From ed7c5f1190eaca84c27d08d7661b225bda65bdad Mon Sep 17 00:00:00 2001 From: datelier <57349093+datelier@users.noreply.github.com> Date: Wed, 20 May 2020 15:46:22 +0900 Subject: [PATCH 06/13] remove sessionTarget, sessionConfig Signed-off-by: datelier <57349093+datelier@users.noreply.github.com> --- internal/core/converter/tensorflow/option.go | 20 ++++++- .../core/converter/tensorflow/option_test.go | 8 ++- .../core/converter/tensorflow/tensorflow.go | 9 --- .../converter/tensorflow/tensorflow_test.go | 60 ------------------- 4 files changed, 23 insertions(+), 74 deletions(-) diff --git a/internal/core/converter/tensorflow/option.go b/internal/core/converter/tensorflow/option.go index a017dec386..2b5df9a1bc 100644 --- a/internal/core/converter/tensorflow/option.go +++ b/internal/core/converter/tensorflow/option.go @@ -29,14 +29,22 @@ var ( func WithSessionOptions(opts *SessionOptions) Option { return func(t *tensorflow) { - t.options = opts + if opts != nil { + t.options = opts + } } } func WithSessionTarget(tgt string) Option { return func(t *tensorflow) { if len(tgt) != 0 { - t.sessionTarget = tgt + if t.options == nil { + t.options = &SessionOptions{ + Target: tgt, + } + } else { + t.options.Target = tgt + } } } } @@ -44,7 +52,13 @@ func WithSessionTarget(tgt string) Option { func WithSessionConfig(cfg []byte) Option { return func(t *tensorflow) { if cfg != nil { - t.sessionConfig = cfg + if t.options == nil { + t.options = &SessionOptions{ + Config: cfg, + } + } else { + t.options.Config = cfg + } } } } diff --git a/internal/core/converter/tensorflow/option_test.go b/internal/core/converter/tensorflow/option_test.go index d7a8f1e3c0..d2ba3ae76b 100644 --- a/internal/core/converter/tensorflow/option_test.go +++ b/internal/core/converter/tensorflow/option_test.go @@ -136,7 +136,9 @@ func TestWithSessionTarget(t *testing.T) { }, want: want{ obj: &T{ - sessionTarget: "test", + options: &SessionOptions{ + Target: "test", + }, }, }, }, @@ -206,7 +208,9 @@ func TestWithSessionConfig(t *testing.T) { }, want: want{ obj: &T{ - sessionConfig: []byte{0}, + options: &SessionOptions{ + Config: []byte{0}, + }, }, }, }, diff --git a/internal/core/converter/tensorflow/tensorflow.go b/internal/core/converter/tensorflow/tensorflow.go index b2d66e2425..bbbad54a02 100644 --- a/internal/core/converter/tensorflow/tensorflow.go +++ b/internal/core/converter/tensorflow/tensorflow.go @@ -47,8 +47,6 @@ type tensorflow struct { feeds []OutputSpec fetches []OutputSpec operations []*Operation - sessionTarget string - sessionConfig []byte options *SessionOptions graph *tf.Graph session session @@ -75,13 +73,6 @@ func New(opts ...Option) (TF, error) { opt(t) } - if t.options == nil && (t.sessionTarget == "" || t.sessionConfig != nil) { - t.options = &tf.SessionOptions{ - Target: t.sessionTarget, - Config: t.sessionConfig, - } - } - model, err := loadFunc(t.exportDir, t.tags, t.options) if err != nil { return nil, err diff --git a/internal/core/converter/tensorflow/tensorflow_test.go b/internal/core/converter/tensorflow/tensorflow_test.go index 515a6a6dd5..49a4ed99cc 100644 --- a/internal/core/converter/tensorflow/tensorflow_test.go +++ b/internal/core/converter/tensorflow/tensorflow_test.go @@ -83,8 +83,6 @@ func TestNew(t *testing.T) { }, want: want{ want: &tensorflow{ - sessionTarget: "test", - sessionConfig: []byte{}, options: &tf.SessionOptions{ Target: "test", Config: []byte{}, @@ -151,8 +149,6 @@ func Test_tensorflow_Close(t *testing.T) { feeds []OutputSpec fetches []OutputSpec operations []*Operation - sessionTarget string - sessionConfig []byte options *SessionOptions graph *tf.Graph session session @@ -184,8 +180,6 @@ func Test_tensorflow_Close(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -208,8 +202,6 @@ func Test_tensorflow_Close(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -244,8 +236,6 @@ func Test_tensorflow_Close(t *testing.T) { feeds: test.fields.feeds, fetches: test.fields.fetches, operations: test.fields.operations, - sessionTarget: test.fields.sessionTarget, - sessionConfig: test.fields.sessionConfig, options: test.fields.options, graph: test.fields.graph, session: test.fields.session, @@ -271,8 +261,6 @@ func Test_tensorflow_run(t *testing.T) { feeds []OutputSpec fetches []OutputSpec operations []*Operation - sessionTarget string - sessionConfig []byte options *SessionOptions graph *tf.Graph session session @@ -312,8 +300,6 @@ func Test_tensorflow_run(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -347,8 +333,6 @@ func Test_tensorflow_run(t *testing.T) { }, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: tf.NewGraph(), session: &mockSession{ @@ -377,8 +361,6 @@ func Test_tensorflow_run(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: nil, @@ -400,8 +382,6 @@ func Test_tensorflow_run(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: tf.NewGraph(), session: &mockSession{ @@ -436,8 +416,6 @@ func Test_tensorflow_run(t *testing.T) { feeds: test.fields.feeds, fetches: test.fields.fetches, operations: test.fields.operations, - sessionTarget: test.fields.sessionTarget, - sessionConfig: test.fields.sessionConfig, options: test.fields.options, graph: test.fields.graph, session: test.fields.session, @@ -463,8 +441,6 @@ func Test_tensorflow_GetVector(t *testing.T) { feeds []OutputSpec fetches []OutputSpec operations []*Operation - sessionTarget string - sessionConfig []byte options *SessionOptions graph *tf.Graph session session @@ -504,8 +480,6 @@ func Test_tensorflow_GetVector(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -536,8 +510,6 @@ func Test_tensorflow_GetVector(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -564,8 +536,6 @@ func Test_tensorflow_GetVector(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -592,8 +562,6 @@ func Test_tensorflow_GetVector(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -620,8 +588,6 @@ func Test_tensorflow_GetVector(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -652,8 +618,6 @@ func Test_tensorflow_GetVector(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -684,8 +648,6 @@ func Test_tensorflow_GetVector(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -725,8 +687,6 @@ func Test_tensorflow_GetVector(t *testing.T) { feeds: test.fields.feeds, fetches: test.fields.fetches, operations: test.fields.operations, - sessionTarget: test.fields.sessionTarget, - sessionConfig: test.fields.sessionConfig, options: test.fields.options, graph: test.fields.graph, session: test.fields.session, @@ -752,8 +712,6 @@ func Test_tensorflow_GetValue(t *testing.T) { feeds []OutputSpec fetches []OutputSpec operations []*Operation - sessionTarget string - sessionConfig []byte options *SessionOptions graph *tf.Graph session session @@ -793,8 +751,6 @@ func Test_tensorflow_GetValue(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -825,8 +781,6 @@ func Test_tensorflow_GetValue(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -853,8 +807,6 @@ func Test_tensorflow_GetValue(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -881,8 +833,6 @@ func Test_tensorflow_GetValue(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -918,8 +868,6 @@ func Test_tensorflow_GetValue(t *testing.T) { feeds: test.fields.feeds, fetches: test.fields.fetches, operations: test.fields.operations, - sessionTarget: test.fields.sessionTarget, - sessionConfig: test.fields.sessionConfig, options: test.fields.options, graph: test.fields.graph, session: test.fields.session, @@ -945,8 +893,6 @@ func Test_tensorflow_GetValues(t *testing.T) { feeds []OutputSpec fetches []OutputSpec operations []*Operation - sessionTarget string - sessionConfig []byte options *SessionOptions graph *tf.Graph session session @@ -986,8 +932,6 @@ func Test_tensorflow_GetValues(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -1018,8 +962,6 @@ func Test_tensorflow_GetValues(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -1055,8 +997,6 @@ func Test_tensorflow_GetValues(t *testing.T) { feeds: test.fields.feeds, fetches: test.fields.fetches, operations: test.fields.operations, - sessionTarget: test.fields.sessionTarget, - sessionConfig: test.fields.sessionConfig, options: test.fields.options, graph: test.fields.graph, session: test.fields.session, From ffcb2afebb7c1684bfa69bbfcede66835f7f9e02 Mon Sep 17 00:00:00 2001 From: datelier <57349093+datelier@users.noreply.github.com> Date: Wed, 20 May 2020 15:55:47 +0900 Subject: [PATCH 07/13] fix DeepSource issue: Empty string test can be improved Signed-off-by: datelier <57349093+datelier@users.noreply.github.com> --- internal/core/converter/tensorflow/option.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/core/converter/tensorflow/option.go b/internal/core/converter/tensorflow/option.go index 2b5df9a1bc..32aeb670a1 100644 --- a/internal/core/converter/tensorflow/option.go +++ b/internal/core/converter/tensorflow/option.go @@ -37,7 +37,7 @@ func WithSessionOptions(opts *SessionOptions) Option { func WithSessionTarget(tgt string) Option { return func(t *tensorflow) { - if len(tgt) != 0 { + if tgt != "" { if t.options == nil { t.options = &SessionOptions{ Target: tgt, @@ -77,7 +77,7 @@ func WithOperations(opes ...*Operation) Option { func WithExportPath(path string) Option { return func(t *tensorflow) { - if len(path) != 0 { + if path != "" { t.exportDir = path } } From 5f837f0720a93f729169328da6489726009b8d5e Mon Sep 17 00:00:00 2001 From: datelier <57349093+datelier@users.noreply.github.com> Date: Thu, 21 May 2020 13:55:47 +0900 Subject: [PATCH 08/13] fix test case based on review Signed-off-by: datelier <57349093+datelier@users.noreply.github.com> --- .../core/converter/tensorflow/option_test.go | 181 +++---- .../converter/tensorflow/tensorflow_test.go | 488 +++++------------- 2 files changed, 210 insertions(+), 459 deletions(-) diff --git a/internal/core/converter/tensorflow/option_test.go b/internal/core/converter/tensorflow/option_test.go index d2ba3ae76b..dc9072c675 100644 --- a/internal/core/converter/tensorflow/option_test.go +++ b/internal/core/converter/tensorflow/option_test.go @@ -51,22 +51,19 @@ func TestWithSessionOptions(t *testing.T) { tests := []test{ { - name: "set default", - args: args{ - opts: nil, - }, + name: "set nothing when opts is nil", want: want{ obj: new(T), }, }, { - name: "set value", + name: "set success when opts is not nil", args: args{ - opts: &SessionOptions{}, + opts: new(SessionOptions), }, want: want{ obj: &T{ - options: &SessionOptions{}, + options: new(SessionOptions), }, }, }, @@ -121,16 +118,13 @@ func TestWithSessionTarget(t *testing.T) { tests := []test{ { - name: "set default", - args: args{ - tgt: "", - }, + name: "set nothing when tgt is empty", want: want{ obj: new(T), }, }, { - name: "set value", + name: "set success when tfg is `test`", args: args{ tgt: "test", }, @@ -193,23 +187,20 @@ func TestWithSessionConfig(t *testing.T) { tests := []test{ { - name: "set default", - args: args{ - cfg: nil, - }, + name: "set nothing when cfg is nil", want: want{ obj: new(T), }, }, { - name: "set value", + name: "set success when cfg is []byte{}", args: args{ - cfg: []byte{0}, + cfg: []byte{}, }, want: want{ obj: &T{ options: &SessionOptions{ - Config: []byte{0}, + Config: []byte{}, }, }, }, @@ -244,14 +235,17 @@ func TestWithOperations(t *testing.T) { type args struct { opes []*Operation } + type fields struct { + opes []*Operation + } type want struct { obj *T } type test struct { name string args args + fields fields want want - obj *T checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) @@ -266,31 +260,27 @@ func TestWithOperations(t *testing.T) { tests := []test{ { - name: "set default", - args: args{ - opes: nil, - }, + name: "set nothing when opes is nil", want: want{ obj: new(T), }, - obj: new(T), }, { - name: "set value: tensorflow.operations != nil", + name: "set success when opes is not nil and operations field is not nil", args: args{ opes: []*Operation{}, }, + fields: fields{ + opes: []*Operation{}, + }, want: want{ obj: &T{ operations: []*Operation{}, }, }, - obj: &T{ - operations: []*Operation{}, - }, }, { - name: "set value: tensorflow.operations == nil", + name: "set success when opes is not nil and operations field is nil", args: args{ opes: []*Operation{}, }, @@ -299,7 +289,6 @@ func TestWithOperations(t *testing.T) { operations: []*Operation{}, }, }, - obj: new(T), }, } @@ -317,8 +306,11 @@ func TestWithOperations(t *testing.T) { test.checkFunc = defaultCheckFunc } got := WithOperations(test.args.opes...) - got(test.obj) - if err := test.checkFunc(test.want, test.obj); err != nil { + obj := &T{ + operations: test.fields.opes, + } + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { tt.Errorf("error = %v", err) } }) @@ -351,16 +343,13 @@ func TestWithExportPath(t *testing.T) { tests := []test{ { - name: "set default", - args: args{ - path: "", - }, + name: "set nothing when path is empty", want: want{ obj: new(T), }, }, { - name: "set value", + name: "set success when path is `test`", args: args{ path: "test", }, @@ -400,6 +389,9 @@ func TestWithTags(t *testing.T) { type args struct { tags []string } + type fields struct { + tags []string + } type want struct { obj *T } @@ -407,7 +399,7 @@ func TestWithTags(t *testing.T) { name string args args want want - obj *T + fields fields checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) @@ -422,19 +414,22 @@ func TestWithTags(t *testing.T) { tests := []test{ { - name: "set default", - args: args{ - tags: nil, - }, + name: "set nothing when tags is nil", want: want{ obj: new(T), }, - obj: new(T), }, { - name: "set value: tensorflow.tags != nil", + name: "set success when tags is not nil and tags field is not nil", args: args{ - tags: []string{"test"}, + tags: []string{ + "test", + }, + }, + fields: fields{ + tags: []string{ + "test", + }, }, want: want{ obj: &T{ @@ -444,16 +439,13 @@ func TestWithTags(t *testing.T) { }, }, }, - obj: &T{ - tags: []string{ - "test", - }, - }, }, { - name: "set value: tensorflow.tags == nil", + name: "set success when tags is not nil and tags field is nil", args: args{ - tags: []string{"test"}, + tags: []string{ + "test", + }, }, want: want{ obj: &T{ @@ -462,7 +454,6 @@ func TestWithTags(t *testing.T) { }, }, }, - obj: new(T), }, } @@ -480,8 +471,11 @@ func TestWithTags(t *testing.T) { test.checkFunc = defaultCheckFunc } got := WithTags(test.args.tags...) - got(test.obj) - if err := test.checkFunc(test.want, test.obj); err != nil { + obj := &T{ + tags: test.fields.tags, + } + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { tt.Errorf("error = %v", err) } }) @@ -515,7 +509,7 @@ func TestWithFeed(t *testing.T) { tests := []test{ { - name: "set value", + name: "set success when operationName is `test` and outputIndex is 0", args: args{ operationName: "test", outputIndex: 0, @@ -583,10 +577,14 @@ func TestWithFeeds(t *testing.T) { tests := []test{ { - name: "set value", + name: "set success when operationNames is []string{`test`} and outputIndexes is []int{0}", args: args{ - operationNames: []string{"test"}, - outputIndexes: []int{0}, + operationNames: []string{ + "test", + }, + outputIndexes: []int{ + 0, + }, }, want: want{ obj: &T{ @@ -600,30 +598,34 @@ func TestWithFeeds(t *testing.T) { }, }, { - name: "operationNames == nil", + name: "set nothing when operationNames is nil", args: args{ - operationNames: nil, - outputIndexes: []int{0}, + outputIndexes: []int{ + 0, + }, }, want: want{ obj: new(T), }, }, { - name: "outputIndexes == nil", + name: "set nothing when outputIndexes is nil", args: args{ - operationNames: []string{"test"}, - outputIndexes: nil, + operationNames: []string{ + "test", + }, }, want: want{ obj: new(T), }, }, { - name: "operationName length != outputIndexes length", + name: "set nothing when length of operationName and outputIndexes are different", args: args{ operationNames: []string{}, - outputIndexes: []int{0}, + outputIndexes: []int{ + 0, + }, }, want: want{ obj: new(T), @@ -681,7 +683,7 @@ func TestWithFetch(t *testing.T) { tests := []test{ { - name: "set value", + name: "set success when operationName is `test` and outputIndex is 0", args: args{ operationName: "test", outputIndex: 0, @@ -749,10 +751,14 @@ func TestWithFetches(t *testing.T) { tests := []test{ { - name: "set value", + name: "set success when operationNames is []string{`test`} and outputIndexes is []int{0}", args: args{ - operationNames: []string{"test"}, - outputIndexes: []int{0}, + operationNames: []string{ + "test", + }, + outputIndexes: []int{ + 0, + }, }, want: want{ obj: &T{ @@ -766,30 +772,34 @@ func TestWithFetches(t *testing.T) { }, }, { - name: "operationName == nil", + name: "set nothing when operationNames is nil", args: args{ - operationNames: nil, - outputIndexes: []int{0}, + outputIndexes: []int{ + 0, + }, }, want: want{ obj: new(T), }, }, { - name: "outputIndexs == nil", + name: "set nothing when outputIndexs is nil", args: args{ - operationNames: []string{"test"}, - outputIndexes: nil, + operationNames: []string{ + "test", + }, }, want: want{ obj: new(T), }, }, { - name: "operationNames length != outputIndexs length", + name: "set nothing when length of operationNames and outputIndexs are different", args: args{ operationNames: []string{}, - outputIndexes: []int{0}, + outputIndexes: []int{ + 0, + }, }, want: want{ obj: new(T), @@ -846,16 +856,7 @@ func TestWithNdim(t *testing.T) { tests := []test{ { - name: "set defalut", - args: args{ - ndim: 0, - }, - want: want{ - obj: new(T), - }, - }, - { - name: "set value", + name: "set success when ndim is 1", args: args{ ndim: 1, }, diff --git a/internal/core/converter/tensorflow/tensorflow_test.go b/internal/core/converter/tensorflow/tensorflow_test.go index 49a4ed99cc..3a3e62bc41 100644 --- a/internal/core/converter/tensorflow/tensorflow_test.go +++ b/internal/core/converter/tensorflow/tensorflow_test.go @@ -53,18 +53,12 @@ func TestNew(t *testing.T) { } tests := []test{ { - name: "return (t, nil): default options", - args: args{ - opts: nil, - }, + name: "returns (t, nil) when opts is nil", want: want{ want: &tensorflow{ - graph: nil, session: (&tf.SavedModel{}).Session, }, - err: nil, }, - checkFunc: defaultCheckFunc, beforeFunc: func(args args) { defaultOpts = []Option{} loadFunc = func(s string, ss []string, o *SessionOptions) (*tf.SavedModel, error) { @@ -73,7 +67,7 @@ func TestNew(t *testing.T) { }, }, { - name: "return (t, nil): args options", + name: "returns (t, nil) when args is not nil", args: args{ opts: []Option{ WithSessionTarget("test"), @@ -91,9 +85,7 @@ func TestNew(t *testing.T) { session: (&tf.SavedModel{}).Session, ndim: 1, }, - err: nil, }, - checkFunc: defaultCheckFunc, beforeFunc: func(args args) { defaultOpts = []Option{} loadFunc = func(s string, ss []string, o *SessionOptions) (*tf.SavedModel, error) { @@ -102,15 +94,10 @@ func TestNew(t *testing.T) { }, }, { - name: "return (nil, error)", - args: args{ - nil, - }, + name: "returns (nil, error) when loadFunc function returns error", want: want{ - want: nil, - err: errors.New("load error"), + err: errors.New("load error"), }, - checkFunc: defaultCheckFunc, beforeFunc: func(args args) { defaultOpts = []Option{} loadFunc = func(s string, ss []string, o *SessionOptions) (*tf.SavedModel, error) { @@ -144,15 +131,7 @@ func TestNew(t *testing.T) { func Test_tensorflow_Close(t *testing.T) { type fields struct { - exportDir string - tags []string - feeds []OutputSpec - fetches []OutputSpec - operations []*Operation - options *SessionOptions - graph *tf.Graph - session session - ndim uint8 + session session } type want struct { err error @@ -175,46 +154,25 @@ func Test_tensorflow_Close(t *testing.T) { { name: "return nil", fields: fields{ - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - options: nil, - graph: nil, session: &mockSession{ CloseFunc: func() error { return nil }, }, - ndim: 0, }, - want: want{ - err: nil, - }, - checkFunc: defaultCheckFunc, }, { name: "return error", fields: fields{ - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - options: nil, - graph: nil, session: &mockSession{ CloseFunc: func() error { return errors.New("fail") }, }, - ndim: 0, }, want: want{ err: errors.New("fail"), }, - checkFunc: defaultCheckFunc, }, } @@ -231,15 +189,7 @@ func Test_tensorflow_Close(t *testing.T) { test.checkFunc = defaultCheckFunc } t := &tensorflow{ - exportDir: test.fields.exportDir, - tags: test.fields.tags, - feeds: test.fields.feeds, - fetches: test.fields.fetches, - operations: test.fields.operations, - options: test.fields.options, - graph: test.fields.graph, - session: test.fields.session, - ndim: test.fields.ndim, + session: test.fields.session, } err := t.Close() @@ -256,15 +206,9 @@ func Test_tensorflow_run(t *testing.T) { inputs []string } type fields struct { - exportDir string - tags []string - feeds []OutputSpec - fetches []OutputSpec - operations []*Operation - options *SessionOptions - graph *tf.Graph - session session - ndim uint8 + feeds []OutputSpec + graph *tf.Graph + session session } type want struct { want []*tf.Tensor @@ -290,111 +234,67 @@ func Test_tensorflow_run(t *testing.T) { } tests := []test{ { - name: "return ([], nil): inputs == nil", - args: args{ - inputs: nil, - }, + name: "returns ([], nil) when inputs is nil", fields: fields{ - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - options: nil, - graph: nil, session: &mockSession{ RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { return []*tf.Tensor{}, nil }, }, - ndim: 0, }, want: want{ want: []*tf.Tensor{}, - err: nil, }, - checkFunc: defaultCheckFunc, }, { - name: "return ([], nil): inputs == {\"test\"}", + name: "returns ([], nil) when inputs is []string{`test`}", args: args{ inputs: []string{ "test", }, }, fields: fields{ - exportDir: "", - tags: nil, feeds: []OutputSpec{ OutputSpec{ operationName: "test", outputIndex: 0, }, }, - fetches: nil, - operations: nil, - options: nil, - graph: tf.NewGraph(), + graph: tf.NewGraph(), session: &mockSession{ RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { return []*tf.Tensor{}, nil }, }, - ndim: 0, }, want: want{ want: []*tf.Tensor{}, - err: nil, }, - checkFunc: defaultCheckFunc, }, { - name: "length error", + name: "returns (nil, error) when length of inputs and feeds field are different", args: args{ inputs: []string{ "", }, }, - fields: fields{ - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, want: want{ err: errors.ErrInputLength(1, 0), }, - checkFunc: defaultCheckFunc, }, { - name: "session.Run() error", - args: args{ - inputs: nil, - }, + name: "returns (nil, error) when Run function returns (nil, error)", fields: fields{ - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - options: nil, - graph: tf.NewGraph(), + graph: tf.NewGraph(), session: &mockSession{ RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { return nil, errors.New("session.Run() error") }, }, - ndim: 0, }, want: want{ err: errors.New("session.Run() error"), }, - checkFunc: defaultCheckFunc, }, } @@ -411,15 +311,9 @@ func Test_tensorflow_run(t *testing.T) { test.checkFunc = defaultCheckFunc } t := &tensorflow{ - exportDir: test.fields.exportDir, - tags: test.fields.tags, - feeds: test.fields.feeds, - fetches: test.fields.fetches, - operations: test.fields.operations, - options: test.fields.options, - graph: test.fields.graph, - session: test.fields.session, - ndim: test.fields.ndim, + feeds: test.fields.feeds, + graph: test.fields.graph, + session: test.fields.session, } got, err := t.run(test.args.inputs...) @@ -436,15 +330,8 @@ func Test_tensorflow_GetVector(t *testing.T) { inputs []string } type fields struct { - exportDir string - tags []string - feeds []OutputSpec - fetches []OutputSpec - operations []*Operation - options *SessionOptions - graph *tf.Graph - session session - ndim uint8 + session session + ndim uint8 } type want struct { want []float64 @@ -470,126 +357,130 @@ func Test_tensorflow_GetVector(t *testing.T) { } tests := []test{ { - name: "return (vector, nil)", - args: args{ - inputs: nil, + name: "returns (vector, nil) when run function returns (tensors, nil) and ndim is default", + fields: fields{ + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor([]float64{ + 1, + 2, + 3, + }) + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor}, nil + }, + }, }, + want: want{ + want: []float64{ + 1, + 2, + 3, + }, + }, + }, + { + name: "returns (vector, nil) when run function returns (tensors, nil) and ndim is 2", fields: fields{ - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - options: nil, - graph: nil, session: &mockSession{ RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { - tensor, err := tf.NewTensor([]float64{1, 2, 3}) + tensor, err := tf.NewTensor([][]float64{ + []float64{ + 1, + 2, + 3, + }, + }) if err != nil { return nil, errors.New("NewTensor error") } return []*tf.Tensor{tensor}, nil }, }, - ndim: 0, + ndim: 2, }, want: want{ - want: []float64{1, 2, 3}, - err: nil, + want: []float64{ + 1, + 2, + 3, + }, }, - checkFunc: defaultCheckFunc, }, { - name: "run() error", - args: args{ - inputs: nil, + name: "returns (vector, nil) when run function returns (tensors, nil) and ndim is 3", + fields: fields{ + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor([][][]float64{ + [][]float64{ + []float64{ + 1, + 2, + 3, + }, + }, + }) + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor}, nil + }, + }, + ndim: 3, + }, + want: want{ + want: []float64{ + 1, + 2, + 3, + }, }, + }, + { + name: "returns (nil, error) when run function returns (nil, error)", fields: fields{ - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - options: nil, - graph: nil, session: &mockSession{ RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { return nil, errors.New("session.Run() error") }, }, - ndim: 0, }, want: want{ - want: nil, - err: errors.New("session.Run() error"), + err: errors.New("session.Run() error"), }, - checkFunc: defaultCheckFunc, }, { - name: "nil tensor error: run() return nil", - args: args{ - inputs: nil, - }, + name: "returns (nil, error) when tensors returned by the run funcion is nil", fields: fields{ - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - options: nil, - graph: nil, session: &mockSession{ RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { return nil, nil }, }, - ndim: 0, }, want: want{ - want: nil, - err: errors.ErrNilTensorTF([]*tf.Tensor{}), + err: errors.ErrNilTensorTF([]*tf.Tensor{}), }, - checkFunc: defaultCheckFunc, }, { - name: "nil tensor error: run() return [nil]", - args: args{ - inputs: nil, - }, + name: "returns (nil, error) when element of tensors returned by the run funcion is nil", fields: fields{ - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - options: nil, - graph: nil, session: &mockSession{ RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { return []*tf.Tensor{nil}, nil }, }, - ndim: 0, }, want: want{ - want: nil, - err: errors.ErrNilTensorTF([]*tf.Tensor{nil}), + err: errors.ErrNilTensorTF([]*tf.Tensor{nil}), }, - checkFunc: defaultCheckFunc, }, { - name: "failed to cast error: ndim == TwoDim", - args: args{ - inputs: nil, - }, + name: "returns (nil, error) when ndim is `TwoDim` and returns error of `ErrFailedToCastTF`", fields: fields{ - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - options: nil, - graph: nil, session: &mockSession{ RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { tensor, err := tf.NewTensor("test") @@ -602,24 +493,12 @@ func Test_tensorflow_GetVector(t *testing.T) { ndim: 2, }, want: want{ - want: nil, - err: errors.ErrFailedToCastTF("test"), + err: errors.ErrFailedToCastTF("test"), }, - checkFunc: defaultCheckFunc, }, { - name: "failed to cast error: ndim == ThreeDim", - args: args{ - inputs: nil, - }, + name: "returns (nil, error) when ndim is `ThreeDim` and returns error of `ErrFailedToCastTF`", fields: fields{ - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - options: nil, - graph: nil, session: &mockSession{ RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { tensor, err := tf.NewTensor("test") @@ -632,24 +511,12 @@ func Test_tensorflow_GetVector(t *testing.T) { ndim: 3, }, want: want{ - want: nil, - err: errors.ErrFailedToCastTF("test"), + err: errors.ErrFailedToCastTF("test"), }, - checkFunc: defaultCheckFunc, }, { - name: "failed to cast error: ndim == default", - args: args{ - inputs: nil, - }, + name: "returns (nil, error) when ndim is `default` and returns error of `ErrFailedToCastTF`", fields: fields{ - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - options: nil, - graph: nil, session: &mockSession{ RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { tensor, err := tf.NewTensor("test") @@ -659,13 +526,10 @@ func Test_tensorflow_GetVector(t *testing.T) { return []*tf.Tensor{tensor}, nil }, }, - ndim: 0, }, want: want{ - want: nil, - err: errors.ErrFailedToCastTF("test"), + err: errors.ErrFailedToCastTF("test"), }, - checkFunc: defaultCheckFunc, }, } @@ -682,15 +546,8 @@ func Test_tensorflow_GetVector(t *testing.T) { test.checkFunc = defaultCheckFunc } t := &tensorflow{ - exportDir: test.fields.exportDir, - tags: test.fields.tags, - feeds: test.fields.feeds, - fetches: test.fields.fetches, - operations: test.fields.operations, - options: test.fields.options, - graph: test.fields.graph, - session: test.fields.session, - ndim: test.fields.ndim, + session: test.fields.session, + ndim: test.fields.ndim, } got, err := t.GetVector(test.args.inputs...) @@ -707,15 +564,7 @@ func Test_tensorflow_GetValue(t *testing.T) { inputs []string } type fields struct { - exportDir string - tags []string - feeds []OutputSpec - fetches []OutputSpec - operations []*Operation - options *SessionOptions - graph *tf.Graph - session session - ndim uint8 + session session } type want struct { want interface{} @@ -741,18 +590,8 @@ func Test_tensorflow_GetValue(t *testing.T) { } tests := []test{ { - name: "return (value, nil)", - args: args{ - inputs: nil, - }, + name: "returns (value, nil) when run function returns (tensors, nil)", fields: fields{ - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - options: nil, - graph: nil, session: &mockSession{ RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { tensor, err := tf.NewTensor("test") @@ -762,91 +601,49 @@ func Test_tensorflow_GetValue(t *testing.T) { return []*tf.Tensor{tensor}, nil }, }, - ndim: 0, }, want: want{ want: "test", - err: nil, }, - checkFunc: defaultCheckFunc, }, { - name: "run() error", - args: args{ - inputs: nil, - }, + name: "returns (nil, error) when run function returns (nil, error)", fields: fields{ - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - options: nil, - graph: nil, session: &mockSession{ RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { return nil, errors.New("session.Run() error") }, }, - ndim: 0, }, want: want{ - want: nil, - err: errors.New("session.Run() error"), + err: errors.New("session.Run() error"), }, - checkFunc: defaultCheckFunc, }, { - name: "nil tensor error: run() return nil", - args: args{ - inputs: nil, - }, + name: "returns (nil, error) when tensors returned by the run funcion is nil", fields: fields{ - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - options: nil, - graph: nil, session: &mockSession{ RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { return nil, nil }, }, - ndim: 0, }, want: want{ - want: nil, - err: errors.ErrNilTensorTF([]*tf.Tensor{}), + err: errors.ErrNilTensorTF([]*tf.Tensor{}), }, - checkFunc: defaultCheckFunc, }, { - name: "nil tensor error: run() return [nil]", - args: args{ - inputs: nil, - }, + name: "returns (nil, error) when element of tensors returned by the run funcion is nil", fields: fields{ - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - options: nil, - graph: nil, session: &mockSession{ RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { return []*tf.Tensor{nil}, nil }, }, - ndim: 0, }, want: want{ - want: nil, - err: errors.ErrNilTensorTF([]*tf.Tensor{nil}), + err: errors.ErrNilTensorTF([]*tf.Tensor{nil}), }, - checkFunc: defaultCheckFunc, }, } @@ -863,15 +660,7 @@ func Test_tensorflow_GetValue(t *testing.T) { test.checkFunc = defaultCheckFunc } t := &tensorflow{ - exportDir: test.fields.exportDir, - tags: test.fields.tags, - feeds: test.fields.feeds, - fetches: test.fields.fetches, - operations: test.fields.operations, - options: test.fields.options, - graph: test.fields.graph, - session: test.fields.session, - ndim: test.fields.ndim, + session: test.fields.session, } got, err := t.GetValue(test.args.inputs...) @@ -888,15 +677,7 @@ func Test_tensorflow_GetValues(t *testing.T) { inputs []string } type fields struct { - exportDir string - tags []string - feeds []OutputSpec - fetches []OutputSpec - operations []*Operation - options *SessionOptions - graph *tf.Graph - session session - ndim uint8 + session session } type want struct { wantValues []interface{} @@ -922,18 +703,8 @@ func Test_tensorflow_GetValues(t *testing.T) { } tests := []test{ { - name: "return (values, nil)", - args: args{ - inputs: nil, - }, + name: "return (values, nil) when run function returns (tensors, nil)", fields: fields{ - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - options: nil, - graph: nil, session: &mockSession{ RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { tensor, err := tf.NewTensor("test") @@ -943,39 +714,26 @@ func Test_tensorflow_GetValues(t *testing.T) { return []*tf.Tensor{tensor, tensor}, nil }, }, - ndim: 0, }, want: want{ - wantValues: []interface{}{"test", "test"}, - err: nil, + wantValues: []interface{}{ + "test", + "test", + }, }, - checkFunc: defaultCheckFunc, }, { - name: "run() error", - args: args{ - inputs: nil, - }, + name: "returns (nil, error) when run function returns (nil, error)", fields: fields{ - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - options: nil, - graph: nil, session: &mockSession{ RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { return nil, errors.New("session.Run() error") }, }, - ndim: 0, }, want: want{ - wantValues: nil, - err: errors.New("session.Run() error"), + err: errors.New("session.Run() error"), }, - checkFunc: defaultCheckFunc, }, } @@ -992,15 +750,7 @@ func Test_tensorflow_GetValues(t *testing.T) { test.checkFunc = defaultCheckFunc } t := &tensorflow{ - exportDir: test.fields.exportDir, - tags: test.fields.tags, - feeds: test.fields.feeds, - fetches: test.fields.fetches, - operations: test.fields.operations, - options: test.fields.options, - graph: test.fields.graph, - session: test.fields.session, - ndim: test.fields.ndim, + session: test.fields.session, } gotValues, err := t.GetValues(test.args.inputs...) From ae114521c3fc51d7e4c17b0f55446aa198d1f832 Mon Sep 17 00:00:00 2001 From: datelier <57349093+datelier@users.noreply.github.com> Date: Mon, 25 May 2020 13:18:48 +0900 Subject: [PATCH 09/13] gofmt Signed-off-by: datelier <57349093+datelier@users.noreply.github.com> --- .../core/converter/tensorflow/option_test.go | 8 ++++---- .../core/converter/tensorflow/tensorflow.go | 18 +++++++++--------- .../converter/tensorflow/tensorflow_test.go | 8 ++++---- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/internal/core/converter/tensorflow/option_test.go b/internal/core/converter/tensorflow/option_test.go index dc9072c675..0c522faf4e 100644 --- a/internal/core/converter/tensorflow/option_test.go +++ b/internal/core/converter/tensorflow/option_test.go @@ -517,7 +517,7 @@ func TestWithFeed(t *testing.T) { want: want{ obj: &T{ feeds: []OutputSpec{ - OutputSpec{ + { operationName: "test", outputIndex: 0, }, @@ -589,7 +589,7 @@ func TestWithFeeds(t *testing.T) { want: want{ obj: &T{ feeds: []OutputSpec{ - OutputSpec{ + { operationName: "test", outputIndex: 0, }, @@ -691,7 +691,7 @@ func TestWithFetch(t *testing.T) { want: want{ obj: &T{ fetches: []OutputSpec{ - OutputSpec{ + { operationName: "test", outputIndex: 0, }, @@ -763,7 +763,7 @@ func TestWithFetches(t *testing.T) { want: want{ obj: &T{ fetches: []OutputSpec{ - OutputSpec{ + { operationName: "test", outputIndex: 0, }, diff --git a/internal/core/converter/tensorflow/tensorflow.go b/internal/core/converter/tensorflow/tensorflow.go index bbbad54a02..728cc2e1dc 100644 --- a/internal/core/converter/tensorflow/tensorflow.go +++ b/internal/core/converter/tensorflow/tensorflow.go @@ -42,15 +42,15 @@ type Closer interface { } type tensorflow struct { - exportDir string - tags []string - feeds []OutputSpec - fetches []OutputSpec - operations []*Operation - options *SessionOptions - graph *tf.Graph - session session - ndim uint8 + exportDir string + tags []string + feeds []OutputSpec + fetches []OutputSpec + operations []*Operation + options *SessionOptions + graph *tf.Graph + session session + ndim uint8 } type OutputSpec struct { diff --git a/internal/core/converter/tensorflow/tensorflow_test.go b/internal/core/converter/tensorflow/tensorflow_test.go index 3a3e62bc41..f2f52b6cfb 100644 --- a/internal/core/converter/tensorflow/tensorflow_test.go +++ b/internal/core/converter/tensorflow/tensorflow_test.go @@ -255,7 +255,7 @@ func Test_tensorflow_run(t *testing.T) { }, fields: fields{ feeds: []OutputSpec{ - OutputSpec{ + { operationName: "test", outputIndex: 0, }, @@ -387,7 +387,7 @@ func Test_tensorflow_GetVector(t *testing.T) { session: &mockSession{ RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { tensor, err := tf.NewTensor([][]float64{ - []float64{ + { 1, 2, 3, @@ -415,8 +415,8 @@ func Test_tensorflow_GetVector(t *testing.T) { session: &mockSession{ RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { tensor, err := tf.NewTensor([][][]float64{ - [][]float64{ - []float64{ + { + { 1, 2, 3, From 5ab7953fae21069c3d2a08c0cab8fad298c6aaca Mon Sep 17 00:00:00 2001 From: datelier <57349093+datelier@users.noreply.github.com> Date: Mon, 25 May 2020 14:32:48 +0900 Subject: [PATCH 10/13] fix golangci-lint issue Signed-off-by: datelier <57349093+datelier@users.noreply.github.com> --- .../core/converter/tensorflow/tensorflow.go | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/internal/core/converter/tensorflow/tensorflow.go b/internal/core/converter/tensorflow/tensorflow.go index 728cc2e1dc..6a41145e67 100644 --- a/internal/core/converter/tensorflow/tensorflow.go +++ b/internal/core/converter/tensorflow/tensorflow.go @@ -69,6 +69,7 @@ var loadFunc = func(exportDir string, tags []string, options *SessionOptions) (* func New(opts ...Option) (TF, error) { t := new(tensorflow) + for _, opt := range append(defaultOpts, opts...) { opt(t) } @@ -77,8 +78,10 @@ func New(opts ...Option) (TF, error) { if err != nil { return nil, err } + t.graph = model.Graph t.session = model.Session + return t, nil } @@ -92,11 +95,13 @@ func (t *tensorflow) run(inputs ...string) ([]*tf.Tensor, error) { } feeds := make(map[tf.Output]*tf.Tensor, len(inputs)) + for i, val := range inputs { inputTensor, err := tf.NewTensor(val) if err != nil { return nil, err } + feeds[t.graph.Operation(t.feeds[i].operationName).Output(t.feeds[i].outputIndex)] = inputTensor } @@ -113,6 +118,7 @@ func (t *tensorflow) GetVector(inputs ...string) ([]float64, error) { if err != nil { return nil, err } + if len(tensors) == 0 || tensors[0] == nil || tensors[0].Value() == nil { return nil, errors.ErrNilTensorTF(tensors) } @@ -124,27 +130,29 @@ func (t *tensorflow) GetVector(inputs ...string) ([]float64, error) { if value == nil { return nil, errors.ErrNilTensorValueTF(value) } + return value[0], nil - } else { - return nil, errors.ErrFailedToCastTF(tensors[0].Value()) } + + return nil, errors.ErrFailedToCastTF(tensors[0].Value()) case ThreeDim: value, ok := tensors[0].Value().([][][]float64) if ok { if len(value) == 0 || value[0] == nil { return nil, errors.ErrNilTensorValueTF(value) } + return value[0][0], nil - } else { - return nil, errors.ErrFailedToCastTF(tensors[0].Value()) } + + return nil, errors.ErrFailedToCastTF(tensors[0].Value()) default: value, ok := tensors[0].Value().([]float64) if ok { return value, nil - } else { - return nil, errors.ErrFailedToCastTF(tensors[0].Value()) } + + return nil, errors.ErrFailedToCastTF(tensors[0].Value()) } } @@ -153,9 +161,11 @@ func (t *tensorflow) GetValue(inputs ...string) (interface{}, error) { if err != nil { return nil, err } + if len(tensors) == 0 || tensors[0] == nil { return nil, errors.ErrNilTensorTF(tensors) } + return tensors[0].Value(), nil } @@ -164,9 +174,11 @@ func (t *tensorflow) GetValues(inputs ...string) (values []interface{}, err erro if err != nil { return nil, err } + values = make([]interface{}, 0, len(tensors)) for _, tensor := range tensors { values = append(values, tensor.Value()) } + return values, nil } From bb536a89e0aa50419c189784828b0354da2bf766 Mon Sep 17 00:00:00 2001 From: datelier <57349093+datelier@users.noreply.github.com> Date: Mon, 25 May 2020 16:30:47 +0900 Subject: [PATCH 11/13] fix golangci-lint issue Signed-off-by: datelier <57349093+datelier@users.noreply.github.com> --- internal/core/converter/tensorflow/option.go | 12 ++++++++++++ internal/core/converter/tensorflow/tensorflow.go | 15 +++++++++++---- .../core/converter/tensorflow/tensorflow_test.go | 6 ------ 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/internal/core/converter/tensorflow/option.go b/internal/core/converter/tensorflow/option.go index 32aeb670a1..a8156252e5 100644 --- a/internal/core/converter/tensorflow/option.go +++ b/internal/core/converter/tensorflow/option.go @@ -17,6 +17,7 @@ // Package tensorflow provides implementation of Go API for extract data to vector package tensorflow +// Option is tensorflow configure. type Option func(*tensorflow) var ( @@ -27,6 +28,7 @@ var ( } ) +// WithSessionOptions returns Option that sets options. func WithSessionOptions(opts *SessionOptions) Option { return func(t *tensorflow) { if opts != nil { @@ -35,6 +37,7 @@ func WithSessionOptions(opts *SessionOptions) Option { } } +// WithSessionTarget returns Option that sets target. func WithSessionTarget(tgt string) Option { return func(t *tensorflow) { if tgt != "" { @@ -49,6 +52,7 @@ func WithSessionTarget(tgt string) Option { } } +// WithSessionConfig returns Option that sets config. func WithSessionConfig(cfg []byte) Option { return func(t *tensorflow) { if cfg != nil { @@ -63,6 +67,7 @@ func WithSessionConfig(cfg []byte) Option { } } +// WithOperations returns Option that sets operations. func WithOperations(opes ...*Operation) Option { return func(t *tensorflow) { if opes != nil { @@ -75,6 +80,7 @@ func WithOperations(opes ...*Operation) Option { } } +// WithExportPath returns Option that sets exportDir. func WithExportPath(path string) Option { return func(t *tensorflow) { if path != "" { @@ -83,6 +89,7 @@ func WithExportPath(path string) Option { } } +// WithTags returns Option that sets tags. func WithTags(tags ...string) Option { return func(t *tensorflow) { if tags != nil { @@ -95,12 +102,14 @@ func WithTags(tags ...string) Option { } } +// WithFeed returns Option that sets feeds. func WithFeed(operationName string, outputIndex int) Option { return func(t *tensorflow) { t.feeds = append(t.feeds, OutputSpec{operationName, outputIndex}) } } +// WithFeeds returns Option that sets feeds. func WithFeeds(operationNames []string, outputIndexes []int) Option { return func(t *tensorflow) { if operationNames != nil && outputIndexes != nil && len(operationNames) == len(outputIndexes) { @@ -111,12 +120,14 @@ func WithFeeds(operationNames []string, outputIndexes []int) Option { } } +// WithFetch returns Option that sets fetches. func WithFetch(operationName string, outputIndex int) Option { return func(t *tensorflow) { t.fetches = append(t.fetches, OutputSpec{operationName, outputIndex}) } } +// WithFetches returns Option that sets fetches. func WithFetches(operationNames []string, outputIndexes []int) Option { return func(t *tensorflow) { if operationNames != nil && outputIndexes != nil && len(operationNames) == len(outputIndexes) { @@ -127,6 +138,7 @@ func WithFetches(operationNames []string, outputIndexes []int) Option { } } +// WithNdim returns Option that sets ndim. func WithNdim(ndim uint8) Option { return func(t *tensorflow) { t.ndim = ndim diff --git a/internal/core/converter/tensorflow/tensorflow.go b/internal/core/converter/tensorflow/tensorflow.go index 6a41145e67..a9a7fee5ed 100644 --- a/internal/core/converter/tensorflow/tensorflow.go +++ b/internal/core/converter/tensorflow/tensorflow.go @@ -22,9 +22,13 @@ import ( "github.com/vdaas/vald/internal/errors" ) +// SessionOptions is a type alias for tensorflow.SessionOptions. type SessionOptions = tf.SessionOptions + +// Operation is a type alias for tensorflow.Operation. type Operation = tf.Operation +// TF represents a tensorflow interface. type TF interface { GetVector(inputs ...string) ([]float64, error) GetValue(inputs ...string) (interface{}, error) @@ -37,6 +41,7 @@ type session interface { Closer } +// Closer close a tensorflow.Session. type Closer interface { Close() error } @@ -53,20 +58,22 @@ type tensorflow struct { ndim uint8 } +// OutputSpec is the specification of an feed/fetch. type OutputSpec struct { operationName string outputIndex int } const ( - TwoDim uint8 = iota + 2 - ThreeDim + twoDim uint8 = iota + 2 + threeDim ) var loadFunc = func(exportDir string, tags []string, options *SessionOptions) (*tf.SavedModel, error) { return tf.LoadSavedModel(exportDir, tags, options) } +// New load a tensorlfow model and returns a new tensorflow struct. func New(opts ...Option) (TF, error) { t := new(tensorflow) @@ -124,7 +131,7 @@ func (t *tensorflow) GetVector(inputs ...string) ([]float64, error) { } switch t.ndim { - case TwoDim: + case twoDim: value, ok := tensors[0].Value().([][]float64) if ok { if value == nil { @@ -135,7 +142,7 @@ func (t *tensorflow) GetVector(inputs ...string) ([]float64, error) { } return nil, errors.ErrFailedToCastTF(tensors[0].Value()) - case ThreeDim: + case threeDim: value, ok := tensors[0].Value().([][][]float64) if ok { if len(value) == 0 || value[0] == nil { diff --git a/internal/core/converter/tensorflow/tensorflow_test.go b/internal/core/converter/tensorflow/tensorflow_test.go index f2f52b6cfb..8954b1b687 100644 --- a/internal/core/converter/tensorflow/tensorflow_test.go +++ b/internal/core/converter/tensorflow/tensorflow_test.go @@ -124,7 +124,6 @@ func TestNew(t *testing.T) { if err := test.checkFunc(test.want, got, err); err != nil { tt.Errorf("error = %v", err) } - }) } } @@ -196,7 +195,6 @@ func Test_tensorflow_Close(t *testing.T) { if err := test.checkFunc(test.want, err); err != nil { tt.Errorf("error = %v", err) } - }) } } @@ -320,7 +318,6 @@ func Test_tensorflow_run(t *testing.T) { if err := test.checkFunc(test.want, got, err); err != nil { tt.Errorf("error = %v", err) } - }) } } @@ -554,7 +551,6 @@ func Test_tensorflow_GetVector(t *testing.T) { if err := test.checkFunc(test.want, got, err); err != nil { tt.Errorf("error = %v", err) } - }) } } @@ -667,7 +663,6 @@ func Test_tensorflow_GetValue(t *testing.T) { if err := test.checkFunc(test.want, got, err); err != nil { tt.Errorf("error = %v", err) } - }) } } @@ -757,7 +752,6 @@ func Test_tensorflow_GetValues(t *testing.T) { if err := test.checkFunc(test.want, gotValues, err); err != nil { tt.Errorf("error = %v", err) } - }) } } From 0df8eb0b5eb331f28812f635a666a912b774b33f Mon Sep 17 00:00:00 2001 From: datelier <57349093+datelier@users.noreply.github.com> Date: Wed, 3 Jun 2020 13:20:57 +0900 Subject: [PATCH 12/13] delete Closer interface Signed-off-by: datelier <57349093+datelier@users.noreply.github.com> --- internal/core/converter/tensorflow/tensorflow.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/internal/core/converter/tensorflow/tensorflow.go b/internal/core/converter/tensorflow/tensorflow.go index a9a7fee5ed..dbee2c356c 100644 --- a/internal/core/converter/tensorflow/tensorflow.go +++ b/internal/core/converter/tensorflow/tensorflow.go @@ -18,6 +18,8 @@ package tensorflow import ( + "io" + tf "github.com/tensorflow/tensorflow/tensorflow/go" "github.com/vdaas/vald/internal/errors" ) @@ -28,6 +30,9 @@ type SessionOptions = tf.SessionOptions // Operation is a type alias for tensorflow.Operation. type Operation = tf.Operation +// Closer is a type alias io.Closer +type Closer = io.Closer + // TF represents a tensorflow interface. type TF interface { GetVector(inputs ...string) ([]float64, error) @@ -41,11 +46,6 @@ type session interface { Closer } -// Closer close a tensorflow.Session. -type Closer interface { - Close() error -} - type tensorflow struct { exportDir string tags []string From 057e0bd1e476dc249e39a39c8f7b0e7611941b29 Mon Sep 17 00:00:00 2001 From: datelier <57349093+datelier@users.noreply.github.com> Date: Wed, 3 Jun 2020 14:11:34 +0900 Subject: [PATCH 13/13] fix DeepSource issue: Function literal can be simplified Signed-off-by: datelier <57349093+datelier@users.noreply.github.com> --- internal/core/converter/tensorflow/tensorflow.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/internal/core/converter/tensorflow/tensorflow.go b/internal/core/converter/tensorflow/tensorflow.go index dbee2c356c..6f90c12b7a 100644 --- a/internal/core/converter/tensorflow/tensorflow.go +++ b/internal/core/converter/tensorflow/tensorflow.go @@ -69,9 +69,7 @@ const ( threeDim ) -var loadFunc = func(exportDir string, tags []string, options *SessionOptions) (*tf.SavedModel, error) { - return tf.LoadSavedModel(exportDir, tags, options) -} +var loadFunc = tf.LoadSavedModel // New load a tensorlfow model and returns a new tensorflow struct. func New(opts ...Option) (TF, error) {