Skip to content

Commit

Permalink
Merge pull request #61 from Chronolens/face_creation_endpoint
Browse files Browse the repository at this point in the history
Face creation endpoint
  • Loading branch information
lucasverdelho authored Nov 28, 2024
2 parents 5be6c12 + c47ac0b commit 6370714
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 3 deletions.
8 changes: 5 additions & 3 deletions api/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ use models::api_models::AccessTokenClaims;
use routes::{
clip_search::clip_search, cluster_previews::cluster_previews, face_previews::face_previews,
faces::faces, login::login, logs::logs, media::media, preview::preview, previews::previews,
refresh::refresh, register::register, sync_full::sync_full, sync_partial::sync_partial,
upload_image::upload_image,
refresh::refresh, sync_full::sync_full, sync_partial::sync_partial, upload_image::upload_image,
create_face::create_face,

};
use s3::{creds::Credentials, error::S3Error, Bucket, BucketConfiguration, Region};
use serde::Deserialize;
Expand Down Expand Up @@ -109,7 +110,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.route("/register", post(register))
.route("/refresh", post(refresh));

let private_routes = Router::new()
let private_routes = Router::new()
.route(
"/image/upload",
post(upload_image).route_layer(DefaultBodyLimit::max(10737418240)),
Expand All @@ -124,6 +125,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.route("/cluster/:cluster_id", get(cluster_previews))
.route("/face/:face_id", get(face_previews))
.route("/search", get(clip_search))
.route("/create_face", post(create_face))
.layer(middleware::from_fn_with_state(
server_config.secret.clone(),
auth_middleware,
Expand Down
7 changes: 7 additions & 0 deletions api/src/models/api_models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,10 @@ pub struct SearchQuery {
pub page: Option<u32>,
pub page_size: Option<u32>,
}


#[derive(Deserialize)]
pub struct CreateFacePayload {
pub ids: Vec<i32>,
pub name: String,
}
24 changes: 24 additions & 0 deletions api/src/routes/create_face.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use axum::{
extract::{Json, State, Extension},
http::StatusCode,
};
use crate::models::api_models::CreateFacePayload;
use crate::ServerConfig;

pub async fn create_face(
State(server_config): State<ServerConfig>,
Extension(user_id): Extension<String>,
Json(payload): Json<CreateFacePayload>,
) -> StatusCode {
// println!("User ID: {}", user_id);
// println!("Payload IDs: {:?}", payload.ids);
// println!("Payload Name: {}", payload.name);
match server_config
.database
.insert_face(user_id.clone(), payload.ids.clone(), payload.name.clone())
.await
{
Ok(_) => StatusCode::OK,
Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
}
}
1 change: 1 addition & 0 deletions api/src/routes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ pub mod register;
pub mod sync_full;
pub mod sync_partial;
pub mod upload_image;
pub mod create_face;
35 changes: 35 additions & 0 deletions database/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,41 @@ impl DbManager {
)),
}
}


pub async fn insert_face(
&self,
user_id: String,
cluster_ids: Vec<i32>,
name: String,
) -> Result<(), DbErr> {
let new_face = face::ActiveModel {
name: Set(name),
..Default::default()
};
let face_result = face::Entity::insert(new_face)
.exec(&self.connection)
.await?;
let face_id = face_result.last_insert_id;

for cluster_id in cluster_ids {
cluster::Entity::update_many()
.filter(
cluster::Column::Id
.eq(cluster_id)
.and(cluster::Column::UserId.eq(user_id.clone())),
)
.set(cluster::ActiveModel {
face_id: Set(Some(face_id)),
..Default::default()
})
.exec(&self.connection)
.await?;
}

Ok(())
}

}

pub enum GetPreviewError {
Expand Down

0 comments on commit 6370714

Please sign in to comment.