diff --git a/google/cloud/spanner/client.cc b/google/cloud/spanner/client.cc index 904542ef2a5a5..cdfe95d28659d 100644 --- a/google/cloud/spanner/client.cc +++ b/google/cloud/spanner/client.cc @@ -357,7 +357,7 @@ std::shared_ptr MakeConnection(spanner::Database const& db, return spanner_internal::CreateDefaultSpannerStub(db, auth, opts, id++); }); return std::make_shared( - std::move(db), std::move(background), std::move(stubs), opts); + std::move(db), std::move(background), std::move(stubs), std::move(opts)); } std::shared_ptr MakeConnection( diff --git a/google/cloud/spanner/client.h b/google/cloud/spanner/client.h index 200a770489d59..3b38da8ae4530 100644 --- a/google/cloud/spanner/client.h +++ b/google/cloud/spanner/client.h @@ -133,7 +133,7 @@ class Client { */ explicit Client(std::shared_ptr conn, Options opts = {}) : conn_(std::move(conn)), - opts_(spanner_internal::DefaultOptions(std::move(opts))) {} + opts_(internal::MergeOptions(std::move(opts), conn_->options())) {} //@{ /// Backwards compatibility for `ClientOptions`. diff --git a/google/cloud/spanner/client_test.cc b/google/cloud/spanner/client_test.cc index d990347133fb7..fa74db45b4c02 100644 --- a/google/cloud/spanner/client_test.cc +++ b/google/cloud/spanner/client_test.cc @@ -14,6 +14,7 @@ #include "google/cloud/spanner/client.h" #include "google/cloud/spanner/connection.h" +#include "google/cloud/spanner/internal/defaults.h" #include "google/cloud/spanner/mocks/mock_spanner_connection.h" #include "google/cloud/spanner/mutations.h" #include "google/cloud/spanner/results.h" @@ -28,6 +29,7 @@ #include #include #include +#include #include namespace google { @@ -46,6 +48,7 @@ using ::google::protobuf::TextFormat; using ::testing::ByMove; using ::testing::DoAll; using ::testing::ElementsAre; +using ::testing::Eq; using ::testing::HasSubstr; using ::testing::Return; using ::testing::SaveArg; @@ -389,7 +392,17 @@ TEST(ClientTest, MakeConnectionOptionalArguments) { EXPECT_NE(conn, nullptr); conn = MakeConnection(db, Options{}); - EXPECT_NE(conn, nullptr); + ASSERT_NE(conn, nullptr); + ASSERT_TRUE(conn->options().has()); + EXPECT_EQ(conn->options().get(), + spanner_internal::DefaultOptions().get()); + + conn = MakeConnection(db, Options{}.set("endpoint")); + ASSERT_NE(conn, nullptr); + ASSERT_TRUE(conn->options().has()); + EXPECT_NE(conn->options().get(), + spanner_internal::DefaultOptions().get()); + EXPECT_EQ(conn->options().get(), "endpoint"); } TEST(ClientTest, CommitMutatorSuccess) { @@ -1192,6 +1205,71 @@ TEST(ClientTest, QueryOptionsOverlayPrecedence) { } } +struct StringOption { + using Type = std::string; +}; + +TEST(ClientTest, UsesConnectionOptions) { + auto conn = std::make_shared(); + auto txn = MakeReadWriteTransaction(); + + EXPECT_CALL(*conn, options).WillOnce([] { + return Options{}.set("connection"); + }); + EXPECT_CALL(*conn, Rollback) + .WillOnce([txn](Connection::RollbackParams const& params) { + auto const& options = internal::CurrentOptions(); + EXPECT_THAT(options.get(), Eq("connection")); + EXPECT_THAT(params.transaction, Eq(txn)); + return Status(); + }); + + Client client(conn, Options{}); + auto rollback = client.Rollback(txn, Options{}); + EXPECT_STATUS_OK(rollback); +} + +TEST(ClientTest, UsesClientOptions) { + auto conn = std::make_shared(); + auto txn = MakeReadWriteTransaction(); + + EXPECT_CALL(*conn, options).WillOnce([] { + return Options{}.set("connection"); + }); + EXPECT_CALL(*conn, Rollback) + .WillOnce([txn](Connection::RollbackParams const& params) { + auto const& options = internal::CurrentOptions(); + EXPECT_THAT(options.get(), Eq("client")); + EXPECT_THAT(params.transaction, Eq(txn)); + return Status(); + }); + + Client client(conn, Options{}.set("client")); + auto rollback = client.Rollback(txn, Options{}); + EXPECT_STATUS_OK(rollback); +} + +TEST(ClientTest, UsesOperationOptions) { + auto conn = std::make_shared(); + auto txn = MakeReadWriteTransaction(); + + EXPECT_CALL(*conn, options).WillOnce([] { + return Options{}.set("connection"); + }); + EXPECT_CALL(*conn, Rollback) + .WillOnce([txn](Connection::RollbackParams const& params) { + auto const& options = internal::CurrentOptions(); + EXPECT_THAT(options.get(), Eq("operation")); + EXPECT_THAT(params.transaction, Eq(txn)); + return Status(); + }); + + Client client(conn, Options{}.set("client")); + auto rollback = + client.Rollback(txn, Options{}.set("operation")); + EXPECT_STATUS_OK(rollback); +} + } // namespace GOOGLE_CLOUD_CPP_INLINE_NAMESPACE_END } // namespace spanner diff --git a/google/cloud/spanner/connection.h b/google/cloud/spanner/connection.h index dfe3eed7922ca..c3d5403d07f82 100644 --- a/google/cloud/spanner/connection.h +++ b/google/cloud/spanner/connection.h @@ -128,6 +128,9 @@ class Connection { }; //@} + /// Returns the options used by the Connection. + virtual Options options() { return Options{}; } + /// Defines the interface for `Client::Read()` virtual RowStream Read(ReadParams) = 0; diff --git a/google/cloud/spanner/internal/connection_impl.cc b/google/cloud/spanner/internal/connection_impl.cc index b238ba0d43666..af0b293a4ba25 100644 --- a/google/cloud/spanner/internal/connection_impl.cc +++ b/google/cloud/spanner/internal/connection_impl.cc @@ -103,18 +103,12 @@ Status MissingTransactionStatus(std::string const& operation) { ConnectionImpl::ConnectionImpl( spanner::Database db, std::unique_ptr background_threads, - std::vector> stubs, Options const& opts) + std::vector> stubs, Options opts) : db_(std::move(db)), - retry_policy_prototype_( - opts.get()->clone()), - backoff_policy_prototype_( - opts.get()->clone()), background_threads_(std::move(background_threads)), + opts_(internal::MergeOptions(std::move(opts), Connection::options())), session_pool_(MakeSessionPool(db_, std::move(stubs), - background_threads_->cq(), opts)), - rpc_stream_tracing_enabled_(internal::Contains( - opts.get(), "rpc-streams")), - tracing_options_(opts.get()) {} + background_threads_->cq(), opts_)) {} spanner::RowStream ConnectionImpl::Read(ReadParams params) { return Visit(std::move(params.transaction), @@ -327,6 +321,31 @@ class StreamingPartitionedDmlResult { std::unique_ptr source_; }; +std::shared_ptr const& +ConnectionImpl::RetryPolicyPrototype() const { + // TODO(#7690): Base this on internal::CurrentOptions(). + return opts_.get(); +} + +std::shared_ptr const& +ConnectionImpl::BackoffPolicyPrototype() const { + // TODO(#7690): Base this on internal::CurrentOptions(). + return opts_.get(); +} + +bool ConnectionImpl::RpcStreamTracingEnabled() const { + // TODO(#7690): Base this on internal::CurrentOptions() if we want to + // allow per-operation options to influence "rpc-streams" tracing. + return internal::Contains(opts_.get(), + "rpc-streams"); +} + +TracingOptions const& ConnectionImpl::RpcTracingOptions() const { + // TODO(#7690): Base this on internal::CurrentOptions() if we want to + // allow per-operation options to influence "rpc-streams" tracing. + return opts_.get(); +} + /** * Helper function that ensures `session` holds a valid `Session`, or returns * an error if `session` is empty and no `Session` can be allocated. @@ -366,7 +385,7 @@ StatusOr ConnectionImpl::BeginTransaction( auto stub = session_pool_->GetStub(*session); auto response = RetryLoop( - retry_policy_prototype_->clone(), backoff_policy_prototype_->clone(), + RetryPolicyPrototype()->clone(), BackoffPolicyPrototype()->clone(), Idempotency::kIdempotent, [&stub](grpc::ClientContext& context, spanner_proto::BeginTransactionRequest const& request) { @@ -417,8 +436,8 @@ spanner::RowStream ConnectionImpl::ReadImpl( // Capture a copy of `stub` to ensure the `shared_ptr<>` remains valid through // the lifetime of the lambda. auto stub = session_pool_->GetStub(*session); - auto const tracing_enabled = rpc_stream_tracing_enabled_; - auto const tracing_options = tracing_options_; + auto const tracing_enabled = RpcStreamTracingEnabled(); + auto const& tracing_options = RpcTracingOptions(); auto factory = [stub, &request, tracing_enabled, tracing_options](std::string const& resume_token) mutable { request.set_resume_token(resume_token); @@ -434,8 +453,8 @@ spanner::RowStream ConnectionImpl::ReadImpl( }; for (;;) { auto rpc = absl::make_unique( - factory, Idempotency::kIdempotent, retry_policy_prototype_->clone(), - backoff_policy_prototype_->clone()); + factory, Idempotency::kIdempotent, RetryPolicyPrototype()->clone(), + BackoffPolicyPrototype()->clone()); auto reader = PartialResultSetSource::Create(std::move(rpc)); if (s->has_begin()) { if (reader.ok()) { @@ -496,7 +515,7 @@ StatusOr> ConnectionImpl::PartitionReadImpl( auto stub = session_pool_->GetStub(*session); for (;;) { auto response = RetryLoop( - retry_policy_prototype_->clone(), backoff_policy_prototype_->clone(), + RetryPolicyPrototype()->clone(), BackoffPolicyPrototype()->clone(), Idempotency::kIdempotent, [&stub](grpc::ClientContext& context, spanner_proto::PartitionReadRequest const& request) { @@ -627,12 +646,12 @@ ResultType ConnectionImpl::CommonQueryImpl( // through the lifetime of the lambda. Note that the local variables are a // reference to avoid increasing refcounts twice, but the capture is by value. auto stub = session_pool_->GetStub(*session); - auto const& retry_policy = retry_policy_prototype_; - auto const& backoff_policy = backoff_policy_prototype_; - auto const tracing_enabled = rpc_stream_tracing_enabled_; - auto const tracing_options = tracing_options_; + auto const& retry_policy_prototype = RetryPolicyPrototype(); + auto const& backoff_policy_prototype = BackoffPolicyPrototype(); + auto const tracing_enabled = RpcStreamTracingEnabled(); + auto const& tracing_options = RpcTracingOptions(); auto retry_resume_fn = - [stub, retry_policy, backoff_policy, tracing_enabled, + [stub, retry_policy_prototype, backoff_policy_prototype, tracing_enabled, tracing_options](spanner_proto::ExecuteSqlRequest& request) mutable -> StatusOr> { auto factory = [stub, request, tracing_enabled, @@ -649,8 +668,8 @@ ResultType ConnectionImpl::CommonQueryImpl( return reader; }; auto rpc = absl::make_unique( - std::move(factory), Idempotency::kIdempotent, retry_policy->clone(), - backoff_policy->clone()); + std::move(factory), Idempotency::kIdempotent, + retry_policy_prototype->clone(), backoff_policy_prototype->clone()); return PartialResultSetSource::Create(std::move(rpc)); }; @@ -699,15 +718,15 @@ StatusOr ConnectionImpl::CommonDmlImpl( // through the lifetime of the lambda. Note that the local variables are a // reference to avoid increasing refcounts twice, but the capture is by value. auto stub = session_pool_->GetStub(*session); - auto const& retry_policy = retry_policy_prototype_; - auto const& backoff_policy = backoff_policy_prototype_; + auto const& retry_policy_prototype = RetryPolicyPrototype(); + auto const& backoff_policy_prototype = BackoffPolicyPrototype(); auto retry_resume_fn = - [function_name, stub, retry_policy, backoff_policy, + [function_name, stub, retry_policy_prototype, backoff_policy_prototype, session](spanner_proto::ExecuteSqlRequest& request) mutable -> StatusOr> { StatusOr response = RetryLoop( - retry_policy->clone(), backoff_policy->clone(), + retry_policy_prototype->clone(), backoff_policy_prototype->clone(), Idempotency::kIdempotent, [stub](grpc::ClientContext& context, spanner_proto::ExecuteSqlRequest const& request) { @@ -783,7 +802,7 @@ ConnectionImpl::PartitionQueryImpl( auto stub = session_pool_->GetStub(*session); for (;;) { auto response = RetryLoop( - retry_policy_prototype_->clone(), backoff_policy_prototype_->clone(), + RetryPolicyPrototype()->clone(), BackoffPolicyPrototype()->clone(), Idempotency::kIdempotent, [&stub](grpc::ClientContext& context, spanner_proto::PartitionQueryRequest const& request) { @@ -857,7 +876,7 @@ StatusOr ConnectionImpl::ExecuteBatchDmlImpl( auto stub = session_pool_->GetStub(*session); for (;;) { auto response = RetryLoop( - retry_policy_prototype_->clone(), backoff_policy_prototype_->clone(), + RetryPolicyPrototype()->clone(), BackoffPolicyPrototype()->clone(), Idempotency::kIdempotent, [&stub](grpc::ClientContext& context, spanner_proto::ExecuteBatchDmlRequest const& request) { @@ -979,7 +998,7 @@ StatusOr ConnectionImpl::CommitImpl( auto stub = session_pool_->GetStub(*session); auto response = RetryLoop( - retry_policy_prototype_->clone(), backoff_policy_prototype_->clone(), + RetryPolicyPrototype()->clone(), BackoffPolicyPrototype()->clone(), Idempotency::kIdempotent, [&stub](grpc::ClientContext& context, spanner_proto::CommitRequest const& request) { @@ -1041,7 +1060,7 @@ Status ConnectionImpl::RollbackImpl( request.set_transaction_id(s->id()); auto stub = session_pool_->GetStub(*session); auto status = RetryLoop( - retry_policy_prototype_->clone(), backoff_policy_prototype_->clone(), + RetryPolicyPrototype()->clone(), BackoffPolicyPrototype()->clone(), Idempotency::kIdempotent, [&stub](grpc::ClientContext& context, spanner_proto::RollbackRequest const& request) { diff --git a/google/cloud/spanner/internal/connection_impl.h b/google/cloud/spanner/internal/connection_impl.h index 252c10cba5e6f..7a08393959584 100644 --- a/google/cloud/spanner/internal/connection_impl.h +++ b/google/cloud/spanner/internal/connection_impl.h @@ -48,8 +48,9 @@ class ConnectionImpl : public spanner::Connection { public: ConnectionImpl(spanner::Database db, std::unique_ptr background_threads, - std::vector> stubs, - Options const& opts); + std::vector> stubs, Options opts); + + Options options() override { return opts_; } spanner::RowStream Read(ReadParams) override; StatusOr> PartitionRead( @@ -69,6 +70,11 @@ class ConnectionImpl : public spanner::Connection { Status Rollback(RollbackParams) override; private: + std::shared_ptr const& RetryPolicyPrototype() const; + std::shared_ptr const& BackoffPolicyPrototype() const; + bool RpcStreamTracingEnabled() const; + TracingOptions const& RpcTracingOptions() const; + Status PrepareSession(SessionHolder& session, bool dissociate_from_pool = false); @@ -163,12 +169,9 @@ class ConnectionImpl : public spanner::Connection { google::spanner::v1::ExecuteSqlRequest::QueryMode query_mode); spanner::Database db_; - std::shared_ptr retry_policy_prototype_; - std::shared_ptr backoff_policy_prototype_; std::unique_ptr background_threads_; + Options opts_; std::shared_ptr session_pool_; - bool rpc_stream_tracing_enabled_ = false; - TracingOptions tracing_options_; }; GOOGLE_CLOUD_CPP_INLINE_NAMESPACE_END diff --git a/google/cloud/spanner/mocks/mock_spanner_connection.h b/google/cloud/spanner/mocks/mock_spanner_connection.h index 7fd7aceec430f..6492527847e4e 100644 --- a/google/cloud/spanner/mocks/mock_spanner_connection.h +++ b/google/cloud/spanner/mocks/mock_spanner_connection.h @@ -42,6 +42,7 @@ GOOGLE_CLOUD_CPP_INLINE_NAMESPACE_BEGIN */ class MockConnection : public spanner::Connection { public: + MOCK_METHOD(Options, options, (), (override)); MOCK_METHOD(spanner::RowStream, Read, (ReadParams), (override)); MOCK_METHOD(StatusOr>, PartitionRead, (PartitionReadParams), (override));