Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensorflow test #378

Merged
merged 14 commits into from
Jun 4, 2020
36 changes: 31 additions & 5 deletions internal/core/converter/tensorflow/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
// Package tensorflow provides implementation of Go API for extract data to vector
package tensorflow

// Option is tensorflow configure.
type Option func(*tensorflow)

var (
Expand All @@ -27,28 +28,46 @@ var (
}
)

// WithSessionOptions returns Option that sets options.
func WithSessionOptions(opts *SessionOptions) Option {
return func(t *tensorflow) {
t.options = opts
if opts != nil {
t.options = opts
}
}
}

// WithSessionTarget returns Option that sets target.
func WithSessionTarget(tgt string) Option {
return func(t *tensorflow) {
if len(tgt) != 0 {
t.sessionTarget = tgt
if tgt != "" {
if t.options == nil {
Copy link
Collaborator

@hlts2 hlts2 May 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's good! 👍

t.options = &SessionOptions{
Target: tgt,
}
} else {
t.options.Target = tgt
}
}
}
}

// WithSessionConfig returns Option that sets config.
func WithSessionConfig(cfg []byte) Option {
return func(t *tensorflow) {
if cfg != nil {
t.sessionConfig = cfg
if t.options == nil {
t.options = &SessionOptions{
Config: cfg,
}
} else {
t.options.Config = cfg
}
}
}
}

// WithOperations returns Option that sets operations.
func WithOperations(opes ...*Operation) Option {
return func(t *tensorflow) {
if opes != nil {
Expand All @@ -61,14 +80,16 @@ func WithOperations(opes ...*Operation) Option {
}
}

// WithExportPath returns Option that sets exportDir.
func WithExportPath(path string) Option {
return func(t *tensorflow) {
if len(path) != 0 {
if path != "" {
t.exportDir = path
}
}
}

// WithTags returns Option that sets tags.
func WithTags(tags ...string) Option {
return func(t *tensorflow) {
if tags != nil {
Expand All @@ -81,12 +102,14 @@ func WithTags(tags ...string) Option {
}
}

// WithFeed returns Option that sets feeds.
func WithFeed(operationName string, outputIndex int) Option {
return func(t *tensorflow) {
t.feeds = append(t.feeds, OutputSpec{operationName, outputIndex})
}
}

// WithFeeds returns Option that sets feeds.
func WithFeeds(operationNames []string, outputIndexes []int) Option {
return func(t *tensorflow) {
if operationNames != nil && outputIndexes != nil && len(operationNames) == len(outputIndexes) {
Expand All @@ -97,12 +120,14 @@ func WithFeeds(operationNames []string, outputIndexes []int) Option {
}
}

// WithFetch returns Option that sets fetches.
func WithFetch(operationName string, outputIndex int) Option {
return func(t *tensorflow) {
t.fetches = append(t.fetches, OutputSpec{operationName, outputIndex})
}
}

// WithFetches returns Option that sets fetches.
func WithFetches(operationNames []string, outputIndexes []int) Option {
return func(t *tensorflow) {
if operationNames != nil && outputIndexes != nil && len(operationNames) == len(outputIndexes) {
Expand All @@ -113,6 +138,7 @@ func WithFetches(operationNames []string, outputIndexes []int) Option {
}
}

// WithNdim returns Option that sets ndim.
func WithNdim(ndim uint8) Option {
return func(t *tensorflow) {
t.ndim = ndim
Expand Down
Loading