Skip to content

Commit

Permalink
tensorflow test (#378)
Browse files Browse the repository at this point in the history
* Add tensorflow test code

Signed-off-by: datelier <[email protected]>

* Add tensorflow option test code

Signed-off-by: datelier <[email protected]>

* fix test name

Signed-off-by: datelier <[email protected]>

* fix DeepSource issue: Empty string test can be improved

Signed-off-by: datelier <[email protected]>

* fix Deepsource issue: Incomplete condition detected

Signed-off-by: datelier <[email protected]>

* remove sessionTarget, sessionConfig

Signed-off-by: datelier <[email protected]>

* fix DeepSource issue: Empty string test can be improved

Signed-off-by: datelier <[email protected]>

* fix test case based on review

Signed-off-by: datelier <[email protected]>

* gofmt

Signed-off-by: datelier <[email protected]>

* fix golangci-lint issue

Signed-off-by: datelier <[email protected]>

* fix golangci-lint issue

Signed-off-by: datelier <[email protected]>

* delete Closer interface

Signed-off-by: datelier <[email protected]>

* fix DeepSource issue: Function literal can be simplified

Signed-off-by: datelier <[email protected]>

Co-authored-by: Yusuke Kato <[email protected]>
  • Loading branch information
datelier and Yusuke Kato authored Jun 4, 2020
1 parent 48e92aa commit 6034bb8
Show file tree
Hide file tree
Showing 5 changed files with 1,073 additions and 1,350 deletions.
36 changes: 31 additions & 5 deletions internal/core/converter/tensorflow/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
// Package tensorflow provides implementation of Go API for extract data to vector
package tensorflow

// Option is tensorflow configure.
type Option func(*tensorflow)

var (
Expand All @@ -27,28 +28,46 @@ var (
}
)

// WithSessionOptions returns Option that sets options.
func WithSessionOptions(opts *SessionOptions) Option {
return func(t *tensorflow) {
t.options = opts
if opts != nil {
t.options = opts
}
}
}

// WithSessionTarget returns Option that sets target.
func WithSessionTarget(tgt string) Option {
return func(t *tensorflow) {
if len(tgt) != 0 {
t.sessionTarget = tgt
if tgt != "" {
if t.options == nil {
t.options = &SessionOptions{
Target: tgt,
}
} else {
t.options.Target = tgt
}
}
}
}

// WithSessionConfig returns Option that sets config.
func WithSessionConfig(cfg []byte) Option {
return func(t *tensorflow) {
if cfg != nil {
t.sessionConfig = cfg
if t.options == nil {
t.options = &SessionOptions{
Config: cfg,
}
} else {
t.options.Config = cfg
}
}
}
}

// WithOperations returns Option that sets operations.
func WithOperations(opes ...*Operation) Option {
return func(t *tensorflow) {
if opes != nil {
Expand All @@ -61,14 +80,16 @@ func WithOperations(opes ...*Operation) Option {
}
}

// WithExportPath returns Option that sets exportDir.
func WithExportPath(path string) Option {
return func(t *tensorflow) {
if len(path) != 0 {
if path != "" {
t.exportDir = path
}
}
}

// WithTags returns Option that sets tags.
func WithTags(tags ...string) Option {
return func(t *tensorflow) {
if tags != nil {
Expand All @@ -81,12 +102,14 @@ func WithTags(tags ...string) Option {
}
}

// WithFeed returns Option that sets feeds.
func WithFeed(operationName string, outputIndex int) Option {
return func(t *tensorflow) {
t.feeds = append(t.feeds, OutputSpec{operationName, outputIndex})
}
}

// WithFeeds returns Option that sets feeds.
func WithFeeds(operationNames []string, outputIndexes []int) Option {
return func(t *tensorflow) {
if operationNames != nil && outputIndexes != nil && len(operationNames) == len(outputIndexes) {
Expand All @@ -97,12 +120,14 @@ func WithFeeds(operationNames []string, outputIndexes []int) Option {
}
}

// WithFetch returns Option that sets fetches.
func WithFetch(operationName string, outputIndex int) Option {
return func(t *tensorflow) {
t.fetches = append(t.fetches, OutputSpec{operationName, outputIndex})
}
}

// WithFetches returns Option that sets fetches.
func WithFetches(operationNames []string, outputIndexes []int) Option {
return func(t *tensorflow) {
if operationNames != nil && outputIndexes != nil && len(operationNames) == len(outputIndexes) {
Expand All @@ -113,6 +138,7 @@ func WithFetches(operationNames []string, outputIndexes []int) Option {
}
}

// WithNdim returns Option that sets ndim.
func WithNdim(ndim uint8) Option {
return func(t *tensorflow) {
t.ndim = ndim
Expand Down
Loading

0 comments on commit 6034bb8

Please sign in to comment.