diff --git a/cunumeric/config.py b/cunumeric/config.py index 101b677cc..33753eaa1 100644 --- a/cunumeric/config.py +++ b/cunumeric/config.py @@ -99,6 +99,7 @@ class CuNumericOpCode(IntEnum): SCALAR_UNARY_RED = _cunumeric.CUNUMERIC_SCALAR_UNARY_RED TILE = _cunumeric.CUNUMERIC_TILE TRANSPOSE = _cunumeric.CUNUMERIC_TRANSPOSE + TRILU = _cunumeric.CUNUMERIC_TRILU UNARY_OP = _cunumeric.CUNUMERIC_UNARY_OP UNARY_RED = _cunumeric.CUNUMERIC_UNARY_RED WHERE = _cunumeric.CUNUMERIC_WHERE diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index fd1ad2623..0fb2230bb 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -1277,6 +1277,24 @@ def transpose(self, rhs, axes, stacklevel=0, callsite=None): assert lhs_array.ndim == len(axes) lhs_array.base = rhs_array.base.transpose(axes) + @profile + @auto_convert([1]) + @shadow_debug("trilu", [1]) + def trilu(self, rhs, k, lower, stacklevel=0, callsite=None): + lhs = self.base + rhs = rhs._broadcast(lhs.shape) + + task = self.context.create_task(CuNumericOpCode.TRILU) + + task.add_output(lhs) + task.add_input(rhs) + task.add_scalar_arg(lower, bool) + task.add_scalar_arg(k, ty.int32) + + task.add_alignment(lhs, rhs) + + task.execute() + @profile @auto_convert([1]) @shadow_debug("flip", [1]) @@ -1571,10 +1589,10 @@ def binary_reduction( # Populate the Legate launcher if op == BinaryOpCode.NOT_EQUAL: redop = ReductionOp.ADD - self.fill(np.array(False)) + self.fill(np.array(False), stacklevel=stacklevel + 1) else: redop = ReductionOp.MUL - self.fill(np.array(True)) + self.fill(np.array(True), stacklevel=stacklevel + 1) task = self.context.create_task(CuNumericOpCode.BINARY_RED) task.add_reduction(lhs, redop) task.add_input(rhs1) diff --git a/cunumeric/eager.py b/cunumeric/eager.py index a69093bb4..30b6ed39b 100644 --- a/cunumeric/eager.py +++ b/cunumeric/eager.py @@ -1049,3 +1049,17 @@ def where(self, rhs1, rhs2, rhs3, stacklevel): else: self.array[:] = np.where(rhs1.array, rhs2.array, rhs3.array) self.runtime.profile_callsite(stacklevel + 1, False) + + def trilu(self, rhs, k, lower, stacklevel): + if self.shadow: + rhs = self.runtime.to_eager_array(rhs, stacklevel=stacklevel + 1) + elif self.deferred is None: + self.check_eager_args(stacklevel + 1, rhs) + if self.deferred is not None: + self.deferred.trilu(rhs, k, lower, stacklevel=stacklevel + 1) + else: + if lower: + self.array[:] = np.tril(rhs.array, k) + else: + self.array[:] = np.triu(rhs.array, k) + self.runtime.profile_callsite(stacklevel + 1, False) diff --git a/cunumeric/module.py b/cunumeric/module.py index 939b3666c..cce1b3acb 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -430,6 +430,28 @@ def copy(a): return result +@copy_docstring(np.tril) +def tril(m, k=0): + return trilu(m, k, True) + + +@copy_docstring(np.triu) +def triu(m, k=0): + return trilu(m, k, False) + + +def trilu(m, k, lower, stacklevel=2): + array = ndarray.convert_to_cunumeric_ndarray(m) + if array.ndim < 1: + raise TypeError("Array must be at least 1-D") + shape = m.shape if m.ndim >= 2 else m.shape * 2 + result = ndarray( + shape, dtype=array.dtype, stacklevel=stacklevel + 1, inputs=(array,) + ) + result._thunk.trilu(array._thunk, k, lower, stacklevel=2) + return result + + # ### ARRAY MANIPULATION ROUTINES # Changing array shape diff --git a/src/cunumeric.mk b/src/cunumeric.mk index d42a6ae6f..e4d616776 100644 --- a/src/cunumeric.mk +++ b/src/cunumeric.mk @@ -33,6 +33,7 @@ GEN_CPU_SRC += cunumeric/ternary/where.cc \ cunumeric/matrix/dot.cc \ cunumeric/matrix/tile.cc \ cunumeric/matrix/transpose.cc \ + cunumeric/matrix/trilu.cc \ cunumeric/matrix/util.cc \ cunumeric/random/rand.cc \ cunumeric/search/nonzero.cc \ @@ -59,6 +60,7 @@ GEN_CPU_SRC += cunumeric/ternary/where_omp.cc \ cunumeric/matrix/dot_omp.cc \ cunumeric/matrix/tile_omp.cc \ cunumeric/matrix/transpose_omp.cc \ + cunumeric/matrix/trilu_omp.cc \ cunumeric/matrix/util_omp.cc \ cunumeric/random/rand_omp.cc \ cunumeric/search/nonzero_omp.cc \ @@ -93,6 +95,7 @@ GEN_GPU_SRC += cunumeric/ternary/where.cu \ cunumeric/matrix/dot.cu \ cunumeric/matrix/tile.cu \ cunumeric/matrix/transpose.cu \ + cunumeric/matrix/trilu.cu \ cunumeric/random/rand.cu \ cunumeric/search/nonzero.cu \ cunumeric/stat/bincount.cu \ diff --git a/src/cunumeric/cunumeric_c.h b/src/cunumeric/cunumeric_c.h index c0d111d57..41cc6baee 100644 --- a/src/cunumeric/cunumeric_c.h +++ b/src/cunumeric/cunumeric_c.h @@ -21,29 +21,31 @@ // Match these to CuNumericOpCode in cunumeric/config.py enum CuNumericOpCode { - CUNUMERIC_ARANGE = 1, - CUNUMERIC_BINARY_OP = 2, - CUNUMERIC_BINARY_RED = 3, - CUNUMERIC_BINCOUNT = 4, - CUNUMERIC_CONVERT = 5, - CUNUMERIC_CONVOLVE = 6, - CUNUMERIC_DIAG = 7, - CUNUMERIC_DOT = 8, - CUNUMERIC_EYE = 9, - CUNUMERIC_FILL = 10, - CUNUMERIC_FLIP = 11, - CUNUMERIC_MATMUL = 12, - CUNUMERIC_MATVECMUL = 13, - CUNUMERIC_NONZERO = 14, - CUNUMERIC_RAND = 15, - CUNUMERIC_READ = 16, - CUNUMERIC_SCALAR_UNARY_RED = 17, - CUNUMERIC_TILE = 18, - CUNUMERIC_TRANSPOSE = 19, - CUNUMERIC_UNARY_OP = 20, - CUNUMERIC_UNARY_RED = 21, - CUNUMERIC_WHERE = 22, - CUNUMERIC_WRITE = 23, + _CUNUMERIC_OP_CODE_BASE = 0, + CUNUMERIC_ARANGE, + CUNUMERIC_BINARY_OP, + CUNUMERIC_BINARY_RED, + CUNUMERIC_BINCOUNT, + CUNUMERIC_CONVERT, + CUNUMERIC_CONVOLVE, + CUNUMERIC_DIAG, + CUNUMERIC_DOT, + CUNUMERIC_EYE, + CUNUMERIC_FILL, + CUNUMERIC_FLIP, + CUNUMERIC_MATMUL, + CUNUMERIC_MATVECMUL, + CUNUMERIC_NONZERO, + CUNUMERIC_RAND, + CUNUMERIC_READ, + CUNUMERIC_SCALAR_UNARY_RED, + CUNUMERIC_TILE, + CUNUMERIC_TRANSPOSE, + CUNUMERIC_TRILU, + CUNUMERIC_UNARY_OP, + CUNUMERIC_UNARY_RED, + CUNUMERIC_WHERE, + CUNUMERIC_WRITE, }; // Match these to CuNumericRedopCode in cunumeric/config.py diff --git a/src/cunumeric/matrix/trilu.cc b/src/cunumeric/matrix/trilu.cc new file mode 100644 index 000000000..ccfe2c7bb --- /dev/null +++ b/src/cunumeric/matrix/trilu.cc @@ -0,0 +1,65 @@ +/* Copyright 2021 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/matrix/trilu.h" +#include "cunumeric/matrix/trilu_template.inl" + +namespace cunumeric { + +using namespace Legion; +using namespace legate; + +template +struct TriluImplBody { + using VAL = legate_type_of; + + void operator()(const AccessorWO& out, + const AccessorRO& in, + const Pitches& pitches, + const Point& lo, + size_t volume, + int32_t k) const + { + if (LOWER) + for (size_t idx = 0; idx < volume; ++idx) { + auto p = pitches.unflatten(idx, lo); + if (p[DIM - 2] + k >= p[DIM - 1]) + out[p] = in[p]; + else + out[p] = 0; + } + else + for (size_t idx = 0; idx < volume; ++idx) { + auto p = pitches.unflatten(idx, lo); + if (p[DIM - 2] + k <= p[DIM - 1]) + out[p] = in[p]; + else + out[p] = 0; + } + } +}; + +/*static*/ void TriluTask::cpu_variant(TaskContext& context) +{ + trilu_template(context); +} + +namespace // unnamed +{ +static void __attribute__((constructor)) register_tasks(void) { TriluTask::register_variants(); } +} // namespace + +} // namespace cunumeric diff --git a/src/cunumeric/matrix/trilu.cu b/src/cunumeric/matrix/trilu.cu new file mode 100644 index 000000000..111a5c7de --- /dev/null +++ b/src/cunumeric/matrix/trilu.cu @@ -0,0 +1,75 @@ +/* Copyright 2021 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/matrix/trilu.h" +#include "cunumeric/matrix/trilu_template.inl" + +#include "cunumeric/cuda_help.h" + +namespace cunumeric { + +using namespace Legion; +using namespace legate; + +template +static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) + trilu_kernel(AccessorWO out, + AccessorRO in, + Pitches pitches, + Point lo, + size_t volume, + int32_t k) +{ + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= volume) return; + + if (LOWER) { + auto p = pitches.unflatten(idx, lo); + if (p[DIM - 2] + k >= p[DIM - 1]) + out[p] = in[p]; + else + out[p] = 0; + } else { + auto p = pitches.unflatten(idx, lo); + if (p[DIM - 2] + k <= p[DIM - 1]) + out[p] = in[p]; + else + out[p] = 0; + } +} + +template +struct TriluImplBody { + using VAL = legate_type_of; + + void operator()(const AccessorWO& out, + const AccessorRO& in, + const Pitches& pitches, + const Point& lo, + size_t volume, + int32_t k) const + { + const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + trilu_kernel<<>>(out, in, pitches, lo, volume, k); + } +}; + +/*static*/ void TriluTask::gpu_variant(TaskContext& context) +{ + trilu_template(context); +} + +} // namespace cunumeric diff --git a/src/cunumeric/matrix/trilu.h b/src/cunumeric/matrix/trilu.h new file mode 100644 index 000000000..7e87656b5 --- /dev/null +++ b/src/cunumeric/matrix/trilu.h @@ -0,0 +1,44 @@ +/* Copyright 2021 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#pragma once + +#include "cunumeric/cunumeric.h" + +namespace cunumeric { + +struct TriluArgs { + bool lower; + int32_t k; + const Array& output; + const Array& input; +}; + +class TriluTask : public CuNumericTask { + public: + static const int TASK_ID = CUNUMERIC_TRILU; + + public: + static void cpu_variant(legate::TaskContext& context); +#ifdef LEGATE_USE_OPENMP + static void omp_variant(legate::TaskContext& context); +#endif +#ifdef LEGATE_USE_CUDA + static void gpu_variant(legate::TaskContext& context); +#endif +}; + +} // namespace cunumeric diff --git a/src/cunumeric/matrix/trilu_omp.cc b/src/cunumeric/matrix/trilu_omp.cc new file mode 100644 index 000000000..4935131e7 --- /dev/null +++ b/src/cunumeric/matrix/trilu_omp.cc @@ -0,0 +1,62 @@ +/* Copyright 2021 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/matrix/trilu.h" +#include "cunumeric/matrix/trilu_template.inl" + +namespace cunumeric { + +using namespace Legion; +using namespace legate; + +template +struct TriluImplBody { + using VAL = legate_type_of; + + void operator()(const AccessorWO& out, + const AccessorRO& in, + const Pitches& pitches, + const Point& lo, + size_t volume, + int32_t k) const + { + if (LOWER) +#pragma omp parallel for schedule(static) + for (size_t idx = 0; idx < volume; ++idx) { + auto p = pitches.unflatten(idx, lo); + if (p[DIM - 2] + k >= p[DIM - 1]) + out[p] = in[p]; + else + out[p] = 0; + } + else +#pragma omp parallel for schedule(static) + for (size_t idx = 0; idx < volume; ++idx) { + auto p = pitches.unflatten(idx, lo); + if (p[DIM - 2] + k <= p[DIM - 1]) + out[p] = in[p]; + else + out[p] = 0; + } + } +}; + +/*static*/ void TriluTask::omp_variant(TaskContext& context) +{ + trilu_template(context); +} + +} // namespace cunumeric diff --git a/src/cunumeric/matrix/trilu_template.inl b/src/cunumeric/matrix/trilu_template.inl new file mode 100644 index 000000000..989a43a57 --- /dev/null +++ b/src/cunumeric/matrix/trilu_template.inl @@ -0,0 +1,68 @@ +/* Copyright 2021 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/pitches.h" + +namespace cunumeric { + +using namespace Legion; +using namespace legate; + +template +struct TriluImplBody; + +template +struct TriluImpl { + template = 2)>* = nullptr> + void operator()(TriluArgs& args) const + { + using VAL = legate_type_of; + + auto shape = args.output.shape(); + + Pitches pitches; + size_t volume = pitches.flatten(shape); + + if (volume == 0) return; + + auto out = args.output.write_accessor(shape); + auto in = args.input.read_accessor(shape); + if (args.lower) + TriluImplBody()(out, in, pitches, shape.lo, volume, args.k); + else + TriluImplBody()(out, in, pitches, shape.lo, volume, args.k); + } + + template * = nullptr> + void operator()(TriluArgs& args) const + { + assert(false); + } +}; + +template +static void trilu_template(TaskContext& context) +{ + auto& scalars = context.scalars(); + auto lower = scalars[0].value(); + auto k = scalars[1].value(); + auto& input = context.inputs()[0]; + auto& output = context.outputs()[0]; + TriluArgs args{lower, k, output, input}; + double_dispatch(args.output.dim(), args.output.code(), TriluImpl{}, args); +} + +} // namespace cunumeric diff --git a/tests/trilu.py b/tests/trilu.py new file mode 100644 index 000000000..a938c12b7 --- /dev/null +++ b/tests/trilu.py @@ -0,0 +1,56 @@ +# Copyright 2021 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np + +import cunumeric as num + + +def test(): + for f in ["tril", "triu"]: + num_f = getattr(num, f) + np_f = getattr(np, f) + for k in [0, -1, 1, -2, 2]: + print(f"{f}(k={k})") + a = num.array( + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + [17, 18, 19, 20], + ] + ) + an = np.array( + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + [17, 18, 19, 20], + ] + ) + + b = num_f(a, k=k) + bn = np_f(an, k=k) + assert num.array_equal(b, bn) + + b = num_f(a[0, :], k=k) + bn = np_f(an[0, :], k=k) + assert num.array_equal(b, bn) + + +if __name__ == "__main__": + test()