Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-17259: [C++] Do not use shared_ptr<DataType> with kernel function signatures, do less copying of shared_ptrs #13753

Closed
wants to merge 11 commits into from
50 changes: 47 additions & 3 deletions cpp/src/arrow/compute/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,42 @@ std::shared_ptr<TypeMatcher> SameTypeId(Type::type type_id) {
return std::make_shared<SameTypeIdMatcher>(type_id);
}

class ListOfMatcher : public TypeMatcher {
public:
explicit ListOfMatcher(Type::type accepted_id, Type::type accepted_list_id)
: accepted_id_(accepted_id), accepted_list_id_(accepted_list_id) {}

bool Matches(const DataType& type) const override {
if (type.id() != accepted_list_id_) return false;
return checked_cast<const BaseListType&>(type).value_type()->id() == accepted_id_;
}

std::string ToString() const override {
std::stringstream ss;
ss << "list of Type::" << ::arrow::internal::ToString(accepted_id_);
return ss.str();
}

bool Equals(const TypeMatcher& other) const override {
if (this == &other) {
return true;
}
auto casted = dynamic_cast<const ListOfMatcher*>(&other);
if (casted == nullptr) {
return false;
}
return this->accepted_id_ == casted->accepted_id_;
}

private:
Type::type accepted_id_;
Type::type accepted_list_id_;
};

std::shared_ptr<TypeMatcher> ListOf(Type::type type_id, Type::type list_type_id) {
return std::make_shared<ListOfMatcher>(type_id, list_type_id);
}

