Skip to content

Commit

Permalink
remove sessionTarget, sessionConfig
Browse files Browse the repository at this point in the history
Signed-off-by: datelier <[email protected]>
  • Loading branch information
datelier authored and actions-user committed May 20, 2020
1 parent 71e7914 commit 044658c
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 74 deletions.
20 changes: 17 additions & 3 deletions internal/core/converter/tensorflow/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,36 @@ var (

func WithSessionOptions(opts *SessionOptions) Option {
return func(t *tensorflow) {
t.options = opts
if opts != nil {
t.options = opts
}
}
}

func WithSessionTarget(tgt string) Option {
return func(t *tensorflow) {
if len(tgt) != 0 {
t.sessionTarget = tgt
if t.options == nil {
t.options = &SessionOptions{
Target: tgt,
}
} else {
t.options.Target = tgt
}
}
}
}

func WithSessionConfig(cfg []byte) Option {
return func(t *tensorflow) {
if cfg != nil {
t.sessionConfig = cfg
if t.options == nil {
t.options = &SessionOptions{
Config: cfg,
}
} else {
t.options.Config = cfg
}
}
}
}
Expand Down
8 changes: 6 additions & 2 deletions internal/core/converter/tensorflow/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ func TestWithSessionTarget(t *testing.T) {
},
want: want{
obj: &T{
sessionTarget: "test",
options: &SessionOptions{
Target: "test",
},
},
},
},
Expand Down Expand Up @@ -206,7 +208,9 @@ func TestWithSessionConfig(t *testing.T) {
},
want: want{
obj: &T{
sessionConfig: []byte{0},
options: &SessionOptions{
Config: []byte{0},
},
},
},
},
Expand Down
9 changes: 0 additions & 9 deletions internal/core/converter/tensorflow/tensorflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ type tensorflow struct {
feeds []OutputSpec
fetches []OutputSpec
operations []*Operation
sessionTarget string
sessionConfig []byte
options *SessionOptions
graph *tf.Graph
session session
Expand All @@ -75,13 +73,6 @@ func New(opts ...Option) (TF, error) {
opt(t)
}

if t.options == nil && (t.sessionTarget == "" || t.sessionConfig != nil) {
t.options = &tf.SessionOptions{
Target: t.sessionTarget,
Config: t.sessionConfig,
}
}

model, err := loadFunc(t.exportDir, t.tags, t.options)
if err != nil {
return nil, err
Expand Down
60 changes: 0 additions & 60 deletions internal/core/converter/tensorflow/tensorflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ func TestNew(t *testing.T) {
},
want: want{
want: &tensorflow{
sessionTarget: "test",
sessionConfig: []byte{},
options: &tf.SessionOptions{
Target: "test",
Config: []byte{},
Expand Down Expand Up @@ -151,8 +149,6 @@ func Test_tensorflow_Close(t *testing.T) {
feeds []OutputSpec
fetches []OutputSpec
operations []*Operation
sessionTarget string
sessionConfig []byte
options *SessionOptions
graph *tf.Graph
session session
Expand Down Expand Up @@ -184,8 +180,6 @@ func Test_tensorflow_Close(t *testing.T) {
feeds: nil,
fetches: nil,
operations: nil,
sessionTarget: "",
sessionConfig: nil,
options: nil,
graph: nil,
session: &mockSession{
Expand All @@ -208,8 +202,6 @@ func Test_tensorflow_Close(t *testing.T) {
feeds: nil,
fetches: nil,
operations: nil,
sessionTarget: "",
sessionConfig: nil,
options: nil,
graph: nil,
session: &mockSession{
Expand Down Expand Up @@ -244,8 +236,6 @@ func Test_tensorflow_Close(t *testing.T) {
feeds: test.fields.feeds,
fetches: test.fields.fetches,
operations: test.fields.operations,
sessionTarget: test.fields.sessionTarget,
sessionConfig: test.fields.sessionConfig,
options: test.fields.options,
graph: test.fields.graph,
session: test.fields.session,
Expand All @@ -271,8 +261,6 @@ func Test_tensorflow_run(t *testing.T) {
feeds []OutputSpec
fetches []OutputSpec
operations []*Operation
sessionTarget string
sessionConfig []byte
options *SessionOptions
graph *tf.Graph
session session
Expand Down Expand Up @@ -312,8 +300,6 @@ func Test_tensorflow_run(t *testing.T) {
feeds: nil,
fetches: nil,
operations: nil,
sessionTarget: "",
sessionConfig: nil,
options: nil,
graph: nil,
session: &mockSession{
Expand Down Expand Up @@ -347,8 +333,6 @@ func Test_tensorflow_run(t *testing.T) {
},
fetches: nil,
operations: nil,
sessionTarget: "",
sessionConfig: nil,
options: nil,
graph: tf.NewGraph(),
session: &mockSession{
Expand Down Expand Up @@ -377,8 +361,6 @@ func Test_tensorflow_run(t *testing.T) {
feeds: nil,
fetches: nil,
operations: nil,
sessionTarget: "",
sessionConfig: nil,
options: nil,
graph: nil,
session: nil,
Expand All @@ -400,8 +382,6 @@ func Test_tensorflow_run(t *testing.T) {
feeds: nil,
fetches: nil,
operations: nil,
sessionTarget: "",
sessionConfig: nil,
options: nil,
graph: tf.NewGraph(),
session: &mockSession{
Expand Down Expand Up @@ -436,8 +416,6 @@ func Test_tensorflow_run(t *testing.T) {
feeds: test.fields.feeds,
fetches: test.fields.fetches,
operations: test.fields.operations,
sessionTarget: test.fields.sessionTarget,
sessionConfig: test.fields.sessionConfig,
options: test.fields.options,
graph: test.fields.graph,
session: test.fields.session,
Expand All @@ -463,8 +441,6 @@ func Test_tensorflow_GetVector(t *testing.T) {
feeds []OutputSpec
fetches []OutputSpec
operations []*Operation
sessionTarget string
sessionConfig []byte
options *SessionOptions
graph *tf.Graph
session session
Expand Down Expand Up @@ -504,8 +480,6 @@ func Test_tensorflow_GetVector(t *testing.T) {
feeds: nil,
fetches: nil,
operations: nil,
sessionTarget: "",
sessionConfig: nil,
options: nil,
graph: nil,
session: &mockSession{
Expand Down Expand Up @@ -536,8 +510,6 @@ func Test_tensorflow_GetVector(t *testing.T) {
feeds: nil,
fetches: nil,
operations: nil,
sessionTarget: "",
sessionConfig: nil,
options: nil,
graph: nil,
session: &mockSession{
Expand All @@ -564,8 +536,6 @@ func Test_tensorflow_GetVector(t *testing.T) {
feeds: nil,
fetches: nil,
operations: nil,
sessionTarget: "",
sessionConfig: nil,
options: nil,
graph: nil,
session: &mockSession{
Expand All @@ -592,8 +562,6 @@ func Test_tensorflow_GetVector(t *testing.T) {
feeds: nil,
fetches: nil,
operations: nil,
sessionTarget: "",
sessionConfig: nil,
options: nil,
graph: nil,
session: &mockSession{
Expand All @@ -620,8 +588,6 @@ func Test_tensorflow_GetVector(t *testing.T) {
feeds: nil,
fetches: nil,
operations: nil,
sessionTarget: "",
sessionConfig: nil,
options: nil,
graph: nil,
session: &mockSession{
Expand Down Expand Up @@ -652,8 +618,6 @@ func Test_tensorflow_GetVector(t *testing.T) {
feeds: nil,
fetches: nil,
operations: nil,
sessionTarget: "",
sessionConfig: nil,
options: nil,
graph: nil,
session: &mockSession{
Expand Down Expand Up @@ -684,8 +648,6 @@ func Test_tensorflow_GetVector(t *testing.T) {
feeds: nil,
fetches: nil,
operations: nil,
sessionTarget: "",
sessionConfig: nil,
options: nil,
graph: nil,
session: &mockSession{
Expand Down Expand Up @@ -725,8 +687,6 @@ func Test_tensorflow_GetVector(t *testing.T) {
feeds: test.fields.feeds,
fetches: test.fields.fetches,
operations: test.fields.operations,
sessionTarget: test.fields.sessionTarget,
sessionConfig: test.fields.sessionConfig,
options: test.fields.options,
graph: test.fields.graph,
session: test.fields.session,
Expand All @@ -752,8 +712,6 @@ func Test_tensorflow_GetValue(t *testing.T) {
feeds []OutputSpec
fetches []OutputSpec
operations []*Operation
sessionTarget string
sessionConfig []byte
options *SessionOptions
graph *tf.Graph
session session
Expand Down Expand Up @@ -793,8 +751,6 @@ func Test_tensorflow_GetValue(t *testing.T) {
feeds: nil,
fetches: nil,
operations: nil,
sessionTarget: "",
sessionConfig: nil,
options: nil,
graph: nil,
session: &mockSession{
Expand Down Expand Up @@ -825,8 +781,6 @@ func Test_tensorflow_GetValue(t *testing.T) {
feeds: nil,
fetches: nil,
operations: nil,
sessionTarget: "",
sessionConfig: nil,
options: nil,
graph: nil,
session: &mockSession{
Expand All @@ -853,8 +807,6 @@ func Test_tensorflow_GetValue(t *testing.T) {
feeds: nil,
fetches: nil,
operations: nil,
sessionTarget: "",
sessionConfig: nil,
options: nil,
graph: nil,
session: &mockSession{
Expand All @@ -881,8 +833,6 @@ func Test_tensorflow_GetValue(t *testing.T) {
feeds: nil,
fetches: nil,
operations: nil,
sessionTarget: "",
sessionConfig: nil,
options: nil,
graph: nil,
session: &mockSession{
Expand Down Expand Up @@ -918,8 +868,6 @@ func Test_tensorflow_GetValue(t *testing.T) {
feeds: test.fields.feeds,
fetches: test.fields.fetches,
operations: test.fields.operations,
sessionTarget: test.fields.sessionTarget,
sessionConfig: test.fields.sessionConfig,
options: test.fields.options,
graph: test.fields.graph,
session: test.fields.session,
Expand All @@ -945,8 +893,6 @@ func Test_tensorflow_GetValues(t *testing.T) {
feeds []OutputSpec
fetches []OutputSpec
operations []*Operation
sessionTarget string
sessionConfig []byte
options *SessionOptions
graph *tf.Graph
session session
Expand Down Expand Up @@ -986,8 +932,6 @@ func Test_tensorflow_GetValues(t *testing.T) {
feeds: nil,
fetches: nil,
operations: nil,
sessionTarget: "",
sessionConfig: nil,
options: nil,
graph: nil,
session: &mockSession{
Expand Down Expand Up @@ -1018,8 +962,6 @@ func Test_tensorflow_GetValues(t *testing.T) {
feeds: nil,
fetches: nil,
operations: nil,
sessionTarget: "",
sessionConfig: nil,
options: nil,
graph: nil,
session: &mockSession{
Expand Down Expand Up @@ -1055,8 +997,6 @@ func Test_tensorflow_GetValues(t *testing.T) {
feeds: test.fields.feeds,
fetches: test.fields.fetches,
operations: test.fields.operations,
sessionTarget: test.fields.sessionTarget,
sessionConfig: test.fields.sessionConfig,
options: test.fields.options,
graph: test.fields.graph,
session: test.fields.session,
Expand Down

0 comments on commit 044658c

Please sign in to comment.