Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MatrixRef to take a sub-matrix from an existing Matrix #934

Merged
merged 19 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions include/dlaf/matrix/distribution.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,17 @@ class Distribution {

Distribution& operator=(Distribution&& rhs) noexcept;

/// Constructs a sub-distribution based on the given distribution @p dist with
/// an @p offset and @p size.
///
/// @param[in] dist is the input distribution,
/// @param[in] offset is the offset of the new distribution relative to the input distribution,
/// @param[in] size is the size of the new distribution relative to the offset,
/// @pre origin.isValid()
/// @pre size.isValid()
/// @pre origin + size <= dist.size()
Distribution(Distribution dist, const GlobalElementIndex& offset, const GlobalElementSize& size);
msimberg marked this conversation as resolved.
Show resolved Hide resolved

bool operator==(const Distribution& rhs) const noexcept {
return size_ == rhs.size_ && local_size_ == rhs.local_size_ && tile_size_ == rhs.tile_size_ &&
block_size_ == rhs.block_size_ && global_nr_tiles_ == rhs.global_nr_tiles_ &&
Expand Down Expand Up @@ -490,6 +501,30 @@ class Distribution {
localElementDistanceFromLocalTile<Coord::Col>(begin.col(), end.col())};
}

/// Returns the tile index in the current distribution corresponding to a tile index @p sub_index in a
/// sub-distribution (defined by @p sub_offset and @p sub_distribution)
GlobalTileIndex globalTileIndexFromSubDistribution(const GlobalElementIndex& sub_offset,
const Distribution& sub_distribution,
const GlobalTileIndex& sub_index) const noexcept {
DLAF_ASSERT(sub_index.isIn(sub_distribution.nrTiles()), sub_index, sub_distribution.nrTiles());
DLAF_ASSERT(isCompatibleSubDistribution(sub_offset, sub_distribution), "");
const GlobalTileIndex tile_offset = globalTileIndex(sub_offset);
return tile_offset + common::sizeFromOrigin(sub_index);
}

/// Returns the element offset within the tile in the current distribution corresponding to a tile
/// index @p sub_index in a sub-distribution (defined by @p sub_offset and @p sub_distribution)
TileElementIndex tileElementOffsetFromSubDistribution(
msimberg marked this conversation as resolved.
Show resolved Hide resolved
const GlobalElementIndex& sub_offset, const Distribution& sub_distribution,
const GlobalTileIndex& sub_index) const noexcept {
DLAF_ASSERT(sub_index.isIn(sub_distribution.nrTiles()), sub_index, sub_distribution.nrTiles());
DLAF_ASSERT(isCompatibleSubDistribution(sub_offset, sub_distribution), "");
return {
sub_index.row() == 0 ? tileElementFromGlobalElement<Coord::Row>(sub_offset.row()) : 0,
sub_index.col() == 0 ? tileElementFromGlobalElement<Coord::Col>(sub_offset.col()) : 0,
};
}

private:
/// @pre block_size_, and tile_size_ are already set correctly.
template <Coord rc>
Expand Down Expand Up @@ -564,6 +599,25 @@ class Distribution {
/// @post offset_.row() < block_size_.rows() && offset_.col() < block_size_.cols()
void normalizeSourceRankAndOffset() noexcept;

/// Checks if another distribution is a compatible sub-distribution of the current distribution.
///
/// Compatible means that the block size, tile size, rank index, and grid size are equal.
/// Sub-distribution means that the source rank index of the sub-distribution is the rank index
/// of the tile at sub_offset in the current distribution. Additionally, the size and offset of
/// the sub-distribution must be within the size of the current distribution.
bool isCompatibleSubDistribution(const GlobalElementIndex& sub_offset,
const Distribution& sub_distribution) const noexcept {
const bool compatibleGrid = blockSize() == sub_distribution.blockSize() &&
baseTileSize() == sub_distribution.baseTileSize() &&
rankIndex() == sub_distribution.rankIndex() &&
commGridSize() == sub_distribution.commGridSize();
const bool compatibleSourceRankIndex =
rankGlobalTile(globalTileIndex(sub_offset)) == sub_distribution.sourceRankIndex();
const bool compatibleSize = sub_offset.row() + sub_distribution.size().rows() <= size().rows() &&
sub_offset.col() + sub_distribution.size().cols() <= size().cols();
return compatibleGrid && compatibleSourceRankIndex && compatibleSize;
}

/// Sets default values.
///
/// offset_ = {0, 0}
Expand Down
175 changes: 175 additions & 0 deletions include/dlaf/matrix/matrix_ref.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
//
// Distributed Linear Algebra with Future (DLAF)
//
// Copyright (c) 2018-2023, ETH Zurich
// All rights reserved.
//
// Please, refer to the LICENSE file in the root directory.
// SPDX-License-Identifier: BSD-3-Clause
//

#pragma once

/// @file

#include <dlaf/matrix/distribution.h>
#include <dlaf/matrix/matrix.h>
#include <dlaf/matrix/matrix_base.h>
#include <dlaf/matrix/tile.h>
#include <dlaf/types.h>

namespace dlaf::matrix {
/// A @c MatrixRef represents a sub-matrix of a @c Matrix.
///
/// The class has reference semantics, meaning accesses to a @c MatrixRef and
/// it's corresponding @c Matrix are interleaved if calls to read/readwrite are
/// interleaved. Access to a @c MatrixRef and its corresponding @c Matrix is not
/// thread-safe. A @c MatrixRef must outlive its corresponding @c Matrix.
template <class T, Device D>
class MatrixRef;

msimberg marked this conversation as resolved.
Show resolved Hide resolved
template <class T, Device D>
class MatrixRef<const T, D> : public internal::MatrixBase {
public:
static constexpr Device device = D;

using ElementType = T;
using TileType = Tile<ElementType, D>;
using ConstTileType = Tile<const ElementType, D>;
using TileDataType = internal::TileData<ElementType, D>;
using ReadOnlySenderType = ReadOnlyTileSender<T, D>;

/// Create a sub-matrix of @p mat with an @p offset and @p size.
///
/// @param[in] mat is the input matrix,
/// @param[in] offset is the offset of the new matrix relative to the input matrix,
/// @param[in] size is the size of the new matrix relative to the offset,
/// @pre origin.isValid()
/// @pre size.isValid()
/// @pre origin + size <= mat.size()
MatrixRef(Matrix<const T, D>& mat, const GlobalElementIndex& offset, const GlobalElementSize& size)
: internal::MatrixBase(Distribution(mat.distribution(), offset, size)), mat_const_(mat),
offset_(offset) {}

// TODO: default, copy, move construction?
// - default: no, don't want empty MatrixRef
// - copy: implementable, still refer to the original matrix
// - move: implement as copy, i.e. still refer to original matrix?
MatrixRef() = delete;
msimberg marked this conversation as resolved.
Show resolved Hide resolved

// TODO: Do we need access to the original matrix? e.g:
// Matrix& get() const noexcept { return mat_const_; }
msimberg marked this conversation as resolved.
Show resolved Hide resolved

/// Returns a read-only sender of the Tile with local index @p index.
///
/// @pre index.isIn(distribution().localNrTiles()).
ReadOnlySenderType read(const LocalTileIndex& index) noexcept {
// Note: this forwards to the overload with GlobalTileIndex which will
// handle taking a subtile if needed
return read(distribution().globalTileIndex(index));
}

/// Returns a read-only sender of the Tile with global index @p index.
///
/// @pre the global tile is stored in the current process,
/// @pre index.isIn(globalNrTiles()).
ReadOnlySenderType read(const GlobalTileIndex& index) {
DLAF_ASSERT(index.isIn(distribution().nrTiles()), index, distribution().nrTiles());

const auto parent_index(
mat_const_.distribution().globalTileIndexFromSubDistribution(offset_, distribution(), index));
auto tile_sender = mat_const_.read(parent_index);

const auto parent_dist = mat_const_.distribution();
const auto parent_tile_size = parent_dist.tileSize(parent_index);
const auto tile_size = tileSize(index);

// If the corresponding tile in the parent distribution is exactly the same
// size as the tile in the sub-distribution, we don't need to take a subtile
// and can return the tile sender directly. This avoids unnecessary wrapping.
if (parent_tile_size == tile_size) {
return tile_sender;
}

// Otherwise we have to extract a subtile from the tile in the parent
// distribution.
const auto ij_tile =
parent_dist.tileElementOffsetFromSubDistribution(offset_, distribution(), index);
return splitTile(std::move(tile_sender), SubTileSpec{ij_tile, tile_size});
}

private:
Matrix<const T, D>& mat_const_;

protected:
GlobalElementIndex offset_;
};

template <class T, Device D>
class MatrixRef : public MatrixRef<const T, D> {
public:
static constexpr Device device = D;

using ElementType = T;
using TileType = Tile<ElementType, D>;
using ConstTileType = Tile<const ElementType, D>;
using TileDataType = internal::TileData<ElementType, D>;
using ReadWriteSenderType = ReadWriteTileSender<T, D>;

/// Create a sub-matrix of @p mat with an @p offset and @p size.
///
/// @param[in] mat is the input matrix,
/// @param[in] offset is the offset of the new matrix relative to the input matrix,
/// @param[in] size is the size of the new matrix relative to the offset,
/// @pre origin.isValid()
/// @pre size.isValid()
/// @pre origin + size <= mat.size()
MatrixRef(Matrix<T, D>& mat, const GlobalElementIndex& offset, const GlobalElementSize& size)
: MatrixRef<const T, D>(mat, offset, size), mat_(mat) {}

// TODO: default, copy, move construction?
MatrixRef() = delete;

/// Returns a sender of the Tile with local index @p index.
///
/// @pre index.isIn(distribution().localNrTiles()).
ReadWriteSenderType readwrite(const LocalTileIndex& index) noexcept {
// Note: this forwards to the overload with GlobalTileIndex which will
// handle taking a subtile if needed
return readwrite(this->distribution().globalTileIndex(index));
}

/// Returns a sender of the Tile with global index @p index.
///
/// @pre the global tile is stored in the current process,
/// @pre index.isIn(globalNrTiles()).
ReadWriteSenderType readwrite(const GlobalTileIndex& index) {
DLAF_ASSERT(index.isIn(this->distribution().nrTiles()), index, this->distribution().nrTiles());

const auto parent_index(
mat_.distribution().globalTileIndexFromSubDistribution(offset_, this->distribution(), index));
auto tile_sender = mat_.readwrite(parent_index);

const auto parent_dist = mat_.distribution();
const auto parent_tile_size = parent_dist.tileSize(parent_index);
const auto tile_size = this->tileSize(index);

// If the corresponding tile in the parent distribution is exactly the same
// size as the tile in the sub-distribution, we don't need to take a subtile
// and can return the tile sender directly. This avoids unnecessary wrapping.
if (parent_tile_size == tile_size) {
return tile_sender;
}

// Otherwise we have to extract a subtile from the tile in the parent
// distribution.
const auto ij_tile =
parent_dist.tileElementOffsetFromSubDistribution(offset_, this->distribution(), index);
return splitTile(std::move(tile_sender), SubTileSpec{ij_tile, tile_size});
}

private:
Matrix<T, D>& mat_;
using MatrixRef<const T, D>::offset_;
};
}
msimberg marked this conversation as resolved.
Show resolved Hide resolved
14 changes: 14 additions & 0 deletions src/matrix/distribution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,20 @@ Distribution& Distribution::operator=(Distribution&& rhs) noexcept {
return *this;
}

Distribution::Distribution(Distribution rhs, const GlobalElementIndex& sub_offset,
const GlobalElementSize& size)
: Distribution(std::move(rhs)) {
DLAF_ASSERT(sub_offset.isValid(), sub_offset);
DLAF_ASSERT(size.isValid(), size);
DLAF_ASSERT(sub_offset.row() + size.rows() <= size_.rows(), sub_offset, size_);
DLAF_ASSERT(sub_offset.col() + size.cols() <= size_.cols(), sub_offset, size_);

offset_ = offset_ + sizeFromOrigin(sub_offset);
size_ = size;

computeGlobalAndLocalNrTilesAndLocalSize();
}

void Distribution::computeGlobalSizeForNonDistr() noexcept {
size_ = GlobalElementSize(local_size_.rows(), local_size_.cols());
}
Expand Down
8 changes: 8 additions & 0 deletions test/unit/matrix/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ DLAF_addTest(
MPIRANKS 6
)

DLAF_addTest(
test_matrix_ref
SOURCES test_matrix_ref.cpp
LIBRARIES dlaf.core
USE_MAIN MPIPIKA
MPIRANKS 6
)

DLAF_addTest(
test_panel
SOURCES test_panel.cpp
Expand Down
75 changes: 75 additions & 0 deletions test/unit/matrix/test_distribution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -633,3 +633,78 @@ TEST(DistributionTest, LocalElementDistanceFromGlobalTile) {
obj.localElementDistanceFromGlobalTile(test.global_tile_begin, test.global_tile_end));
}
}

struct ParametersSubDistribution {
// Distribution settings
GlobalElementSize size;
TileElementSize block_size;
comm::Index2D rank;
comm::Size2D grid_size;
comm::Index2D src_rank;
GlobalElementIndex offset;
// Sub-distribution settings
GlobalElementIndex sub_offset;
GlobalElementSize sub_size;
// Valid indices
GlobalElementIndex global_element;
GlobalTileIndex global_tile;
comm::Index2D rank_tile;
std::array<SizeType, 2> local_tile; // can be an invalid LocalTileIndex
};

const std::vector<ParametersSubDistribution> tests_sub_distribution = {
// {size, block_size, rank, grid_size, src_rank, offset, sub_offset, sub_size,
// global_element, global_tile, rank_tile, local_tile}
// Empty distribution
{{0, 0}, {2, 5}, {0, 0}, {1, 1}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}},
{{0, 0}, {2, 5}, {0, 0}, {1, 1}, {0, 0}, {4, 8}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}},
// Empty sub-distribution
{{3, 4}, {2, 2}, {0, 0}, {1, 1}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}},
{{3, 4}, {2, 2}, {0, 0}, {1, 1}, {0, 0}, {0, 0}, {2, 3}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}},
{{5, 9}, {3, 2}, {1, 1}, {2, 4}, {0, 2}, {1, 1}, {4, 5}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}},
// Sub-distribution == distribution
{{3, 4}, {2, 2}, {0, 0}, {1, 1}, {0, 0}, {0, 0}, {0, 0}, {3, 4}, {1, 3}, {0, 1}, {0, 0}, {0, 1}},
{{5, 9}, {3, 2}, {1, 1}, {2, 4}, {0, 2}, {1, 1}, {0, 0}, {5, 9}, {1, 3}, {0, 2}, {0, 0}, {-1, -1}},
// clang-format off
{{123, 59}, {32, 16}, {3, 3}, {5, 7}, {3, 1}, {1, 1}, {0, 0}, {123, 59}, {30, 30}, {0, 1}, {3, 2}, {0, -1}},
// clang-format on
// Other sub-distributions
{{3, 4}, {2, 2}, {0, 0}, {1, 1}, {0, 0}, {0, 0}, {1, 2}, {2, 1}, {0, 0}, {0, 0}, {0, 0}, {0, 0}},
{{3, 4}, {2, 2}, {0, 0}, {1, 1}, {0, 0}, {0, 0}, {1, 2}, {2, 1}, {1, 0}, {1, 0}, {0, 0}, {1, 0}},
{{5, 9}, {3, 2}, {1, 1}, {2, 4}, {0, 2}, {1, 1}, {3, 4}, {2, 3}, {0, 0}, {0, 0}, {1, 0}, {0, -1}},
{{5, 9}, {3, 2}, {1, 1}, {2, 4}, {0, 2}, {1, 1}, {3, 4}, {2, 3}, {1, 2}, {0, 1}, {1, 1}, {0, 0}},
// clang-format off
{{123, 59}, {32, 16}, {3, 3}, {5, 7}, {3, 1}, {1, 1}, {50, 17}, {40, 20}, {20, 10}, {1, 0}, {0, 2}, {-1, -1}},
// clang-format on
};

