Skip to content

Commit

Permalink
feat: User registration (#185)
Browse files Browse the repository at this point in the history
* Add a stub handler for the registration endpoint

* Add extractors for registration endpoints

Also we now use the `Bytes` and `Json` extractors to read the payload.
The payload size limit is configured through the respective extractor
config objects.

* Increment ua.command.register metric

* Implement router registration

* Store the user in the database during registration

* Add the channel to the database and accept more data in the request body

* Move make_endpoint to the common code and use in user registration

* Generate a secret for future requests and return the registration data

* Fix an incorrect expression value and missing current_month value

* Support compiling with OpenSSL 1.0

The `sign_oneshot_to_vec` method is only available with OpenSSL >=1.1.1.

* Add logging to register_uaid_route

* Use hyphenated UUIDs in the message table "chids" column

* Fix errors after rebase

* Simplify endpoint creation

* Use the lowercase hyphenated formatter (default) when returning UUIDs

Closes #176

Co-authored-by: JR Conlin <[email protected]>
  • Loading branch information
AzureMarker and jrconlin authored Jul 27, 2020
1 parent e55b977 commit 6df3e36
Show file tree
Hide file tree
Showing 21 changed files with 438 additions and 89 deletions.
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

58 changes: 53 additions & 5 deletions autoendpoint/src/db/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::db::retry::{
};
use autopush_common::db::{DynamoDbNotification, DynamoDbUser};
use autopush_common::notification::Notification;
use autopush_common::util::sec_since_epoch;
use autopush_common::{ddb_item, hashmap, val};
use cadence::StatsdClient;
use rusoto_core::credential::StaticProvider;
Expand All @@ -17,6 +18,9 @@ use std::collections::HashSet;
use std::env;
use uuid::Uuid;

/// The maximum TTL for channels, 30 days
const MAX_CHANNEL_TTL: u64 = 30 * 24 * 60 * 60;

/// Provides high-level operations over the DynamoDB database
#[derive(Clone)]
pub struct DbClient {
Expand Down Expand Up @@ -53,6 +57,25 @@ impl DbClient {
})
}

/// Add a new user to the database. An error will occur if the user already
/// exists.
pub async fn add_user(&self, user: &DynamoDbUser) -> DbResult<()> {
let input = PutItemInput {
item: serde_dynamodb::to_hashmap(user)?,
table_name: self.router_table.to_string(),
condition_expression: Some("attribute_not_exists(uaid)".to_string()),
..Default::default()
};

retry_policy()
.retry_if(
move || self.ddb.put_item(input.clone()),
retryable_putitem_error(self.metrics.clone()),
)
.await?;
Ok(())
}