template <typename ArrowType>
class TimeUnitMatcher : public TypeMatcher {
using ThisType = TimeUnitMatcher<ArrowType>;
Expand Down Expand Up @@ -280,6 +316,10 @@ std::shared_ptr<TypeMatcher> FixedSizeBinaryLike() {
// ----------------------------------------------------------------------
// InputType

InputType::InputType(const std::shared_ptr<DataType>& type) : InputType(type.get()) {
DCHECK(is_parameter_free(type->id()));
}

size_t InputType::Hash() const {
size_t result = kHashSeed;
hash_combine(result, static_cast<int>(kind_));
Expand Down Expand Up @@ -356,7 +396,7 @@ bool InputType::Matches(const Datum& value) const {
return Matches(*value.type());
}

const std::shared_ptr<DataType>& InputType::type() const {
const DataType* InputType::type() const {
DCHECK_EQ(InputType::EXACT_TYPE, kind_);
return type_;
}
Expand All @@ -369,16 +409,20 @@ const TypeMatcher& InputType::type_matcher() const {
// ----------------------------------------------------------------------
// OutputType

OutputType::OutputType(const std::shared_ptr<DataType>& type) : OutputType(type.get()) {
DCHECK(is_parameter_free(type->id()));
}

Result<TypeHolder> OutputType::Resolve(KernelContext* ctx,
const std::vector<TypeHolder>& types) const {
if (kind_ == OutputType::FIXED) {
return type_.get();
return type_;
} else {
return resolver_(ctx, types);
}
}

const std::shared_ptr<DataType>& OutputType::type() const {
const DataType* OutputType::type() const {
DCHECK_EQ(FIXED, kind_);
return type_;
}
Expand Down
28 changes: 18 additions & 10 deletions cpp/src/arrow/compute/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ ARROW_EXPORT std::shared_ptr<TypeMatcher> FixedSizeBinaryLike();
// Type)
ARROW_EXPORT std::shared_ptr<TypeMatcher> Primitive();

ARROW_EXPORT std::shared_ptr<TypeMatcher> ListOf(Type::type type_id,
Type::type list_type = Type::LIST);

} // namespace match

/// \brief An object used for type-checking arguments to be passed to a kernel
Expand All @@ -166,8 +169,10 @@ class ARROW_EXPORT InputType {
InputType() : kind_(ANY_TYPE) {}

/// \brief Accept an exact value type.
InputType(std::shared_ptr<DataType> type) // NOLINT implicit construction
: kind_(EXACT_TYPE), type_(std::move(type)) {}
InputType(const DataType* type) // NOLINT implicit construction
: kind_(EXACT_TYPE), type_(type) {}

InputType(const std::shared_ptr<DataType>& type); // NOLINT implicit construction

/// \brief Use the passed TypeMatcher to type check.
InputType(std::shared_ptr<TypeMatcher> type_matcher) // NOLINT implicit construction
Expand Down Expand Up @@ -216,7 +221,7 @@ class ARROW_EXPORT InputType {
/// \brief For InputType::EXACT_TYPE kind, the exact type that this InputType
/// must match. Otherwise this function should not be used and will assert in
/// debug builds.
const std::shared_ptr<DataType>& type() const;
const DataType* type() const;

/// \brief For InputType::USE_TYPE_MATCHER, the TypeMatcher to be used for
/// checking the type of a value. Otherwise this function should not be used
Expand All @@ -232,14 +237,14 @@ class ARROW_EXPORT InputType {

void MoveInto(InputType&& other) {
this->kind_ = other.kind_;
this->type_ = std::move(other.type_);
this->type_ = other.type_;
this->type_matcher_ = std::move(other.type_matcher_);
}

Kind kind_;

// For EXACT_TYPE Kind
std::shared_ptr<DataType> type_;
const DataType* type_ = NULLPTR;

// For USE_TYPE_MATCHER Kind
std::shared_ptr<TypeMatcher> type_matcher_;
Expand All @@ -261,8 +266,11 @@ class ARROW_EXPORT OutputType {
using Resolver = Result<TypeHolder> (*)(KernelContext*, const std::vector<TypeHolder>&);

/// \brief Output an exact type
OutputType(std::shared_ptr<DataType> type) // NOLINT implicit construction
: kind_(FIXED), type_(std::move(type)) {}
OutputType(const DataType* type) // NOLINT implicit construction
: kind_(FIXED), type_(type) {}

/// \brief Output an exact type
OutputType(const std::shared_ptr<DataType>& type); // NOLINT implicit construction

/// \brief Output a computed type depending on actual input types
OutputType(Resolver resolver) // NOLINT implicit construction
Expand All @@ -276,7 +284,7 @@ class ARROW_EXPORT OutputType {

OutputType(OutputType&& other) {
this->kind_ = other.kind_;
this->type_ = std::move(other.type_);
this->type_ = other.type_;
this->resolver_ = other.resolver_;
}

Expand All @@ -290,7 +298,7 @@ class ARROW_EXPORT OutputType {
const std::vector<TypeHolder>& args) const;

/// \brief The exact output value type for the FIXED kind.
const std::shared_ptr<DataType>& type() const;
const DataType* type() const;

/// \brief For use with COMPUTED resolution strategy. It may be more
/// convenient to invoke this with OutputType::Resolve returned from this
Expand All @@ -308,7 +316,7 @@ class ARROW_EXPORT OutputType {
ResolveKind kind_;

// For FIXED resolution
std::shared_ptr<DataType> type_;
const DataType* type_ = NULLPTR;

// For COMPUTED resolution
Resolver resolver_ = NULLPTR;
Expand Down
6 changes: 4 additions & 2 deletions cpp/src/arrow/compute/kernel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,11 @@ TEST(InputType, Equals) {
ASSERT_NE(InputType(int8()), InputType(Type::INT32));

// Check that field metadata excluded from equality checks
InputType t9 = list(
auto ty9 = list(
field("item", utf8(), /*nullable=*/true, key_value_metadata({"foo"}, {"bar"})));
InputType t10 = list(field("item", utf8()));
auto ty10 = list(field("item", utf8()));
InputType t9 = ty9.get();
InputType t10 = ty10.get();
ASSERT_TRUE(t9.Equals(t10));
}

Expand Down
Loading