TEST(DistributionTest, SubDistribution) {
for (const auto& test : tests_sub_distribution) {
Distribution dist(test.size, test.block_size, test.grid_size, test.rank, test.src_rank, test.offset);
Distribution sub_dist(dist, test.sub_offset, test.sub_size);

EXPECT_EQ(sub_dist.size(), test.sub_size);

EXPECT_EQ(sub_dist.blockSize(), dist.blockSize());
EXPECT_EQ(sub_dist.baseTileSize(), dist.baseTileSize());
EXPECT_EQ(sub_dist.rankIndex(), dist.rankIndex());
EXPECT_EQ(sub_dist.commGridSize(), dist.commGridSize());

EXPECT_LE(sub_dist.localSize().rows(), dist.localSize().rows());
EXPECT_LE(sub_dist.localSize().cols(), dist.localSize().cols());
EXPECT_LE(sub_dist.localNrTiles().rows(), dist.localNrTiles().rows());
EXPECT_LE(sub_dist.localNrTiles().cols(), dist.localNrTiles().cols());
EXPECT_LE(sub_dist.nrTiles().rows(), dist.nrTiles().rows());
EXPECT_LE(sub_dist.nrTiles().cols(), dist.nrTiles().cols());

if (!test.sub_size.isEmpty()) {
EXPECT_EQ(sub_dist.globalTileIndex(test.global_element), test.global_tile);
EXPECT_EQ(sub_dist.rankGlobalTile(sub_dist.globalTileIndex(test.global_element)), test.rank_tile);

EXPECT_EQ(sub_dist.localTileFromGlobalElement<Coord::Row>(test.global_element.get<Coord::Row>()),
test.local_tile[0]);
EXPECT_EQ(sub_dist.localTileFromGlobalElement<Coord::Col>(test.global_element.get<Coord::Col>()),
test.local_tile[1]);
}
}
}
Loading