diff --git a/c_glib/test/flight/test-command-descriptor.rb b/c_glib/test/flight/test-command-descriptor.rb index 316973287f08f..8fcf2d65fccdb 100644 --- a/c_glib/test/flight/test-command-descriptor.rb +++ b/c_glib/test/flight/test-command-descriptor.rb @@ -22,7 +22,7 @@ def setup def test_to_s descriptor = ArrowFlight::CommandDescriptor.new("command") - assert_equal("FlightDescriptor", + assert_equal("", descriptor.to_s) end diff --git a/c_glib/test/flight/test-path-descriptor.rb b/c_glib/test/flight/test-path-descriptor.rb index 441fc7bb04387..ccee99cfbf711 100644 --- a/c_glib/test/flight/test-path-descriptor.rb +++ b/c_glib/test/flight/test-path-descriptor.rb @@ -22,7 +22,7 @@ def setup def test_to_s descriptor = ArrowFlight::PathDescriptor.new(["a", "b", "c"]) - assert_equal("FlightDescriptor", + assert_equal("", descriptor.to_s) end diff --git a/cpp/src/arrow/flight/flight_internals_test.cc b/cpp/src/arrow/flight/flight_internals_test.cc index 9818cb207987c..3fef31b0ea462 100644 --- a/cpp/src/arrow/flight/flight_internals_test.cc +++ b/cpp/src/arrow/flight/flight_internals_test.cc @@ -18,6 +18,8 @@ // ---------------------------------------------------------------------- // Tests for Flight which don't actually spin up a client/server +#include + #include #include @@ -40,40 +42,250 @@ namespace pb = arrow::flight::protocol; // ---------------------------------------------------------------------- // Core Flight types -TEST(FlightTypes, FlightDescriptor) { - auto a = FlightDescriptor::Command("select * from table"); - auto b = FlightDescriptor::Command("select * from table"); - auto c = FlightDescriptor::Command("select foo from table"); - auto d = FlightDescriptor::Path({"foo", "bar"}); - auto e = FlightDescriptor::Path({"foo", "baz"}); - auto f = FlightDescriptor::Path({"foo", "baz"}); - - ASSERT_EQ(a.ToString(), "FlightDescriptor"); - ASSERT_EQ(d.ToString(), "FlightDescriptor"); - ASSERT_TRUE(a.Equals(b)); - ASSERT_FALSE(a.Equals(c)); - ASSERT_FALSE(a.Equals(d)); - ASSERT_FALSE(d.Equals(e)); - ASSERT_TRUE(e.Equals(f)); -} +template +void TestRoundtrip(const std::vector& values, + const std::vector& reprs) { + for (size_t i = 0; i < values.size(); i++) { + ARROW_SCOPED_TRACE("LHS = ", values[i].ToString()); + for (size_t j = 0; j < values.size(); j++) { + ARROW_SCOPED_TRACE("RHS = ", values[j].ToString()); + if (i == j) { + EXPECT_EQ(values[i], values[j]); + EXPECT_TRUE(values[i].Equals(values[j])); + } else { + EXPECT_NE(values[i], values[j]); + EXPECT_FALSE(values[i].Equals(values[j])); + } + } + EXPECT_EQ(values[i].ToString(), reprs[i]); + + ASSERT_OK_AND_ASSIGN(std::string serialized, values[i].SerializeToString()); + ASSERT_OK_AND_ASSIGN(auto deserialized, FlightType::Deserialize(serialized)); + if constexpr (std::is_same_v) { + EXPECT_EQ(values[i], *deserialized); + } else { + EXPECT_EQ(values[i], deserialized); + } // This tests the internal protobuf types which don't get exported in the Flight DLL. #ifndef _WIN32 -TEST(FlightTypes, FlightDescriptorToFromProto) { - FlightDescriptor descr_test; - pb::FlightDescriptor pb_descr; - - FlightDescriptor descr1{FlightDescriptor::PATH, "", {"foo", "bar"}}; - ASSERT_OK(internal::ToProto(descr1, &pb_descr)); - ASSERT_OK(internal::FromProto(pb_descr, &descr_test)); - ASSERT_EQ(descr1, descr_test); - - FlightDescriptor descr2{FlightDescriptor::CMD, "command", {}}; - ASSERT_OK(internal::ToProto(descr2, &pb_descr)); - ASSERT_OK(internal::FromProto(pb_descr, &descr_test)); - ASSERT_EQ(descr2, descr_test); -} + PbType pb_value; + ASSERT_OK(internal::ToProto(values[i], &pb_value)); + + if constexpr (std::is_same_v) { + FlightInfo::Data data; + ASSERT_OK(internal::FromProto(pb_value, &data)); + FlightInfo value(std::move(data)); + EXPECT_EQ(values[i], value); + } else if constexpr (std::is_same_v) { + std::string data; + ASSERT_OK(internal::FromProto(pb_value, &data)); + SchemaResult value(std::move(data)); + EXPECT_EQ(values[i], value); + } else { + FlightType value; + ASSERT_OK(internal::FromProto(pb_value, &value)); + EXPECT_EQ(values[i], value); + } #endif + } +} + +TEST(FlightTypes, Action) { + std::vector values = { + {"type", Buffer::FromString("")}, + {"type", Buffer::FromString("foo")}, + {"type", Buffer::FromString("bar")}, + }; + std::vector reprs = { + "", + "", + "", + }; + + ASSERT_NO_FATAL_FAILURE(TestRoundtrip(values, reprs)); + + // This doesn't roundtrip since we don't differentiate between no + // body and empty body on the wire + Action action{"", nullptr}; + ASSERT_EQ("", action.ToString()); + ASSERT_NE(values[0], action); + ASSERT_EQ(action, action); +} + +TEST(FlightTypes, ActionType) { + std::vector values = { + {"", ""}, + {"type", ""}, + {"type", "descr"}, + {"", "descr"}, + }; + std::vector reprs = { + "", + "", + "", + "", + }; + + ASSERT_NO_FATAL_FAILURE(TestRoundtrip(values, reprs)); +} + +TEST(FlightTypes, BasicAuth) { + std::vector values = { + {"", ""}, + {"user", ""}, + {"", "pass"}, + {"user", "pass"}, + }; + std::vector reprs = { + "", + "", + "", + "", + }; + + ASSERT_NO_FATAL_FAILURE(TestRoundtrip(values, reprs)); +} + +TEST(FlightTypes, Criteria) { + std::vector values = {{""}, {"criteria"}}; + std::vector reprs = {"", + ""}; + ASSERT_NO_FATAL_FAILURE(TestRoundtrip(values, reprs)); +} + +TEST(FlightTypes, FlightDescriptor) { + std::vector values = { + FlightDescriptor::Command(""), + FlightDescriptor::Command("\x01"), + FlightDescriptor::Command("select * from table"), + FlightDescriptor::Command("select foo from table"), + FlightDescriptor::Path({}), + FlightDescriptor::Path({"foo", "baz"}), + }; + std::vector reprs = { + "", + "", + "", + "", + "", + "", + }; + + ASSERT_NO_FATAL_FAILURE(TestRoundtrip(values, reprs)); +} + +TEST(FlightTypes, FlightEndpoint) { + ASSERT_OK_AND_ASSIGN(auto location1, Location::ForGrpcTcp("localhost", 1024)); + ASSERT_OK_AND_ASSIGN(auto location2, Location::ForGrpcTls("localhost", 1024)); + std::vector values = { + {{""}, {}}, + {{"foo"}, {}}, + {{"bar"}, {}}, + {{"foo"}, {location1}}, + {{"bar"}, {location1}}, + {{"foo"}, {location2}}, + {{"foo"}, {location1, location2}}, + }; + std::vector reprs = { + " locations=[]>", + " locations=[]>", + " locations=[]>", + " locations=" + "[grpc+tcp://localhost:1024]>", + " locations=" + "[grpc+tcp://localhost:1024]>", + " locations=" + "[grpc+tls://localhost:1024]>", + " locations=" + "[grpc+tcp://localhost:1024, grpc+tls://localhost:1024]>", + }; + + ASSERT_NO_FATAL_FAILURE(TestRoundtrip(values, reprs)); +} + +TEST(FlightTypes, FlightInfo) { + ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 1234)); + Schema schema1({field("ints", int64())}); + Schema schema2({}); + auto desc1 = FlightDescriptor::Command("foo"); + auto desc2 = FlightDescriptor::Command("bar"); + auto endpoint1 = FlightEndpoint{Ticket{"foo"}, {}}; + auto endpoint2 = FlightEndpoint{Ticket{"foo"}, {location}}; + std::vector values = { + MakeFlightInfo(schema1, desc1, {}, -1, -1), + MakeFlightInfo(schema1, desc2, {}, -1, -1), + MakeFlightInfo(schema2, desc1, {}, -1, -1), + MakeFlightInfo(schema1, desc1, {endpoint1}, -1, 42), + MakeFlightInfo(schema1, desc2, {endpoint1, endpoint2}, 64, -1), + }; + std::vector reprs = { + " " + "endpoints=[] total_records=-1 total_bytes=-1>", + " " + "endpoints=[] total_records=-1 total_bytes=-1>", + " " + "endpoints=[] total_records=-1 total_bytes=-1>", + " " + "endpoints=[ locations=[]>] " + "total_records=-1 total_bytes=42>", + " " + "endpoints=[ locations=[]>, " + " locations=" + "[grpc+tcp://localhost:1234]>] total_records=64 total_bytes=-1>", + }; + + ASSERT_NO_FATAL_FAILURE(TestRoundtrip(values, reprs)); +} + +TEST(FlightTypes, Result) { + std::vector values = { + {Buffer::FromString("")}, + {Buffer::FromString("foo")}, + {Buffer::FromString("bar")}, + }; + std::vector reprs = { + "", + "", + "", + }; + + ASSERT_NO_FATAL_FAILURE(TestRoundtrip(values, reprs)); + + // This doesn't roundtrip since we don't differentiate between no + // body and empty body on the wire + Result result{nullptr}; + ASSERT_EQ("", result.ToString()); + ASSERT_NE(values[0], result); + ASSERT_EQ(result, result); +} + +TEST(FlightTypes, SchemaResult) { + ASSERT_OK_AND_ASSIGN(auto value1, SchemaResult::Make(Schema({}))); + ASSERT_OK_AND_ASSIGN(auto value2, SchemaResult::Make(Schema({field("foo", int64())}))); + std::vector values = {*value1, *value2}; + std::vector reprs = { + "", + "", + }; + + ASSERT_NO_FATAL_FAILURE(TestRoundtrip(values, reprs)); +} + +TEST(FlightTypes, Ticket) { + std::vector values = { + {""}, + {"foo"}, + {"bar"}, + }; + std::vector reprs = { + "", + "", + "", + }; + + ASSERT_NO_FATAL_FAILURE(TestRoundtrip(values, reprs)); +} // ARROW-6017: we should be able to construct locations for unknown // schemes @@ -82,89 +294,6 @@ TEST(FlightTypes, LocationUnknownScheme) { ASSERT_OK(Location::Parse("https://example.com/foo")); } -TEST(FlightTypes, RoundTripTypes) { - ActionType action_type{"action-type1", "action-type1-description"}; - ASSERT_OK_AND_ASSIGN(std::string action_type_serialized, - action_type.SerializeToString()); - ASSERT_OK_AND_ASSIGN(ActionType action_type_deserialized, - ActionType::Deserialize(action_type_serialized)); - ASSERT_EQ(action_type, action_type_deserialized); - - Criteria criteria{"criteria1"}; - ASSERT_OK_AND_ASSIGN(std::string criteria_serialized, criteria.SerializeToString()); - ASSERT_OK_AND_ASSIGN(Criteria criteria_deserialized, - Criteria::Deserialize(criteria_serialized)); - ASSERT_EQ(criteria, criteria_deserialized); - - Action action{"action1", Buffer::FromString("action1-content")}; - ASSERT_OK_AND_ASSIGN(std::string action_serialized, action.SerializeToString()); - ASSERT_OK_AND_ASSIGN(Action action_deserialized, - Action::Deserialize(action_serialized)); - ASSERT_EQ(action, action_deserialized); - - Result result{Buffer::FromString("result1-content")}; - ASSERT_OK_AND_ASSIGN(std::string result_serialized, result.SerializeToString()); - ASSERT_OK_AND_ASSIGN(Result result_deserialized, - Result::Deserialize(result_serialized)); - ASSERT_EQ(result, result_deserialized); - - BasicAuth basic_auth{"username1", "password1"}; - ASSERT_OK_AND_ASSIGN(std::string basic_auth_serialized, basic_auth.SerializeToString()); - ASSERT_OK_AND_ASSIGN(BasicAuth basic_auth_deserialized, - BasicAuth::Deserialize(basic_auth_serialized)); - ASSERT_EQ(basic_auth, basic_auth_deserialized); - - SchemaResult schema_result{"schema_result1"}; - ASSERT_OK_AND_ASSIGN(std::string schema_result_serialized, - schema_result.SerializeToString()); - ASSERT_OK_AND_ASSIGN(SchemaResult schema_result_deserialized, - SchemaResult::Deserialize(schema_result_serialized)); - ASSERT_EQ(schema_result, schema_result_deserialized); - - Ticket ticket{"foo"}; - ASSERT_OK_AND_ASSIGN(std::string ticket_serialized, ticket.SerializeToString()); - ASSERT_OK_AND_ASSIGN(Ticket ticket_deserialized, - Ticket::Deserialize(ticket_serialized)); - ASSERT_EQ(ticket, ticket_deserialized); - - FlightDescriptor desc = FlightDescriptor::Command("select * from foo;"); - ASSERT_OK_AND_ASSIGN(std::string desc_serialized, desc.SerializeToString()); - ASSERT_OK_AND_ASSIGN(FlightDescriptor desc_deserialized, - FlightDescriptor::Deserialize(desc_serialized)); - ASSERT_EQ(desc, desc_deserialized); - - desc = FlightDescriptor::Path({"a", "b", "test.arrow"}); - ASSERT_OK_AND_ASSIGN(desc_serialized, desc.SerializeToString()); - ASSERT_OK_AND_ASSIGN(desc_deserialized, FlightDescriptor::Deserialize(desc_serialized)); - ASSERT_EQ(desc, desc_deserialized); - - FlightInfo::Data data; - std::shared_ptr schema = - arrow::schema({field("a", int64()), field("b", int64()), field("c", int64()), - field("d", int64())}); - ASSERT_OK_AND_ASSIGN(auto location1, Location::ForGrpcTcp("localhost", 10010)); - ASSERT_OK_AND_ASSIGN(auto location2, Location::ForGrpcTls("localhost", 10010)); - ASSERT_OK_AND_ASSIGN(auto location3, Location::ForGrpcUnix("/tmp/test.sock")); - std::vector endpoints{FlightEndpoint{ticket, {location1, location2}}, - FlightEndpoint{ticket, {location3}}}; - ASSERT_OK(MakeFlightInfo(*schema, desc, endpoints, -1, -1, &data)); - auto info = std::make_unique(data); - ASSERT_OK_AND_ASSIGN(std::string info_serialized, info->SerializeToString()); - ASSERT_OK_AND_ASSIGN(std::unique_ptr info_deserialized, - FlightInfo::Deserialize(info_serialized)); - ASSERT_EQ(info->descriptor(), info_deserialized->descriptor()); - ASSERT_EQ(info->endpoints(), info_deserialized->endpoints()); - ASSERT_EQ(info->total_records(), info_deserialized->total_records()); - ASSERT_EQ(info->total_bytes(), info_deserialized->total_bytes()); - - FlightEndpoint flight_endpoint{ticket, {location1, location2}}; - ASSERT_OK_AND_ASSIGN(std::string flight_endpoint_serialized, - flight_endpoint.SerializeToString()); - ASSERT_OK_AND_ASSIGN(FlightEndpoint flight_endpoint_deserialized, - FlightEndpoint::Deserialize(flight_endpoint_serialized)); - ASSERT_EQ(flight_endpoint, flight_endpoint_deserialized); -} - TEST(FlightTypes, RoundtripStatus) { // Make sure status codes round trip through our conversions diff --git a/cpp/src/arrow/flight/perf_server.cc b/cpp/src/arrow/flight/perf_server.cc index eb6ac92b83624..db3e8b150e048 100644 --- a/cpp/src/arrow/flight/perf_server.cc +++ b/cpp/src/arrow/flight/perf_server.cc @@ -195,10 +195,8 @@ class FlightPerfServer : public FlightServerBase { uint64_t total_records = perf_request.stream_count() * perf_request.records_per_stream(); - FlightInfo::Data data; - RETURN_NOT_OK( - MakeFlightInfo(*perf_schema_, request, endpoints, total_records, -1, &data)); - *info = std::make_unique(data); + *info = std::make_unique( + MakeFlightInfo(*perf_schema_, request, endpoints, total_records, -1)); return Status::OK(); } diff --git a/cpp/src/arrow/flight/test_util.cc b/cpp/src/arrow/flight/test_util.cc index a478aed998fed..0d6c28b296847 100644 --- a/cpp/src/arrow/flight/test_util.cc +++ b/cpp/src/arrow/flight/test_util.cc @@ -510,14 +510,12 @@ std::unique_ptr ExampleTestServer() { return std::make_unique(); } -Status MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor, - const std::vector& endpoints, int64_t total_records, - int64_t total_bytes, FlightInfo::Data* out) { - out->descriptor = descriptor; - out->endpoints = endpoints; - out->total_records = total_records; - out->total_bytes = total_bytes; - return internal::SchemaToString(schema, &out->schema); +FlightInfo MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor, + const std::vector& endpoints, + int64_t total_records, int64_t total_bytes) { + EXPECT_OK_AND_ASSIGN(auto info, FlightInfo::Make(schema, descriptor, endpoints, + total_records, total_bytes)); + return info; } NumberingStream::NumberingStream(std::unique_ptr stream) @@ -585,8 +583,6 @@ std::vector ExampleFlightInfo() { Location location4 = *Location::ForGrpcTcp("foo4.bar.com", 12345); Location location5 = *Location::ForGrpcTcp("foo5.bar.com", 12345); - FlightInfo::Data flight1, flight2, flight3, flight4; - FlightEndpoint endpoint1({{"ticket-ints-1"}, {location1}}); FlightEndpoint endpoint2({{"ticket-ints-2"}, {location2}}); FlightEndpoint endpoint3({{"ticket-cmd"}, {location3}}); @@ -603,13 +599,12 @@ std::vector ExampleFlightInfo() { auto schema3 = ExampleDictSchema(); auto schema4 = ExampleFloatSchema(); - ARROW_EXPECT_OK( - MakeFlightInfo(*schema1, descr1, {endpoint1, endpoint2}, 1000, 100000, &flight1)); - ARROW_EXPECT_OK(MakeFlightInfo(*schema2, descr2, {endpoint3}, 1000, 100000, &flight2)); - ARROW_EXPECT_OK(MakeFlightInfo(*schema3, descr3, {endpoint4}, -1, -1, &flight3)); - ARROW_EXPECT_OK(MakeFlightInfo(*schema4, descr4, {endpoint5}, 1000, 100000, &flight4)); - return {FlightInfo(flight1), FlightInfo(flight2), FlightInfo(flight3), - FlightInfo(flight4)}; + return { + MakeFlightInfo(*schema1, descr1, {endpoint1, endpoint2}, 1000, 100000), + MakeFlightInfo(*schema2, descr2, {endpoint3}, 1000, 100000), + MakeFlightInfo(*schema3, descr3, {endpoint4}, -1, -1), + MakeFlightInfo(*schema4, descr4, {endpoint5}, 1000, 100000), + }; } Status ExampleIntBatches(RecordBatchVector* out) { diff --git a/cpp/src/arrow/flight/test_util.h b/cpp/src/arrow/flight/test_util.h index c0c6d7514e169..679e04fa1b14f 100644 --- a/cpp/src/arrow/flight/test_util.h +++ b/cpp/src/arrow/flight/test_util.h @@ -190,9 +190,9 @@ ARROW_FLIGHT_EXPORT std::vector ExampleActionTypes(); ARROW_FLIGHT_EXPORT -Status MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor, - const std::vector& endpoints, int64_t total_records, - int64_t total_bytes, FlightInfo::Data* out); +FlightInfo MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor, + const std::vector& endpoints, + int64_t total_records, int64_t total_bytes); // ---------------------------------------------------------------------- // A pair of authentication handlers that check for a predefined password diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index a09f09ff9dbe1..b051ec7081a09 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -29,6 +29,7 @@ #include "arrow/ipc/reader.h" #include "arrow/status.h" #include "arrow/table.h" +#include "arrow/util/string_builder.h" #include "arrow/util/uri.h" namespace arrow { @@ -105,11 +106,11 @@ bool FlightDescriptor::Equals(const FlightDescriptor& other) const { std::string FlightDescriptor::ToString() const { std::stringstream ss; - ss << "FlightDescriptor<"; + ss << ""; +} + bool SchemaResult::Equals(const SchemaResult& other) const { return raw_schema_ == other.raw_schema_; } @@ -225,6 +230,12 @@ Status FlightDescriptor::Deserialize(const std::string& serialized, return Deserialize(serialized).Value(out); } +std::string Ticket::ToString() const { + std::stringstream ss; + ss << ""; + return ss.str(); +} + bool Ticket::Equals(const Ticket& other) const { return ticket == other.ticket; } arrow::Result Ticket::SerializeToString() const { @@ -326,6 +337,36 @@ Status FlightInfo::Deserialize(const std::string& serialized, return Deserialize(serialized).Value(out); } +std::string FlightInfo::ToString() const { + std::stringstream ss; + ss << ""; + return ss.str(); +} + bool FlightEndpoint::Equals(const FlightEndpoint& other) const { return ticket == other.ticket && locations == other.locations; } @@ -423,6 +478,11 @@ arrow::Result FlightEndpoint::Deserialize(std::string_view seria return out; } +std::string ActionType::ToString() const { + return arrow::util::StringBuilder(""); +} + bool ActionType::Equals(const ActionType& other) const { return type == other.type && description == other.description; } @@ -453,6 +513,10 @@ arrow::Result ActionType::Deserialize(std::string_view serialized) { return out; } +std::string Criteria::ToString() const { + return arrow::util::StringBuilder(""); +} + bool Criteria::Equals(const Criteria& other) const { return expression == other.expression; } @@ -483,6 +547,19 @@ arrow::Result Criteria::Deserialize(std::string_view serialized) { return out; } +std::string Action::ToString() const { + std::stringstream ss; + ss << "size() << " bytes)"; + } else { + ss << "(nullptr)"; + } + ss << '>'; + return ss.str(); +} + bool Action::Equals(const Action& other) const { return (type == other.type) && ((body == other.body) || (body && other.body && body->Equals(*other.body))); @@ -514,6 +591,17 @@ arrow::Result Action::Deserialize(std::string_view serialized) { return out; } +std::string Result::ToString() const { + std::stringstream ss; + ss << "size() << " bytes)>"; + } else { + ss << "(nullptr)>"; + } + return ss.str(); +} + bool Result::Equals(const Result& other) const { return (body == other.body) || (body && other.body && body->Equals(*other.body)); } @@ -638,6 +726,11 @@ arrow::Result> SimpleResultStream::Next() { return std::make_unique(std::move(results_[position_++])); } +std::string BasicAuth::ToString() const { + return arrow::util::StringBuilder(""); +} + bool BasicAuth::Equals(const BasicAuth& other) const { return (username == other.username) && (password == other.password); } diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 6957c5992a328..39353bcb9977a 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -140,6 +140,7 @@ struct ARROW_FLIGHT_EXPORT ActionType { /// \brief A human-readable description of the action. std::string description; + std::string ToString() const; bool Equals(const ActionType& other) const; friend bool operator==(const ActionType& left, const ActionType& right) { @@ -161,6 +162,7 @@ struct ARROW_FLIGHT_EXPORT Criteria { /// Opaque criteria expression, dependent on server implementation std::string expression; + std::string ToString() const; bool Equals(const Criteria& other) const; friend bool operator==(const Criteria& left, const Criteria& right) { @@ -185,6 +187,7 @@ struct ARROW_FLIGHT_EXPORT Action { /// The action content as a Buffer std::shared_ptr body; + std::string ToString() const; bool Equals(const Action& other) const; friend bool operator==(const Action& left, const Action& right) { @@ -205,6 +208,7 @@ struct ARROW_FLIGHT_EXPORT Action { struct ARROW_FLIGHT_EXPORT Result { std::shared_ptr body; + std::string ToString() const; bool Equals(const Result& other) const; friend bool operator==(const Result& left, const Result& right) { @@ -226,6 +230,7 @@ struct ARROW_FLIGHT_EXPORT BasicAuth { std::string username; std::string password; + std::string ToString() const; bool Equals(const BasicAuth& other) const; friend bool operator==(const BasicAuth& left, const BasicAuth& right) { @@ -312,6 +317,7 @@ struct ARROW_FLIGHT_EXPORT FlightDescriptor { struct ARROW_FLIGHT_EXPORT Ticket { std::string ticket; + std::string ToString() const; bool Equals(const Ticket& other) const; friend bool operator==(const Ticket& left, const Ticket& right) { @@ -429,6 +435,7 @@ struct ARROW_FLIGHT_EXPORT FlightEndpoint { /// generated std::vector locations; + std::string ToString() const; bool Equals(const FlightEndpoint& other) const; friend bool operator==(const FlightEndpoint& left, const FlightEndpoint& right) { @@ -469,7 +476,7 @@ struct ARROW_FLIGHT_EXPORT SchemaResult { /// \brief return schema /// \param[in,out] dictionary_memo for dictionary bookkeeping, will /// be modified - /// \return Arrrow result with the reconstructed Schema + /// \return Arrow result with the reconstructed Schema arrow::Result> GetSchema( ipc::DictionaryMemo* dictionary_memo) const; @@ -479,6 +486,7 @@ struct ARROW_FLIGHT_EXPORT SchemaResult { const std::string& serialized_schema() const { return raw_schema_; } + std::string ToString() const; bool Equals(const SchemaResult& other) const; friend bool operator==(const SchemaResult& left, const SchemaResult& right) { @@ -510,9 +518,7 @@ class ARROW_FLIGHT_EXPORT FlightInfo { int64_t total_bytes; }; - explicit FlightInfo(const Data& data) : data_(data), reconstructed_schema_(false) {} - explicit FlightInfo(Data&& data) - : data_(std::move(data)), reconstructed_schema_(false) {} + explicit FlightInfo(Data data) : data_(std::move(data)), reconstructed_schema_(false) {} /// \brief Factory method to construct a FlightInfo. static arrow::Result Make(const Schema& schema, @@ -568,6 +574,20 @@ class ARROW_FLIGHT_EXPORT FlightInfo { static Status Deserialize(const std::string& serialized, std::unique_ptr* out); + std::string ToString() const; + + /// Compare two FlightInfo for equality. This will compare the + /// serialized schema representations, NOT the logical equality of + /// the schemas. + bool Equals(const FlightInfo& other) const; + + friend bool operator==(const FlightInfo& left, const FlightInfo& right) { + return left.Equals(right); + } + friend bool operator!=(const FlightInfo& left, const FlightInfo& right) { + return !(left == right); + } + private: Data data_; mutable std::shared_ptr schema_; diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index 7feee8cf7b443..db40d35f5ca67 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -314,6 +314,10 @@ cdef class Action(_Weakrefable): def __eq__(self, Action other): return self.action == other.action + def __repr__(self): + return (f"") + _ActionType = collections.namedtuple('_ActionType', ['type', 'description']) @@ -377,6 +381,9 @@ cdef class Result(_Weakrefable): def __eq__(self, Result other): return deref(self.result.get()) == deref(other.result.get()) + def __repr__(self): + return f"" + cdef class BasicAuth(_Weakrefable): """A container for basic auth.""" @@ -420,6 +427,10 @@ cdef class BasicAuth(_Weakrefable): def __eq__(self, BasicAuth other): return deref(self.basic_auth.get()) == deref(other.basic_auth.get()) + def __repr__(self): + return (f"") + class DescriptorType(enum.Enum): """ @@ -537,11 +548,11 @@ cdef class FlightDescriptor(_Weakrefable): def __repr__(self): if self.descriptor_type == DescriptorType.PATH: - return "".format(self.path) + return f"" elif self.descriptor_type == DescriptorType.CMD: - return "".format(self.command) + return f"" else: - return "".format(self.descriptor_type) + return f"" @staticmethod cdef CFlightDescriptor unwrap(descriptor) except *: @@ -581,14 +592,14 @@ cdef class Ticket(_Weakrefable): """A ticket for requesting a Flight stream.""" cdef: - CTicket ticket + CTicket c_ticket def __init__(self, ticket): - self.ticket.ticket = tobytes(ticket) + self.c_ticket.ticket = tobytes(ticket) @property def ticket(self): - return self.ticket.ticket + return self.c_ticket.ticket def serialize(self): """Get the wire-format representation of this type. @@ -597,7 +608,7 @@ cdef class Ticket(_Weakrefable): services) that may want to return Flight types. """ - return GetResultValue(self.ticket.SerializeToString()) + return GetResultValue(self.c_ticket.SerializeToString()) @classmethod def deserialize(cls, serialized): @@ -608,15 +619,15 @@ cdef class Ticket(_Weakrefable): """ cdef Ticket ticket = Ticket.__new__(Ticket) - ticket.ticket = GetResultValue( + ticket.c_ticket = GetResultValue( CTicket.Deserialize(tobytes(serialized))) return ticket def __eq__(self, Ticket other): - return self.ticket == other.ticket + return self.c_ticket == other.c_ticket def __repr__(self): - return ''.format(self.ticket.ticket) + return f"" cdef class Location(_Weakrefable): @@ -628,7 +639,7 @@ cdef class Location(_Weakrefable): check_flight_status(CLocation.Parse(tobytes(uri)).Value(&self.location)) def __repr__(self): - return ''.format(self.location.ToString()) + return f'' @property def uri(self): @@ -762,15 +773,15 @@ cdef class FlightEndpoint(_Weakrefable): return endpoint def __repr__(self): - return "".format( - self.ticket, self.locations) + return (f"") def __eq__(self, FlightEndpoint other): return self.endpoint == other.endpoint cdef class SchemaResult(_Weakrefable): - """A result from a getschema request. Holding a schema""" + """The serialized schema returned from a GetSchema request.""" cdef: unique_ptr[CSchemaResult] result @@ -821,6 +832,9 @@ cdef class SchemaResult(_Weakrefable): def __eq__(self, SchemaResult other): return deref(self.result.get()) == deref(other.result.get()) + def __repr__(self): + return f"" + cdef class FlightInfo(_Weakrefable): """A description of a Flight stream.""" @@ -926,6 +940,16 @@ cdef class FlightInfo(_Weakrefable): CFlightInfo.Deserialize(tobytes(serialized)))) return info + def __eq__(self, FlightInfo other): + return deref(self.info.get()) == deref(other.info.get()) + + def __repr__(self): + return (f"") + cdef class FlightStreamChunk(_Weakrefable): """A RecordBatch with application metadata on the side.""" @@ -1538,7 +1562,7 @@ cdef class FlightClient(_Weakrefable): with nogil: check_flight_status( self.client.get().DoGet( - deref(c_options), ticket.ticket).Value(&reader)) + deref(c_options), ticket.c_ticket).Value(&reader)) result = FlightStreamReader() result.reader.reset(reader.release()) return result diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index 3301c1b6360b2..34ba809438e2c 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -39,6 +39,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: shared_ptr[CBuffer] body bint operator==(CAction) CResult[c_string] SerializeToString() + c_string ToString() @staticmethod CResult[CAction] Deserialize(const c_string& serialized) @@ -49,6 +50,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: shared_ptr[CBuffer] body bint operator==(CFlightResult) CResult[c_string] SerializeToString() + c_string ToString() @staticmethod CResult[CFlightResult] Deserialize(const c_string& serialized) @@ -61,6 +63,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: c_string password bint operator==(CBasicAuth) CResult[c_string] SerializeToString() + c_string ToString() @staticmethod CResult[CBasicAuth] Deserialize(const c_string& serialized) @@ -85,6 +88,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: vector[c_string] path bint operator==(CFlightDescriptor) CResult[c_string] SerializeToString() + c_string ToString() @staticmethod CResult[CFlightDescriptor] Deserialize(const c_string& serialized) @@ -94,6 +98,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: c_string ticket bint operator==(CTicket) CResult[c_string] SerializeToString() + c_string ToString() @staticmethod CResult[CTicket] Deserialize(const c_string& serialized) @@ -132,6 +137,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: bint operator==(CFlightEndpoint) CResult[c_string] SerializeToString() + c_string ToString() @staticmethod CResult[CFlightEndpoint] Deserialize(const c_string& serialized) @@ -144,6 +150,8 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: CFlightDescriptor& descriptor() const vector[CFlightEndpoint]& endpoints() CResult[c_string] SerializeToString() + c_string ToString() + bint operator==(CFlightInfo) @staticmethod CResult[unique_ptr[CFlightInfo]] Deserialize( @@ -155,6 +163,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: CResult[shared_ptr[CSchema]] GetSchema(CDictionaryMemo* memo) bint operator==(CSchemaResult) CResult[c_string] SerializeToString() + c_string ToString() @staticmethod CResult[CSchemaResult] Deserialize(const c_string& serialized) diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index 28ace4f93e3d3..dafa9e6011ce6 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -870,6 +870,79 @@ def do_exchange(self, context, descriptor, reader, writer): writer.write_metadata(self._metadata) +def test_repr(): + action_repr = "" + action_type_repr = "ActionType(type='foo', description='bar')" + basic_auth_repr = "" + descriptor_repr = "" + endpoint_repr = (" " + "locations=[]>") + info_repr = ( + " " + "endpoints=[] " + "total_records=-1 " + "total_bytes=-1>") + location_repr = "" + result_repr = "" + schema_result_repr = "" + ticket_repr = "" + + assert repr(flight.Action("foo", b"")) == action_repr + assert repr(flight.ActionType("foo", "bar")) == action_type_repr + assert repr(flight.BasicAuth("user", "pass")) == basic_auth_repr + assert repr(flight.FlightDescriptor.for_command("foo")) == descriptor_repr + assert repr(flight.FlightEndpoint(b"foo", [])) == endpoint_repr + info = flight.FlightInfo( + pa.schema([]), flight.FlightDescriptor.for_path(), [], -1, -1) + assert repr(info) == info_repr + assert repr(flight.Location("grpc+tcp://localhost:1234")) == location_repr + assert repr(flight.Result(b"foo")) == result_repr + assert repr(flight.SchemaResult(pa.schema([]))) == schema_result_repr + assert repr(flight.SchemaResult(pa.schema([("int", "int64")]))) == \ + "" + assert repr(flight.Ticket(b"foo")) == ticket_repr + + with pytest.raises(TypeError): + flight.Action("foo", None) + + +def test_eq(): + items = [ + lambda: (flight.Action("foo", b""), flight.Action("foo", b"bar")), + lambda: (flight.ActionType("foo", "bar"), + flight.ActionType("foo", "baz")), + lambda: (flight.BasicAuth("user", "pass"), + flight.BasicAuth("user2", "pass")), + lambda: (flight.FlightDescriptor.for_command("foo"), + flight.FlightDescriptor.for_path("foo")), + lambda: (flight.FlightEndpoint(b"foo", []), + flight.FlightEndpoint(b"", [])), + lambda: ( + flight.FlightInfo( + pa.schema([]), + flight.FlightDescriptor.for_path(), [], -1, -1), + flight.FlightInfo( + pa.schema([]), + flight.FlightDescriptor.for_command(b"foo"), [], -1, 42)), + lambda: (flight.Location("grpc+tcp://localhost:1234"), + flight.Location("grpc+tls://localhost:1234")), + lambda: (flight.Result(b"foo"), flight.Result(b"bar")), + lambda: (flight.SchemaResult(pa.schema([])), + flight.SchemaResult(pa.schema([("ints", pa.int64())]))), + lambda: (flight.Ticket(b""), flight.Ticket(b"foo")), + ] + + for gen in items: + lhs1, rhs1 = gen() + lhs2, rhs2 = gen() + assert lhs1 == lhs2 + assert rhs1 == rhs2 + assert lhs1 != rhs1 + + def test_flight_server_location_argument(): locations = [ None,