diff --git a/internal/core/converter/tensorflow/option.go b/internal/core/converter/tensorflow/option.go index a017dec386..a8156252e5 100644 --- a/internal/core/converter/tensorflow/option.go +++ b/internal/core/converter/tensorflow/option.go @@ -17,6 +17,7 @@ // Package tensorflow provides implementation of Go API for extract data to vector package tensorflow +// Option is tensorflow configure. type Option func(*tensorflow) var ( @@ -27,28 +28,46 @@ var ( } ) +// WithSessionOptions returns Option that sets options. func WithSessionOptions(opts *SessionOptions) Option { return func(t *tensorflow) { - t.options = opts + if opts != nil { + t.options = opts + } } } +// WithSessionTarget returns Option that sets target. func WithSessionTarget(tgt string) Option { return func(t *tensorflow) { - if len(tgt) != 0 { - t.sessionTarget = tgt + if tgt != "" { + if t.options == nil { + t.options = &SessionOptions{ + Target: tgt, + } + } else { + t.options.Target = tgt + } } } } +// WithSessionConfig returns Option that sets config. func WithSessionConfig(cfg []byte) Option { return func(t *tensorflow) { if cfg != nil { - t.sessionConfig = cfg + if t.options == nil { + t.options = &SessionOptions{ + Config: cfg, + } + } else { + t.options.Config = cfg + } } } } +// WithOperations returns Option that sets operations. func WithOperations(opes ...*Operation) Option { return func(t *tensorflow) { if opes != nil { @@ -61,14 +80,16 @@ func WithOperations(opes ...*Operation) Option { } } +// WithExportPath returns Option that sets exportDir. func WithExportPath(path string) Option { return func(t *tensorflow) { - if len(path) != 0 { + if path != "" { t.exportDir = path } } } +// WithTags returns Option that sets tags. func WithTags(tags ...string) Option { return func(t *tensorflow) { if tags != nil { @@ -81,12 +102,14 @@ func WithTags(tags ...string) Option { } } +// WithFeed returns Option that sets feeds. func WithFeed(operationName string, outputIndex int) Option { return func(t *tensorflow) { t.feeds = append(t.feeds, OutputSpec{operationName, outputIndex}) } } +// WithFeeds returns Option that sets feeds. func WithFeeds(operationNames []string, outputIndexes []int) Option { return func(t *tensorflow) { if operationNames != nil && outputIndexes != nil && len(operationNames) == len(outputIndexes) { @@ -97,12 +120,14 @@ func WithFeeds(operationNames []string, outputIndexes []int) Option { } } +// WithFetch returns Option that sets fetches. func WithFetch(operationName string, outputIndex int) Option { return func(t *tensorflow) { t.fetches = append(t.fetches, OutputSpec{operationName, outputIndex}) } } +// WithFetches returns Option that sets fetches. func WithFetches(operationNames []string, outputIndexes []int) Option { return func(t *tensorflow) { if operationNames != nil && outputIndexes != nil && len(operationNames) == len(outputIndexes) { @@ -113,6 +138,7 @@ func WithFetches(operationNames []string, outputIndexes []int) Option { } } +// WithNdim returns Option that sets ndim. func WithNdim(ndim uint8) Option { return func(t *tensorflow) { t.ndim = ndim diff --git a/internal/core/converter/tensorflow/option_test.go b/internal/core/converter/tensorflow/option_test.go index 3da2be475e..0c522faf4e 100644 --- a/internal/core/converter/tensorflow/option_test.go +++ b/internal/core/converter/tensorflow/option_test.go @@ -18,83 +18,55 @@ package tensorflow import ( + "reflect" "testing" + "github.com/vdaas/vald/internal/errors" "go.uber.org/goleak" ) func TestWithSessionOptions(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { opts *SessionOptions } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - opts: nil, - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - opts: nil, - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set nothing when opts is nil", + want: want{ + obj: new(T), + }, + }, + { + name: "set success when opts is not nil", + args: args{ + opts: new(SessionOptions), + }, + want: want{ + obj: &T{ + options: new(SessionOptions), + }, + }, + }, } for _, test := range tests { @@ -107,107 +79,63 @@ func TestWithSessionOptions(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithSessionOptions(test.args.opts) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithSessionOptions(test.args.opts) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithSessionOptions(test.args.opts) + obj := new(T) + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } func TestWithSessionTarget(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { tgt string } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - tgt: "", - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - tgt: "", - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set nothing when tgt is empty", + want: want{ + obj: new(T), + }, + }, + { + name: "set success when tfg is `test`", + args: args{ + tgt: "test", + }, + want: want{ + obj: &T{ + options: &SessionOptions{ + Target: "test", + }, + }, + }, + }, } for _, test := range tests { @@ -220,107 +148,63 @@ func TestWithSessionTarget(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithSessionTarget(test.args.tgt) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithSessionTarget(test.args.tgt) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithSessionTarget(test.args.tgt) + obj := new(T) + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } func TestWithSessionConfig(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { cfg []byte } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - cfg: nil, - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - cfg: nil, - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set nothing when cfg is nil", + want: want{ + obj: new(T), + }, + }, + { + name: "set success when cfg is []byte{}", + args: args{ + cfg: []byte{}, + }, + want: want{ + obj: &T{ + options: &SessionOptions{ + Config: []byte{}, + }, + }, + }, + }, } for _, test := range tests { @@ -333,107 +217,79 @@ func TestWithSessionConfig(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithSessionConfig(test.args.cfg) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithSessionConfig(test.args.cfg) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithSessionConfig(test.args.cfg) + obj := new(T) + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } func TestWithOperations(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { opes []*Operation } + type fields struct { + opes []*Operation + } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + fields fields + want want + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - opes: nil, - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - opes: nil, - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set nothing when opes is nil", + want: want{ + obj: new(T), + }, + }, + { + name: "set success when opes is not nil and operations field is not nil", + args: args{ + opes: []*Operation{}, + }, + fields: fields{ + opes: []*Operation{}, + }, + want: want{ + obj: &T{ + operations: []*Operation{}, + }, + }, + }, + { + name: "set success when opes is not nil and operations field is nil", + args: args{ + opes: []*Operation{}, + }, + want: want{ + obj: &T{ + operations: []*Operation{}, + }, + }, + }, } for _, test := range tests { @@ -446,107 +302,63 @@ func TestWithOperations(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithOperations(test.args.opes...) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithOperations(test.args.opes...) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithOperations(test.args.opes...) + obj := &T{ + operations: test.fields.opes, + } + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } func TestWithExportPath(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { path string } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - path: "", - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - path: "", - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set nothing when path is empty", + want: want{ + obj: new(T), + }, + }, + { + name: "set success when path is `test`", + args: args{ + path: "test", + }, + want: want{ + obj: &T{ + exportDir: "test", + }, + }, + }, } for _, test := range tests { @@ -559,107 +371,90 @@ func TestWithExportPath(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithExportPath(test.args.path) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithExportPath(test.args.path) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithExportPath(test.args.path) + obj := new(T) + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } func TestWithTags(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { tags []string } + type fields struct { + tags []string + } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + fields fields + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - tags: nil, - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - tags: nil, - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set nothing when tags is nil", + want: want{ + obj: new(T), + }, + }, + { + name: "set success when tags is not nil and tags field is not nil", + args: args{ + tags: []string{ + "test", + }, + }, + fields: fields{ + tags: []string{ + "test", + }, + }, + want: want{ + obj: &T{ + tags: []string{ + "test", + "test", + }, + }, + }, + }, + { + name: "set success when tags is not nil and tags field is nil", + args: args{ + tags: []string{ + "test", + }, + }, + want: want{ + obj: &T{ + tags: []string{ + "test", + }, + }, + }, + }, } for _, test := range tests { @@ -672,110 +467,64 @@ func TestWithTags(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithTags(test.args.tags...) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithTags(test.args.tags...) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithTags(test.args.tags...) + obj := &T{ + tags: test.fields.tags, + } + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } func TestWithFeed(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { operationName string outputIndex int } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - operationName: "", - outputIndex: 0, - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - operationName: "", - outputIndex: 0, - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set success when operationName is `test` and outputIndex is 0", + args: args{ + operationName: "test", + outputIndex: 0, + }, + want: want{ + obj: &T{ + feeds: []OutputSpec{ + { + operationName: "test", + outputIndex: 0, + }, + }, + }, + }, + }, } for _, test := range tests { @@ -788,110 +537,100 @@ func TestWithFeed(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithFeed(test.args.operationName, test.args.outputIndex) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithFeed(test.args.operationName, test.args.outputIndex) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithFeed(test.args.operationName, test.args.outputIndex) + obj := new(T) + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } func TestWithFeeds(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { operationNames []string outputIndexes []int } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - operationNames: nil, - outputIndexes: nil, - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - operationNames: nil, - outputIndexes: nil, - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set success when operationNames is []string{`test`} and outputIndexes is []int{0}", + args: args{ + operationNames: []string{ + "test", + }, + outputIndexes: []int{ + 0, + }, + }, + want: want{ + obj: &T{ + feeds: []OutputSpec{ + { + operationName: "test", + outputIndex: 0, + }, + }, + }, + }, + }, + { + name: "set nothing when operationNames is nil", + args: args{ + outputIndexes: []int{ + 0, + }, + }, + want: want{ + obj: new(T), + }, + }, + { + name: "set nothing when outputIndexes is nil", + args: args{ + operationNames: []string{ + "test", + }, + }, + want: want{ + obj: new(T), + }, + }, + { + name: "set nothing when length of operationName and outputIndexes are different", + args: args{ + operationNames: []string{}, + outputIndexes: []int{ + 0, + }, + }, + want: want{ + obj: new(T), + }, + }, } for _, test := range tests { @@ -904,110 +643,62 @@ func TestWithFeeds(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithFeeds(test.args.operationNames, test.args.outputIndexes) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithFeeds(test.args.operationNames, test.args.outputIndexes) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithFeeds(test.args.operationNames, test.args.outputIndexes) + obj := new(T) + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } func TestWithFetch(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { operationName string outputIndex int } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - operationName: "", - outputIndex: 0, - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - operationName: "", - outputIndex: 0, - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set success when operationName is `test` and outputIndex is 0", + args: args{ + operationName: "test", + outputIndex: 0, + }, + want: want{ + obj: &T{ + fetches: []OutputSpec{ + { + operationName: "test", + outputIndex: 0, + }, + }, + }, + }, + }, } for _, test := range tests { @@ -1020,110 +711,100 @@ func TestWithFetch(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithFetch(test.args.operationName, test.args.outputIndex) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithFetch(test.args.operationName, test.args.outputIndex) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithFetch(test.args.operationName, test.args.outputIndex) + obj := new(T) + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } func TestWithFetches(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { operationNames []string outputIndexes []int } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - operationNames: nil, - outputIndexes: nil, - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - operationNames: nil, - outputIndexes: nil, - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set success when operationNames is []string{`test`} and outputIndexes is []int{0}", + args: args{ + operationNames: []string{ + "test", + }, + outputIndexes: []int{ + 0, + }, + }, + want: want{ + obj: &T{ + fetches: []OutputSpec{ + { + operationName: "test", + outputIndex: 0, + }, + }, + }, + }, + }, + { + name: "set nothing when operationNames is nil", + args: args{ + outputIndexes: []int{ + 0, + }, + }, + want: want{ + obj: new(T), + }, + }, + { + name: "set nothing when outputIndexs is nil", + args: args{ + operationNames: []string{ + "test", + }, + }, + want: want{ + obj: new(T), + }, + }, + { + name: "set nothing when length of operationNames and outputIndexs are different", + args: args{ + operationNames: []string{}, + outputIndexes: []int{ + 0, + }, + }, + want: want{ + obj: new(T), + }, + }, } for _, test := range tests { @@ -1136,107 +817,55 @@ func TestWithFetches(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithFetches(test.args.operationNames, test.args.outputIndexes) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithFetches(test.args.operationNames, test.args.outputIndexes) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithFetches(test.args.operationNames, test.args.outputIndexes) + obj := new(T) + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } func TestWithNdim(t *testing.T) { - type T = interface{} + type T = tensorflow type args struct { ndim uint8 } type want struct { obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error } type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error + name string + args args + want want + checkFunc func(want, *T) error beforeFunc func(args) afterFunc func(args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.c) - } - return nil - } - */ + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - ndim: 0, - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - ndim: 0, - }, - want: want { - obj: new(T), - }, - } - }(), - */ + { + name: "set success when ndim is 1", + args: args{ + ndim: 1, + }, + want: want{ + obj: &T{ + ndim: 1, + }, + }, + }, } for _, test := range tests { @@ -1249,31 +878,15 @@ func TestWithNdim(t *testing.T) { defer test.afterFunc(test.args) } - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithNdim(test.args.ndim) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithNdim(test.args.ndim) - obj := new(T) - got(obj) - if err := test.checkFunc(tt.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithNdim(test.args.ndim) + obj := new(T) + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } }) } } diff --git a/internal/core/converter/tensorflow/tensorflow.go b/internal/core/converter/tensorflow/tensorflow.go index 799a1268f6..6f90c12b7a 100644 --- a/internal/core/converter/tensorflow/tensorflow.go +++ b/internal/core/converter/tensorflow/tensorflow.go @@ -18,63 +18,75 @@ package tensorflow import ( + "io" + tf "github.com/tensorflow/tensorflow/tensorflow/go" "github.com/vdaas/vald/internal/errors" ) +// SessionOptions is a type alias for tensorflow.SessionOptions. type SessionOptions = tf.SessionOptions + +// Operation is a type alias for tensorflow.Operation. type Operation = tf.Operation +// Closer is a type alias io.Closer +type Closer = io.Closer + +// TF represents a tensorflow interface. type TF interface { GetVector(inputs ...string) ([]float64, error) GetValue(inputs ...string) (interface{}, error) GetValues(inputs ...string) (values []interface{}, err error) - Close() error + Closer +} + +type session interface { + Run(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*Operation) ([]*tf.Tensor, error) + Closer } type tensorflow struct { - exportDir string - tags []string - feeds []OutputSpec - fetches []OutputSpec - operations []*Operation - sessionTarget string - sessionConfig []byte - options *SessionOptions - graph *tf.Graph - session *tf.Session - ndim uint8 + exportDir string + tags []string + feeds []OutputSpec + fetches []OutputSpec + operations []*Operation + options *SessionOptions + graph *tf.Graph + session session + ndim uint8 } +// OutputSpec is the specification of an feed/fetch. type OutputSpec struct { operationName string outputIndex int } const ( - TwoDim uint8 = iota + 2 - ThreeDim + twoDim uint8 = iota + 2 + threeDim ) +var loadFunc = tf.LoadSavedModel + +// New load a tensorlfow model and returns a new tensorflow struct. func New(opts ...Option) (TF, error) { t := new(tensorflow) + for _, opt := range append(defaultOpts, opts...) { opt(t) } - if t.options == nil && (len(t.sessionTarget) != 0 || t.sessionConfig != nil) { - t.options = &tf.SessionOptions{ - Target: t.sessionTarget, - Config: t.sessionConfig, - } - } - - model, err := tf.LoadSavedModel(t.exportDir, t.tags, t.options) + model, err := loadFunc(t.exportDir, t.tags, t.options) if err != nil { return nil, err } + t.graph = model.Graph t.session = model.Session + return t, nil } @@ -88,11 +100,13 @@ func (t *tensorflow) run(inputs ...string) ([]*tf.Tensor, error) { } feeds := make(map[tf.Output]*tf.Tensor, len(inputs)) + for i, val := range inputs { inputTensor, err := tf.NewTensor(val) if err != nil { return nil, err } + feeds[t.graph.Operation(t.feeds[i].operationName).Output(t.feeds[i].outputIndex)] = inputTensor } @@ -109,38 +123,41 @@ func (t *tensorflow) GetVector(inputs ...string) ([]float64, error) { if err != nil { return nil, err } - if tensors == nil || tensors[0] == nil || tensors[0].Value() == nil { + + if len(tensors) == 0 || tensors[0] == nil || tensors[0].Value() == nil { return nil, errors.ErrNilTensorTF(tensors) } switch t.ndim { - case TwoDim: + case twoDim: value, ok := tensors[0].Value().([][]float64) if ok { if value == nil { return nil, errors.ErrNilTensorValueTF(value) } + return value[0], nil - } else { - return nil, errors.ErrFailedToCastTF(tensors[0].Value()) } - case ThreeDim: + + return nil, errors.ErrFailedToCastTF(tensors[0].Value()) + case threeDim: value, ok := tensors[0].Value().([][][]float64) if ok { - if value == nil || value[0] == nil { + if len(value) == 0 || value[0] == nil { return nil, errors.ErrNilTensorValueTF(value) } + return value[0][0], nil - } else { - return nil, errors.ErrFailedToCastTF(tensors[0].Value()) } + + return nil, errors.ErrFailedToCastTF(tensors[0].Value()) default: value, ok := tensors[0].Value().([]float64) if ok { return value, nil - } else { - return nil, errors.ErrFailedToCastTF(tensors[0].Value()) } + + return nil, errors.ErrFailedToCastTF(tensors[0].Value()) } } @@ -149,9 +166,11 @@ func (t *tensorflow) GetValue(inputs ...string) (interface{}, error) { if err != nil { return nil, err } - if tensors == nil || tensors[0] == nil { + + if len(tensors) == 0 || tensors[0] == nil { return nil, errors.ErrNilTensorTF(tensors) } + return tensors[0].Value(), nil } @@ -160,9 +179,11 @@ func (t *tensorflow) GetValues(inputs ...string) (values []interface{}, err erro if err != nil { return nil, err } + values = make([]interface{}, 0, len(tensors)) for _, tensor := range tensors { values = append(values, tensor.Value()) } + return values, nil } diff --git a/internal/core/converter/tensorflow/tensorflow_mock_test.go b/internal/core/converter/tensorflow/tensorflow_mock_test.go new file mode 100644 index 0000000000..b9b6e4bc85 --- /dev/null +++ b/internal/core/converter/tensorflow/tensorflow_mock_test.go @@ -0,0 +1,35 @@ +// +// Copyright (C) 2019-2020 Vdaas.org Vald team ( kpango, rinx, kmrmt ) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package tensorflow provides implementation of Go API for extract data to vector +package tensorflow + +import ( + tf "github.com/tensorflow/tensorflow/tensorflow/go" +) + +type mockSession struct { + RunFunc func(map[tf.Output]*tf.Tensor, []tf.Output, []*Operation) ([]*tf.Tensor, error) + CloseFunc func() error +} + +func (m *mockSession) Run(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*Operation) ([]*tf.Tensor, error) { + return m.RunFunc(feeds, fetches, operations) +} + +func (m *mockSession) Close() error { + return m.CloseFunc() +} diff --git a/internal/core/converter/tensorflow/tensorflow_test.go b/internal/core/converter/tensorflow/tensorflow_test.go index 58b9f67118..8954b1b687 100644 --- a/internal/core/converter/tensorflow/tensorflow_test.go +++ b/internal/core/converter/tensorflow/tensorflow_test.go @@ -23,7 +23,6 @@ import ( tf "github.com/tensorflow/tensorflow/tensorflow/go" "github.com/vdaas/vald/internal/errors" - "go.uber.org/goleak" ) @@ -53,31 +52,59 @@ func TestNew(t *testing.T) { return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - opts: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - opts: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ + { + name: "returns (t, nil) when opts is nil", + want: want{ + want: &tensorflow{ + session: (&tf.SavedModel{}).Session, + }, + }, + beforeFunc: func(args args) { + defaultOpts = []Option{} + loadFunc = func(s string, ss []string, o *SessionOptions) (*tf.SavedModel, error) { + return &tf.SavedModel{}, nil + } + }, + }, + { + name: "returns (t, nil) when args is not nil", + args: args{ + opts: []Option{ + WithSessionTarget("test"), + WithSessionConfig([]byte{}), + WithNdim(1), + }, + }, + want: want{ + want: &tensorflow{ + options: &tf.SessionOptions{ + Target: "test", + Config: []byte{}, + }, + graph: nil, + session: (&tf.SavedModel{}).Session, + ndim: 1, + }, + }, + beforeFunc: func(args args) { + defaultOpts = []Option{} + loadFunc = func(s string, ss []string, o *SessionOptions) (*tf.SavedModel, error) { + return &tf.SavedModel{}, nil + } + }, + }, + { + name: "returns (nil, error) when loadFunc function returns error", + want: want{ + err: errors.New("load error"), + }, + beforeFunc: func(args args) { + defaultOpts = []Option{} + loadFunc = func(s string, ss []string, o *SessionOptions) (*tf.SavedModel, error) { + return nil, errors.New("load error") + } + }, + }, } for _, test := range tests { @@ -97,24 +124,13 @@ func TestNew(t *testing.T) { if err := test.checkFunc(test.want, got, err); err != nil { tt.Errorf("error = %v", err) } - }) } } func Test_tensorflow_Close(t *testing.T) { type fields struct { - exportDir string - tags []string - feeds []OutputSpec - fetches []OutputSpec - operations []*Operation - sessionTarget string - sessionConfig []byte - options *SessionOptions - graph *tf.Graph - session *tf.Session - ndim uint8 + session session } type want struct { err error @@ -134,51 +150,29 @@ func Test_tensorflow_Close(t *testing.T) { return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ + { + name: "return nil", + fields: fields{ + session: &mockSession{ + CloseFunc: func() error { + return nil + }, + }, + }, + }, + { + name: "return error", + fields: fields{ + session: &mockSession{ + CloseFunc: func() error { + return errors.New("fail") + }, + }, + }, + want: want{ + err: errors.New("fail"), + }, + }, } for _, test := range tests { @@ -194,24 +188,13 @@ func Test_tensorflow_Close(t *testing.T) { test.checkFunc = defaultCheckFunc } t := &tensorflow{ - exportDir: test.fields.exportDir, - tags: test.fields.tags, - feeds: test.fields.feeds, - fetches: test.fields.fetches, - operations: test.fields.operations, - sessionTarget: test.fields.sessionTarget, - sessionConfig: test.fields.sessionConfig, - options: test.fields.options, - graph: test.fields.graph, - session: test.fields.session, - ndim: test.fields.ndim, + session: test.fields.session, } err := t.Close() if err := test.checkFunc(test.want, err); err != nil { tt.Errorf("error = %v", err) } - }) } } @@ -221,17 +204,9 @@ func Test_tensorflow_run(t *testing.T) { inputs []string } type fields struct { - exportDir string - tags []string - feeds []OutputSpec - fetches []OutputSpec - operations []*Operation - sessionTarget string - sessionConfig []byte - options *SessionOptions - graph *tf.Graph - session *tf.Session - ndim uint8 + feeds []OutputSpec + graph *tf.Graph + session session } type want struct { want []*tf.Tensor @@ -256,57 +231,69 @@ func Test_tensorflow_run(t *testing.T) { return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ + { + name: "returns ([], nil) when inputs is nil", + fields: fields{ + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return []*tf.Tensor{}, nil + }, + }, + }, + want: want{ + want: []*tf.Tensor{}, + }, + }, + { + name: "returns ([], nil) when inputs is []string{`test`}", + args: args{ + inputs: []string{ + "test", + }, + }, + fields: fields{ + feeds: []OutputSpec{ + { + operationName: "test", + outputIndex: 0, + }, + }, + graph: tf.NewGraph(), + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return []*tf.Tensor{}, nil + }, + }, + }, + want: want{ + want: []*tf.Tensor{}, + }, + }, + { + name: "returns (nil, error) when length of inputs and feeds field are different", + args: args{ + inputs: []string{ + "", + }, + }, + want: want{ + err: errors.ErrInputLength(1, 0), + }, + }, + { + name: "returns (nil, error) when Run function returns (nil, error)", + fields: fields{ + graph: tf.NewGraph(), + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return nil, errors.New("session.Run() error") + }, + }, + }, + want: want{ + err: errors.New("session.Run() error"), + }, + }, } for _, test := range tests { @@ -322,24 +309,15 @@ func Test_tensorflow_run(t *testing.T) { test.checkFunc = defaultCheckFunc } t := &tensorflow{ - exportDir: test.fields.exportDir, - tags: test.fields.tags, - feeds: test.fields.feeds, - fetches: test.fields.fetches, - operations: test.fields.operations, - sessionTarget: test.fields.sessionTarget, - sessionConfig: test.fields.sessionConfig, - options: test.fields.options, - graph: test.fields.graph, - session: test.fields.session, - ndim: test.fields.ndim, + feeds: test.fields.feeds, + graph: test.fields.graph, + session: test.fields.session, } got, err := t.run(test.args.inputs...) if err := test.checkFunc(test.want, got, err); err != nil { tt.Errorf("error = %v", err) } - }) } } @@ -349,17 +327,8 @@ func Test_tensorflow_GetVector(t *testing.T) { inputs []string } type fields struct { - exportDir string - tags []string - feeds []OutputSpec - fetches []OutputSpec - operations []*Operation - sessionTarget string - sessionConfig []byte - options *SessionOptions - graph *tf.Graph - session *tf.Session - ndim uint8 + session session + ndim uint8 } type want struct { want []float64 @@ -384,57 +353,181 @@ func Test_tensorflow_GetVector(t *testing.T) { return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ + { + name: "returns (vector, nil) when run function returns (tensors, nil) and ndim is default", + fields: fields{ + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor([]float64{ + 1, + 2, + 3, + }) + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor}, nil + }, + }, + }, + want: want{ + want: []float64{ + 1, + 2, + 3, + }, + }, + }, + { + name: "returns (vector, nil) when run function returns (tensors, nil) and ndim is 2", + fields: fields{ + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor([][]float64{ + { + 1, + 2, + 3, + }, + }) + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor}, nil + }, + }, + ndim: 2, + }, + want: want{ + want: []float64{ + 1, + 2, + 3, + }, + }, + }, + { + name: "returns (vector, nil) when run function returns (tensors, nil) and ndim is 3", + fields: fields{ + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor([][][]float64{ + { + { + 1, + 2, + 3, + }, + }, + }) + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor}, nil + }, + }, + ndim: 3, + }, + want: want{ + want: []float64{ + 1, + 2, + 3, + }, + }, + }, + { + name: "returns (nil, error) when run function returns (nil, error)", + fields: fields{ + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return nil, errors.New("session.Run() error") + }, + }, + }, + want: want{ + err: errors.New("session.Run() error"), + }, + }, + { + name: "returns (nil, error) when tensors returned by the run funcion is nil", + fields: fields{ + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return nil, nil + }, + }, + }, + want: want{ + err: errors.ErrNilTensorTF([]*tf.Tensor{}), + }, + }, + { + name: "returns (nil, error) when element of tensors returned by the run funcion is nil", + fields: fields{ + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return []*tf.Tensor{nil}, nil + }, + }, + }, + want: want{ + err: errors.ErrNilTensorTF([]*tf.Tensor{nil}), + }, + }, + { + name: "returns (nil, error) when ndim is `TwoDim` and returns error of `ErrFailedToCastTF`", + fields: fields{ + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor("test") + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor}, nil + }, + }, + ndim: 2, + }, + want: want{ + err: errors.ErrFailedToCastTF("test"), + }, + }, + { + name: "returns (nil, error) when ndim is `ThreeDim` and returns error of `ErrFailedToCastTF`", + fields: fields{ + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor("test") + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor}, nil + }, + }, + ndim: 3, + }, + want: want{ + err: errors.ErrFailedToCastTF("test"), + }, + }, + { + name: "returns (nil, error) when ndim is `default` and returns error of `ErrFailedToCastTF`", + fields: fields{ + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor("test") + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor}, nil + }, + }, + }, + want: want{ + err: errors.ErrFailedToCastTF("test"), + }, + }, } for _, test := range tests { @@ -450,24 +543,14 @@ func Test_tensorflow_GetVector(t *testing.T) { test.checkFunc = defaultCheckFunc } t := &tensorflow{ - exportDir: test.fields.exportDir, - tags: test.fields.tags, - feeds: test.fields.feeds, - fetches: test.fields.fetches, - operations: test.fields.operations, - sessionTarget: test.fields.sessionTarget, - sessionConfig: test.fields.sessionConfig, - options: test.fields.options, - graph: test.fields.graph, - session: test.fields.session, - ndim: test.fields.ndim, + session: test.fields.session, + ndim: test.fields.ndim, } got, err := t.GetVector(test.args.inputs...) if err := test.checkFunc(test.want, got, err); err != nil { tt.Errorf("error = %v", err) } - }) } } @@ -477,17 +560,7 @@ func Test_tensorflow_GetValue(t *testing.T) { inputs []string } type fields struct { - exportDir string - tags []string - feeds []OutputSpec - fetches []OutputSpec - operations []*Operation - sessionTarget string - sessionConfig []byte - options *SessionOptions - graph *tf.Graph - session *tf.Session - ndim uint8 + session session } type want struct { want interface{} @@ -512,57 +585,62 @@ func Test_tensorflow_GetValue(t *testing.T) { return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ + { + name: "returns (value, nil) when run function returns (tensors, nil)", + fields: fields{ + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor("test") + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor}, nil + }, + }, + }, + want: want{ + want: "test", + }, + }, + { + name: "returns (nil, error) when run function returns (nil, error)", + fields: fields{ + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return nil, errors.New("session.Run() error") + }, + }, + }, + want: want{ + err: errors.New("session.Run() error"), + }, + }, + { + name: "returns (nil, error) when tensors returned by the run funcion is nil", + fields: fields{ + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return nil, nil + }, + }, + }, + want: want{ + err: errors.ErrNilTensorTF([]*tf.Tensor{}), + }, + }, + { + name: "returns (nil, error) when element of tensors returned by the run funcion is nil", + fields: fields{ + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return []*tf.Tensor{nil}, nil + }, + }, + }, + want: want{ + err: errors.ErrNilTensorTF([]*tf.Tensor{nil}), + }, + }, } for _, test := range tests { @@ -578,24 +656,13 @@ func Test_tensorflow_GetValue(t *testing.T) { test.checkFunc = defaultCheckFunc } t := &tensorflow{ - exportDir: test.fields.exportDir, - tags: test.fields.tags, - feeds: test.fields.feeds, - fetches: test.fields.fetches, - operations: test.fields.operations, - sessionTarget: test.fields.sessionTarget, - sessionConfig: test.fields.sessionConfig, - options: test.fields.options, - graph: test.fields.graph, - session: test.fields.session, - ndim: test.fields.ndim, + session: test.fields.session, } got, err := t.GetValue(test.args.inputs...) if err := test.checkFunc(test.want, got, err); err != nil { tt.Errorf("error = %v", err) } - }) } } @@ -605,17 +672,7 @@ func Test_tensorflow_GetValues(t *testing.T) { inputs []string } type fields struct { - exportDir string - tags []string - feeds []OutputSpec - fetches []OutputSpec - operations []*Operation - sessionTarget string - sessionConfig []byte - options *SessionOptions - graph *tf.Graph - session *tf.Session - ndim uint8 + session session } type want struct { wantValues []interface{} @@ -640,57 +697,39 @@ func Test_tensorflow_GetValues(t *testing.T) { return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ + { + name: "return (values, nil) when run function returns (tensors, nil)", + fields: fields{ + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor("test") + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor, tensor}, nil + }, + }, + }, + want: want{ + wantValues: []interface{}{ + "test", + "test", + }, + }, + }, + { + name: "returns (nil, error) when run function returns (nil, error)", + fields: fields{ + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return nil, errors.New("session.Run() error") + }, + }, + }, + want: want{ + err: errors.New("session.Run() error"), + }, + }, } for _, test := range tests { @@ -706,24 +745,13 @@ func Test_tensorflow_GetValues(t *testing.T) { test.checkFunc = defaultCheckFunc } t := &tensorflow{ - exportDir: test.fields.exportDir, - tags: test.fields.tags, - feeds: test.fields.feeds, - fetches: test.fields.fetches, - operations: test.fields.operations, - sessionTarget: test.fields.sessionTarget, - sessionConfig: test.fields.sessionConfig, - options: test.fields.options, - graph: test.fields.graph, - session: test.fields.session, - ndim: test.fields.ndim, + session: test.fields.session, } gotValues, err := t.GetValues(test.args.inputs...) if err := test.checkFunc(test.want, gotValues, err); err != nil { tt.Errorf("error = %v", err) } - }) } }