/// Read a user from the database
pub async fn get_user(&self, uaid: Uuid) -> DbResult<Option<DynamoDbUser>> {
let input = GetItemInput {
Expand Down Expand Up @@ -91,8 +114,33 @@ impl DbClient {
Ok(())
}

/// Add a channel to a user
pub async fn add_channel(&self, uaid: Uuid, channel_id: Uuid) -> DbResult<()> {
let input = UpdateItemInput {
table_name: self.message_table.clone(),
key: ddb_item! {
uaid: s => uaid.to_simple().to_string(),
chidmessageid: s => " ".to_string()
},
update_expression: Some("ADD chids :channel_id SET expiry = :expiry".to_string()),
expression_attribute_values: Some(hashmap! {
":channel_id".to_string() => val!(SS => Some(channel_id.to_hyphenated())),
":expiry".to_string() => val!(N => sec_since_epoch() + MAX_CHANNEL_TTL)
}),
..Default::default()
};

retry_policy()
.retry_if(
|| self.ddb.update_item(input.clone()),
retryable_updateitem_error(self.metrics.clone()),
)
.await?;
Ok(())
}

/// Get the set of channel IDs for a user
pub async fn get_user_channels(&self, uaid: Uuid) -> DbResult<HashSet<Uuid>> {
pub async fn get_channels(&self, uaid: Uuid) -> DbResult<HashSet<Uuid>> {
// Channel IDs are stored in a special row in the message table, where
// chidmessageid = " "
let input = GetItemInput {
Expand Down Expand Up @@ -141,7 +189,7 @@ impl DbClient {
node_id: String,
connected_at: u64,
) -> DbResult<()> {
let update_item = UpdateItemInput {
let input = UpdateItemInput {
key: ddb_item! { uaid: s => uaid.to_simple().to_string() },
update_expression: Some("REMOVE node_id".to_string()),
condition_expression: Some("(node_id = :node) and (connected_at = :conn)".to_string()),
Expand All @@ -155,7 +203,7 @@ impl DbClient {

retry_policy()
.retry_if(
|| self.ddb.update_item(update_item.clone()),
|| self.ddb.update_item(input.clone()),
retryable_updateitem_error(self.metrics.clone()),
)
.await?;
Expand All @@ -165,15 +213,15 @@ impl DbClient {

/// Store a single message
pub async fn store_message(&self, uaid: Uuid, message: Notification) -> DbResult<()> {
let put_item = PutItemInput {
let input = PutItemInput {
item: serde_dynamodb::to_hashmap(&DynamoDbNotification::from_notif(&uaid, message))?,
table_name: self.message_table.clone(),
..Default::default()
};

retry_policy()
.retry_if(
|| self.ddb.put_item(put_item.clone()),
|| self.ddb.put_item(input.clone()),
retryable_putitem_error(self.metrics.clone()),
)
.await?;
Expand Down
47 changes: 33 additions & 14 deletions autoendpoint/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::headers::vapid::VapidError;
use crate::routers::RouterError;
use actix_web::{
dev::{HttpResponseBuilder, ServiceResponse},
error::{PayloadError, ResponseError},
error::{JsonPayloadError, PayloadError, ResponseError},
http::StatusCode,
middleware::errhandlers::ErrorHandlerResponse,
HttpResponse, Result,
Expand Down Expand Up @@ -55,9 +55,8 @@ pub enum ApiErrorKind {
#[error(transparent)]
Validation(#[from] validator::ValidationErrors),

// PayloadError does not implement std Error
#[error("{0}")]
PayloadError(PayloadError),
#[error(transparent)]
PayloadError(actix_web::Error),

#[error(transparent)]
VapidError(#[from] VapidError),
Expand All @@ -71,6 +70,12 @@ pub enum ApiErrorKind {
#[error("Error while validating token")]
TokenHashValidation(#[source] openssl::error::ErrorStack),

#[error("Error while creating secret")]
RegistrationSecretHash(#[source] openssl::error::ErrorStack),

#[error("Error while creating endpoint URL: {0}")]
EndpointUrl(#[source] autopush_common::errors::Error),

#[error("Database error: {0}")]
Database(#[from] DbError),

Expand All @@ -87,16 +92,19 @@ pub enum ApiErrorKind {
#[error("{0}")]
InvalidEncryption(String),

#[error("Data payload must be smaller than {0} bytes")]
PayloadTooLarge(usize),

/// Used if the API version given is not v1 or v2
#[error("Invalid API version")]
InvalidApiVersion,

#[error("Missing TTL value")]
NoTTL,

#[error("Invalid router type")]
InvalidRouterType,

#[error("Invalid router token")]
InvalidRouterToken,

#[error("{0}")]
Internal(String),
}
Expand All @@ -105,25 +113,27 @@ impl ApiErrorKind {
/// Get the associated HTTP status code
pub fn status(&self) -> StatusCode {
match self {
ApiErrorKind::PayloadError(e) => e.status_code(),
ApiErrorKind::PayloadError(e) => e.as_response_error().status_code(),
ApiErrorKind::Router(e) => e.status(),

ApiErrorKind::Validation(_)
| ApiErrorKind::InvalidEncryption(_)
| ApiErrorKind::TokenHashValidation(_)
| ApiErrorKind::NoTTL => StatusCode::BAD_REQUEST,
| ApiErrorKind::NoTTL
| ApiErrorKind::InvalidRouterType
| ApiErrorKind::InvalidRouterToken => StatusCode::BAD_REQUEST,

ApiErrorKind::NoUser | ApiErrorKind::NoSubscription => StatusCode::GONE,

ApiErrorKind::VapidError(_) | ApiErrorKind::Jwt(_) => StatusCode::UNAUTHORIZED,

ApiErrorKind::InvalidToken | ApiErrorKind::InvalidApiVersion => StatusCode::NOT_FOUND,

ApiErrorKind::PayloadTooLarge(_) => StatusCode::PAYLOAD_TOO_LARGE,

ApiErrorKind::Io(_)
| ApiErrorKind::Metrics(_)
| ApiErrorKind::Database(_)
| ApiErrorKind::EndpointUrl(_)
| ApiErrorKind::RegistrationSecretHash(_)
| ApiErrorKind::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
}
}
Expand All @@ -139,11 +149,17 @@ impl ApiErrorKind {

ApiErrorKind::NoUser => Some(103),

ApiErrorKind::PayloadError(PayloadError::Overflow)
| ApiErrorKind::PayloadTooLarge(_) => Some(104),
ApiErrorKind::PayloadError(error)
if matches!(error.as_error(), Some(PayloadError::Overflow))
|| matches!(error.as_error(), Some(JsonPayloadError::Overflow)) =>
{
Some(104)
}

ApiErrorKind::NoSubscription => Some(106),

ApiErrorKind::InvalidRouterType => Some(108),

ApiErrorKind::VapidError(_)
| ApiErrorKind::TokenHashValidation(_)
| ApiErrorKind::Jwt(_) => Some(109),
Expand All @@ -157,7 +173,10 @@ impl ApiErrorKind {
ApiErrorKind::Io(_)
| ApiErrorKind::Metrics(_)
| ApiErrorKind::Database(_)
| ApiErrorKind::PayloadError(_) => None,
| ApiErrorKind::PayloadError(_)
| ApiErrorKind::InvalidRouterToken
| ApiErrorKind::RegistrationSecretHash(_)
| ApiErrorKind::EndpointUrl(_) => None,
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions autoendpoint/src/extractors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
pub mod notification;
pub mod notification_headers;
pub mod registration_path_args;
pub mod router_data_input;
pub mod routers;
pub mod subscription;
pub mod token_info;
Expand Down
17 changes: 5 additions & 12 deletions autoendpoint/src/extractors/notification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ use crate::extractors::subscription::Subscription;
use crate::server::ServerState;
use actix_web::dev::{Payload, PayloadStream};
use actix_web::web::Data;
use actix_web::{FromRequest, HttpRequest};
use actix_web::{web, FromRequest, HttpRequest};
use autopush_common::util::sec_since_epoch;
use cadence::Counted;
use fernet::MultiFernet;
use futures::{future, FutureExt, StreamExt};
use futures::{future, FutureExt};
use std::collections::HashMap;
use uuid::Uuid;

Expand Down Expand Up @@ -38,16 +38,9 @@ impl FromRequest for Notification {
.expect("No server state found");

// Read data
let mut data = Vec::new();
while let Some(item) = payload.next().await {
data.extend_from_slice(&item.map_err(ApiErrorKind::PayloadError)?);

// Make sure the payload isn't too big
let max_bytes = state.settings.max_data_bytes;
if data.len() > max_bytes {
return Err(ApiErrorKind::PayloadTooLarge(max_bytes).into());
}
}
let data = web::Bytes::from_request(&req, &mut payload)
.await
.map_err(ApiErrorKind::PayloadError)?;

// Convert data to base64
let data = if data.is_empty() {
Expand Down
38 changes: 38 additions & 0 deletions autoendpoint/src/extractors/registration_path_args.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use crate::error::{ApiError, ApiErrorKind};
use crate::extractors::routers::RouterType;
use actix_web::dev::{Payload, PayloadStream};
use actix_web::{FromRequest, HttpRequest};
use futures::future;

/// Extracts and validates the `router_type` and `app_id` path arguments
pub struct RegistrationPathArgs {
pub router_type: RouterType,
pub app_id: String,
}

impl FromRequest for RegistrationPathArgs {
type Error = ApiError;
type Future = future::Ready<Result<Self, Self::Error>>;
type Config = ();

fn from_request(req: &HttpRequest, _: &mut Payload<PayloadStream>) -> Self::Future {
let match_info = req.match_info();
let router_type = match match_info
.get("router_type")
.expect("{router_type} must be part of the path")
.parse::<RouterType>()
{
Ok(router_type) => router_type,
Err(_) => return future::err(ApiErrorKind::InvalidRouterType.into()),
};
let app_id = match_info
.get("app_id")
.expect("{app_id} must be part of the path")
.to_string();

future::ok(Self {
router_type,
app_id,
})
}
}
Loading

0 comments on commit 6df3e36

Please sign in to comment.