Skip to content

Commit

Permalink
Save changes test stil fail
Browse files Browse the repository at this point in the history
  • Loading branch information
Drew Hubley committed Jul 11, 2024
1 parent 3a8c71a commit dd8dbda
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 74 deletions.
50 changes: 45 additions & 5 deletions include/xtensor/xstrided_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,35 @@ namespace xt

using inner_storage_type = typename base_type::inner_storage_type;
using storage_type = typename base_type::storage_type;
using linear_iterator = typename storage_type::iterator;
using const_linear_iterator = typename storage_type::const_iterator;
using reverse_linear_iterator = std::reverse_iterator<linear_iterator>;
using const_reverse_linear_iterator = std::reverse_iterator<const_linear_iterator>;

template <class C, class = void_t<>>
struct get_linear_iterator : std::false_type
{
using iterator = typename C::iterator;
};

template<typename C>
struct get_linear_iterator<C, void_t<decltype(std::declval<C>().linear_begin())>> : std::true_type
{
using iterator = typename C::linear_iterator;
};

template <class C, class = void_t<>>
struct get_const_linear_iterator : std::false_type
{
using iterator = typename C::const_iterator;
};

template<typename C>
struct get_const_linear_iterator<C, void_t<decltype(std::declval<C>().linear_cbegin())>> : std::true_type
{
using iterator = typename C::const_linear_iterator;
};

using linear_iterator = typename get_linear_iterator<storage_type>::iterator;
using const_linear_iterator = typename get_const_linear_iterator<storage_type>::iterator;
using reverse_linear_iterator = std::reverse_iterator<typename get_linear_iterator<storage_type>::iterator>;
using const_reverse_linear_iterator = std::reverse_iterator<typename get_const_linear_iterator<storage_type>::iterator>;

using iterable_base = select_iterable_base_t<L, xexpression_type::static_layout, self_type>;
using inner_shape_type = typename base_type::inner_shape_type;
Expand Down Expand Up @@ -222,6 +247,9 @@ namespace xt
const_linear_iterator linear_begin() const;
const_linear_iterator linear_end() const;
const_linear_iterator linear_cbegin() const;
const_linear_iterator linear_cbegin(std::true_type) const;
const_linear_iterator linear_cbegin(std::false_type) const;

const_linear_iterator linear_cend() const;

reverse_linear_iterator linear_rbegin();
Expand Down Expand Up @@ -509,11 +537,23 @@ namespace xt
}

template <class CT, class S, layout_type L, class FST>
inline auto xstrided_view<CT, S, L, FST>::linear_cbegin() const -> const_linear_iterator
inline auto xstrided_view<CT, S, L, FST>::linear_cbegin(std::true_type) const -> const_linear_iterator
{
return this->storage().linear_cbegin() + static_cast<std::ptrdiff_t>(data_offset());
}

template <class CT, class S, layout_type L, class FST>
inline auto xstrided_view<CT, S, L, FST>::linear_cbegin(std::false_type) const -> const_linear_iterator
{
return this->storage().cbegin() + static_cast<std::ptrdiff_t>(data_offset());
}

template <class CT, class S, layout_type L, class FST>
inline auto xstrided_view<CT, S, L, FST>::linear_cbegin() const -> const_linear_iterator
{
return linear_cbegin(get_const_linear_iterator<storage_type>());
}

template <class CT, class S, layout_type L, class FST>
inline auto xstrided_view<CT, S, L, FST>::linear_cend() const -> const_linear_iterator
{
Expand Down
97 changes: 28 additions & 69 deletions include/xtensor/xstrided_view_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ namespace xt
using reverse_iterator = decltype(std::declval<std::remove_reference_t<CT>>().template rbegin<L>());
using const_reverse_iterator = decltype(std::declval<std::decay_t<CT>>().template crbegin<L>());


explicit flat_expression_adaptor(CT* e);

template <class FST>
Expand Down Expand Up @@ -76,7 +77,7 @@ namespace xt
};

