From 044658c8c85ab63398df8fd7f5f62fb003fabf4f Mon Sep 17 00:00:00 2001 From: datelier <57349093+datelier@users.noreply.github.com> Date: Wed, 20 May 2020 15:46:22 +0900 Subject: [PATCH] remove sessionTarget, sessionConfig Signed-off-by: datelier <57349093+datelier@users.noreply.github.com> --- internal/core/converter/tensorflow/option.go | 20 ++++++- .../core/converter/tensorflow/option_test.go | 8 ++- .../core/converter/tensorflow/tensorflow.go | 9 --- .../converter/tensorflow/tensorflow_test.go | 60 ------------------- 4 files changed, 23 insertions(+), 74 deletions(-) diff --git a/internal/core/converter/tensorflow/option.go b/internal/core/converter/tensorflow/option.go index a017dec3868..2b5df9a1bc8 100644 --- a/internal/core/converter/tensorflow/option.go +++ b/internal/core/converter/tensorflow/option.go @@ -29,14 +29,22 @@ var ( func WithSessionOptions(opts *SessionOptions) Option { return func(t *tensorflow) { - t.options = opts + if opts != nil { + t.options = opts + } } } func WithSessionTarget(tgt string) Option { return func(t *tensorflow) { if len(tgt) != 0 { - t.sessionTarget = tgt + if t.options == nil { + t.options = &SessionOptions{ + Target: tgt, + } + } else { + t.options.Target = tgt + } } } } @@ -44,7 +52,13 @@ func WithSessionTarget(tgt string) Option { func WithSessionConfig(cfg []byte) Option { return func(t *tensorflow) { if cfg != nil { - t.sessionConfig = cfg + if t.options == nil { + t.options = &SessionOptions{ + Config: cfg, + } + } else { + t.options.Config = cfg + } } } } diff --git a/internal/core/converter/tensorflow/option_test.go b/internal/core/converter/tensorflow/option_test.go index d7a8f1e3c07..d2ba3ae76b8 100644 --- a/internal/core/converter/tensorflow/option_test.go +++ b/internal/core/converter/tensorflow/option_test.go @@ -136,7 +136,9 @@ func TestWithSessionTarget(t *testing.T) { }, want: want{ obj: &T{ - sessionTarget: "test", + options: &SessionOptions{ + Target: "test", + }, }, }, }, @@ -206,7 +208,9 @@ func TestWithSessionConfig(t *testing.T) { }, want: want{ obj: &T{ - sessionConfig: []byte{0}, + options: &SessionOptions{ + Config: []byte{0}, + }, }, }, }, diff --git a/internal/core/converter/tensorflow/tensorflow.go b/internal/core/converter/tensorflow/tensorflow.go index b2d66e24257..bbbad54a025 100644 --- a/internal/core/converter/tensorflow/tensorflow.go +++ b/internal/core/converter/tensorflow/tensorflow.go @@ -47,8 +47,6 @@ type tensorflow struct { feeds []OutputSpec fetches []OutputSpec operations []*Operation - sessionTarget string - sessionConfig []byte options *SessionOptions graph *tf.Graph session session @@ -75,13 +73,6 @@ func New(opts ...Option) (TF, error) { opt(t) } - if t.options == nil && (t.sessionTarget == "" || t.sessionConfig != nil) { - t.options = &tf.SessionOptions{ - Target: t.sessionTarget, - Config: t.sessionConfig, - } - } - model, err := loadFunc(t.exportDir, t.tags, t.options) if err != nil { return nil, err diff --git a/internal/core/converter/tensorflow/tensorflow_test.go b/internal/core/converter/tensorflow/tensorflow_test.go index 515a6a6dd55..49a4ed99ccc 100644 --- a/internal/core/converter/tensorflow/tensorflow_test.go +++ b/internal/core/converter/tensorflow/tensorflow_test.go @@ -83,8 +83,6 @@ func TestNew(t *testing.T) { }, want: want{ want: &tensorflow{ - sessionTarget: "test", - sessionConfig: []byte{}, options: &tf.SessionOptions{ Target: "test", Config: []byte{}, @@ -151,8 +149,6 @@ func Test_tensorflow_Close(t *testing.T) { feeds []OutputSpec fetches []OutputSpec operations []*Operation - sessionTarget string - sessionConfig []byte options *SessionOptions graph *tf.Graph session session @@ -184,8 +180,6 @@ func Test_tensorflow_Close(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -208,8 +202,6 @@ func Test_tensorflow_Close(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -244,8 +236,6 @@ func Test_tensorflow_Close(t *testing.T) { feeds: test.fields.feeds, fetches: test.fields.fetches, operations: test.fields.operations, - sessionTarget: test.fields.sessionTarget, - sessionConfig: test.fields.sessionConfig, options: test.fields.options, graph: test.fields.graph, session: test.fields.session, @@ -271,8 +261,6 @@ func Test_tensorflow_run(t *testing.T) { feeds []OutputSpec fetches []OutputSpec operations []*Operation - sessionTarget string - sessionConfig []byte options *SessionOptions graph *tf.Graph session session @@ -312,8 +300,6 @@ func Test_tensorflow_run(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -347,8 +333,6 @@ func Test_tensorflow_run(t *testing.T) { }, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: tf.NewGraph(), session: &mockSession{ @@ -377,8 +361,6 @@ func Test_tensorflow_run(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: nil, @@ -400,8 +382,6 @@ func Test_tensorflow_run(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: tf.NewGraph(), session: &mockSession{ @@ -436,8 +416,6 @@ func Test_tensorflow_run(t *testing.T) { feeds: test.fields.feeds, fetches: test.fields.fetches, operations: test.fields.operations, - sessionTarget: test.fields.sessionTarget, - sessionConfig: test.fields.sessionConfig, options: test.fields.options, graph: test.fields.graph, session: test.fields.session, @@ -463,8 +441,6 @@ func Test_tensorflow_GetVector(t *testing.T) { feeds []OutputSpec fetches []OutputSpec operations []*Operation - sessionTarget string - sessionConfig []byte options *SessionOptions graph *tf.Graph session session @@ -504,8 +480,6 @@ func Test_tensorflow_GetVector(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -536,8 +510,6 @@ func Test_tensorflow_GetVector(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -564,8 +536,6 @@ func Test_tensorflow_GetVector(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -592,8 +562,6 @@ func Test_tensorflow_GetVector(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -620,8 +588,6 @@ func Test_tensorflow_GetVector(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -652,8 +618,6 @@ func Test_tensorflow_GetVector(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -684,8 +648,6 @@ func Test_tensorflow_GetVector(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -725,8 +687,6 @@ func Test_tensorflow_GetVector(t *testing.T) { feeds: test.fields.feeds, fetches: test.fields.fetches, operations: test.fields.operations, - sessionTarget: test.fields.sessionTarget, - sessionConfig: test.fields.sessionConfig, options: test.fields.options, graph: test.fields.graph, session: test.fields.session, @@ -752,8 +712,6 @@ func Test_tensorflow_GetValue(t *testing.T) { feeds []OutputSpec fetches []OutputSpec operations []*Operation - sessionTarget string - sessionConfig []byte options *SessionOptions graph *tf.Graph session session @@ -793,8 +751,6 @@ func Test_tensorflow_GetValue(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -825,8 +781,6 @@ func Test_tensorflow_GetValue(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -853,8 +807,6 @@ func Test_tensorflow_GetValue(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -881,8 +833,6 @@ func Test_tensorflow_GetValue(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -918,8 +868,6 @@ func Test_tensorflow_GetValue(t *testing.T) { feeds: test.fields.feeds, fetches: test.fields.fetches, operations: test.fields.operations, - sessionTarget: test.fields.sessionTarget, - sessionConfig: test.fields.sessionConfig, options: test.fields.options, graph: test.fields.graph, session: test.fields.session, @@ -945,8 +893,6 @@ func Test_tensorflow_GetValues(t *testing.T) { feeds []OutputSpec fetches []OutputSpec operations []*Operation - sessionTarget string - sessionConfig []byte options *SessionOptions graph *tf.Graph session session @@ -986,8 +932,6 @@ func Test_tensorflow_GetValues(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -1018,8 +962,6 @@ func Test_tensorflow_GetValues(t *testing.T) { feeds: nil, fetches: nil, operations: nil, - sessionTarget: "", - sessionConfig: nil, options: nil, graph: nil, session: &mockSession{ @@ -1055,8 +997,6 @@ func Test_tensorflow_GetValues(t *testing.T) { feeds: test.fields.feeds, fetches: test.fields.fetches, operations: test.fields.operations, - sessionTarget: test.fields.sessionTarget, - sessionConfig: test.fields.sessionConfig, options: test.fields.options, graph: test.fields.graph, session: test.fields.session,