From dc8d574d3ee803ccc41e0419e8cedf36c4ea310c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20F=C3=A9ron?= Date: Wed, 16 Oct 2024 13:50:50 +0200 Subject: [PATCH 01/14] Switch to reqwest and reqwest-websocket --- Cargo.toml | 16 +- src/account_manager.rs | 73 +++-- src/groups_v2/manager.rs | 18 +- src/messagepipe.rs | 27 -- src/push_service/account.rs | 47 +-- src/push_service/cdn.rs | 209 ++++-------- src/push_service/error.rs | 9 +- src/push_service/keys.rs | 113 ++++--- src/push_service/linking.rs | 27 +- src/push_service/mod.rs | 544 ++++++++----------------------- src/push_service/profile.rs | 44 +-- src/push_service/registration.rs | 166 +++++----- src/sender.rs | 3 +- src/websocket/mod.rs | 125 ++++--- src/websocket/tungstenite.rs | 202 ------------ 15 files changed, 573 insertions(+), 1050 deletions(-) delete mode 100644 src/websocket/tungstenite.rs diff --git a/Cargo.toml b/Cargo.toml index c3213bb4e..cc43f204a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,22 +35,14 @@ url = { version = "2.1", features = ["serde"] } uuid = { version = "1", features = ["serde"] } # http -hyper = "1.0" -hyper-util = { version = "0.1", features = ["client", "client-legacy"] } -hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "ring", "logging"] } -hyper-timeout = "0.5" -headers = "0.4" -http-body-util = "0.1" -mpart-async = "0.7" -async-tungstenite = { version = "0.27", features = ["tokio-rustls-native-certs", "url"] } -tokio = { version = "1.0", features = ["macros"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } - -rustls-pemfile = "2.0" +reqwest = { version = "0.12", default-features = false, features = ["json", "multipart", "rustls-tls-manual-roots", "stream"] } +reqwest-websocket = { version = "0.4.2", features = ["json"] } tracing = { version = "0.1", features = ["log"] } tracing-futures = "0.2" +tokio = { version = "1.0", features = ["macros"] } + [build-dependencies] prost-build = "0.13" diff --git a/src/account_manager.rs b/src/account_manager.rs index d8251c92b..84ade2573 100644 --- a/src/account_manager.rs +++ b/src/account_manager.rs @@ -1,5 +1,6 @@ use base64::prelude::*; use phonenumber::PhoneNumber; +use reqwest::Method; use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; @@ -28,9 +29,9 @@ use crate::proto::sync_message::PniChangeNumber; use crate::proto::{DeviceName, SyncMessage}; use crate::provisioning::generate_registration_id; use crate::push_service::{ - AvatarWrite, DeviceActivationRequest, DeviceInfo, RecaptchaAttributes, - RegistrationMethod, ServiceIdType, VerifyAccountResponse, - DEFAULT_DEVICE_ID, + AvatarWrite, DeviceActivationRequest, DeviceInfo, HttpAuthOverride, + RecaptchaAttributes, RegistrationMethod, ReqwestExt, ServiceIdType, + VerifyAccountResponse, DEFAULT_DEVICE_ID, }; use crate::sender::OutgoingPushMessage; use crate::session_store::SessionStoreExt; @@ -44,9 +45,7 @@ use crate::{ profile_name::ProfileName, proto::{ProvisionEnvelope, ProvisionMessage, ProvisioningVersion}, provisioning::{ProvisioningCipher, ProvisioningError}, - push_service::{ - AccountAttributes, HttpAuthOverride, PushService, ServiceError, - }, + push_service::{AccountAttributes, PushService, ServiceError}, utils::serde_base64, }; @@ -224,13 +223,17 @@ impl AccountManager { let dc: DeviceCode = self .service - .get_json( + .request( + Method::GET, Endpoint::Service, "/v1/devices/provisioning/code", - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await?; + Ok(dc.verification_code) } @@ -247,16 +250,19 @@ impl AccountManager { let body = env.encode_to_vec(); self.service - .put_json( + .request( + Method::PUT, Endpoint::Service, &format!("/v1/provisioning/{}", destination), - &[], HttpAuthOverride::NoOverride, - &ProvisioningMessage { - body: BASE64_RELAXED.encode(body), - }, - ) - .await + )? + .json(&ProvisioningMessage { + body: BASE64_RELAXED.encode(body), + }) + .send_to_signal() + .await?; + + Ok(()) } /// Link a new device, given a tsurl. @@ -582,15 +588,16 @@ impl AccountManager { } self.service - .put_json::<(), _>( + .request( + Method::PUT, Endpoint::Service, "/v1/accounts/name", - &[], HttpAuthOverride::NoOverride, - Data { - device_name: encrypted_device_name.encode_to_vec(), - }, - ) + )? + .json(&Data { + device_name: encrypted_device_name.encode_to_vec(), + }) + .send_to_signal() .await?; Ok(()) @@ -607,20 +614,22 @@ impl AccountManager { token: &str, captcha: &str, ) -> Result<(), ServiceError> { - let payload = RecaptchaAttributes { - r#type: String::from("recaptcha"), - token: String::from(token), - captcha: String::from(captcha), - }; self.service - .put_json( + .request( + Method::PUT, Endpoint::Service, "/v1/challenge", - &[], HttpAuthOverride::NoOverride, - payload, - ) - .await + )? + .json(&RecaptchaAttributes { + r#type: String::from("recaptcha"), + token: String::from(token), + captcha: String::from(captcha), + }) + .send_to_signal() + .await?; + + Ok(()) } /// Initialize PNI on linked devices. diff --git a/src/groups_v2/manager.rs b/src/groups_v2/manager.rs index d5e77dffd..9183bef10 100644 --- a/src/groups_v2/manager.rs +++ b/src/groups_v2/manager.rs @@ -2,11 +2,13 @@ use std::{collections::HashMap, convert::TryInto}; use crate::{ configuration::Endpoint, - groups_v2::model::{Group, GroupChanges}, - groups_v2::operations::{GroupDecodingError, GroupOperations}, + groups_v2::{ + model::{Group, GroupChanges}, + operations::{GroupDecodingError, GroupOperations}, + }, prelude::{PushService, ServiceError}, proto::GroupContextV2, - push_service::{HttpAuth, HttpAuthOverride, ServiceIds}, + push_service::{HttpAuth, HttpAuthOverride, ReqwestExt, ServiceIds}, utils::BASE64_RELAXED, }; @@ -15,6 +17,7 @@ use bytes::Bytes; use chrono::{Days, NaiveDate, NaiveTime, Utc}; use futures::AsyncReadExt; use rand::RngCore; +use reqwest::Method; use serde::Deserialize; use zkgroup::{ auth::AuthCredentialWithPniResponse, @@ -165,12 +168,15 @@ impl GroupsManager { let credentials_response: CredentialResponse = self .push_service - .get_json( + .request( + Method::GET, Endpoint::Service, &path, - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await?; self.credentials_cache .write(credentials_response.parse()?)?; diff --git a/src/messagepipe.rs b/src/messagepipe.rs index cdb77b7a8..f3e0c67c5 100644 --- a/src/messagepipe.rs +++ b/src/messagepipe.rs @@ -1,11 +1,9 @@ -use bytes::Bytes; use futures::{ channel::{ mpsc::{self, Sender}, oneshot, }, prelude::*, - stream::FusedStream, }; pub use crate::{ @@ -18,24 +16,12 @@ pub use crate::{ use crate::{push_service::ServiceError, websocket::SignalWebSocket}; -pub enum WebSocketStreamItem { - Message(Bytes), - KeepAliveRequest, -} - #[derive(Debug)] pub enum Incoming { Envelope(Envelope), QueueEmpty, } -#[async_trait::async_trait] -pub trait WebSocketService { - type Stream: FusedStream + Unpin; - - async fn send_message(&mut self, msg: Bytes) -> Result<(), ServiceError>; -} - pub struct MessagePipe { ws: SignalWebSocket, credentials: ServiceCredentials, @@ -133,16 +119,3 @@ impl MessagePipe { combined.filter_map(|x| async { x }) } } - -/// WebSocketService that panics on every request, mainly for example code. -pub struct PanicingWebSocketService; - -#[allow(clippy::diverging_sub_expression)] -#[async_trait::async_trait] -impl WebSocketService for PanicingWebSocketService { - type Stream = futures::channel::mpsc::Receiver; - - async fn send_message(&mut self, _msg: Bytes) -> Result<(), ServiceError> { - todo!(); - } -} diff --git a/src/push_service/account.rs b/src/push_service/account.rs index d36eff544..293e0857e 100644 --- a/src/push_service/account.rs +++ b/src/push_service/account.rs @@ -2,10 +2,11 @@ use std::fmt; use chrono::{DateTime, Utc}; use phonenumber::PhoneNumber; +use reqwest::Method; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use super::{HttpAuthOverride, PushService, ServiceError}; +use super::{HttpAuthOverride, PushService, ReqwestExt, ServiceError}; use crate::{ configuration::Endpoint, utils::{serde_optional_base64, serde_phone_number}, @@ -127,13 +128,17 @@ pub struct WhoAmIResponse { impl PushService { /// Method used to check our own UUID pub async fn whoami(&mut self) -> Result { - self.get_json( + self.request( + Method::GET, Endpoint::Service, "/v1/accounts/whoami", - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await + .map_err(Into::into) } /// Fetches a list of all devices tied to the authenticated account. @@ -146,12 +151,15 @@ impl PushService { } let devices: DeviceInfoList = self - .get_json( + .request( + Method::GET, Endpoint::Service, "/v1/devices/", - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await?; Ok(devices.devices) @@ -166,18 +174,17 @@ impl PushService { "only one of PIN and registration lock can be set." ); - match self - .put_json( - Endpoint::Service, - "/v1/accounts/attributes/", - &[], - HttpAuthOverride::NoOverride, - attributes, - ) - .await - { - Err(ServiceError::JsonDecodeError { .. }) => Ok(()), - r => r, - } + self.request( + Method::PUT, + Endpoint::Service, + "/v1/accounts/attributes/", + HttpAuthOverride::NoOverride, + )? + .json(&attributes) + .send_to_signal() + .await? + .json() + .await + .map_err(Into::into) } } diff --git a/src/push_service/cdn.rs b/src/push_service/cdn.rs index 3ba962a6d..8309807b2 100644 --- a/src/push_service/cdn.rs +++ b/src/push_service/cdn.rs @@ -1,19 +1,15 @@ use std::io::{self, Read}; -use bytes::Bytes; -use futures::{FutureExt, StreamExt, TryStreamExt}; -use http_body_util::BodyExt; -use hyper::Method; +use futures::TryStreamExt; +use reqwest::{multipart::Part, Method}; use tracing::debug; use crate::{ - configuration::Endpoint, - prelude::AttachmentIdentifier, - proto::AttachmentPointer, - push_service::{HttpAuthOverride, RequestBody}, + configuration::Endpoint, prelude::AttachmentIdentifier, + proto::AttachmentPointer, push_service::HttpAuthOverride, }; -use super::{PushService, ServiceError}; +use super::{PushService, ReqwestExt, ServiceError}; #[derive(Debug, serde::Deserialize, Default)] #[serde(rename_all = "camelCase")] @@ -32,170 +28,109 @@ pub struct AttachmentV2UploadAttributes { } impl PushService { + pub async fn get_attachment( + &mut self, + ptr: &AttachmentPointer, + ) -> Result { + let id = match ptr.attachment_identifier.as_ref().unwrap() { + AttachmentIdentifier::CdnId(id) => &id.to_string(), + AttachmentIdentifier::CdnKey(key) => key, + }; + self.get_from_cdn(ptr.cdn_number(), &format!("attachments/{}", id)) + .await + } + #[tracing::instrument(skip(self))] pub(crate) async fn get_from_cdn( &mut self, cdn_id: u32, path: &str, ) -> Result { - let response = self + let response_stream = self .request( Method::GET, Endpoint::Cdn(cdn_id), path, - &[], HttpAuthOverride::Unidentified, // CDN requests are always without authentication - None, - ) - .await?; - - Ok(Box::new( - response - .into_body() - .into_data_stream() - .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) - .into_async_read(), - )) - } - - pub async fn get_attachment_by_id( - &mut self, - id: &str, - cdn_id: u32, - ) -> Result { - let path = format!("attachments/{}", id); - self.get_from_cdn(cdn_id, &path).await + )? + .send_to_signal() + .await? + .bytes_stream() + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + .into_async_read(); + + Ok(response_stream) } - pub async fn get_attachment( - &mut self, - ptr: &AttachmentPointer, - ) -> Result { - match ptr.attachment_identifier.as_ref().unwrap() { - AttachmentIdentifier::CdnId(id) => { - // cdn_number did not exist for this part of the protocol. - // cdn_number(), however, returns 0 when the field does not - // exist. - self.get_attachment_by_id(&format!("{}", id), ptr.cdn_number()) - .await - }, - AttachmentIdentifier::CdnKey(key) => { - self.get_attachment_by_id(key, ptr.cdn_number()).await - }, - } - } - - pub async fn get_attachment_v2_upload_attributes( + pub(crate) async fn get_attachment_v2_upload_attributes( &mut self, ) -> Result { - self.get_json( + self.request( + Method::GET, Endpoint::Service, "/v2/attachments/form/upload", - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await + .map_err(Into::into) } /// Upload attachment to CDN /// /// Returns attachment ID and the attachment digest - pub async fn upload_attachment<'s, C>( + pub async fn upload_attachment( &mut self, - attrs: &AttachmentV2UploadAttributes, - content: &'s mut C, - ) -> Result<(u64, Vec), ServiceError> - where - C: std::io::Read + Send + 's, - { - let values = [ - ("acl", &attrs.acl as &str), - ("key", &attrs.key), - ("policy", &attrs.policy), - ("Content-Type", "application/octet-stream"), - ("x-amz-algorithm", &attrs.algorithm), - ("x-amz-credential", &attrs.credential), - ("x-amz-date", &attrs.date), - ("x-amz-signature", &attrs.signature), - ]; - - let mut digester = crate::digeststream::DigestingReader::new(content); - - self.post_to_cdn0( - "attachments/", - &values, - Some(("file", &mut digester)), - ) - .await?; - Ok((attrs.attachment_id, digester.finalize())) + attrs: AttachmentV2UploadAttributes, + mut reader: impl Read + Send, + ) -> Result<(u64, Vec), ServiceError> { + let attachment_id = attrs.attachment_id; + let mut digester = + crate::digeststream::DigestingReader::new(&mut reader); + + self.post_to_cdn0("attachments/", attrs, "file".into(), &mut digester) + .await?; + + Ok((attachment_id, digester.finalize())) } - #[tracing::instrument(skip(self, value, file), fields(file = file.as_ref().map(|_| "")))] - pub async fn post_to_cdn0<'s, C>( + #[tracing::instrument(skip(self, upload_attributes, reader))] + pub async fn post_to_cdn0( &mut self, path: &str, - value: &[(&str, &str)], - file: Option<(&str, &'s mut C)>, - ) -> Result<(), ServiceError> - where - C: Read + Send + 's, - { - let mut form = mpart_async::client::MultipartRequest::default(); - - // mpart-async has a peculiar ordering of the form items, - // and Amazon S3 expects them in a very specific order (i.e., the file contents should - // go last. - // - // mpart-async uses a VecDeque internally for ordering the fields in the order given. - // - // https://github.com/cetra3/mpart-async/issues/16 - - for &(k, v) in value { - form.add_field(k, v); - } - - if let Some((filename, file)) = file { - // XXX Actix doesn't cope with none-'static lifetimes - // https://docs.rs/actix-web/3.2.0/actix_web/body/enum.Body.html - let mut buf = Vec::new(); - file.read_to_end(&mut buf) - .expect("infallible Read instance"); - form.add_stream( - "file", - filename, - "application/octet-stream", - futures::future::ok::<_, ()>(Bytes::from(buf)).into_stream(), - ); - } - - let content_type = - format!("multipart/form-data; boundary={}", form.get_boundary()); - - // XXX Amazon S3 needs the Content-Length, but we don't know it without depleting the whole - // stream. Sadly, Content-Length != contents.len(), but should include the whole form. - let mut body_contents = vec![]; - while let Some(b) = form.next().await { - // Unwrap, because no error type was used above - body_contents.extend(b.unwrap()); - } - tracing::trace!( - "Sending PUT with Content-Type={} and length {}", - content_type, - body_contents.len() - ); + upload_attributes: AttachmentV2UploadAttributes, + filename: String, + mut reader: impl Read + Send, + ) -> Result<(), ServiceError> { + let mut form = reqwest::multipart::Form::new(); + form = form.text("acl", upload_attributes.acl); + form = form.text("key", upload_attributes.key); + form = form.text("policy", upload_attributes.policy); + form = form.text("x-amz-algorithm", upload_attributes.algorithm); + form = form.text("x-amz-credential", upload_attributes.credential); + form = form.text("x-amz-date", upload_attributes.date); + form = form.text("x-amz-signature", upload_attributes.signature); + + let mut buf = Vec::new(); + reader + .read_to_end(&mut buf) + .expect("infallible Read instance"); + + form = form.text("Content-Type", "application/octet-stream"); + form = form.text("Content-Length", buf.len().to_string()); + form = form.part("file", Part::bytes(buf).file_name(filename)); let response = self .request( Method::POST, Endpoint::Cdn(0), path, - &[], HttpAuthOverride::NoOverride, - Some(RequestBody { - contents: body_contents, - content_type, - }), - ) + )? + .multipart(form) + .send_to_signal() .await?; debug!("HyperPushService::PUT response: {:?}", response); diff --git a/src/push_service/error.rs b/src/push_service/error.rs index d197667d9..2e96f27c7 100644 --- a/src/push_service/error.rs +++ b/src/push_service/error.rs @@ -39,10 +39,10 @@ pub enum ServiceError { #[error("Unexpected response: HTTP {http_code}")] UnhandledResponseCode { http_code: u16 }, - #[error("Websocket error: {reason}")] - WsError { reason: String }, + #[error("Websocket error: {0}")] + WsError(#[from] reqwest_websocket::Error), #[error("Websocket closing: {reason}")] - WsClosing { reason: String }, + WsClosing { reason: &'static str }, #[error("Invalid frame: {reason}")] InvalidFrameError { reason: String }, @@ -85,4 +85,7 @@ pub enum ServiceError { #[error("invalid device name")] InvalidDeviceName, + + #[error("HTTP reqwest error: {0}")] + Http(#[from] reqwest::Error), } diff --git a/src/push_service/keys.rs b/src/push_service/keys.rs index 50894d1f2..93b7d6636 100644 --- a/src/push_service/keys.rs +++ b/src/push_service/keys.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use libsignal_protocol::{IdentityKey, PreKeyBundle, SenderCertificate}; +use reqwest::Method; use serde::Deserialize; use crate::{ @@ -13,8 +14,8 @@ use crate::{ }; use super::{ - HttpAuthOverride, PushService, SenderCertificateJson, ServiceError, - ServiceIdType, VerifyAccountResponse, + HttpAuthOverride, PushService, ReqwestExt, SenderCertificateJson, + ServiceError, ServiceIdType, VerifyAccountResponse, }; #[derive(Debug, Deserialize, Default)] @@ -29,13 +30,17 @@ impl PushService { &mut self, service_id_type: ServiceIdType, ) -> Result { - self.get_json( + self.request( + Method::GET, Endpoint::Service, &format!("/v2/keys?identity={}", service_id_type), - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await + .map_err(Into::into) } pub async fn register_pre_keys( @@ -43,19 +48,17 @@ impl PushService { service_id_type: ServiceIdType, pre_key_state: PreKeyState, ) -> Result<(), ServiceError> { - match self - .put_json( - Endpoint::Service, - &format!("/v2/keys?identity={}", service_id_type), - &[], - HttpAuthOverride::NoOverride, - pre_key_state, - ) - .await - { - Err(ServiceError::JsonDecodeError { .. }) => Ok(()), - r => r, - } + self.request( + Method::PUT, + Endpoint::Service, + &format!("/v2/keys?identity={}", service_id_type), + HttpAuthOverride::NoOverride, + )? + .json(&pre_key_state) + .send_to_signal() + .await?; + + Ok(()) } pub async fn get_pre_key( @@ -67,13 +70,17 @@ impl PushService { format!("/v2/keys/{}/{}?pq=true", destination.uuid, device_id); let mut pre_key_response: PreKeyResponse = self - .get_json( + .request( + Method::GET, Endpoint::Service, &path, - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await?; + assert!(!pre_key_response.devices.is_empty()); let identity = IdentityKey::decode(&pre_key_response.identity_key)?; @@ -92,12 +99,15 @@ impl PushService { format!("/v2/keys/{}/{}?pq=true", destination.uuid, device_id) }; let pre_key_response: PreKeyResponse = self - .get_json( + .request( + Method::GET, Endpoint::Service, &path, - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await?; let mut pre_keys = vec![]; let identity = IdentityKey::decode(&pre_key_response.identity_key)?; @@ -111,12 +121,15 @@ impl PushService { &mut self, ) -> Result { let cert: SenderCertificateJson = self - .get_json( + .request( + Method::GET, Endpoint::Service, "/v1/certificate/delivery", - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await?; Ok(SenderCertificate::deserialize(&cert.certificate)?) } @@ -125,12 +138,15 @@ impl PushService { &mut self, ) -> Result { let cert: SenderCertificateJson = self - .get_json( + .request( + Method::GET, Endpoint::Service, "/v1/certificate/delivery?includeE164=false", - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await?; Ok(SenderCertificate::deserialize(&cert.certificate)?) } @@ -160,23 +176,24 @@ impl PushService { pni_registration_ids: HashMap, signature_valid_on_each_signed_pre_key: bool, } - - let res: VerifyAccountResponse = self - .put_json( - Endpoint::Service, - "/v2/accounts/phone_number_identity_key_distribution", - &[], - HttpAuthOverride::NoOverride, - PniKeyDistributionRequest { - pni_identity_key: pni_identity_key.serialize().into(), - device_messages, - device_pni_signed_prekeys, - device_pni_last_resort_kyber_prekeys, - pni_registration_ids, - signature_valid_on_each_signed_pre_key, - }, - ) - .await?; - Ok(res) + self.request( + Method::PUT, + Endpoint::Service, + "/v2/accounts/phone_number_identity_key_distribution", + HttpAuthOverride::NoOverride, + )? + .json(&PniKeyDistributionRequest { + pni_identity_key: pni_identity_key.serialize().into(), + device_messages, + device_pni_signed_prekeys, + device_pni_last_resort_kyber_prekeys, + pni_registration_ids, + signature_valid_on_each_signed_pre_key, + }) + .send_to_signal() + .await? + .json() + .await + .map_err(Into::into) } } diff --git a/src/push_service/linking.rs b/src/push_service/linking.rs index 5d9b9ded8..d10f20116 100644 --- a/src/push_service/linking.rs +++ b/src/push_service/linking.rs @@ -1,3 +1,4 @@ +use reqwest::Method; use serde::{Deserialize, Serialize}; use uuid::Uuid; @@ -5,7 +6,7 @@ use crate::configuration::Endpoint; use super::{ DeviceActivationRequest, HttpAuth, HttpAuthOverride, PushService, - ServiceError, + ReqwestExt, ServiceError, }; #[derive(Debug, Serialize)] @@ -59,18 +60,30 @@ impl PushService { link_request: &LinkRequest, http_auth: HttpAuth, ) -> Result { - self.put_json( + self.request( + Method::PUT, Endpoint::Service, "/v1/devices/link", - &[], HttpAuthOverride::Identified(http_auth), - link_request, - ) + )? + .json(&link_request) + .send_to_signal() + .await? + .json() .await + .map_err(Into::into) } pub async fn unlink_device(&mut self, id: i64) -> Result<(), ServiceError> { - self.delete_json(Endpoint::Service, &format!("/v1/devices/{}", id), &[]) - .await + self.request( + Method::DELETE, + Endpoint::Service, + format!("/v1/devices/{}", id), + HttpAuthOverride::NoOverride, + )? + .send_to_signal() + .await?; + + Ok(()) } } diff --git a/src/push_service/mod.rs b/src/push_service/mod.rs index f839f185b..daf683617 100644 --- a/src/push_service/mod.rs +++ b/src/push_service/mod.rs @@ -1,36 +1,23 @@ -use std::{io, time::Duration}; +use std::time::Duration; use crate::{ configuration::{Endpoint, ServiceCredentials}, pre_keys::{KyberPreKeyEntity, PreKeyEntity, SignedPreKeyEntity}, prelude::ServiceConfiguration, utils::serde_base64, - websocket::{tungstenite::TungsteniteWebSocket, SignalWebSocket}, + websocket::SignalWebSocket, }; -use bytes::{Buf, Bytes}; use derivative::Derivative; -use headers::{Authorization, HeaderMapExt}; -use http_body_util::{BodyExt, Full}; -use hyper::{ - body::Incoming, - header::{CONTENT_LENGTH, CONTENT_TYPE, USER_AGENT}, - Method, Request, Response, StatusCode, -}; -use hyper_rustls::HttpsConnector; -use hyper_timeout::TimeoutConnector; -use hyper_util::{ - client::legacy::{connect::HttpConnector, Client}, - rt::TokioExecutor, -}; use libsignal_protocol::{ error::SignalProtocolError, kem::{Key, Public}, IdentityKey, PreKeyBundle, PublicKey, }; -use prost::Message as ProtobufMessage; +use protobuf::ProtobufResponseExt; +use reqwest::{Method, RequestBuilder, Response, StatusCode}; +use reqwest_websocket::RequestBuilderExt; use serde::{Deserialize, Serialize}; -use tokio_rustls::rustls; use tracing::{debug_span, Instrument}; pub const KEEPALIVE_TIMEOUT_SECONDS: Duration = Duration::from_secs(55); @@ -67,16 +54,6 @@ pub struct HttpAuth { pub password: String, } -/// This type is used in registration lock handling. -/// It's identical with HttpAuth, but used to avoid type confusion. -#[derive(Derivative, Clone, Serialize, Deserialize)] -#[derivative(Debug)] -pub struct AuthCredentials { - pub username: String, - #[derivative(Debug = "ignore")] - pub password: String, -} - #[derive(Debug, Clone)] pub enum HttpAuthOverride { NoOverride, @@ -164,127 +141,83 @@ pub struct StaleDevices { pub stale_devices: Vec, } -#[derive(Debug)] -struct RequestBody { - contents: Vec, - content_type: String, -} - #[derive(Clone)] pub struct PushService { cfg: ServiceConfiguration, - user_agent: String, credentials: Option, - client: - Client>, Full>, + client: reqwest::Client, } impl PushService { pub fn new( cfg: impl Into, credentials: Option, - user_agent: String, + user_agent: impl AsRef, ) -> Self { let cfg = cfg.into(); - let tls_config = Self::tls_config(&cfg); - - let https = hyper_rustls::HttpsConnectorBuilder::new() - .with_tls_config(tls_config) - .https_only() - .enable_http1() - .build(); - - // as in Signal-Android - let mut timeout_connector = TimeoutConnector::new(https); - timeout_connector.set_connect_timeout(Some(Duration::from_secs(10))); - timeout_connector.set_read_timeout(Some(Duration::from_secs(65))); - timeout_connector.set_write_timeout(Some(Duration::from_secs(65))); - - let client: Client<_, Full> = - Client::builder(TokioExecutor::new()).build(timeout_connector); + let client = reqwest::ClientBuilder::new() + .add_root_certificate( + reqwest::Certificate::from_pem( + &cfg.certificate_authority.as_bytes(), + ) + .unwrap(), + ) + .connect_timeout(Duration::from_secs(10)) + .timeout(Duration::from_secs(65)) + .user_agent(user_agent.as_ref()) + .build() + .unwrap(); Self { cfg, credentials: credentials.and_then(|c| c.authorization()), client, - user_agent, } } - fn tls_config(cfg: &ServiceConfiguration) -> rustls::ClientConfig { - let mut cert_bytes = io::Cursor::new(&cfg.certificate_authority); - let roots = rustls_pemfile::certs(&mut cert_bytes); - - let mut root_certs = rustls::RootCertStore::empty(); - root_certs.add_parsable_certificates( - roots.map(|c| c.expect("parsable PEM files")), - ); - - rustls::ClientConfig::builder() - .with_root_certificates(root_certs) - .with_no_client_auth() - } - - #[tracing::instrument(skip(self, path, body), fields(path = %path.as_ref()))] - async fn request( + #[tracing::instrument(skip(self, path), fields(path = %path.as_ref()))] + pub fn request( &self, method: Method, endpoint: Endpoint, path: impl AsRef, - additional_headers: &[(&str, &str)], - credentials_override: HttpAuthOverride, - body: Option, - ) -> Result, ServiceError> { + auth_override: HttpAuthOverride, + ) -> Result { let url = self.cfg.base_url(endpoint).join(path.as_ref())?; - let mut builder = Request::builder() - .method(method) - .uri(url.as_str()) - .header(USER_AGENT, &self.user_agent); - - for (header, value) in additional_headers { - builder = builder.header(*header, *value); - } + let mut builder = self.client.request(method, url); - match credentials_override { + builder = match auth_override { HttpAuthOverride::NoOverride => { if let Some(HttpAuth { username, password }) = self.credentials.as_ref() { + builder.basic_auth(username, Some(password)) + } else { builder - .headers_mut() - .unwrap() - .typed_insert(Authorization::basic(username, password)); } }, HttpAuthOverride::Identified(HttpAuth { username, password }) => { - builder - .headers_mut() - .unwrap() - .typed_insert(Authorization::basic(&username, &password)); + builder.basic_auth(username, Some(password)) }, - HttpAuthOverride::Unidentified => (), + HttpAuthOverride::Unidentified => builder, }; - let request = if let Some(RequestBody { - contents, - content_type, - }) = body - { - builder - .header(CONTENT_LENGTH, contents.len() as u64) - .header(CONTENT_TYPE, content_type) - .body(Full::new(Bytes::from(contents))) - .unwrap() - } else { - builder.body(Full::default()).unwrap() - }; + Ok(builder) + } +} - let mut response = self.client.request(request).await.map_err(|e| { - ServiceError::SendError { - reason: e.to_string(), - } - })?; +#[async_trait::async_trait] +pub(crate) trait ReqwestExt +where + Self: Sized, +{ + async fn send_to_signal(self) -> Result; +} +#[async_trait::async_trait] +impl ReqwestExt for RequestBuilder { + async fn send_to_signal(self) -> Result { + let response = self.send().await?; match response.status() { StatusCode::OK => Ok(response), StatusCode::NO_CONTENT => Ok(response), @@ -301,10 +234,10 @@ impl PushService { }, StatusCode::CONFLICT => { let mismatched_devices = - Self::json(&mut response).await.map_err(|e| { + response.json().await.map_err(|error| { tracing::error!( - "Failed to decode HTTP 409 response: {}", - e + %error, + "failed to decode HTTP 409 status" ); ServiceError::UnhandledResponseCode { http_code: StatusCode::CONFLICT.as_u16(), @@ -315,25 +248,17 @@ impl PushService { )) }, StatusCode::GONE => { - let stale_devices = - Self::json(&mut response).await.map_err(|e| { - tracing::error!( - "Failed to decode HTTP 410 response: {}", - e - ); - ServiceError::UnhandledResponseCode { - http_code: StatusCode::GONE.as_u16(), - } - })?; + let stale_devices = response.json().await.map_err(|error| { + tracing::error!(%error, "failed to decode HTTP 410 status"); + ServiceError::UnhandledResponseCode { + http_code: StatusCode::GONE.as_u16(), + } + })?; Err(ServiceError::StaleDevices(stale_devices)) }, StatusCode::LOCKED => { - let locked = Self::json(&mut response).await.map_err(|e| { - tracing::error!( - ?response, - "Failed to decode HTTP 423 response: {}", - e - ); + let locked = response.json().await.map_err(|error| { + tracing::error!(%error, "failed to decode HTTP 423 status"); ServiceError::UnhandledResponseCode { http_code: StatusCode::LOCKED.as_u16(), } @@ -342,10 +267,10 @@ impl PushService { }, StatusCode::PRECONDITION_REQUIRED => { let proof_required = - Self::json(&mut response).await.map_err(|e| { + response.json().await.map_err(|error| { tracing::error!( - "Failed to decode HTTP 428 response: {}", - e + %error, + "failed to decode HTTP 428 status" ); ServiceError::UnhandledResponseCode { http_code: StatusCode::PRECONDITION_REQUIRED @@ -356,306 +281,107 @@ impl PushService { }, // XXX: fill in rest from PushServiceSocket code => { - tracing::trace!( - "Unhandled response {} with body: {}", - code.as_u16(), - Self::text(&mut response).await?, - ); + let response_text = response.text().await?; + tracing::trace!(status_code =% code, body = response_text, "unhandled HTTP response"); Err(ServiceError::UnhandledResponseCode { http_code: code.as_u16(), }) }, } } - - async fn body( - response: &mut Response, - ) -> Result { - Ok(response - .collect() - .await - .map_err(|e| ServiceError::ResponseError { - reason: format!("failed to aggregate HTTP response body: {e}"), - })? - .aggregate()) - } - - #[tracing::instrument(skip(response), fields(status = %response.status()))] - async fn json( - response: &mut Response, - ) -> Result - where - for<'de> T: Deserialize<'de>, - { - let body = Self::body(response).await?; - - if body.has_remaining() { - serde_json::from_reader(body.reader()) - } else { - serde_json::from_value(serde_json::Value::Null) - } - .map_err(|e| ServiceError::JsonDecodeError { - reason: e.to_string(), - }) - } - - #[tracing::instrument(skip(response), fields(status = %response.status()))] - async fn protobuf( - response: &mut Response, - ) -> Result - where - M: ProtobufMessage + Default, - { - let body = Self::body(response).await?; - M::decode(body).map_err(ServiceError::ProtobufDecodeError) - } - - #[tracing::instrument(skip(response), fields(status = %response.status()))] - async fn text( - response: &mut Response, - ) -> Result { - let body = Self::body(response).await?; - io::read_to_string(body.reader()).map_err(|e| { - ServiceError::ResponseError { - reason: format!("failed to read HTTP response body: {e}"), - } - }) - } } -impl PushService { - #[tracing::instrument(skip(self))] - pub(crate) async fn get_json( - &mut self, - service: Endpoint, - path: &str, - additional_headers: &[(&str, &str)], - credentials_override: HttpAuthOverride, - ) -> Result - where - for<'de> T: Deserialize<'de>, - { - let mut response = self - .request( - Method::GET, - service, - path, - additional_headers, - credentials_override, - None, - ) - .await?; +pub(crate) mod protobuf { + use async_trait::async_trait; + use prost::{EncodeError, Message}; + use reqwest::{header, RequestBuilder, Response}; - Self::json(&mut response).await - } + use super::ServiceError; - #[tracing::instrument(skip(self))] - async fn delete_json( - &mut self, - service: Endpoint, - path: &str, - additional_headers: &[(&str, &str)], - ) -> Result + pub(crate) trait ProtobufRequestBuilderExt where - for<'de> T: Deserialize<'de>, + Self: Sized, { - let mut response = self - .request( - Method::DELETE, - service, - path, - additional_headers, - HttpAuthOverride::NoOverride, - None, - ) - .await?; - - Self::json(&mut response).await - } - - #[tracing::instrument(skip(self, value))] - pub async fn put_json( - &mut self, - service: Endpoint, - path: &str, - additional_headers: &[(&str, &str)], - credentials_override: HttpAuthOverride, - value: S, - ) -> Result - where - for<'de> D: Deserialize<'de>, - S: Send + Serialize, - { - let json = serde_json::to_vec(&value).map_err(|e| { - ServiceError::JsonDecodeError { - reason: e.to_string(), - } - })?; - - let mut response = self - .request( - Method::PUT, - service, - path, - additional_headers, - credentials_override, - Some(RequestBody { - contents: json, - content_type: "application/json".into(), - }), - ) - .await?; - - Self::json(&mut response).await + /// Set the request payload encoded as protobuf. + /// Sets the `Content-Type` header to `application/protobuf` + #[allow(dead_code)] + fn protobuf( + self, + value: T, + ) -> Result; } - #[tracing::instrument(skip(self, value))] - async fn patch_json( - &mut self, - service: Endpoint, - path: &str, - additional_headers: &[(&str, &str)], - credentials_override: HttpAuthOverride, - value: S, - ) -> Result - where - for<'de> D: Deserialize<'de>, - S: Send + Serialize, - { - let json = serde_json::to_vec(&value).map_err(|e| { - ServiceError::JsonDecodeError { - reason: e.to_string(), - } - })?; - - let mut response = self - .request( - Method::PATCH, - service, - path, - additional_headers, - credentials_override, - Some(RequestBody { - contents: json, - content_type: "application/json".into(), - }), - ) - .await?; - - Self::json(&mut response).await + #[async_trait::async_trait] + pub(crate) trait ProtobufResponseExt { + /// Get the response body decoded from Protobuf + async fn protobuf( + self, + ) -> Result; } - #[tracing::instrument(skip(self, value))] - async fn post_json( - &mut self, - service: Endpoint, - path: &str, - additional_headers: &[(&str, &str)], - credentials_override: HttpAuthOverride, - value: S, - ) -> Result - where - for<'de> D: Deserialize<'de>, - S: Send + Serialize, - { - let json = serde_json::to_vec(&value).map_err(|e| { - ServiceError::JsonDecodeError { - reason: e.to_string(), - } - })?; - - let mut response = self - .request( - Method::POST, - service, - path, - additional_headers, - credentials_override, - Some(RequestBody { - contents: json, - content_type: "application/json".into(), - }), - ) - .await?; - - Self::json(&mut response).await - } - - #[tracing::instrument(skip(self))] - async fn get_protobuf( - &mut self, - service: Endpoint, - path: &str, - additional_headers: &[(&str, &str)], - credentials_override: HttpAuthOverride, - ) -> Result - where - T: Default + ProtobufMessage, - { - let mut response = self - .request( - Method::GET, - service, - path, - additional_headers, - credentials_override, - None, - ) - .await?; - - Self::protobuf(&mut response).await + impl ProtobufRequestBuilderExt for RequestBuilder { + fn protobuf( + self, + value: T, + ) -> Result { + let mut buf = Vec::new(); + value.encode(&mut buf)?; + let this = + self.header(header::CONTENT_TYPE, "application/protobuf"); + Ok(this.body(buf)) + } } - #[tracing::instrument(skip(self, value))] - async fn put_protobuf( - &mut self, - service: Endpoint, - path: &str, - additional_headers: &[(&str, &str)], - value: S, - ) -> Result - where - D: Default + ProtobufMessage, - S: Sized + ProtobufMessage, - { - let protobuf = value.encode_to_vec(); - - let mut response = self - .request( - Method::PUT, - service, - path, - additional_headers, - HttpAuthOverride::NoOverride, - Some(RequestBody { - contents: protobuf, - content_type: "application/x-protobuf".into(), - }), - ) - .await?; - - Self::protobuf(&mut response).await + #[async_trait] + impl ProtobufResponseExt for Response { + async fn protobuf( + self, + ) -> Result { + let body = self.bytes().await?; + let decoded = T::decode(body)?; + Ok(decoded) + } } +} +impl PushService { pub async fn ws( &mut self, path: &str, keepalive_path: &str, - additional_headers: &[(&str, &str)], + additional_headers: &[(&'static str, &str)], credentials: Option, ) -> Result { let span = debug_span!("websocket"); - let (ws, stream) = TungsteniteWebSocket::with_tls_config( - Self::tls_config(&self.cfg), - self.cfg.base_url(Endpoint::Service), - path, - additional_headers, - credentials.as_ref(), - ) - .instrument(span.clone()) - .await?; + + let endpoint = self.cfg.base_url(Endpoint::Service); + let mut url = endpoint.join(path).expect("valid url"); + url.set_scheme("wss").expect("valid https base url"); + + if let Some(credentials) = credentials { + url.query_pairs_mut() + .append_pair("login", &credentials.login()) + .append_pair( + "password", + credentials.password.as_ref().expect("a password"), + ); + } + + let mut builder = self.client.get(url); + for (key, value) in additional_headers { + builder = builder.header(*key, *value); + } + + let ws = builder + .upgrade() + .send() + .await? + .into_websocket() + .instrument(span.clone()) + .await?; + let (ws, task) = - SignalWebSocket::from_socket(ws, stream, keepalive_path.to_owned()); + SignalWebSocket::from_socket(ws, keepalive_path.to_owned()); let task = task.instrument(span); tokio::task::spawn(task); Ok(ws) @@ -665,13 +391,17 @@ impl PushService { &mut self, credentials: HttpAuth, ) -> Result { - self.get_protobuf( + self.request( + Method::GET, Endpoint::Storage, "/v1/groups/", - &[], HttpAuthOverride::Identified(credentials), - ) + )? + .send() + .await? + .protobuf() .await + .map_err(Into::into) } } diff --git a/src/push_service/profile.rs b/src/push_service/profile.rs index 0a444fef6..b05aef56c 100644 --- a/src/push_service/profile.rs +++ b/src/push_service/profile.rs @@ -1,3 +1,4 @@ +use reqwest::Method; use serde::{Deserialize, Serialize}; use zkgroup::profiles::{ProfileKeyCommitment, ProfileKeyVersion}; @@ -5,12 +6,12 @@ use crate::{ configuration::Endpoint, content::ServiceError, profile_cipher::ProfileCipherError, - push_service::{AvatarWrite, HttpAuthOverride}, + push_service::AvatarWrite, utils::{serde_base64, serde_optional_base64}, Profile, ServiceAddress, }; -use super::{DeviceCapabilities, PushService}; +use super::{DeviceCapabilities, HttpAuthOverride, PushService, ReqwestExt}; #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] @@ -102,13 +103,17 @@ impl PushService { format!("/v1/profile/{}", address.uuid) }; // TODO: set locale to en_US - self.get_json( + self.request( + Method::GET, Endpoint::Service, &endpoint, - &[], HttpAuthOverride::NoOverride, - ) + )? + .send_to_signal() + .await? + .json() .await + .map_err(Into::into) } pub async fn retrieve_profile_avatar( @@ -161,35 +166,32 @@ impl PushService { }; // XXX this should be a struct; cfr ProfileAvatarUploadAttributes - let response: Result = self - .put_json( + let upload_url: Result = self + .request( + Method::PUT, Endpoint::Service, "/v1/profile", - &[], HttpAuthOverride::NoOverride, - command, - ) + )? + .json(&command) + .send_to_signal() + .await? + .json() .await; - match (response, avatar) { - (Ok(_url), AvatarWrite::NewAvatar(_avatar)) => { + + match (upload_url, avatar) { + (_url, AvatarWrite::NewAvatar(_avatar)) => { // FIXME unreachable!("Uploading avatar unimplemented"); }, // FIXME cleanup when #54883 is stable and MSRV: // or-patterns syntax is experimental // see issue #54883 for more information - ( - Err(ServiceError::JsonDecodeError { .. }), - AvatarWrite::RetainAvatar, - ) - | ( - Err(ServiceError::JsonDecodeError { .. }), - AvatarWrite::NoAvatar, - ) => { + (Err(_), AvatarWrite::RetainAvatar) + | (Err(_), AvatarWrite::NoAvatar) => { // OWS sends an empty string when there's no attachment Ok(None) }, - (Err(e), _) => Err(e), (Ok(_resp), AvatarWrite::RetainAvatar) | (Ok(_resp), AvatarWrite::NoAvatar) => { tracing::warn!( diff --git a/src/push_service/registration.rs b/src/push_service/registration.rs index 6d353b962..f968b6add 100644 --- a/src/push_service/registration.rs +++ b/src/push_service/registration.rs @@ -1,15 +1,27 @@ +use derivative::Derivative; use libsignal_protocol::IdentityKey; +use reqwest::Method; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use super::{AccountAttributes, AuthCredentials, PushService, ServiceError}; +use super::{AccountAttributes, PushService, ServiceError}; use crate::{ configuration::Endpoint, pre_keys::{KyberPreKeyEntity, SignedPreKeyEntity}, - push_service::HttpAuthOverride, + push_service::{HttpAuthOverride, ReqwestExt}, utils::serde_base64, }; +/// This type is used in registration lock handling. +/// It's identical with HttpAuth, but used to avoid type confusion. +#[derive(Derivative, Clone, Serialize, Deserialize)] +#[derivative(Debug)] +pub struct AuthCredentials { + pub username: String, + #[derivative(Debug = "ignore")] + pub password: String, +} + #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct RegistrationLockFailure { @@ -31,21 +43,13 @@ pub struct VerifyAccountResponse { pub number: Option, } -#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize)] +#[serde(rename_all = "snake_case")] pub enum VerificationTransport { Sms, Voice, } -impl VerificationTransport { - pub fn as_str(&self) -> &str { - match self { - Self::Sms => "sms", - Self::Voice => "voice", - } - } -} - #[derive(Clone, Debug)] pub enum RegistrationMethod<'a> { SessionId(&'a str), @@ -148,7 +152,13 @@ impl PushService { device_activation_request: DeviceActivationRequest, } - let req = RegistrationSessionRequestBody { + self.request( + Method::POST, + Endpoint::Service, + "/v1/registration", + HttpAuthOverride::NoOverride, + )? + .json(&RegistrationSessionRequestBody { session_id: registration_method.session_id(), recovery_password: registration_method.recovery_password(), account_attributes, @@ -157,18 +167,12 @@ impl PushService { pni_identity_key: pni_identity_key.serialize().into(), device_activation_request, every_signed_key_valid: true, - }; - - let res: VerifyAccountResponse = self - .post_json( - Endpoint::Service, - "/v1/registration", - &[], - HttpAuthOverride::NoOverride, - req, - ) - .await?; - Ok(res) + }) + .send_to_signal() + .await? + .json() + .await + .map_err(Into::into) } // Equivalent of Java's @@ -190,24 +194,24 @@ impl PushService { push_token_type: Option<&'a str>, } - let req = VerificationSessionMetadataRequestBody { + self.request( + Method::POST, + Endpoint::Service, + "/v1/verification/session", + HttpAuthOverride::Unidentified, + )? + .json(&VerificationSessionMetadataRequestBody { number, push_token_type: push_token.as_ref().map(|_| "fcm"), push_token, mcc, mnc, - }; - - let res: RegistrationSessionMetadataResponse = self - .post_json( - Endpoint::Service, - "/v1/verification/session", - &[], - HttpAuthOverride::Unidentified, - req, - ) - .await?; - Ok(res) + }) + .send_to_signal() + .await? + .json() + .await + .map_err(Into::into) } pub async fn patch_verification_session<'a>( @@ -230,25 +234,25 @@ impl PushService { push_token_type: Option<&'a str>, } - let req = UpdateVerificationSessionRequestBody { + self.request( + Method::PATCH, + Endpoint::Service, + &format!("/v1/verification/session/{}", session_id), + HttpAuthOverride::Unidentified, + )? + .json(&UpdateVerificationSessionRequestBody { captcha, push_token_type: push_token.as_ref().map(|_| "fcm"), push_token, mcc, mnc, push_challenge, - }; - - let res: RegistrationSessionMetadataResponse = self - .patch_json( - Endpoint::Service, - &format!("/v1/verification/session/{}", session_id), - &[], - HttpAuthOverride::Unidentified, - req, - ) - .await?; - Ok(res) + }) + .send_to_signal() + .await? + .json() + .await + .map_err(Into::into) } // Equivalent of Java's @@ -271,20 +275,24 @@ impl PushService { // locale: Option, transport: VerificationTransport, ) -> Result { - let mut req = std::collections::HashMap::new(); - req.insert("transport", transport.as_str()); - req.insert("client", client); + #[derive(Debug, Serialize)] + struct VerificationCodeRequest<'a> { + transport: VerificationTransport, + client: &'a str, + } - let res: RegistrationSessionMetadataResponse = self - .post_json( - Endpoint::Service, - &format!("/v1/verification/session/{}/code", session_id), - &[], - HttpAuthOverride::Unidentified, - req, - ) - .await?; - Ok(res) + self.request( + Method::POST, + Endpoint::Service, + &format!("/v1/verification/session/{}/code", session_id), + HttpAuthOverride::Unidentified, + )? + .json(&VerificationCodeRequest { transport, client }) + .send_to_signal() + .await? + .json() + .await + .map_err(Into::into) } pub async fn submit_verification_code( @@ -292,18 +300,24 @@ impl PushService { session_id: &str, verification_code: &str, ) -> Result { - let mut req = std::collections::HashMap::new(); - req.insert("code", verification_code); + #[derive(Debug, Serialize)] + struct VerificationCode<'a> { + code: &'a str, + } - let res: RegistrationSessionMetadataResponse = self - .put_json( - Endpoint::Service, - &format!("/v1/verification/session/{}/code", session_id), - &[], - HttpAuthOverride::Unidentified, - req, - ) - .await?; - Ok(res) + self.request( + Method::PUT, + Endpoint::Service, + &format!("/v1/verification/session/{}/code", session_id), + HttpAuthOverride::Unidentified, + )? + .json(&VerificationCode { + code: verification_code, + }) + .send_to_signal() + .await? + .json() + .await + .map_err(Into::into) } } diff --git a/src/sender.rs b/src/sender.rs index ae8a862e1..e755d8460 100644 --- a/src/sender.rs +++ b/src/sender.rs @@ -222,9 +222,10 @@ where .get_attachment_v2_upload_attributes() .instrument(tracing::trace_span!("requesting upload attributes")) .await?; + let (id, digest) = self .service - .upload_attachment(&attrs, &mut std::io::Cursor::new(&contents)) + .upload_attachment(attrs, &mut std::io::Cursor::new(&contents)) .instrument(tracing::trace_span!("Uploading attachment")) .await?; diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs index 4736edc45..b74d02ff1 100644 --- a/src/websocket/mod.rs +++ b/src/websocket/mod.rs @@ -9,18 +9,18 @@ use futures::channel::{mpsc, oneshot}; use futures::future::BoxFuture; use futures::prelude::*; use futures::stream::FuturesUnordered; -use prost::Message; +use reqwest_websocket::WebSocket; use serde::{Deserialize, Serialize}; +use tokio::time::Instant; -use crate::messagepipe::{WebSocketService, WebSocketStreamItem}; use crate::proto::{ web_socket_message, WebSocketMessage, WebSocketRequestMessage, WebSocketResponseMessage, }; -use crate::push_service::{MismatchedDevices, ServiceError}; +use crate::push_service::{self, MismatchedDevices, ServiceError}; mod sender; -pub(crate) mod tungstenite; +// pub(crate) mod tungstenite; type RequestStreamItem = ( WebSocketRequestMessage, @@ -61,7 +61,7 @@ struct SignalWebSocketInner { stream: Option, } -struct SignalWebSocketProcess { +struct SignalWebSocketProcess { /// Whether to enable keep-alive or not (and send a request to this path) keep_alive_path: String, @@ -85,16 +85,16 @@ struct SignalWebSocketProcess { >, // WS backend stuff - ws: WS, - stream: WS::Stream, + ws: WebSocket, } -impl SignalWebSocketProcess { +impl SignalWebSocketProcess { async fn process_frame( &mut self, - frame: Bytes, + frame: Vec, ) -> Result<(), ServiceError> { - let msg = WebSocketMessage::decode(frame)?; + use prost::Message; + let msg = WebSocketMessage::decode(Bytes::from(frame))?; if let Some(request) = &msg.request { tracing::trace!( "decoded WebSocketMessage request {{ r#type: {:?}, verb: {:?}, path: {:?}, body: {} bytes, headers: {:?}, id: {:?} }}", @@ -128,7 +128,7 @@ impl SignalWebSocketProcess { let (sink, recv) = oneshot::channel(); tracing::trace!("sending request with body"); self.request_sink.send((request, sink)).await.map_err( - |_| ServiceError::WsError { + |_| ServiceError::WsClosing { reason: "request handler failed".into(), }, )?; @@ -155,10 +155,11 @@ impl SignalWebSocketProcess { } else if let Some(_x) = self.outgoing_keep_alive_set.take(&id) { - if response.status() != 200 { + let status_code = response.status(); + if status_code != 200 { tracing::warn!( - "Response code for keep-alive is not 200: {:?}", - response + status_code, + "response code for keep-alive != 200" ); return Err(ServiceError::UnhandledResponseCode { http_code: response.status() as u16, @@ -166,8 +167,8 @@ impl SignalWebSocketProcess { } } else { tracing::warn!( - "Response for non existing request: {:?}", - response + ?response, + "response for non existing request" ); } } @@ -193,12 +194,40 @@ impl SignalWebSocketProcess { } async fn run(mut self) -> Result<(), ServiceError> { - loop { + let mut ka_interval = tokio::time::interval_at( + Instant::now(), + push_service::KEEPALIVE_TIMEOUT_SECONDS, + ); + + Ok(loop { futures::select! { + _ = ka_interval.tick().fuse() => { + use prost::Message; + tracing::debug!("sending keep-alive"); + let request = WebSocketRequestMessage { + id: Some(self.next_request_id()), + path: Some(self.keep_alive_path.clone()), + verb: Some("GET".into()), + ..Default::default() + }; + self.outgoing_keep_alive_set.insert(request.id.unwrap()); + let msg = WebSocketMessage { + r#type: Some(web_socket_message::Type::Request.into()), + request: Some(request), + ..Default::default() + }; + let buffer = msg.encode_to_vec(); + if let Err(e) = self.ws.send(reqwest_websocket::Message::Binary(buffer)).await { + tracing::info!("Websocket sink has closed: {:?}.", e); + break; + }; + }, // Process requests from the application, forward them to Signal x = self.requests.next() => { match x { Some((mut request, responder)) => { + use prost::Message; + // Regenerate ID if already in the table request.id = Some( request @@ -222,47 +251,44 @@ impl SignalWebSocketProcess { ..Default::default() }; let buffer = msg.encode_to_vec(); - self.ws.send_message(buffer.into()).await? + self.ws.send(reqwest_websocket::Message::Binary(buffer)).await? } None => { - return Err(ServiceError::WsError { - reason: "SignalWebSocket: end of application request stream; socket closing".into() + return Err(ServiceError::WsClosing { + reason: "SignalWebSocket: end of application request stream; socket closing" }); } } } - web_socket_item = self.stream.next() => { + web_socket_item = self.ws.next().fuse() => { + use reqwest_websocket::Message; match web_socket_item { - Some(WebSocketStreamItem::Message(frame)) => { + Some(Ok(Message::Close { code, reason })) => { + tracing::warn!(%code, reason, "websocket closed"); + break; + }, + Some(Ok(Message::Binary(frame))) => { self.process_frame(frame).await?; } - Some(WebSocketStreamItem::KeepAliveRequest) => { - // XXX: would be nicer if we could drop this request into the request - // queue above. - tracing::debug!("Sending keep alive upon request"); - let request = WebSocketRequestMessage { - id: Some(self.next_request_id()), - path: Some(self.keep_alive_path.clone()), - verb: Some("GET".into()), - ..Default::default() - }; - self.outgoing_keep_alive_set.insert(request.id.unwrap()); - let msg = WebSocketMessage { - r#type: Some(web_socket_message::Type::Request.into()), - request: Some(request), - ..Default::default() - }; - let buffer = msg.encode_to_vec(); - self.ws.send_message(buffer.into()).await?; + Some(Ok(Message::Ping(_))) => { + tracing::trace!("received ping"); } + Some(Ok(Message::Pong(_))) => { + tracing::trace!("received pong"); + } + Some(Ok(Message::Text(_))) => { + tracing::trace!("received text (unsupported, skipping)"); + } + Some(Err(e)) => return Err(ServiceError::WsError(e)), None => { - return Err(ServiceError::WsError { - reason: "end of web request stream; socket closing".into() + return Err(ServiceError::WsClosing { + reason: "end of web request stream; socket closing" }); } } } response = self.outgoing_responses.next() => { + use prost::Message; match response { Some(Ok(response)) => { tracing::trace!("sending response {:?}", response); @@ -273,7 +299,7 @@ impl SignalWebSocketProcess { ..Default::default() }; let buffer = msg.encode_to_vec(); - self.ws.send_message(buffer.into()).await?; + self.ws.send(buffer.into()).await?; } Some(Err(e)) => { tracing::error!("could not generate response to a Signal request; responder was canceled: {}. Continuing.", e); @@ -284,7 +310,7 @@ impl SignalWebSocketProcess { } } } - } + }) } } @@ -293,9 +319,8 @@ impl SignalWebSocket { self.inner.lock().unwrap() } - pub fn from_socket( - ws: WS, - stream: WS::Stream, + pub fn from_socket( + ws: WebSocket, keep_alive_path: String, ) -> (Self, impl Future) { // Create process @@ -316,7 +341,6 @@ impl SignalWebSocket { .into_iter() .collect(), ws, - stream, }; let process = process.run().map(|x| match x { Ok(()) => (), @@ -389,15 +413,14 @@ impl SignalWebSocket { async move { if let Err(_e) = request_sink.send((r, sink)).await { return Err(ServiceError::WsClosing { - reason: "WebSocket closing while sending request.".into(), + reason: "WebSocket closing while sending request.", }); } // Handle the oneshot sender error for dropped senders. match recv.await { Ok(x) => x, Err(_) => Err(ServiceError::WsClosing { - reason: "WebSocket closing while waiting for a response." - .into(), + reason: "WebSocket closing while waiting for a response.", }), } } diff --git a/src/websocket/tungstenite.rs b/src/websocket/tungstenite.rs deleted file mode 100644 index ae40caf09..000000000 --- a/src/websocket/tungstenite.rs +++ /dev/null @@ -1,202 +0,0 @@ -use std::sync::Arc; - -use async_tungstenite::{ - tokio::connect_async_with_tls_connector, - tungstenite::{ - client::IntoClientRequest, - http::{HeaderName, StatusCode}, - Error as TungsteniteError, Message, - }, -}; -use bytes::Bytes; -use futures::{channel::mpsc::*, prelude::*}; -use tokio::time::Instant; -use tokio_rustls::rustls; -use url::Url; - -use crate::{ - configuration::ServiceCredentials, - push_service::{self, ServiceError}, -}; - -use crate::messagepipe::{WebSocketService, WebSocketStreamItem}; - -pub struct TungsteniteWebSocket { - socket_sink: - Box + Send + Unpin>, -} - -#[derive(thiserror::Error, Debug)] -pub enum TungsteniteWebSocketError { - #[error("error while connecting to websocket: {0}")] - ConnectionError(#[from] TungsteniteError), -} - -impl From for ServiceError { - fn from(e: TungsteniteWebSocketError) -> Self { - match e { - TungsteniteWebSocketError::ConnectionError( - TungsteniteError::Http(response), - ) => match response.status() { - StatusCode::FORBIDDEN => ServiceError::Unauthorized, - s => ServiceError::WsError { - reason: format!("HTTP status {}", s), - }, - }, - e => ServiceError::WsError { - reason: e.to_string(), - }, - } - } -} - -// Process the WebSocket, until it times out. -async fn process( - socket_stream: S, - mut incoming_sink: Sender, -) -> Result<(), TungsteniteWebSocketError> -where - S: Stream> + Unpin, -{ - let mut socket_stream = socket_stream.fuse(); - - let mut ka_interval = tokio::time::interval_at( - Instant::now(), - push_service::KEEPALIVE_TIMEOUT_SECONDS, - ); - - loop { - tokio::select! { - _ = ka_interval.tick() => { - tracing::trace!("Triggering keep-alive"); - if let Err(e) = incoming_sink.send(WebSocketStreamItem::KeepAliveRequest).await { - tracing::info!("Websocket sink has closed: {:?}.", e); - break; - }; - }, - frame = socket_stream.next() => { - let frame = if let Some(frame) = frame { - frame - } else { - tracing::info!("process: Socket stream ended"); - break; - }; - - let frame = match frame? { - Message::Binary(s) => s, - Message::Ping(msg) => { - tracing::warn!("Received Ping({:?})", msg); - - continue; - }, - Message::Pong(msg) => { - tracing::trace!("Received Pong({:?})", msg); - - continue; - }, - Message::Text(frame) => { - tracing::warn!("Message::Text {:?}", frame); - - // this is a protocol violation, maybe break; is better? - continue; - }, - - Message::Close(c) => { - tracing::warn!("Websocket closing: {:?}", c); - - break; - }, - Message::Frame(_f) => unreachable!("handled internally in Tungstenite") - }; - - // Match SendError - if let Err(e) = incoming_sink.send(WebSocketStreamItem::Message(Bytes::from(frame))).await { - tracing::info!("Websocket sink has closed: {:?}.", e); - break; - } - }, - } - } - Ok(()) -} - -impl TungsteniteWebSocket { - pub(crate) async fn with_tls_config( - tls_config: rustls::ClientConfig, - base_url: impl std::borrow::Borrow, - path: &str, - additional_headers: &[(&str, &str)], - credentials: Option<&ServiceCredentials>, - ) -> Result< - (Self, ::Stream), - TungsteniteWebSocketError, - > { - let mut url = base_url.borrow().join(path).expect("valid url"); - url.set_scheme("wss").expect("valid https base url"); - - let tls_connector = - tokio_rustls::TlsConnector::from(Arc::new(tls_config)); - - if let Some(credentials) = credentials { - url.query_pairs_mut() - .append_pair("login", &credentials.login()) - .append_pair( - "password", - credentials.password.as_ref().expect("a password"), - ); - } - - tracing::trace!("Will start websocket at {:?}", url); - - let mut request = url.into_client_request()?; - - for (key, value) in additional_headers { - request.headers_mut().insert( - // FromStr is implemnted for HeaderName, but that expects a &'static str... - HeaderName::from_bytes(key.as_bytes()) - .expect("valid header name"), - value.parse().expect("valid header value"), - ); - } - - let (socket_stream, response) = - connect_async_with_tls_connector(request, Some(tls_connector)) - .await?; - - tracing::debug!("WebSocket connected: {:?}", response); - - let (incoming_sink, incoming_stream) = channel(5); - - let (socket_sink, socket_stream) = socket_stream.split(); - let processing_task = process(socket_stream, incoming_sink); - - // When the processing_task stops, the consuming stream and sink also - // terminate. - tokio::spawn(processing_task.map(|v| match v { - Ok(()) => (), - Err(e) => { - tracing::warn!("Processing task terminated with error: {:?}", e) - }, - })); - - Ok(( - Self { - socket_sink: Box::new(socket_sink), - }, - incoming_stream, - )) - } -} - -#[async_trait::async_trait] -impl WebSocketService for TungsteniteWebSocket { - type Stream = Receiver; - - async fn send_message(&mut self, msg: Bytes) -> Result<(), ServiceError> { - self.socket_sink - .send(Message::Binary(msg.to_vec())) - .await - .map_err(TungsteniteWebSocketError::from)?; - Ok(()) - } -} From 49baa11b2432410565c11234fb8518aa9b0464fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20F=C3=A9ron?= Date: Thu, 17 Oct 2024 21:46:37 +0200 Subject: [PATCH 02/14] Use PushService instead of websocket like in libsignal.kt --- src/profile_service.rs | 38 ++++++++++++++++++++-------------- src/sender.rs | 1 - src/websocket/mod.rs | 47 +++++++++++++++++++++--------------------- 3 files changed, 46 insertions(+), 40 deletions(-) diff --git a/src/profile_service.rs b/src/profile_service.rs index e23a4b51b..0c02b12ba 100644 --- a/src/profile_service.rs +++ b/src/profile_service.rs @@ -1,17 +1,21 @@ +use reqwest::Method; + use crate::{ - proto::WebSocketRequestMessage, - push_service::{ServiceError, SignalServiceProfile}, - websocket::SignalWebSocket, + configuration::Endpoint, + prelude::PushService, + push_service::{ + HttpAuthOverride, ReqwestExt, ServiceError, SignalServiceProfile, + }, ServiceAddress, }; pub struct ProfileService { - ws: SignalWebSocket, + push_service: PushService, } impl ProfileService { - pub fn from_socket(ws: SignalWebSocket) -> Self { - ProfileService { ws } + pub fn from_socket(push_service: PushService) -> Self { + ProfileService { push_service } } pub async fn retrieve_profile_by_id( @@ -19,7 +23,7 @@ impl ProfileService { address: ServiceAddress, profile_key: Option, ) -> Result { - let endpoint = match profile_key { + let path = match profile_key { Some(key) => { let version = bincode::serialize(&key.get_profile_key_version( @@ -34,13 +38,17 @@ impl ProfileService { }, }; - let request = WebSocketRequestMessage { - path: Some(endpoint), - verb: Some("GET".into()), - // TODO: set locale to en_US - ..Default::default() - }; - - self.ws.request_json(request).await + self.push_service + .request( + Method::GET, + Endpoint::Service, + path, + HttpAuthOverride::NoOverride, + )? + .send_to_signal() + .await? + .json() + .await + .map_err(Into::into) } } diff --git a/src/sender.rs b/src/sender.rs index e755d8460..aea24d508 100644 --- a/src/sender.rs +++ b/src/sender.rs @@ -82,7 +82,6 @@ pub struct AttachmentSpec { pub blur_hash: Option, } -/// Equivalent of Java's `SignalServiceMessageSender`. #[derive(Clone)] pub struct MessageSender { identified_ws: SignalWebSocket, diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs index b74d02ff1..c423bf65a 100644 --- a/src/websocket/mod.rs +++ b/src/websocket/mod.rs @@ -73,7 +73,7 @@ struct SignalWebSocketProcess { /// Signal's requests should go in here, to be delivered to the application. request_sink: mpsc::Sender, - outgoing_request_map: HashMap< + outgoing_requests: HashMap< u64, oneshot::Sender>, >, @@ -84,7 +84,6 @@ struct SignalWebSocketProcess { BoxFuture<'static, Result>, >, - // WS backend stuff ws: WebSocket, } @@ -97,23 +96,23 @@ impl SignalWebSocketProcess { let msg = WebSocketMessage::decode(Bytes::from(frame))?; if let Some(request) = &msg.request { tracing::trace!( - "decoded WebSocketMessage request {{ r#type: {:?}, verb: {:?}, path: {:?}, body: {} bytes, headers: {:?}, id: {:?} }}", - msg.r#type(), + msg_type =? msg.r#type(), + request.id, request.verb, request.path, - request.body.as_ref().map(|x| x.len()).unwrap_or(0), - request.headers, - request.id, + request_body_size_bytes = request.body.as_ref().map(|x| x.len()).unwrap_or(0), + ?request.headers, + "decoded WebSocketMessage request" ); } else if let Some(response) = &msg.response { tracing::trace!( - "decoded WebSocketMessage response {{ r#type: {:?}, status: {:?}, message: {:?}, body: {} bytes, headers: {:?}, id: {:?} }}", - msg.r#type(), + msg_type =? msg.r#type(), response.status, response.message, - response.body.as_ref().map(|x| x.len()).unwrap_or(0), - response.headers, + response_body_size_bytes = response.body.as_ref().map(|x| x.len()).unwrap_or(0), + ?response.headers, response.id, + "decoded WebSocketMessage response" ); } else { tracing::debug!("decoded {msg:?}"); @@ -142,8 +141,7 @@ impl SignalWebSocketProcess { }), (Type::Response, _, Some(response)) => { if let Some(id) = response.id { - if let Some(responder) = - self.outgoing_request_map.remove(&id) + if let Some(responder) = self.outgoing_requests.remove(&id) { if let Err(e) = responder.send(Ok(response)) { tracing::warn!( @@ -187,7 +185,7 @@ impl SignalWebSocketProcess { let mut rng = rand::thread_rng(); loop { let id = rng.gen(); - if !self.outgoing_request_map.contains_key(&id) { + if !self.outgoing_requests.contains_key(&id) { return id; } } @@ -232,19 +230,19 @@ impl SignalWebSocketProcess { request.id = Some( request .id - .filter(|x| !self.outgoing_request_map.contains_key(x)) + .filter(|x| !self.outgoing_requests.contains_key(x)) .unwrap_or_else(|| self.next_request_id()), ); tracing::trace!( - "sending WebSocketRequestMessage {{ verb: {:?}, path: {:?}, body (bytes): {:?}, headers: {:?}, id: {:?} }}", + request.id, request.verb, request.path, - request.body.as_ref().map(|x| x.len()), - request.headers, - request.id, + request_body_size_bytes = request.body.as_ref().map(|x| x.len()), + ?request.headers, + "sending WebSocketRequestMessage", ); - self.outgoing_request_map.insert(request.id.unwrap(), responder); + self.outgoing_requests.insert(request.id.unwrap(), responder); let msg = WebSocketMessage { r#type: Some(web_socket_message::Type::Request.into()), request: Some(request), @@ -255,11 +253,12 @@ impl SignalWebSocketProcess { } None => { return Err(ServiceError::WsClosing { - reason: "SignalWebSocket: end of application request stream; socket closing" + reason: "end of application request stream; socket closing" }); } } } + // Incoming websocket message web_socket_item = self.ws.next().fuse() => { use reqwest_websocket::Message; match web_socket_item { @@ -301,8 +300,8 @@ impl SignalWebSocketProcess { let buffer = msg.encode_to_vec(); self.ws.send(buffer.into()).await?; } - Some(Err(e)) => { - tracing::error!("could not generate response to a Signal request; responder was canceled: {}. Continuing.", e); + Some(Err(error)) => { + tracing::error!(%error, "could not generate response to a Signal request; responder was canceled. continuing."); } None => { unreachable!("outgoing responses should never fuse") @@ -331,7 +330,7 @@ impl SignalWebSocket { keep_alive_path, requests: outgoing_requests, request_sink: incoming_request_sink, - outgoing_request_map: HashMap::default(), + outgoing_requests: HashMap::default(), outgoing_keep_alive_set: HashSet::new(), // Initializing the FuturesUnordered with a `pending` future means it will never fuse // itself, so an "empty" FuturesUnordered will still allow new futures to be added. From 5c99403abddc618281f981d10b6a23b94ee75f92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20F=C3=A9ron?= Date: Thu, 17 Oct 2024 23:09:43 +0200 Subject: [PATCH 03/14] Introduce trait to share error handling between PushService and WebSocketService --- src/account_manager.rs | 16 ++- src/groups_v2/manager.rs | 4 +- src/profile_service.rs | 4 +- src/push_service/account.rs | 16 ++- src/push_service/cdn.rs | 14 ++- src/push_service/error.rs | 4 +- src/push_service/keys.rs | 30 +++-- src/push_service/linking.rs | 12 +- src/push_service/mod.rs | 199 +++++++++---------------------- src/push_service/profile.rs | 4 +- src/push_service/registration.rs | 22 +++- src/push_service/response.rs | 160 +++++++++++++++++++++++++ src/websocket/mod.rs | 97 ++------------- src/websocket/sender.rs | 5 +- 14 files changed, 319 insertions(+), 268 deletions(-) create mode 100644 src/push_service/response.rs diff --git a/src/account_manager.rs b/src/account_manager.rs index 84ade2573..f3d4301ee 100644 --- a/src/account_manager.rs +++ b/src/account_manager.rs @@ -229,7 +229,9 @@ impl AccountManager { "/v1/devices/provisioning/code", HttpAuthOverride::NoOverride, )? - .send_to_signal() + .send() + .await? + .service_error_for_status() .await? .json() .await?; @@ -259,7 +261,9 @@ impl AccountManager { .json(&ProvisioningMessage { body: BASE64_RELAXED.encode(body), }) - .send_to_signal() + .send() + .await? + .service_error_for_status() .await?; Ok(()) @@ -597,7 +601,9 @@ impl AccountManager { .json(&Data { device_name: encrypted_device_name.encode_to_vec(), }) - .send_to_signal() + .send() + .await? + .service_error_for_status() .await?; Ok(()) @@ -626,7 +632,9 @@ impl AccountManager { token: String::from(token), captcha: String::from(captcha), }) - .send_to_signal() + .send() + .await? + .service_error_for_status() .await?; Ok(()) diff --git a/src/groups_v2/manager.rs b/src/groups_v2/manager.rs index 9183bef10..93e6229c7 100644 --- a/src/groups_v2/manager.rs +++ b/src/groups_v2/manager.rs @@ -174,7 +174,9 @@ impl GroupsManager { &path, HttpAuthOverride::NoOverride, )? - .send_to_signal() + .send() + .await? + .service_error_for_status() .await? .json() .await?; diff --git a/src/profile_service.rs b/src/profile_service.rs index 0c02b12ba..016658876 100644 --- a/src/profile_service.rs +++ b/src/profile_service.rs @@ -45,7 +45,9 @@ impl ProfileService { path, HttpAuthOverride::NoOverride, )? - .send_to_signal() + .send() + .await? + .service_error_for_status() .await? .json() .await diff --git a/src/push_service/account.rs b/src/push_service/account.rs index 293e0857e..0cc99ea0b 100644 --- a/src/push_service/account.rs +++ b/src/push_service/account.rs @@ -6,7 +6,9 @@ use reqwest::Method; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use super::{HttpAuthOverride, PushService, ReqwestExt, ServiceError}; +use super::{ + response::ReqwestExt, HttpAuthOverride, PushService, ServiceError, +}; use crate::{ configuration::Endpoint, utils::{serde_optional_base64, serde_phone_number}, @@ -134,7 +136,9 @@ impl PushService { "/v1/accounts/whoami", HttpAuthOverride::NoOverride, )? - .send_to_signal() + .send() + .await? + .service_error_for_status() .await? .json() .await @@ -157,7 +161,9 @@ impl PushService { "/v1/devices/", HttpAuthOverride::NoOverride, )? - .send_to_signal() + .send() + .await? + .service_error_for_status() .await? .json() .await?; @@ -181,7 +187,9 @@ impl PushService { HttpAuthOverride::NoOverride, )? .json(&attributes) - .send_to_signal() + .send() + .await? + .service_error_for_status() .await? .json() .await diff --git a/src/push_service/cdn.rs b/src/push_service/cdn.rs index 8309807b2..a345b563e 100644 --- a/src/push_service/cdn.rs +++ b/src/push_service/cdn.rs @@ -9,7 +9,7 @@ use crate::{ proto::AttachmentPointer, push_service::HttpAuthOverride, }; -use super::{PushService, ReqwestExt, ServiceError}; +use super::{response::ReqwestExt, PushService, ServiceError}; #[derive(Debug, serde::Deserialize, Default)] #[serde(rename_all = "camelCase")] @@ -53,7 +53,9 @@ impl PushService { path, HttpAuthOverride::Unidentified, // CDN requests are always without authentication )? - .send_to_signal() + .send() + .await? + .service_error_for_status() .await? .bytes_stream() .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) @@ -71,7 +73,9 @@ impl PushService { "/v2/attachments/form/upload", HttpAuthOverride::NoOverride, )? - .send_to_signal() + .send() + .await? + .service_error_for_status() .await? .json() .await @@ -130,7 +134,9 @@ impl PushService { HttpAuthOverride::NoOverride, )? .multipart(form) - .send_to_signal() + .send() + .await? + .service_error_for_status() .await?; debug!("HyperPushService::PUT response: {:?}", response); diff --git a/src/push_service/error.rs b/src/push_service/error.rs index 2e96f27c7..99561c9e7 100644 --- a/src/push_service/error.rs +++ b/src/push_service/error.rs @@ -21,8 +21,8 @@ pub enum ServiceError { #[error("Error decoding response: {reason}")] ResponseError { reason: String }, - #[error("Error decoding JSON response: {reason}")] - JsonDecodeError { reason: String }, + #[error("Error decoding JSON: {0}")] + JsonDecodeError(#[from] serde_json::Error), #[error("Error decoding protobuf frame: {0}")] ProtobufDecodeError(#[from] prost::DecodeError), #[error("error encoding or decoding bincode: {0}")] diff --git a/src/push_service/keys.rs b/src/push_service/keys.rs index 93b7d6636..df5892313 100644 --- a/src/push_service/keys.rs +++ b/src/push_service/keys.rs @@ -14,7 +14,7 @@ use crate::{ }; use super::{ - HttpAuthOverride, PushService, ReqwestExt, SenderCertificateJson, + response::ReqwestExt, HttpAuthOverride, PushService, SenderCertificateJson, ServiceError, ServiceIdType, VerifyAccountResponse, }; @@ -36,7 +36,9 @@ impl PushService { &format!("/v2/keys?identity={}", service_id_type), HttpAuthOverride::NoOverride, )? - .send_to_signal() + .send() + .await? + .service_error_for_status() .await? .json() .await @@ -55,7 +57,9 @@ impl PushService { HttpAuthOverride::NoOverride, )? .json(&pre_key_state) - .send_to_signal() + .send() + .await? + .service_error_for_status() .await?; Ok(()) @@ -76,7 +80,9 @@ impl PushService { &path, HttpAuthOverride::NoOverride, )? - .send_to_signal() + .send() + .await? + .service_error_for_status() .await? .json() .await?; @@ -105,7 +111,9 @@ impl PushService { &path, HttpAuthOverride::NoOverride, )? - .send_to_signal() + .send() + .await? + .service_error_for_status() .await? .json() .await?; @@ -127,7 +135,9 @@ impl PushService { "/v1/certificate/delivery", HttpAuthOverride::NoOverride, )? - .send_to_signal() + .send() + .await? + .service_error_for_status() .await? .json() .await?; @@ -144,7 +154,9 @@ impl PushService { "/v1/certificate/delivery?includeE164=false", HttpAuthOverride::NoOverride, )? - .send_to_signal() + .send() + .await? + .service_error_for_status() .await? .json() .await?; @@ -190,7 +202,9 @@ impl PushService { pni_registration_ids, signature_valid_on_each_signed_pre_key, }) - .send_to_signal() + .send() + .await? + .service_error_for_status() .await? .json() .await diff --git a/src/push_service/linking.rs b/src/push_service/linking.rs index d10f20116..5d5026c6c 100644 --- a/src/push_service/linking.rs +++ b/src/push_service/linking.rs @@ -5,8 +5,8 @@ use uuid::Uuid; use crate::configuration::Endpoint; use super::{ - DeviceActivationRequest, HttpAuth, HttpAuthOverride, PushService, - ReqwestExt, ServiceError, + response::ReqwestExt, DeviceActivationRequest, HttpAuth, HttpAuthOverride, + PushService, ServiceError, }; #[derive(Debug, Serialize)] @@ -67,7 +67,9 @@ impl PushService { HttpAuthOverride::Identified(http_auth), )? .json(&link_request) - .send_to_signal() + .send() + .await? + .service_error_for_status() .await? .json() .await @@ -81,7 +83,9 @@ impl PushService { format!("/v1/devices/{}", id), HttpAuthOverride::NoOverride, )? - .send_to_signal() + .send() + .await? + .service_error_for_status() .await?; Ok(()) diff --git a/src/push_service/mod.rs b/src/push_service/mod.rs index daf683617..b4e52aabc 100644 --- a/src/push_service/mod.rs +++ b/src/push_service/mod.rs @@ -15,7 +15,7 @@ use libsignal_protocol::{ IdentityKey, PreKeyBundle, PublicKey, }; use protobuf::ProtobufResponseExt; -use reqwest::{Method, RequestBuilder, Response, StatusCode}; +use reqwest::{Method, RequestBuilder}; use reqwest_websocket::RequestBuilderExt; use serde::{Deserialize, Serialize}; use tracing::{debug_span, Instrument}; @@ -30,6 +30,7 @@ mod keys; mod linking; mod profile; mod registration; +mod response; mod stickers; pub use account::*; @@ -39,6 +40,7 @@ pub use keys::*; pub use linking::*; pub use profile::*; pub use registration::*; +pub(crate) use response::{ReqwestExt, SignalServiceResponse}; #[derive(Debug, Serialize, Deserialize)] pub struct ProofRequired { @@ -204,147 +206,7 @@ impl PushService { Ok(builder) } -} - -#[async_trait::async_trait] -pub(crate) trait ReqwestExt -where - Self: Sized, -{ - async fn send_to_signal(self) -> Result; -} - -#[async_trait::async_trait] -impl ReqwestExt for RequestBuilder { - async fn send_to_signal(self) -> Result { - let response = self.send().await?; - match response.status() { - StatusCode::OK => Ok(response), - StatusCode::NO_CONTENT => Ok(response), - StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - Err(ServiceError::Unauthorized) - }, - StatusCode::NOT_FOUND => { - // This is 404 and means that e.g. recipient is not registered - Err(ServiceError::NotFoundError) - }, - StatusCode::PAYLOAD_TOO_LARGE => { - // This is 413 and means rate limit exceeded for Signal. - Err(ServiceError::RateLimitExceeded) - }, - StatusCode::CONFLICT => { - let mismatched_devices = - response.json().await.map_err(|error| { - tracing::error!( - %error, - "failed to decode HTTP 409 status" - ); - ServiceError::UnhandledResponseCode { - http_code: StatusCode::CONFLICT.as_u16(), - } - })?; - Err(ServiceError::MismatchedDevicesException( - mismatched_devices, - )) - }, - StatusCode::GONE => { - let stale_devices = response.json().await.map_err(|error| { - tracing::error!(%error, "failed to decode HTTP 410 status"); - ServiceError::UnhandledResponseCode { - http_code: StatusCode::GONE.as_u16(), - } - })?; - Err(ServiceError::StaleDevices(stale_devices)) - }, - StatusCode::LOCKED => { - let locked = response.json().await.map_err(|error| { - tracing::error!(%error, "failed to decode HTTP 423 status"); - ServiceError::UnhandledResponseCode { - http_code: StatusCode::LOCKED.as_u16(), - } - })?; - Err(ServiceError::Locked(locked)) - }, - StatusCode::PRECONDITION_REQUIRED => { - let proof_required = - response.json().await.map_err(|error| { - tracing::error!( - %error, - "failed to decode HTTP 428 status" - ); - ServiceError::UnhandledResponseCode { - http_code: StatusCode::PRECONDITION_REQUIRED - .as_u16(), - } - })?; - Err(ServiceError::ProofRequiredError(proof_required)) - }, - // XXX: fill in rest from PushServiceSocket - code => { - let response_text = response.text().await?; - tracing::trace!(status_code =% code, body = response_text, "unhandled HTTP response"); - Err(ServiceError::UnhandledResponseCode { - http_code: code.as_u16(), - }) - }, - } - } -} - -pub(crate) mod protobuf { - use async_trait::async_trait; - use prost::{EncodeError, Message}; - use reqwest::{header, RequestBuilder, Response}; - - use super::ServiceError; - - pub(crate) trait ProtobufRequestBuilderExt - where - Self: Sized, - { - /// Set the request payload encoded as protobuf. - /// Sets the `Content-Type` header to `application/protobuf` - #[allow(dead_code)] - fn protobuf( - self, - value: T, - ) -> Result; - } - - #[async_trait::async_trait] - pub(crate) trait ProtobufResponseExt { - /// Get the response body decoded from Protobuf - async fn protobuf( - self, - ) -> Result; - } - - impl ProtobufRequestBuilderExt for RequestBuilder { - fn protobuf( - self, - value: T, - ) -> Result { - let mut buf = Vec::new(); - value.encode(&mut buf)?; - let this = - self.header(header::CONTENT_TYPE, "application/protobuf"); - Ok(this.body(buf)) - } - } - - #[async_trait] - impl ProtobufResponseExt for Response { - async fn protobuf( - self, - ) -> Result { - let body = self.bytes().await?; - let decoded = T::decode(body)?; - Ok(decoded) - } - } -} -impl PushService { pub async fn ws( &mut self, path: &str, @@ -399,12 +261,67 @@ impl PushService { )? .send() .await? + .service_error_for_status() + .await? .protobuf() .await .map_err(Into::into) } } +pub(crate) mod protobuf { + use async_trait::async_trait; + use prost::{EncodeError, Message}; + use reqwest::{header, RequestBuilder, Response}; + + use super::ServiceError; + + pub(crate) trait ProtobufRequestBuilderExt + where + Self: Sized, + { + /// Set the request payload encoded as protobuf. + /// Sets the `Content-Type` header to `application/protobuf` + #[allow(dead_code)] + fn protobuf( + self, + value: T, + ) -> Result; + } + + #[async_trait::async_trait] + pub(crate) trait ProtobufResponseExt { + /// Get the response body decoded from Protobuf + async fn protobuf( + self, + ) -> Result; + } + + impl ProtobufRequestBuilderExt for RequestBuilder { + fn protobuf( + self, + value: T, + ) -> Result { + let mut buf = Vec::new(); + value.encode(&mut buf)?; + let this = + self.header(header::CONTENT_TYPE, "application/protobuf"); + Ok(this.body(buf)) + } + } + + #[async_trait] + impl ProtobufResponseExt for Response { + async fn protobuf( + self, + ) -> Result { + let body = self.bytes().await?; + let decoded = T::decode(body)?; + Ok(decoded) + } + } +} + #[cfg(test)] mod tests { use crate::configuration::SignalServers; diff --git a/src/push_service/profile.rs b/src/push_service/profile.rs index b05aef56c..c48c67679 100644 --- a/src/push_service/profile.rs +++ b/src/push_service/profile.rs @@ -109,7 +109,7 @@ impl PushService { &endpoint, HttpAuthOverride::NoOverride, )? - .send_to_signal() + .send().await?.service_error_for_status() .await? .json() .await @@ -174,7 +174,7 @@ impl PushService { HttpAuthOverride::NoOverride, )? .json(&command) - .send_to_signal() + .send().await?.service_error_for_status() .await? .json() .await; diff --git a/src/push_service/registration.rs b/src/push_service/registration.rs index f968b6add..3e56f1da9 100644 --- a/src/push_service/registration.rs +++ b/src/push_service/registration.rs @@ -8,7 +8,7 @@ use super::{AccountAttributes, PushService, ServiceError}; use crate::{ configuration::Endpoint, pre_keys::{KyberPreKeyEntity, SignedPreKeyEntity}, - push_service::{HttpAuthOverride, ReqwestExt}, + push_service::{response::ReqwestExt, HttpAuthOverride}, utils::serde_base64, }; @@ -168,7 +168,9 @@ impl PushService { device_activation_request, every_signed_key_valid: true, }) - .send_to_signal() + .send() + .await? + .service_error_for_status() .await? .json() .await @@ -207,7 +209,9 @@ impl PushService { mcc, mnc, }) - .send_to_signal() + .send() + .await? + .service_error_for_status() .await? .json() .await @@ -248,7 +252,9 @@ impl PushService { mnc, push_challenge, }) - .send_to_signal() + .send() + .await? + .service_error_for_status() .await? .json() .await @@ -288,7 +294,9 @@ impl PushService { HttpAuthOverride::Unidentified, )? .json(&VerificationCodeRequest { transport, client }) - .send_to_signal() + .send() + .await? + .service_error_for_status() .await? .json() .await @@ -314,7 +322,9 @@ impl PushService { .json(&VerificationCode { code: verification_code, }) - .send_to_signal() + .send() + .await? + .service_error_for_status() .await? .json() .await diff --git a/src/push_service/response.rs b/src/push_service/response.rs new file mode 100644 index 000000000..9fc02c6e9 --- /dev/null +++ b/src/push_service/response.rs @@ -0,0 +1,160 @@ +use reqwest::StatusCode; + +use crate::proto::WebSocketResponseMessage; + +use super::ServiceError; + +async fn service_error_for_status(response: R) -> Result +where + R: SignalServiceResponse, + ServiceError: From<::Error>, +{ + match response.status_code() { + StatusCode::OK => Ok(response), + StatusCode::NO_CONTENT => Ok(response), + StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { + Err(ServiceError::Unauthorized) + }, + StatusCode::NOT_FOUND => { + // This is 404 and means that e.g. recipient is not registered + Err(ServiceError::NotFoundError) + }, + StatusCode::PAYLOAD_TOO_LARGE => { + // This is 413 and means rate limit exceeded for Signal. + Err(ServiceError::RateLimitExceeded) + }, + StatusCode::CONFLICT => { + let mismatched_devices = + response.json().await.map_err(|error| { + tracing::error!( + %error, + "failed to decode HTTP 409 status" + ); + ServiceError::UnhandledResponseCode { + http_code: StatusCode::CONFLICT.as_u16(), + } + })?; + Err(ServiceError::MismatchedDevicesException(mismatched_devices)) + }, + StatusCode::GONE => { + let stale_devices = response.json().await.map_err(|error| { + tracing::error!(%error, "failed to decode HTTP 410 status"); + ServiceError::UnhandledResponseCode { + http_code: StatusCode::GONE.as_u16(), + } + })?; + Err(ServiceError::StaleDevices(stale_devices)) + }, + StatusCode::LOCKED => { + let locked = response.json().await.map_err(|error| { + tracing::error!(%error, "failed to decode HTTP 423 status"); + ServiceError::UnhandledResponseCode { + http_code: StatusCode::LOCKED.as_u16(), + } + })?; + Err(ServiceError::Locked(locked)) + }, + StatusCode::PRECONDITION_REQUIRED => { + let proof_required = response.json().await.map_err(|error| { + tracing::error!( + %error, + "failed to decode HTTP 428 status" + ); + ServiceError::UnhandledResponseCode { + http_code: StatusCode::PRECONDITION_REQUIRED.as_u16(), + } + })?; + Err(ServiceError::ProofRequiredError(proof_required)) + }, + // XXX: fill in rest from PushServiceSocket + code => { + let response_text = response.text().await?; + tracing::trace!(status_code =% code, body = response_text, "unhandled HTTP response"); + Err(ServiceError::UnhandledResponseCode { + http_code: code.as_u16(), + }) + }, + } +} + +#[async_trait::async_trait] +pub(crate) trait SignalServiceResponse { + type Error: std::error::Error; + + fn status_code(&self) -> StatusCode; + + async fn json(self) -> Result + where + for<'de> U: serde::Deserialize<'de>; + + async fn text(self) -> Result; +} + +#[async_trait::async_trait] +impl SignalServiceResponse for reqwest::Response { + type Error = reqwest::Error; + + fn status_code(&self) -> StatusCode { + self.status() + } + + async fn json(self) -> Result + where + for<'de> U: serde::Deserialize<'de>, + { + reqwest::Response::json(self).await + } + + async fn text(self) -> Result { + reqwest::Response::text(self).await + } +} + +#[async_trait::async_trait] +impl SignalServiceResponse for WebSocketResponseMessage { + type Error = ServiceError; + + fn status_code(&self) -> StatusCode { + StatusCode::from_u16(self.status() as u16).unwrap_or_default() + } + + async fn json(self) -> Result + where + for<'de> U: serde::Deserialize<'de>, + { + serde_json::from_slice(self.body()).map_err(Into::into) + } + + async fn text(self) -> Result { + Ok(self + .body + .map(|body| String::from_utf8_lossy(&body).to_string()) + .unwrap_or_default()) + } +} + +#[async_trait::async_trait] +pub(crate) trait ReqwestExt +where + Self: Sized, +{ + /// convenience error handler to be used in the builder-style API of `reqwest::Response` + async fn service_error_for_status( + self, + ) -> Result; +} + +#[async_trait::async_trait] +impl ReqwestExt for reqwest::Response { + async fn service_error_for_status( + self, + ) -> Result { + service_error_for_status(self).await + } +} + +impl WebSocketResponseMessage { + pub async fn service_error_for_status(self) -> Result { + service_error_for_status(self).await + } +} diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs index c423bf65a..1350bab7c 100644 --- a/src/websocket/mod.rs +++ b/src/websocket/mod.rs @@ -17,7 +17,7 @@ use crate::proto::{ web_socket_message, WebSocketMessage, WebSocketRequestMessage, WebSocketResponseMessage, }; -use crate::push_service::{self, MismatchedDevices, ServiceError}; +use crate::push_service::{self, ServiceError, SignalServiceResponse}; mod sender; // pub(crate) mod tungstenite; @@ -432,95 +432,16 @@ impl SignalWebSocket { where for<'de> T: serde::Deserialize<'de>, { - let response = self.request(r).await?; - if response.status() != 200 { - tracing::debug!( - "request_json with non-200 status code. message: {}", - response.message() - ); - } - - fn json(body: &[u8]) -> Result - where - for<'de> U: serde::Deserialize<'de>, - { - serde_json::from_slice(body).map_err(|e| { - ServiceError::JsonDecodeError { - reason: e.to_string(), - } - }) - } - - match response.status() { - 200 | 204 => json(response.body()), - 401 | 403 => Err(ServiceError::Unauthorized), - 404 => Err(ServiceError::NotFoundError), - 413 /* PAYLOAD_TOO_LARGE */ => Err(ServiceError::RateLimitExceeded) , - 409 /* CONFLICT */ => { - let mismatched_devices: MismatchedDevices = - json(response.body()).map_err(|e| { - tracing::error!( - "Failed to decode HTTP 409 response: {}", - e - ); - ServiceError::UnhandledResponseCode { - http_code: 409, - } - })?; - Err(ServiceError::MismatchedDevicesException( - mismatched_devices, - )) - }, - 410 /* GONE */ => { - let stale_devices = - json(response.body()).map_err(|e| { - tracing::error!( - "Failed to decode HTTP 410 response: {}", - e - ); - ServiceError::UnhandledResponseCode { - http_code: 410, - } - })?; - Err(ServiceError::StaleDevices(stale_devices)) - }, - 423 /* LOCKED */ => { - let locked = json(response.body()).map_err(|e| { - tracing::error!("Failed to decode HTTP 423 response: {}", e); - ServiceError::UnhandledResponseCode { - http_code: 423, - } - })?; - Err(ServiceError::Locked(locked)) - }, - 428 /* PRECONDITION_REQUIRED */ => { - let proof_required = json(response.body()).map_err(|e| { - tracing::error!("Failed to decode HTTP 428 response: {}", e); - ServiceError::UnhandledResponseCode { - http_code: 428, - } - })?; - Err(ServiceError::ProofRequiredError(proof_required)) - }, - _ => Err(ServiceError::UnhandledResponseCode { - http_code: response.status() as u16, - }), - } - } - - pub(crate) async fn put_json( - &mut self, - path: &str, - value: S, - ) -> Result - where - for<'de> D: Deserialize<'de>, - S: Serialize, - { - self.put_json_with_headers(path, value, vec![]).await + self.request(r) + .await? + .service_error_for_status() + .await? + .json() + .await + .map_err(Into::into) } - pub(crate) async fn put_json_with_headers<'h, D, S>( + pub(crate) async fn put_json<'h, D, S>( &mut self, path: &str, value: S, diff --git a/src/websocket/sender.rs b/src/websocket/sender.rs index e9ea90ef6..341430b38 100644 --- a/src/websocket/sender.rs +++ b/src/websocket/sender.rs @@ -13,7 +13,7 @@ impl SignalWebSocket { messages: OutgoingPushMessages, ) -> Result { let path = format!("/v1/messages/{}", messages.destination); - self.put_json(&path, messages).await + self.put_json(&path, messages, vec![]).await } pub async fn send_messages_unidentified( @@ -26,7 +26,6 @@ impl SignalWebSocket { "Unidentified-Access-Key:{}", BASE64_RELAXED.encode(&access.key) ); - self.put_json_with_headers(&path, messages, vec![header]) - .await + self.put_json(&path, messages, vec![header]).await } } From 23d1a1b81d4dbec3a497f36c2f93d449691ec101 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20F=C3=A9ron?= Date: Thu, 17 Oct 2024 23:15:25 +0200 Subject: [PATCH 04/14] Clippy clip --- src/account_manager.rs | 2 +- src/messagepipe.rs | 2 +- src/push_service/keys.rs | 4 ++-- src/push_service/mod.rs | 9 +++------ src/push_service/registration.rs | 6 +++--- src/websocket/mod.rs | 7 ++++--- 6 files changed, 14 insertions(+), 16 deletions(-) diff --git a/src/account_manager.rs b/src/account_manager.rs index f3d4301ee..9e69aaccb 100644 --- a/src/account_manager.rs +++ b/src/account_manager.rs @@ -255,7 +255,7 @@ impl AccountManager { .request( Method::PUT, Endpoint::Service, - &format!("/v1/provisioning/{}", destination), + format!("/v1/provisioning/{}", destination), HttpAuthOverride::NoOverride, )? .json(&ProvisioningMessage { diff --git a/src/messagepipe.rs b/src/messagepipe.rs index f3e0c67c5..4fd1d54d4 100644 --- a/src/messagepipe.rs +++ b/src/messagepipe.rs @@ -97,7 +97,7 @@ impl MessagePipe { responder .send(response) .map_err(|_| ServiceError::WsClosing { - reason: "could not respond to message pipe request".into(), + reason: "could not respond to message pipe request", })?; Ok(result) diff --git a/src/push_service/keys.rs b/src/push_service/keys.rs index df5892313..f4addb013 100644 --- a/src/push_service/keys.rs +++ b/src/push_service/keys.rs @@ -33,7 +33,7 @@ impl PushService { self.request( Method::GET, Endpoint::Service, - &format!("/v2/keys?identity={}", service_id_type), + format!("/v2/keys?identity={}", service_id_type), HttpAuthOverride::NoOverride, )? .send() @@ -53,7 +53,7 @@ impl PushService { self.request( Method::PUT, Endpoint::Service, - &format!("/v2/keys?identity={}", service_id_type), + format!("/v2/keys?identity={}", service_id_type), HttpAuthOverride::NoOverride, )? .json(&pre_key_state) diff --git a/src/push_service/mod.rs b/src/push_service/mod.rs index b4e52aabc..c6cf31fd5 100644 --- a/src/push_service/mod.rs +++ b/src/push_service/mod.rs @@ -160,7 +160,7 @@ impl PushService { let client = reqwest::ClientBuilder::new() .add_root_certificate( reqwest::Certificate::from_pem( - &cfg.certificate_authority.as_bytes(), + cfg.certificate_authority.as_bytes(), ) .unwrap(), ) @@ -332,11 +332,8 @@ mod tests { let configs = &[SignalServers::Staging, SignalServers::Production]; for cfg in configs { - let _ = super::PushService::new( - cfg, - None, - "libsignal-service test".to_string(), - ); + let _ = + super::PushService::new(cfg, None, "libsignal-service test"); } } diff --git a/src/push_service/registration.rs b/src/push_service/registration.rs index 3e56f1da9..2ab6ea0c2 100644 --- a/src/push_service/registration.rs +++ b/src/push_service/registration.rs @@ -241,7 +241,7 @@ impl PushService { self.request( Method::PATCH, Endpoint::Service, - &format!("/v1/verification/session/{}", session_id), + format!("/v1/verification/session/{}", session_id), HttpAuthOverride::Unidentified, )? .json(&UpdateVerificationSessionRequestBody { @@ -290,7 +290,7 @@ impl PushService { self.request( Method::POST, Endpoint::Service, - &format!("/v1/verification/session/{}/code", session_id), + format!("/v1/verification/session/{}/code", session_id), HttpAuthOverride::Unidentified, )? .json(&VerificationCodeRequest { transport, client }) @@ -316,7 +316,7 @@ impl PushService { self.request( Method::PUT, Endpoint::Service, - &format!("/v1/verification/session/{}/code", session_id), + format!("/v1/verification/session/{}/code", session_id), HttpAuthOverride::Unidentified, )? .json(&VerificationCode { diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs index 1350bab7c..128956eac 100644 --- a/src/websocket/mod.rs +++ b/src/websocket/mod.rs @@ -128,7 +128,7 @@ impl SignalWebSocketProcess { tracing::trace!("sending request with body"); self.request_sink.send((request, sink)).await.map_err( |_| ServiceError::WsClosing { - reason: "request handler failed".into(), + reason: "request handler failed", }, )?; self.outgoing_responses.push(Box::pin(recv)); @@ -197,7 +197,7 @@ impl SignalWebSocketProcess { push_service::KEEPALIVE_TIMEOUT_SECONDS, ); - Ok(loop { + loop { futures::select! { _ = ka_interval.tick().fuse() => { use prost::Message; @@ -309,7 +309,8 @@ impl SignalWebSocketProcess { } } } - }) + } + Ok(()) } } From c758d8f0dabd2d89b141429d2ea1b605127e7d4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20F=C3=A9ron?= Date: Thu, 17 Oct 2024 23:43:33 +0200 Subject: [PATCH 05/14] Add builder for WebSocketRequestMessage I will stop with the bike-shedding now... =D --- src/websocket/mod.rs | 41 +++++++------------------------- src/websocket/request.rs | 51 ++++++++++++++++++++++++++++++++++++++++ src/websocket/sender.rs | 20 +++++++++------- 3 files changed, 71 insertions(+), 41 deletions(-) create mode 100644 src/websocket/request.rs diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs index 128956eac..76d7d362b 100644 --- a/src/websocket/mod.rs +++ b/src/websocket/mod.rs @@ -9,8 +9,8 @@ use futures::channel::{mpsc, oneshot}; use futures::future::BoxFuture; use futures::prelude::*; use futures::stream::FuturesUnordered; +use reqwest::Method; use reqwest_websocket::WebSocket; -use serde::{Deserialize, Serialize}; use tokio::time::Instant; use crate::proto::{ @@ -19,8 +19,10 @@ use crate::proto::{ }; use crate::push_service::{self, ServiceError, SignalServiceResponse}; +mod request; mod sender; -// pub(crate) mod tungstenite; + +pub use request::WebSocketRequestMessageBuilder; type RequestStreamItem = ( WebSocketRequestMessage, @@ -202,12 +204,10 @@ impl SignalWebSocketProcess { _ = ka_interval.tick().fuse() => { use prost::Message; tracing::debug!("sending keep-alive"); - let request = WebSocketRequestMessage { - id: Some(self.next_request_id()), - path: Some(self.keep_alive_path.clone()), - verb: Some("GET".into()), - ..Default::default() - }; + let request = WebSocketRequestMessage::new(Method::GET) + .id(self.next_request_id()) + .path(&self.keep_alive_path) + .build(); self.outgoing_keep_alive_set.insert(request.id.unwrap()); let msg = WebSocketMessage { r#type: Some(web_socket_message::Type::Request.into()), @@ -441,29 +441,4 @@ impl SignalWebSocket { .await .map_err(Into::into) } - - pub(crate) async fn put_json<'h, D, S>( - &mut self, - path: &str, - value: S, - mut extra_headers: Vec, - ) -> Result - where - for<'de> D: Deserialize<'de>, - S: Serialize, - { - extra_headers.push("content-type:application/json".into()); - let request = WebSocketRequestMessage { - path: Some(path.into()), - verb: Some("PUT".into()), - headers: extra_headers, - body: Some(serde_json::to_vec(&value).map_err(|e| { - ServiceError::SendError { - reason: format!("Serializing JSON {}", e), - } - })?), - ..Default::default() - }; - self.request_json(request).await - } } diff --git a/src/websocket/request.rs b/src/websocket/request.rs new file mode 100644 index 000000000..f7a7b01fb --- /dev/null +++ b/src/websocket/request.rs @@ -0,0 +1,51 @@ +use reqwest::Method; +use serde::Serialize; + +use crate::proto::WebSocketRequestMessage; + +#[derive(Debug)] +pub struct WebSocketRequestMessageBuilder { + request: WebSocketRequestMessage, +} + +impl WebSocketRequestMessage { + pub fn new(method: Method) -> WebSocketRequestMessageBuilder { + WebSocketRequestMessageBuilder { + request: WebSocketRequestMessage { + verb: Some(method.to_string()), + ..Default::default() + }, + } + } +} + +impl WebSocketRequestMessageBuilder { + pub fn id(mut self, id: u64) -> Self { + self.request.id = Some(id); + self + } + + pub fn path(mut self, path: impl Into) -> Self { + self.request.path = Some(path.into()); + self + } + + pub fn header(mut self, key: &str, value: impl AsRef) -> Self { + self.request + .headers + .push(format!("{key}={}", value.as_ref())); + self + } + + pub fn json( + mut self, + value: S, + ) -> Result { + self.request.body = Some(serde_json::to_vec(&value)?); + Ok(self.header("Content-Type", "application/json").request) + } + + pub fn build(self) -> WebSocketRequestMessage { + self.request + } +} diff --git a/src/websocket/sender.rs b/src/websocket/sender.rs index 341430b38..394cf4c12 100644 --- a/src/websocket/sender.rs +++ b/src/websocket/sender.rs @@ -12,8 +12,10 @@ impl SignalWebSocket { &mut self, messages: OutgoingPushMessages, ) -> Result { - let path = format!("/v1/messages/{}", messages.destination); - self.put_json(&path, messages, vec![]).await + let request = WebSocketRequestMessage::new(Method::PUT) + .path(format!("/v1/messages/{}", messages.destination)) + .json(&messages)?; + self.request_json(request).await } pub async fn send_messages_unidentified( @@ -21,11 +23,13 @@ impl SignalWebSocket { messages: OutgoingPushMessages, access: &UnidentifiedAccess, ) -> Result { - let path = format!("/v1/messages/{}", messages.destination); - let header = format!( - "Unidentified-Access-Key:{}", - BASE64_RELAXED.encode(&access.key) - ); - self.put_json(&path, messages, vec![header]).await + let request = WebSocketRequestMessage::new(Method::PUT) + .path(format!("/v1/messages/{}", messages.destination)) + .header( + "Unidentified-Access-Key:{}", + BASE64_RELAXED.encode(&access.key), + ) + .json(&messages)?; + self.request_json(request).await } } From 267b28550040c2e97c3b6a1b6753e45cac9b7527 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20F=C3=A9ron?= Date: Thu, 17 Oct 2024 23:44:21 +0200 Subject: [PATCH 06/14] Clippy stfu --- src/websocket/request.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/websocket/request.rs b/src/websocket/request.rs index f7a7b01fb..71e0fd488 100644 --- a/src/websocket/request.rs +++ b/src/websocket/request.rs @@ -9,6 +9,7 @@ pub struct WebSocketRequestMessageBuilder { } impl WebSocketRequestMessage { + #[allow(clippy::new_ret_no_self)] pub fn new(method: Method) -> WebSocketRequestMessageBuilder { WebSocketRequestMessageBuilder { request: WebSocketRequestMessage { From 9dbbc7254e4b0f49300eb52f9be63220057089e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20F=C3=A9ron?= Date: Thu, 17 Oct 2024 23:45:37 +0200 Subject: [PATCH 07/14] restore comment --- src/push_service/cdn.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/push_service/cdn.rs b/src/push_service/cdn.rs index a345b563e..b88c9e901 100644 --- a/src/push_service/cdn.rs +++ b/src/push_service/cdn.rs @@ -108,6 +108,7 @@ impl PushService { filename: String, mut reader: impl Read + Send, ) -> Result<(), ServiceError> { + // Amazon S3 expects multipart fields in a very specific order (the file contents should go last.) let mut form = reqwest::multipart::Form::new(); form = form.text("acl", upload_attributes.acl); form = form.text("key", upload_attributes.key); From 660dc9173790329e091cfb6a4ed590d282798bb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20F=C3=A9ron?= Date: Thu, 17 Oct 2024 23:48:18 +0200 Subject: [PATCH 08/14] fmt --- src/push_service/profile.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/push_service/profile.rs b/src/push_service/profile.rs index c48c67679..a1a859ac9 100644 --- a/src/push_service/profile.rs +++ b/src/push_service/profile.rs @@ -109,7 +109,9 @@ impl PushService { &endpoint, HttpAuthOverride::NoOverride, )? - .send().await?.service_error_for_status() + .send() + .await? + .service_error_for_status() .await? .json() .await @@ -174,7 +176,9 @@ impl PushService { HttpAuthOverride::NoOverride, )? .json(&command) - .send().await?.service_error_for_status() + .send() + .await? + .service_error_for_status() .await? .json() .await; From 59078db50af2e2459684fef3e59b273ad03de5df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20F=C3=A9ron?= Date: Thu, 17 Oct 2024 23:50:40 +0200 Subject: [PATCH 09/14] Fix compilation for Rust 1.75 and remove unwrap --- src/cipher.rs | 10 +++++----- src/envelope.rs | 2 +- src/messagepipe.rs | 2 +- src/push_service/cdn.rs | 18 +++++++++++++----- src/push_service/error.rs | 2 +- src/websocket/mod.rs | 6 +++--- 6 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/cipher.rs b/src/cipher.rs index 09649b679..d787f9472 100644 --- a/src/cipher.rs +++ b/src/cipher.rs @@ -134,7 +134,7 @@ where let ciphertext = if let Some(msg) = envelope.content.as_ref() { msg } else { - return Err(ServiceError::InvalidFrameError { + return Err(ServiceError::InvalidFrame { reason: "Envelope should have either a legacy message or content." .into(), @@ -311,7 +311,7 @@ where }, _ => { // else - return Err(ServiceError::InvalidFrameError { + return Err(ServiceError::InvalidFrame { reason: format!( "Envelope has unknown type {:?}.", envelope.r#type() @@ -408,7 +408,7 @@ struct Plaintext { #[allow(clippy::comparison_chain)] fn add_padding(version: u32, contents: &[u8]) -> Result, ServiceError> { if version < 2 { - Err(ServiceError::InvalidFrameError { + Err(ServiceError::InvalidFrame { reason: format!("Unknown version {}", version), }) } else if version == 2 { @@ -436,7 +436,7 @@ fn strip_padding_version( contents: &mut Vec, ) -> Result<(), ServiceError> { if version < 2 { - Err(ServiceError::InvalidFrameError { + Err(ServiceError::InvalidFrame { reason: format!("Unknown version {}", version), }) } else if version == 2 { @@ -450,7 +450,7 @@ fn strip_padding_version( #[allow(clippy::comparison_chain)] fn strip_padding(contents: &mut Vec) -> Result<(), ServiceError> { let new_length = Iso7816::raw_unpad(contents) - .map_err(|e| ServiceError::InvalidFrameError { + .map_err(|e| ServiceError::InvalidFrame { reason: format!("Invalid message padding: {:?}", e), })? .len(); diff --git a/src/envelope.rs b/src/envelope.rs index 1a33de669..694700b89 100644 --- a/src/envelope.rs +++ b/src/envelope.rs @@ -42,7 +42,7 @@ impl Envelope { if input.len() < VERSION_LENGTH || input[VERSION_OFFSET] != SUPPORTED_VERSION { - return Err(ServiceError::InvalidFrameError { + return Err(ServiceError::InvalidFrame { reason: "Unsupported signaling cryptogram version".into(), }); } diff --git a/src/messagepipe.rs b/src/messagepipe.rs index 4fd1d54d4..711ed0cff 100644 --- a/src/messagepipe.rs +++ b/src/messagepipe.rs @@ -79,7 +79,7 @@ impl MessagePipe { let body = if let Some(body) = request.body.as_ref() { body } else { - return Err(ServiceError::InvalidFrameError { + return Err(ServiceError::InvalidFrame { reason: "Request without body.".into(), }); }; diff --git a/src/push_service/cdn.rs b/src/push_service/cdn.rs index b88c9e901..51b19abe2 100644 --- a/src/push_service/cdn.rs +++ b/src/push_service/cdn.rs @@ -32,12 +32,20 @@ impl PushService { &mut self, ptr: &AttachmentPointer, ) -> Result { - let id = match ptr.attachment_identifier.as_ref().unwrap() { - AttachmentIdentifier::CdnId(id) => &id.to_string(), - AttachmentIdentifier::CdnKey(key) => key, + let path = match ptr.attachment_identifier.as_ref() { + Some(AttachmentIdentifier::CdnId(id)) => { + format!("attachments/{}", id) + }, + Some(AttachmentIdentifier::CdnKey(key)) => { + format!("attachments/{}", key) + }, + None => { + return Err(ServiceError::InvalidFrame { + reason: "no attachment identifier in pointer".into(), + }); + }, }; - self.get_from_cdn(ptr.cdn_number(), &format!("attachments/{}", id)) - .await + self.get_from_cdn(ptr.cdn_number(), &path).await } #[tracing::instrument(skip(self))] diff --git a/src/push_service/error.rs b/src/push_service/error.rs index 99561c9e7..6efa1ec38 100644 --- a/src/push_service/error.rs +++ b/src/push_service/error.rs @@ -45,7 +45,7 @@ pub enum ServiceError { WsClosing { reason: &'static str }, #[error("Invalid frame: {reason}")] - InvalidFrameError { reason: String }, + InvalidFrame { reason: String }, #[error("MAC error")] MacError, diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs index 76d7d362b..48b00ceb7 100644 --- a/src/websocket/mod.rs +++ b/src/websocket/mod.rs @@ -122,7 +122,7 @@ impl SignalWebSocketProcess { use web_socket_message::Type; match (msg.r#type(), msg.request, msg.response) { - (Type::Unknown, _, _) => Err(ServiceError::InvalidFrameError { + (Type::Unknown, _, _) => Err(ServiceError::InvalidFrame { reason: "Unknown frame type".into(), }), (Type::Request, Some(request), _) => { @@ -137,7 +137,7 @@ impl SignalWebSocketProcess { Ok(()) }, - (Type::Request, None, _) => Err(ServiceError::InvalidFrameError { + (Type::Request, None, _) => Err(ServiceError::InvalidFrame { reason: "Type was request, but does not contain request." .into(), }), @@ -175,7 +175,7 @@ impl SignalWebSocketProcess { Ok(()) }, - (Type::Response, _, None) => Err(ServiceError::InvalidFrameError { + (Type::Response, _, None) => Err(ServiceError::InvalidFrame { reason: "Type was response, but does not contain response." .into(), }), From fd95398baad5f67bea3347ee74141c409425dce3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20F=C3=A9ron?= Date: Fri, 18 Oct 2024 18:02:10 +0200 Subject: [PATCH 10/14] Clean ServiceError a little bit --- src/cipher.rs | 17 ++++------------- src/envelope.rs | 2 +- src/groups_v2/manager.rs | 12 +++--------- src/messagepipe.rs | 2 +- src/provisioning/mod.rs | 4 ++-- src/provisioning/pipe.rs | 16 ++++++---------- src/push_service/cdn.rs | 2 +- src/push_service/error.rs | 15 +++++++++++---- src/receiver.rs | 2 +- src/sender.rs | 15 +++++++-------- src/websocket/mod.rs | 12 +++++------- 11 files changed, 42 insertions(+), 57 deletions(-) diff --git a/src/cipher.rs b/src/cipher.rs index d787f9472..81460274f 100644 --- a/src/cipher.rs +++ b/src/cipher.rs @@ -312,10 +312,7 @@ where _ => { // else return Err(ServiceError::InvalidFrame { - reason: format!( - "Envelope has unknown type {:?}.", - envelope.r#type() - ), + reason: "envelope has unknown type", }); }, }; @@ -408,9 +405,7 @@ struct Plaintext { #[allow(clippy::comparison_chain)] fn add_padding(version: u32, contents: &[u8]) -> Result, ServiceError> { if version < 2 { - Err(ServiceError::InvalidFrame { - reason: format!("Unknown version {}", version), - }) + Err(ServiceError::PaddingVersion(version)) } else if version == 2 { Ok(contents.to_vec()) } else { @@ -437,7 +432,7 @@ fn strip_padding_version( ) -> Result<(), ServiceError> { if version < 2 { Err(ServiceError::InvalidFrame { - reason: format!("Unknown version {}", version), + reason: "unknown version", }) } else if version == 2 { Ok(()) @@ -449,11 +444,7 @@ fn strip_padding_version( #[allow(clippy::comparison_chain)] fn strip_padding(contents: &mut Vec) -> Result<(), ServiceError> { - let new_length = Iso7816::raw_unpad(contents) - .map_err(|e| ServiceError::InvalidFrame { - reason: format!("Invalid message padding: {:?}", e), - })? - .len(); + let new_length = Iso7816::raw_unpad(contents)?.len(); contents.resize(new_length, 0); Ok(()) } diff --git a/src/envelope.rs b/src/envelope.rs index 694700b89..ff247c9a3 100644 --- a/src/envelope.rs +++ b/src/envelope.rs @@ -43,7 +43,7 @@ impl Envelope { || input[VERSION_OFFSET] != SUPPORTED_VERSION { return Err(ServiceError::InvalidFrame { - reason: "Unsupported signaling cryptogram version".into(), + reason: "unsupported signaling cryptogram version", }); } diff --git a/src/groups_v2/manager.rs b/src/groups_v2/manager.rs index 93e6229c7..a1c7af0be 100644 --- a/src/groups_v2/manager.rs +++ b/src/groups_v2/manager.rs @@ -183,10 +183,9 @@ impl GroupsManager { self.credentials_cache .write(credentials_response.parse()?)?; self.credentials_cache.get(&today)?.ok_or_else(|| { - ServiceError::ResponseError { + ServiceError::InvalidFrame { reason: - "credentials received did not contain requested day" - .into(), + "credentials received did not contain requested day", } })? }; @@ -287,12 +286,7 @@ impl GroupsManager { .retrieve_groups_v2_profile_avatar(path) .await?; let mut result = Vec::with_capacity(10 * 1024 * 1024); - encrypted_avatar - .read_to_end(&mut result) - .await - .map_err(|e| ServiceError::ResponseError { - reason: format!("reading avatar data: {}", e), - })?; + encrypted_avatar.read_to_end(&mut result).await?; Ok(GroupOperations::new(group_secret_params).decrypt_avatar(&result)) } diff --git a/src/messagepipe.rs b/src/messagepipe.rs index 711ed0cff..9d4e89efa 100644 --- a/src/messagepipe.rs +++ b/src/messagepipe.rs @@ -80,7 +80,7 @@ impl MessagePipe { body } else { return Err(ServiceError::InvalidFrame { - reason: "Request without body.".into(), + reason: "request without body.", }); }; Some(Incoming::Envelope(Envelope::decrypt( diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index 2455409a1..f1c16c589 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -75,8 +75,8 @@ pub enum ProvisioningError { DecodeError(#[from] prost::DecodeError), #[error("Websocket error: {reason}")] WsError { reason: String }, - #[error("Websocket closing: {reason}")] - WsClosing { reason: String }, + #[error("Websocket closing")] + WsClosing, #[error("Service error: {0}")] ServiceError(#[from] ServiceError), #[error("libsignal-protocol error: {0}")] diff --git a/src/provisioning/pipe.rs b/src/provisioning/pipe.rs index 7aa258b9f..0842ec895 100644 --- a/src/provisioning/pipe.rs +++ b/src/provisioning/pipe.rs @@ -105,11 +105,9 @@ impl ProvisioningPipe { ); // acknowledge - responder.send(ok).map_err(|_| { - ProvisioningError::WsClosing { - reason: "could not respond to provision request".into(), - } - })?; + responder + .send(ok) + .map_err(|_| ProvisioningError::WsClosing)?; Ok(Some(ProvisioningStep::Url(provisioning_url))) }, @@ -132,11 +130,9 @@ impl ProvisioningPipe { self.provisioning_cipher.decrypt(provision_envelope)?; // acknowledge - responder.send(ok).map_err(|_| { - ProvisioningError::WsClosing { - reason: "could not respond to provision request".into(), - } - })?; + responder + .send(ok) + .map_err(|_| ProvisioningError::WsClosing)?; Ok(Some(ProvisioningStep::Message(provision_message))) }, diff --git a/src/push_service/cdn.rs b/src/push_service/cdn.rs index 51b19abe2..daf37a1d7 100644 --- a/src/push_service/cdn.rs +++ b/src/push_service/cdn.rs @@ -41,7 +41,7 @@ impl PushService { }, None => { return Err(ServiceError::InvalidFrame { - reason: "no attachment identifier in pointer".into(), + reason: "no attachment identifier in pointer", }); }, }; diff --git a/src/push_service/error.rs b/src/push_service/error.rs index 6efa1ec38..f10f67866 100644 --- a/src/push_service/error.rs +++ b/src/push_service/error.rs @@ -1,3 +1,4 @@ +use aes::cipher::block_padding::UnpadError; use libsignal_protocol::SignalProtocolError; use zkgroup::ZkGroupDeserializationFailure; @@ -10,7 +11,7 @@ use super::{ #[derive(thiserror::Error, Debug)] pub enum ServiceError { #[error("Service request timed out: {reason}")] - Timeout { reason: String }, + Timeout { reason: &'static str }, #[error("invalid URL: {0}")] InvalidUrl(#[from] url::ParseError), @@ -18,8 +19,8 @@ pub enum ServiceError { #[error("Error sending request: {reason}")] SendError { reason: String }, - #[error("Error decoding response: {reason}")] - ResponseError { reason: String }, + #[error("i/o error")] + IO(#[from] std::io::Error), #[error("Error decoding JSON: {0}")] JsonDecodeError(#[from] serde_json::Error), @@ -44,8 +45,14 @@ pub enum ServiceError { #[error("Websocket closing: {reason}")] WsClosing { reason: &'static str }, + #[error("Invalid padding: {0}")] + Padding(#[from] UnpadError), + + #[error("unknown padding version {0}")] + PaddingVersion(u32), + #[error("Invalid frame: {reason}")] - InvalidFrame { reason: String }, + InvalidFrame { reason: &'static str }, #[error("MAC error")] MacError, diff --git a/src/receiver.rs b/src/receiver.rs index ba69320f0..57016dce7 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -64,7 +64,7 @@ impl MessageReceiver { retries += 1; if retries >= MAX_DOWNLOAD_RETRIES { return Err(ServiceError::Timeout { - reason: "too many retries".into(), + reason: "too many retries", }); } }, diff --git a/src/sender.rs b/src/sender.rs index aea24d508..8a5eb86c8 100644 --- a/src/sender.rs +++ b/src/sender.rs @@ -108,13 +108,18 @@ pub enum AttachmentUploadError { #[derive(thiserror::Error, Debug)] pub enum MessageSenderError { - #[error("{0}")] + #[error("service error: {0}")] ServiceError(#[from] ServiceError), + #[error("protocol error: {0}")] ProtocolError(#[from] SignalProtocolError), + #[error("Failed to upload attachment {0}")] AttachmentUploadError(#[from] AttachmentUploadError), + #[error("primary device can't send sync message {0:?}")] + SendSyncMessageError(sync_message::request::Type), + #[error("Untrusted identity key with {address:?}")] UntrustedIdentity { address: ServiceAddress }, @@ -778,13 +783,7 @@ where request_type: sync_message::request::Type, ) -> Result<(), MessageSenderError> { if self.device_id == DEFAULT_DEVICE_ID.into() { - let reason = format!( - "Primary device can't send sync requests, ignoring {:?}", - request_type - ); - return Err(MessageSenderError::ServiceError( - ServiceError::SendError { reason }, - )); + return Err(MessageSenderError::SendSyncMessageError(request_type)); } let msg = SyncMessage { diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs index 48b00ceb7..c279b1f28 100644 --- a/src/websocket/mod.rs +++ b/src/websocket/mod.rs @@ -123,7 +123,7 @@ impl SignalWebSocketProcess { use web_socket_message::Type; match (msg.r#type(), msg.request, msg.response) { (Type::Unknown, _, _) => Err(ServiceError::InvalidFrame { - reason: "Unknown frame type".into(), + reason: "unknown frame type", }), (Type::Request, Some(request), _) => { let (sink, recv) = oneshot::channel(); @@ -138,8 +138,7 @@ impl SignalWebSocketProcess { Ok(()) }, (Type::Request, None, _) => Err(ServiceError::InvalidFrame { - reason: "Type was request, but does not contain request." - .into(), + reason: "type was request, but does not contain request", }), (Type::Response, _, Some(response)) => { if let Some(id) = response.id { @@ -176,8 +175,7 @@ impl SignalWebSocketProcess { Ok(()) }, (Type::Response, _, None) => Err(ServiceError::InvalidFrame { - reason: "Type was response, but does not contain response." - .into(), + reason: "type was response, but does not contain response", }), } } @@ -413,14 +411,14 @@ impl SignalWebSocket { async move { if let Err(_e) = request_sink.send((r, sink)).await { return Err(ServiceError::WsClosing { - reason: "WebSocket closing while sending request.", + reason: "WebSocket closing while sending request", }); } // Handle the oneshot sender error for dropped senders. match recv.await { Ok(x) => x, Err(_) => Err(ServiceError::WsClosing { - reason: "WebSocket closing while waiting for a response.", + reason: "WebSocket closing while waiting for a response", }), } } From ad0666c0d3503b9d7bc2e1fcc2f73f9b08c959b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20F=C3=A9ron?= Date: Fri, 18 Oct 2024 21:17:08 +0200 Subject: [PATCH 11/14] Fix wrong header concat --- src/websocket/request.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websocket/request.rs b/src/websocket/request.rs index 71e0fd488..945d43c6d 100644 --- a/src/websocket/request.rs +++ b/src/websocket/request.rs @@ -34,7 +34,7 @@ impl WebSocketRequestMessageBuilder { pub fn header(mut self, key: &str, value: impl AsRef) -> Self { self.request .headers - .push(format!("{key}={}", value.as_ref())); + .push(format!("{key}:{}", value.as_ref())); self } From 771ec18e20e4a5c617b54beeb58a0790fe71ebb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20F=C3=A9ron?= Date: Fri, 18 Oct 2024 22:23:03 +0200 Subject: [PATCH 12/14] WIP --- src/push_service/cdn.rs | 44 +++++++++++++++++++++------------------- src/sender.rs | 2 +- src/websocket/request.rs | 2 +- 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/src/push_service/cdn.rs b/src/push_service/cdn.rs index daf37a1d7..be1eaaf80 100644 --- a/src/push_service/cdn.rs +++ b/src/push_service/cdn.rs @@ -1,7 +1,7 @@ use std::io::{self, Read}; use futures::TryStreamExt; -use reqwest::{multipart::Part, Method}; +use reqwest::{header::CONTENT_TYPE, multipart::Part, Method}; use tracing::debug; use crate::{ @@ -116,37 +116,39 @@ impl PushService { filename: String, mut reader: impl Read + Send, ) -> Result<(), ServiceError> { - // Amazon S3 expects multipart fields in a very specific order (the file contents should go last.) - let mut form = reqwest::multipart::Form::new(); - form = form.text("acl", upload_attributes.acl); - form = form.text("key", upload_attributes.key); - form = form.text("policy", upload_attributes.policy); - form = form.text("x-amz-algorithm", upload_attributes.algorithm); - form = form.text("x-amz-credential", upload_attributes.credential); - form = form.text("x-amz-date", upload_attributes.date); - form = form.text("x-amz-signature", upload_attributes.signature); - let mut buf = Vec::new(); reader .read_to_end(&mut buf) .expect("infallible Read instance"); - form = form.text("Content-Type", "application/octet-stream"); - form = form.text("Content-Length", buf.len().to_string()); - form = form.part("file", Part::bytes(buf).file_name(filename)); - - let response = self + // Amazon S3 expects multipart fields in a very specific order (the file contents should go last.) + let form = reqwest::multipart::Form::new() + .text("acl", upload_attributes.acl) + .text("key", upload_attributes.key) + .text("policy", upload_attributes.policy) + .text("x-amz-algorithm", upload_attributes.algorithm) + .text("x-amz-credential", upload_attributes.credential) + .text("x-amz-date", upload_attributes.date) + .text("x-amz-signature", upload_attributes.signature) + .part("file", Part::bytes(buf)); + + let content_type = + format!("multipart/form-data; boundary={}", form.boundary()); + + dbg!(content_type); + + let response = dbg!(self .request( Method::POST, Endpoint::Cdn(0), path, HttpAuthOverride::NoOverride, )? - .multipart(form) - .send() - .await? - .service_error_for_status() - .await?; + .multipart(form)) + .send() + .await? + .service_error_for_status() + .await?; debug!("HyperPushService::PUT response: {:?}", response); diff --git a/src/sender.rs b/src/sender.rs index 8a5eb86c8..691e03939 100644 --- a/src/sender.rs +++ b/src/sender.rs @@ -68,7 +68,7 @@ pub struct SentMessage { /// Attachment specification to be used for uploading. /// /// Loose equivalent of Java's `SignalServiceAttachmentStream`. -#[derive(Debug)] +#[derive(Debug, Default)] pub struct AttachmentSpec { pub content_type: String, pub length: usize, diff --git a/src/websocket/request.rs b/src/websocket/request.rs index 945d43c6d..d10306f29 100644 --- a/src/websocket/request.rs +++ b/src/websocket/request.rs @@ -43,7 +43,7 @@ impl WebSocketRequestMessageBuilder { value: S, ) -> Result { self.request.body = Some(serde_json::to_vec(&value)?); - Ok(self.header("Content-Type", "application/json").request) + Ok(self.header("content-type", "application/json").request) } pub fn build(self) -> WebSocketRequestMessage { From 5f0e3558893ef977d95811604765e328f8c3ff96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20F=C3=A9ron?= Date: Fri, 18 Oct 2024 23:14:05 +0200 Subject: [PATCH 13/14] Fix CDN0 multipart upload --- src/push_service/cdn.rs | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/push_service/cdn.rs b/src/push_service/cdn.rs index be1eaaf80..201d439e3 100644 --- a/src/push_service/cdn.rs +++ b/src/push_service/cdn.rs @@ -121,34 +121,36 @@ impl PushService { .read_to_end(&mut buf) .expect("infallible Read instance"); - // Amazon S3 expects multipart fields in a very specific order (the file contents should go last.) + // Amazon S3 expects multipart fields in a very specific order + // DO NOT CHANGE THIS (or do it, but feel the wrath of the gods) let form = reqwest::multipart::Form::new() .text("acl", upload_attributes.acl) .text("key", upload_attributes.key) .text("policy", upload_attributes.policy) + .text("Content-Type", "application/octet-stream") .text("x-amz-algorithm", upload_attributes.algorithm) .text("x-amz-credential", upload_attributes.credential) .text("x-amz-date", upload_attributes.date) .text("x-amz-signature", upload_attributes.signature) - .part("file", Part::bytes(buf)); - - let content_type = - format!("multipart/form-data; boundary={}", form.boundary()); - - dbg!(content_type); - - let response = dbg!(self + .part( + "file", + Part::stream(buf) + .mime_str("application/octet-stream")? + .file_name(filename), + ); + + let response = self .request( Method::POST, Endpoint::Cdn(0), path, HttpAuthOverride::NoOverride, )? - .multipart(form)) - .send() - .await? - .service_error_for_status() - .await?; + .multipart(form) + .send() + .await? + .service_error_for_status() + .await?; debug!("HyperPushService::PUT response: {:?}", response); From 69a09a99754f42bffd28aff9c2180860471e911f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20F=C3=A9ron?= Date: Fri, 18 Oct 2024 23:15:48 +0200 Subject: [PATCH 14/14] Clippy --- src/cipher.rs | 3 +-- src/push_service/cdn.rs | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/cipher.rs b/src/cipher.rs index 81460274f..057cc8332 100644 --- a/src/cipher.rs +++ b/src/cipher.rs @@ -136,8 +136,7 @@ where } else { return Err(ServiceError::InvalidFrame { reason: - "Envelope should have either a legacy message or content." - .into(), + "envelope should have either a legacy message or content.", }); }; diff --git a/src/push_service/cdn.rs b/src/push_service/cdn.rs index 201d439e3..b87e3e8a3 100644 --- a/src/push_service/cdn.rs +++ b/src/push_service/cdn.rs @@ -1,7 +1,7 @@ use std::io::{self, Read}; use futures::TryStreamExt; -use reqwest::{header::CONTENT_TYPE, multipart::Part, Method}; +use reqwest::{multipart::Part, Method}; use tracing::debug; use crate::{