Skip to content

Commit

Permalink
Merge pull request #144 from magnatelee/triangular
Browse files Browse the repository at this point in the history
`numpy.tril` and `numpy.triu`
  • Loading branch information
magnatelee authored Dec 6, 2021
2 parents 4f11216 + de9b81c commit 5757fd9
Show file tree
Hide file tree
Showing 12 changed files with 455 additions and 25 deletions.
1 change: 1 addition & 0 deletions cunumeric/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class CuNumericOpCode(IntEnum):
SCALAR_UNARY_RED = _cunumeric.CUNUMERIC_SCALAR_UNARY_RED
TILE = _cunumeric.CUNUMERIC_TILE
TRANSPOSE = _cunumeric.CUNUMERIC_TRANSPOSE
TRILU = _cunumeric.CUNUMERIC_TRILU
UNARY_OP = _cunumeric.CUNUMERIC_UNARY_OP
UNARY_RED = _cunumeric.CUNUMERIC_UNARY_RED
WHERE = _cunumeric.CUNUMERIC_WHERE
Expand Down
22 changes: 20 additions & 2 deletions cunumeric/deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,24 @@ def transpose(self, rhs, axes, stacklevel=0, callsite=None):
assert lhs_array.ndim == len(axes)
lhs_array.base = rhs_array.base.transpose(axes)

@profile
@auto_convert([1])
@shadow_debug("trilu", [1])
def trilu(self, rhs, k, lower, stacklevel=0, callsite=None):
lhs = self.base
rhs = rhs._broadcast(lhs.shape)

task = self.context.create_task(CuNumericOpCode.TRILU)

task.add_output(lhs)
task.add_input(rhs)
task.add_scalar_arg(lower, bool)
task.add_scalar_arg(k, ty.int32)

task.add_alignment(lhs, rhs)

task.execute()

