Skip to content

Commit

Permalink
Updates the PR to use attribute instead of Env Variable
Browse files Browse the repository at this point in the history
-Originally AVRO_PARSER_NUM_MINIBATCH was set as an environmental
variable.  Because tensorflow-io rarely uses env vars to fine tune
kernal ops this was changed to an attribute. See comment here:
tensorflow#1283 (comment)
  • Loading branch information
i-ony committed Mar 5, 2021
1 parent 4da7742 commit 6ae9f50
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
16 changes: 12 additions & 4 deletions tensorflow_io/core/kernels/avro/parse_avro_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ Status ParseAvro(const AvroParserConfig& config,

// This parameter affects performance in a big and data-dependent way.
const size_t kMiniBatchSizeBytes = 50000;
size_t avro_num_minibatches_;

// Calculate number of minibatches.
// In main regime make each minibatch around kMiniBatchSizeBytes bytes.
Expand All @@ -206,10 +207,9 @@ Status ParseAvro(const AvroParserConfig& config,
minibatch_bytes = 0;
}
}
if (const char* n_minibatches =
std::getenv("AVRO_PARSER_NUM_MINIBATCHES")) {
VLOG(5) << "Overriding num_minibatches with " << n_minibatches;
result = std::stoi(n_minibatches);
if (avro_num_minibatches_) {
VLOG(5) << "Overriding num_minibatches with " << avro_num_minibatches_;
result = avro_num_minibatches_;
}
// This is to ensure users can control the num minibatches all the way down
// to size of 1(no parallelism).
Expand Down Expand Up @@ -406,6 +406,8 @@ class ParseAvroOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("sparse_types", &sparse_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_types", &dense_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_shapes", &dense_shapes_));
OP_REQUIRES_OK(
ctx, ctx->GetAttr("avro_num_minibatches", &avro_num_minibatches_));

OP_REQUIRES_OK(ctx, ctx->GetAttr("sparse_keys", &sparse_keys_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_keys", &dense_keys_));
Expand All @@ -419,6 +421,11 @@ class ParseAvroOp : public OpKernel {
dense_shapes_[d].dims() > 1 && dense_shapes_[d].dim_size(0) == -1;
}

// Check that avro_num_minibatches is positive
OP_REQUIRES(ctx, avro_num_minibatches_ >= 0,
errors::InvalidArgument("Need avro_num_minibatches >= 0, got ",
avro_num_minibatches_));

string reader_schema_str;
OP_REQUIRES_OK(ctx, ctx->GetAttr("reader_schema", &reader_schema_str));

Expand Down Expand Up @@ -513,6 +520,7 @@ class ParseAvroOp : public OpKernel {
avro::ValidSchema reader_schema_;
size_t num_dense_;
size_t num_sparse_;
int64 avro_num_minibatches_;

private:
std::vector<std::pair<string, DataType>> CreateKeysAndTypes() {
Expand Down
4 changes: 4 additions & 0 deletions tensorflow_io/core/ops/avro_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ REGISTER_OP("IO>ParseAvro")
.Output("sparse_values: sparse_types")
.Output("sparse_shapes: num_sparse * int64")
.Output("dense_values: dense_types")
.Attr("avro_num_minibatches: int >= 0")
.Attr("num_sparse: int >= 0")
.Attr("reader_schema: string")
.Attr("sparse_keys: list(string) >= 0")
Expand All @@ -94,6 +95,7 @@ REGISTER_OP("IO>ParseAvro")
.SetShapeFn([](shape_inference::InferenceContext* c) {
size_t num_dense;
size_t num_sparse;
int64 avro_num_minibatches;
int64 num_sparse_from_user;
std::vector<DataType> sparse_types;
std::vector<DataType> dense_types;
Expand All @@ -106,6 +108,8 @@ REGISTER_OP("IO>ParseAvro")
TF_RETURN_IF_ERROR(c->GetAttr("sparse_types", &sparse_types));
TF_RETURN_IF_ERROR(c->GetAttr("dense_types", &dense_types));
TF_RETURN_IF_ERROR(c->GetAttr("dense_shapes", &dense_shapes));
TF_RETURN_IF_ERROR(
c->GetAttr("avro_num_minibatches", &avro_num_minibatches));

TF_RETURN_IF_ERROR(c->GetAttr("sparse_keys", &sparse_keys));
TF_RETURN_IF_ERROR(c->GetAttr("sparse_ranks", &sparse_ranks));
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_io/core/python/experimental/parse_avro_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def _parse_avro(
dense_defaults=None,
dense_shapes=None,
name=None,
avro_num_minibatches=0,
):
"""Parses Avro records.
Expand Down Expand Up @@ -196,6 +197,7 @@ def _parse_avro(
dense_keys=dense_keys,
dense_shapes=dense_shapes,
name=name,
avro_num_minibatches=avro_num_minibatches,
)

(sparse_indices, sparse_values, sparse_shapes, dense_values) = outputs
Expand Down

1 comment on commit 6ae9f50

@i-ony
Copy link
Owner Author

@i-ony i-ony commented on 6ae9f50 Mar 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passes internal avro benchmark tests on Linux

Please sign in to comment.