Skip to content

Commit

Permalink
Initialize JWT and protected routes (#4)
Browse files Browse the repository at this point in the history
* Add jwt module

* Add axum_extras and jsonwebtoken packages

* Return user_id when inserting user

* Add new middleware modules

* Add new user route

* Return JWT when registering user

* Add protected /user/me route

* Fix error conversion for bad encoding
  • Loading branch information
sneakycrow authored Oct 31, 2024
1 parent ca10c87 commit 9fb85d6
Show file tree
Hide file tree
Showing 10 changed files with 498 additions and 40 deletions.
405 changes: 374 additions & 31 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions packages/api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ edition = "2021"

[dependencies]
axum = { version = "0.7", features = ["multipart", "tracing", "ws"] }
axum-extra = "0.3"
chrono = { version = "0.4.31", features = ["serde"] }
db = { path = "../db" }
jsonwebtoken = "8.1"
nanoid = "0.4.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
Expand Down
53 changes: 53 additions & 0 deletions packages/api/src/jwt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use chrono::{Duration, Utc};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, TokenData, Validation};
use serde::{Deserialize, Serialize};

#[derive(Serialize)]
pub enum JWTError {
DecodingError,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub exp: usize, // Expiry time of the token
pub iat: usize, // Issued at time of the token
pub user_id: String, // User ID associated with the token
}

/// Gets the JWT_SECRET from the environment variables and converts to bytes
fn get_secret() -> Vec<u8> {
std::env::var("JWT_SECRET")
.expect("JWT_SECRET must be set")
.into_bytes()
}

/// Creates a JWT token containing the given user_id
pub fn encode_jwt(user_id: &str) -> Result<String, jsonwebtoken::errors::Error> {
let jwt_secret = get_secret();
let now = Utc::now();
let expire: chrono::TimeDelta = Duration::hours(24);
let exp: usize = (now + expire).timestamp() as usize;
let iat: usize = now.timestamp() as usize;
let claim = Claims {
iat,
exp,
user_id: user_id.to_owned(),
};

encode(
&Header::default(),
&claim,
&EncodingKey::from_secret(&jwt_secret),
)
}

/// Decodes a JWT token
pub fn decode_jwt(jwt_token: String) -> Result<TokenData<Claims>, JWTError> {
let jwt_secret = get_secret();
decode(
&jwt_token,
&DecodingKey::from_secret(&jwt_secret),
&Validation::default(),
)
.map_err(|_| JWTError::DecodingError)
}
13 changes: 11 additions & 2 deletions packages/api/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
mod config;
mod jwt;
mod middleware;
mod routes;

use axum::{
extract::DefaultBodyLimit,
middleware as axum_mw,
response::IntoResponse,
routing::{get, post},
Router,
Expand Down Expand Up @@ -57,8 +59,15 @@ async fn main() {
"/auth",
Router::new().route("/register", post(routes::auth::register)),
)
.nest(
"/user",
Router::new()
.route("/me", get(routes::user::get_user))
.layer(axum_mw::from_fn(middleware::auth::auth_middleware)),
)
.route("/register", post(routes::auth::register))
.layer(DefaultBodyLimit::max(5 * 1024 * 1024 * 1024)) // 5GB limit
// TODO: Attach this to the upload route when you re-add it
// .layer(DefaultBodyLimit::max(5 * 1024 * 1024 * 1024)) // 5GB limit
.with_state(state)
.layer(
TraceLayer::new_for_http()
Expand Down
29 changes: 29 additions & 0 deletions packages/api/src/middleware/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use crate::jwt::decode_jwt;
use axum::{
body::Body,
extract::Request,
http::{self, Response, StatusCode},
middleware::Next,
};

/// A middleware for checking the validity of the JWT token
pub async fn auth_middleware(mut req: Request, next: Next) -> Result<Response<Body>, StatusCode> {
// Get the auth header from the request
let raw_auth_header = req.headers_mut().get(http::header::AUTHORIZATION);
// Pull the full header string out of the header
let auth_header = match raw_auth_header {
Some(header) => header.to_str().map_err(|_| StatusCode::BAD_REQUEST),
None => return Err(StatusCode::BAD_REQUEST),
}?;
// Full header is expected to be `Bearer token`, split by whitespace
let mut split_header = auth_header.split_whitespace();
// It _should_ only be two values, we care about the token value
let (_bearer, token) = (split_header.next(), split_header.next());
let jwt_token = token.expect("Could not parse token").to_owned();
let token_claims = match decode_jwt(jwt_token) {
Ok(token) => token,
Err(jwt_err) => return Err(StatusCode::UNAUTHORIZED),
};
// TODO: Using the token_claims, get the User and add it to the request for downstream usage
Ok(next.run(req).await)
}
1 change: 1 addition & 0 deletions packages/api/src/middleware/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod auth;
21 changes: 18 additions & 3 deletions packages/api/src/routes/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use axum::{extract::State, http::StatusCode, Json};
use db::users::{hash_password, insert_user};
use serde::{Deserialize, Serialize};

use crate::AppState;
use crate::{jwt::encode_jwt, AppState};

#[derive(Deserialize)]
pub struct RegisterRequest {
Expand All @@ -14,6 +14,11 @@ pub struct RegisterRequest {
password_confirmation: String,
}

#[derive(Debug, Deserialize, Serialize)]
pub struct RegisterResponse {
token: String,
}

#[derive(Serialize)]
pub struct ErrorResponse {
message: String,
Expand All @@ -23,7 +28,7 @@ pub struct ErrorResponse {
pub async fn register(
State(state): State<Arc<AppState>>,
Json(payload): Json<RegisterRequest>,
) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
) -> Result<Json<RegisterResponse>, (StatusCode, Json<ErrorResponse>)> {
// Validate inputs
if payload.password != payload.password_confirmation {
return Err((
Expand All @@ -48,7 +53,17 @@ pub async fn register(

// Insert the new user into the database
match insert_user(&state.db, &payload.email, &payload.username, &password_hash).await {
Ok(_) => Ok(StatusCode::CREATED),
Ok(user_id) => {
let token = encode_jwt(&user_id).map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
message: "Could not encode JWT token".to_string(),
}),
)
})?;
Ok(Json(RegisterResponse { token }))
}
Err(e) => {
let error_message = match e {
sqlx::Error::Database(ref db_error) if db_error.is_unique_violation() => {
Expand Down
1 change: 1 addition & 0 deletions packages/api/src/routes/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod auth;
pub mod user;
4 changes: 4 additions & 0 deletions packages/api/src/routes/user.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
/// Gets a user by their ID
pub async fn get_user() {
todo!("Fetch user information")
}
8 changes: 4 additions & 4 deletions packages/db/src/users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ pub async fn insert_user(
email: &str,
username: &str,
password: &str,
) -> Result<(), sqlx::Error> {
let id = Uuid::new_v4();
) -> Result<String, sqlx::Error> {
let user_id = Uuid::new_v4();
sqlx::query("INSERT INTO users (id, email, username, password_hash) VALUES ($1, $2, $3, $4)")
.bind(id)
.bind(user_id)
.bind(email)
.bind(username)
.bind(password)
.execute(pool)
.await?;

Ok(())
Ok(user_id.to_string())
}

// Hash a password using Argon2
Expand Down

0 comments on commit 9fb85d6

Please sign in to comment.