Skip to content

Commit

Permalink
TensorFlow: upstream changes to git
Browse files Browse the repository at this point in the history
Change 109195845
	Fix TensorFlow for build against Bazel 0.1.2rc2

	Two things are currently broken with TensorFlow and Bazel 0.1.2:
	  - Bazel now use sandboxing by default on Linux and have fixed it for cc_* rules.
	    Undeclared headers are not mounted in the sandbox which make several cc_* rules
	    fails.
	  - Bazel now enforce strict header checking and some target were missing
	    headers even though the headers were mounted in the sandbox. This change
	    adds a "strict_headers" target that globs every headers of the core
	    library and add it to the `tf_cc_tests` targets.
Change 109162708
	Fix various website issues
	- Fix headline in os_setup.md
	- Fix #anchor links
Change 109162129
	Fix numbers in mnist tutorial, fixes tensorflow#362
Change 109158967
	Fix typo in word2vec tutorial, fixes tensorflow#347
Change 109151855
	Fix tile and its gradient for scalars on GPUs

	Eigen doesn't handle scalars on GPUs in all cases.  Fortunately, both
	tile and its gradient are the identity for scalars, so we can just copy
	the input to the output.

	Fixes tensorflow#391.
Change 109140763
	Support int32 and int64 in tf.random_uniform

	This requires a new RandomUniformInt op on the C++ side since the op needs
	to know minval and maxval.

	Fixes tensorflow#364.
Change 109140738
	Fix spacing in docs.
Change 109140030
	Fix content nav to not hide the bottom 100 or so px.
Change 109139967
	Add license files to TensorBoard files, fix mnist_with_summaries test
Change 109138333
	Fix typos in docstring
Change 109138098
	Fix some missing resources in the website.

	Fixes tensorflow#366.
Change 109123771
	Make sparse_to_dense's default_value default to 0

	Nearly all uses of sparse_to_dense use 0 as the default.  The
	same goes for sparse_tensor_to_dense.

Base CL: 109198336
  • Loading branch information
