From 5f24912998794fec738203404ef168ca0e6c744b Mon Sep 17 00:00:00 2001 From: Jeremy Maitin-Shepard Date: Tue, 12 Nov 2024 21:22:31 -0800 Subject: [PATCH] Support arbitrary array layout order in AsyncWriteArray This makes it easier to match the on-disk layout without the need for an external IndexTransform to permute the dimensions. PiperOrigin-RevId: 695974602 Change-Id: I03c64a7209ff430363c2f49a07d85079beebee50 --- tensorstore/internal/BUILD | 5 - tensorstore/internal/async_write_array.cc | 18 +-- tensorstore/internal/async_write_array.h | 34 +++++ .../internal/async_write_array_test.cc | 2 +- tensorstore/internal/masked_array.cc | 133 ++++++++---------- tensorstore/internal/masked_array.h | 22 +-- tensorstore/internal/masked_array_test.cc | 44 +++--- tensorstore/internal/masked_array_testutil.cc | 37 ++--- tensorstore/internal/masked_array_testutil.h | 10 +- 9 files changed, 160 insertions(+), 145 deletions(-) diff --git a/tensorstore/internal/BUILD b/tensorstore/internal/BUILD index e28de5985..f61b2dbf2 100644 --- a/tensorstore/internal/BUILD +++ b/tensorstore/internal/BUILD @@ -877,7 +877,6 @@ tensorstore_cc_library( ":nditerable", ":nditerable_buffer_management", ":nditerable_transformed_array", - ":nditerable_util", ":unowned_to_shared", "//tensorstore:array", "//tensorstore:box", @@ -888,7 +887,6 @@ tensorstore_cc_library( "//tensorstore:rank", "//tensorstore:strided_layout", "//tensorstore/index_space:index_transform", - "//tensorstore/index_space:transformed_array", "//tensorstore/util:byte_strided_pointer", "//tensorstore/util:element_pointer", "//tensorstore/util:iterate", @@ -1018,13 +1016,10 @@ tensorstore_cc_library( "//tensorstore:array", "//tensorstore:box", "//tensorstore:contiguous_layout", - "//tensorstore:data_type", "//tensorstore:index", "//tensorstore:rank", - "//tensorstore:strided_layout", "//tensorstore/index_space:index_transform", "//tensorstore/index_space:transformed_array", - "//tensorstore/util:element_pointer", "//tensorstore/util:result", "//tensorstore/util:span", "//tensorstore/util:status", diff --git a/tensorstore/internal/async_write_array.cc b/tensorstore/internal/async_write_array.cc index fcf3b60af..5e9a66db2 100644 --- a/tensorstore/internal/async_write_array.cc +++ b/tensorstore/internal/async_write_array.cc @@ -92,6 +92,12 @@ Result AsyncWriteArray::Spec::GetReadNDIterable( arena); } +SharedArray AsyncWriteArray::Spec::AllocateArray( + span shape) const { + return tensorstore::AllocateArray(shape, layout_order(), default_init, + this->dtype()); +} + AsyncWriteArray::MaskedArray::MaskedArray(DimensionIndex rank) : mask(rank) {} void AsyncWriteArray::MaskedArray::WriteFillValue(const Spec& spec, @@ -211,7 +217,7 @@ size_t AsyncWriteArray::MaskedArray::EstimateSizeInBytes( if (array.valid()) { total += GetByteExtent(array); } - if (mask.mask_array) { + if (mask.mask_array.valid()) { const Index num_elements = ProductOfExtents(shape); total += num_elements * sizeof(bool); } @@ -220,9 +226,7 @@ size_t AsyncWriteArray::MaskedArray::EstimateSizeInBytes( void AsyncWriteArray::MaskedArray::EnsureWritable(const Spec& spec) { assert(array.valid()); - auto new_array = - tensorstore::AllocateArray(array.shape(), tensorstore::c_order, - tensorstore::default_init, spec.dtype()); + auto new_array = spec.AllocateArray(array.shape()); CopyArray(array, new_array); array = std::move(new_array); array_capabilities = kMutableArray; @@ -234,9 +238,7 @@ AsyncWriteArray::MaskedArray::GetWritableTransformedArray( // TODO(jbms): Could avoid copies when the output range of `chunk_transform` // is known to fully cover ``domain`. if (!array.valid()) { - this->array = - tensorstore::AllocateArray(domain.shape(), tensorstore::c_order, - tensorstore::default_init, spec.dtype()); + this->array = spec.AllocateArray(domain.shape()); array_capabilities = kMutableArray; if (IsFullyOverwritten(spec, domain)) { // Previously, there was no data array allocated for the array but it @@ -276,7 +278,7 @@ Result AsyncWriteArray::MaskedArray::BeginWrite( void AsyncWriteArray::MaskedArray::EndWrite( const Spec& spec, BoxView<> domain, IndexTransformView<> chunk_transform, Arena* arena) { - WriteToMask(&mask, domain, chunk_transform, arena); + WriteToMask(&mask, domain, chunk_transform, spec.layout_order(), arena); } void AsyncWriteArray::MaskedArray::Clear() { diff --git a/tensorstore/internal/async_write_array.h b/tensorstore/internal/async_write_array.h index 69b21d803..4988e8238 100644 --- a/tensorstore/internal/async_write_array.h +++ b/tensorstore/internal/async_write_array.h @@ -28,6 +28,7 @@ #include "absl/status/status.h" #include "tensorstore/array.h" #include "tensorstore/box.h" +#include "tensorstore/contiguous_layout.h" #include "tensorstore/data_type.h" #include "tensorstore/index.h" #include "tensorstore/index_space/index_transform.h" @@ -65,6 +66,18 @@ struct AsyncWriteArray { read_generation(std::move(other.read_generation)) {} struct Spec { + Spec() = default; + + template >> + explicit Spec(SharedOffsetArray overall_fill_value, + Box<> valid_data_bounds, LayoutOrder order = c_order) + : overall_fill_value(std::move(overall_fill_value)), + valid_data_bounds(std::move(valid_data_bounds)) { + ConvertToContiguousLayoutPermutation( + order, span(layout_order_buffer, this->rank())); + } + /// The overall fill value. Every individual chunk must be contained within /// `overall_fill_value.domain()`. SharedOffsetArray overall_fill_value; @@ -76,6 +89,18 @@ struct AsyncWriteArray { /// preserved. Box<> valid_data_bounds; + /// Buffer containing permutation specifying the storage order to use when + /// allocating a new array. + /// + /// Note that this order is used when allocating a new order but does not + /// apply when the zero-copy `WriteArray` method is called. + /// + /// Only the first `rank()` elements are meaningful. + /// + /// For example, ``0, 1, 2`` denotes C order for rank 3, while ``2, 1, 0`` + /// denotes F order. + DimensionIndex layout_order_buffer[kMaxRank]; + /// If `true`, indicates that the array should be stored even if it equals /// the fill value. By default (when set to `false`), when preparing a /// writeback snapshot, if the value of the array is equal to the fill @@ -102,6 +127,11 @@ struct AsyncWriteArray { /// `domain`, translated to have a zero origin. SharedArrayView GetFillValueForDomain(BoxView<> domain) const; + /// Storage order to use when allocating a new array. + ContiguousLayoutPermutation<> layout_order() const { + return ContiguousLayoutPermutation<>(span(layout_order_buffer, rank())); + } + /// Returns an `NDIterable` for that may be used for reading the specified /// `array`, using the specified `chunk_transform`. /// @@ -120,6 +150,10 @@ struct AsyncWriteArray { if (!valid) return 0; return ProductOfExtents(shape) * dtype()->size; } + + /// Allocates an array of the specified `shape`, for `this->dtype()` and + /// `this->layout_order`. + SharedArray AllocateArray(span shape) const; }; /// Return type of `GetArrayForWriteback`. diff --git a/tensorstore/internal/async_write_array_test.cc b/tensorstore/internal/async_write_array_test.cc index c6174685a..7dd5e06b5 100644 --- a/tensorstore/internal/async_write_array_test.cc +++ b/tensorstore/internal/async_write_array_test.cc @@ -210,7 +210,7 @@ TEST(MaskedArrayTest, Basic) { EXPECT_EQ(MakeArray({{9, 0, 0}, {0, 7, 8}}), write_state.shared_array_view(spec)); EXPECT_EQ(MakeArray({{1, 0, 0}, {0, 1, 1}}), - tensorstore::Array(write_state.mask.mask_array.get(), {2, 3})); + write_state.mask.mask_array); EXPECT_FALSE(write_state.IsUnmodified()); EXPECT_FALSE(write_state.IsFullyOverwritten(spec, domain)); // Both data array and mask array have been allocated. diff --git a/tensorstore/internal/masked_array.cc b/tensorstore/internal/masked_array.cc index 4969e4138..0d7804f7c 100644 --- a/tensorstore/internal/masked_array.cc +++ b/tensorstore/internal/masked_array.cc @@ -27,15 +27,11 @@ #include "tensorstore/index.h" #include "tensorstore/index_interval.h" #include "tensorstore/index_space/index_transform.h" -#include "tensorstore/index_space/transformed_array.h" #include "tensorstore/internal/arena.h" #include "tensorstore/internal/elementwise_function.h" #include "tensorstore/internal/integer_overflow.h" -#include "tensorstore/internal/memory.h" -#include "tensorstore/internal/nditerable.h" #include "tensorstore/internal/nditerable_buffer_management.h" #include "tensorstore/internal/nditerable_transformed_array.h" -#include "tensorstore/internal/nditerable_util.h" #include "tensorstore/internal/unowned_to_shared.h" #include "tensorstore/rank.h" #include "tensorstore/strided_layout.h" @@ -123,7 +119,7 @@ Index GetRelativeOffset(tensorstore::span base, void RemoveMaskArrayIfNotNeeded(MaskData* mask) { if (mask->num_masked_elements == mask->region.num_elements()) { - mask->mask_array.reset(); + mask->mask_array.element_pointer() = {}; } } } // namespace @@ -132,29 +128,30 @@ MaskData::MaskData(DimensionIndex rank) : region(rank) { region.Fill(IndexInterval::UncheckedSized(0, 0)); } -std::unique_ptr CreateMaskArray( - BoxView<> box, BoxView<> mask_region, - tensorstore::span byte_strides) { - std::unique_ptr result( - static_cast(std::calloc(box.num_elements(), sizeof(bool)))); - ByteStridedPointer start = result.get(); - start += GetRelativeOffset(box.origin(), mask_region.origin(), byte_strides); +SharedArray CreateMaskArray(BoxView<> box, BoxView<> mask_region, + ContiguousLayoutPermutation<> layout_order) { + auto array = AllocateArray(box.shape(), layout_order, value_init); + ByteStridedPointer start = array.data(); + start += GetRelativeOffset(box.origin(), mask_region.origin(), + array.byte_strides()); internal::IterateOverArrays( internal::SimpleElementwiseFunction{}, /*arg=*/nullptr, /*constraints=*/skip_repeated_elements, - ArrayView(start.get(), - StridedLayoutView<>(mask_region.shape(), byte_strides))); - return result; + ArrayView(start.get(), StridedLayoutView<>(mask_region.shape(), + array.byte_strides()))); + return array; } void CreateMaskArrayFromRegion(BoxView<> box, MaskData* mask, - tensorstore::span byte_strides) { + ContiguousLayoutPermutation<> layout_order) { assert(mask->num_masked_elements == mask->region.num_elements()); - mask->mask_array = CreateMaskArray(box, mask->region, byte_strides); + assert(layout_order.size() == mask->region.rank()); + mask->mask_array = CreateMaskArray(box, mask->region, layout_order); } -void UnionMasks(BoxView<> box, MaskData* mask_a, MaskData* mask_b) { +void UnionMasks(BoxView<> box, MaskData* mask_a, MaskData* mask_b, + ContiguousLayoutPermutation<> layout_order) { assert(mask_a != mask_b); // May work but not supported. if (mask_a->num_masked_elements == 0) { std::swap(*mask_a, *mask_b); @@ -162,54 +159,52 @@ void UnionMasks(BoxView<> box, MaskData* mask_a, MaskData* mask_b) { } else if (mask_b->num_masked_elements == 0) { return; } - const DimensionIndex rank = box.rank(); - assert(mask_a->region.rank() == rank); - assert(mask_b->region.rank() == rank); + assert(mask_a->region.rank() == box.rank()); + assert(mask_b->region.rank() == box.rank()); - if (mask_a->mask_array && mask_b->mask_array) { - const Index size = box.num_elements(); - mask_a->num_masked_elements = 0; - for (Index i = 0; i < size; ++i) { - if ((mask_a->mask_array[i] |= mask_b->mask_array[i])) { - ++mask_a->num_masked_elements; - } - } + if (mask_a->mask_array.valid() && mask_b->mask_array.valid()) { + Index num_masked_elements = 0; + IterateOverArrays( + [&](bool* a, bool* b) { + if ((*a |= *b) == true) { + ++num_masked_elements; + } + }, + /*constraints=*/{}, mask_a->mask_array, mask_b->mask_array); + mask_a->num_masked_elements = num_masked_elements; Hull(mask_a->region, mask_b->region, mask_a->region); RemoveMaskArrayIfNotNeeded(mask_a); return; } - if (!mask_a->mask_array && !mask_b->mask_array) { + if (!mask_a->mask_array.valid() && !mask_b->mask_array.valid()) { if (IsHullEqualToUnion(mask_a->region, mask_b->region)) { // The combined mask can be specified by the region alone. Hull(mask_a->region, mask_b->region, mask_a->region); mask_a->num_masked_elements = mask_a->region.num_elements(); return; } - } else if (!mask_a->mask_array) { + } else if (!mask_a->mask_array.valid()) { std::swap(*mask_a, *mask_b); } - Index byte_strides[kMaxRank]; // Only first `rank` elements are used. - const tensorstore::span byte_strides_span(&byte_strides[0], rank); - ComputeStrides(ContiguousLayoutOrder::c, sizeof(bool), box.shape(), - byte_strides_span); - if (!mask_a->mask_array) { - CreateMaskArrayFromRegion(box, mask_a, byte_strides_span); + if (!mask_a->mask_array.valid()) { + CreateMaskArrayFromRegion(box, mask_a, layout_order); } // Copy in mask_b. - ByteStridedPointer start = mask_a->mask_array.get(); + ByteStridedPointer start = mask_a->mask_array.data(); start += GetRelativeOffset(box.origin(), mask_b->region.origin(), - byte_strides_span); + mask_a->mask_array.byte_strides()); IterateOverArrays( [&](bool* ptr) { if (!*ptr) ++mask_a->num_masked_elements; *ptr = true; }, /*constraints=*/{}, - ArrayView(start.get(), StridedLayoutView<>(mask_b->region.shape(), - byte_strides_span))); + ArrayView(start.get(), + StridedLayoutView<>(mask_b->region.shape(), + mask_a->mask_array.byte_strides()))); Hull(mask_a->region, mask_b->region, mask_a->region); RemoveMaskArrayIfNotNeeded(mask_a); } @@ -229,29 +224,30 @@ void RebaseMaskedArray(BoxView<> box, ArrayView source, assert(success); return; } - Index mask_byte_strides_storage[kMaxRank]; - const tensorstore::span mask_byte_strides( - &mask_byte_strides_storage[0], box.rank()); - ComputeStrides(ContiguousLayoutOrder::c, sizeof(bool), box.shape(), - mask_byte_strides); - std::unique_ptr mask_owner; - bool* mask_array_ptr; - if (!mask.mask_array) { - mask_owner = CreateMaskArray(box, mask.region, mask_byte_strides); - mask_array_ptr = mask_owner.get(); + + // Materialize mask array. + ArrayView mask_array_view; + SharedArray mask_array; + if (mask.mask_array.valid()) { + mask_array_view = mask.mask_array; } else { - mask_array_ptr = mask.mask_array.get(); + DimensionIndex layout_order[kMaxRank]; + tensorstore::span layout_order_span(layout_order, + dest.rank()); + SetPermutationFromStrides(dest.byte_strides(), layout_order_span); + mask_array = CreateMaskArray( + box, mask.region, ContiguousLayoutPermutation<>(layout_order_span)); + mask_array_view = mask_array; } - ArrayView mask_array( - mask_array_ptr, StridedLayoutView<>(box.shape(), mask_byte_strides)); [[maybe_unused]] const auto success = internal::IterateOverArrays( {&dtype->copy_assign_unmasked, /*context=*/nullptr}, - /*arg=*/nullptr, skip_repeated_elements, source, dest, mask_array); + /*arg=*/nullptr, skip_repeated_elements, source, dest, mask_array_view); assert(success); } void WriteToMask(MaskData* mask, BoxView<> output_box, - IndexTransformView<> input_to_output, Arena* arena) { + IndexTransformView<> input_to_output, + ContiguousLayoutPermutation<> layout_order, Arena* arena) { assert(input_to_output.output_rank() == output_box.rank()); if (input_to_output.domain().box().is_empty()) { @@ -265,35 +261,28 @@ void WriteToMask(MaskData* mask, BoxView<> output_box, GetOutputRange(input_to_output, output_range).value(); Intersect(output_range, output_box, output_range); - Index mask_byte_strides_storage[kMaxRank]; - const tensorstore::span mask_byte_strides( - &mask_byte_strides_storage[0], output_rank); - ComputeStrides(ContiguousLayoutOrder::c, sizeof(bool), output_box.shape(), - mask_byte_strides); - StridedLayoutView mask_layout(output_box, - mask_byte_strides); - const bool use_mask_array = output_box.rank() != 0 && mask->num_masked_elements != output_box.num_elements() && - (static_cast(mask->mask_array) || + (mask->mask_array.valid() || (!Contains(mask->region, output_range) && (!range_is_exact || !IsHullEqualToUnion(mask->region, output_range)))); - if (use_mask_array && !mask->mask_array) { - CreateMaskArrayFromRegion(output_box, mask, mask_byte_strides); + if (use_mask_array && !mask->mask_array.valid()) { + CreateMaskArrayFromRegion(output_box, mask, layout_order); } Hull(mask->region, output_range, mask->region); if (use_mask_array) { // Cannot fail, because `input_to_output` must have already been validated. + StridedLayoutView mask_layout( + output_box, mask->mask_array.byte_strides()); auto mask_iterable = GetTransformedArrayNDIterable( ArrayView, dynamic_rank, offset_origin>( - AddByteOffset( - SharedElementPointer( - UnownedToShared(mask->mask_array.get())), - -IndexInnerProduct(output_box.origin(), - tensorstore::span(mask_byte_strides))), + AddByteOffset(SharedElementPointer( + UnownedToShared(mask->mask_array.data())), + -IndexInnerProduct(output_box.origin(), + mask_layout.byte_strides())), mask_layout), input_to_output, arena) .value(); diff --git a/tensorstore/internal/masked_array.h b/tensorstore/internal/masked_array.h index a57ccdc12..c880a0a27 100644 --- a/tensorstore/internal/masked_array.h +++ b/tensorstore/internal/masked_array.h @@ -53,18 +53,18 @@ struct MaskData { void Reset() { num_masked_elements = 0; - mask_array.reset(); + mask_array.element_pointer() = {}; region.Fill(IndexInterval::UncheckedSized(0, 0)); } - /// If not `nullptr`, stores a mask array of size `mask_box.shape()` in C - /// order, where all elements outside `region` are `false`. If `nullptr`, + /// If `mask_array.valid()`, stores a mask array of size `mask_box.shape()`, + /// where all elements outside `region` are `false`. If `!mask_array.valid()`, /// indicates that all elements within `region` are masked. - std::unique_ptr mask_array; + SharedArray mask_array; /// Number of `true` values in `mask_array`, or `region.num_elements()` if - /// `mask_array` is `nullptr`. As a special case, if `region.rank() == 0`, - /// `num_masked_elements` may equal `0` even if `mask_array` is `nullptr` to + /// `!mask_array.valid()`. As a special case, if `region.rank() == 0`, + /// `num_masked_elements` may equal `0` even if `!mask_array.valid()` to /// indicate that the singleton element is not included in the mask. Index num_masked_elements = 0; @@ -79,9 +79,12 @@ struct MaskData { /// \param output_box Domain of the `mask`. /// \param input_to_output Transform that specifies the mapping to `output_box`. /// Must be valid. +/// \param Permutation of length `output_box.rank()` specifying the layout order +/// for any newly-allocated mask array. /// \param arena Allocation arena that may be used. void WriteToMask(MaskData* mask, BoxView<> output_box, - IndexTransformView<> input_to_output, Arena* arena); + IndexTransformView<> input_to_output, + ContiguousLayoutPermutation<> layout_order, Arena* arena); /// Copies unmasked elements from `source_data` to `data_ptr`. /// @@ -100,7 +103,10 @@ void RebaseMaskedArray(BoxView<> box, ArrayView source, /// May modify `*mask_b`. /// /// \param box The region over which the two masks are defined. -void UnionMasks(BoxView<> box, MaskData* mask_a, MaskData* mask_b); +/// \param Permutation of length `box.rank()` specifying the layout order +/// for any newly-allocated mask array. +void UnionMasks(BoxView<> box, MaskData* mask_a, MaskData* mask_b, + ContiguousLayoutPermutation<> layout_order); } // namespace internal } // namespace tensorstore diff --git a/tensorstore/internal/masked_array_test.cc b/tensorstore/internal/masked_array_test.cc index 1e87c819c..339f6c2bf 100644 --- a/tensorstore/internal/masked_array_test.cc +++ b/tensorstore/internal/masked_array_test.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -67,25 +68,28 @@ using ::tensorstore::internal::SimpleElementwiseFunction; /// and a StridedLayout representing the mask array layout. class MaskedArrayTester { public: - explicit MaskedArrayTester(BoxView<> box) - : box_(box), - mask_(box.rank()), - mask_layout_zero_origin_(tensorstore::ContiguousLayoutOrder::c, - sizeof(bool), box.shape()) {} - - ArrayView mask_array() const { - if (!mask_.mask_array) return {}; - return ArrayView(mask_.mask_array.get(), - mask_layout_zero_origin_); + template + explicit MaskedArrayTester(BoxView<> box, + LayoutOrder layout_order = tensorstore::c_order) + : box_(box), mask_(box.rank()) { + layout_order_.resize(box.rank()); + tensorstore::ConvertToContiguousLayoutPermutation( + layout_order, tensorstore::span(layout_order_)); } + ArrayView mask_array() const { return mask_.mask_array; } + Index num_masked_elements() const { return mask_.num_masked_elements; } BoxView<> mask_region() const { return mask_.region; } const MaskData& mask() const { return mask_; } BoxView<> domain() const { return box_; } + tensorstore::ContiguousLayoutPermutation<> layout_order() const { + return tensorstore::ContiguousLayoutPermutation<>(layout_order_); + } + void Combine(MaskedArrayTester&& other) { - UnionMasks(box_, &mask_, &other.mask_); + UnionMasks(box_, &mask_, &other.mask_, layout_order()); } void Reset() { mask_.Reset(); } @@ -93,7 +97,7 @@ class MaskedArrayTester { protected: Box<> box_; MaskData mask_; - StridedLayout<> mask_layout_zero_origin_; + std::vector layout_order_; }; /// Extends MaskedArrayTester to also include an array of type T defined over @@ -103,12 +107,13 @@ class MaskedArrayTester { template class MaskedArrayWriteTester : public MaskedArrayTester { public: - explicit MaskedArrayWriteTester(BoxView<> box) - : MaskedArrayTester(box), - dest_(tensorstore::AllocateArray(box, tensorstore::c_order, + template + explicit MaskedArrayWriteTester( + BoxView<> box, LayoutOrder layout_order = tensorstore::c_order) + : MaskedArrayTester(box, layout_order), + dest_(tensorstore::AllocateArray(box, layout_order, tensorstore::value_init)), - dest_layout_zero_origin_(tensorstore::ContiguousLayoutOrder::c, - sizeof(T), box.shape()) {} + dest_layout_zero_origin_(dest_.shape(), dest_.byte_strides()) {} template absl::Status Write(IndexTransformView<> dest_transform, @@ -116,8 +121,7 @@ class MaskedArrayWriteTester : public MaskedArrayTester { ElementCopyFunction copy_function = SimpleElementwiseFunction(const T, T), void*>(); - return WriteToMaskedArray(dest_.byte_strided_origin_pointer().get(), &mask_, - dest_.domain(), dest_transform, source, + return WriteToMaskedArray(dest_, &mask_, dest_transform, source, {©_function, ©_func}); } @@ -150,7 +154,7 @@ class MaskedArrayWriteTester : public MaskedArrayTester { TEST(MaskDataTest, Construct) { MaskData mask(3); - EXPECT_FALSE(mask.mask_array); + EXPECT_FALSE(mask.mask_array.valid()); EXPECT_EQ(0, mask.num_masked_elements); EXPECT_EQ(0, mask.region.num_elements()); } diff --git a/tensorstore/internal/masked_array_testutil.cc b/tensorstore/internal/masked_array_testutil.cc index 44de43c85..c3500e88a 100644 --- a/tensorstore/internal/masked_array_testutil.cc +++ b/tensorstore/internal/masked_array_testutil.cc @@ -21,7 +21,6 @@ #include "tensorstore/array.h" #include "tensorstore/box.h" #include "tensorstore/contiguous_layout.h" -#include "tensorstore/data_type.h" #include "tensorstore/index.h" #include "tensorstore/index_space/index_transform.h" #include "tensorstore/index_space/transformed_array.h" @@ -33,8 +32,6 @@ #include "tensorstore/internal/nditerable_elementwise_input_transform.h" #include "tensorstore/internal/nditerable_transformed_array.h" #include "tensorstore/rank.h" -#include "tensorstore/strided_layout.h" -#include "tensorstore/util/element_pointer.h" #include "tensorstore/util/result.h" #include "tensorstore/util/span.h" #include "tensorstore/util/status.h" @@ -42,37 +39,28 @@ namespace tensorstore { namespace internal { -absl::Status WriteToMaskedArray(ElementPointer output_ptr, MaskData* mask, - BoxView<> output_box, +absl::Status WriteToMaskedArray(SharedOffsetArray output, MaskData* mask, IndexTransformView<> input_to_output, const NDIterable& source, Arena* arena) { - const DimensionIndex output_rank = output_box.rank(); - Index data_byte_strides_storage[kMaxRank]; - const tensorstore::span data_byte_strides( - &data_byte_strides_storage[0], output_rank); - ComputeStrides(ContiguousLayoutOrder::c, output_ptr.dtype()->size, - output_box.shape(), data_byte_strides); + const DimensionIndex output_rank = output.rank(); TENSORSTORE_ASSIGN_OR_RETURN( auto dest_iterable, - GetTransformedArrayNDIterable( - {UnownedToShared(AddByteOffset( - output_ptr, - -IndexInnerProduct(output_box.origin(), - tensorstore::span(data_byte_strides)))), - StridedLayoutView(output_box, - data_byte_strides)}, - input_to_output, arena)); + GetTransformedArrayNDIterable(output, input_to_output, arena)); TENSORSTORE_RETURN_IF_ERROR(NDIterableCopier(source, *dest_iterable, input_to_output.input_shape(), arena) .Copy()); - WriteToMask(mask, output_box, input_to_output, arena); + DimensionIndex layout_order[kMaxRank]; + tensorstore::span layout_order_span(layout_order, + output_rank); + SetPermutationFromStrides(output.byte_strides(), layout_order_span); + WriteToMask(mask, output.domain(), input_to_output, + ContiguousLayoutPermutation<>(layout_order_span), arena); return absl::OkStatus(); } -absl::Status WriteToMaskedArray(ElementPointer output_ptr, MaskData* mask, - BoxView<> output_box, +absl::Status WriteToMaskedArray(SharedOffsetArray output, MaskData* mask, IndexTransformView<> input_to_output, TransformedArray source, ElementCopyFunction::Closure copy_function) { @@ -87,9 +75,8 @@ absl::Status WriteToMaskedArray(ElementPointer output_ptr, MaskData* mask, auto source_iterable, GetTransformedArrayNDIterable(UnownedToShared(source), &arena)); auto transformed_source_iterable = GetElementwiseInputTransformNDIterable( - {{std::move(source_iterable)}}, output_ptr.dtype(), copy_function, - &arena); - return WriteToMaskedArray(output_ptr, mask, output_box, input_to_output, + {{std::move(source_iterable)}}, output.dtype(), copy_function, &arena); + return WriteToMaskedArray(std::move(output), mask, input_to_output, *transformed_source_iterable, &arena); } diff --git a/tensorstore/internal/masked_array_testutil.h b/tensorstore/internal/masked_array_testutil.h index 1c923dc84..ae0690c6c 100644 --- a/tensorstore/internal/masked_array_testutil.h +++ b/tensorstore/internal/masked_array_testutil.h @@ -16,12 +16,12 @@ #define TENSORSTORE_INTERNAL_MASKED_ARRAY_TESTUTIL_H_ #include "absl/status/status.h" +#include "tensorstore/array.h" #include "tensorstore/box.h" #include "tensorstore/index_space/index_transform.h" #include "tensorstore/index_space/transformed_array.h" #include "tensorstore/internal/element_copy_function.h" #include "tensorstore/internal/masked_array.h" -#include "tensorstore/util/element_pointer.h" #include "tensorstore/util/status.h" namespace tensorstore { @@ -30,9 +30,8 @@ namespace internal { /// Copies the contents of `source` to an "output" array, and updates `*mask` to /// include all positions that were modified. /// -/// \param output_ptr[out] Pointer to the origin (not the zero position) of a -/// C-order contiguous "output" array with domain `output_box`. -/// \param mask[in,out] Non-null pointer to mask with domain `output_box`. +/// \param output[out] Output array. +/// \param mask[in,out] Non-null pointer to mask with domain `output.domain()`. /// \param input_to_output Transform to apply to the "output" array. Must be /// valid. /// \param source Source array to copy to the transformed output array. @@ -47,8 +46,7 @@ namespace internal { /// out-of-bounds index. /// \error `absl::StatusCode::kInvalidArgument` if integer overflow occurs /// computing output indices. -absl::Status WriteToMaskedArray(ElementPointer output_ptr, MaskData* mask, - BoxView<> output_box, +absl::Status WriteToMaskedArray(SharedOffsetArray output, MaskData* mask, IndexTransformView<> input_to_output, TransformedArray source, ElementCopyFunction::Closure copy_function);