diff --git a/api/src/main.rs b/api/src/main.rs index f8680d6..6a03c3d 100644 --- a/api/src/main.rs +++ b/api/src/main.rs @@ -17,8 +17,9 @@ use models::api_models::AccessTokenClaims; use routes::{ clip_search::clip_search, cluster_previews::cluster_previews, face_previews::face_previews, faces::faces, login::login, logs::logs, media::media, preview::preview, previews::previews, - refresh::refresh, register::register, sync_full::sync_full, sync_partial::sync_partial, - upload_image::upload_image, + refresh::refresh, sync_full::sync_full, sync_partial::sync_partial, upload_image::upload_image, + create_face::create_face, + }; use s3::{creds::Credentials, error::S3Error, Bucket, BucketConfiguration, Region}; use serde::Deserialize; @@ -109,7 +110,7 @@ async fn main() -> Result<(), Box> { .route("/register", post(register)) .route("/refresh", post(refresh)); - let private_routes = Router::new() + let private_routes = Router::new() .route( "/image/upload", post(upload_image).route_layer(DefaultBodyLimit::max(10737418240)), @@ -124,6 +125,7 @@ async fn main() -> Result<(), Box> { .route("/cluster/:cluster_id", get(cluster_previews)) .route("/face/:face_id", get(face_previews)) .route("/search", get(clip_search)) + .route("/create_face", post(create_face)) .layer(middleware::from_fn_with_state( server_config.secret.clone(), auth_middleware, diff --git a/api/src/models/api_models.rs b/api/src/models/api_models.rs index c1dea88..6da5d77 100644 --- a/api/src/models/api_models.rs +++ b/api/src/models/api_models.rs @@ -118,3 +118,10 @@ pub struct SearchQuery { pub page: Option, pub page_size: Option, } + + +#[derive(Deserialize)] +pub struct CreateFacePayload { + pub ids: Vec, + pub name: String, +} \ No newline at end of file diff --git a/api/src/routes/create_face.rs b/api/src/routes/create_face.rs new file mode 100644 index 0000000..602cf45 --- /dev/null +++ b/api/src/routes/create_face.rs @@ -0,0 +1,24 @@ +use axum::{ + extract::{Json, State, Extension}, + http::StatusCode, +}; +use crate::models::api_models::CreateFacePayload; +use crate::ServerConfig; + +pub async fn create_face( + State(server_config): State, + Extension(user_id): Extension, + Json(payload): Json, +) -> StatusCode { + // println!("User ID: {}", user_id); + // println!("Payload IDs: {:?}", payload.ids); + // println!("Payload Name: {}", payload.name); + match server_config + .database + .insert_face(user_id.clone(), payload.ids.clone(), payload.name.clone()) + .await + { + Ok(_) => StatusCode::OK, + Err(_) => StatusCode::INTERNAL_SERVER_ERROR, + } +} diff --git a/api/src/routes/mod.rs b/api/src/routes/mod.rs index 841e696..5d64967 100644 --- a/api/src/routes/mod.rs +++ b/api/src/routes/mod.rs @@ -12,3 +12,4 @@ pub mod register; pub mod sync_full; pub mod sync_partial; pub mod upload_image; +pub mod create_face; \ No newline at end of file diff --git a/database/src/lib.rs b/database/src/lib.rs index b92db39..b924afc 100644 --- a/database/src/lib.rs +++ b/database/src/lib.rs @@ -581,6 +581,41 @@ impl DbManager { )), } } + + + pub async fn insert_face( + &self, + user_id: String, + cluster_ids: Vec, + name: String, + ) -> Result<(), DbErr> { + let new_face = face::ActiveModel { + name: Set(name), + ..Default::default() + }; + let face_result = face::Entity::insert(new_face) + .exec(&self.connection) + .await?; + let face_id = face_result.last_insert_id; + + for cluster_id in cluster_ids { + cluster::Entity::update_many() + .filter( + cluster::Column::Id + .eq(cluster_id) + .and(cluster::Column::UserId.eq(user_id.clone())), + ) + .set(cluster::ActiveModel { + face_id: Set(Some(face_id)), + ..Default::default() + }) + .exec(&self.connection) + .await?; + } + + Ok(()) + } + } pub enum GetPreviewError {