keveman committed Dec 2, 2015
1 parent f586a5e commit fa095c5
Show file tree
Hide file tree
Showing 27 changed files with 715 additions and 273 deletions.
2 changes: 1 addition & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ bind(
git_repository(
name = "re2",
remote = "https://github.com/google/re2.git",
tag = "2015-07-01",
commit = "791beff",
)

new_http_archive(
Expand Down
14 changes: 13 additions & 1 deletion tensorflow/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ tf_cuda_library(
"**/*main.cc",
],
),
hdrs = glob(["public/**/*.h"]),
hdrs = glob([
"public/**/*.h",
"util/device_name_utils.h",
]),
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
Expand Down Expand Up @@ -345,6 +348,12 @@ cc_library(
alwayslink = 1,
)

# This is to workaround strict header checks
cc_library(
name = "strict_headers",
hdrs = glob(["**/*.h"]),
)

# Low level library tests
tf_cc_tests(
tests = glob(
Expand All @@ -356,6 +365,7 @@ tf_cc_tests(
),
deps = [
":lib",
":strict_headers",
":test_main",
],
)
Expand Down Expand Up @@ -404,6 +414,7 @@ tf_cc_tests(
":direct_session",
":kernels",
":lib",
":strict_headers",
":test_main",
":testlib",
"//tensorflow/cc:cc_ops",
Expand All @@ -424,6 +435,7 @@ tf_cc_tests(
deps = [
":direct_session",
":kernels",
":strict_headers",
":test_main",
":testlib",
"//tensorflow/cc:cc_ops",
Expand Down
151 changes: 116 additions & 35 deletions tensorflow/core/kernels/random_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ template <typename Device, class Distribution>
struct FillPhiloxRandom {
typedef typename Distribution::ResultElementType T;
void operator()(OpKernelContext*, const Device&, random::PhiloxRandom gen,
T* data, int64 size) {
T* data, int64 size, Distribution dist) {
LOG(FATAL) << "Default FillPhiloxRandom should not be executed.";
}
};
Expand All @@ -57,7 +57,8 @@ template <class Distribution>
struct FillPhiloxRandom<GPUDevice, Distribution> {
typedef typename Distribution::ResultElementType T;
void operator()(OpKernelContext* ctx, const GPUDevice&,
random::PhiloxRandom gen, T* data, int64 size);
random::PhiloxRandom gen, T* data, int64 size,
Distribution dist);
};

#endif
Expand All @@ -72,8 +73,7 @@ template <class Distribution>
struct FillPhiloxRandomTask<Distribution, false> {
typedef typename Distribution::ResultElementType T;
static void Run(random::PhiloxRandom gen, T* data, int64 size,
int64 start_group, int64 limit_group) {
Distribution dist;
int64 start_group, int64 limit_group, Distribution dist) {
const int kGroupSize = Distribution::kResultElementCount;

gen.Skip(start_group);
Expand All @@ -96,19 +96,18 @@ struct FillPhiloxRandomTask<Distribution, false> {
}
};

// Specialization for distribution that takes a varaiable number of samples for
// Specialization for distribution that takes a variable number of samples for
// each output. This will be slower due to the generality.
template <class Distribution>
struct FillPhiloxRandomTask<Distribution, true> {
typedef typename Distribution::ResultElementType T;
static const int64 kReservedSamplesPerOutput = 256;

static void Run(random::PhiloxRandom base_gen, T* data, int64 size,
int64 start_group, int64 limit_group) {
int64 start_group, int64 limit_group, Distribution dist) {
using random::PhiloxRandom;
using random::SingleSampleAdapter;

Distribution dist;
const int kGroupSize = Distribution::kResultElementCount;

static const int kGeneratorSkipPerOutputGroup =
Expand Down Expand Up @@ -153,7 +152,8 @@ template <class Distribution>
struct FillPhiloxRandom<CPUDevice, Distribution> {
typedef typename Distribution::ResultElementType T;
void operator()(OpKernelContext* context, const CPUDevice&,
random::PhiloxRandom gen, T* data, int64 size) {
random::PhiloxRandom gen, T* data, int64 size,
Distribution dist) {
const int kGroupSize = Distribution::kResultElementCount;

auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
Expand All @@ -164,17 +164,49 @@ struct FillPhiloxRandom<CPUDevice, Distribution> {
// sub-linear. Too many threads causes a much worse overall performance.
int num_workers = 6;
Shard(num_workers, worker_threads.workers, total_group_count, kGroupSize,
[&gen, data, size](int64 start_group, int64 limit_group) {
[&gen, data, size, dist](int64 start_group, int64 limit_group) {
FillPhiloxRandomTask<
Distribution,
Distribution::kVariableSamplesPerOutput>::Run(gen, data, size,
start_group,
limit_group);
limit_group,
dist);
});
}
};
} // namespace functor

namespace {

static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape,
int index, Tensor** output) {
if (!TensorShapeUtils::IsLegacyVector(shape.shape())) {
return errors::InvalidArgument(
"shape must be a vector of {int32,int64}, got shape ",
shape.shape().ShortDebugString());
}
if (shape.dtype() == DataType::DT_INT32) {
auto vec = shape.flat<int32>();
TF_RETURN_IF_ERROR(ctx->allocate_output(
index, TensorShapeUtils::MakeShape(vec.data(), vec.size()), output));
} else if (shape.dtype() == DataType::DT_INT64) {
auto vec = shape.flat<int64>();
TF_RETURN_IF_ERROR(ctx->allocate_output(
index, TensorShapeUtils::MakeShape(vec.data(), vec.size()), output));
} else {
return errors::InvalidArgument("shape must be a vector of {int32,int64}.");
}
return Status::OK();
}

// Reserve enough random samples in the generator for the given output count.
// Note that the 256 multiplier is repeated above; do not change it just here.
static random::PhiloxRandom ReserveRandomOutputs(GuardedPhiloxRandom& generator,
int64 output_count) {
int64 conservative_sample_count = output_count << 8;
return generator.ReserveSamples128(conservative_sample_count);
}

// For now, use the same interface as RandomOp, so we can choose either one
// at the run-time.
template <typename Device, class Distribution>
Expand All @@ -186,41 +218,65 @@ class PhiloxRandomOp : public OpKernel {
}

void Compute(OpKernelContext* ctx) override {
const Tensor& input = ctx->input(0);
OP_REQUIRES(
ctx, TensorShapeUtils::IsLegacyVector(input.shape()),
errors::InvalidArgument("shape must be a vector of {int32,int64}."));
Tensor* output = nullptr;
if (input.dtype() == DataType::DT_INT32) {
auto vec = input.flat<int32>();
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShapeUtils::MakeShape(
vec.data(), vec.size()),
&output));
} else if (input.dtype() == DataType::DT_INT64) {
auto vec = input.flat<int64>();
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShapeUtils::MakeShape(
vec.data(), vec.size()),
&output));
} else {
OP_REQUIRES(ctx, false, errors::InvalidArgument(
"shape must be a vector of {int32,int64}."));
}
const Tensor& shape = ctx->input(0);
Tensor* output;
OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
auto output_flat = output->flat<T>();
functor::FillPhiloxRandom<Device, Distribution>()(
ctx, ctx->eigen_device<Device>(),
ReserveRandomOutputs(output->flat<T>().size()),
output->flat<T>().data(), output->flat<T>().size());
ReserveRandomOutputs(generator_, output_flat.size()),
output_flat.data(), output_flat.size(), Distribution());
}

private:
GuardedPhiloxRandom generator_;
};

template <typename Device, class IntType>
class RandomUniformIntOp : public OpKernel {
public:
explicit RandomUniformIntOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, generator_.Init(ctx));
}

