From 993aa8fe183323c7c6b7195393ada0cae8762117 Mon Sep 17 00:00:00 2001 From: Dmitrii Ubskii <18616863+dmitrii-ubskii@users.noreply.github.com> Date: Wed, 1 Mar 2023 13:56:50 +0000 Subject: [PATCH] Clientless runtime-agnostic architecture (#32) ## What is the goal of this PR? We remove the requirement that typedb-client be used within a tokio runtime, making the library runtime-agnostic. We also remove the distinction between core and cluster, and replace the `Client` entry point with more fundamental `Connection`. ## What is the motivation behind the changes? ### Encapsulate tokio runtime This change serves three main purposes: 1. **Remove the requirement that the typedb-client library is run from a tokio runtime.** The gRPC crate we use, `tonic`, and its networking dependencies heavily rely on being used within a `tokio` runtime. That's a fairly big restriction to place on all user code. Now that all RPC interaction is hidden away in a background thread, all the API exposed to the user can be fully runtime-agnostic, and, potentially down the line, available in synchronous contexts as well. 2. **Eliminate the session close deadlocks in single-threaded runtimes.** We create our own _system_ thread, which spawns a tokio runtime and handles the RPC communication. This runs independently from the user-facing runtime and as such is not affected by it. Previously, if the user code happened to run in a single-threaded runtime, dropping the last handle for a session would send a `session_close()` request and block the only executor thread tokio had available, deadlocking the entire application. This happened because, to avoid sessions timing out on the server side, the session's drop method would block until it received a signal that the request had been sent. Because RPC communication is done asynchronously, it must be done in a different task from the one that's performing the drop, which is always blocking and cannot yield to said task. 3. **Encapsulate away RPC implementation details.** Should we ever make a decision to move away from gRPC+protobuf, the change would only require us to change a small set of files, namely the contents of `connection::network`. The communication under this new model uses Rust-native `Request` and `Response` enums, abstracting away the protocol structures. ### Dissolve Client into underlying Connection Requiring that a `session` may not outlive its `client`, and that a `transaction` may not outlive its `session`, meant that to preserve consistency we had to either extend the lifetimes of both `client` and `session`, or require the user to share the `client` handle between threads explicitly, even if they don't intend to use it beyond opening a `session`. Removing the top-of-the-hierarchy `Client` type and replacing it with a primitive clonable `Connection` allows us to partially invert the hierarchy, such that `DatabaseManager` and `Session` can explicitly own the resources they rely upon. This is in line with how established Rust crates (e.g. `tonic`, `mysql`) treat connections. For multithreading or concurrency, the ownership of a `session` needs to be explicitly shared between the threads or tasks, whether it be via using shared pointers (read: `Arc<_>`) or explicit scope bounds (such as `std::thread::scope()` in a synchronous version). ### Remove distinction between core and cluster The shift from TypeDB Core + TypeDB Cluster to just TypeDB (Cluster / Cloud) by default is reflected in the architecture. We now treat a core server as effectively a single-node cluster instance that lacks enterprise facilities (viz. user management). This change greatly improves user experience: all code written to interact with an open-source TypeDB instance is automatically valid for the production instance with a simple change in initialization. As a side-effect, this also helps us ensure that all integration tests implemented for core are also automatically implemented for cluster. Merging core and cluster also vastly simplifies the internal structure of the library, as only a few places have to know about which backend they are running against, specifically the portions that deal with authentication and, some day, user management. ## What are the changes implemented in this PR? Major changes: - New `connection` module: - `Connection` and `ServerConnection` conceptually roughly correspond to `ClusterRPC` and `ClusterServerRPC`: `Connection`'s only job is to manage the set of `ServerConnection`s, i.e. connections to individual nodes of the server; - `Connection`, created by user, spawns a background single-threaded tokio runtime which will houses all request handlers; - `ServerConnection` performs the actual message-passing between user code and its dedicated request dispatcher. - Move `common::rpc` module under `connection::network`: - move all protobuf serialization/deserialization into `connection::network::proto`, fully isolated from the rest of the crate; - provide native message enums intended for inter-thread communication (cf. `common::info` for crate-wide data structures); - merge `CoreRPC`, `ServerRPC`, and `ClusterServerRPC` into single `RPCStub`; - add `RPCTransmitter`: a dispatcher agent meant to run in the background tokio runtime, which handles the communication with the server; its job is to: - listen for user requests over an inter-thread mpsc channel, - serialize the requests for tonic; - deserialize the responses; - send the responses back over the provided callback channel; - overhaul `TransactionRPC` into `TransactionTransmitter`, analogous to the `RPCTransmitter` above; - its listener loop buffers the requests into batches before dispatching them to the server; - because, unlike `RPCTransmitter`, `TransactionTransmitter` wraps a bidirectional stream, it also has an associated listener agent that handles the user callbacks and auto-requests stream continuation. - Remove `connection::core` and merge `connection::{cluster, server}`, and `query` into a single top-level `database` module: - `ServerDatabase` and `ServerSession` are now hidden implementation details that handle communication with an individual node; - `Client`, as mentioned, has been removed entirely. Minor changes: - remove the no longer needed `async_dispatch` helper macro; - restructure the `queries_...` tests into a single `queries` test module that handles both core and cluster connections using a helper permutation test macro; - add a `compatibility` test module that ensures the API is async runtime-agnostic. Closes #7, #16, #17, #20, #22, #30. --- .factory/automation.yml | 17 +- BUILD | 36 +- WORKSPACE | 6 +- dependencies/ide/{ => rust}/BUILD | 0 dependencies/ide/{ => rust}/sync.sh | 2 +- dependencies/vaticle/repositories.bzl | 8 +- rustfmt.toml | 1 + src/answer/concept_map.rs | 10 +- src/answer/numeric.rs | 16 - src/common/address.rs | 6 +- src/common/credential.rs | 83 +--- src/common/error.rs | 113 +++-- src/common/{macros.rs => id.rs} | 47 +- src/common/info.rs | 45 ++ src/common/mod.rs | 42 +- src/common/options.rs | 82 ++++ src/common/rpc/builder.rs | 312 ------------- src/common/rpc/channel.rs | 98 ---- src/common/rpc/cluster.rs | 280 ------------ src/common/rpc/core.rs | 224 --------- src/common/rpc/server.rs | 108 ----- src/common/rpc/transaction.rs | 430 ------------------ src/concept/mod.rs | 149 +----- src/connection/cluster/client.rs | 56 --- src/connection/cluster/database.rs | 247 ---------- src/connection/cluster/database_manager.rs | 94 ---- src/connection/cluster/session.rs | 93 ---- src/connection/connection.rs | 332 ++++++++++++++ src/connection/core/client.rs | 63 --- src/connection/core/database_manager.rs | 76 ---- src/connection/core/options.rs | 134 ------ src/connection/message.rs | 147 ++++++ src/connection/mod.rs | 11 +- src/connection/network/channel.rs | 137 ++++++ src/connection/{core => network}/mod.rs | 9 +- src/connection/network/proto/common.rs | 74 +++ src/connection/network/proto/concept.rs | 196 ++++++++ src/connection/network/proto/database.rs | 49 ++ src/connection/network/proto/message.rs | 370 +++++++++++++++ .../rpc => connection/network/proto}/mod.rs | 34 +- src/connection/network/stub.rs | 264 +++++++++++ .../{server => network/transmitter}/mod.rs | 6 +- .../network/transmitter/response_sink.rs | 68 +++ src/connection/network/transmitter/rpc.rs | 161 +++++++ .../network/transmitter/transaction.rs | 271 +++++++++++ src/connection/runtime.rs | 88 ++++ src/connection/server/database.rs | 68 --- src/connection/server/session.rs | 122 ----- src/connection/server/transaction.rs | 88 ---- src/connection/transaction_stream.rs | 153 +++++++ src/database/database.rs | 285 ++++++++++++ src/database/database_manager.rs | 79 ++++ src/{connection/cluster => database}/mod.rs | 7 +- src/database/query.rs | 98 ++++ src/database/session.rs | 129 ++++++ src/database/transaction.rs | 70 +++ src/lib.rs | 15 +- src/query/mod.rs | 188 -------- tests/BUILD | 38 +- tests/common.rs | 62 +++ tests/queries.rs | 332 ++++++++++++++ tests/queries_cluster.rs | 61 --- tests/queries_core.rs | 249 ---------- tests/runtimes.rs | 89 ++++ 64 files changed, 3841 insertions(+), 3387 deletions(-) rename dependencies/ide/{ => rust}/BUILD (100%) rename dependencies/ide/{ => rust}/sync.sh (94%) rename src/common/{macros.rs => id.rs} (56%) create mode 100644 src/common/info.rs create mode 100644 src/common/options.rs delete mode 100644 src/common/rpc/builder.rs delete mode 100644 src/common/rpc/channel.rs delete mode 100644 src/common/rpc/cluster.rs delete mode 100644 src/common/rpc/core.rs delete mode 100644 src/common/rpc/server.rs delete mode 100644 src/common/rpc/transaction.rs delete mode 100644 src/connection/cluster/client.rs delete mode 100644 src/connection/cluster/database.rs delete mode 100644 src/connection/cluster/database_manager.rs delete mode 100644 src/connection/cluster/session.rs create mode 100644 src/connection/connection.rs delete mode 100644 src/connection/core/client.rs delete mode 100644 src/connection/core/database_manager.rs delete mode 100644 src/connection/core/options.rs create mode 100644 src/connection/message.rs create mode 100644 src/connection/network/channel.rs rename src/connection/{core => network}/mod.rs (86%) create mode 100644 src/connection/network/proto/common.rs create mode 100644 src/connection/network/proto/concept.rs create mode 100644 src/connection/network/proto/database.rs create mode 100644 src/connection/network/proto/message.rs rename src/{common/rpc => connection/network/proto}/mod.rs (67%) create mode 100644 src/connection/network/stub.rs rename src/connection/{server => network/transmitter}/mod.rs (87%) create mode 100644 src/connection/network/transmitter/response_sink.rs create mode 100644 src/connection/network/transmitter/rpc.rs create mode 100644 src/connection/network/transmitter/transaction.rs create mode 100644 src/connection/runtime.rs delete mode 100644 src/connection/server/database.rs delete mode 100644 src/connection/server/session.rs delete mode 100644 src/connection/server/transaction.rs create mode 100644 src/connection/transaction_stream.rs create mode 100644 src/database/database.rs create mode 100644 src/database/database_manager.rs rename src/{connection/cluster => database}/mod.rs (86%) create mode 100644 src/database/query.rs create mode 100644 src/database/session.rs create mode 100644 src/database/transaction.rs delete mode 100644 src/query/mod.rs create mode 100644 tests/common.rs create mode 100644 tests/queries.rs delete mode 100644 tests/queries_cluster.rs delete mode 100644 tests/queries_core.rs create mode 100644 tests/runtimes.rs diff --git a/.factory/automation.yml b/.factory/automation.yml index 507c3e93..898de4cf 100644 --- a/.factory/automation.yml +++ b/.factory/automation.yml @@ -23,7 +23,9 @@ config: version-candidate: VERSION dependencies: dependencies: [build] + typedb-common: [build] typedb-protocol: [build, release] + typeql: [build, release] build: quality: @@ -53,7 +55,7 @@ build: bazel run @vaticle_dependencies//distribution/artifact:create-netrc bazel build //... tools/start-core-server.sh - bazel test //tests:queries_core --test_arg=-- --test_arg=--test-threads=1 --test_output=streamed && export TEST_SUCCESS=0 || export TEST_SUCCESS=1 + bazel test //tests:queries --test_arg=-- --test_arg=core --test_arg=--test-threads=1 --test_output=streamed && export TEST_SUCCESS=0 || export TEST_SUCCESS=1 tools/stop-core-server.sh exit $TEST_SUCCESS test-integration-cluster: @@ -65,9 +67,20 @@ build: bazel run @vaticle_dependencies//distribution/artifact:create-netrc bazel build //... source tools/start-cluster-servers.sh # use source to receive export vars - bazel test //tests:queries_cluster --test_env=ROOT_CA=$ROOT_CA --test_arg=-- --test_arg=--test-threads=1 --test_output=streamed && export TEST_SUCCESS=0 || export TEST_SUCCESS=1 + bazel test //tests:queries --test_env=ROOT_CA=$ROOT_CA --test_arg=-- --test_arg=cluster --test_arg=--test-threads=1 --test_output=streamed && export TEST_SUCCESS=0 || export TEST_SUCCESS=1 tools/stop-cluster-servers.sh exit $TEST_SUCCESS + test-integration-runtimes: + image: vaticle-ubuntu-22.04 + command: | + export ARTIFACT_USERNAME=$REPO_VATICLE_USERNAME + export ARTIFACT_PASSWORD=$REPO_VATICLE_PASSWORD + bazel run @vaticle_dependencies//distribution/artifact:create-netrc + bazel build //... + tools/start-core-server.sh + bazel test //tests:runtimes --test_arg=-- --test_arg=--test-threads=1 --test_output=streamed && export TEST_SUCCESS=0 || export TEST_SUCCESS=1 + tools/stop-core-server.sh + exit $TEST_SUCCESS deploy-crate-snapshot: filter: owner: vaticle diff --git a/BUILD b/BUILD index 5de007c8..c192ac19 100644 --- a/BUILD +++ b/BUILD @@ -31,45 +31,47 @@ load("//:deployment.bzl", deployment_github = "deployment") rust_library( name = "typedb_client", srcs = glob(["src/**/*.rs"]), + tags = ["crate-name=typedb-client"], deps = [ + "@crates//:chrono", + "@crates//:crossbeam", + "@crates//:futures", + "@crates//:http", + "@crates//:itertools", + "@crates//:log", + "@crates//:prost", + "@crates//:tokio", + "@crates//:tokio-stream", + "@crates//:tonic", + "@crates//:uuid", "@vaticle_typedb_protocol//grpc/rust:typedb_protocol", "@vaticle_typeql//rust:typeql_lang", - - "@vaticle_dependencies//library/crates:chrono", - "@vaticle_dependencies//library/crates:crossbeam", - "@vaticle_dependencies//library/crates:futures", - "@vaticle_dependencies//library/crates:log", - "@vaticle_dependencies//library/crates:prost", - "@vaticle_dependencies//library/crates:tokio", - "@vaticle_dependencies//library/crates:tonic", - "@vaticle_dependencies//library/crates:uuid", ], - tags = ["crate-name=typedb-client"], ) assemble_crate( name = "assemble_crate", - target = "typedb_client", description = "TypeDB Client API for Rust", homepage = "https://github.com/vaticle/typedb-client-rust", license = "Apache-2.0", repository = "https://github.com/vaticle/typedb-client-rust", + target = "typedb_client", ) deploy_crate( name = "deploy_crate", - target = ":assemble_crate", + release = deployment["crate.release"], snapshot = deployment["crate.snapshot"], - release = deployment["crate.release"] + target = ":assemble_crate", ) deploy_github( name = "deploy_github", draft = True, - title = "TypeDB Client Rust", - release_description = "//:RELEASE_TEMPLATE.md", organisation = deployment_github["github.organisation"], + release_description = "//:RELEASE_TEMPLATE.md", repository = deployment_github["github.repository"], + title = "TypeDB Client Rust", title_append_version = True, ) @@ -105,13 +107,13 @@ filegroup( rustfmt_test( name = "client_rustfmt_test", - targets = ["typedb_client"] + targets = ["typedb_client"], ) # CI targets that are not declared in any BUILD file, but are called externally filegroup( name = "ci", data = [ - "@vaticle_dependencies//ide/rust:sync" + "@vaticle_dependencies//tool/cargo:sync", ], ) diff --git a/WORKSPACE b/WORKSPACE index 86111ef2..dbb04ef5 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -49,8 +49,10 @@ load("@rules_rust//rust:repositories.bzl", "rules_rust_dependencies", "rust_regi rules_rust_dependencies() rust_register_toolchains(edition = "2021", include_rustc_srcs = True) -load("@vaticle_dependencies//library/crates:crates.bzl", "raze_fetch_remote_crates") -raze_fetch_remote_crates() +load("@vaticle_dependencies//library/crates:crates.bzl", "fetch_crates") +fetch_crates() +load("@crates//:defs.bzl", "crate_repositories") +crate_repositories() # Load //builder/python load("@vaticle_dependencies//builder/python:deps.bzl", python_deps = "deps") diff --git a/dependencies/ide/BUILD b/dependencies/ide/rust/BUILD similarity index 100% rename from dependencies/ide/BUILD rename to dependencies/ide/rust/BUILD diff --git a/dependencies/ide/sync.sh b/dependencies/ide/rust/sync.sh similarity index 94% rename from dependencies/ide/sync.sh rename to dependencies/ide/rust/sync.sh index f167a274..0603f4a8 100755 --- a/dependencies/ide/sync.sh +++ b/dependencies/ide/rust/sync.sh @@ -20,4 +20,4 @@ # under the License. # -bazel run @vaticle_dependencies//ide/rust:sync +bazel run @vaticle_dependencies//tool/cargo:sync diff --git a/dependencies/vaticle/repositories.bzl b/dependencies/vaticle/repositories.bzl index 3f1d763e..72a49da2 100644 --- a/dependencies/vaticle/repositories.bzl +++ b/dependencies/vaticle/repositories.bzl @@ -25,26 +25,26 @@ def vaticle_dependencies(): git_repository( name = "vaticle_dependencies", remote = "https://github.com/vaticle/dependencies", - commit = "d76a7b935cd6452615f78772539fbc2e1228f503", # sync-marker: do not remove this comment, this is used for sync-dependencies by @vaticle_dependencies + commit = "76636b1672b04e9880439395b8913231724ae459", # sync-marker: do not remove this comment, this is used for sync-dependencies by @vaticle_dependencies ) def vaticle_typedb_common(): git_repository( name = "vaticle_typedb_common", remote = "https://github.com/vaticle/typedb-common", - tag = "2.12.0" # sync-marker: do not remove this comment, this is used for sync-dependencies by @vaticle_typedb_common + commit = "aa03cb5f6a57ec2a51291b7a0510734ca1f41479" # sync-marker: do not remove this comment, this is used for sync-dependencies by @vaticle_typedb_common ) def vaticle_typedb_protocol(): git_repository( name = "vaticle_typedb_protocol", remote = "https://github.com/vaticle/typedb-protocol", - commit = "16d1fb6749c0fee85843ca67f470015dda9fc497", # sync-marker: do not remove this comment, this is used for sync-dependencies by @vaticle_dependencies + commit = "b1c19e02054c1a1d354b42875e6ccd67602a546f", # sync-marker: do not remove this comment, this is used for sync-dependencies by @vaticle_dependencies ) def vaticle_typeql(): git_repository( name = "vaticle_typeql", remote = "https://github.com/vaticle/typeql", - commit = "776643fb6f0c754730e55230733fd2326f32cd39", # sync-marker: do not remove this comment, this is used for sync-dependencies by @vaticle_dependencies + commit = "7a63699b3879296ae3039577ba3f5220bbf6d33d", # sync-marker: do not remove this comment, this is used for sync-dependencies by @vaticle_dependencies ) diff --git a/rustfmt.toml b/rustfmt.toml index 1fe0cdaa..2bfc3860 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -23,3 +23,4 @@ imports_granularity = "Crate" group_imports = "StdExternalCrate" use_small_heuristics = "Max" +max_width = 120 diff --git a/src/answer/concept_map.rs b/src/answer/concept_map.rs index bebc9376..8d665ba7 100644 --- a/src/answer/concept_map.rs +++ b/src/answer/concept_map.rs @@ -24,7 +24,7 @@ use std::{ ops::Index, }; -use crate::{common::Result, concept::Concept}; +use crate::concept::Concept; #[derive(Debug)] pub struct ConceptMap { @@ -32,14 +32,6 @@ pub struct ConceptMap { } impl ConceptMap { - pub(crate) fn from_proto(proto: typedb_protocol::ConceptMap) -> Result { - let mut map = HashMap::with_capacity(proto.map.len()); - for (k, v) in proto.map { - map.insert(k, Concept::from_proto(v)?); - } - Ok(Self { map }) - } - pub fn get(&self, var_name: &str) -> Option<&Concept> { self.map.get(var_name) } diff --git a/src/answer/numeric.rs b/src/answer/numeric.rs index 53d7b54e..c1771b1d 100644 --- a/src/answer/numeric.rs +++ b/src/answer/numeric.rs @@ -19,10 +19,6 @@ * under the License. */ -use typedb_protocol::numeric::Value; - -use crate::common::{Error, Result}; - #[derive(Clone, Debug)] pub enum Numeric { Long(i64), @@ -48,18 +44,6 @@ impl Numeric { } } -impl TryFrom for Numeric { - type Error = Error; - - fn try_from(value: typedb_protocol::Numeric) -> Result { - match value.value.unwrap() { - Value::LongValue(long) => Ok(Numeric::Long(long)), - Value::DoubleValue(double) => Ok(Numeric::Double(double)), - Value::Nan(_) => Ok(Numeric::NaN), - } - } -} - impl From for i64 { fn from(n: Numeric) -> Self { n.into_i64() diff --git a/src/common/address.rs b/src/common/address.rs index 8266d519..64201efc 100644 --- a/src/common/address.rs +++ b/src/common/address.rs @@ -21,12 +21,12 @@ use std::{fmt, str::FromStr}; -use tonic::transport::Uri; +use http::Uri; use crate::common::{Error, Result}; #[derive(Clone, Debug, Hash, PartialEq, Eq)] -pub struct Address { +pub(crate) struct Address { uri: Uri, } @@ -43,7 +43,7 @@ impl FromStr for Address { let uri = if address.contains("://") { address.parse::()? } else { - format!("http://{}", address).parse::()? + format!("http://{address}").parse::()? }; Ok(Self { uri }) } diff --git a/src/common/credential.rs b/src/common/credential.rs index 14d90d8a..1f90bdbb 100644 --- a/src/common/credential.rs +++ b/src/common/credential.rs @@ -19,35 +19,34 @@ * under the License. */ -use std::{ - fs, - path::{Path, PathBuf}, - sync::RwLock, -}; +use std::{fmt, fs, path::Path}; -use tonic::{ - transport::{Certificate, ClientTlsConfig}, - Request, -}; +use tonic::transport::{Certificate, ClientTlsConfig}; use crate::Result; -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct Credential { username: String, password: String, is_tls_enabled: bool, - tls_root_ca: Option, + tls_config: Option, } impl Credential { - pub fn with_tls(username: &str, password: &str, tls_root_ca: Option<&Path>) -> Self { - Credential { + pub fn with_tls(username: &str, password: &str, tls_root_ca: Option<&Path>) -> Result { + let tls_config = Some(if let Some(tls_root_ca) = tls_root_ca { + ClientTlsConfig::new().ca_certificate(Certificate::from_pem(fs::read_to_string(tls_root_ca)?)) + } else { + ClientTlsConfig::new() + }); + + Ok(Credential { username: username.to_owned(), password: password.to_owned(), is_tls_enabled: true, - tls_root_ca: tls_root_ca.map(Path::to_owned), - } + tls_config, + }) } pub fn without_tls(username: &str, password: &str) -> Self { @@ -55,7 +54,7 @@ impl Credential { username: username.to_owned(), password: password.to_owned(), is_tls_enabled: false, - tls_root_ca: None, + tls_config: None, } } @@ -71,51 +70,17 @@ impl Credential { self.is_tls_enabled } - pub fn tls_config(&self) -> Result { - if let Some(ref tls_root_ca) = self.tls_root_ca { - Ok(ClientTlsConfig::new() - .ca_certificate(Certificate::from_pem(fs::read_to_string(tls_root_ca)?))) - } else { - Ok(ClientTlsConfig::new()) - } + pub fn tls_config(&self) -> &Option { + &self.tls_config } } -#[derive(Debug)] -pub(crate) struct CallCredentials { - credential: Credential, - token: RwLock>, -} - -impl CallCredentials { - pub(super) fn new(credential: Credential) -> Self { - Self { credential, token: RwLock::new(None) } - } - - pub(super) fn username(&self) -> &str { - self.credential.username() - } - - pub(super) fn password(&self) -> &str { - self.credential.password() - } - - pub(super) fn set_token(&self, token: String) { - *self.token.write().unwrap() = Some(token); - } - - pub(super) fn reset_token(&self) { - *self.token.write().unwrap() = None; - } - - pub(super) fn inject(&self, mut request: Request<()>) -> Request<()> { - request.metadata_mut().insert("username", self.credential.username().try_into().unwrap()); - match &*self.token.read().unwrap() { - Some(token) => request.metadata_mut().insert("token", token.try_into().unwrap()), - None => request - .metadata_mut() - .insert("password", self.credential.password().try_into().unwrap()), - }; - request +impl fmt::Debug for Credential { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Credential") + .field("username", &self.username) + .field("is_tls_enabled", &self.is_tls_enabled) + .field("tls_config", &self.tls_config) + .finish() } } diff --git a/src/common/error.rs b/src/common/error.rs index bfadbbb1..d345243e 100644 --- a/src/common/error.rs +++ b/src/common/error.rs @@ -24,8 +24,12 @@ use std::{error::Error as StdError, fmt}; use tonic::{Code, Status}; use typeql_lang::error_messages; -error_messages! { ClientError - code: "CLI", type: "Client Error", +use crate::common::RequestID; + +error_messages! { ConnectionError + code: "CXN", type: "Connection Error", + ConnectionIsClosed() = + 1: "The connection has been closed and no further operation is allowed.", SessionIsClosed() = 2: "The session is closed and no further operation is allowed.", TransactionIsClosed() = @@ -38,7 +42,7 @@ error_messages! { ClientError 8: "The database '{}' does not exist.", MissingResponseField(&'static str) = 9: "Missing field in message received from server: '{}'.", - UnknownRequestId(String) = + UnknownRequestId(RequestID) = 10: "Received a response with unknown request id '{}'", ClusterUnableToConnect(String) = 12: "Unable to connect to TypeDB Cluster. Attempted connecting to the cluster members, but none are available: '{}'.", @@ -52,23 +56,35 @@ error_messages! { ClientError 17: "Failed to close session. It may still be open on the server: or it may already have been closed previously.", } -#[derive(Debug, PartialEq, Eq)] -pub enum Error { - Client(ClientError), - Other(String), +error_messages! { InternalError + code: "INT", type: "Internal Error", + RecvError() = + 1: "Channel is closed.", + SendError() = + 2: "Channel is closed.", + UnexpectedRequestType(String) = + 3: "Unexpected request type for remote procedure call: {}.", + UnexpectedResponseType(String) = + 4: "Unexpected response type for remote procedure call: {}.", + UnknownConnectionAddress(String) = + 5: "Received unrecognized address from the server: {}.", + EnumOutOfBounds(i32, &'static str) = + 6: "Value '{}' is out of bounds for enum '{}'.", } -impl Error { - pub(crate) fn new(msg: String) -> Self { - Error::Other(msg) - } +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Error { + Connection(ConnectionError), + Internal(InternalError), + Other(String), } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Error::Client(error) => write!(f, "{}", error), - Error::Other(message) => write!(f, "{}", message), + Error::Connection(error) => write!(f, "{error}"), + Error::Internal(error) => write!(f, "{error}"), + Error::Other(message) => write!(f, "{message}"), } } } @@ -76,15 +92,36 @@ impl fmt::Display for Error { impl StdError for Error { fn source(&self) -> Option<&(dyn StdError + 'static)> { match self { - Error::Client(error) => Some(error), + Error::Connection(error) => Some(error), + Error::Internal(error) => Some(error), Error::Other(_) => None, } } } -impl From for Error { - fn from(error: ClientError) -> Self { - Error::Client(error) +impl From for Error { + fn from(error: ConnectionError) -> Self { + Error::Connection(error) + } +} + +impl From for Error { + fn from(error: InternalError) -> Self { + Error::Internal(error) + } +} + +impl From for Error { + fn from(status: Status) -> Self { + if is_rst_stream(&status) { + Self::Connection(ConnectionError::UnableToConnect()) + } else if is_replica_not_primary(&status) { + Self::Connection(ConnectionError::ClusterReplicaNotPrimary()) + } else if is_token_credential_invalid(&status) { + Self::Connection(ConnectionError::ClusterTokenCredentialInvalid()) + } else { + Self::Other(status.message().to_string()) + } } } @@ -103,22 +140,14 @@ fn is_token_credential_invalid(status: &Status) -> bool { status.code() == Code::Unauthenticated && status.message().contains("[CLS08]") } -impl From for Error { - fn from(status: Status) -> Self { - if is_rst_stream(&status) { - Self::Client(ClientError::UnableToConnect()) - } else if is_replica_not_primary(&status) { - Self::Client(ClientError::ClusterReplicaNotPrimary()) - } else if is_token_credential_invalid(&status) { - Self::Client(ClientError::ClusterTokenCredentialInvalid()) - } else { - Self::Other(status.message().to_string()) - } +impl From for Error { + fn from(err: http::uri::InvalidUri) -> Self { + Error::Other(err.to_string()) } } -impl From for Error { - fn from(err: futures::channel::mpsc::SendError) -> Self { +impl From for Error { + fn from(err: tonic::transport::Error) -> Self { Error::Other(err.to_string()) } } @@ -129,15 +158,27 @@ impl From> for Error { } } -impl From for Error { - fn from(err: tonic::codegen::http::uri::InvalidUri) -> Self { - Error::Other(err.to_string()) +impl From for Error { + fn from(_err: tokio::sync::oneshot::error::RecvError) -> Self { + Error::Internal(InternalError::RecvError()) } } -impl From for Error { - fn from(err: tonic::transport::Error) -> Self { - Error::Other(err.to_string()) +impl From for Error { + fn from(_err: crossbeam::channel::RecvError) -> Self { + Error::Internal(InternalError::RecvError()) + } +} + +impl From> for Error { + fn from(_err: crossbeam::channel::SendError) -> Self { + Error::Internal(InternalError::SendError()) + } +} + +impl From for Error { + fn from(err: String) -> Self { + Error::Other(err) } } diff --git a/src/common/macros.rs b/src/common/id.rs similarity index 56% rename from src/common/macros.rs rename to src/common/id.rs index 52fff1c3..0c5abf05 100644 --- a/src/common/macros.rs +++ b/src/common/id.rs @@ -19,18 +19,39 @@ * under the License. */ -#[macro_export] -macro_rules! async_enum_dispatch { - { - $variants:tt - $($vis:vis async fn $name:ident(&mut self, $arg:ident : $arg_type:ty $(,)?) -> $res:ty);+ $(;)? - } => { $(async_enum_dispatch!(@impl $variants, $vis, $name, $arg, $arg_type, $res);)+ }; - - (@impl {$($variant:ident),+}, $vis:vis, $name:ident, $arg:ident, $arg_type:ty, $res:ty) => { - $vis async fn $name(&mut self, $arg: $arg_type) -> $res { - match self { - $(Self::$variant(inner) => inner.$name($arg).await,)+ - } - } +use std::fmt; + +use uuid::Uuid; + +#[derive(Clone, Eq, Hash, PartialEq)] +pub struct ID(Vec); + +impl ID { + pub(crate) fn generate() -> Self { + Uuid::new_v4().as_bytes().to_vec().into() + } +} + +impl From for Vec { + fn from(id: ID) -> Self { + id.0 + } +} + +impl From> for ID { + fn from(vec: Vec) -> Self { + Self(vec) + } +} + +impl fmt::Debug for ID { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ID[{self}]") + } +} + +impl fmt::Display for ID { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.iter().try_for_each(|byte| write!(f, "{byte:02x}")) } } diff --git a/src/common/info.rs b/src/common/info.rs new file mode 100644 index 00000000..b373b30c --- /dev/null +++ b/src/common/info.rs @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::time::Duration; + +use super::{address::Address, SessionID}; + +#[derive(Clone, Debug)] +pub(crate) struct SessionInfo { + pub(crate) address: Address, + pub(crate) session_id: SessionID, + pub(crate) network_latency: Duration, +} + +#[derive(Debug)] +pub(crate) struct DatabaseInfo { + pub(crate) name: String, + pub(crate) replicas: Vec, +} + +#[derive(Debug)] +pub(crate) struct ReplicaInfo { + pub(crate) address: Address, + pub(crate) is_primary: bool, + pub(crate) is_preferred: bool, + pub(crate) term: i64, +} diff --git a/src/common/mod.rs b/src/common/mod.rs index 6fa466c8..35a175a9 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -19,51 +19,29 @@ * under the License. */ -mod address; -pub mod credential; +pub(crate) mod address; +mod credential; pub mod error; -mod macros; -pub(crate) mod rpc; +mod id; +pub(crate) mod info; +mod options; -use tonic::{Response, Status}; -use typedb_protocol::{session as session_proto, transaction as transaction_proto}; - -pub(crate) use self::rpc::{ClusterRPC, ClusterServerRPC, CoreRPC, ServerRPC, TransactionRPC}; -pub use self::{address::Address, credential::Credential, error::Error}; +pub use self::{credential::Credential, error::Error, options::Options}; pub(crate) type StdResult = std::result::Result; pub type Result = StdResult; -pub(crate) type TonicResult = StdResult, Status>; -pub(crate) type TonicChannel = tonic::transport::Channel; -pub(crate) type Executor = futures::executor::ThreadPool; +pub(crate) type RequestID = id::ID; +pub(crate) type SessionID = id::ID; -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum SessionType { Data = 0, Schema = 1, } -impl SessionType { - pub(crate) fn to_proto(self) -> session_proto::Type { - match self { - SessionType::Data => session_proto::Type::Data, - SessionType::Schema => session_proto::Type::Schema, - } - } -} - -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum TransactionType { Read = 0, Write = 1, } - -impl TransactionType { - pub(crate) fn to_proto(self) -> transaction_proto::Type { - match self { - TransactionType::Read => transaction_proto::Type::Read, - TransactionType::Write => transaction_proto::Type::Write, - } - } -} diff --git a/src/common/options.rs b/src/common/options.rs new file mode 100644 index 00000000..176f1e36 --- /dev/null +++ b/src/common/options.rs @@ -0,0 +1,82 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::time::Duration; + +#[derive(Clone, Debug, Default)] +pub struct Options { + pub infer: Option, + pub trace_inference: Option, + pub explain: Option, + pub parallel: Option, + pub prefetch: Option, + pub prefetch_size: Option, + pub session_idle_timeout: Option, + pub transaction_timeout: Option, + pub schema_lock_acquire_timeout: Option, + pub read_any_replica: Option, +} + +impl Options { + pub fn new() -> Self { + Default::default() + } + + pub fn infer(self, infer: bool) -> Self { + Self { infer: Some(infer), ..self } + } + + pub fn trace_inference(self, trace_inference: bool) -> Self { + Self { trace_inference: Some(trace_inference), ..self } + } + + pub fn explain(self, explain: bool) -> Self { + Self { explain: Some(explain), ..self } + } + + pub fn parallel(self, parallel: bool) -> Self { + Self { parallel: Some(parallel), ..self } + } + + pub fn prefetch(self, prefetch: bool) -> Self { + Self { prefetch: Some(prefetch), ..self } + } + + pub fn prefetch_size(self, prefetch_size: i32) -> Self { + Self { prefetch_size: Some(prefetch_size), ..self } + } + + pub fn session_idle_timeout(self, timeout: Duration) -> Self { + Self { session_idle_timeout: Some(timeout), ..self } + } + + pub fn transaction_timeout(self, timeout: Duration) -> Self { + Self { transaction_timeout: Some(timeout), ..self } + } + + pub fn schema_lock_acquire_timeout(self, timeout: Duration) -> Self { + Self { schema_lock_acquire_timeout: Some(timeout), ..self } + } + + pub fn read_any_replica(self, read_any_replica: bool) -> Self { + Self { read_any_replica: Some(read_any_replica), ..self } + } +} diff --git a/src/common/rpc/builder.rs b/src/common/rpc/builder.rs deleted file mode 100644 index 574bb611..00000000 --- a/src/common/rpc/builder.rs +++ /dev/null @@ -1,312 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -pub(crate) mod core { - pub(crate) mod database_manager { - use typedb_protocol::core_database_manager::{all, contains, create}; - - pub(crate) fn contains_req(name: &str) -> contains::Req { - contains::Req { name: name.into() } - } - - pub(crate) fn create_req(name: &str) -> create::Req { - create::Req { name: name.into() } - } - - pub(crate) fn all_req() -> all::Req { - all::Req {} - } - } - - pub(crate) mod database { - use typedb_protocol::core_database::{delete, rule_schema, schema, type_schema}; - - pub(crate) fn delete_req(name: &str) -> delete::Req { - delete::Req { name: name.into() } - } - - pub(crate) fn rule_schema_req(name: &str) -> rule_schema::Req { - rule_schema::Req { name: name.into() } - } - - pub(crate) fn schema_req(name: &str) -> schema::Req { - schema::Req { name: name.into() } - } - - pub(crate) fn type_schema_req(name: &str) -> type_schema::Req { - type_schema::Req { name: name.into() } - } - } -} - -pub(crate) mod cluster { - pub(crate) mod server_manager { - use typedb_protocol::server_manager::all; - - pub(crate) fn all_req() -> all::Req { - all::Req {} - } - } - - pub(crate) mod user_manager { - use typedb_protocol::cluster_user_manager::{all, contains, create}; - - pub(crate) fn contains_req(username: &str) -> contains::Req { - contains::Req { username: username.into() } - } - - pub(crate) fn create_req(username: &str, password: &str) -> create::Req { - create::Req { username: username.into(), password: password.into() } - } - - pub(crate) fn all_req() -> all::Req { - all::Req {} - } - } - - pub(crate) mod user { - use typedb_protocol::cluster_user::{delete, password, token}; - - pub(crate) fn password_req(username: &str, password: &str) -> password::Req { - password::Req { username: username.into(), password: password.into() } - } - - pub(crate) fn token_req(username: &str) -> token::Req { - token::Req { username: username.into() } - } - - pub(crate) fn delete_req(username: &str) -> delete::Req { - delete::Req { username: username.into() } - } - } - - pub(crate) mod database_manager { - use typedb_protocol::cluster_database_manager::{all, get}; - - pub(crate) fn get_req(name: &str) -> get::Req { - get::Req { name: name.into() } - } - - pub(crate) fn all_req() -> all::Req { - all::Req {} - } - } -} - -pub(crate) mod session { - use typedb_protocol::{ - session, - session::{close, open}, - Options, - }; - - pub(crate) fn close_req(session_id: Vec) -> close::Req { - close::Req { session_id } - } - - pub(crate) fn open_req( - database: &str, - session_type: session::Type, - options: Options, - ) -> open::Req { - open::Req { - database: database.into(), - r#type: session_type.into(), - options: options.into(), - } - } -} - -pub(crate) mod transaction { - use typedb_protocol::{ - transaction, - transaction::{commit, open, rollback, stream}, - Options, - }; - use uuid::Uuid; - - pub(crate) fn client_msg(reqs: Vec) -> transaction::Client { - transaction::Client { reqs } - } - - pub(crate) fn stream_req(req_id: Vec) -> transaction::Req { - req_with_id(transaction::req::Req::StreamReq(stream::Req {}), req_id) - } - - pub(crate) fn open_req( - session_id: Vec, - transaction_type: transaction::Type, - options: Options, - network_latency_millis: i32, - ) -> transaction::Req { - req(transaction::req::Req::OpenReq(open::Req { - session_id, - r#type: transaction_type.into(), - options: options.into(), - network_latency_millis, - })) - } - - pub(crate) fn commit_req() -> transaction::Req { - req(transaction::req::Req::CommitReq(commit::Req {})) - } - - pub(crate) fn rollback_req() -> transaction::Req { - req(transaction::req::Req::RollbackReq(rollback::Req {})) - } - - pub(super) fn req(req: transaction::req::Req) -> transaction::Req { - transaction::Req { req_id: new_req_id(), metadata: Default::default(), req: req.into() } - } - - pub(super) fn req_with_id(req: transaction::req::Req, req_id: Vec) -> transaction::Req { - transaction::Req { req_id, metadata: Default::default(), req: req.into() } - } - - fn new_req_id() -> Vec { - Uuid::new_v4().as_bytes().to_vec() - } -} - -#[allow(dead_code)] -pub(crate) mod query_manager { - use typedb_protocol::{ - query_manager, - query_manager::{ - define, delete, explain, insert, match_aggregate, match_group, match_group_aggregate, - r#match, undefine, update, - }, - transaction, - transaction::req::Req::QueryManagerReq, - Options, - }; - - fn query_manager_req(req: query_manager::Req) -> transaction::Req { - super::transaction::req(QueryManagerReq(req)) - } - - pub(crate) fn define_req(query: &str, options: Option) -> transaction::Req { - query_manager_req(query_manager::Req { - options, - req: query_manager::req::Req::DefineReq(define::Req { query: query.to_string() }) - .into(), - }) - } - - pub(crate) fn undefine_req(query: &str, options: Option) -> transaction::Req { - query_manager_req(query_manager::Req { - options, - req: query_manager::req::Req::UndefineReq(undefine::Req { query: query.to_string() }) - .into(), - }) - } - - pub(crate) fn match_req(query: &str, options: Option) -> transaction::Req { - query_manager_req(query_manager::Req { - options, - req: query_manager::req::Req::MatchReq(r#match::Req { query: query.to_string() }) - .into(), - }) - } - - pub(crate) fn match_aggregate_req(query: &str, options: Option) -> transaction::Req { - query_manager_req(query_manager::Req { - options, - req: query_manager::req::Req::MatchAggregateReq(match_aggregate::Req { - query: query.to_string(), - }) - .into(), - }) - } - - pub(crate) fn match_group_req(query: &str, options: Option) -> transaction::Req { - query_manager_req(query_manager::Req { - options, - req: query_manager::req::Req::MatchGroupReq(match_group::Req { - query: query.to_string(), - }) - .into(), - }) - } - - pub(crate) fn match_group_aggregate_req( - query: &str, - options: Option, - ) -> transaction::Req { - query_manager_req(query_manager::Req { - options, - req: query_manager::req::Req::MatchGroupAggregateReq(match_group_aggregate::Req { - query: query.to_string(), - }) - .into(), - }) - } - - pub(crate) fn insert_req(query: &str, options: Option) -> transaction::Req { - query_manager_req(query_manager::Req { - options, - req: query_manager::req::Req::InsertReq(insert::Req { query: query.to_string() }) - .into(), - }) - } - - pub(crate) fn delete_req(query: &str, options: Option) -> transaction::Req { - query_manager_req(query_manager::Req { - options, - req: query_manager::req::Req::DeleteReq(delete::Req { query: query.to_string() }) - .into(), - }) - } - - pub(crate) fn update_req(query: &str, options: Option) -> transaction::Req { - query_manager_req(query_manager::Req { - options, - req: query_manager::req::Req::UpdateReq(update::Req { query: query.to_string() }) - .into(), - }) - } - - pub(crate) fn explain_req(id: i64) -> transaction::Req { - query_manager_req(query_manager::Req { - options: None, - req: query_manager::req::Req::ExplainReq(explain::Req { explainable_id: id }).into(), - }) - } -} - -#[allow(dead_code)] -pub(crate) mod thing { - use typedb_protocol::{ - attribute, thing, thing::req::Req::AttributeGetOwnersReq, transaction, - transaction::req::Req::ThingReq, - }; - - fn thing_req(req: thing::Req) -> transaction::Req { - super::transaction::req(ThingReq(req)) - } - - pub(crate) fn attribute_get_owners_req(iid: &[u8]) -> transaction::Req { - thing_req(thing::Req { - iid: iid.to_vec(), - req: AttributeGetOwnersReq(attribute::get_owners::Req { filter: None }).into(), - }) - } -} diff --git a/src/common/rpc/channel.rs b/src/common/rpc/channel.rs deleted file mode 100644 index a320892f..00000000 --- a/src/common/rpc/channel.rs +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::sync::Arc; - -use tonic::{codegen::InterceptedService, service::Interceptor, Request, Status}; - -use crate::{ - common::{credential::CallCredentials, Address, Credential, TonicChannel}, - Result, -}; - -pub(crate) type CallCredChannel = InterceptedService; - -#[derive(Clone, Debug)] -pub(crate) enum Channel { - Plaintext(TonicChannel), - Encrypted(CallCredChannel), -} - -impl Channel { - pub(crate) fn open_plaintext(address: Address) -> Result { - Ok(Self::Plaintext(TonicChannel::builder(address.into_uri()).connect_lazy())) - } - - pub(crate) fn open_encrypted( - address: Address, - credential: Credential, - ) -> Result<(Self, Arc)> { - let mut builder = TonicChannel::builder(address.into_uri()); - if credential.is_tls_enabled() { - builder = builder.tls_config(credential.tls_config()?)?; - } - - let channel = builder.connect_lazy(); - let call_credentials = Arc::new(CallCredentials::new(credential)); - Ok(( - Self::Encrypted(InterceptedService::new( - channel, - CredentialInjector::new(call_credentials.clone()), - )), - call_credentials, - )) - } -} - -impl From for TonicChannel { - fn from(channel: Channel) -> Self { - match channel { - Channel::Plaintext(channel) => channel, - _ => panic!(), - } - } -} - -impl From for CallCredChannel { - fn from(channel: Channel) -> Self { - match channel { - Channel::Encrypted(channel) => channel, - _ => panic!(), - } - } -} - -#[derive(Clone, Debug)] -pub(crate) struct CredentialInjector { - call_credentials: Arc, -} - -impl CredentialInjector { - fn new(call_credentials: Arc) -> Self { - Self { call_credentials } - } -} - -impl Interceptor for CredentialInjector { - fn call(&mut self, request: Request<()>) -> std::result::Result, Status> { - Ok(self.call_credentials.inject(request)) - } -} diff --git a/src/common/rpc/cluster.rs b/src/common/rpc/cluster.rs deleted file mode 100644 index 0ec45a7e..00000000 --- a/src/common/rpc/cluster.rs +++ /dev/null @@ -1,280 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::{ - collections::{HashMap, HashSet}, - sync::Arc, -}; - -use futures::{channel::mpsc, future::BoxFuture, FutureExt}; -use tonic::Streaming; -use typedb_protocol::{ - cluster_database_manager, cluster_user, core_database, core_database_manager, session, - transaction, type_db_cluster_client::TypeDbClusterClient as ClusterGRPC, -}; - -use crate::common::{ - credential::CallCredentials, - error::ClientError, - rpc::{ - builder::{cluster, cluster::user::token_req}, - channel::CallCredChannel, - Channel, CoreRPC, - }, - Address, Credential, Error, Executor, Result, -}; - -#[derive(Debug, Clone)] -pub(crate) struct ClusterRPC { - server_rpcs: HashMap, -} - -impl ClusterRPC { - pub(crate) fn new(addresses: HashSet
, credential: Credential) -> Result> { - let cluster_clients = addresses - .into_iter() - .map(|address| { - Ok((address.clone(), ClusterServerRPC::new(address, credential.clone())?)) - }) - .collect::>()?; - Ok(Arc::new(Self { server_rpcs: cluster_clients })) - } - - pub(crate) async fn fetch_current_addresses>( - addresses: &[T], - credential: &Credential, - ) -> Result> { - for address in addresses { - match ClusterServerRPC::new(address.as_ref().parse()?, credential.clone())? - .validated() - .await - { - Ok(mut client) => { - let servers = client.servers_all().await?.servers; - return servers.into_iter().map(|server| server.address.parse()).collect(); - } - Err(Error::Client(ClientError::UnableToConnect())) => (), - Err(err) => Err(err)?, - } - } - Err(ClientError::UnableToConnect())? - } - - pub(crate) fn server_rpc_count(&self) -> usize { - self.server_rpcs.len() - } - - pub(crate) fn addresses(&self) -> impl Iterator { - self.server_rpcs.keys() - } - - pub(crate) fn get_server_rpc(&self, address: &Address) -> ClusterServerRPC { - self.server_rpcs.get(address).unwrap().clone() - } - - pub(crate) fn get_any_server_rpc(&self) -> ClusterServerRPC { - // TODO round robin? - self.server_rpcs.values().next().unwrap().clone() - } - - pub(crate) fn iter_server_rpcs_cloned(&self) -> impl Iterator + '_ { - self.server_rpcs.values().cloned() - } - - pub(crate) fn unable_to_connect(&self) -> Error { - Error::Client(ClientError::ClusterUnableToConnect( - self.addresses().map(Address::to_string).collect::>().join(","), - )) - } -} - -#[derive(Clone, Debug)] -pub(crate) struct ClusterServerRPC { - address: Address, - core_rpc: CoreRPC, - cluster_grpc: ClusterGRPC, - call_credentials: Arc, - pub(crate) executor: Arc, -} - -impl ClusterServerRPC { - pub(crate) fn new(address: Address, credential: Credential) -> Result { - let (channel, call_credentials) = Channel::open_encrypted(address.clone(), credential)?; - Ok(Self { - address, - core_rpc: CoreRPC::new(channel.clone())?, - cluster_grpc: ClusterGRPC::new(channel.into()), - executor: Arc::new(Executor::new().expect("Failed to create Executor")), - call_credentials, - }) - } - - async fn validated(mut self) -> Result { - self.cluster_grpc.databases_all(cluster::database_manager::all_req()).await?; - Ok(self) - } - - pub(crate) fn address(&self) -> &Address { - &self.address - } - - async fn call_with_auto_renew_token(&mut self, call: F) -> Result - where - for<'a> F: Fn(&'a mut Self) -> BoxFuture<'a, Result>, - { - match call(self).await { - Err(Error::Client(ClientError::ClusterTokenCredentialInvalid())) => { - self.renew_token().await?; - call(self).await - } - res => res, - } - } - - async fn renew_token(&mut self) -> Result { - self.call_credentials.reset_token(); - let req = token_req(self.call_credentials.username()); - let token = self.user_token(req).await?.token; - self.call_credentials.set_token(token); - Ok(()) - } - - async fn user_token( - &mut self, - username: cluster_user::token::Req, - ) -> Result { - Ok(self.cluster_grpc.user_token(username).await?.into_inner()) - } - - pub(crate) async fn servers_all( - &mut self, - ) -> Result { - self.call_with_auto_renew_token(|this| { - Box::pin( - this.cluster_grpc - .servers_all(cluster::server_manager::all_req()) - .map(|res| Ok(res?.into_inner())), - ) - }) - .await - } - - pub(crate) async fn databases_get( - &mut self, - req: cluster_database_manager::get::Req, - ) -> Result { - self.call_with_auto_renew_token(|this| { - Box::pin(this.cluster_grpc.databases_get(req.clone()).map(|res| Ok(res?.into_inner()))) - }) - .await - } - - pub(crate) async fn databases_all( - &mut self, - req: cluster_database_manager::all::Req, - ) -> Result { - self.call_with_auto_renew_token(|this| { - Box::pin(this.cluster_grpc.databases_all(req.clone()).map(|res| Ok(res?.into_inner()))) - }) - .await - } - - // server client pass-through - pub(crate) async fn databases_contains( - &mut self, - req: core_database_manager::contains::Req, - ) -> Result { - self.call_with_auto_renew_token(|this| { - Box::pin(this.core_rpc.databases_contains(req.clone())) - }) - .await - } - - pub(crate) async fn databases_create( - &mut self, - req: core_database_manager::create::Req, - ) -> Result { - self.call_with_auto_renew_token(|this| { - Box::pin(this.core_rpc.databases_create(req.clone())) - }) - .await - } - - pub(crate) async fn database_delete( - &mut self, - req: core_database::delete::Req, - ) -> Result { - self.call_with_auto_renew_token(|this| Box::pin(this.core_rpc.database_delete(req.clone()))) - .await - } - - pub(crate) async fn database_schema( - &mut self, - req: core_database::schema::Req, - ) -> Result { - self.call_with_auto_renew_token(|this| Box::pin(this.core_rpc.database_schema(req.clone()))) - .await - } - - pub(crate) async fn database_rule_schema( - &mut self, - req: core_database::rule_schema::Req, - ) -> Result { - self.call_with_auto_renew_token(|this| { - Box::pin(this.core_rpc.database_rule_schema(req.clone())) - }) - .await - } - - pub(crate) async fn database_type_schema( - &mut self, - req: core_database::type_schema::Req, - ) -> Result { - self.call_with_auto_renew_token(|this| { - Box::pin(this.core_rpc.database_type_schema(req.clone())) - }) - .await - } - - pub(crate) async fn session_open( - &mut self, - req: session::open::Req, - ) -> Result { - self.call_with_auto_renew_token(|this| Box::pin(this.core_rpc.session_open(req.clone()))) - .await - } - - pub(crate) async fn session_close( - &mut self, - req: session::close::Req, - ) -> Result { - self.call_with_auto_renew_token(|this| Box::pin(this.core_rpc.session_close(req.clone()))) - .await - } - - pub(crate) async fn transaction( - &mut self, - req: transaction::Req, - ) -> Result<(mpsc::Sender, Streaming)> { - self.call_with_auto_renew_token(|this| Box::pin(this.core_rpc.transaction(req.clone()))) - .await - } -} diff --git a/src/common/rpc/core.rs b/src/common/rpc/core.rs deleted file mode 100644 index 59715ed1..00000000 --- a/src/common/rpc/core.rs +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::{future::Future, sync::Arc}; - -use futures::{channel::mpsc, SinkExt}; -use tonic::{Response, Status, Streaming}; -use typedb_protocol::{ - core_database, core_database_manager, session, transaction, - type_db_client::TypeDbClient as RawCoreGRPC, -}; - -use crate::{ - async_enum_dispatch, - common::{ - rpc::{ - builder::{core, transaction::client_msg}, - channel::CallCredChannel, - Channel, - }, - Address, Executor, Result, StdResult, TonicChannel, - }, -}; - -#[derive(Clone, Debug)] -enum CoreGRPC { - Plaintext(RawCoreGRPC), - Encrypted(RawCoreGRPC), -} - -impl CoreGRPC { - pub fn new(channel: Channel) -> Self { - match channel { - Channel::Plaintext(channel) => Self::Plaintext(RawCoreGRPC::new(channel)), - Channel::Encrypted(channel) => Self::Encrypted(RawCoreGRPC::new(channel)), - } - } - - async_enum_dispatch! { { Plaintext, Encrypted } - pub async fn databases_contains( - &mut self, - request: core_database_manager::contains::Req, - ) -> StdResult, Status>; - - pub async fn databases_create( - &mut self, - request: core_database_manager::create::Req, - ) -> StdResult, Status>; - - pub async fn databases_all( - &mut self, - request: core_database_manager::all::Req, - ) -> StdResult, Status>; - - pub async fn database_schema( - &mut self, - request: core_database::schema::Req, - ) -> StdResult, Status>; - - pub async fn database_type_schema( - &mut self, - request: core_database::type_schema::Req, - ) -> StdResult, Status>; - - pub async fn database_rule_schema( - &mut self, - request: core_database::rule_schema::Req, - ) -> StdResult, Status>; - - pub async fn database_delete( - &mut self, - request: core_database::delete::Req, - ) -> StdResult, Status>; - - pub async fn session_open( - &mut self, - request: session::open::Req, - ) -> StdResult, Status>; - - pub async fn session_close( - &mut self, - request: session::close::Req, - ) -> StdResult, Status>; - - pub async fn session_pulse( - &mut self, - request: session::pulse::Req, - ) -> StdResult, Status>; - - pub async fn transaction( - &mut self, - request: impl tonic::IntoStreamingRequest, - ) -> StdResult>, Status>; - } -} - -#[derive(Clone, Debug)] -pub(crate) struct CoreRPC { - core_grpc: CoreGRPC, - pub(crate) executor: Arc, -} - -impl CoreRPC { - pub(crate) fn new(channel: Channel) -> Result { - Ok(Self { - core_grpc: CoreGRPC::new(channel), - executor: Arc::new(Executor::new().expect("Failed to create Executor")), - }) - } - - pub(crate) async fn connect(address: Address) -> Result { - Self::new(Channel::open_plaintext(address)?)?.validated().await - } - - async fn validated(mut self) -> Result { - // TODO: temporary hack to validate connection until we have client pulse - self.core_grpc.databases_all(core::database_manager::all_req()).await?; - Ok(self) - } - - pub(crate) async fn databases_contains( - &mut self, - req: core_database_manager::contains::Req, - ) -> Result { - single(self.core_grpc.databases_contains(req)).await - } - - pub(crate) async fn databases_create( - &mut self, - req: core_database_manager::create::Req, - ) -> Result { - single(self.core_grpc.databases_create(req)).await - } - - pub(crate) async fn databases_all( - &mut self, - req: core_database_manager::all::Req, - ) -> Result { - single(self.core_grpc.databases_all(req)).await - } - - pub(crate) async fn database_delete( - &mut self, - req: core_database::delete::Req, - ) -> Result { - single(self.core_grpc.database_delete(req)).await - } - - pub(crate) async fn database_schema( - &mut self, - req: core_database::schema::Req, - ) -> Result { - single(self.core_grpc.database_schema(req)).await - } - - pub(crate) async fn database_type_schema( - &mut self, - req: core_database::type_schema::Req, - ) -> Result { - single(self.core_grpc.database_type_schema(req)).await - } - - pub(crate) async fn database_rule_schema( - &mut self, - req: core_database::rule_schema::Req, - ) -> Result { - single(self.core_grpc.database_rule_schema(req)).await - } - - pub(crate) async fn session_open( - &mut self, - req: session::open::Req, - ) -> Result { - single(self.core_grpc.session_open(req)).await - } - - pub(crate) async fn session_close( - &mut self, - req: session::close::Req, - ) -> Result { - single(self.core_grpc.session_close(req)).await - } - - pub(crate) async fn transaction( - &mut self, - open_req: transaction::Req, - ) -> Result<(mpsc::Sender, Streaming)> { - // TODO: refactor to crossbeam channel - let (mut sender, receiver) = mpsc::channel::(256); - sender.send(client_msg(vec![open_req])).await.unwrap(); - bidi_stream(sender, self.core_grpc.transaction(receiver)).await - } -} - -pub(crate) async fn single( - res: impl Future, Status>>, -) -> Result { - Ok(res.await?.into_inner()) -} - -pub(crate) async fn bidi_stream( - req_sink: mpsc::Sender, - res: impl Future>, Status>>, -) -> Result<(mpsc::Sender, Streaming)> { - Ok((req_sink, res.await?.into_inner())) -} diff --git a/src/common/rpc/server.rs b/src/common/rpc/server.rs deleted file mode 100644 index c7a281fd..00000000 --- a/src/common/rpc/server.rs +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::sync::Arc; - -use futures::channel::mpsc; -use tonic::Streaming; -use typedb_protocol::{core_database, core_database_manager, session, transaction}; - -use crate::{ - async_enum_dispatch, - common::{ - rpc::{core::CoreRPC, ClusterServerRPC}, - Executor, Result, - }, -}; - -#[derive(Clone, Debug)] -pub(crate) enum ServerRPC { - Core(CoreRPC), - Cluster(ClusterServerRPC), -} - -impl From for ServerRPC { - fn from(server_client: CoreRPC) -> Self { - ServerRPC::Core(server_client) - } -} - -impl From for ServerRPC { - fn from(cluster_client: ClusterServerRPC) -> Self { - ServerRPC::Cluster(cluster_client) - } -} - -impl ServerRPC { - pub(crate) fn executor(&self) -> &Arc { - match self { - Self::Core(client) => &client.executor, - Self::Cluster(client) => &client.executor, - } - } - - async_enum_dispatch! { { Core, Cluster } - pub(crate) async fn databases_contains( - &mut self, - req: core_database_manager::contains::Req, - ) -> Result; - - pub(crate) async fn databases_create( - &mut self, - req: core_database_manager::create::Req, - ) -> Result; - - pub(crate) async fn database_delete( - &mut self, - req: core_database::delete::Req, - ) -> Result; - - pub(crate) async fn database_schema( - &mut self, - req: core_database::schema::Req, - ) -> Result; - - pub(crate) async fn database_type_schema( - &mut self, - req: core_database::type_schema::Req, - ) -> Result; - - pub(crate) async fn database_rule_schema( - &mut self, - req: core_database::rule_schema::Req, - ) -> Result; - - pub(crate) async fn session_open( - &mut self, - req: session::open::Req, - ) -> Result; - - pub(crate) async fn session_close( - &mut self, - req: session::close::Req, - ) -> Result; - - pub(crate) async fn transaction( - &mut self, - req: transaction::Req, - ) -> Result<(mpsc::Sender, Streaming)>; - } -} diff --git a/src/common/rpc/transaction.rs b/src/common/rpc/transaction.rs deleted file mode 100644 index bfd46c25..00000000 --- a/src/common/rpc/transaction.rs +++ /dev/null @@ -1,430 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::{ - collections::HashMap, - mem, - pin::Pin, - sync::{Arc, Mutex}, - task::{Context, Poll}, - thread::sleep, - time::Duration, -}; - -use crossbeam::atomic::AtomicCell; -use futures::{ - channel::{mpsc, oneshot}, - SinkExt, Stream, StreamExt, -}; -use tonic::Streaming; -use typedb_protocol::{ - transaction, - transaction::{res::Res, res_part, server::Server, stream::State}, -}; - -use crate::common::{ - error::{ClientError, Error}, - rpc::{ - builder::transaction::{client_msg, stream_req}, - ServerRPC, - }, - Executor, Result, -}; - -// TODO: This structure has become pretty messy - review -#[derive(Clone, Debug)] -pub(crate) struct TransactionRPC { - rpc_client: ServerRPC, - sender: Sender, - receiver: Receiver, -} - -impl TransactionRPC { - pub(crate) async fn new(rpc_client: &ServerRPC, open_req: transaction::Req) -> Result { - let mut rpc_client_clone = rpc_client.clone(); - let (req_sink, streaming_res): ( - mpsc::Sender, - Streaming, - ) = rpc_client_clone.transaction(open_req).await?; - let (close_signal_sink, close_signal_receiver) = oneshot::channel::>(); - Ok(TransactionRPC { - rpc_client: rpc_client_clone.clone(), - sender: Sender::new( - req_sink, - rpc_client_clone.executor().clone(), - close_signal_receiver, - ), - receiver: Receiver::new(streaming_res, rpc_client_clone.executor(), close_signal_sink) - .await, - }) - } - - pub(crate) async fn single(&mut self, req: transaction::Req) -> Result { - if !self.is_open() { - todo!() - } - let (res_sink, res_receiver) = oneshot::channel::>(); - self.receiver.add_single(&req.req_id, res_sink); - self.sender.submit_message(req); - match res_receiver.await { - Ok(result) => result, - Err(err) => Err(Error::new(err.to_string())), - } - } - - pub(crate) fn stream(&mut self, req: transaction::Req) -> ResPartStream { - const BUFFER_SIZE: usize = 1024; - let (res_part_sink, res_part_receiver) = - mpsc::channel::>(BUFFER_SIZE); - let (stream_req_sink, stream_req_receiver) = std::sync::mpsc::channel::(); - self.receiver.add_stream(&req.req_id, res_part_sink); - let res_part_stream = - ResPartStream::new(res_part_receiver, stream_req_sink, req.req_id.clone()); - self.sender.add_message_provider(stream_req_receiver); - self.sender.submit_message(req); - res_part_stream - } - - pub(crate) fn is_open(&self) -> bool { - self.sender.is_open() - } - - pub(crate) async fn close(&self) { - self.sender.close(None).await; - } -} - -#[derive(Clone, Debug)] -struct Sender { - state: Arc, - executor: Arc, -} - -#[derive(Debug)] -struct SenderState { - req_sink: mpsc::Sender, - // TODO: refactor to crossbeam_queue::ArrayQueue? - queued_messages: Mutex>, - // TODO: refactor to message passing for these atomics - ongoing_task_count: AtomicCell, - is_open: AtomicCell, -} - -type ReqId = Vec; - -impl SenderState { - fn new(req_sink: mpsc::Sender) -> Self { - SenderState { - req_sink, - queued_messages: Mutex::new(Vec::new()), - ongoing_task_count: AtomicCell::new(0), - is_open: AtomicCell::new(true), - } - } - - fn submit_message(&self, req: transaction::Req) { - self.queued_messages.lock().unwrap().push(req); - } - - async fn dispatch_loop(&self) { - while self.is_open.load() { - const DISPATCH_INTERVAL: Duration = Duration::from_millis(3); - sleep(DISPATCH_INTERVAL); - self.dispatch_messages().await; - } - } - - async fn dispatch_messages(&self) { - self.ongoing_task_count.fetch_add(1); - let messages = mem::take(&mut *self.queued_messages.lock().unwrap()); - if !messages.is_empty() { - self.req_sink.clone().send(client_msg(messages)).await.unwrap(); - } - self.ongoing_task_count.fetch_sub(1); - } - - async fn await_close_signal(&self, close_signal_receiver: CloseSignalReceiver) { - match close_signal_receiver.await { - Ok(close_signal) => { - self.close(close_signal).await; - } - Err(err) => { - self.close(Some(Error::new(err.to_string()))).await; - } - } - } - - async fn close(&self, error: Option) { - if let Ok(true) = self.is_open.compare_exchange(true, false) { - if error.is_none() { - self.dispatch_messages().await; - } - // TODO: refactor to non-busy wait? - // TODO: this loop should have a timeout - loop { - if self.ongoing_task_count.load() == 0 { - self.req_sink.clone().close().await.unwrap(); - break; - } - } - } - } -} - -impl Sender { - pub(crate) fn new( - req_sink: mpsc::Sender, - executor: Arc, - close_signal_receiver: CloseSignalReceiver, - ) -> Self { - let state = Arc::new(SenderState::new(req_sink)); - // // TODO: clarify lifetimes of these threads - executor.spawn_ok({ - let state = state.clone(); - async move { - state.await_close_signal(close_signal_receiver).await; - } - }); - - executor.spawn_ok({ - let state = state.clone(); - async move { - state.dispatch_loop().await; - } - }); - - Sender { state, executor } - } - - fn submit_message(&self, req: transaction::Req) { - self.state.submit_message(req); - } - - fn add_message_provider(&self, provider: std::sync::mpsc::Receiver) { - let cloned_state = self.state.clone(); - self.executor.spawn_ok(async move { - for req in provider.iter() { - cloned_state.submit_message(req); - } - }); - } - - fn is_open(&self) -> bool { - self.state.is_open.load() - } - - async fn close(&self, error: Option) { - self.state.close(error).await - } -} - -#[derive(Clone, Debug)] -struct Receiver { - state: Arc, -} - -#[derive(Debug)] -struct ReceiverState { - res_collectors: Mutex>, - res_part_collectors: Mutex>, - is_open: AtomicCell, -} - -impl ReceiverState { - fn new() -> Self { - ReceiverState { - res_collectors: Mutex::new(HashMap::new()), - res_part_collectors: Mutex::new(HashMap::new()), - is_open: AtomicCell::new(true), - } - } - - async fn listen( - self: &Arc, - mut grpc_stream: Streaming, - close_signal_sink: CloseSignalSink, - ) { - loop { - match grpc_stream.next().await { - Some(Ok(message)) => { - self.clone().on_receive(message).await; - } - Some(Err(err)) => { - self.close(Some(err.into()), close_signal_sink).await; - break; - } - None => { - self.close(None, close_signal_sink).await; - break; - } - } - } - } - - async fn on_receive(&self, message: transaction::Server) { - // TODO: If an error occurs here (or in some other background process), resources are not - // properly cleaned up, and the application may hang. - match message.server { - Some(Server::Res(res)) => self.collect_res(res), - Some(Server::ResPart(res_part)) => { - self.collect_res_part(res_part).await; - } - None => println!("{}", ClientError::MissingResponseField("server")), - } - } - - fn collect_res(&self, res: transaction::Res) { - match self.res_collectors.lock().unwrap().remove(&res.req_id) { - Some(collector) => collector.send(Ok(res)).unwrap(), - None => { - if let Res::OpenRes(_) = res.res.unwrap() { - // ignore open_res - } else { - println!("{}", ClientError::UnknownRequestId(format!("{:?}", &res.req_id))) - // println!("{}", MESSAGES.client.unknown_request_id.to_err( - // vec![std::str::from_utf8(&res.req_id).unwrap()]) - // ) - } - } - } - } - - async fn collect_res_part(&self, res_part: transaction::ResPart) { - let value = self.res_part_collectors.lock().unwrap().remove(&res_part.req_id); - match value { - Some(mut collector) => { - let req_id = res_part.req_id.clone(); - if collector.send(Ok(res_part)).await.is_ok() { - self.res_part_collectors.lock().unwrap().insert(req_id, collector); - } - } - None => { - let req_id_str = hex_string(&res_part.req_id); - println!("{}", ClientError::UnknownRequestId(req_id_str)); - } - } - } - - async fn close(&self, error: Option, close_signal_sink: CloseSignalSink) { - if let Ok(true) = self.is_open.compare_exchange(true, false) { - let error_str = error.map(|err| err.to_string()); - for (_, collector) in self.res_collectors.lock().unwrap().drain() { - collector.send(Err(close_reason(&error_str))).ok(); - } - let mut res_part_collectors = Vec::new(); - for (_, res_part_collector) in self.res_part_collectors.lock().unwrap().drain() { - res_part_collectors.push(res_part_collector) - } - for mut collector in res_part_collectors { - collector.send(Err(close_reason(&error_str))).await.ok(); - } - close_signal_sink.send(Some(close_reason(&error_str))).unwrap(); - } - } -} - -fn hex_string(v: &[u8]) -> String { - v.iter().map(|b| format!("{:02X}", b)).collect::() -} - -fn close_reason(error_str: &Option) -> Error { - match error_str { - None => ClientError::TransactionIsClosed(), - Some(value) => ClientError::TransactionIsClosedWithErrors(value.clone()), - } - .into() -} - -impl Receiver { - async fn new( - grpc_stream: Streaming, - executor: &Executor, - close_signal_sink: CloseSignalSink, - ) -> Self { - let state = Arc::new(ReceiverState::new()); - executor.spawn_ok({ - let state = state.clone(); - async move { - state.listen(grpc_stream, close_signal_sink).await; - } - }); - Receiver { state } - } - - fn add_single(&mut self, req_id: &ReqId, res_collector: ResCollector) { - self.state.res_collectors.lock().unwrap().insert(req_id.clone(), res_collector); - } - - fn add_stream(&mut self, req_id: &ReqId, res_part_sink: ResPartCollector) { - self.state.res_part_collectors.lock().unwrap().insert(req_id.clone(), res_part_sink); - } -} - -type ResCollector = oneshot::Sender>; -type ResPartCollector = mpsc::Sender>; -type CloseSignalSink = oneshot::Sender>; -type CloseSignalReceiver = oneshot::Receiver>; - -#[derive(Debug)] -pub(crate) struct ResPartStream { - source: mpsc::Receiver>, - stream_req_sink: std::sync::mpsc::Sender, - req_id: ReqId, -} - -impl ResPartStream { - fn new( - source: mpsc::Receiver>, - stream_req_sink: std::sync::mpsc::Sender, - req_id: ReqId, - ) -> Self { - ResPartStream { source, stream_req_sink, req_id } - } -} - -impl Stream for ResPartStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { - let poll = Pin::new(&mut self.source).poll_next(ctx); - match poll { - Poll::Ready(Some(Ok(ref res_part))) => { - match &res_part.res { - Some(res_part::Res::StreamResPart(stream_res_part)) => { - // TODO: unwrap -> expect("enum out of range") - match State::from_i32(stream_res_part.state).unwrap() { - State::Done => Poll::Ready(None), - State::Continue => { - let req_id = self.req_id.clone(); - self.stream_req_sink.send(stream_req(req_id)).unwrap(); - ctx.waker().wake_by_ref(); - Poll::Pending - } - } - } - Some(_other) => poll, - None => panic!("{}", ClientError::MissingResponseField("res_part.res")), - } - } - poll => poll, - } - } -} diff --git a/src/concept/mod.rs b/src/concept/mod.rs index e3e1d450..98dd046d 100644 --- a/src/concept/mod.rs +++ b/src/concept/mod.rs @@ -30,12 +30,8 @@ use std::{ use chrono::NaiveDateTime; use futures::{FutureExt, Stream, StreamExt}; -use typedb_protocol::{ - attribute as attribute_proto, attribute_type as attribute_type_proto, - attribute_type::ValueType, concept as concept_proto, r#type as type_proto, r#type::Encoding, -}; -use crate::common::{error::ClientError, Result}; +use crate::common::{error::ConnectionError, Result}; #[derive(Clone, Debug)] pub enum Concept { @@ -43,41 +39,12 @@ pub enum Concept { Thing(Thing), } -impl Concept { - pub(crate) fn from_proto(mut proto: typedb_protocol::Concept) -> Result { - let concept = proto.concept.ok_or(ClientError::MissingResponseField("concept"))?; - match concept { - concept_proto::Concept::Thing(thing) => Ok(Self::Thing(Thing::from_proto(thing)?)), - concept_proto::Concept::Type(type_) => Ok(Self::Type(Type::from_proto(type_)?)), - } - } -} - #[derive(Clone, Debug)] pub enum Type { Thing(ThingType), Role(RoleType), } -impl Type { - pub(crate) fn from_proto(proto: typedb_protocol::Type) -> Result { - // TODO: replace unwrap() with ok_or(custom_error) throughout the module - match type_proto::Encoding::from_i32(proto.encoding).unwrap() { - Encoding::ThingType => Ok(Self::Thing(ThingType::Root(RootThingType::default()))), - Encoding::EntityType => { - Ok(Self::Thing(ThingType::Entity(EntityType::from_proto(proto)))) - } - Encoding::RelationType => { - Ok(Self::Thing(ThingType::Relation(RelationType::from_proto(proto)))) - } - Encoding::AttributeType => { - Ok(Self::Thing(ThingType::Attribute(AttributeType::from_proto(proto)?))) - } - Encoding::RoleType => Ok(Self::Role(RoleType::from_proto(proto))), - } - } -} - #[derive(Clone, Debug)] pub enum ThingType { Root(RootThingType), @@ -120,10 +87,6 @@ impl EntityType { pub fn new(label: String) -> Self { Self { label } } - - fn from_proto(proto: typedb_protocol::Type) -> Self { - Self::new(proto.label) - } } #[derive(Clone, Debug)] @@ -135,10 +98,6 @@ impl RelationType { pub fn new(label: String) -> Self { Self { label } } - - fn from_proto(proto: typedb_protocol::Type) -> Self { - Self::new(proto.label) - } } #[derive(Clone, Debug)] @@ -151,19 +110,6 @@ pub enum AttributeType { DateTime(DateTimeAttributeType), } -impl AttributeType { - pub(crate) fn from_proto(mut proto: typedb_protocol::Type) -> Result { - match attribute_type_proto::ValueType::from_i32(proto.value_type).unwrap() { - ValueType::Object => Ok(Self::Root(RootAttributeType::default())), - ValueType::Boolean => Ok(Self::Boolean(BooleanAttributeType { label: proto.label })), - ValueType::Long => Ok(Self::Long(LongAttributeType { label: proto.label })), - ValueType::Double => Ok(Self::Double(DoubleAttributeType { label: proto.label })), - ValueType::String => Ok(Self::String(StringAttributeType { label: proto.label })), - ValueType::Datetime => Ok(Self::DateTime(DateTimeAttributeType { label: proto.label })), - } - } -} - #[derive(Clone, Debug)] pub struct RootAttributeType { pub label: String, @@ -244,10 +190,6 @@ pub struct RoleType { } impl RoleType { - fn from_proto(proto: typedb_protocol::Type) -> Self { - Self::new(ScopedLabel::new(proto.scope, proto.label)) - } - pub fn new(label: ScopedLabel) -> Self { Self { label } } @@ -261,23 +203,6 @@ pub enum Thing { Attribute(Attribute), } -impl Thing { - pub(crate) fn from_proto(mut proto: typedb_protocol::Thing) -> Result { - match typedb_protocol::r#type::Encoding::from_i32(proto.r#type.clone().unwrap().encoding) - .unwrap() - { - type_proto::Encoding::EntityType => Ok(Self::Entity(Entity::from_proto(proto)?)), - type_proto::Encoding::RelationType => Ok(Self::Relation(Relation::from_proto(proto)?)), - type_proto::Encoding::AttributeType => { - Ok(Self::Attribute(Attribute::from_proto(proto)?)) - } - _ => { - todo!() - } - } - } -} - // impl ConceptApi for Thing {} // impl ThingApi for Thing { @@ -298,12 +223,6 @@ pub struct Entity { pub type_: EntityType, } -impl Entity { - pub(crate) fn from_proto(mut proto: typedb_protocol::Thing) -> Result { - Ok(Self { type_: EntityType::from_proto(proto.r#type.unwrap()), iid: proto.iid }) - } -} - // impl ThingApi for Entity { // // TODO: use enum_dispatch macro to avoid manually writing the duplicates of this method // fn get_iid(&self) -> &Vec { @@ -321,12 +240,6 @@ pub struct Relation { pub type_: RelationType, } -impl Relation { - pub(crate) fn from_proto(mut proto: typedb_protocol::Thing) -> Result { - Ok(Self { type_: RelationType::from_proto(proto.r#type.unwrap()), iid: proto.iid }) - } -} - // macro_rules! default_impl { // { impl $trait:ident $body:tt for $($t:ident),* $(,)? } => { // $(impl $trait for $t $body)* @@ -354,66 +267,6 @@ pub enum Attribute { DateTime(DateTimeAttribute), } -impl Attribute { - pub(crate) fn from_proto(mut proto: typedb_protocol::Thing) -> Result { - match attribute_type_proto::ValueType::from_i32(proto.r#type.unwrap().value_type).unwrap() { - ValueType::Object => { - todo!() - } - ValueType::Boolean => Ok(Self::Boolean(BooleanAttribute { - value: if let attribute_proto::value::Value::Boolean(value) = - proto.value.unwrap().value.unwrap() - { - value - } else { - todo!() - }, - iid: proto.iid, - })), - ValueType::Long => Ok(Self::Long(LongAttribute { - value: if let attribute_proto::value::Value::Long(value) = - proto.value.unwrap().value.unwrap() - { - value - } else { - todo!() - }, - iid: proto.iid, - })), - ValueType::Double => Ok(Self::Double(DoubleAttribute { - value: if let attribute_proto::value::Value::Double(value) = - proto.value.unwrap().value.unwrap() - { - value - } else { - todo!() - }, - iid: proto.iid, - })), - ValueType::String => Ok(Self::String(StringAttribute { - value: if let attribute_proto::value::Value::String(value) = - proto.value.unwrap().value.unwrap() - { - value - } else { - todo!() - }, - iid: proto.iid, - })), - ValueType::Datetime => Ok(Self::DateTime(DateTimeAttribute { - value: if let attribute_proto::value::Value::DateTime(value) = - proto.value.unwrap().value.unwrap() - { - NaiveDateTime::from_timestamp_opt(value / 1000, (value % 1000) as u32).unwrap() - } else { - todo!() - }, - iid: proto.iid, - })), - } - } -} - #[derive(Clone, Debug)] pub struct BooleanAttribute { pub iid: Vec, diff --git a/src/connection/cluster/client.rs b/src/connection/cluster/client.rs deleted file mode 100644 index af403b95..00000000 --- a/src/connection/cluster/client.rs +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::sync::Arc; - -use super::{DatabaseManager, Session}; -use crate::common::{ClusterRPC, Credential, Result, SessionType}; - -pub struct Client { - databases: DatabaseManager, - cluster_rpc: Arc, -} - -impl Client { - pub async fn new>(init_addresses: &[T], credential: Credential) -> Result { - let addresses = ClusterRPC::fetch_current_addresses(init_addresses, &credential).await?; - let cluster_rpc = ClusterRPC::new(addresses, credential)?; - let databases = DatabaseManager::new(cluster_rpc.clone()); - Ok(Self { cluster_rpc, databases }) - } - - pub fn databases(&mut self) -> &mut DatabaseManager { - &mut self.databases - } - - pub async fn session( - &mut self, - database_name: &str, - session_type: SessionType, - ) -> Result { - Session::new( - self.databases.get(database_name).await?, - session_type, - self.cluster_rpc.clone(), - ) - .await - } -} diff --git a/src/connection/cluster/database.rs b/src/connection/cluster/database.rs deleted file mode 100644 index 357bdb79..00000000 --- a/src/connection/cluster/database.rs +++ /dev/null @@ -1,247 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::{fmt, fmt::Debug, future::Future, sync::Arc, time::Duration}; - -use log::debug; -use tokio::time::sleep; - -use crate::{ - common::{ - error::ClientError, rpc::builder::cluster::database_manager::get_req, Address, ClusterRPC, - ClusterServerRPC, Error, Result, - }, - connection::server, -}; - -#[derive(Clone)] -pub struct Database { - pub name: String, - replicas: Vec, - cluster_rpc: Arc, -} - -impl Debug for Database { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("cluster::Database") - .field("name", &self.name) - .field("replicas", &self.replicas) - .finish() - } -} - -impl Database { - const PRIMARY_REPLICA_TASK_MAX_RETRIES: usize = 10; - const FETCH_REPLICAS_MAX_RETRIES: usize = 10; - const WAIT_FOR_PRIMARY_REPLICA_SELECTION: Duration = Duration::from_secs(2); - - pub(super) fn new( - proto: typedb_protocol::ClusterDatabase, - cluster_rpc: Arc, - ) -> Result { - let name = proto.name.clone(); - let replicas = Replica::from_proto(proto, &cluster_rpc); - Ok(Self { name, replicas, cluster_rpc }) - } - - pub(super) async fn get(name: &str, cluster_rpc: Arc) -> Result { - Ok(Self { - name: name.to_string(), - replicas: Replica::fetch_all(name, cluster_rpc.clone()).await?, - cluster_rpc, - }) - } - - pub async fn delete(mut self) -> Result { - self.run_on_primary_replica(|database, _, _| database.delete()).await - } - - pub async fn schema(&mut self) -> Result { - self.run_failsafe(|mut database, _, _| async move { database.schema().await }).await - } - - pub async fn type_schema(&mut self) -> Result { - self.run_failsafe(|mut database, _, _| async move { database.type_schema().await }).await - } - - pub async fn rule_schema(&mut self) -> Result { - self.run_failsafe(|mut database, _, _| async move { database.rule_schema().await }).await - } - - pub(crate) async fn run_failsafe(&mut self, task: F) -> Result - where - F: Fn(server::Database, ClusterServerRPC, bool) -> P, - P: Future>, - { - match self.run_on_any_replica(&task).await { - Err(Error::Client(ClientError::ClusterReplicaNotPrimary())) => { - debug!("Attempted to run on a non-primary replica, retrying on primary..."); - self.run_on_primary_replica(&task).await - } - res => res, - } - } - - async fn run_on_any_replica(&mut self, task: F) -> Result - where - F: Fn(server::Database, ClusterServerRPC, bool) -> P, - P: Future>, - { - let mut is_first_run = true; - for replica in self.replicas.iter() { - match task( - replica.database.clone(), - self.cluster_rpc.get_server_rpc(&replica.address), - is_first_run, - ) - .await - { - Err(Error::Client(ClientError::UnableToConnect())) => { - println!("Unable to connect to {}. Attempting next server.", replica.address); - } - res => return res, - } - is_first_run = false; - } - Err(self.cluster_rpc.unable_to_connect()) - } - - async fn run_on_primary_replica(&mut self, task: F) -> Result - where - F: Fn(server::Database, ClusterServerRPC, bool) -> P, - P: Future>, - { - let mut primary_replica = if let Some(replica) = self.primary_replica() { - replica - } else { - self.seek_primary_replica().await? - }; - - for retry in 0..Self::PRIMARY_REPLICA_TASK_MAX_RETRIES { - match task( - primary_replica.database.clone(), - self.cluster_rpc.get_server_rpc(&primary_replica.address), - retry == 0, - ) - .await - { - Err(Error::Client( - ClientError::ClusterReplicaNotPrimary() | ClientError::UnableToConnect(), - )) => { - debug!("Primary replica error, waiting..."); - Self::wait_for_primary_replica_selection().await; - primary_replica = self.seek_primary_replica().await?; - } - res => return res, - } - } - Err(self.cluster_rpc.unable_to_connect()) - } - - async fn seek_primary_replica(&mut self) -> Result { - for _ in 0..Self::FETCH_REPLICAS_MAX_RETRIES { - self.replicas = Replica::fetch_all(&self.name, self.cluster_rpc.clone()).await?; - if let Some(replica) = self.primary_replica() { - return Ok(replica); - } - Self::wait_for_primary_replica_selection().await; - } - Err(self.cluster_rpc.unable_to_connect()) - } - - fn primary_replica(&mut self) -> Option { - self.replicas.iter().filter(|r| r.is_primary).max_by_key(|r| r.term).cloned() - } - - async fn wait_for_primary_replica_selection() { - sleep(Self::WAIT_FOR_PRIMARY_REPLICA_SELECTION).await; - } -} - -#[derive(Clone)] -pub struct Replica { - address: Address, - database_name: String, - is_primary: bool, - term: i64, - is_preferred: bool, - database: server::Database, -} - -impl Debug for Replica { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Replica") - .field("address", &self.address) - .field("database_name", &self.database_name) - .field("is_primary", &self.is_primary) - .field("term", &self.term) - .field("is_preferred", &self.is_preferred) - .finish() - } -} - -impl Replica { - fn new( - name: &str, - metadata: typedb_protocol::cluster_database::Replica, - server_rpc: ClusterServerRPC, - ) -> Self { - Self { - address: metadata.address.parse().expect("Invalid URI received from the server"), - database_name: name.to_owned(), - is_primary: metadata.primary, - term: metadata.term, - is_preferred: metadata.preferred, - database: server::Database::new(name, server_rpc.into()), - } - } - - fn from_proto(proto: typedb_protocol::ClusterDatabase, cluster_rpc: &ClusterRPC) -> Vec { - proto - .replicas - .into_iter() - .map(|replica| { - let server_rpc = cluster_rpc.get_server_rpc(&replica.address.parse().unwrap()); - Replica::new(&proto.name, replica, server_rpc) - }) - .collect() - } - - async fn fetch_all(name: &str, cluster_rpc: Arc) -> Result> { - for mut rpc in cluster_rpc.iter_server_rpcs_cloned() { - let res = rpc.databases_get(get_req(name)).await; - match res { - Ok(res) => { - return Ok(Replica::from_proto(res.database.unwrap(), &cluster_rpc)); - } - Err(Error::Client(ClientError::UnableToConnect())) => { - println!( - "Failed to fetch replica info for database '{}' from {}. Attempting next server.", - name, - rpc.address() - ); - } - Err(err) => return Err(err), - } - } - Err(cluster_rpc.unable_to_connect()) - } -} diff --git a/src/connection/cluster/database_manager.rs b/src/connection/cluster/database_manager.rs deleted file mode 100644 index 48fd1190..00000000 --- a/src/connection/cluster/database_manager.rs +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::{fmt::Debug, future::Future, sync::Arc}; - -use super::Database; -use crate::{ - common::{ - error::ClientError, - rpc::builder::{ - cluster::database_manager::all_req, - core::database_manager::{contains_req, create_req}, - }, - ClusterRPC, ClusterServerRPC, Result, - }, - connection::server, -}; - -#[derive(Clone, Debug)] -pub struct DatabaseManager { - cluster_rpc: Arc, -} - -impl DatabaseManager { - pub(crate) fn new(cluster_rpc: Arc) -> Self { - Self { cluster_rpc } - } - - pub async fn get(&mut self, name: &str) -> Result { - Database::get(name, self.cluster_rpc.clone()).await - } - - pub async fn contains(&mut self, name: &str) -> Result { - Ok(self - .run_failsafe(name, move |database, mut server_rpc, _| { - let req = contains_req(&database.name); - async move { server_rpc.databases_contains(req).await } - }) - .await? - .contains) - } - - pub async fn create(&mut self, name: &str) -> Result { - self.run_failsafe(name, |database, mut server_rpc, _| { - let req = create_req(&database.name); - async move { server_rpc.databases_create(req).await } - }) - .await?; - Ok(()) - } - - pub async fn all(&mut self) -> Result> { - let mut error_buffer = Vec::with_capacity(self.cluster_rpc.server_rpc_count()); - for mut server_rpc in self.cluster_rpc.iter_server_rpcs_cloned() { - match server_rpc.databases_all(all_req()).await { - Ok(list) => { - return list - .databases - .into_iter() - .map(|proto_db| Database::new(proto_db, self.cluster_rpc.clone())) - .collect() - } - Err(err) => error_buffer.push(format!("- {}: {}", server_rpc.address(), err)), - } - } - Err(ClientError::ClusterAllNodesFailed(error_buffer.join("\n")))? - } - - async fn run_failsafe(&mut self, name: &str, task: F) -> Result - where - F: Fn(server::Database, ClusterServerRPC, bool) -> P, - P: Future>, - { - Database::get(name, self.cluster_rpc.clone()).await?.run_failsafe(&task).await - } -} diff --git a/src/connection/cluster/session.rs b/src/connection/cluster/session.rs deleted file mode 100644 index 12bc7467..00000000 --- a/src/connection/cluster/session.rs +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::sync::Arc; - -use super::Database; -use crate::{ - common::{ClusterRPC, Result, SessionType, TransactionType}, - connection::{core, server, server::Transaction}, -}; - -pub struct Session { - pub database: Database, - pub session_type: SessionType, - - server_session: server::Session, - cluster_rpc: Arc, -} - -impl Session { - // TODO options - pub(crate) async fn new( - mut database: Database, - session_type: SessionType, - cluster_rpc: Arc, - ) -> Result { - let server_session = database - .run_failsafe(|database, server_rpc, _| async { - let database_name = database.name; - server::Session::new( - database_name.as_str(), - session_type, - core::Options::default(), - server_rpc.into(), - ) - .await - }) - .await?; - - Ok(Self { database, session_type, server_session, cluster_rpc }) - } - - //TODO options - pub async fn transaction(&mut self, transaction_type: TransactionType) -> Result { - let (session, transaction) = self - .database - .run_failsafe(|database, server_rpc, is_first_run| { - let session_type = self.session_type; - let session = &self.server_session; - async move { - if is_first_run { - let transaction = session.transaction(transaction_type).await?; - Ok((None, transaction)) - } else { - let server_session = server::Session::new( - database.name.as_str(), - session_type, - core::Options::default(), - server_rpc.into(), - ) - .await?; - let transaction = server_session.transaction(transaction_type).await?; - Ok((Some(server_session), transaction)) - } - } - }) - .await?; - - if let Some(session) = session { - self.server_session = session; - } - - Ok(transaction) - } -} diff --git a/src/connection/connection.rs b/src/connection/connection.rs new file mode 100644 index 00000000..1d4d176d --- /dev/null +++ b/src/connection/connection.rs @@ -0,0 +1,332 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{ + collections::{HashMap, HashSet}, + fmt, + sync::{Arc, Mutex}, + time::Duration, +}; + +use itertools::Itertools; +use tokio::{ + select, + sync::mpsc::{unbounded_channel as unbounded_async, UnboundedReceiver, UnboundedSender}, + time::{sleep_until, Instant}, +}; + +use super::{ + network::transmitter::{RPCTransmitter, TransactionTransmitter}, + runtime::BackgroundRuntime, + TransactionStream, +}; +use crate::{ + common::{ + address::Address, + error::{ConnectionError, Error}, + info::{DatabaseInfo, SessionInfo}, + Result, SessionID, SessionType, TransactionType, + }, + connection::message::{Request, Response, TransactionRequest}, + error::InternalError, + Credential, Options, +}; + +#[derive(Clone)] +pub struct Connection { + server_connections: HashMap, + background_runtime: Arc, +} + +impl Connection { + pub fn new_plaintext(address: impl AsRef) -> Result { + let address: Address = address.as_ref().parse()?; + let background_runtime = Arc::new(BackgroundRuntime::new()?); + let server_connection = ServerConnection::new_plaintext(background_runtime.clone(), address.clone())?; + Ok(Self { server_connections: [(address, server_connection)].into(), background_runtime }) + } + + pub fn new_encrypted + Sync>(init_addresses: &[T], credential: Credential) -> Result { + let background_runtime = Arc::new(BackgroundRuntime::new()?); + + let init_addresses = init_addresses.iter().map(|addr| addr.as_ref().parse()).try_collect()?; + let addresses = Self::fetch_current_addresses(background_runtime.clone(), init_addresses, credential.clone())?; + + let mut server_connections = HashMap::with_capacity(addresses.len()); + for address in addresses { + let server_connection = + ServerConnection::new_encrypted(background_runtime.clone(), address.clone(), credential.clone())?; + server_connections.insert(address, server_connection); + } + + Ok(Self { server_connections, background_runtime }) + } + + fn fetch_current_addresses( + background_runtime: Arc, + addresses: Vec
, + credential: Credential, + ) -> Result> { + for address in addresses { + let server_connection = + ServerConnection::new_encrypted(background_runtime.clone(), address.clone(), credential.clone())?; + match server_connection.servers_all() { + Ok(servers) => return Ok(servers.into_iter().collect()), + Err(Error::Connection(ConnectionError::UnableToConnect())) => (), + Err(err) => Err(err)?, + } + } + Err(ConnectionError::UnableToConnect())? + } + + pub fn force_close(self) -> Result { + self.server_connections.values().map(ServerConnection::force_close).try_collect()?; + self.background_runtime.force_close() + } + + pub(crate) fn server_count(&self) -> usize { + self.server_connections.len() + } + + pub(crate) fn addresses(&self) -> impl Iterator { + self.server_connections.keys() + } + + pub(crate) fn connection(&self, address: &Address) -> Result<&ServerConnection> { + self.server_connections + .get(address) + .ok_or_else(|| InternalError::UnknownConnectionAddress(address.to_string()).into()) + } + + pub(crate) fn connections(&self) -> impl Iterator + '_ { + self.server_connections.values() + } + + pub(crate) fn unable_to_connect_error(&self) -> Error { + Error::Connection(ConnectionError::ClusterUnableToConnect( + self.addresses().map(Address::to_string).collect::>().join(","), + )) + } +} + +impl fmt::Debug for Connection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Connection").field("server_connections", &self.server_connections).finish() + } +} + +#[derive(Clone)] +pub(crate) struct ServerConnection { + address: Address, + background_runtime: Arc, + open_sessions: Arc>>>, + request_transmitter: Arc, +} + +impl ServerConnection { + fn new_plaintext(background_runtime: Arc, address: Address) -> Result { + let request_transmitter = Arc::new(RPCTransmitter::start_plaintext(address.clone(), &background_runtime)?); + Ok(Self { address, background_runtime, open_sessions: Default::default(), request_transmitter }) + } + + fn new_encrypted( + background_runtime: Arc, + address: Address, + credential: Credential, + ) -> Result { + let request_transmitter = + Arc::new(RPCTransmitter::start_encrypted(address.clone(), credential, &background_runtime)?); + Ok(Self { address, background_runtime, open_sessions: Default::default(), request_transmitter }) + } + + pub(crate) fn address(&self) -> &Address { + &self.address + } + + async fn request_async(&self, request: Request) -> Result { + if !self.background_runtime.is_open() { + return Err(ConnectionError::ConnectionIsClosed().into()); + } + self.request_transmitter.request_async(request).await + } + + fn request_blocking(&self, request: Request) -> Result { + if !self.background_runtime.is_open() { + return Err(ConnectionError::ConnectionIsClosed().into()); + } + self.request_transmitter.request_blocking(request) + } + + pub(crate) fn force_close(&self) -> Result { + let session_ids: Vec = self.open_sessions.lock().unwrap().keys().cloned().collect(); + for session_id in session_ids.into_iter() { + self.close_session(session_id).ok(); + } + self.request_transmitter.force_close() + } + + pub(crate) fn servers_all(&self) -> Result> { + match self.request_blocking(Request::ServersAll)? { + Response::ServersAll { servers } => Ok(servers), + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } + + pub(crate) async fn database_exists(&self, database_name: String) -> Result { + match self.request_async(Request::DatabasesContains { database_name }).await? { + Response::DatabasesContains { contains } => Ok(contains), + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } + + pub(crate) async fn create_database(&self, database_name: String) -> Result { + self.request_async(Request::DatabaseCreate { database_name }).await?; + Ok(()) + } + + pub(crate) async fn get_database_replicas(&self, database_name: String) -> Result { + match self.request_async(Request::DatabaseGet { database_name }).await? { + Response::DatabaseGet { database } => Ok(database), + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } + + pub(crate) async fn all_databases(&self) -> Result> { + match self.request_async(Request::DatabasesAll).await? { + Response::DatabasesAll { databases } => Ok(databases), + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } + + pub(crate) async fn database_schema(&self, database_name: String) -> Result { + match self.request_async(Request::DatabaseSchema { database_name }).await? { + Response::DatabaseSchema { schema } => Ok(schema), + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } + + pub(crate) async fn database_type_schema(&self, database_name: String) -> Result { + match self.request_async(Request::DatabaseTypeSchema { database_name }).await? { + Response::DatabaseTypeSchema { schema } => Ok(schema), + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } + + pub(crate) async fn database_rule_schema(&self, database_name: String) -> Result { + match self.request_async(Request::DatabaseRuleSchema { database_name }).await? { + Response::DatabaseRuleSchema { schema } => Ok(schema), + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } + + pub(crate) async fn delete_database(&self, database_name: String) -> Result { + self.request_async(Request::DatabaseDelete { database_name }).await?; + Ok(()) + } + + pub(crate) async fn open_session( + &self, + database_name: String, + session_type: SessionType, + options: Options, + ) -> Result { + let start = Instant::now(); + match self.request_async(Request::SessionOpen { database_name, session_type, options }).await? { + Response::SessionOpen { session_id, server_duration } => { + let (pulse_shutdown_sink, pulse_shutdown_source) = unbounded_async(); + self.open_sessions.lock().unwrap().insert(session_id.clone(), pulse_shutdown_sink); + self.background_runtime.spawn(session_pulse( + session_id.clone(), + self.request_transmitter.clone(), + pulse_shutdown_source, + )); + Ok(SessionInfo { + address: self.address.clone(), + session_id, + network_latency: start.elapsed() - server_duration, + }) + } + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } + + pub(crate) fn close_session(&self, session_id: SessionID) -> Result { + if let Some(sink) = self.open_sessions.lock().unwrap().remove(&session_id) { + sink.send(()).ok(); + } + self.request_blocking(Request::SessionClose { session_id })?; + Ok(()) + } + + pub(crate) async fn open_transaction( + &self, + session_id: SessionID, + transaction_type: TransactionType, + options: Options, + network_latency: Duration, + ) -> Result { + match self + .request_async(Request::Transaction(TransactionRequest::Open { + session_id, + transaction_type, + options: options.clone(), + network_latency, + })) + .await? + { + Response::TransactionOpen { request_sink, response_source } => { + let transmitter = TransactionTransmitter::new(&self.background_runtime, request_sink, response_source); + Ok(TransactionStream::new(transaction_type, options, transmitter)) + } + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } +} + +impl fmt::Debug for ServerConnection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ServerConnection") + .field("address", &self.address) + .field("open_sessions", &self.open_sessions) + .finish() + } +} + +async fn session_pulse( + session_id: SessionID, + request_transmitter: Arc, + mut shutdown_source: UnboundedReceiver<()>, +) { + const PULSE_INTERVAL: Duration = Duration::from_secs(5); + let mut next_pulse = Instant::now(); + loop { + select! { + _ = sleep_until(next_pulse) => { + request_transmitter + .request_async(Request::SessionPulse { session_id: session_id.clone() }) + .await + .ok(); + next_pulse += PULSE_INTERVAL; + } + _ = shutdown_source.recv() => break, + } + } +} diff --git a/src/connection/core/client.rs b/src/connection/core/client.rs deleted file mode 100644 index aa613a93..00000000 --- a/src/connection/core/client.rs +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use crate::{ - common::{CoreRPC, Result, SessionType}, - connection::{core, server}, -}; - -pub struct Client { - databases: core::DatabaseManager, - core_rpc: CoreRPC, -} - -impl Client { - pub async fn new(address: &str) -> Result { - let core_rpc = CoreRPC::connect(address.parse()?).await?; - Ok(Self { databases: core::DatabaseManager::new(core_rpc.clone()), core_rpc }) - } - - pub async fn with_default_address() -> Result { - Self::new("http://localhost:1729").await - } - - pub fn databases(&mut self) -> &mut core::DatabaseManager { - &mut self.databases - } - - pub async fn session( - &mut self, - database_name: &str, - session_type: SessionType, - ) -> Result { - self.session_with_options(database_name, session_type, core::Options::default()).await - } - - pub async fn session_with_options( - &mut self, - database_name: &str, - session_type: SessionType, - options: core::Options, - ) -> Result { - server::Session::new(database_name, session_type, options, self.core_rpc.clone().into()) - .await - } -} diff --git a/src/connection/core/database_manager.rs b/src/connection/core/database_manager.rs deleted file mode 100644 index 1d28d901..00000000 --- a/src/connection/core/database_manager.rs +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use crate::{ - common::{ - error::ClientError, - rpc::builder::core::database_manager::{all_req, contains_req, create_req}, - CoreRPC, Result, - }, - connection::server, -}; - -/// An interface for performing database-level operations against the connected server. -/// These operations include: -/// -/// - Listing [all databases][DatabaseManager::all] -/// - Creating a [new database][DatabaseManager::create] -/// - Checking if a database [exists][DatabaseManager::contains] -/// - Retrieving a [specific database][DatabaseManager::get] in order to perform further operations on it -/// -/// These operations all connect to the server to retrieve results. In the event of a connection -/// failure or other problem executing the operation, they will return an [`Err`][Err] result. -#[derive(Clone, Debug)] -pub struct DatabaseManager { - pub(crate) core_rpc: CoreRPC, -} - -impl DatabaseManager { - pub(crate) fn new(core_rpc: CoreRPC) -> Self { - DatabaseManager { core_rpc } - } - - /// Retrieves a single [`Database`][Database] by name. Returns an [`Err`][Err] if there does not - /// exist a database with the provided name. - pub async fn get(&mut self, name: &str) -> Result { - match self.contains(name).await? { - true => Ok(server::Database::new(name, self.core_rpc.clone().into())), - false => Err(ClientError::DatabaseDoesNotExist(name.to_string()))?, - } - } - - pub async fn contains(&mut self, name: &str) -> Result { - self.core_rpc.databases_contains(contains_req(name)).await.map(|res| res.contains) - } - - pub async fn create(&mut self, name: &str) -> Result { - self.core_rpc.databases_create(create_req(name)).await.map(|_| ()) - } - - pub async fn all(&mut self) -> Result> { - self.core_rpc.databases_all(all_req()).await.map(|res| { - res.names - .iter() - .map(|name| server::Database::new(name, self.core_rpc.clone().into())) - .collect() - }) - } -} diff --git a/src/connection/core/options.rs b/src/connection/core/options.rs deleted file mode 100644 index 778d69c4..00000000 --- a/src/connection/core/options.rs +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::time::Duration; - -use typedb_protocol::{ - options::{ - ExplainOpt::Explain, InferOpt::Infer, ParallelOpt::Parallel, PrefetchOpt::Prefetch, - PrefetchSizeOpt::PrefetchSize, ReadAnyReplicaOpt::ReadAnyReplica, - SchemaLockAcquireTimeoutOpt::SchemaLockAcquireTimeoutMillis, - SessionIdleTimeoutOpt::SessionIdleTimeoutMillis, TraceInferenceOpt::TraceInference, - TransactionTimeoutOpt::TransactionTimeoutMillis, - }, - Options as OptionsProto, -}; - -macro_rules! options { - {pub struct $name:ident { $(pub $field_name:ident : Option<$field_type:ty>),* $(,)? }} => { - #[derive(Clone, Debug, Default)] - pub struct $name { - $(pub $field_name: Option<$field_type>,)* - } - - impl $name { - $( - pub fn $field_name(mut self, value: $field_type) -> Self { - self.$field_name = value.into(); - self - } - )* - } - }; -} - -options! { - pub struct Options { - pub infer: Option, - pub trace_inference: Option, - pub explain: Option, - pub parallel: Option, - pub prefetch: Option, - pub prefetch_size: Option, - pub session_idle_timeout: Option, - pub transaction_timeout: Option, - pub schema_lock_acquire_timeout: Option, - } -} - -options! { - pub struct ClusterOptions { - pub infer: Option, - pub trace_inference: Option, - pub explain: Option, - pub parallel: Option, - pub prefetch: Option, - pub prefetch_size: Option, - pub session_idle_timeout: Option, - pub transaction_timeout: Option, - pub schema_lock_acquire_timeout: Option, - pub read_any_replica: Option, - } -} - -impl Options { - pub fn new_core() -> Options { - Options::default() - } - - pub fn new_cluster() -> ClusterOptions { - ClusterOptions::default() - } - - pub(crate) fn to_proto(&self) -> OptionsProto { - OptionsProto { - infer_opt: self.infer.map(Infer), - trace_inference_opt: self.trace_inference.map(TraceInference), - explain_opt: self.explain.map(Explain), - parallel_opt: self.parallel.map(Parallel), - prefetch_size_opt: self.prefetch_size.map(PrefetchSize), - prefetch_opt: self.prefetch.map(Prefetch), - session_idle_timeout_opt: self - .session_idle_timeout - .map(|val| SessionIdleTimeoutMillis(val.as_millis() as i32)), - transaction_timeout_opt: self - .transaction_timeout - .map(|val| TransactionTimeoutMillis(val.as_millis() as i32)), - schema_lock_acquire_timeout_opt: self - .schema_lock_acquire_timeout - .map(|val| SchemaLockAcquireTimeoutMillis(val.as_millis() as i32)), - read_any_replica_opt: None, - } - } -} - -impl ClusterOptions { - pub(crate) fn to_proto(&self) -> OptionsProto { - OptionsProto { - infer_opt: self.infer.map(Infer), - trace_inference_opt: self.trace_inference.map(TraceInference), - explain_opt: self.explain.map(Explain), - parallel_opt: self.parallel.map(Parallel), - prefetch_size_opt: self.prefetch_size.map(PrefetchSize), - prefetch_opt: self.prefetch.map(Prefetch), - session_idle_timeout_opt: self - .session_idle_timeout - .map(|val| SessionIdleTimeoutMillis(val.as_millis() as i32)), - transaction_timeout_opt: self - .transaction_timeout - .map(|val| TransactionTimeoutMillis(val.as_millis() as i32)), - schema_lock_acquire_timeout_opt: self - .schema_lock_acquire_timeout - .map(|val| SchemaLockAcquireTimeoutMillis(val.as_millis() as i32)), - read_any_replica_opt: self.read_any_replica.map(ReadAnyReplica), - } - } -} diff --git a/src/connection/message.rs b/src/connection/message.rs new file mode 100644 index 00000000..9647fb13 --- /dev/null +++ b/src/connection/message.rs @@ -0,0 +1,147 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::time::Duration; + +use tokio::sync::mpsc::UnboundedSender; +use tonic::Streaming; +use typedb_protocol::transaction; + +use crate::{ + answer::{ConceptMap, Numeric}, + common::{address::Address, info::DatabaseInfo, RequestID, SessionID}, + Options, SessionType, TransactionType, +}; + +#[derive(Debug)] +pub(super) enum Request { + ServersAll, + + DatabasesContains { database_name: String }, + DatabaseCreate { database_name: String }, + DatabaseGet { database_name: String }, + DatabasesAll, + + DatabaseSchema { database_name: String }, + DatabaseTypeSchema { database_name: String }, + DatabaseRuleSchema { database_name: String }, + DatabaseDelete { database_name: String }, + + SessionOpen { database_name: String, session_type: SessionType, options: Options }, + SessionClose { session_id: SessionID }, + SessionPulse { session_id: SessionID }, + + Transaction(TransactionRequest), +} + +#[derive(Debug)] +pub(super) enum Response { + ServersAll { + servers: Vec
, + }, + + DatabasesContains { + contains: bool, + }, + DatabaseCreate, + DatabaseGet { + database: DatabaseInfo, + }, + DatabasesAll { + databases: Vec, + }, + + DatabaseDelete, + DatabaseSchema { + schema: String, + }, + DatabaseTypeSchema { + schema: String, + }, + DatabaseRuleSchema { + schema: String, + }, + + SessionOpen { + session_id: SessionID, + server_duration: Duration, + }, + SessionPulse, + SessionClose, + + TransactionOpen { + request_sink: UnboundedSender, + response_source: Streaming, + }, +} + +#[derive(Debug)] +pub(super) enum TransactionRequest { + Open { session_id: SessionID, transaction_type: TransactionType, options: Options, network_latency: Duration }, + Commit, + Rollback, + Query(QueryRequest), + Stream { request_id: RequestID }, +} + +#[derive(Debug)] +pub(super) enum TransactionResponse { + Open, + Commit, + Rollback, + Query(QueryResponse), +} + +#[derive(Debug)] +pub(super) enum QueryRequest { + Define { query: String, options: Options }, + Undefine { query: String, options: Options }, + Delete { query: String, options: Options }, + + Match { query: String, options: Options }, + Insert { query: String, options: Options }, + Update { query: String, options: Options }, + + MatchAggregate { query: String, options: Options }, + + Explain { explainable_id: i64, options: Options }, // TODO: ID type + + MatchGroup { query: String, options: Options }, + MatchGroupAggregate { query: String, options: Options }, +} + +#[derive(Debug)] +pub(super) enum QueryResponse { + Define, + Undefine, + Delete, + + Match { answers: Vec }, + Insert { answers: Vec }, + Update { answers: Vec }, + + MatchAggregate { answer: Numeric }, + + Explain {}, // TODO: explanations + + MatchGroup {}, // TODO: ConceptMapGroup + MatchGroupAggregate {}, // TODO: NumericGroup +} diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 671b1c06..6b96ace1 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -19,6 +19,11 @@ * under the License. */ -pub mod cluster; -pub mod core; -pub mod server; +mod connection; +mod message; +mod network; +mod runtime; +mod transaction_stream; + +pub use self::connection::Connection; +pub(crate) use self::{connection::ServerConnection, transaction_stream::TransactionStream}; diff --git a/src/connection/network/channel.rs b/src/connection/network/channel.rs new file mode 100644 index 00000000..e75171b7 --- /dev/null +++ b/src/connection/network/channel.rs @@ -0,0 +1,137 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::sync::{Arc, RwLock}; + +use tonic::{ + body::BoxBody, + client::GrpcService, + service::{ + interceptor::{self, InterceptedService}, + Interceptor, + }, + transport::{channel, Channel, Error as TonicError}, + Request, Status, +}; + +use crate::{ + common::{address::Address, Result, StdResult}, + Credential, +}; + +type ResponseFuture = interceptor::ResponseFuture; + +pub(super) type PlainTextChannel = InterceptedService; +pub(super) type CallCredChannel = InterceptedService; + +pub(super) trait GRPCChannel: + GrpcService + Clone + Send + 'static +{ + fn is_plaintext(&self) -> bool; +} + +impl GRPCChannel for PlainTextChannel { + fn is_plaintext(&self) -> bool { + true + } +} + +impl GRPCChannel for CallCredChannel { + fn is_plaintext(&self) -> bool { + false + } +} + +pub(super) fn open_plaintext_channel(address: Address) -> PlainTextChannel { + PlainTextChannel::new(Channel::builder(address.into_uri()).connect_lazy(), PlainTextFacade) +} + +#[derive(Clone, Debug)] +pub(super) struct PlainTextFacade; + +impl Interceptor for PlainTextFacade { + fn call(&mut self, request: Request<()>) -> StdResult, Status> { + Ok(request) + } +} + +pub(super) fn open_encrypted_channel( + address: Address, + credential: Credential, +) -> Result<(CallCredChannel, Arc)> { + let mut builder = Channel::builder(address.into_uri()); + if credential.is_tls_enabled() { + builder = builder.tls_config(credential.tls_config().clone().unwrap())?; + } + let channel = builder.connect_lazy(); + let call_credentials = Arc::new(CallCredentials::new(credential)); + Ok((CallCredChannel::new(channel, CredentialInjector::new(call_credentials.clone())), call_credentials)) +} + +#[derive(Debug)] +pub(super) struct CallCredentials { + credential: Credential, + token: RwLock>, +} + +impl CallCredentials { + pub(super) fn new(credential: Credential) -> Self { + Self { credential, token: RwLock::new(None) } + } + + pub(super) fn username(&self) -> &str { + self.credential.username() + } + + pub(super) fn set_token(&self, token: String) { + *self.token.write().unwrap() = Some(token); + } + + pub(super) fn reset_token(&self) { + *self.token.write().unwrap() = None; + } + + pub(super) fn inject(&self, mut request: Request<()>) -> Request<()> { + request.metadata_mut().insert("username", self.credential.username().try_into().unwrap()); + match &*self.token.read().unwrap() { + Some(token) => request.metadata_mut().insert("token", token.try_into().unwrap()), + None => request.metadata_mut().insert("password", self.credential.password().try_into().unwrap()), + }; + request + } +} + +#[derive(Clone, Debug)] +pub(super) struct CredentialInjector { + call_credentials: Arc, +} + +impl CredentialInjector { + pub(super) fn new(call_credentials: Arc) -> Self { + Self { call_credentials } + } +} + +impl Interceptor for CredentialInjector { + fn call(&mut self, request: Request<()>) -> StdResult, Status> { + Ok(self.call_credentials.inject(request)) + } +} diff --git a/src/connection/core/mod.rs b/src/connection/network/mod.rs similarity index 86% rename from src/connection/core/mod.rs rename to src/connection/network/mod.rs index 4bbd6efe..6bd85c52 100644 --- a/src/connection/core/mod.rs +++ b/src/connection/network/mod.rs @@ -19,8 +19,7 @@ * under the License. */ -mod client; -mod database_manager; -mod options; - -pub use self::{client::Client, database_manager::DatabaseManager, options::Options}; +mod channel; +mod proto; +mod stub; +pub(super) mod transmitter; diff --git a/src/connection/network/proto/common.rs b/src/connection/network/proto/common.rs new file mode 100644 index 00000000..99c69a24 --- /dev/null +++ b/src/connection/network/proto/common.rs @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use typedb_protocol::{ + options::{ + ExplainOpt::Explain, InferOpt::Infer, ParallelOpt::Parallel, PrefetchOpt::Prefetch, + PrefetchSizeOpt::PrefetchSize, ReadAnyReplicaOpt::ReadAnyReplica, + SchemaLockAcquireTimeoutOpt::SchemaLockAcquireTimeoutMillis, SessionIdleTimeoutOpt::SessionIdleTimeoutMillis, + TraceInferenceOpt::TraceInference, TransactionTimeoutOpt::TransactionTimeoutMillis, + }, + session, transaction, Options as OptionsProto, +}; + +use super::IntoProto; +use crate::{Options, SessionType, TransactionType}; + +impl IntoProto for SessionType { + fn into_proto(self) -> session::Type { + match self { + SessionType::Data => session::Type::Data, + SessionType::Schema => session::Type::Schema, + } + } +} + +impl IntoProto for TransactionType { + fn into_proto(self) -> transaction::Type { + match self { + TransactionType::Read => transaction::Type::Read, + TransactionType::Write => transaction::Type::Write, + } + } +} + +impl IntoProto for Options { + fn into_proto(self) -> OptionsProto { + OptionsProto { + infer_opt: self.infer.map(Infer), + trace_inference_opt: self.trace_inference.map(TraceInference), + explain_opt: self.explain.map(Explain), + parallel_opt: self.parallel.map(Parallel), + prefetch_size_opt: self.prefetch_size.map(PrefetchSize), + prefetch_opt: self.prefetch.map(Prefetch), + session_idle_timeout_opt: self + .session_idle_timeout + .map(|val| SessionIdleTimeoutMillis(val.as_millis() as i32)), + transaction_timeout_opt: self + .transaction_timeout + .map(|val| TransactionTimeoutMillis(val.as_millis() as i32)), + schema_lock_acquire_timeout_opt: self + .schema_lock_acquire_timeout + .map(|val| SchemaLockAcquireTimeoutMillis(val.as_millis() as i32)), + read_any_replica_opt: self.read_any_replica.map(ReadAnyReplica), + } + } +} diff --git a/src/connection/network/proto/concept.rs b/src/connection/network/proto/concept.rs new file mode 100644 index 00000000..495f2529 --- /dev/null +++ b/src/connection/network/proto/concept.rs @@ -0,0 +1,196 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::collections::HashMap; + +use chrono::NaiveDateTime; +use typedb_protocol::{ + attribute::value::Value as ValueProto, attribute_type::ValueType, concept as concept_proto, numeric::Value, + r#type::Encoding, Concept as ConceptProto, ConceptMap as ConceptMapProto, Numeric as NumericProto, + Thing as ThingProto, Type as TypeProto, +}; + +use super::TryFromProto; +use crate::{ + answer::{ConceptMap, Numeric}, + concept::{ + Attribute, AttributeType, BooleanAttribute, BooleanAttributeType, Concept, DateTimeAttribute, + DateTimeAttributeType, DoubleAttribute, DoubleAttributeType, Entity, EntityType, LongAttribute, + LongAttributeType, Relation, RelationType, RoleType, RootAttributeType, RootThingType, ScopedLabel, + StringAttribute, StringAttributeType, Thing, ThingType, Type, + }, + connection::network::proto::FromProto, + error::{ConnectionError, InternalError}, + Result, +}; + +impl TryFromProto for Numeric { + fn try_from_proto(proto: NumericProto) -> Result { + match proto.value { + Some(Value::LongValue(long)) => Ok(Numeric::Long(long)), + Some(Value::DoubleValue(double)) => Ok(Numeric::Double(double)), + Some(Value::Nan(_)) => Ok(Numeric::NaN), + None => Err(ConnectionError::MissingResponseField("value").into()), + } + } +} + +impl TryFromProto for ConceptMap { + fn try_from_proto(proto: ConceptMapProto) -> Result { + let mut map = HashMap::with_capacity(proto.map.len()); + for (k, v) in proto.map { + map.insert(k, Concept::try_from_proto(v)?); + } + Ok(Self { map }) + } +} + +impl TryFromProto for Concept { + fn try_from_proto(proto: ConceptProto) -> Result { + let concept = proto.concept.ok_or(ConnectionError::MissingResponseField("concept"))?; + match concept { + concept_proto::Concept::Thing(thing) => Ok(Self::Thing(Thing::try_from_proto(thing)?)), + concept_proto::Concept::Type(type_) => Ok(Self::Type(Type::try_from_proto(type_)?)), + } + } +} + +impl TryFromProto for Encoding { + fn try_from_proto(proto: i32) -> Result { + Self::from_i32(proto).ok_or(InternalError::EnumOutOfBounds(proto, "Encoding").into()) + } +} + +impl TryFromProto for Type { + fn try_from_proto(proto: TypeProto) -> Result { + match Encoding::try_from_proto(proto.encoding)? { + Encoding::ThingType => Ok(Self::Thing(ThingType::Root(RootThingType::default()))), + Encoding::EntityType => Ok(Self::Thing(ThingType::Entity(EntityType::from_proto(proto)))), + Encoding::RelationType => Ok(Self::Thing(ThingType::Relation(RelationType::from_proto(proto)))), + Encoding::AttributeType => Ok(Self::Thing(ThingType::Attribute(AttributeType::try_from_proto(proto)?))), + Encoding::RoleType => Ok(Self::Role(RoleType::from_proto(proto))), + } + } +} + +impl FromProto for EntityType { + fn from_proto(proto: TypeProto) -> Self { + Self::new(proto.label) + } +} + +impl FromProto for RelationType { + fn from_proto(proto: TypeProto) -> Self { + Self::new(proto.label) + } +} + +impl TryFromProto for ValueType { + fn try_from_proto(proto: i32) -> Result { + Self::from_i32(proto).ok_or(InternalError::EnumOutOfBounds(proto, "ValueType").into()) + } +} + +impl TryFromProto for AttributeType { + fn try_from_proto(proto: TypeProto) -> Result { + match ValueType::try_from_proto(proto.value_type)? { + ValueType::Object => Ok(Self::Root(RootAttributeType::default())), + ValueType::Boolean => Ok(Self::Boolean(BooleanAttributeType { label: proto.label })), + ValueType::Long => Ok(Self::Long(LongAttributeType { label: proto.label })), + ValueType::Double => Ok(Self::Double(DoubleAttributeType { label: proto.label })), + ValueType::String => Ok(Self::String(StringAttributeType { label: proto.label })), + ValueType::Datetime => Ok(Self::DateTime(DateTimeAttributeType { label: proto.label })), + } + } +} + +impl FromProto for RoleType { + fn from_proto(proto: TypeProto) -> Self { + Self::new(ScopedLabel::new(proto.scope, proto.label)) + } +} + +impl TryFromProto for Thing { + fn try_from_proto(proto: ThingProto) -> Result { + let encoding = proto.r#type.clone().ok_or(ConnectionError::MissingResponseField("type"))?.encoding; + match Encoding::try_from_proto(encoding)? { + Encoding::EntityType => Ok(Self::Entity(Entity::try_from_proto(proto)?)), + Encoding::RelationType => Ok(Self::Relation(Relation::try_from_proto(proto)?)), + Encoding::AttributeType => Ok(Self::Attribute(Attribute::try_from_proto(proto)?)), + _ => todo!(), + } + } +} + +impl TryFromProto for Entity { + fn try_from_proto(proto: ThingProto) -> Result { + Ok(Self { + type_: EntityType::from_proto(proto.r#type.ok_or(ConnectionError::MissingResponseField("type"))?), + iid: proto.iid, + }) + } +} + +impl TryFromProto for Relation { + fn try_from_proto(proto: ThingProto) -> Result { + Ok(Self { + type_: RelationType::from_proto(proto.r#type.ok_or(ConnectionError::MissingResponseField("type"))?), + iid: proto.iid, + }) + } +} + +impl TryFromProto for Attribute { + fn try_from_proto(proto: ThingProto) -> Result { + let value = proto.value.and_then(|v| v.value).ok_or(ConnectionError::MissingResponseField("value"))?; + + let value_type = proto.r#type.ok_or(ConnectionError::MissingResponseField("type"))?.value_type; + let iid = proto.iid; + + match ValueType::try_from_proto(value_type)? { + ValueType::Object => todo!(), + ValueType::Boolean => Ok(Self::Boolean(BooleanAttribute { + value: if let ValueProto::Boolean(value) = value { value } else { unreachable!() }, + iid, + })), + ValueType::Long => Ok(Self::Long(LongAttribute { + value: if let ValueProto::Long(value) = value { value } else { unreachable!() }, + iid, + })), + ValueType::Double => Ok(Self::Double(DoubleAttribute { + value: if let ValueProto::Double(value) = value { value } else { unreachable!() }, + iid, + })), + ValueType::String => Ok(Self::String(StringAttribute { + value: if let ValueProto::String(value) = value { value } else { unreachable!() }, + iid, + })), + ValueType::Datetime => Ok(Self::DateTime(DateTimeAttribute { + value: if let ValueProto::DateTime(value) = value { + NaiveDateTime::from_timestamp_opt(value / 1000, (value % 1000) as u32 * 1_000_000).unwrap() + } else { + unreachable!() + }, + iid, + })), + } + } +} diff --git a/src/connection/network/proto/database.rs b/src/connection/network/proto/database.rs new file mode 100644 index 00000000..bc73473d --- /dev/null +++ b/src/connection/network/proto/database.rs @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use itertools::Itertools; +use typedb_protocol::{cluster_database::Replica as ReplicaProto, ClusterDatabase as DatabaseProto}; + +use super::TryFromProto; +use crate::{ + common::info::{DatabaseInfo, ReplicaInfo}, + Result, +}; + +impl TryFromProto for DatabaseInfo { + fn try_from_proto(proto: DatabaseProto) -> Result { + Ok(Self { + name: proto.name, + replicas: proto.replicas.into_iter().map(ReplicaInfo::try_from_proto).try_collect()?, + }) + } +} + +impl TryFromProto for ReplicaInfo { + fn try_from_proto(proto: ReplicaProto) -> Result { + Ok(Self { + address: proto.address.as_str().parse()?, + is_primary: proto.primary, + is_preferred: proto.preferred, + term: proto.term, + }) + } +} diff --git a/src/connection/network/proto/message.rs b/src/connection/network/proto/message.rs new file mode 100644 index 00000000..1a0c9f81 --- /dev/null +++ b/src/connection/network/proto/message.rs @@ -0,0 +1,370 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::time::Duration; + +use itertools::Itertools; +use typedb_protocol::{ + cluster_database_manager, core_database, core_database_manager, query_manager, server_manager, session, transaction, +}; + +use super::{FromProto, IntoProto, TryFromProto}; +use crate::{ + answer::{ConceptMap, Numeric}, + common::{info::DatabaseInfo, RequestID, Result}, + connection::{ + message::{QueryRequest, QueryResponse, Request, Response, TransactionRequest, TransactionResponse}, + network::proto::TryIntoProto, + }, + error::{ConnectionError, InternalError}, +}; + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::ServersAll => Ok(server_manager::all::Req {}), + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::DatabasesContains { database_name } => { + Ok(core_database_manager::contains::Req { name: database_name }) + } + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::DatabaseCreate { database_name } => Ok(core_database_manager::create::Req { name: database_name }), + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::DatabaseGet { database_name } => Ok(cluster_database_manager::get::Req { name: database_name }), + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::DatabasesAll => Ok(cluster_database_manager::all::Req {}), + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::DatabaseDelete { database_name } => Ok(core_database::delete::Req { name: database_name }), + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::DatabaseSchema { database_name } => Ok(core_database::schema::Req { name: database_name }), + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::DatabaseTypeSchema { database_name } => { + Ok(core_database::type_schema::Req { name: database_name }) + } + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::DatabaseRuleSchema { database_name } => { + Ok(core_database::rule_schema::Req { name: database_name }) + } + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::SessionOpen { database_name, session_type, options } => Ok(session::open::Req { + database: database_name, + r#type: session_type.into_proto().into(), + options: Some(options.into_proto()), + }), + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::SessionPulse { session_id } => Ok(session::pulse::Req { session_id: session_id.into() }), + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::SessionClose { session_id } => Ok(session::close::Req { session_id: session_id.into() }), + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryIntoProto for Request { + fn try_into_proto(self) -> Result { + match self { + Request::Transaction(transaction_req) => { + Ok(transaction::Client { reqs: vec![transaction_req.into_proto()] }) + } + other => Err(InternalError::UnexpectedRequestType(format!("{other:?}")).into()), + } + } +} + +impl TryFromProto for Response { + fn try_from_proto(proto: server_manager::all::Res) -> Result { + let servers = proto.servers.into_iter().map(|server| server.address.parse()).try_collect()?; + Ok(Response::ServersAll { servers }) + } +} + +impl FromProto for Response { + fn from_proto(proto: core_database_manager::contains::Res) -> Self { + Self::DatabasesContains { contains: proto.contains } + } +} + +impl FromProto for Response { + fn from_proto(_proto: core_database_manager::create::Res) -> Self { + Self::DatabaseCreate + } +} + +impl TryFromProto for Response { + fn try_from_proto(proto: cluster_database_manager::get::Res) -> Result { + Ok(Response::DatabaseGet { + database: DatabaseInfo::try_from_proto( + proto.database.ok_or(ConnectionError::MissingResponseField("database"))?, + )?, + }) + } +} + +impl TryFromProto for Response { + fn try_from_proto(proto: cluster_database_manager::all::Res) -> Result { + Ok(Response::DatabasesAll { + databases: proto.databases.into_iter().map(DatabaseInfo::try_from_proto).try_collect()?, + }) + } +} + +impl FromProto for Response { + fn from_proto(_proto: core_database::delete::Res) -> Self { + Self::DatabaseDelete + } +} + +impl FromProto for Response { + fn from_proto(proto: core_database::schema::Res) -> Self { + Self::DatabaseSchema { schema: proto.schema } + } +} + +impl FromProto for Response { + fn from_proto(proto: core_database::type_schema::Res) -> Self { + Self::DatabaseTypeSchema { schema: proto.schema } + } +} + +impl FromProto for Response { + fn from_proto(proto: core_database::rule_schema::Res) -> Self { + Self::DatabaseRuleSchema { schema: proto.schema } + } +} + +impl FromProto for Response { + fn from_proto(proto: session::open::Res) -> Self { + Self::SessionOpen { + session_id: proto.session_id.into(), + server_duration: Duration::from_millis(proto.server_duration_millis as u64), + } + } +} + +impl FromProto for Response { + fn from_proto(_proto: session::pulse::Res) -> Self { + Self::SessionPulse + } +} + +impl FromProto for Response { + fn from_proto(_proto: session::close::Res) -> Self { + Self::SessionClose + } +} + +impl IntoProto for TransactionRequest { + fn into_proto(self) -> transaction::Req { + let mut request_id = None; + + let req = match self { + TransactionRequest::Open { session_id, transaction_type, options, network_latency } => { + transaction::req::Req::OpenReq(transaction::open::Req { + session_id: session_id.into(), + r#type: transaction_type.into_proto().into(), + options: Some(options.into_proto()), + network_latency_millis: network_latency.as_millis() as i32, + }) + } + TransactionRequest::Commit => transaction::req::Req::CommitReq(transaction::commit::Req {}), + TransactionRequest::Rollback => transaction::req::Req::RollbackReq(transaction::rollback::Req {}), + TransactionRequest::Query(query_request) => { + transaction::req::Req::QueryManagerReq(query_request.into_proto()) + } + TransactionRequest::Stream { request_id: req_id } => { + request_id = Some(req_id); + transaction::req::Req::StreamReq(transaction::stream::Req {}) + } + }; + + transaction::Req { + req_id: request_id.unwrap_or_else(RequestID::generate).into(), + metadata: Default::default(), + req: Some(req), + } + } +} + +impl TryFromProto for TransactionResponse { + fn try_from_proto(proto: transaction::Res) -> Result { + match proto.res { + Some(transaction::res::Res::OpenRes(_)) => Ok(TransactionResponse::Open), + Some(transaction::res::Res::CommitRes(_)) => Ok(TransactionResponse::Commit), + Some(transaction::res::Res::RollbackRes(_)) => Ok(TransactionResponse::Rollback), + Some(transaction::res::Res::QueryManagerRes(res)) => { + Ok(TransactionResponse::Query(QueryResponse::try_from_proto(res)?)) + } + Some(_) => todo!(), + None => Err(ConnectionError::MissingResponseField("res").into()), + } + } +} + +impl TryFromProto for TransactionResponse { + fn try_from_proto(proto: transaction::ResPart) -> Result { + match proto.res { + Some(transaction::res_part::Res::QueryManagerResPart(res_part)) => { + Ok(TransactionResponse::Query(QueryResponse::try_from_proto(res_part)?)) + } + Some(_) => todo!(), + None => Err(ConnectionError::MissingResponseField("res").into()), + } + } +} + +impl IntoProto for QueryRequest { + fn into_proto(self) -> query_manager::Req { + let (req, options) = match self { + QueryRequest::Define { query, options } => { + (query_manager::req::Req::DefineReq(query_manager::define::Req { query }), options) + } + QueryRequest::Undefine { query, options } => { + (query_manager::req::Req::UndefineReq(query_manager::undefine::Req { query }), options) + } + QueryRequest::Delete { query, options } => { + (query_manager::req::Req::DeleteReq(query_manager::delete::Req { query }), options) + } + + QueryRequest::Match { query, options } => { + (query_manager::req::Req::MatchReq(query_manager::r#match::Req { query }), options) + } + QueryRequest::Insert { query, options } => { + (query_manager::req::Req::InsertReq(query_manager::insert::Req { query }), options) + } + QueryRequest::Update { query, options } => { + (query_manager::req::Req::UpdateReq(query_manager::update::Req { query }), options) + } + + QueryRequest::MatchAggregate { query, options } => { + (query_manager::req::Req::MatchAggregateReq(query_manager::match_aggregate::Req { query }), options) + } + + _ => todo!(), + }; + query_manager::Req { req: Some(req), options: Some(options.into_proto()) } + } +} + +impl TryFromProto for QueryResponse { + fn try_from_proto(proto: query_manager::Res) -> Result { + match proto.res { + Some(query_manager::res::Res::DefineRes(_)) => Ok(QueryResponse::Define), + Some(query_manager::res::Res::UndefineRes(_)) => Ok(QueryResponse::Undefine), + Some(query_manager::res::Res::DeleteRes(_)) => Ok(QueryResponse::Delete), + Some(query_manager::res::Res::MatchAggregateRes(res)) => Ok(QueryResponse::MatchAggregate { + answer: Numeric::try_from_proto(res.answer.ok_or(ConnectionError::MissingResponseField("answer"))?)?, + }), + None => Err(ConnectionError::MissingResponseField("res").into()), + } + } +} + +impl TryFromProto for QueryResponse { + fn try_from_proto(proto: query_manager::ResPart) -> Result { + match proto.res { + Some(query_manager::res_part::Res::MatchResPart(res)) => Ok(QueryResponse::Match { + answers: res.answers.into_iter().map(ConceptMap::try_from_proto).try_collect()?, + }), + Some(query_manager::res_part::Res::InsertResPart(res)) => Ok(QueryResponse::Insert { + answers: res.answers.into_iter().map(ConceptMap::try_from_proto).try_collect()?, + }), + Some(_) => todo!(), + None => Err(ConnectionError::MissingResponseField("res").into()), + } + } +} diff --git a/src/common/rpc/mod.rs b/src/connection/network/proto/mod.rs similarity index 67% rename from src/common/rpc/mod.rs rename to src/connection/network/proto/mod.rs index e38fa19d..fc805bb2 100644 --- a/src/common/rpc/mod.rs +++ b/src/connection/network/proto/mod.rs @@ -19,17 +19,25 @@ * under the License. */ -pub(crate) mod builder; -mod channel; -mod cluster; -mod core; -mod server; -mod transaction; +mod common; +mod concept; +mod database; +mod message; -pub(crate) use self::{ - channel::Channel, - cluster::{ClusterRPC, ClusterServerRPC}, - core::CoreRPC, - server::ServerRPC, - transaction::TransactionRPC, -}; +use crate::Result; + +pub(super) trait IntoProto { + fn into_proto(self) -> Proto; +} + +pub(super) trait TryIntoProto { + fn try_into_proto(self) -> Result; +} + +pub(super) trait FromProto { + fn from_proto(proto: Proto) -> Self; +} + +pub(super) trait TryFromProto: Sized { + fn try_from_proto(proto: Proto) -> Result; +} diff --git a/src/connection/network/stub.rs b/src/connection/network/stub.rs new file mode 100644 index 00000000..fccae604 --- /dev/null +++ b/src/connection/network/stub.rs @@ -0,0 +1,264 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::sync::Arc; + +use futures::{future::BoxFuture, FutureExt, TryFutureExt}; +use log::{debug, trace}; +use tokio::sync::mpsc::{unbounded_channel as unbounded_async, UnboundedSender}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tonic::{Response, Status, Streaming}; +use typedb_protocol::{ + cluster_database::Replica, cluster_database_manager, cluster_user, core_database, core_database_manager, + server_manager, session, transaction, type_db_client::TypeDbClient as CoreGRPC, + type_db_cluster_client::TypeDbClusterClient as ClusterGRPC, ClusterDatabase, +}; + +use super::channel::{CallCredentials, GRPCChannel}; +use crate::common::{address::Address, error::ConnectionError, Error, Result, StdResult}; + +type TonicResult = StdResult, Status>; + +#[derive(Clone, Debug)] +pub(super) struct RPCStub { + address: Address, + channel: Channel, + core_grpc: CoreGRPC, + cluster_grpc: ClusterGRPC, + call_credentials: Option>, +} + +impl RPCStub { + pub(super) async fn new( + address: Address, + channel: Channel, + call_credentials: Option>, + ) -> Result { + let this = Self { + address, + core_grpc: CoreGRPC::new(channel.clone()), + cluster_grpc: ClusterGRPC::new(channel.clone()), + channel, + call_credentials, + }; + let mut this = this.validated().await?; + this.renew_token().await?; + Ok(this) + } + + pub(super) async fn validated(mut self) -> Result { + self.databases_all(cluster_database_manager::all::Req {}).await?; + Ok(self) + } + + fn address(&self) -> &Address { + &self.address + } + + async fn call_with_auto_renew_token(&mut self, call: F) -> Result + where + for<'a> F: Fn(&'a mut Self) -> BoxFuture<'a, Result>, + { + match call(self).await { + Err(Error::Connection(ConnectionError::ClusterTokenCredentialInvalid())) => { + self.renew_token().await?; + call(self).await + } + res => res, + } + } + + async fn renew_token(&mut self) -> Result { + if let Some(call_credentials) = &self.call_credentials { + trace!("renewing token..."); + call_credentials.reset_token(); + let req = cluster_user::token::Req { username: call_credentials.username().to_owned() }; + trace!("sending token request..."); + let token = self.cluster_grpc.user_token(req).await?.into_inner().token; + call_credentials.set_token(token); + trace!("renewed token"); + } + Ok(()) + } + + pub(super) async fn servers_all(&mut self, req: server_manager::all::Req) -> Result { + self.single(|this| Box::pin(this.cluster_grpc.servers_all(req.clone()))).await + } + + pub(super) async fn databases_contains( + &mut self, + req: core_database_manager::contains::Req, + ) -> Result { + self.single(|this| Box::pin(this.core_grpc.databases_contains(req.clone()))).await + } + + pub(super) async fn databases_create( + &mut self, + req: core_database_manager::create::Req, + ) -> Result { + self.single(|this| Box::pin(this.core_grpc.databases_create(req.clone()))).await + } + + // FIXME: merge after protocol merge + pub(super) async fn databases_get( + &mut self, + req: cluster_database_manager::get::Req, + ) -> Result { + if self.channel.is_plaintext() { + self.databases_get_core(req).await + } else { + self.databases_get_cluster(req).await + } + } + + pub(super) async fn databases_all( + &mut self, + req: cluster_database_manager::all::Req, + ) -> Result { + if self.channel.is_plaintext() { + self.databases_all_core(req).await + } else { + self.databases_all_cluster(req).await + } + } + + async fn databases_get_core( + &mut self, + req: cluster_database_manager::get::Req, + ) -> Result { + Ok(cluster_database_manager::get::Res { + database: Some(ClusterDatabase { + name: req.name, + replicas: vec![Replica { + address: self.address().to_string(), + primary: true, + preferred: true, + term: 0, + }], + }), + }) + } + + async fn databases_all_core( + &mut self, + _req: cluster_database_manager::all::Req, + ) -> Result { + let database_names = + self.single(|this| Box::pin(this.core_grpc.databases_all(core_database_manager::all::Req {}))).await?.names; + Ok(cluster_database_manager::all::Res { + databases: database_names + .into_iter() + .map(|db_name| ClusterDatabase { + name: db_name, + replicas: vec![Replica { + address: self.address().to_string(), + primary: true, + preferred: true, + term: 0, + }], + }) + .collect(), + }) + } + + async fn databases_get_cluster( + &mut self, + req: cluster_database_manager::get::Req, + ) -> Result { + self.single(|this| Box::pin(this.cluster_grpc.databases_get(req.clone()))).await + } + + async fn databases_all_cluster( + &mut self, + req: cluster_database_manager::all::Req, + ) -> Result { + self.single(|this| Box::pin(this.cluster_grpc.databases_all(req.clone()))).await + } + // FIXME: end FIXME + + pub(super) async fn database_delete( + &mut self, + req: core_database::delete::Req, + ) -> Result { + self.single(|this| Box::pin(this.core_grpc.database_delete(req.clone()))).await + } + + pub(super) async fn database_schema( + &mut self, + req: core_database::schema::Req, + ) -> Result { + self.single(|this| Box::pin(this.core_grpc.database_schema(req.clone()))).await + } + + pub(super) async fn database_type_schema( + &mut self, + req: core_database::type_schema::Req, + ) -> Result { + self.single(|this| Box::pin(this.core_grpc.database_type_schema(req.clone()))).await + } + + pub(super) async fn database_rule_schema( + &mut self, + req: core_database::rule_schema::Req, + ) -> Result { + self.single(|this| Box::pin(this.core_grpc.database_rule_schema(req.clone()))).await + } + + pub(super) async fn session_open(&mut self, req: session::open::Req) -> Result { + self.single(|this| Box::pin(this.core_grpc.session_open(req.clone()))).await + } + + pub(super) async fn session_close(&mut self, req: session::close::Req) -> Result { + debug!("closing session"); + self.single(|this| Box::pin(this.core_grpc.session_close(req.clone()))).await + } + + pub(super) async fn session_pulse(&mut self, req: session::pulse::Req) -> Result { + self.single(|this| Box::pin(this.core_grpc.session_pulse(req.clone()))).await + } + + pub(super) async fn transaction( + &mut self, + open_req: transaction::Req, + ) -> Result<(UnboundedSender, Streaming)> { + self.call_with_auto_renew_token(|this| { + let transaction_req = transaction::Client { reqs: vec![open_req.clone()] }; + Box::pin(async { + let (sender, receiver) = unbounded_async(); + sender.send(transaction_req)?; + this.core_grpc + .transaction(UnboundedReceiverStream::new(receiver)) + .map_ok(|stream| Response::new((sender, stream.into_inner()))) + .map(|r| Ok(r?.into_inner())) + .await + }) + }) + .await + } + + async fn single(&mut self, call: F) -> Result + where + for<'a> F: Fn(&'a mut Self) -> BoxFuture<'a, TonicResult> + Send + Sync, + R: 'static, + { + self.call_with_auto_renew_token(|this| Box::pin(call(this).map(|r| Ok(r?.into_inner())))).await + } +} diff --git a/src/connection/server/mod.rs b/src/connection/network/transmitter/mod.rs similarity index 87% rename from src/connection/server/mod.rs rename to src/connection/network/transmitter/mod.rs index 8b865283..971012f7 100644 --- a/src/connection/server/mod.rs +++ b/src/connection/network/transmitter/mod.rs @@ -19,8 +19,8 @@ * under the License. */ -mod database; -mod session; +mod response_sink; +mod rpc; mod transaction; -pub use self::{database::Database, session::Session, transaction::Transaction}; +pub(in crate::connection) use self::{rpc::RPCTransmitter, transaction::TransactionTransmitter}; diff --git a/src/connection/network/transmitter/response_sink.rs b/src/connection/network/transmitter/response_sink.rs new file mode 100644 index 00000000..526a3110 --- /dev/null +++ b/src/connection/network/transmitter/response_sink.rs @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crossbeam::channel::Sender as SyncSender; +use log::error; +use tokio::sync::{mpsc::UnboundedSender, oneshot::Sender as AsyncOneshotSender}; + +use crate::{ + common::Result, + error::{ConnectionError, InternalError}, + Error, +}; + +#[derive(Debug)] +pub(super) enum ResponseSink { + AsyncOneShot(AsyncOneshotSender>), + BlockingOneShot(SyncSender>), + Streamed(UnboundedSender>), +} + +impl ResponseSink { + pub(super) fn finish(self, response: Result) { + let result = match self { + Self::AsyncOneShot(sink) => sink.send(response).map_err(|_| InternalError::SendError().into()), + Self::BlockingOneShot(sink) => sink.send(response).map_err(Error::from), + Self::Streamed(sink) => sink.send(response).map_err(Error::from), + }; + if let Err(err) = result { + error!("{}", err); + } + } + + pub(super) fn send(&self, response: Result) { + let result = match self { + Self::Streamed(sink) => sink.send(response).map_err(Error::from), + _ => unreachable!("attempted to stream over a one-shot callback"), + }; + if let Err(err) = result { + error!("{}", err); + } + } + + pub(super) async fn error(self, error: ConnectionError) { + match self { + Self::AsyncOneShot(sink) => sink.send(Err(error.into())).ok(), + Self::BlockingOneShot(sink) => sink.send(Err(error.into())).ok(), + Self::Streamed(sink) => sink.send(Err(error.into())).ok(), + }; + } +} diff --git a/src/connection/network/transmitter/rpc.rs b/src/connection/network/transmitter/rpc.rs new file mode 100644 index 00000000..def4f64e --- /dev/null +++ b/src/connection/network/transmitter/rpc.rs @@ -0,0 +1,161 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crossbeam::channel::{bounded as bounded_blocking, Receiver as SyncReceiver, Sender as SyncSender}; +use tokio::{ + select, + sync::{ + mpsc::{unbounded_channel as unbounded_async, UnboundedReceiver, UnboundedSender}, + oneshot::channel as oneshot_async, + }, +}; + +use super::response_sink::ResponseSink; +use crate::{ + common::{address::Address, Result}, + connection::{ + message::{Request, Response}, + network::{ + channel::{open_encrypted_channel, open_plaintext_channel, GRPCChannel}, + proto::{FromProto, IntoProto, TryFromProto, TryIntoProto}, + stub::RPCStub, + }, + runtime::BackgroundRuntime, + }, + Credential, Error, +}; + +fn oneshot_blocking() -> (SyncSender, SyncReceiver) { + bounded_blocking::(0) +} + +pub(in crate::connection) struct RPCTransmitter { + request_sink: UnboundedSender<(Request, ResponseSink)>, + shutdown_sink: UnboundedSender<()>, +} + +impl RPCTransmitter { + pub(in crate::connection) fn start_plaintext(address: Address, runtime: &BackgroundRuntime) -> Result { + let (request_sink, request_source) = unbounded_async(); + let (shutdown_sink, shutdown_source) = unbounded_async(); + runtime.run_blocking(async move { + let channel = open_plaintext_channel(address.clone()); + let rpc = RPCStub::new(address.clone(), channel, None).await?; + tokio::spawn(Self::dispatcher_loop(rpc, request_source, shutdown_source)); + Ok::<(), Error>(()) + })?; + Ok(Self { request_sink, shutdown_sink }) + } + + pub(in crate::connection) fn start_encrypted( + address: Address, + credential: Credential, + runtime: &BackgroundRuntime, + ) -> Result { + let (request_sink, request_source) = unbounded_async(); + let (shutdown_sink, shutdown_source) = unbounded_async(); + runtime.run_blocking(async move { + let (channel, call_credentials) = open_encrypted_channel(address.clone(), credential)?; + let rpc = RPCStub::new(address.clone(), channel, Some(call_credentials)).await?; + tokio::spawn(Self::dispatcher_loop(rpc, request_source, shutdown_source)); + Ok::<(), Error>(()) + })?; + Ok(Self { request_sink, shutdown_sink }) + } + + pub(in crate::connection) async fn request_async(&self, request: Request) -> Result { + let (response_sink, response) = oneshot_async(); + self.request_sink.send((request, ResponseSink::AsyncOneShot(response_sink)))?; + response.await? + } + + pub(in crate::connection) fn request_blocking(&self, request: Request) -> Result { + let (response_sink, response) = oneshot_blocking(); + self.request_sink.send((request, ResponseSink::BlockingOneShot(response_sink)))?; + response.recv()? + } + + pub(in crate::connection) fn force_close(&self) -> Result { + self.shutdown_sink.send(()).map_err(Into::into) + } + + async fn dispatcher_loop( + rpc: RPCStub, + mut request_source: UnboundedReceiver<(Request, ResponseSink)>, + mut shutdown_signal: UnboundedReceiver<()>, + ) { + while let Some((request, response_sink)) = select! { + request = request_source.recv() => request, + _ = shutdown_signal.recv() => None, + } { + let rpc = rpc.clone(); + tokio::spawn(async move { + let response = Self::send_request(rpc, request).await; + response_sink.finish(response); + }); + } + } + + async fn send_request(mut rpc: RPCStub, request: Request) -> Result { + match request { + Request::ServersAll => rpc.servers_all(request.try_into_proto()?).await.and_then(Response::try_from_proto), + + Request::DatabasesContains { .. } => { + rpc.databases_contains(request.try_into_proto()?).await.map(Response::from_proto) + } + Request::DatabaseCreate { .. } => { + rpc.databases_create(request.try_into_proto()?).await.map(Response::from_proto) + } + Request::DatabaseGet { .. } => { + rpc.databases_get(request.try_into_proto()?).await.and_then(Response::try_from_proto) + } + Request::DatabasesAll => { + rpc.databases_all(request.try_into_proto()?).await.and_then(Response::try_from_proto) + } + + Request::DatabaseDelete { .. } => { + rpc.database_delete(request.try_into_proto()?).await.map(Response::from_proto) + } + Request::DatabaseSchema { .. } => { + rpc.database_schema(request.try_into_proto()?).await.map(Response::from_proto) + } + Request::DatabaseTypeSchema { .. } => { + rpc.database_type_schema(request.try_into_proto()?).await.map(Response::from_proto) + } + Request::DatabaseRuleSchema { .. } => { + rpc.database_rule_schema(request.try_into_proto()?).await.map(Response::from_proto) + } + + Request::SessionOpen { .. } => rpc.session_open(request.try_into_proto()?).await.map(Response::from_proto), + Request::SessionPulse { .. } => { + rpc.session_pulse(request.try_into_proto()?).await.map(Response::from_proto) + } + Request::SessionClose { .. } => { + rpc.session_close(request.try_into_proto()?).await.map(Response::from_proto) + } + + Request::Transaction(transaction_request) => { + let (request_sink, response_source) = rpc.transaction(transaction_request.into_proto()).await?; + Ok(Response::TransactionOpen { request_sink, response_source }) + } + } + } +} diff --git a/src/connection/network/transmitter/transaction.rs b/src/connection/network/transmitter/transaction.rs new file mode 100644 index 00000000..41f8427d --- /dev/null +++ b/src/connection/network/transmitter/transaction.rs @@ -0,0 +1,271 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{ + collections::HashMap, + ops::DerefMut, + sync::{Arc, RwLock}, + time::Duration, +}; + +use crossbeam::atomic::AtomicCell; +use futures::{Stream, StreamExt, TryStreamExt}; +use log::error; +use prost::Message; +use tokio::{ + select, + sync::{ + mpsc::{error::SendError, unbounded_channel as unbounded_async, UnboundedReceiver, UnboundedSender}, + oneshot::channel as oneshot_async, + }, + time::{sleep_until, Instant}, +}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tonic::Streaming; +use typedb_protocol::transaction::{self, server::Server, stream::State}; + +use super::response_sink::ResponseSink; +use crate::{ + common::{error::ConnectionError, RequestID, Result}, + connection::{ + message::{TransactionRequest, TransactionResponse}, + network::proto::{IntoProto, TryFromProto}, + runtime::BackgroundRuntime, + }, +}; + +pub(in crate::connection) struct TransactionTransmitter { + request_sink: UnboundedSender<(TransactionRequest, Option>)>, + is_open: Arc>, + shutdown_sink: UnboundedSender<()>, +} + +impl Drop for TransactionTransmitter { + fn drop(&mut self) { + self.is_open.store(false); + self.shutdown_sink.send(()).ok(); + } +} + +impl TransactionTransmitter { + pub(in crate::connection) fn new( + background_runtime: &BackgroundRuntime, + request_sink: UnboundedSender, + response_source: Streaming, + ) -> Self { + let (buffer_sink, buffer_source) = unbounded_async(); + let (shutdown_sink, shutdown_source) = unbounded_async(); + let is_open = Arc::new(AtomicCell::new(true)); + background_runtime.spawn(Self::start_workers( + buffer_sink.clone(), + buffer_source, + request_sink, + response_source, + is_open.clone(), + shutdown_source, + )); + Self { request_sink: buffer_sink, is_open, shutdown_sink } + } + + pub(in crate::connection) async fn single(&self, req: TransactionRequest) -> Result { + if !self.is_open.load() { + return Err(ConnectionError::SessionIsClosed().into()); + } + let (res_sink, recv) = oneshot_async(); + self.request_sink.send((req, Some(ResponseSink::AsyncOneShot(res_sink))))?; + recv.await?.map(Into::into) + } + + pub(in crate::connection) fn stream( + &self, + req: TransactionRequest, + ) -> Result>> { + if !self.is_open.load() { + return Err(ConnectionError::SessionIsClosed().into()); + } + let (res_part_sink, recv) = unbounded_async(); + self.request_sink.send((req, Some(ResponseSink::Streamed(res_part_sink))))?; + Ok(UnboundedReceiverStream::new(recv).map_ok(Into::into)) + } + + async fn start_workers( + queue_sink: UnboundedSender<(TransactionRequest, Option>)>, + queue_source: UnboundedReceiver<(TransactionRequest, Option>)>, + request_sink: UnboundedSender, + response_source: Streaming, + is_open: Arc>, + shutdown_signal: UnboundedReceiver<()>, + ) { + let collector = ResponseCollector { request_sink: queue_sink, callbacks: Default::default(), is_open }; + tokio::spawn(Self::dispatch_loop(queue_source, request_sink, collector.clone(), shutdown_signal)); + tokio::spawn(Self::listen_loop(response_source, collector)); + } + + async fn dispatch_loop( + mut request_source: UnboundedReceiver<(TransactionRequest, Option>)>, + request_sink: UnboundedSender, + mut collector: ResponseCollector, + mut shutdown_signal: UnboundedReceiver<()>, + ) { + const MAX_GRPC_MESSAGE_LEN: usize = 1_000_000; + const DISPATCH_INTERVAL: Duration = Duration::from_millis(3); + + let mut request_buffer = TransactionRequestBuffer::default(); + let mut next_dispatch = Instant::now() + DISPATCH_INTERVAL; + loop { + select! { biased; + _ = shutdown_signal.recv() => { + if !request_buffer.is_empty() { + request_sink.send(request_buffer.take()).unwrap(); + } + break; + } + _ = sleep_until(next_dispatch) => { + if !request_buffer.is_empty() { + request_sink.send(request_buffer.take()).unwrap(); + } + next_dispatch = Instant::now() + DISPATCH_INTERVAL; + } + recv = request_source.recv() => { + if let Some((request, callback)) = recv { + let request = request.into_proto(); + if let Some(callback) = callback { + collector.register(request.req_id.clone().into(), callback); + } + if request_buffer.len() + request.encoded_len() > MAX_GRPC_MESSAGE_LEN { + request_sink.send(request_buffer.take()).unwrap(); + } + request_buffer.push(request); + } else { + break; + } + } + } + } + } + + async fn listen_loop(mut grpc_source: Streaming, collector: ResponseCollector) { + loop { + match grpc_source.next().await { + Some(Ok(message)) => collector.collect(message).await, + Some(Err(err)) => { + break collector.close(ConnectionError::TransactionIsClosedWithErrors(err.to_string())).await + } + None => break collector.close(ConnectionError::TransactionIsClosed()).await, + } + } + } +} + +#[derive(Default)] +struct TransactionRequestBuffer { + reqs: Vec, + len: usize, +} + +impl TransactionRequestBuffer { + fn is_empty(&self) -> bool { + self.reqs.is_empty() + } + + fn len(&self) -> usize { + self.len + } + + fn push(&mut self, request: transaction::Req) { + self.len += request.encoded_len(); + self.reqs.push(request); + } + + fn take(&mut self) -> transaction::Client { + self.len = 0; + transaction::Client { reqs: std::mem::take(&mut self.reqs) } + } +} + +#[derive(Clone)] +struct ResponseCollector { + request_sink: UnboundedSender<(TransactionRequest, Option>)>, + callbacks: Arc>>>, + is_open: Arc>, +} + +impl ResponseCollector { + fn register(&mut self, request_id: RequestID, callback: ResponseSink) { + self.callbacks.write().unwrap().insert(request_id, callback); + } + + async fn collect(&self, message: transaction::Server) { + match message.server { + Some(Server::Res(res)) => self.collect_res(res), + Some(Server::ResPart(res_part)) => self.collect_res_part(res_part).await, + None => error!("{}", ConnectionError::MissingResponseField("server")), + } + } + + fn collect_res(&self, res: transaction::Res) { + if matches!(res.res, Some(transaction::res::Res::OpenRes(_))) { + // Transaction::Open responses don't need to be collected. + return; + } + let req_id = res.req_id.clone().into(); + match self.callbacks.write().unwrap().remove(&req_id) { + Some(sink) => sink.finish(TransactionResponse::try_from_proto(res)), + _ => error!("{}", ConnectionError::UnknownRequestId(req_id)), + } + } + + async fn collect_res_part(&self, res_part: transaction::ResPart) { + let request_id = res_part.req_id.clone().into(); + + match res_part.res { + Some(transaction::res_part::Res::StreamResPart(stream_res_part)) => { + match State::from_i32(stream_res_part.state).expect("enum out of range") { + State::Done => { + self.callbacks.write().unwrap().remove(&request_id); + } + State::Continue => { + match self.request_sink.send((TransactionRequest::Stream { request_id }, None)) { + Err(SendError((TransactionRequest::Stream { request_id }, None))) => { + let callback = self.callbacks.write().unwrap().remove(&request_id).unwrap(); + callback.error(ConnectionError::TransactionIsClosed()).await; + } + _ => (), + } + } + } + } + Some(_) => match self.callbacks.read().unwrap().get(&request_id) { + Some(sink) => sink.send(TransactionResponse::try_from_proto(res_part)), + _ => error!("{}", ConnectionError::UnknownRequestId(request_id)), + }, + None => error!("{}", ConnectionError::MissingResponseField("res_part.res")), + } + } + + async fn close(self, error: ConnectionError) { + self.is_open.store(false); + let mut listeners = std::mem::take(self.callbacks.write().unwrap().deref_mut()); + for (_, listener) in listeners.drain() { + listener.error(error.clone()).await; + } + } +} diff --git a/src/connection/runtime.rs b/src/connection/runtime.rs new file mode 100644 index 00000000..faa45fd9 --- /dev/null +++ b/src/connection/runtime.rs @@ -0,0 +1,88 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{future::Future, thread}; + +use crossbeam::{atomic::AtomicCell, channel::bounded as bounded_blocking}; +use tokio::{ + runtime, + sync::mpsc::{unbounded_channel as unbounded_async, UnboundedSender}, +}; + +use crate::common::Result; + +pub(super) struct BackgroundRuntime { + async_runtime_handle: runtime::Handle, + is_open: AtomicCell, + shutdown_sink: UnboundedSender<()>, +} + +impl BackgroundRuntime { + pub(super) fn new() -> Result { + let is_open = AtomicCell::new(true); + let (shutdown_sink, mut shutdown_source) = unbounded_async(); + let async_runtime = runtime::Builder::new_current_thread().enable_time().enable_io().build()?; + let async_runtime_handle = async_runtime.handle().clone(); + thread::Builder::new().name("gRPC worker".to_string()).spawn(move || { + async_runtime.block_on(async move { + shutdown_source.recv().await; + }) + })?; + Ok(Self { async_runtime_handle, is_open, shutdown_sink }) + } + + pub(super) fn is_open(&self) -> bool { + self.is_open.load() + } + + pub(super) fn force_close(&self) -> Result { + self.is_open.store(false); + self.shutdown_sink.send(())?; + Ok(()) + } + + pub(super) fn spawn(&self, future: F) + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.async_runtime_handle.spawn(future); + } + + pub(super) fn run_blocking(&self, future: F) -> F::Output + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + let (response_sink, response) = bounded_blocking(0); + self.async_runtime_handle.spawn(async move { + response_sink.send(future.await).ok(); + }); + response.recv().unwrap() + } +} + +impl Drop for BackgroundRuntime { + fn drop(&mut self) { + self.is_open.store(false); + self.shutdown_sink.send(()).ok(); + } +} diff --git a/src/connection/server/database.rs b/src/connection/server/database.rs deleted file mode 100644 index 0f038213..00000000 --- a/src/connection/server/database.rs +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::fmt::{Display, Formatter}; - -use crate::common::{ - rpc::builder::core::database::{delete_req, rule_schema_req, schema_req, type_schema_req}, - Result, ServerRPC, -}; - -#[derive(Clone, Debug)] -pub struct Database { - pub name: String, - server_rpc: ServerRPC, -} - -impl Database { - pub(crate) fn new(name: &str, server_rpc: ServerRPC) -> Self { - Database { name: name.into(), server_rpc } - } - - pub async fn delete(mut self) -> Result { - self.server_rpc.database_delete(delete_req(self.name.as_str())).await?; - Ok(()) - } - - pub async fn schema(&mut self) -> Result { - self.server_rpc.database_schema(schema_req(self.name.as_str())).await.map(|res| res.schema) - } - - pub async fn type_schema(&mut self) -> Result { - self.server_rpc - .database_type_schema(type_schema_req(self.name.as_str())) - .await - .map(|res| res.schema) - } - - pub async fn rule_schema(&mut self) -> Result { - self.server_rpc - .database_rule_schema(rule_schema_req(self.name.as_str())) - .await - .map(|res| res.schema) - } -} - -impl Display for Database { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.name) - } -} diff --git a/src/connection/server/session.rs b/src/connection/server/session.rs deleted file mode 100644 index ddf50c1b..00000000 --- a/src/connection/server/session.rs +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::time::{Duration, Instant}; - -use crossbeam::atomic::AtomicCell; -use futures::executor; -use log::warn; - -use crate::{ - common::{ - error::ClientError, - rpc::builder::session::{close_req, open_req}, - Result, ServerRPC, SessionType, TransactionType, - }, - connection::{core, server::Transaction}, -}; - -pub(crate) type SessionId = Vec; - -#[derive(Debug)] -pub struct Session { - pub db_name: String, - pub session_type: SessionType, - pub(crate) id: SessionId, - pub(crate) server_rpc: ServerRPC, - is_open_atomic: AtomicCell, - network_latency: Duration, -} - -impl Session { - pub(crate) async fn new( - db_name: &str, - session_type: SessionType, - options: core::Options, - mut server_rpc: ServerRPC, - ) -> Result { - let start_time = Instant::now(); - let open_req = open_req(db_name, session_type.to_proto(), options.to_proto()); - let res = server_rpc.session_open(open_req).await?; - // TODO: pulse task - Ok(Session { - db_name: String::from(db_name), - session_type, - network_latency: Self::compute_network_latency(start_time, res.server_duration_millis), - id: res.session_id, - server_rpc, - is_open_atomic: AtomicCell::new(true), - }) - } - - pub async fn transaction(&self, transaction_type: TransactionType) -> Result { - self.transaction_with_options(transaction_type, core::Options::default()).await - } - - pub async fn transaction_with_options( - &self, - transaction_type: TransactionType, - options: core::Options, - ) -> Result { - match self.is_open() { - true => { - Transaction::new( - &self.id, - transaction_type, - options, - self.network_latency, - &self.server_rpc, - ) - .await - } - false => Err(ClientError::SessionIsClosed())?, - } - } - - pub fn is_open(&self) -> bool { - self.is_open_atomic.load() - } - - pub async fn close(&mut self) { - if let Ok(true) = self.is_open_atomic.compare_exchange(true, false) { - // let res = self.session_close_sink.send(self.id.clone()); - let res = self.server_rpc.session_close(close_req(self.id.clone())).await; - // TODO: the request errors harmlessly if the session is already closed. Protocol should - // expose the cause of the error and we can use that to decide whether to warn here. - if res.is_err() { - warn!("{}", ClientError::SessionCloseFailed()) - } - } - } - - fn compute_network_latency(start_time: Instant, server_duration_millis: i32) -> Duration { - Duration::from_millis( - (Instant::now() - start_time).as_millis() as u64 - server_duration_millis as u64, - ) - } -} - -impl Drop for Session { - fn drop(&mut self) { - // TODO: this will stall in a single-threaded environment - executor::block_on(self.close()); - } -} diff --git a/src/connection/server/transaction.rs b/src/connection/server/transaction.rs deleted file mode 100644 index 0894cfe3..00000000 --- a/src/connection/server/transaction.rs +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::{fmt::Debug, time::Duration}; - -use futures::Stream; -use typedb_protocol::transaction as transaction_proto; - -use crate::{ - common::{ - rpc::builder::transaction::{commit_req, open_req, rollback_req}, - Result, ServerRPC, TransactionRPC, TransactionType, - }, - connection::core, - query::QueryManager, -}; - -#[derive(Clone, Debug)] -pub struct Transaction { - pub type_: TransactionType, - pub options: core::Options, - pub query: QueryManager, - rpc: TransactionRPC, -} - -impl Transaction { - pub(crate) async fn new( - session_id: &[u8], - transaction_type: TransactionType, - options: core::Options, - network_latency: Duration, - server_rpc: &ServerRPC, - ) -> Result { - let open_req = open_req( - session_id.to_vec(), - transaction_type.to_proto(), - options.to_proto(), - network_latency.as_millis() as i32, - ); - let rpc = TransactionRPC::new(server_rpc, open_req).await?; - Ok(Transaction { type_: transaction_type, options, query: QueryManager::new(&rpc), rpc }) - } - - pub async fn commit(&mut self) -> Result { - self.single_rpc(commit_req()).await.map(|_| ()) - } - - pub async fn rollback(&mut self) -> Result { - self.single_rpc(rollback_req()).await.map(|_| ()) - } - - pub(crate) async fn single_rpc( - &mut self, - req: transaction_proto::Req, - ) -> Result { - self.rpc.single(req).await - } - - pub(crate) fn streaming_rpc( - &mut self, - req: transaction_proto::Req, - ) -> impl Stream> { - self.rpc.stream(req) - } - - // TODO: refactor to delegate work to a background process - pub async fn close(&self) { - self.rpc.close().await; - } -} diff --git a/src/connection/transaction_stream.rs b/src/connection/transaction_stream.rs new file mode 100644 index 00000000..7adf560e --- /dev/null +++ b/src/connection/transaction_stream.rs @@ -0,0 +1,153 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{fmt, iter}; + +use futures::{stream, Stream, StreamExt}; + +use super::network::transmitter::TransactionTransmitter; +use crate::{ + answer::{ConceptMap, Numeric}, + common::Result, + connection::message::{QueryRequest, QueryResponse, TransactionRequest, TransactionResponse}, + error::InternalError, + Options, TransactionType, +}; + +pub(crate) struct TransactionStream { + type_: TransactionType, + options: Options, + transaction_transmitter: TransactionTransmitter, +} + +impl TransactionStream { + pub(super) fn new( + type_: TransactionType, + options: Options, + transaction_transmitter: TransactionTransmitter, + ) -> Self { + Self { type_, options, transaction_transmitter } + } + + pub(crate) fn type_(&self) -> TransactionType { + self.type_ + } + + pub(crate) fn options(&self) -> &Options { + &self.options + } + + pub(crate) async fn commit(&self) -> Result { + self.single(TransactionRequest::Commit).await?; + Ok(()) + } + + pub(crate) async fn rollback(&self) -> Result { + self.single(TransactionRequest::Rollback).await?; + Ok(()) + } + + pub(crate) async fn define(&self, query: String, options: Options) -> Result { + self.single(TransactionRequest::Query(QueryRequest::Define { query, options })).await?; + Ok(()) + } + + pub(crate) async fn undefine(&self, query: String, options: Options) -> Result { + self.single(TransactionRequest::Query(QueryRequest::Undefine { query, options })).await?; + Ok(()) + } + + pub(crate) async fn delete(&self, query: String, options: Options) -> Result { + self.single(TransactionRequest::Query(QueryRequest::Delete { query, options })).await?; + Ok(()) + } + + pub(crate) fn match_(&self, query: String, options: Options) -> Result>> { + let stream = self.query_stream(QueryRequest::Match { query, options })?; + Ok(stream.flat_map(|result| match result { + Ok(QueryResponse::Match { answers }) => stream_iter(answers.into_iter().map(Ok)), + Ok(other) => stream_once(Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into())), + Err(err) => stream_once(Err(err)), + })) + } + + pub(crate) fn insert(&self, query: String, options: Options) -> Result>> { + let stream = self.query_stream(QueryRequest::Insert { query, options })?; + Ok(stream.flat_map(|result| match result { + Ok(QueryResponse::Insert { answers }) => stream_iter(answers.into_iter().map(Ok)), + Ok(other) => stream_once(Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into())), + Err(err) => stream_once(Err(err)), + })) + } + + pub(crate) fn update(&self, query: String, options: Options) -> Result>> { + let stream = self.query_stream(QueryRequest::Update { query, options })?; + Ok(stream.flat_map(|result| match result { + Ok(QueryResponse::Update { answers }) => stream_iter(answers.into_iter().map(Ok)), + Ok(other) => stream_once(Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into())), + Err(err) => stream_once(Err(err)), + })) + } + + pub(crate) async fn match_aggregate(&self, query: String, options: Options) -> Result { + match self.query_single(QueryRequest::MatchAggregate { query, options }).await? { + QueryResponse::MatchAggregate { answer } => Ok(answer), + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } + + async fn single(&self, req: TransactionRequest) -> Result { + self.transaction_transmitter.single(req).await + } + + async fn query_single(&self, req: QueryRequest) -> Result { + match self.single(TransactionRequest::Query(req)).await? { + TransactionResponse::Query(query) => Ok(query), + other => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + } + } + + fn stream(&self, req: TransactionRequest) -> Result>> { + self.transaction_transmitter.stream(req) + } + + fn query_stream(&self, req: QueryRequest) -> Result>> { + Ok(self.stream(TransactionRequest::Query(req))?.map(|response| match response { + Ok(TransactionResponse::Query(query)) => Ok(query), + Ok(other) => Err(InternalError::UnexpectedResponseType(format!("{other:?}")).into()), + Err(err) => Err(err), + })) + } +} + +impl fmt::Debug for TransactionStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TransactionStream").field("type_", &self.type_).field("options", &self.options).finish() + } +} + +fn stream_once<'a, T: Send + 'a>(value: T) -> stream::BoxStream<'a, T> { + stream_iter(iter::once(value)) +} + +fn stream_iter<'a, T: Send + 'a>(iter: impl Iterator + Send + 'a) -> stream::BoxStream<'a, T> { + Box::pin(stream::iter(iter)) +} diff --git a/src/database/database.rs b/src/database/database.rs new file mode 100644 index 00000000..74c823aa --- /dev/null +++ b/src/database/database.rs @@ -0,0 +1,285 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{fmt, future::Future, sync::RwLock, thread::sleep, time::Duration}; + +use itertools::Itertools; +use log::{debug, error}; + +use crate::{ + common::{ + address::Address, + error::ConnectionError, + info::{DatabaseInfo, ReplicaInfo}, + Error, Result, + }, + connection::ServerConnection, + Connection, +}; + +pub struct Database { + name: String, + replicas: RwLock>, + connection: Connection, +} + +impl Database { + const PRIMARY_REPLICA_TASK_MAX_RETRIES: usize = 10; + const FETCH_REPLICAS_MAX_RETRIES: usize = 10; + const WAIT_FOR_PRIMARY_REPLICA_SELECTION: Duration = Duration::from_secs(2); + + pub(super) fn new(database_info: DatabaseInfo, connection: Connection) -> Result { + let name = database_info.name.clone(); + let replicas = RwLock::new(Replica::try_from_info(database_info, &connection)?); + Ok(Self { name, replicas, connection }) + } + + pub(super) async fn get(name: String, connection: Connection) -> Result { + Ok(Self { + name: name.to_string(), + replicas: RwLock::new(Replica::fetch_all(name, connection.clone()).await?), + connection, + }) + } + + pub fn name(&self) -> &str { + self.name.as_str() + } + + pub(super) fn connection(&self) -> &Connection { + &self.connection + } + + pub async fn delete(self) -> Result { + self.run_on_primary_replica(|database, _, _| database.delete()).await + } + + pub async fn schema(&self) -> Result { + self.run_failsafe(|database, _, _| async move { database.schema().await }).await + } + + pub async fn type_schema(&self) -> Result { + self.run_failsafe(|database, _, _| async move { database.type_schema().await }).await + } + + pub async fn rule_schema(&self) -> Result { + self.run_failsafe(|database, _, _| async move { database.rule_schema().await }).await + } + + pub(super) async fn run_failsafe(&self, task: F) -> Result + where + F: Fn(ServerDatabase, ServerConnection, bool) -> P, + P: Future>, + { + match self.run_on_any_replica(&task).await { + Err(Error::Connection(ConnectionError::ClusterReplicaNotPrimary())) => { + debug!("Attempted to run on a non-primary replica, retrying on primary..."); + self.run_on_primary_replica(&task).await + } + res => res, + } + } + + async fn run_on_any_replica(&self, task: F) -> Result + where + F: Fn(ServerDatabase, ServerConnection, bool) -> P, + P: Future>, + { + let mut is_first_run = true; + let replicas = self.replicas.read().unwrap().clone(); + for replica in replicas.iter() { + match task(replica.database.clone(), self.connection.connection(&replica.address)?.clone(), is_first_run) + .await + { + Err(Error::Connection(ConnectionError::UnableToConnect())) => { + debug!("Unable to connect to {}. Attempting next server.", replica.address); + } + res => return res, + } + is_first_run = false; + } + Err(self.connection.unable_to_connect_error()) + } + + async fn run_on_primary_replica(&self, task: F) -> Result + where + F: Fn(ServerDatabase, ServerConnection, bool) -> P, + P: Future>, + { + let mut primary_replica = + if let Some(replica) = self.primary_replica() { replica } else { self.seek_primary_replica().await? }; + + for retry in 0..Self::PRIMARY_REPLICA_TASK_MAX_RETRIES { + match task( + primary_replica.database.clone(), + self.connection.connection(&primary_replica.address)?.clone(), + retry == 0, + ) + .await + { + Err(Error::Connection( + ConnectionError::ClusterReplicaNotPrimary() | ConnectionError::UnableToConnect(), + )) => { + debug!("Primary replica error, waiting..."); + Self::wait_for_primary_replica_selection().await; + primary_replica = self.seek_primary_replica().await?; + } + res => return res, + } + } + Err(self.connection.unable_to_connect_error()) + } + + async fn seek_primary_replica(&self) -> Result { + for _ in 0..Self::FETCH_REPLICAS_MAX_RETRIES { + let replicas = Replica::fetch_all(self.name.clone(), self.connection.clone()).await?; + *self.replicas.write().unwrap() = replicas; + if let Some(replica) = self.primary_replica() { + return Ok(replica); + } + Self::wait_for_primary_replica_selection().await; + } + Err(self.connection.unable_to_connect_error()) + } + + fn primary_replica(&self) -> Option { + self.replicas.read().unwrap().iter().filter(|r| r.is_primary).max_by_key(|r| r.term).cloned() + } + + async fn wait_for_primary_replica_selection() { + // FIXME: blocking sleep! Can't do agnostic async sleep. + sleep(Self::WAIT_FOR_PRIMARY_REPLICA_SELECTION); + } +} + +impl fmt::Debug for Database { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Database").field("name", &self.name).field("replicas", &self.replicas).finish() + } +} + +#[derive(Clone)] +pub(super) struct Replica { + address: Address, + database_name: String, + is_primary: bool, + term: i64, + is_preferred: bool, + database: ServerDatabase, +} + +impl Replica { + fn new(name: String, metadata: ReplicaInfo, server_connection: ServerConnection) -> Self { + Self { + address: metadata.address, + database_name: name.clone(), + is_primary: metadata.is_primary, + term: metadata.term, + is_preferred: metadata.is_preferred, + database: ServerDatabase::new(name, server_connection), + } + } + + fn try_from_info(database_info: DatabaseInfo, connection: &Connection) -> Result> { + database_info + .replicas + .into_iter() + .map(|replica| { + let server_connection = connection.connection(&replica.address)?.clone(); + Ok(Replica::new(database_info.name.clone(), replica, server_connection)) + }) + .try_collect() + } + + async fn fetch_all(name: String, connection: Connection) -> Result> { + for server_connection in connection.connections() { + let res = server_connection.get_database_replicas(name.clone()).await; + match res { + Ok(res) => { + return Replica::try_from_info(res, &connection); + } + Err(Error::Connection(ConnectionError::UnableToConnect())) => { + error!( + "Failed to fetch replica info for database '{}' from {}. Attempting next server.", + name, + server_connection.address() + ); + } + Err(err) => return Err(err), + } + } + Err(connection.unable_to_connect_error()) + } +} + +impl fmt::Debug for Replica { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Replica") + .field("address", &self.address) + .field("database_name", &self.database_name) + .field("is_primary", &self.is_primary) + .field("term", &self.term) + .field("is_preferred", &self.is_preferred) + .finish() + } +} + +#[derive(Clone, Debug)] +pub(super) struct ServerDatabase { + name: String, + connection: ServerConnection, +} + +impl ServerDatabase { + fn new(name: String, connection: ServerConnection) -> Self { + ServerDatabase { name, connection } + } + + pub(super) fn name(&self) -> &str { + self.name.as_str() + } + + pub(super) fn connection(&self) -> &ServerConnection { + &self.connection + } + + async fn delete(self) -> Result { + self.connection.delete_database(self.name).await + } + + async fn schema(&self) -> Result { + self.connection.database_schema(self.name.clone()).await + } + + async fn type_schema(&self) -> Result { + self.connection.database_type_schema(self.name.clone()).await + } + + async fn rule_schema(&self) -> Result { + self.connection.database_rule_schema(self.name.clone()).await + } +} + +impl fmt::Display for ServerDatabase { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.name) + } +} diff --git a/src/database/database_manager.rs b/src/database/database_manager.rs new file mode 100644 index 00000000..f29ee9d8 --- /dev/null +++ b/src/database/database_manager.rs @@ -0,0 +1,79 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::future::Future; + +use super::{database::ServerDatabase, Database}; +use crate::{ + common::{error::ConnectionError, Result}, + connection::ServerConnection, + Connection, +}; + +#[derive(Clone, Debug)] +pub struct DatabaseManager { + connection: Connection, +} + +impl DatabaseManager { + pub fn new(connection: Connection) -> Self { + Self { connection } + } + + pub async fn get(&self, name: impl Into) -> Result { + Database::get(name.into(), self.connection.clone()).await + } + + pub async fn contains(&self, name: impl Into) -> Result { + self.run_failsafe(name.into(), move |database, server_connection, _| async move { + server_connection.database_exists(database.name().to_owned()).await + }) + .await + } + + pub async fn create(&self, name: impl Into) -> Result { + self.run_failsafe(name.into(), |database, server_connection, _| async move { + server_connection.create_database(database.name().to_owned()).await + }) + .await + } + + pub async fn all(&self) -> Result> { + let mut error_buffer = Vec::with_capacity(self.connection.server_count()); + for server_connection in self.connection.connections() { + match server_connection.all_databases().await { + Ok(list) => { + return list.into_iter().map(|db_info| Database::new(db_info, self.connection.clone())).collect() + } + Err(err) => error_buffer.push(format!("- {}: {}", server_connection.address(), err)), + } + } + Err(ConnectionError::ClusterAllNodesFailed(error_buffer.join("\n")))? + } + + async fn run_failsafe(&self, name: String, task: F) -> Result + where + F: Fn(ServerDatabase, ServerConnection, bool) -> P, + P: Future>, + { + Database::get(name, self.connection.clone()).await?.run_failsafe(&task).await + } +} diff --git a/src/connection/cluster/mod.rs b/src/database/mod.rs similarity index 86% rename from src/connection/cluster/mod.rs rename to src/database/mod.rs index 4bff4fa7..6c3c5606 100644 --- a/src/connection/cluster/mod.rs +++ b/src/database/mod.rs @@ -19,11 +19,10 @@ * under the License. */ -mod client; mod database; mod database_manager; +mod query; mod session; +mod transaction; -pub use self::{ - client::Client, database::Database, database_manager::DatabaseManager, session::Session, -}; +pub use self::{database::Database, database_manager::DatabaseManager, session::Session, transaction::Transaction}; diff --git a/src/database/query.rs b/src/database/query.rs new file mode 100644 index 00000000..b7bd5a52 --- /dev/null +++ b/src/database/query.rs @@ -0,0 +1,98 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::sync::Arc; + +use futures::Stream; + +use crate::{ + answer::{ConceptMap, Numeric}, + common::Result, + connection::TransactionStream, + Options, +}; + +#[derive(Debug)] +pub struct QueryManager { + transaction_stream: Arc, +} + +impl QueryManager { + pub(super) fn new(transaction_stream: Arc) -> QueryManager { + QueryManager { transaction_stream } + } + + pub async fn define(&self, query: &str) -> Result { + self.define_with_options(query, Options::new()).await + } + + pub async fn define_with_options(&self, query: &str, options: Options) -> Result { + self.transaction_stream.define(query.to_string(), options).await + } + + pub async fn undefine(&self, query: &str) -> Result { + self.undefine_with_options(query, Options::new()).await + } + + pub async fn undefine_with_options(&self, query: &str, options: Options) -> Result { + self.transaction_stream.undefine(query.to_string(), options).await + } + + pub async fn delete(&self, query: &str) -> Result { + self.delete_with_options(query, Options::new()).await + } + + pub async fn delete_with_options(&self, query: &str, options: Options) -> Result { + self.transaction_stream.delete(query.to_string(), options).await + } + + pub fn match_(&self, query: &str) -> Result>> { + self.match_with_options(query, Options::new()) + } + + pub fn match_with_options(&self, query: &str, options: Options) -> Result>> { + self.transaction_stream.match_(query.to_string(), options) + } + + pub fn insert(&self, query: &str) -> Result>> { + self.insert_with_options(query, Options::new()) + } + + pub fn insert_with_options(&self, query: &str, options: Options) -> Result>> { + self.transaction_stream.insert(query.to_string(), options) + } + + pub fn update(&self, query: &str) -> Result>> { + self.update_with_options(query, Options::new()) + } + + pub fn update_with_options(&self, query: &str, options: Options) -> Result>> { + self.transaction_stream.update(query.to_string(), options) + } + + pub async fn match_aggregate(&self, query: &str) -> Result { + self.match_aggregate_with_options(query, Options::new()).await + } + + pub async fn match_aggregate_with_options(&self, query: &str, options: Options) -> Result { + self.transaction_stream.match_aggregate(query.to_string(), options).await + } +} diff --git a/src/database/session.rs b/src/database/session.rs new file mode 100644 index 00000000..b8262732 --- /dev/null +++ b/src/database/session.rs @@ -0,0 +1,129 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::sync::RwLock; + +use crossbeam::atomic::AtomicCell; +use log::warn; + +use crate::{ + common::{error::ConnectionError, info::SessionInfo, Result, SessionType, TransactionType}, + Database, Options, Transaction, +}; + +#[derive(Debug)] +pub struct Session { + database: Database, + server_session_info: RwLock, + session_type: SessionType, + is_open: AtomicCell, +} + +impl Drop for Session { + fn drop(&mut self) { + if let Err(err) = self.force_close() { + warn!("Error encountered while closing session: {}", err); + } + } +} + +impl Session { + pub async fn new(database: Database, session_type: SessionType) -> Result { + let server_session_info = RwLock::new( + database + .run_failsafe(|database, _, _| async move { + database + .connection() + .open_session(database.name().to_owned(), session_type, Options::default()) + .await + }) + .await?, + ); + + Ok(Self { database, session_type, server_session_info, is_open: AtomicCell::new(true) }) + } + + pub fn database_name(&self) -> &str { + self.database.name() + } + + pub fn type_(&self) -> SessionType { + self.session_type + } + + pub fn is_open(&self) -> bool { + self.is_open.load() + } + + pub fn force_close(&self) -> Result { + if self.is_open.compare_exchange(true, false).is_ok() { + let session_info = self.server_session_info.write().unwrap(); + let connection = self.database.connection().connection(&session_info.address).unwrap(); + connection.close_session(session_info.session_id.clone())?; + } + Ok(()) + } + + pub async fn transaction(&self, transaction_type: TransactionType) -> Result { + self.transaction_with_options(transaction_type, Options::new()).await + } + + pub async fn transaction_with_options( + &self, + transaction_type: TransactionType, + options: Options, + ) -> Result { + if !self.is_open() { + return Err(ConnectionError::SessionIsClosed().into()); + } + + let (session_info, transaction_stream) = self + .database + .run_failsafe(|database, _, is_first_run| { + let session_info = self.server_session_info.read().unwrap().clone(); + let session_type = self.session_type; + let options = options.clone(); + async move { + let connection = database.connection(); + let session_info = if is_first_run { + session_info + } else { + connection.open_session(database.name().to_owned(), session_type, options.clone()).await? + }; + Ok(( + session_info.clone(), + connection + .open_transaction( + session_info.session_id, + transaction_type, + options, + session_info.network_latency, + ) + .await?, + )) + } + }) + .await?; + + *self.server_session_info.write().unwrap() = session_info; + Transaction::new(transaction_stream) + } +} diff --git a/src/database/transaction.rs b/src/database/transaction.rs new file mode 100644 index 00000000..671ae9e4 --- /dev/null +++ b/src/database/transaction.rs @@ -0,0 +1,70 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{fmt, marker::PhantomData, sync::Arc}; + +use super::query::QueryManager; +use crate::{ + common::{Result, TransactionType}, + connection::TransactionStream, + Options, +}; + +pub struct Transaction<'a> { + type_: TransactionType, + options: Options, + + query: QueryManager, + transaction_stream: Arc, + + _lifetime_guard: PhantomData<&'a ()>, +} + +impl Transaction<'_> { + pub(super) fn new(transaction_stream: TransactionStream) -> Result { + let transaction_stream = Arc::new(transaction_stream); + Ok(Transaction { + type_: transaction_stream.type_(), + options: transaction_stream.options().clone(), + query: QueryManager::new(transaction_stream.clone()), + transaction_stream, + _lifetime_guard: PhantomData::default(), + }) + } + + pub fn query(&self) -> &QueryManager { + &self.query + } + + pub async fn commit(self) -> Result { + self.transaction_stream.commit().await + } + + pub async fn rollback(&self) -> Result { + self.transaction_stream.rollback().await + } +} + +impl fmt::Debug for Transaction<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Transaction").field("type_", &self.type_).field("options", &self.options).finish() + } +} diff --git a/src/lib.rs b/src/lib.rs index aa246581..721b9fe1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,15 +19,14 @@ * under the License. */ -#![allow(dead_code)] - -pub mod answer; -pub mod common; +mod answer; +mod common; pub mod concept; -pub(crate) mod connection; -pub mod query; +mod connection; +mod database; pub use self::{ - common::{Credential, Error, Result, SessionType, TransactionType}, - connection::{cluster, core, server}, + common::{error, Credential, Error, Options, Result, SessionType, TransactionType}, + connection::Connection, + database::{Database, DatabaseManager, Session, Transaction}, }; diff --git a/src/query/mod.rs b/src/query/mod.rs deleted file mode 100644 index 111caccb..00000000 --- a/src/query/mod.rs +++ /dev/null @@ -1,188 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::iter::once; - -use futures::{stream, Stream, StreamExt}; -use query_manager::res::Res::MatchAggregateRes; -use typedb_protocol::{ - query_manager, - query_manager::res_part::Res::{InsertResPart, MatchResPart, UpdateResPart}, - transaction, -}; - -use crate::{ - answer::{ConceptMap, Numeric}, - common::{ - error::ClientError, - rpc::builder::query_manager::{ - define_req, delete_req, insert_req, match_aggregate_req, match_req, undefine_req, - update_req, - }, - Result, TransactionRPC, - }, - connection::core, -}; - -macro_rules! stream_concept_maps { - ($self:ident, $req:ident, $res_part_kind:ident, $query_type_str:tt) => { - $self.stream_answers($req).flat_map(|result: Result| { - match result { - Ok(res_part) => match res_part { - $res_part_kind(x) => { - stream::iter(x.answers.into_iter().map(|cm| ConceptMap::from_proto(cm))) - .left_stream() - } - _ => stream::iter(once(Err(ClientError::MissingResponseField(concat!( - "query_manager_res_part.", - $query_type_str, - "_res_part" - )) - .into()))) - .right_stream(), - }, - Err(err) => stream::iter(once(Err(err))).right_stream(), - } - }) - }; -} - -#[derive(Clone, Debug)] -pub struct QueryManager { - tx: TransactionRPC, -} - -impl QueryManager { - pub(crate) fn new(tx: &TransactionRPC) -> QueryManager { - QueryManager { tx: tx.clone() } - } - - pub async fn define(&mut self, query: &str) -> Result { - self.single_call(define_req(query, None)).await.map(|_| ()) - } - - pub async fn define_with_options(&mut self, query: &str, options: &core::Options) -> Result { - self.single_call(define_req(query, Some(options.to_proto()))).await.map(|_| ()) - } - - pub async fn delete(&mut self, query: &str) -> Result { - self.single_call(delete_req(query, None)).await.map(|_| ()) - } - - pub async fn delete_with_options(&mut self, query: &str, options: &core::Options) -> Result { - self.single_call(delete_req(query, Some(options.to_proto()))).await.map(|_| ()) - } - - pub fn insert(&mut self, query: &str) -> impl Stream> { - let req = insert_req(query, None); - stream_concept_maps!(self, req, InsertResPart, "insert") - } - - pub fn insert_with_options( - &mut self, - query: &str, - options: &core::Options, - ) -> impl Stream> { - let req = insert_req(query, Some(options.to_proto())); - stream_concept_maps!(self, req, InsertResPart, "insert") - } - - // TODO: investigate performance impact of using BoxStream - pub fn match_(&mut self, query: &str) -> impl Stream> { - let req = match_req(query, None); - stream_concept_maps!(self, req, MatchResPart, "match") - } - - pub fn match_with_options( - &mut self, - query: &str, - options: &core::Options, - ) -> impl Stream> { - let req = match_req(query, Some(options.to_proto())); - stream_concept_maps!(self, req, MatchResPart, "match") - } - - pub async fn match_aggregate(&mut self, query: &str) -> Result { - match self.single_call(match_aggregate_req(query, None)).await? { - MatchAggregateRes(res) => res.answer.unwrap().try_into(), - _ => Err(ClientError::MissingResponseField("match_aggregate_res"))?, - } - } - - pub async fn match_aggregate_with_options( - &mut self, - query: &str, - options: core::Options, - ) -> Result { - match self.single_call(match_aggregate_req(query, Some(options.to_proto()))).await? { - MatchAggregateRes(res) => res.answer.unwrap().try_into(), - _ => Err(ClientError::MissingResponseField("match_aggregate_res"))?, - } - } - - pub async fn undefine(&mut self, query: &str) -> Result { - self.single_call(undefine_req(query, None)).await.map(|_| ()) - } - - pub async fn undefine_with_options(&mut self, query: &str, options: &core::Options) -> Result { - self.single_call(undefine_req(query, Some(options.to_proto()))).await.map(|_| ()) - } - - pub fn update(&mut self, query: &str) -> impl Stream> { - let req = update_req(query, None); - stream_concept_maps!(self, req, UpdateResPart, "update") - } - - pub fn update_with_options( - &mut self, - query: &str, - options: &core::Options, - ) -> impl Stream> { - let req = update_req(query, Some(options.to_proto())); - stream_concept_maps!(self, req, UpdateResPart, "update") - } - - async fn single_call(&mut self, req: transaction::Req) -> Result { - match self.tx.single(req).await?.res { - Some(transaction::res::Res::QueryManagerRes(res)) => { - res.res.ok_or(ClientError::MissingResponseField("res.query_manager_res").into()) - } - _ => Err(ClientError::MissingResponseField("res.query_manager_res"))?, - } - } - - fn stream_answers( - &mut self, - req: transaction::Req, - ) -> impl Stream> { - self.tx.stream(req).map(|result: Result| match result { - Ok(tx_res_part) => match tx_res_part.res { - Some(transaction::res_part::Res::QueryManagerResPart(res_part)) => { - res_part.res.ok_or( - ClientError::MissingResponseField("res_part.query_manager_res_part").into(), - ) - } - _ => Err(ClientError::MissingResponseField("res_part.query_manager_res_part"))?, - }, - Err(err) => Err(err), - }) - } -} diff --git a/tests/BUILD b/tests/BUILD index 5d209075..ea2bd313 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -24,31 +24,36 @@ package(default_visibility = ["//visibility:public"]) load("@rules_rust//rust:defs.bzl", "rust_test", "rustfmt_test") load("@vaticle_bazel_distribution//artifact:rules.bzl", "artifact_extractor") load("@vaticle_dependencies//tool/checkstyle:rules.bzl", "checkstyle_test") -load("@vaticle_typedb_common//test:rules.bzl", "native_typedb_artifact") +load("@vaticle_typedb_common//runner:rules.bzl", "native_typedb_artifact") rust_test( - name = "queries_core", - srcs = ["queries_core.rs"], + name = "queries", + srcs = [ + "common.rs", + "queries.rs", + ], deps = [ "//:typedb_client", - - "@vaticle_dependencies//library/crates:chrono", - "@vaticle_dependencies//library/crates:futures", - "@vaticle_dependencies//library/crates:serial_test", - "@vaticle_dependencies//library/crates:tokio", + "@crates//:chrono", + "@crates//:futures", + "@crates//:serial_test", + "@crates//:tokio", ], ) rust_test( - name = "queries_cluster", - srcs = ["queries_cluster.rs"], + name = "runtimes", + srcs = [ + "common.rs", + "runtimes.rs", + ], deps = [ "//:typedb_client", - - "@vaticle_dependencies//library/crates:futures", - "@vaticle_dependencies//library/crates:serial_test", - "@vaticle_dependencies//library/crates:tokio", + "@crates//:async-std", + "@crates//:futures", + "@crates//:serial_test", + "@crates//:smol", ], ) @@ -80,10 +85,7 @@ artifact_extractor( rustfmt_test( name = "queries_rustfmt_test", - targets = [ - "queries_core", - "queries_cluster", - ] + targets = ["queries", "runtimes"] ) checkstyle_test( diff --git a/tests/common.rs b/tests/common.rs new file mode 100644 index 00000000..f8bc7ac5 --- /dev/null +++ b/tests/common.rs @@ -0,0 +1,62 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::path::PathBuf; + +use futures::TryFutureExt; +use typedb_client::{ + Connection, Credential, Database, DatabaseManager, Session, SessionType::Schema, TransactionType::Write, +}; + +pub const TEST_DATABASE: &str = "test"; + +pub fn new_core_connection() -> typedb_client::Result { + Connection::new_plaintext("127.0.0.1:1729") +} + +pub fn new_cluster_connection() -> typedb_client::Result { + Connection::new_encrypted( + &["localhost:11729", "localhost:21729", "localhost:31729"], + Credential::with_tls( + "admin", + "password", + Some(&PathBuf::from( + std::env::var("ROOT_CA") + .expect("ROOT_CA environment variable needs to be set for cluster tests to run"), + )), + )?, + ) +} + +pub async fn create_test_database_with_schema(connection: Connection, schema: &str) -> typedb_client::Result { + let databases = DatabaseManager::new(connection); + if databases.contains(TEST_DATABASE).await? { + databases.get(TEST_DATABASE).and_then(Database::delete).await?; + } + databases.create(TEST_DATABASE).await?; + + let database = databases.get(TEST_DATABASE).await?; + let session = Session::new(database, Schema).await?; + let transaction = session.transaction(Write).await?; + transaction.query().define(schema).await?; + transaction.commit().await?; + Ok(()) +} diff --git a/tests/queries.rs b/tests/queries.rs new file mode 100644 index 00000000..55bdc102 --- /dev/null +++ b/tests/queries.rs @@ -0,0 +1,332 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +mod common; + +use std::{sync::Arc, time::Instant}; + +use chrono::{NaiveDate, NaiveDateTime}; +use futures::StreamExt; +use serial_test::serial; +use tokio::sync::mpsc; +use typedb_client::{ + concept::{Attribute, Concept, DateTimeAttribute, StringAttribute, Thing}, + error::ConnectionError, + Connection, DatabaseManager, Error, Options, Session, + SessionType::Data, + TransactionType::{Read, Write}, +}; + +macro_rules! test_for_each_arg { + { + $perm_args:tt + $( $( #[ $extra_anno:meta ] )* $async:ident fn $test:ident $args:tt -> $ret:ty $test_impl:block )+ + } => { + test_for_each_arg!{ @impl $( $async fn $test $args $ret $test_impl )+ } + test_for_each_arg!{ @impl_per $perm_args { $( $( #[ $extra_anno ] )* $async fn $test )+ } } + }; + + { @impl $( $async:ident fn $test:ident $args:tt $ret:ty $test_impl:block )+ } => { + mod _impl { + use super::*; + $( pub $async fn $test $args -> $ret $test_impl )+ + } + }; + + { @impl_per { $($mod:ident => $arg:expr),+ $(,)? } $fns:tt } => { + $(test_for_each_arg!{ @impl_mod { $mod => $arg } $fns })+ + }; + + { @impl_mod { $mod:ident => $arg:expr } { $( $( #[ $extra_anno:meta ] )* async fn $test:ident )+ } } => { + mod $mod { + use super::*; + $( + #[tokio::test] + #[serial($mod)] + $( #[ $extra_anno ] )* + pub async fn $test() { + _impl::$test($arg).await.unwrap(); + } + )+ + } + }; +} + +test_for_each_arg! { + { + core => common::new_core_connection().unwrap(), + cluster => common::new_cluster_connection().unwrap(), + } + + async fn basic(connection: Connection) -> typedb_client::Result { + common::create_test_database_with_schema(connection.clone(), "define person sub entity;").await?; + let databases = DatabaseManager::new(connection); + assert!(databases.contains(common::TEST_DATABASE).await?); + + let session = Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?; + let transaction = session.transaction(Write).await?; + let answer_stream = transaction.query().match_("match $x sub thing;")?; + let results: Vec<_> = answer_stream.collect().await; + transaction.commit().await?; + assert_eq!(results.len(), 5); + assert!(results.into_iter().all(|res| res.is_ok())); + + Ok(()) + } + + async fn query_error(connection: Connection) -> typedb_client::Result { + common::create_test_database_with_schema(connection.clone(), "define person sub entity;").await?; + let databases = DatabaseManager::new(connection); + + let session = Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?; + let transaction = session.transaction(Write).await?; + let answer_stream = transaction.query().match_("match $x sub nonexistent-type;")?; + let results: Vec<_> = answer_stream.collect().await; + assert_eq!(results.len(), 1); + assert!(results.into_iter().all(|res| res.unwrap_err().to_string().contains("[TYR03]"))); + + Ok(()) + } + + async fn concurrent_transactions(connection: Connection) -> typedb_client::Result { + common::create_test_database_with_schema(connection.clone(), "define person sub entity;").await?; + let databases = DatabaseManager::new(connection); + + let session = Arc::new(Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?); + + let (sender, mut receiver) = mpsc::channel(5 * 5 * 8); + + for _ in 0..8 { + let sender = sender.clone(); + let session = session.clone(); + tokio::spawn(async move { + for _ in 0..5 { + let transaction = session.transaction(Read).await.unwrap(); + let mut answer_stream = transaction.query().match_("match $x sub thing;").unwrap(); + while let Some(result) = answer_stream.next().await { + sender.send(result).await.unwrap(); + } + } + }); + } + drop(sender); // receiver expects data while any sender is live + + let mut results = Vec::with_capacity(5 * 5 * 8); + while let Some(result) = receiver.recv().await { + results.push(result); + } + assert_eq!(results.len(), 5 * 5 * 8); + assert!(results.into_iter().all(|res| res.is_ok())); + + Ok(()) + } + + async fn query_options(connection: Connection) -> typedb_client::Result { + let schema = r#"define + person sub entity, + owns name, + owns age; + name sub attribute, value string; + age sub attribute, value long; + rule age-rule: when { $x isa person; } then { $x has age 25; };"#; + common::create_test_database_with_schema(connection.clone(), schema).await?; + let databases = DatabaseManager::new(connection); + + let session = Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?; + let transaction = session.transaction(Write).await?; + let data = "insert $x isa person, has name 'Alice'; $y isa person, has name 'Bob';"; + let _ = transaction.query().insert(data); + transaction.commit().await?; + + let transaction = session.transaction(Read).await?; + let age_count = transaction.query().match_aggregate("match $x isa age; count;").await?; + assert_eq!(age_count.into_i64(), 0); + + let with_inference = Options::new().infer(true); + let transaction = session.transaction_with_options(Read, with_inference).await?; + let age_count = transaction.query().match_aggregate("match $x isa age; count;").await?; + assert_eq!(age_count.into_i64(), 1); + + Ok(()) + } + + async fn many_concept_types(connection: Connection) -> typedb_client::Result { + let schema = r#"define + person sub entity, + owns name, + owns date-of-birth, + plays friendship:friend; + name sub attribute, value string; + date-of-birth sub attribute, value datetime; + friendship sub relation, + relates friend;"#; + common::create_test_database_with_schema(connection.clone(), schema).await?; + let databases = DatabaseManager::new(connection); + + let session = Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?; + let transaction = session.transaction(Write).await?; + let data = r#"insert + $x isa person, has name "Alice", has date-of-birth 1994-10-03; + $y isa person, has name "Bob", has date-of-birth 1993-04-17; + (friend: $x, friend: $y) isa friendship;"#; + let _ = transaction.query().insert(data); + transaction.commit().await?; + + let transaction = session.transaction(Read).await?; + let mut answer_stream = transaction.query().match_( + r#"match + $p isa person, has name $name, has date-of-birth $date-of-birth; + $f($role: $p) isa friendship;"#, + )?; + + while let Some(result) = answer_stream.next().await { + assert!(result.is_ok()); + let mut result = result?.map; + let name = unwrap_string(result.remove("name").unwrap()); + let date_of_birth = unwrap_date_time(result.remove("date-of-birth").unwrap()).date(); + match name.as_str() { + "Alice" => assert_eq!(date_of_birth, NaiveDate::from_ymd_opt(1994, 10, 3).unwrap()), + "Bob" => assert_eq!(date_of_birth, NaiveDate::from_ymd_opt(1993, 4, 17).unwrap()), + _ => unreachable!(), + } + } + + Ok(()) + } + + async fn force_close_connection(connection: Connection) -> typedb_client::Result { + common::create_test_database_with_schema(connection.clone(), "define person sub entity;").await?; + let databases = DatabaseManager::new(connection.clone()); + + let database = databases.get(common::TEST_DATABASE).await?; + assert!(database.schema().await.is_ok()); + + let session = Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?; + connection.clone().force_close()?; + + let schema = database.schema().await; + assert!(schema.is_err()); + assert!(matches!(schema, Err(Error::Connection(ConnectionError::ConnectionIsClosed())))); + + let database2 = databases.get(common::TEST_DATABASE).await; + assert!(database2.is_err()); + assert!(matches!(database2, Err(Error::Connection(ConnectionError::ConnectionIsClosed())))); + + let transaction = session.transaction(Write).await; + assert!(transaction.is_err()); + assert!(matches!(transaction, Err(Error::Connection(ConnectionError::ConnectionIsClosed())))); + + let session = Session::new(database, Data).await; + assert!(session.is_err()); + assert!(matches!(session, Err(Error::Connection(ConnectionError::ConnectionIsClosed())))); + + Ok(()) + } + + async fn force_close_session(connection: Connection) -> typedb_client::Result { + common::create_test_database_with_schema(connection.clone(), "define person sub entity;").await?; + let databases = DatabaseManager::new(connection.clone()); + + let session = Arc::new(Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?); + let transaction = session.transaction(Write).await?; + + let session2 = session.clone(); + session2.force_close()?; + + let answer_stream = transaction.query().match_("match $x sub thing;"); + assert!(answer_stream.is_err()); + assert!(transaction.query().match_("match $x sub thing;").is_err()); + + let transaction = session.transaction(Write).await; + assert!(transaction.is_err()); + assert!(matches!(transaction, Err(Error::Connection(ConnectionError::SessionIsClosed())))); + + assert!(Session::new(databases.get(common::TEST_DATABASE).await?, Data).await.is_ok()); + + Ok(()) + } + + #[ignore] + async fn streaming_perf(connection: Connection) -> typedb_client::Result { + for i in 0..5 { + let schema = r#"define + person sub entity, owns name, owns age; + name sub attribute, value string; + age sub attribute, value long;"#; + common::create_test_database_with_schema(connection.clone(), schema).await?; + let databases = DatabaseManager::new(connection.clone()); + + let start_time = Instant::now(); + let session = Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?; + let transaction = session.transaction(Write).await?; + for j in 0..100_000 { + drop(transaction.query().insert(format!("insert $x {j} isa age;").as_str())?); + } + transaction.commit().await?; + println!("iteration {i}: inserted and committed 100k attrs in {}ms", start_time.elapsed().as_millis()); + + let mut start_time = Instant::now(); + let session = Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?; + let transaction = session.transaction(Read).await?; + let mut answer_stream = transaction.query().match_("match $x isa attribute;")?; + let mut sum: i64 = 0; + let mut idx = 0; + while let Some(result) = answer_stream.next().await { + match result { + Ok(concept_map) => { + for (_, concept) in concept_map { + if let Concept::Thing(Thing::Attribute(Attribute::Long(long_attr))) = concept { + sum += long_attr.value + } + } + } + Err(err) => { + panic!("An error occurred fetching answers of a Match query: {}", err) + } + } + idx = idx + 1; + if idx == 100_000 { + println!("iteration {i}: retrieved and summed 100k attrs in {}ms", start_time.elapsed().as_millis()); + start_time = Instant::now(); + } + } + println!("sum is {}", sum); + } + + Ok(()) + } +} + +// Concept helpers +// FIXME: should be removed after concept API is implemented +fn unwrap_date_time(concept: Concept) -> NaiveDateTime { + match concept { + Concept::Thing(Thing::Attribute(Attribute::DateTime(DateTimeAttribute { value, .. }))) => value, + _ => unreachable!(), + } +} + +fn unwrap_string(concept: Concept) -> String { + match concept { + Concept::Thing(Thing::Attribute(Attribute::String(StringAttribute { value, .. }))) => value, + _ => unreachable!(), + } +} diff --git a/tests/queries_cluster.rs b/tests/queries_cluster.rs deleted file mode 100644 index 058cafa4..00000000 --- a/tests/queries_cluster.rs +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::path::PathBuf; - -use futures::{StreamExt, TryFutureExt}; -use serial_test::serial; -use typedb_client::{ - cluster, - common::{Credential, SessionType::Data, TransactionType::Write}, -}; - -const TEST_DATABASE: &str = "test"; - -#[tokio::test(flavor = "multi_thread")] -#[serial] -async fn basic() { - let mut client = cluster::Client::new( - &["localhost:11729", "localhost:21729", "localhost:31729"], - Credential::with_tls( - "admin", - "password", - Some(&PathBuf::from(std::env::var("ROOT_CA").unwrap())), - ), - ) - .await - .unwrap(); - - if client.databases().contains(TEST_DATABASE).await.unwrap() { - client.databases().get(TEST_DATABASE).and_then(|db| db.delete()).await.unwrap(); - } - client.databases().create(TEST_DATABASE).await.unwrap(); - - assert!(client.databases().contains(TEST_DATABASE).await.unwrap()); - - let mut session = client.session(TEST_DATABASE, Data).await.unwrap(); - let mut transaction = session.transaction(Write).await.unwrap(); - let mut answer_stream = transaction.query.match_("match $x sub thing;"); - while let Some(result) = answer_stream.next().await { - assert!(result.is_ok()) - } - transaction.commit().await.unwrap(); -} diff --git a/tests/queries_core.rs b/tests/queries_core.rs deleted file mode 100644 index 3858a0b1..00000000 --- a/tests/queries_core.rs +++ /dev/null @@ -1,249 +0,0 @@ -/* - * Copyright (C) 2022 Vaticle - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::{sync::mpsc, time::Instant}; - -use chrono::{NaiveDate, NaiveDateTime}; -use futures::{StreamExt, TryFutureExt}; -use serial_test::serial; -use typedb_client::{ - common::{ - SessionType::{Data, Schema}, - TransactionType::{Read, Write}, - }, - concept::{Attribute, Concept, DateTimeAttribute, StringAttribute, Thing}, - core, server, -}; - -const TEST_DATABASE: &str = "test"; - -#[tokio::test(flavor = "multi_thread")] -#[serial] -async fn basic() { - let mut client = core::Client::with_default_address().await.unwrap(); - create_test_database_with_schema(&mut client, "define person sub entity;").await.unwrap(); - assert!(client.databases().contains(TEST_DATABASE).await.unwrap()); - - let session = client.session(TEST_DATABASE, Data).await.unwrap(); - let mut transaction = session.transaction(Write).await.unwrap(); - let mut answer_stream = transaction.query.match_("match $x sub thing;"); - while let Some(result) = answer_stream.next().await { - assert!(result.is_ok()) - } - transaction.commit().await.unwrap(); -} - -#[tokio::test(flavor = "multi_thread")] -#[serial] -async fn concurrent_queries() { - let mut client = core::Client::with_default_address().await.unwrap(); - create_test_database_with_schema(&mut client, "define person sub entity;").await.unwrap(); - - let session = client.session(TEST_DATABASE, Data).await.unwrap(); - let transaction = session.transaction(Write).await.unwrap(); - - let (sender, receiver) = mpsc::channel(); - - for _ in 0..5 { - let sender = sender.clone(); - let mut transaction = transaction.clone(); - tokio::spawn(async move { - for _ in 0..5 { - let mut answer_stream = transaction.query.match_("match $x sub thing;"); - while let Some(result) = answer_stream.next().await { - sender.send(result).unwrap(); - } - } - }); - } - drop(sender); // receiver expects data while any sender is live - - for received in receiver { - assert!(received.is_ok()); - } -} - -#[tokio::test(flavor = "multi_thread")] -#[serial] -async fn query_options() { - let mut client = core::Client::with_default_address().await.unwrap(); - let schema = r#"define - person sub entity, - owns name, - owns age; - name sub attribute, value string; - age sub attribute, value long; - rule age-rule: when { $x isa person; } then { $x has age 25; };"#; - create_test_database_with_schema(&mut client, schema).await.unwrap(); - - let session = client.session(TEST_DATABASE, Data).await.unwrap(); - let mut transaction = session.transaction(Write).await.unwrap(); - let data = "insert $x isa person, has name 'Alice'; $y isa person, has name 'Bob';"; - let _ = transaction.query.insert(data); - transaction.commit().await.unwrap(); - - let mut transaction = session.transaction(Read).await.unwrap(); - let age_count = transaction.query.match_aggregate("match $x isa age; count;").await.unwrap(); - assert_eq!(age_count.into_i64(), 0); - - let with_inference = core::Options::new_core().infer(true); - let mut transaction = session.transaction_with_options(Read, with_inference).await.unwrap(); - let age_count = transaction.query.match_aggregate("match $x isa age; count;").await.unwrap(); - assert_eq!(age_count.into_i64(), 1); -} - -#[tokio::test(flavor = "multi_thread")] -#[serial] -async fn many_concept_types() { - let mut client = core::Client::with_default_address().await.unwrap(); - let schema = r#"define - person sub entity, - owns name, - owns date-of-birth, - plays friendship:friend; - name sub attribute, value string; - date-of-birth sub attribute, value datetime; - friendship sub relation, - relates friend;"#; - create_test_database_with_schema(&mut client, schema).await.unwrap(); - - let session = client.session(TEST_DATABASE, Data).await.unwrap(); - let mut transaction = session.transaction(Write).await.unwrap(); - let data = r#"insert - $x isa person, has name "Alice", has date-of-birth 1994-10-03; - $y isa person, has name "Bob", has date-of-birth 1993-04-17; - (friend: $x, friend: $y) isa friendship;"#; - let _ = transaction.query.insert(data); - transaction.commit().await.unwrap(); - - let mut transaction = session.transaction(Read).await.unwrap(); - let mut answer_stream = transaction.query.match_( - r#"match - $p isa person, has name $name, has date-of-birth $date-of-birth; - $f($role: $p) isa friendship;"#, - ); - - while let Some(result) = answer_stream.next().await { - assert!(result.is_ok()); - let mut result = result.unwrap().map; - let name = unwrap_string(result.remove("name").unwrap()); - let date_of_birth = unwrap_date_time(result.remove("date-of-birth").unwrap()).date(); - match name.as_str() { - "Alice" => assert_eq!(date_of_birth, NaiveDate::from_ymd_opt(1994, 10, 3).unwrap()), - "Bob" => assert_eq!(date_of_birth, NaiveDate::from_ymd_opt(1993, 4, 17).unwrap()), - _ => unreachable!(), - } - } -} - -#[tokio::test(flavor = "multi_thread")] -#[serial] -#[ignore] -async fn streaming_perf() { - let mut client = core::Client::with_default_address().await.unwrap(); - for i in 0..5 { - let schema = r#"define - person sub entity, owns name, owns age; - name sub attribute, value string; - age sub attribute, value long;"#; - create_test_database_with_schema(&mut client, schema).await.unwrap(); - - let start_time = Instant::now(); - let session = client.session(TEST_DATABASE, Data).await.unwrap(); - let mut transaction = session.transaction(Write).await.unwrap(); - for j in 0..100_000 { - let _ = transaction.query.insert(format!("insert $x {j} isa age;").as_str()); - } - transaction.commit().await.unwrap(); - println!( - "iteration {i}: inserted and committed 100k attrs in {}ms", - (Instant::now() - start_time).as_millis() - ); - - let mut start_time = Instant::now(); - let session = client.session(TEST_DATABASE, Data).await.unwrap(); - let mut transaction = session.transaction(Read).await.unwrap(); - let mut answer_stream = transaction.query.match_("match $x isa attribute;"); - let mut sum: i64 = 0; - let mut idx = 0; - while let Some(result) = answer_stream.next().await { - match result { - Ok(concept_map) => { - for (_, concept) in concept_map { - if let Concept::Thing(Thing::Attribute(Attribute::Long(long_attr))) = - concept - { - sum += long_attr.value - } - } - } - Err(err) => { - panic!("An error occurred fetching answers of a Match query: {}", err) - } - } - idx = idx + 1; - if idx == 100_000 { - println!( - "iteration {i}: retrieved and summed 100k attrs in {}ms", - (Instant::now() - start_time).as_millis() - ); - start_time = Instant::now(); - } - } - println!("sum is {}", sum); - } -} - -async fn create_test_database_with_schema( - client: &mut core::Client, - schema: &str, -) -> typedb_client::Result { - if client.databases().contains(TEST_DATABASE).await.unwrap() { - client.databases().get(TEST_DATABASE).and_then(server::Database::delete).await.unwrap(); - } - client.databases().create(TEST_DATABASE).await.unwrap(); - - let mut session = client.session(TEST_DATABASE, Schema).await.unwrap(); - let mut transaction = session.transaction(Write).await.unwrap(); - transaction.query.define(schema).await.unwrap(); - transaction.commit().await.unwrap(); - session.close().await; - - Ok(()) -} - -// Concept helpers -// FIXME should be removed after concept API is implemented -fn unwrap_date_time(concept: Concept) -> NaiveDateTime { - match concept { - Concept::Thing(Thing::Attribute(Attribute::DateTime(DateTimeAttribute { - value, .. - }))) => value, - _ => unreachable!(), - } -} - -fn unwrap_string(concept: Concept) -> String { - match concept { - Concept::Thing(Thing::Attribute(Attribute::String(StringAttribute { value, .. }))) => value, - _ => unreachable!(), - } -} diff --git a/tests/runtimes.rs b/tests/runtimes.rs new file mode 100644 index 00000000..00943ed0 --- /dev/null +++ b/tests/runtimes.rs @@ -0,0 +1,89 @@ +/* + * Copyright (C) 2022 Vaticle + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +mod common; + +use futures::StreamExt; +use serial_test::serial; +use typedb_client::{DatabaseManager, Session, SessionType::Data, TransactionType::Write}; + +#[test] +#[serial] +fn basic_async_std() { + async_std::task::block_on(async { + let connection = common::new_cluster_connection()?; + common::create_test_database_with_schema(connection.clone(), "define person sub entity;").await?; + let databases = DatabaseManager::new(connection); + assert!(databases.contains(common::TEST_DATABASE).await?); + + let session = Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?; + let transaction = session.transaction(Write).await?; + let answer_stream = transaction.query().match_("match $x sub thing;")?; + let results: Vec<_> = answer_stream.collect().await; + transaction.commit().await?; + assert_eq!(results.len(), 5); + assert!(results.into_iter().all(|res| res.is_ok())); + Ok::<(), typedb_client::Error>(()) + }) + .unwrap(); +} + +#[test] +#[serial] +fn basic_smol() { + smol::block_on(async { + let connection = common::new_cluster_connection()?; + common::create_test_database_with_schema(connection.clone(), "define person sub entity;").await?; + let databases = DatabaseManager::new(connection); + assert!(databases.contains(common::TEST_DATABASE).await?); + + let session = Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?; + let transaction = session.transaction(Write).await?; + let answer_stream = transaction.query().match_("match $x sub thing;")?; + let results: Vec<_> = answer_stream.collect().await; + transaction.commit().await?; + assert_eq!(results.len(), 5); + assert!(results.into_iter().all(|res| res.is_ok())); + Ok::<(), typedb_client::Error>(()) + }) + .unwrap(); +} + +#[test] +#[serial] +fn basic_futures() { + futures::executor::block_on(async { + let connection = common::new_cluster_connection()?; + common::create_test_database_with_schema(connection.clone(), "define person sub entity;").await?; + let databases = DatabaseManager::new(connection); + assert!(databases.contains(common::TEST_DATABASE).await?); + + let session = Session::new(databases.get(common::TEST_DATABASE).await?, Data).await?; + let transaction = session.transaction(Write).await?; + let answer_stream = transaction.query().match_("match $x sub thing;")?; + let results: Vec<_> = answer_stream.collect().await; + transaction.commit().await?; + assert_eq!(results.len(), 5); + assert!(results.into_iter().all(|res| res.is_ok())); + Ok::<(), typedb_client::Error>(()) + }) + .unwrap(); +}