@profile
@auto_convert([1])
@shadow_debug("flip", [1])
Expand Down Expand Up @@ -1571,10 +1589,10 @@ def binary_reduction(
# Populate the Legate launcher
if op == BinaryOpCode.NOT_EQUAL:
redop = ReductionOp.ADD
self.fill(np.array(False))
self.fill(np.array(False), stacklevel=stacklevel + 1)
else:
redop = ReductionOp.MUL
self.fill(np.array(True))
self.fill(np.array(True), stacklevel=stacklevel + 1)
task = self.context.create_task(CuNumericOpCode.BINARY_RED)
task.add_reduction(lhs, redop)
task.add_input(rhs1)
Expand Down
14 changes: 14 additions & 0 deletions cunumeric/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,3 +1049,17 @@ def where(self, rhs1, rhs2, rhs3, stacklevel):
else:
self.array[:] = np.where(rhs1.array, rhs2.array, rhs3.array)
self.runtime.profile_callsite(stacklevel + 1, False)

def trilu(self, rhs, k, lower, stacklevel):
if self.shadow:
rhs = self.runtime.to_eager_array(rhs, stacklevel=stacklevel + 1)
elif self.deferred is None:
self.check_eager_args(stacklevel + 1, rhs)
if self.deferred is not None:
self.deferred.trilu(rhs, k, lower, stacklevel=stacklevel + 1)
else:
if lower:
self.array[:] = np.tril(rhs.array, k)
else:
self.array[:] = np.triu(rhs.array, k)
self.runtime.profile_callsite(stacklevel + 1, False)
22 changes: 22 additions & 0 deletions cunumeric/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,28 @@ def copy(a):
return result


@copy_docstring(np.tril)
def tril(m, k=0):
return trilu(m, k, True)


@copy_docstring(np.triu)
def triu(m, k=0):
return trilu(m, k, False)


def trilu(m, k, lower, stacklevel=2):
array = ndarray.convert_to_cunumeric_ndarray(m)
if array.ndim < 1:
raise TypeError("Array must be at least 1-D")
shape = m.shape if m.ndim >= 2 else m.shape * 2
result = ndarray(
shape, dtype=array.dtype, stacklevel=stacklevel + 1, inputs=(array,)
)
result._thunk.trilu(array._thunk, k, lower, stacklevel=2)
return result


# ### ARRAY MANIPULATION ROUTINES

# Changing array shape
Expand Down
3 changes: 3 additions & 0 deletions src/cunumeric.mk
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ GEN_CPU_SRC += cunumeric/ternary/where.cc \
cunumeric/matrix/dot.cc \
cunumeric/matrix/tile.cc \
cunumeric/matrix/transpose.cc \
cunumeric/matrix/trilu.cc \
cunumeric/matrix/util.cc \
cunumeric/random/rand.cc \
cunumeric/search/nonzero.cc \
Expand All @@ -59,6 +60,7 @@ GEN_CPU_SRC += cunumeric/ternary/where_omp.cc \
cunumeric/matrix/dot_omp.cc \
cunumeric/matrix/tile_omp.cc \
cunumeric/matrix/transpose_omp.cc \
cunumeric/matrix/trilu_omp.cc \
cunumeric/matrix/util_omp.cc \
cunumeric/random/rand_omp.cc \
cunumeric/search/nonzero_omp.cc \
Expand Down Expand Up @@ -93,6 +95,7 @@ GEN_GPU_SRC += cunumeric/ternary/where.cu \
cunumeric/matrix/dot.cu \
cunumeric/matrix/tile.cu \
cunumeric/matrix/transpose.cu \
cunumeric/matrix/trilu.cu \
cunumeric/random/rand.cu \
cunumeric/search/nonzero.cu \
cunumeric/stat/bincount.cu \
Expand Down
48 changes: 25 additions & 23 deletions src/cunumeric/cunumeric_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,31 @@

// Match these to CuNumericOpCode in cunumeric/config.py
enum CuNumericOpCode {
CUNUMERIC_ARANGE = 1,
CUNUMERIC_BINARY_OP = 2,
CUNUMERIC_BINARY_RED = 3,
CUNUMERIC_BINCOUNT = 4,
CUNUMERIC_CONVERT = 5,
CUNUMERIC_CONVOLVE = 6,
CUNUMERIC_DIAG = 7,
CUNUMERIC_DOT = 8,
CUNUMERIC_EYE = 9,
CUNUMERIC_FILL = 10,
CUNUMERIC_FLIP = 11,
CUNUMERIC_MATMUL = 12,
CUNUMERIC_MATVECMUL = 13,
CUNUMERIC_NONZERO = 14,
CUNUMERIC_RAND = 15,
CUNUMERIC_READ = 16,
CUNUMERIC_SCALAR_UNARY_RED = 17,
CUNUMERIC_TILE = 18,
CUNUMERIC_TRANSPOSE = 19,
CUNUMERIC_UNARY_OP = 20,
CUNUMERIC_UNARY_RED = 21,
CUNUMERIC_WHERE = 22,
CUNUMERIC_WRITE = 23,
_CUNUMERIC_OP_CODE_BASE = 0,
CUNUMERIC_ARANGE,
CUNUMERIC_BINARY_OP,
CUNUMERIC_BINARY_RED,
CUNUMERIC_BINCOUNT,
CUNUMERIC_CONVERT,
CUNUMERIC_CONVOLVE,
CUNUMERIC_DIAG,
CUNUMERIC_DOT,
CUNUMERIC_EYE,
CUNUMERIC_FILL,
CUNUMERIC_FLIP,
CUNUMERIC_MATMUL,
CUNUMERIC_MATVECMUL,
CUNUMERIC_NONZERO,
CUNUMERIC_RAND,
CUNUMERIC_READ,
CUNUMERIC_SCALAR_UNARY_RED,
CUNUMERIC_TILE,
CUNUMERIC_TRANSPOSE,
CUNUMERIC_TRILU,
CUNUMERIC_UNARY_OP,
CUNUMERIC_UNARY_RED,
CUNUMERIC_WHERE,
CUNUMERIC_WRITE,
};

// Match these to CuNumericRedopCode in cunumeric/config.py
Expand Down
65 changes: 65 additions & 0 deletions src/cunumeric/matrix/trilu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/* Copyright 2021 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

#include "cunumeric/matrix/trilu.h"
#include "cunumeric/matrix/trilu_template.inl"

namespace cunumeric {

using namespace Legion;
using namespace legate;

template <LegateTypeCode CODE, int32_t DIM, bool LOWER>
struct TriluImplBody<VariantKind::CPU, CODE, DIM, LOWER> {
using VAL = legate_type_of<CODE>;

void operator()(const AccessorWO<VAL, DIM>& out,
const AccessorRO<VAL, DIM>& in,
const Pitches<DIM - 1>& pitches,
const Point<DIM>& lo,
size_t volume,
int32_t k) const
{
if (LOWER)
for (size_t idx = 0; idx < volume; ++idx) {
auto p = pitches.unflatten(idx, lo);
if (p[DIM - 2] + k >= p[DIM - 1])
out[p] = in[p];
else
out[p] = 0;
}
else
for (size_t idx = 0; idx < volume; ++idx) {
auto p = pitches.unflatten(idx, lo);
if (p[DIM - 2] + k <= p[DIM - 1])
out[p] = in[p];
else
out[p] = 0;
}
}
};

/*static*/ void TriluTask::cpu_variant(TaskContext& context)
{
trilu_template<VariantKind::CPU>(context);
}

namespace // unnamed
{
static void __attribute__((constructor)) register_tasks(void) { TriluTask::register_variants(); }
} // namespace

} // namespace cunumeric
75 changes: 75 additions & 0 deletions src/cunumeric/matrix/trilu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/* Copyright 2021 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

#include "cunumeric/matrix/trilu.h"
#include "cunumeric/matrix/trilu_template.inl"

#include "cunumeric/cuda_help.h"

namespace cunumeric {

using namespace Legion;
using namespace legate;

template <typename VAL, int32_t DIM, bool LOWER>
static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM)
trilu_kernel(AccessorWO<VAL, DIM> out,
AccessorRO<VAL, DIM> in,
Pitches<DIM - 1> pitches,
Point<DIM> lo,
size_t volume,
int32_t k)
{
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= volume) return;

if (LOWER) {
auto p = pitches.unflatten(idx, lo);
if (p[DIM - 2] + k >= p[DIM - 1])
out[p] = in[p];
else
out[p] = 0;
} else {
auto p = pitches.unflatten(idx, lo);
if (p[DIM - 2] + k <= p[DIM - 1])
out[p] = in[p];
else
out[p] = 0;
}
}

template <LegateTypeCode CODE, int32_t DIM, bool LOWER>
struct TriluImplBody<VariantKind::GPU, CODE, DIM, LOWER> {
using VAL = legate_type_of<CODE>;

void operator()(const AccessorWO<VAL, DIM>& out,
const AccessorRO<VAL, DIM>& in,
const Pitches<DIM - 1>& pitches,
const Point<DIM>& lo,
size_t volume,
int32_t k) const
{
const size_t blocks = (volume + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
trilu_kernel<VAL, DIM, LOWER><<<blocks, THREADS_PER_BLOCK>>>(out, in, pitches, lo, volume, k);
}
};

/*static*/ void TriluTask::gpu_variant(TaskContext& context)
{
trilu_template<VariantKind::GPU>(context);
}

} // namespace cunumeric
44 changes: 44 additions & 0 deletions src/cunumeric/matrix/trilu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/* Copyright 2021 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

#pragma once

#include "cunumeric/cunumeric.h"

namespace cunumeric {

struct TriluArgs {
bool lower;
int32_t k;
const Array& output;
const Array& input;
};

class TriluTask : public CuNumericTask<TriluTask> {
public:
static const int TASK_ID = CUNUMERIC_TRILU;

public:
static void cpu_variant(legate::TaskContext& context);
#ifdef LEGATE_USE_OPENMP
static void omp_variant(legate::TaskContext& context);
#endif
#ifdef LEGATE_USE_CUDA
static void gpu_variant(legate::TaskContext& context);
#endif
};

} // namespace cunumeric
Loading

0 comments on commit 5757fd9

Please sign in to comment.