// Reserve enough random samples in the generator for the given output count.
random::PhiloxRandom ReserveRandomOutputs(int64 output_count) {
int64 conservative_sample_count = output_count << 8;
return generator_.ReserveSamples128(conservative_sample_count);
void Compute(OpKernelContext* ctx) override {
const Tensor& shape = ctx->input(0);
const Tensor& minval = ctx->input(1);
const Tensor& maxval = ctx->input(2);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()),
errors::InvalidArgument("minval must be 0-D, got shape ",
minval.shape().ShortDebugString()));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval.shape()),
errors::InvalidArgument("maxval must be 0-D, got shape ",
maxval.shape().ShortDebugString()));

// Verify that minval < maxval
IntType lo = minval.scalar<IntType>()();
IntType hi = maxval.scalar<IntType>()();
OP_REQUIRES(
ctx, lo < hi,
errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));

// Build distribution
typedef random::UniformDistribution<random::PhiloxRandom, IntType>
Distribution;
Distribution dist(lo, hi);

Tensor* output;
OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
auto output_flat = output->flat<IntType>();
functor::FillPhiloxRandom<Device, Distribution>()(
ctx, ctx->eigen_device<Device>(),
ReserveRandomOutputs(generator_, output_flat.size()),
output_flat.data(), output_flat.size(), dist);
}

private:
GuardedPhiloxRandom generator_;
};

} // namespace

#define REGISTER(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("RandomUniform") \
Expand All @@ -246,10 +302,22 @@ class PhiloxRandomOp : public OpKernel {
random::TruncatedNormalDistribution< \
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >)

#define REGISTER_INT(IntType) \
REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
.Device(DEVICE_CPU) \
.HostMemory("shape") \
.HostMemory("minval") \
.HostMemory("maxval") \
.TypeConstraint<IntType>("Tout"), \
RandomUniformIntOp<CPUDevice, IntType>);

REGISTER(float);
REGISTER(double);
REGISTER_INT(int32);
REGISTER_INT(int64);

#undef REGISTER
#undef REGISTER_INT

#if GOOGLE_CUDA

