Skip to content

Commit

Permalink
[NNC] Lowering function generates the output buffer with the specifie…
Browse files Browse the repository at this point in the history
…d stride (pytorch#76529)

Summary:
Pass stride information to lowering function to generate the output bufer with proper memory layout.

Pull Request resolved: pytorch#76529

Reviewed By: ZolotukhinM

Differential Revision: D36116712

Pulled By: IvanKobzarev

fbshipit-source-id: d3901f756b3710ecce172d6db3ecb0b7c12fb929
(cherry picked from commit b6cd53c)
  • Loading branch information
EikanWang authored and pytorchmergebot committed May 4, 2022
1 parent 0878ba4 commit 429a80d
Show file tree
Hide file tree
Showing 30 changed files with 687 additions and 83 deletions.
35 changes: 28 additions & 7 deletions aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -786,15 +786,36 @@ struct TORCH_API TensorType : public SharedType {

static const TypeKind Kind = TypeKind::TensorType;

static std::vector<int64_t> contiguousStridesOf(at::IntArrayRef sizes) {
std::vector<int64_t> strides(sizes.size());
if (sizes.empty()) // zero-dim case
static std::vector<int64_t> contiguousStridesOf(
at::IntArrayRef sizes,
at::MemoryFormat memory_format = MemoryFormat::Contiguous) {
auto contiguous_fn = [](const at::IntArrayRef& sizes,
const std::vector<int64_t>& dim_order) {
std::vector<int64_t> strides(sizes.size());
if (sizes.empty()) // zero-dim case
return strides;

strides[dim_order[0]] = 1;
for (size_t i = 1; i < dim_order.size(); i++) {
auto cur_dim = dim_order[i];
auto pre_dim = dim_order[i - 1];
strides[cur_dim] = strides[pre_dim] * sizes[pre_dim];
}
return strides;
strides.back() = 1;
for (size_t i = strides.size() - 1; i > 0; i--) {
strides[i - 1] = strides[i] * sizes[i];
};

std::vector<int64_t> dim_order(sizes.size());
if (memory_format == MemoryFormat::ChannelsLast) {
dim_order = {1, 3, 2, 0};
} else if (memory_format == MemoryFormat::ChannelsLast3d) {
dim_order = {1, 4, 3, 2, 0};
} else {
auto ndims = sizes.size();
for (size_t i = 0; i < ndims; i++) {
dim_order[i] = ndims - i - 1; // Reverse
}
}
return strides;
return contiguous_fn(sizes, dim_order);
}

private:
Expand Down
12 changes: 6 additions & 6 deletions benchmarks/cpp/tensorexpr/bench_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,8 @@ BENCHMARK_DEFINE_F(Reduce1D, Op)(benchmark::State& state) {
const int kChunkSize = 8;

te::BufHandle a("A", {M}, te::kFloat);
te::Tensor b =
te::computeSum({a, te::IntList({0}), false}, {}, at::kFloat, at::kCPU);
te::Tensor b = te::computeSum(
{a, te::IntList({0}), false}, {}, {}, at::kFloat, at::kCPU);
te::LoopNest nest({b});

auto loops = nest.getLoopStmtsFor(b);
Expand Down Expand Up @@ -456,8 +456,8 @@ BENCHMARK_REGISTER_F(Reduce2DCol, Torch)
BENCHMARK_DEFINE_F(Reduce2DCol, OpSchedule)(benchmark::State& state) {
constexpr int kCacheSize = 1 << 12;
te::BufHandle a("A", {M, N}, te::kFloat);
te::Tensor b =
te::computeSum({a, te::IntList({0}), false}, {N}, at::kFloat, at::kCPU);
te::Tensor b = te::computeSum(
{a, te::IntList({0}), false}, {N}, {1}, at::kFloat, at::kCPU);
te::LoopNest nest({b});

auto sch = state.range(2);
Expand Down Expand Up @@ -565,8 +565,8 @@ BENCHMARK_REGISTER_F(Reduce2DRow, Hand)->Args({1 << 18, 1 << 6});
BENCHMARK_DEFINE_F(Reduce2DRow, OpSchedule)(benchmark::State& state) {
constexpr int kChunkSize = 8;
te::BufHandle a("A", {M, N}, te::kFloat);
te::Tensor b =
te::computeSum({a, te::IntList({1}), false}, {M}, at::kFloat, at::kCPU);
te::Tensor b = te::computeSum(
{a, te::IntList({1}), false}, {M}, {1}, at::kFloat, at::kCPU);
te::LoopNest nest({b});

auto sch = state.range(2);
Expand Down
1 change: 1 addition & 0 deletions test/cpp/tensorexpr/test_external_calls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,7 @@ TEST(ExternalCall, JitCustomFusionOp) {
[external_func_name](
const std::vector<torch::jit::tensorexpr::ArgValue>& inputs,
const std::vector<torch::jit::tensorexpr::ExprHandle>& output_shape,
const std::vector<torch::jit::tensorexpr::ExprHandle>& output_strides,
const c10::optional<torch::jit::tensorexpr::ScalarType>& output_type,
at::Device device) {
auto output_dtype = Dtype(*output_type);
Expand Down
2 changes: 2 additions & 0 deletions test/cpp/tensorexpr/test_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1598,12 +1598,14 @@ TEST_F(Kernel, CodegenInspection) {
Tensor lowerNanToNum(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const std::vector<ExprHandle>& outputStrides,
const c10::optional<ScalarType>& outputType,
at::Device device) {
auto input_buf = c10::get<BufHandle>(inputs[0]);
auto e = Compute(
"custom_nan_to_num",
outputShape,
outputStrides,
[&](const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
auto load = input_buf.load(indices);
Expand Down
37 changes: 36 additions & 1 deletion test/cpp/tensorexpr/test_ops.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <gtest/gtest.h>
#include <torch/csrc/jit/tensorexpr/eval.h>
#include <torch/csrc/jit/tensorexpr/expr.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>
#include <torch/csrc/jit/tensorexpr/operators/operators.h>
#include <torch/torch.h>
Expand Down Expand Up @@ -29,7 +30,10 @@ TEST(Ops, Sum) {
const auto& outShape = outputShapes[idx];

BufHandle a("a", {M, N}, kFloat);
Tensor b = computeSum({a, dims, false}, outShape, c10::kFloat, at::kCPU);
std::vector<ExprHandle> outStrides =
c10::fmap<ExprHandle>(make_contiguous_strides(outShape));
Tensor b = computeSum(
{a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU);
auto cg = compile({a}, {b});

auto at = at::arange(M * N, at::kFloat).view({M, N});
Expand All @@ -41,3 +45,34 @@ TEST(Ops, Sum) {
ASSERT_TRUE(at::allclose(bt, ref));
}
}

TEST(Ops, ChannelsLastSum) {
constexpr int A = 2;
constexpr int B = 3;
constexpr int C = 4;
constexpr int D = 5;
constexpr int E = 6;
std::vector<IntList> testDims = {{0}, {1}, {0, 1}};

std::vector<std::vector<ExprHandle>> outputShapes = {
{B, C, D, E}, {A, C, D, E}, {C, D, E}};
for (unsigned idx = 0; idx < testDims.size(); idx++) {
const auto& dims = testDims[idx];
const auto& outShape = outputShapes[idx];

BufHandle a("a", {A, B, C, D, E}, kFloat);
std::vector<ExprHandle> outStrides =
c10::fmap<ExprHandle>(make_channels_last_strides(outShape));
Tensor b = computeSum(
{a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU);
auto cg = compile({a}, {b});

auto at = at::arange(A * B * C * D * E, at::kFloat).view({A, B, C, D, E});
auto ref = at::sum(at, dims);
auto bt = at::empty_like(ref);

cg->call({at.data_ptr<float>(), bt.data_ptr<float>()});

ASSERT_TRUE(at::allclose(bt, ref));
}
}
2 changes: 1 addition & 1 deletion test/test_tensorexpr_pybind.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def f(a):
"""
graph = torch._C.parse_ir(graph_str)

def my_custom_lowering(inputs, out_shape, out_type, device):
def my_custom_lowering(inputs, out_shape, out_stride, out_type, device):
def compute(idxs):
load = inputs[0].as_buf().load(idxs)
return te.ifThenElse(
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/tensorexpr/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,11 @@ bool Buf::is_contiguous(at::MemoryFormat memory_format) const {
return false;
dim_order = {1, 4, 3, 2, 0};
} else {
if (dims_.empty()) {
// Scalar tensor
TORCH_CHECK(strides_.empty());
return true; // Align with the isContiguous logic in the kernel.cpp
}
for (size_t i = 0; i < ndims; i++) {
dim_order[i] = ndims - i - 1; // Reverse
}
Expand Down
5 changes: 2 additions & 3 deletions torch/csrc/jit/tensorexpr/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,9 @@ class TORCH_API Var : public ExprNode<Var> {
std::string name_hint_;
};

std::vector<ExprPtr> make_contiguous_strides(
TORCH_API std::vector<ExprPtr> make_contiguous_strides(
const std::vector<ExprHandle>& dims);
std::vector<ExprPtr> make_channels_last_strides(
TORCH_API std::vector<ExprPtr> make_channels_last_strides(
const std::vector<ExprHandle>& dims);

class TORCH_API Buf : public ExprNode<Buf> {
Expand Down Expand Up @@ -324,7 +324,6 @@ class TORCH_API Buf : public ExprNode<Buf> {
bool is_cont_with(int cur_dim, int adjacent_dim) const;
bool is_stride_one(int cur_dim) const;

private:
VarPtr base_handle_;
std::vector<ExprPtr> dims_;
std::vector<ExprPtr> strides_;
Expand Down
93 changes: 79 additions & 14 deletions torch/csrc/jit/tensorexpr/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ std::vector<int64_t> _pair_int(IValue v) {
}
}

static bool isContiguous(const torch::jit::Value* v) {
static bool isContiguous(
const torch::jit::Value* v,
at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) {
auto const& tt = v->type()->cast<TensorType>();
if (!tt) {
return false;
Expand All @@ -221,6 +223,14 @@ static bool isContiguous(const torch::jit::Value* v) {
if (!sizes || !strides) {
return false;
}

// Check dimension size first
int ndims = (*sizes).size();
if ((memory_format == at::MemoryFormat::ChannelsLast && ndims != 4) ||
(memory_format == at::MemoryFormat::ChannelsLast3d && ndims != 5)) {
return false;
}

return *strides == TensorType::contiguousStridesOf(*sizes);
}

Expand Down Expand Up @@ -475,8 +485,38 @@ Tensor TensorExprKernel::computeValue(const torch::jit::Value* v) {
hasRandom_ = true;
}

// Check if the tensor is a contiguous tensor
bool is_contiguous = false;
// Check if the tensor is a channels-last contiguous tensor
bool is_channels_last_contiguous = false;
for (auto input : inputs) {
if (input->type()->kind() != TypeKind::TensorType)
continue;

TORCH_CHECK(bufs_.count(input) > 0);
auto buf_ = bufs_.at(input);

auto _is_contiguous = buf_->is_contiguous();
if (_is_contiguous) {
is_contiguous |= _is_contiguous;
} else {
is_channels_last_contiguous |=
(buf_->is_contiguous(at::MemoryFormat::ChannelsLast) ||
buf_->is_contiguous(at::MemoryFormat::ChannelsLast3d) ||
buf_->is_channels_last_1d_contiguous());
}
}

auto outputType = findDtypeForValue(v);
std::vector<ExprHandle> outputShape = sizesForValue(v);
std::vector<ExprHandle> outputStrides;
if (is_channels_last_contiguous && (!is_contiguous)) {
outputStrides =
c10::fmap<ExprHandle>(make_channels_last_strides(outputShape));
} else {
// Default
outputStrides = c10::fmap<ExprHandle>(make_contiguous_strides(outputShape));
}

std::vector<ArgValue> argInputs;
if (op == prim::ConstantChunk) {
Expand Down Expand Up @@ -521,12 +561,14 @@ Tensor TensorExprKernel::computeValue(const torch::jit::Value* v) {
}

if (NNCLoweringFunction custom_lowering = getCustomLoweringFor(op)) {
return custom_lowering(argInputs, outputShape, outputType, device_);
return custom_lowering(
argInputs, outputShape, outputStrides, outputType, device_);
}
if (v->node()->maybeSchema()) {
if (NNCLoweringFunction lowering =
getStandardLoweringFor(c10::toString(v->node()->schema()))) {
return lowering(argInputs, outputShape, outputType, device_);
return lowering(
argInputs, outputShape, outputStrides, outputType, device_);
}
}
std::string msg = std::string("Unhandled node kind (in computeValue): ") +
Expand Down Expand Up @@ -995,28 +1037,53 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
auto const& outputs = input->owningGraph()->outputs();
std::unordered_set<const Value*> outputs_set(outputs.begin(), outputs.end());

auto is_concrete_cont = [](const torch::jit::Value* input) {
if (input->isCompleteTensor()) {
return isContiguous(input);
} else {
return false;
}
};

auto is_symbolic_cont = [](std::vector<torch::jit::StrideInput> desc) {
if (desc.size() == 1) {
return desc[0] == torch::jit::StrideInput::TENSOR_CONT;
} else {
return false;
}
};

Tensor result(nullptr, nullptr);
switch (t->kind()) {
case TypeKind::TensorType: {
auto tt = input->type()->cast<TensorType>();
bool contiguous_concrete_tensor =
(input->isCompleteTensor() && isContiguous(input));
bool contiguous_sym_tensor = false;
bool contiguous_concrete_tensor = is_concrete_cont(input);
bool contiguous_symbolic_tensor = false;
if (has_symbolic_shapes_) {
auto desc = getSymbolicInputStrideDesc(input);
contiguous_sym_tensor =
desc.size() == 1 && desc[0] == torch::jit::StrideInput::TENSOR_CONT;
contiguous_symbolic_tensor = is_symbolic_cont(desc);
}

// Get input size and strides
auto size_handles = sizesFromSymbolicShape(tt->symbolic_sizes());
auto inputTensorStrides = getInputStrides(input, size_handles);

// We don't need to copy the input if:
// 1) it is not an output AND
// 2) it is contiguous
bool contiguous = contiguous_concrete_tensor || contiguous_sym_tensor;
bool contiguous =
contiguous_concrete_tensor || contiguous_symbolic_tensor;
if (!outputs_set.count(input) && contiguous) {
BufHandle inBuffer(
"t" + input_name_map_[input],
sizesFromSymbolicShape(tt->symbolic_sizes()),
inputTensorStrides,
ToDtype(static_cast<ScalarType>(*tt->scalarType())));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
inBuffer.node()->is_contiguous() ||
inBuffer.node()->is_channels_last_1d_contiguous() ||
inBuffer.node()->is_contiguous(at::MemoryFormat::ChannelsLast) ||
inBuffer.node()->is_contiguous(at::MemoryFormat::ChannelsLast3d));
bufs_.emplace(input, inBuffer.node());
bufferArgs_.emplace_back(inBuffer);
break;
Expand All @@ -1025,8 +1092,6 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
// if the input isn't contiguous or is an output,
// write strided input into contiguous buffer that is
// then used in all further compute
auto size_handles = sizesFromSymbolicShape(tt->symbolic_sizes());
auto inputTensorStrides = getInputStrides(input, size_handles);
ExprHandle flat_size = 1;
for (size_t i = 0; i < size_handles.size(); ++i) {
auto size = size_handles[i];
Expand Down Expand Up @@ -1168,11 +1233,11 @@ Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(
"Ouput tensor has no corresponding bufs in the fuser."));
BufPtr buf = bufs_.at(v);
// output is contiguous, no work to do
if (tensorOutputStrideDesc_[v->offset()] ==
torch::jit::StrideInput::TENSOR_CONT) {
auto stride_desc = tensorOutputStrideDesc_[v->offset()];
if (stride_desc == torch::jit::StrideInput::TENSOR_CONT) {
return Tensor(buf, nullptr);
;
}

TORCH_INTERNAL_ASSERT(
tensorOutputStrideDesc_[v->offset()] ==
torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST);
Expand Down
Loading

0 comments on commit 429a80d

Please sign in to comment.