template <class CT, layout_type L>
class linear_flat_expression_adaptor
class linear_flat_expression_adaptor : public flat_expression_adaptor<CT, L>
{
public:

Expand All @@ -92,28 +93,24 @@ namespace xt
typename xexpression_type::const_reference,
typename xexpression_type::reference>;

using iterator = decltype(std::declval<std::remove_reference_t<CT>>().linear_begin());
using const_iterator = decltype(std::declval<std::decay_t<CT>>().linear_cbegin());
using reverse_iterator = decltype(std::declval<std::remove_reference_t<CT>>().linear_rbegin());
using const_reverse_iterator = decltype(std::declval<std::decay_t<CT>>().linear_crbegin());

using linear_iterator = decltype(std::declval<std::remove_reference_t<CT>>().linear_begin());
using const_linear_iterator = decltype(std::declval<std::decay_t<CT>>().linear_cbegin());
using reverse_linear_iterator = decltype(std::declval<std::remove_reference_t<CT>>().linear_rbegin());
using const_reverse_linear_iterator = decltype(std::declval<std::decay_t<CT>>().linear_crbegin());


explicit linear_flat_expression_adaptor(CT* e);

template <class FST>
linear_flat_expression_adaptor(CT* e, FST&& strides);

void update_pointer(CT* ptr) const;

size_type size() const;
reference operator[](size_type idx);
const_reference operator[](size_type idx) const;

iterator begin();
iterator end();
const_iterator begin() const;
const_iterator end() const;
const_iterator cbegin() const;
const_iterator cend() const;
linear_iterator linear_begin();
linear_iterator linear_end();
const_linear_iterator linear_begin() const;
const_linear_iterator linear_end() const;
const_linear_iterator linear_cbegin() const;
const_linear_iterator linear_cend() const;

private:

Expand Down Expand Up @@ -307,7 +304,7 @@ namespace xt
template <class CT, layout_type L>
struct flat_adaptor_getter
{
using type = std::conditional_t<detail::has_linear_iterator<std::remove_reference_t<CT>>::value,
using type = std::conditional_t<detail::has_linear_iterator<std::remove_reference_t<CT>>::value && L != xt::layout_type::dynamic,
linear_flat_expression_adaptor<std::remove_reference_t<CT>, L>,
flat_expression_adaptor<std::remove_reference_t<CT>, L>>;
using reference = std::add_lvalue_reference_t<CT>;
Expand Down Expand Up @@ -638,7 +635,7 @@ namespace xt
*/
template <class D>
inline auto xstrided_view_base<D>::storage() const noexcept -> const storage_type&
{
{
return m_storage;
}

Expand Down Expand Up @@ -721,7 +718,7 @@ namespace xt
template <class O>
inline bool xstrided_view_base<D>::has_linear_assign(const O& str) const noexcept
{
return detail::has_linear_iterator<xexpression_type>::value && str.size() == strides().size()
return has_data_interface<xexpression_type>::value && str.size() == strides().size()
&& std::equal(str.cbegin(), str.cend(), strides().begin());
}

Expand Down Expand Up @@ -855,95 +852,57 @@ namespace xt

template <class CT, layout_type L>
inline linear_flat_expression_adaptor<CT, L>::linear_flat_expression_adaptor(CT* e)
: m_e(e)
:
flat_expression_adaptor<CT,L>(e),
m_e(e)
{
resize_container(get_index(), m_e->dimension());
resize_container(m_strides, m_e->dimension());
m_size = compute_size(m_e->shape());
compute_strides(m_e->shape(), L, m_strides);
}

template <class CT, layout_type L>
template <class FST>
inline linear_flat_expression_adaptor<CT, L>::linear_flat_expression_adaptor(CT* e, FST&& strides)
: m_e(e)
: flat_expression_adaptor<CT, L>(e, strides)
, m_e(e)
, m_strides(xtl::forward_sequence<inner_strides_type, FST>(strides))
{
resize_container(get_index(), m_e->dimension());
m_size = m_e->size();
}

template <class CT, layout_type L>
inline void linear_flat_expression_adaptor<CT, L>::update_pointer(CT* ptr) const
{
m_e = ptr;
}

template <class CT, layout_type L>
inline auto linear_flat_expression_adaptor<CT, L>::size() const -> size_type
{
return m_size;
}

template <class CT, layout_type L>
inline auto linear_flat_expression_adaptor<CT, L>::operator[](size_type idx) -> reference
{
auto i = static_cast<typename index_type::value_type>(idx);
get_index() = detail::unravel_noexcept(i, m_strides, L);
return m_e->element(get_index().cbegin(), get_index().cend());
}

template <class CT, layout_type L>
inline auto linear_flat_expression_adaptor<CT, L>::operator[](size_type idx) const -> const_reference
{
auto i = static_cast<typename index_type::value_type>(idx);
get_index() = detail::unravel_noexcept(i, m_strides, L);
return m_e->element(get_index().cbegin(), get_index().cend());
}

template <class CT, layout_type L>
inline auto linear_flat_expression_adaptor<CT, L>::begin() -> iterator
inline auto linear_flat_expression_adaptor<CT, L>::linear_begin() -> linear_iterator
{
return m_e->linear_begin();
}

template <class CT, layout_type L>
inline auto linear_flat_expression_adaptor<CT, L>::end() -> iterator
inline auto linear_flat_expression_adaptor<CT, L>::linear_end() -> linear_iterator
{
return m_e->linear_end();
}

template <class CT, layout_type L>
inline auto linear_flat_expression_adaptor<CT, L>::begin() const -> const_iterator
inline auto linear_flat_expression_adaptor<CT, L>::linear_begin() const -> const_linear_iterator
{
return m_e->linear_cbegin();
}

template <class CT, layout_type L>
inline auto linear_flat_expression_adaptor<CT, L>::end() const -> const_iterator
inline auto linear_flat_expression_adaptor<CT, L>::linear_end() const -> const_linear_iterator
{
return m_e->linear_cend();
}

template <class CT, layout_type L>
inline auto linear_flat_expression_adaptor<CT, L>::cbegin() const -> const_iterator
inline auto linear_flat_expression_adaptor<CT, L>::linear_cbegin() const -> const_linear_iterator
{
return m_e->linear_cbegin();
}

template <class CT, layout_type L>
inline auto linear_flat_expression_adaptor<CT, L>::cend() const -> const_iterator
inline auto linear_flat_expression_adaptor<CT, L>::linear_cend() const -> const_linear_iterator
{
return m_e->linear_cend();
}

template <class CT, layout_type L>
inline auto linear_flat_expression_adaptor<CT, L>::get_index() -> index_type&
{
thread_local static index_type index;
return index;
}

}

/**********************************
Expand Down

0 comments on commit dd8dbda

Please sign in to comment.