Expand Down Expand Up @@ -281,10 +349,23 @@ REGISTER(double);
random::TruncatedNormalDistribution< \
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >)

#define REGISTER_INT(IntType) \
REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
.Device(DEVICE_GPU) \
.HostMemory("shape") \
.HostMemory("minval") \
.HostMemory("maxval") \
.TypeConstraint<int32>("T") \
.TypeConstraint<IntType>("Tout"), \
RandomUniformIntOp<GPUDevice, IntType>);

REGISTER(float);
REGISTER(double);
REGISTER_INT(int32);
REGISTER_INT(int64);

#undef REGISTER
#undef REGISTER_INT

#endif // GOOGLE_CUDA

Expand Down
19 changes: 11 additions & 8 deletions tensorflow/core/kernels/random_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ struct FillPhiloxRandomKernel;
template <class Distribution>
struct FillPhiloxRandomKernel<Distribution, false> {
typedef typename Distribution::ResultElementType T;
PHILOX_DEVICE_FUNC void Run(random::PhiloxRandom gen, T* data, int64 size) {
Distribution dist;
PHILOX_DEVICE_FUNC void Run(random::PhiloxRandom gen, T* data, int64 size,
Distribution dist) {
const int kGroupSize = Distribution::kResultElementCount;

const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
Expand Down Expand Up @@ -74,7 +74,7 @@ template <class Distribution>
struct FillPhiloxRandomKernel<Distribution, true> {
typedef typename Distribution::ResultElementType T;
PHILOX_DEVICE_FUNC void Run(const random::PhiloxRandom& base_gen, T* data,
int64 size) {
int64 size, Distribution dist) {
using random::PhiloxRandom;
using random::SingleSampleAdapter;

Expand All @@ -88,7 +88,6 @@ struct FillPhiloxRandomKernel<Distribution, true> {
const int32 total_thread_count = gridDim.x * blockDim.x;
int64 group_index = thread_id;
int64 offset = group_index * kGroupSize;
Distribution dist;

while (offset < size) {
// Since each output takes a variable number of samples, we need to
Expand Down Expand Up @@ -118,10 +117,10 @@ template <class Distribution>
__global__ void __launch_bounds__(1024)
FillPhiloxRandomKernelLaunch(random::PhiloxRandom base_gen,
typename Distribution::ResultElementType* data,
int64 size) {
int64 size, Distribution dist) {
FillPhiloxRandomKernel<Distribution,
Distribution::kVariableSamplesPerOutput>()
.Run(base_gen, data, size);
.Run(base_gen, data, size, dist);
}

// Partial specialization for GPU
Expand All @@ -130,15 +129,15 @@ struct FillPhiloxRandom<GPUDevice, Distribution> {
typedef typename Distribution::ResultElementType T;
typedef GPUDevice Device;
void operator()(OpKernelContext*, const Device& d, random::PhiloxRandom gen,
T* data, int64 size) {
T* data, int64 size, Distribution dist) {
const int32 block_size = d.maxCudaThreadsPerBlock();
const int32 num_blocks =
(d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor()) /
block_size;

FillPhiloxRandomKernelLaunch<
Distribution><<<num_blocks, block_size, 0, d.stream()>>>(gen, data,
size);
size, dist);
}
};

Expand All @@ -149,6 +148,10 @@ template struct FillPhiloxRandom<
GPUDevice, random::UniformDistribution<random::PhiloxRandom, float> >;
template struct FillPhiloxRandom<
GPUDevice, random::UniformDistribution<random::PhiloxRandom, double> >;
template struct FillPhiloxRandom<
GPUDevice, random::UniformDistribution<random::PhiloxRandom, int32> >;
template struct FillPhiloxRandom<
GPUDevice, random::UniformDistribution<random::PhiloxRandom, int64> >;
template struct FillPhiloxRandom<
GPUDevice, random::NormalDistribution<random::PhiloxRandom, float> >;
template struct FillPhiloxRandom<
Expand Down
Loading

0 comments on commit fa095c5

Please sign in to comment.