Skip to content

Commit

Permalink
Support arbitrary array layout order in AsyncWriteArray
Browse files Browse the repository at this point in the history
This makes it easier to match the on-disk layout without the need for an
external IndexTransform to permute the dimensions.

PiperOrigin-RevId: 695974602
Change-Id: I03c64a7209ff430363c2f49a07d85079beebee50
  • Loading branch information
jbms authored and copybara-github committed Nov 13, 2024
1 parent d779cd5 commit 5f24912
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 145 deletions.
5 changes: 0 additions & 5 deletions tensorstore/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,6 @@ tensorstore_cc_library(
":nditerable",
":nditerable_buffer_management",
":nditerable_transformed_array",
":nditerable_util",
":unowned_to_shared",
"//tensorstore:array",
"//tensorstore:box",
Expand All @@ -888,7 +887,6 @@ tensorstore_cc_library(
"//tensorstore:rank",
"//tensorstore:strided_layout",
"//tensorstore/index_space:index_transform",
"//tensorstore/index_space:transformed_array",
"//tensorstore/util:byte_strided_pointer",
"//tensorstore/util:element_pointer",
"//tensorstore/util:iterate",
Expand Down Expand Up @@ -1018,13 +1016,10 @@ tensorstore_cc_library(
"//tensorstore:array",
"//tensorstore:box",
"//tensorstore:contiguous_layout",
"//tensorstore:data_type",
"//tensorstore:index",
"//tensorstore:rank",
"//tensorstore:strided_layout",
"//tensorstore/index_space:index_transform",
"//tensorstore/index_space:transformed_array",
"//tensorstore/util:element_pointer",
"//tensorstore/util:result",
"//tensorstore/util:span",
"//tensorstore/util:status",
Expand Down
18 changes: 10 additions & 8 deletions tensorstore/internal/async_write_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ Result<NDIterable::Ptr> AsyncWriteArray::Spec::GetReadNDIterable(
arena);
}

SharedArray<void> AsyncWriteArray::Spec::AllocateArray(
span<const Index> shape) const {
return tensorstore::AllocateArray(shape, layout_order(), default_init,
this->dtype());
}

AsyncWriteArray::MaskedArray::MaskedArray(DimensionIndex rank) : mask(rank) {}

void AsyncWriteArray::MaskedArray::WriteFillValue(const Spec& spec,
Expand Down Expand Up @@ -211,7 +217,7 @@ size_t AsyncWriteArray::MaskedArray::EstimateSizeInBytes(
if (array.valid()) {
total += GetByteExtent(array);
}
if (mask.mask_array) {
if (mask.mask_array.valid()) {
const Index num_elements = ProductOfExtents(shape);
total += num_elements * sizeof(bool);
}
Expand All @@ -220,9 +226,7 @@ size_t AsyncWriteArray::MaskedArray::EstimateSizeInBytes(

void AsyncWriteArray::MaskedArray::EnsureWritable(const Spec& spec) {
assert(array.valid());
auto new_array =
tensorstore::AllocateArray(array.shape(), tensorstore::c_order,
tensorstore::default_init, spec.dtype());
auto new_array = spec.AllocateArray(array.shape());
CopyArray(array, new_array);
array = std::move(new_array);
array_capabilities = kMutableArray;
Expand All @@ -234,9 +238,7 @@ AsyncWriteArray::MaskedArray::GetWritableTransformedArray(
// TODO(jbms): Could avoid copies when the output range of `chunk_transform`
// is known to fully cover ``domain`.
if (!array.valid()) {
this->array =
tensorstore::AllocateArray(domain.shape(), tensorstore::c_order,
tensorstore::default_init, spec.dtype());
this->array = spec.AllocateArray(domain.shape());
array_capabilities = kMutableArray;
if (IsFullyOverwritten(spec, domain)) {
// Previously, there was no data array allocated for the array but it
Expand Down Expand Up @@ -276,7 +278,7 @@ Result<NDIterable::Ptr> AsyncWriteArray::MaskedArray::BeginWrite(
void AsyncWriteArray::MaskedArray::EndWrite(
const Spec& spec, BoxView<> domain, IndexTransformView<> chunk_transform,
Arena* arena) {
WriteToMask(&mask, domain, chunk_transform, arena);
WriteToMask(&mask, domain, chunk_transform, spec.layout_order(), arena);
}

void AsyncWriteArray::MaskedArray::Clear() {
Expand Down
34 changes: 34 additions & 0 deletions tensorstore/internal/async_write_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "absl/status/status.h"
#include "tensorstore/array.h"
#include "tensorstore/box.h"
#include "tensorstore/contiguous_layout.h"
#include "tensorstore/data_type.h"
#include "tensorstore/index.h"
#include "tensorstore/index_space/index_transform.h"
Expand Down Expand Up @@ -65,6 +66,18 @@ struct AsyncWriteArray {
read_generation(std::move(other.read_generation)) {}

struct Spec {
Spec() = default;

template <typename LayoutOrder = ContiguousLayoutOrder,
typename = std::enable_if_t<IsContiguousLayoutOrder<LayoutOrder>>>
explicit Spec(SharedOffsetArray<const void> overall_fill_value,
Box<> valid_data_bounds, LayoutOrder order = c_order)
: overall_fill_value(std::move(overall_fill_value)),
valid_data_bounds(std::move(valid_data_bounds)) {
ConvertToContiguousLayoutPermutation(
order, span(layout_order_buffer, this->rank()));
}

/// The overall fill value. Every individual chunk must be contained within
/// `overall_fill_value.domain()`.
SharedOffsetArray<const void> overall_fill_value;
Expand All @@ -76,6 +89,18 @@ struct AsyncWriteArray {
/// preserved.
Box<> valid_data_bounds;

/// Buffer containing permutation specifying the storage order to use when
/// allocating a new array.
///
/// Note that this order is used when allocating a new order but does not
/// apply when the zero-copy `WriteArray` method is called.
///
/// Only the first `rank()` elements are meaningful.
///
/// For example, ``0, 1, 2`` denotes C order for rank 3, while ``2, 1, 0``
/// denotes F order.
DimensionIndex layout_order_buffer[kMaxRank];

/// If `true`, indicates that the array should be stored even if it equals
/// the fill value. By default (when set to `false`), when preparing a
/// writeback snapshot, if the value of the array is equal to the fill
Expand All @@ -102,6 +127,11 @@ struct AsyncWriteArray {
/// `domain`, translated to have a zero origin.
SharedArrayView<const void> GetFillValueForDomain(BoxView<> domain) const;

/// Storage order to use when allocating a new array.
ContiguousLayoutPermutation<> layout_order() const {
return ContiguousLayoutPermutation<>(span(layout_order_buffer, rank()));
}

/// Returns an `NDIterable` for that may be used for reading the specified
/// `array`, using the specified `chunk_transform`.
///
Expand All @@ -120,6 +150,10 @@ struct AsyncWriteArray {
if (!valid) return 0;
return ProductOfExtents(shape) * dtype()->size;
}

/// Allocates an array of the specified `shape`, for `this->dtype()` and
/// `this->layout_order`.
SharedArray<void> AllocateArray(span<const Index> shape) const;
};

/// Return type of `GetArrayForWriteback`.
Expand Down
2 changes: 1 addition & 1 deletion tensorstore/internal/async_write_array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ TEST(MaskedArrayTest, Basic) {
EXPECT_EQ(MakeArray<int32_t>({{9, 0, 0}, {0, 7, 8}}),
write_state.shared_array_view(spec));
EXPECT_EQ(MakeArray<bool>({{1, 0, 0}, {0, 1, 1}}),
tensorstore::Array(write_state.mask.mask_array.get(), {2, 3}));
write_state.mask.mask_array);
EXPECT_FALSE(write_state.IsUnmodified());
EXPECT_FALSE(write_state.IsFullyOverwritten(spec, domain));
// Both data array and mask array have been allocated.
Expand Down
133 changes: 61 additions & 72 deletions tensorstore/internal/masked_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,11 @@
#include "tensorstore/index.h"
#include "tensorstore/index_interval.h"
#include "tensorstore/index_space/index_transform.h"
#include "tensorstore/index_space/transformed_array.h"
#include "tensorstore/internal/arena.h"
#include "tensorstore/internal/elementwise_function.h"
#include "tensorstore/internal/integer_overflow.h"
#include "tensorstore/internal/memory.h"
#include "tensorstore/internal/nditerable.h"
#include "tensorstore/internal/nditerable_buffer_management.h"
#include "tensorstore/internal/nditerable_transformed_array.h"
#include "tensorstore/internal/nditerable_util.h"
#include "tensorstore/internal/unowned_to_shared.h"
#include "tensorstore/rank.h"
#include "tensorstore/strided_layout.h"
Expand Down Expand Up @@ -123,7 +119,7 @@ Index GetRelativeOffset(tensorstore::span<const Index> base,

void RemoveMaskArrayIfNotNeeded(MaskData* mask) {
if (mask->num_masked_elements == mask->region.num_elements()) {
mask->mask_array.reset();
mask->mask_array.element_pointer() = {};
}
}
} // namespace
Expand All @@ -132,84 +128,83 @@ MaskData::MaskData(DimensionIndex rank) : region(rank) {
region.Fill(IndexInterval::UncheckedSized(0, 0));
}

std::unique_ptr<bool[], FreeDeleter> CreateMaskArray(
BoxView<> box, BoxView<> mask_region,
tensorstore::span<const Index> byte_strides) {
std::unique_ptr<bool[], FreeDeleter> result(
static_cast<bool*>(std::calloc(box.num_elements(), sizeof(bool))));
ByteStridedPointer<bool> start = result.get();
start += GetRelativeOffset(box.origin(), mask_region.origin(), byte_strides);
SharedArray<bool> CreateMaskArray(BoxView<> box, BoxView<> mask_region,
ContiguousLayoutPermutation<> layout_order) {
auto array = AllocateArray<bool>(box.shape(), layout_order, value_init);
ByteStridedPointer<bool> start = array.data();
start += GetRelativeOffset(box.origin(), mask_region.origin(),
array.byte_strides());
internal::IterateOverArrays(
internal::SimpleElementwiseFunction<SetMask(bool), void*>{},
/*arg=*/nullptr,
/*constraints=*/skip_repeated_elements,
ArrayView<bool>(start.get(),
StridedLayoutView<>(mask_region.shape(), byte_strides)));
return result;
ArrayView<bool>(start.get(), StridedLayoutView<>(mask_region.shape(),
array.byte_strides())));
return array;
}

void CreateMaskArrayFromRegion(BoxView<> box, MaskData* mask,
tensorstore::span<const Index> byte_strides) {
ContiguousLayoutPermutation<> layout_order) {
assert(mask->num_masked_elements == mask->region.num_elements());
mask->mask_array = CreateMaskArray(box, mask->region, byte_strides);
assert(layout_order.size() == mask->region.rank());
mask->mask_array = CreateMaskArray(box, mask->region, layout_order);
}

void UnionMasks(BoxView<> box, MaskData* mask_a, MaskData* mask_b) {
void UnionMasks(BoxView<> box, MaskData* mask_a, MaskData* mask_b,
ContiguousLayoutPermutation<> layout_order) {
assert(mask_a != mask_b); // May work but not supported.
if (mask_a->num_masked_elements == 0) {
std::swap(*mask_a, *mask_b);
return;
} else if (mask_b->num_masked_elements == 0) {
return;
}
const DimensionIndex rank = box.rank();
assert(mask_a->region.rank() == rank);
assert(mask_b->region.rank() == rank);
assert(mask_a->region.rank() == box.rank());
assert(mask_b->region.rank() == box.rank());

if (mask_a->mask_array && mask_b->mask_array) {
const Index size = box.num_elements();
mask_a->num_masked_elements = 0;
for (Index i = 0; i < size; ++i) {
if ((mask_a->mask_array[i] |= mask_b->mask_array[i])) {
++mask_a->num_masked_elements;
}
}
if (mask_a->mask_array.valid() && mask_b->mask_array.valid()) {
Index num_masked_elements = 0;
IterateOverArrays(
[&](bool* a, bool* b) {
if ((*a |= *b) == true) {
++num_masked_elements;
}
},
/*constraints=*/{}, mask_a->mask_array, mask_b->mask_array);
mask_a->num_masked_elements = num_masked_elements;
Hull(mask_a->region, mask_b->region, mask_a->region);
RemoveMaskArrayIfNotNeeded(mask_a);
return;
}

if (!mask_a->mask_array && !mask_b->mask_array) {
if (!mask_a->mask_array.valid() && !mask_b->mask_array.valid()) {
if (IsHullEqualToUnion(mask_a->region, mask_b->region)) {
// The combined mask can be specified by the region alone.
Hull(mask_a->region, mask_b->region, mask_a->region);
mask_a->num_masked_elements = mask_a->region.num_elements();
return;
}
} else if (!mask_a->mask_array) {
} else if (!mask_a->mask_array.valid()) {
std::swap(*mask_a, *mask_b);
}

Index byte_strides[kMaxRank]; // Only first `rank` elements are used.
const tensorstore::span<Index> byte_strides_span(&byte_strides[0], rank);
ComputeStrides(ContiguousLayoutOrder::c, sizeof(bool), box.shape(),
byte_strides_span);
if (!mask_a->mask_array) {
CreateMaskArrayFromRegion(box, mask_a, byte_strides_span);
if (!mask_a->mask_array.valid()) {
CreateMaskArrayFromRegion(box, mask_a, layout_order);
}

// Copy in mask_b.
ByteStridedPointer<bool> start = mask_a->mask_array.get();
ByteStridedPointer<bool> start = mask_a->mask_array.data();
start += GetRelativeOffset(box.origin(), mask_b->region.origin(),
byte_strides_span);
mask_a->mask_array.byte_strides());
IterateOverArrays(
[&](bool* ptr) {
if (!*ptr) ++mask_a->num_masked_elements;
*ptr = true;
},
/*constraints=*/{},
ArrayView<bool>(start.get(), StridedLayoutView<>(mask_b->region.shape(),
byte_strides_span)));
ArrayView<bool>(start.get(),
StridedLayoutView<>(mask_b->region.shape(),
mask_a->mask_array.byte_strides())));
Hull(mask_a->region, mask_b->region, mask_a->region);
RemoveMaskArrayIfNotNeeded(mask_a);
}
Expand All @@ -229,29 +224,30 @@ void RebaseMaskedArray(BoxView<> box, ArrayView<const void> source,
assert(success);
return;
}
Index mask_byte_strides_storage[kMaxRank];
const tensorstore::span<Index> mask_byte_strides(
&mask_byte_strides_storage[0], box.rank());
ComputeStrides(ContiguousLayoutOrder::c, sizeof(bool), box.shape(),
mask_byte_strides);
std::unique_ptr<bool[], FreeDeleter> mask_owner;
bool* mask_array_ptr;
if (!mask.mask_array) {
mask_owner = CreateMaskArray(box, mask.region, mask_byte_strides);
mask_array_ptr = mask_owner.get();

// Materialize mask array.
ArrayView<bool> mask_array_view;
SharedArray<bool> mask_array;
if (mask.mask_array.valid()) {
mask_array_view = mask.mask_array;
} else {
mask_array_ptr = mask.mask_array.get();
DimensionIndex layout_order[kMaxRank];
tensorstore::span<DimensionIndex> layout_order_span(layout_order,
dest.rank());
SetPermutationFromStrides(dest.byte_strides(), layout_order_span);
mask_array = CreateMaskArray(
box, mask.region, ContiguousLayoutPermutation<>(layout_order_span));
mask_array_view = mask_array;
}
ArrayView<const bool> mask_array(
mask_array_ptr, StridedLayoutView<>(box.shape(), mask_byte_strides));
[[maybe_unused]] const auto success = internal::IterateOverArrays(
{&dtype->copy_assign_unmasked, /*context=*/nullptr},
/*arg=*/nullptr, skip_repeated_elements, source, dest, mask_array);
/*arg=*/nullptr, skip_repeated_elements, source, dest, mask_array_view);
assert(success);
}

void WriteToMask(MaskData* mask, BoxView<> output_box,
IndexTransformView<> input_to_output, Arena* arena) {
IndexTransformView<> input_to_output,
ContiguousLayoutPermutation<> layout_order, Arena* arena) {
assert(input_to_output.output_rank() == output_box.rank());

if (input_to_output.domain().box().is_empty()) {
Expand All @@ -265,35 +261,28 @@ void WriteToMask(MaskData* mask, BoxView<> output_box,
GetOutputRange(input_to_output, output_range).value();
Intersect(output_range, output_box, output_range);

Index mask_byte_strides_storage[kMaxRank];
const tensorstore::span<Index> mask_byte_strides(
&mask_byte_strides_storage[0], output_rank);
ComputeStrides(ContiguousLayoutOrder::c, sizeof(bool), output_box.shape(),
mask_byte_strides);
StridedLayoutView<dynamic_rank, offset_origin> mask_layout(output_box,
mask_byte_strides);

const bool use_mask_array =
output_box.rank() != 0 &&
mask->num_masked_elements != output_box.num_elements() &&
(static_cast<bool>(mask->mask_array) ||
(mask->mask_array.valid() ||
(!Contains(mask->region, output_range) &&
(!range_is_exact || !IsHullEqualToUnion(mask->region, output_range))));
if (use_mask_array && !mask->mask_array) {
CreateMaskArrayFromRegion(output_box, mask, mask_byte_strides);
if (use_mask_array && !mask->mask_array.valid()) {
CreateMaskArrayFromRegion(output_box, mask, layout_order);
}
Hull(mask->region, output_range, mask->region);

if (use_mask_array) {
// Cannot fail, because `input_to_output` must have already been validated.
StridedLayoutView<dynamic_rank, offset_origin> mask_layout(
output_box, mask->mask_array.byte_strides());
auto mask_iterable =
GetTransformedArrayNDIterable(
ArrayView<Shared<bool>, dynamic_rank, offset_origin>(
AddByteOffset(
SharedElementPointer<bool>(
UnownedToShared(mask->mask_array.get())),
-IndexInnerProduct(output_box.origin(),
tensorstore::span(mask_byte_strides))),
AddByteOffset(SharedElementPointer<bool>(
UnownedToShared(mask->mask_array.data())),
-IndexInnerProduct(output_box.origin(),
mask_layout.byte_strides())),
mask_layout),
input_to_output, arena)
.value();
Expand Down
Loading

0 comments on commit 5f24912

Please sign in to comment.