diff --git a/src/bigtrace/orchestrator/BUILD.gn b/src/bigtrace/orchestrator/BUILD.gn index 8a0aff61fe..16c34dbaa6 100644 --- a/src/bigtrace/orchestrator/BUILD.gn +++ b/src/bigtrace/orchestrator/BUILD.gn @@ -24,6 +24,10 @@ if (enable_perfetto_grpc) { "orchestrator_impl.cc", "orchestrator_impl.h", "orchestrator_main.cc", + "resizable_task_pool.cc", + "resizable_task_pool.h", + "trace_address_pool.cc", + "trace_address_pool.h", ] deps = [ "../../../gn:default_deps", diff --git a/src/bigtrace/orchestrator/orchestrator_impl.cc b/src/bigtrace/orchestrator/orchestrator_impl.cc index 21392bd988..8a151ef2b4 100644 --- a/src/bigtrace/orchestrator/orchestrator_impl.cc +++ b/src/bigtrace/orchestrator/orchestrator_impl.cc @@ -15,115 +15,195 @@ */ #include +#include #include #include +#include +#include + #include "perfetto/base/logging.h" +#include "perfetto/base/time.h" +#include "perfetto/ext/base/utils.h" +#include "protos/perfetto/bigtrace/orchestrator.pb.h" #include "src/bigtrace/orchestrator/orchestrator_impl.h" +#include "src/bigtrace/orchestrator/resizable_task_pool.h" +#include "src/bigtrace/orchestrator/trace_address_pool.h" namespace perfetto::bigtrace { - namespace { -const uint32_t kBufferPushDelay = 100; +const uint32_t kBufferPushDelayMicroseconds = 100; + +grpc::Status ExecuteQueryOnTrace( + std::string sql_query, + std::string trace, + grpc::Status& query_status, + std::mutex& worker_lock, + std::vector& response_buffer, + std::unique_ptr& stub, + ThreadWithContext* contextual_thread) { + protos::BigtraceQueryTraceArgs trace_args; + protos::BigtraceQueryTraceResponse trace_response; + + trace_args.set_sql_query(sql_query); + trace_args.set_trace(trace); + grpc::Status status = stub->QueryTrace( + contextual_thread->client_context.get(), trace_args, &trace_response); + + if (!status.ok()) { + { + std::lock_guard status_guard(worker_lock); + // We check and only update the query status if it was not already errored + // to avoid unnecessary updates. + if (query_status.ok()) { + query_status = status; + } + } + + return status; + } + + protos::BigtraceQueryResponse response; + response.set_trace(trace_response.trace()); + for (const protos::QueryResult& query_result : trace_response.result()) { + response.add_result()->CopyFrom(query_result); + } + std::lock_guard buffer_guard(worker_lock); + response_buffer.emplace_back(std::move(response)); + + return grpc::Status::OK; } +void ThreadRunLoop(ThreadWithContext* contextual_thread, + TraceAddressPool& address_pool, + std::string sql_query, + grpc::Status& query_status, + std::mutex& worker_lock, + std::vector& response_buffer, + std::unique_ptr& stub) { + for (;;) { + auto maybe_trace_address = address_pool.Pop(); + if (!maybe_trace_address) { + return; + } + + // The ordering of this context swap followed by the check on thread + // cancellation is essential and should not be changed to avoid a race where + // a request to cancel a thread is sent, followed by a context swap, causing + // the cancel to not be caught and the execution of the loop body to + // continue. + contextual_thread->client_context = std::make_unique(); + + if (contextual_thread->IsCancelled()) { + address_pool.MarkCancelled(std::move(*maybe_trace_address)); + return; + } + + grpc::Status status = ExecuteQueryOnTrace( + sql_query, *maybe_trace_address, query_status, worker_lock, + response_buffer, stub, contextual_thread); + + if (!status.ok()) { + if (status.error_code() == grpc::StatusCode::CANCELLED) { + address_pool.MarkCancelled(std::move(*maybe_trace_address)); + } + return; + } + } +} + +} // namespace + OrchestratorImpl::OrchestratorImpl( std::unique_ptr stub, - uint32_t pool_size) - : stub_(std::move(stub)), - pool_(std::make_unique(pool_size)), - semaphore_(pool_size) {} + uint32_t max_query_concurrency) + : stub_(std::move(stub)), max_query_concurrency_(max_query_concurrency) {} grpc::Status OrchestratorImpl::Query( grpc::ServerContext*, const protos::BigtraceQueryArgs* args, grpc::ServerWriter* writer) { grpc::Status query_status; - std::mutex status_lock; + std::mutex worker_lock; const std::string& sql_query = args->sql_query(); + std::vector traces(args->traces().begin(), args->traces().end()); std::vector response_buffer; uint64_t trace_count = static_cast(args->traces_size()); - std::thread push_response_buffer_thread([&]() { - uint64_t pushed_response_count = 0; - for (;;) { - { - std::lock_guard status_guard(status_lock); - if (pushed_response_count == trace_count || !query_status.ok()) { - break; - } - } - std::this_thread::sleep_for(std::chrono::milliseconds(kBufferPushDelay)); - if (response_buffer.empty()) { - continue; - } - std::vector buffer; - { - std::lock_guard buffer_guard(buffer_lock_); - buffer = std::move(response_buffer); - response_buffer.clear(); - } - for (protos::BigtraceQueryResponse& response : buffer) { - writer->Write(std::move(response)); - } - pushed_response_count += buffer.size(); - } + TraceAddressPool address_pool(std::move(traces)); + + // Update the query count on start and end ensuring that the query count is + // always decremented whenever the function is exited. + { + std::lock_guard lk(query_count_mutex_); + query_count_++; + } + auto query_count_decrement = base::OnScopeExit([&]() { + std::lock_guard lk(query_count_mutex_); + query_count_--; }); - for (const std::string& trace : args->traces()) { + ResizableTaskPool task_pool([&](ThreadWithContext* new_contextual_thread) { + ThreadRunLoop(new_contextual_thread, address_pool, sql_query, query_status, + worker_lock, response_buffer, stub_); + }); + + uint64_t pushed_response_count = 0; + uint32_t last_query_count = 0; + uint32_t current_query_count = 0; + + for (;;) { + { + std::lock_guard lk(query_count_mutex_); + current_query_count = query_count_; + } + + PERFETTO_CHECK(current_query_count != 0); + + // Update the number of threads to the lower of {the remaining number of + // traces} and the {maximum concurrency divided by the number of active + // queries}. This ensures that at most |max_query_concurrency_| calls to the + // backend are outstanding at any one point. + if (last_query_count != current_query_count) { + auto new_size = + std::min(std::max(address_pool.RemainingCount(), 1u), + max_query_concurrency_ / current_query_count); + task_pool.Resize(new_size); + last_query_count = current_query_count; + } + + // Exit the loop when either all responses have been successfully completed + // or if there is an error. { - std::lock_guard status_guard(status_lock); - if (!query_status.ok()) { + std::lock_guard status_guard(worker_lock); + if (pushed_response_count == trace_count || !query_status.ok()) { break; } } - semaphore_.Acquire(); - pool_->PostTask([&]() { - grpc::ClientContext client_context; - protos::BigtraceQueryTraceArgs trace_args; - protos::BigtraceQueryTraceResponse trace_response; - - trace_args.set_sql_query(sql_query); - trace_args.set_trace(trace); - grpc::Status status = - stub_->QueryTrace(&client_context, trace_args, &trace_response); - if (!status.ok()) { - PERFETTO_ELOG("QueryTrace returned an error status %s", - status.error_message().c_str()); - { - std::lock_guard status_guard(status_lock); - query_status = status; - } - } else { - protos::BigtraceQueryResponse response; - response.set_trace(trace_response.trace()); - for (const protos::QueryResult& query_result : - trace_response.result()) { - response.add_result()->CopyFrom(query_result); - } - std::lock_guard buffer_guard(buffer_lock_); - response_buffer.emplace_back(std::move(response)); - } - semaphore_.Release(); - }); - } - push_response_buffer_thread.join(); - return query_status; -} -void OrchestratorImpl::Semaphore::Acquire() { - std::unique_lock lk(mutex_); - while (!count_) { - cv_.wait(lk); + // A buffer is used to periodically make writes to the client instead of + // writing every individual response in order to reduce contention on the + // writer. + base::SleepMicroseconds(kBufferPushDelayMicroseconds); + if (response_buffer.empty()) { + continue; + } + std::vector buffer; + { + std::lock_guard buffer_guard(worker_lock); + buffer = std::move(response_buffer); + response_buffer.clear(); + } + for (protos::BigtraceQueryResponse& response : buffer) { + writer->Write(std::move(response)); + } + pushed_response_count += buffer.size(); } - --count_; -} -void OrchestratorImpl::Semaphore::Release() { - std::lock_guard lk(mutex_); - ++count_; - cv_.notify_one(); + task_pool.JoinAll(); + + return query_status; } } // namespace perfetto::bigtrace diff --git a/src/bigtrace/orchestrator/orchestrator_impl.h b/src/bigtrace/orchestrator/orchestrator_impl.h index 71baa295c0..f74251e578 100644 --- a/src/bigtrace/orchestrator/orchestrator_impl.h +++ b/src/bigtrace/orchestrator/orchestrator_impl.h @@ -17,39 +17,35 @@ #ifndef SRC_BIGTRACE_ORCHESTRATOR_ORCHESTRATOR_IMPL_H_ #define SRC_BIGTRACE_ORCHESTRATOR_ORCHESTRATOR_IMPL_H_ +#include +#include +#include +#include #include "perfetto/ext/base/threading/thread_pool.h" #include "protos/perfetto/bigtrace/orchestrator.grpc.pb.h" #include "protos/perfetto/bigtrace/worker.grpc.pb.h" namespace perfetto::bigtrace { +namespace { +const uint64_t kDefaultMaxQueryConcurrency = 8; +} // namespace class OrchestratorImpl final : public protos::BigtraceOrchestrator::Service { public: explicit OrchestratorImpl(std::unique_ptr stub, - uint32_t pool_size); + uint32_t max_query_concurrency); + grpc::Status Query( grpc::ServerContext*, const protos::BigtraceQueryArgs* args, grpc::ServerWriter* writer) override; private: - class Semaphore { - public: - explicit Semaphore(uint32_t count) : count_(count) {} - void Acquire(); - void Release(); - - private: - std::mutex mutex_; - std::condition_variable cv_; - uint32_t count_; - }; std::unique_ptr stub_; std::unique_ptr pool_; - std::mutex buffer_lock_; - // Used to interleave requests to the Orchestrator to distribute jobs more - // fairly - Semaphore semaphore_; + uint32_t max_query_concurrency_ = kDefaultMaxQueryConcurrency; + uint32_t query_count_ = 0; + std::mutex query_count_mutex_; }; } // namespace perfetto::bigtrace diff --git a/src/bigtrace/orchestrator/orchestrator_main.cc b/src/bigtrace/orchestrator/orchestrator_main.cc index 38d1b297d9..8870965304 100644 --- a/src/bigtrace/orchestrator/orchestrator_main.cc +++ b/src/bigtrace/orchestrator/orchestrator_main.cc @@ -71,9 +71,8 @@ Usage: %s [OPTIONS] -w -p -n EXCLUSIVELY) -r --name_resolution_scheme SCHEME Specify the name resolution scheme for gRPC (e.g. ipv4:, dns://) - -t -thread_pool_size POOL_SIZE Specify the size of the thread pool - which determines number of concurrent - gRPCs from the Orchestrator + -t -max_query_concurrency Specify the number of concurrent + MAX_QUERY_CONCURRENCY queries/gRPCs from the Orchestrator )", argv[0]); } @@ -157,7 +156,6 @@ base::Status OrchestratorMain(int argc, char** argv) { std::string worker_address_list = options->worker_address_list; uint64_t worker_count = options->worker_count; - // TODO(ivankc) Replace with DNS resolver std::string target_address = options->name_resolution_scheme.empty() ? "ipv4:" : options->name_resolution_scheme; diff --git a/src/bigtrace/orchestrator/resizable_task_pool.cc b/src/bigtrace/orchestrator/resizable_task_pool.cc new file mode 100644 index 0000000000..8d0aca6df6 --- /dev/null +++ b/src/bigtrace/orchestrator/resizable_task_pool.cc @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/bigtrace/orchestrator/resizable_task_pool.h" + +namespace perfetto::bigtrace { + +ResizableTaskPool::ResizableTaskPool(std::function fn) + : fn_(std::move(fn)) {} + +// Resizes the number of threads in the task pool to |new_size| +// +// This works by performing one of two possible actions: +// 1) When the number of threads is reduced, the excess are cancelled and joined +// 2) When the number of threads is increased, new threads are created and +// started +void ResizableTaskPool::Resize(uint32_t new_size) { + if (size_t old_size = contextual_threads_.size(); new_size < old_size) { + for (size_t i = new_size; i < old_size; ++i) { + contextual_threads_[i]->Cancel(); + } + for (size_t i = new_size; i < old_size; ++i) { + contextual_threads_[i]->thread.join(); + } + contextual_threads_.resize(new_size); + } else { + contextual_threads_.resize(new_size); + for (size_t i = old_size; i < new_size; ++i) { + contextual_threads_[i] = std::make_unique(fn_); + } + } +} + +// Joins all threads in the task pool +void ResizableTaskPool::JoinAll() { + for (auto& contextual_thread : contextual_threads_) { + contextual_thread->thread.join(); + } +} + +} // namespace perfetto::bigtrace diff --git a/src/bigtrace/orchestrator/resizable_task_pool.h b/src/bigtrace/orchestrator/resizable_task_pool.h new file mode 100644 index 0000000000..32b1516727 --- /dev/null +++ b/src/bigtrace/orchestrator/resizable_task_pool.h @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SRC_BIGTRACE_ORCHESTRATOR_RESIZABLE_TASK_POOL_H_ +#define SRC_BIGTRACE_ORCHESTRATOR_RESIZABLE_TASK_POOL_H_ + +#include +#include +#include + +#include + +namespace perfetto::bigtrace { + +// This struct maps a thread to a context in order to allow for the cancellation +// of the thread's current gRPC call through ClientContext's TryCancel +struct ThreadWithContext { + explicit ThreadWithContext(std::function fn) + : thread(fn, this) {} + + // Cancels the gRPC call through ClientContext as well as signalling a stop to + // the thread + void Cancel() { + client_context->TryCancel(); + std::lock_guard lk(mutex); + is_thread_cancelled = true; + } + + // Returns whether the thread has been cancelled + bool IsCancelled() { + std::lock_guard lk(mutex); + return is_thread_cancelled; + } + + std::mutex mutex; + std::unique_ptr client_context; + std::thread thread; + bool is_thread_cancelled = false; +}; + +// This pool manages a set of running tasks for a given query, and provides the +// ability to resize in order to fairly distribute an equal number of workers +// for each user through preemption +class ResizableTaskPool { + public: + explicit ResizableTaskPool(std::function fn); + void Resize(uint32_t new_size); + void JoinAll(); + + private: + std::function fn_; + std::vector> contextual_threads_; +}; +} // namespace perfetto::bigtrace + +#endif // SRC_BIGTRACE_ORCHESTRATOR_RESIZABLE_TASK_POOL_H_ diff --git a/src/bigtrace/orchestrator/trace_address_pool.cc b/src/bigtrace/orchestrator/trace_address_pool.cc new file mode 100644 index 0000000000..33ee31162d --- /dev/null +++ b/src/bigtrace/orchestrator/trace_address_pool.cc @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/bigtrace/orchestrator/trace_address_pool.h" +#include "perfetto/base/logging.h" + +namespace perfetto::bigtrace { + +TraceAddressPool::TraceAddressPool( + const std::vector& trace_addresses) + : trace_addresses_(trace_addresses) {} + +// Pops a trace address from the pool, blocking if necessary +// +// Returns a nullopt if the pool is empty +std::optional TraceAddressPool::Pop() { + std::lock_guard trace_addresses_guard(trace_addresses_lock_); + if (trace_addresses_.size() == 0) { + return std::nullopt; + } + std::string trace_address = trace_addresses_.back(); + trace_addresses_.pop_back(); + running_queries_++; + return trace_address; +} + +// Marks a trace address as cancelled +// +// Returns cancelled trace addresses to the pool for future calls to |Pop| +void TraceAddressPool::MarkCancelled(std::string trace_address) { + std::lock_guard guard(trace_addresses_lock_); + PERFETTO_CHECK(running_queries_-- > 0); + trace_addresses_.push_back(std::move(trace_address)); +} + +// Returns the number of remaining trace addresses which require processing +uint32_t TraceAddressPool::RemainingCount() { + std::lock_guard guard(trace_addresses_lock_); + return static_cast(trace_addresses_.size()) + running_queries_; +} + +} // namespace perfetto::bigtrace diff --git a/src/bigtrace/orchestrator/trace_address_pool.h b/src/bigtrace/orchestrator/trace_address_pool.h new file mode 100644 index 0000000000..d175ef7982 --- /dev/null +++ b/src/bigtrace/orchestrator/trace_address_pool.h @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SRC_BIGTRACE_ORCHESTRATOR_TRACE_ADDRESS_POOL_H_ +#define SRC_BIGTRACE_ORCHESTRATOR_TRACE_ADDRESS_POOL_H_ + +#include +#include +#include + +namespace perfetto::bigtrace { + +// This pool contains all trace addresses of a given query and facilitates a +// thread safe way of popping traces and returning them to the pool if the query +// is cancelled +class TraceAddressPool { + public: + explicit TraceAddressPool(const std::vector& trace_addresses); + std::optional Pop(); + void MarkCancelled(std::string trace_address); + uint32_t RemainingCount(); + + private: + std::vector trace_addresses_; + std::mutex trace_addresses_lock_; + uint32_t running_queries_ = 0; +}; + +} // namespace perfetto::bigtrace + +#endif // SRC_BIGTRACE_ORCHESTRATOR_TRACE_ADDRESS_POOL_H_ diff --git a/src/bigtrace/worker/worker_impl.cc b/src/bigtrace/worker/worker_impl.cc index 43729db602..7c09984b84 100644 --- a/src/bigtrace/worker/worker_impl.cc +++ b/src/bigtrace/worker/worker_impl.cc @@ -14,24 +14,43 @@ * limitations under the License. */ -#include "src/bigtrace/worker/worker_impl.h" +#include +#include + +#include +#include + +#include "perfetto/base/time.h" #include "perfetto/ext/trace_processor/rpc/query_result_serializer.h" #include "perfetto/trace_processor/trace_processor.h" -#include "src/bigtrace/worker/repository_policies/gcs_trace_processor_loader.h" -#include "src/bigtrace/worker/repository_policies/local_trace_processor_loader.h" +#include "src/bigtrace/worker/worker_impl.h" namespace perfetto::bigtrace { grpc::Status WorkerImpl::QueryTrace( - grpc::ServerContext*, + grpc::ServerContext* server_context, const protos::BigtraceQueryTraceArgs* args, protos::BigtraceQueryTraceResponse* response) { + std::mutex mutex; + bool is_thread_done = false; + std::string args_trace = args->trace(); + if (args_trace.empty()) { + return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Empty trace name is not valid"); + } + + if (args_trace[0] != '/') { + return grpc::Status( + grpc::StatusCode::INVALID_ARGUMENT, + "Trace path must contain and begin with / for the prefix"); + } + std::string prefix = args_trace.substr(0, args_trace.find("/", 1)); if (registry_.find(prefix) == registry_.end()) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, - "Invalid path prefix specified"); + "Path prefix does not exist in registry"); } if (prefix.length() == args_trace.length()) { @@ -50,21 +69,49 @@ grpc::Status WorkerImpl::QueryTrace( } std::unique_ptr tp = std::move(*tp_or); + std::optional iterator; - auto iter = tp->ExecuteQuery(args->sql_query()); - trace_processor::QueryResultSerializer serializer = - trace_processor::QueryResultSerializer(std::move(iter)); + std::thread execute_query_thread([&]() { + iterator = tp->ExecuteQuery(args->sql_query()); + std::lock_guard lk(mutex); + is_thread_done = true; + }); - std::vector serialized; - for (bool has_more = true; has_more;) { - serialized.clear(); - has_more = serializer.Serialize(&serialized); - response->add_result()->ParseFromArray(serialized.data(), - static_cast(serialized.size())); - } - response->set_trace(args->trace()); + for (;;) { + if (server_context->IsCancelled()) { + // If the thread is cancelled, we need to propagate the information to the + // trace processor thread and we do this by attempting to interrupt the + // trace processor every 10ms until the trace processor thread returns. + // + // A loop is necessary here because, due to scheduling delay, it is + // possible we are cancelled before trace processor even started running. + // InterruptQuery is ignored if it happens before entering TraceProcessor + // which can cause the query to not be interrupted at all. + while (!execute_query_thread.joinable()) { + base::SleepMicroseconds(10000); + tp->InterruptQuery(); + } + execute_query_thread.join(); + return grpc::Status::CANCELLED; + } + + std::lock_guard lk(mutex); + if (is_thread_done) { + execute_query_thread.join(); + trace_processor::QueryResultSerializer serializer = + trace_processor::QueryResultSerializer(*std::move(iterator)); - return grpc::Status::OK; + std::vector serialized; + for (bool has_more = true; has_more;) { + serialized.clear(); + has_more = serializer.Serialize(&serialized); + response->add_result()->ParseFromArray( + serialized.data(), static_cast(serialized.size())); + } + response->set_trace(args->trace()); + return grpc::Status::OK; + } + } } } // namespace perfetto::bigtrace