Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add to database endpoint #593

Merged
merged 16 commits into from
Oct 30, 2024
Merged
12 changes: 9 additions & 3 deletions screenpipe-app-tauri/components/search-chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -683,9 +683,15 @@ export function SearchChat() {
{item.type === "Audio" && (
<>
<p className="mt-2">{item.content.transcription}</p>
<div className="flex justify-center mt-4">
<VideoComponent filePath={item.content.file_path} />
</div>
{item.content.file_path && item.content.file_path.trim() !== "" ? (
<div className="flex justify-center mt-4">
<VideoComponent filePath={item.content.file_path} />
</div>
) : (
<p className="text-gray-500 italic mt-2">
No file path available for this audio.
</p>
)}
</>
)}
{item.type === "FTS" && (
Expand Down
2 changes: 1 addition & 1 deletion screenpipe-server/benches/db_benchmarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async fn setup_large_db(size: usize) -> DatabaseManager {

for _ in 0..size {
let _video_id = db.insert_video_chunk("test_video.mp4", "test_device").await.unwrap();
let frame_id = db.insert_frame().await.unwrap();
let frame_id = db.insert_frame("test_device", None).await.unwrap();
let ocr_text = format!("OCR text {}", rng.gen::<u32>());
let text_json = format!(r#"{{"text": "{}"}}"#, ocr_text);
db.insert_ocr_text(
Expand Down
2 changes: 1 addition & 1 deletion screenpipe-server/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ async fn record_video(
while is_running.load(Ordering::SeqCst) {
if let Some(frame) = video_capture.ocr_frame_queue.pop() {
for window_result in &frame.window_ocr_results {
match db.insert_frame(&device_name).await {
match db.insert_frame(&device_name, None).await {
Ok(frame_id) => {
let text_json =
serde_json::to_string(&window_result.text_json).unwrap_or_default();
Expand Down
20 changes: 13 additions & 7 deletions screenpipe-server/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ impl DatabaseManager {

Ok(id)
}

pub async fn update_audio_transcription(
&self,
audio_chunk_id: i64,
Expand Down Expand Up @@ -272,7 +273,10 @@ impl DatabaseManager {
Ok(id)
}

pub async fn insert_frame(&self, device_name: &str) -> Result<i64, sqlx::Error> {
pub async fn insert_frame(&self,
device_name: &str,
timestamp: Option<DateTime<Utc>>,
) -> Result<i64, sqlx::Error> {
let mut tx = self.pool.begin().await?;
debug!("insert_frame Transaction started");

Expand Down Expand Up @@ -302,14 +306,16 @@ impl DatabaseManager {
.fetch_one(&mut *tx)
.await?;
debug!("insert_frame Calculated offset_index: {}", offset_index);

let timestamp = timestamp.unwrap_or_else(Utc::now);

// Insert the new frame
let id = sqlx::query(
"INSERT INTO frames (video_chunk_id, offset_index, timestamp) VALUES (?1, ?2, ?3)",
)
.bind(video_chunk_id)
.bind(offset_index)
.bind(Utc::now())
.bind(timestamp)
.execute(&mut *tx)
.await?
.last_insert_rowid();
Expand Down Expand Up @@ -607,7 +613,7 @@ impl DatabaseManager {
AND (?5 IS NULL OR LENGTH(audio_transcriptions.transcription) <= ?5)
"#,
);

sql.push_str(
r#"
GROUP BY
Expand All @@ -620,7 +626,7 @@ impl DatabaseManager {
LIMIT ?6 OFFSET ?7
"#,
);

let query = sqlx::query_as::<_, AudioResultRaw>(&sql)
.bind(query)
.bind(start_time)
Expand All @@ -629,9 +635,9 @@ impl DatabaseManager {
.bind(max_length.map(|l| l as i64))
.bind(limit)
.bind(offset);

let audio_results_raw = query.fetch_all(&self.pool).await?;

// Parse the tags string into a Vec<String>
let audio_results = audio_results_raw
.into_iter()
Expand All @@ -654,7 +660,7 @@ impl DatabaseManager {
},
})
.collect();

Ok(audio_results)
}

Expand Down
217 changes: 215 additions & 2 deletions screenpipe-server/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use axum::{
extract::{Path, Query, State},
extract::{Path, Query, State, Json},
http::StatusCode,
response::Json as JsonResponse,
routing::{get, post},
Expand All @@ -12,12 +12,15 @@ use screenpipe_core::LLM;
#[cfg(feature = "llm")]
use screenpipe_core::{ChatRequest, ChatResponse};
use screenpipe_vision::monitor::list_monitors;
use screenpipe_vision::OcrEngine;
use image::ImageFormat::{self};

use crate::{
db::TagContentType,
pipe_manager::{PipeInfo, PipeManager},
video_utils::{merge_videos, MergeVideosRequest, MergeVideosResponse},
ContentType, DatabaseManager, SearchResult,
video::{start_ffmpeg_process, write_frame_to_ffmpeg, finish_ffmpeg_process, MAX_FPS}
};
use crate::{plugin::ApiPluginLayer, video_utils::extract_frame};
use chrono::{DateTime, Utc};
Expand Down Expand Up @@ -863,6 +866,215 @@ async fn execute_raw_sql(
}
}

#[derive(Deserialize)]
pub struct AddContentRequest {
pub device_name: String, // Moved device_name to the top level
pub content: AddContentData, // The actual content (either Frame or Transcription)
}

#[derive(Deserialize)]
pub struct AddContentData {
pub content_type: String,
pub data: ContentData,
}

#[derive(Deserialize)]
#[serde(untagged)]
pub enum ContentData {
Frames(Vec<FrameContent>),
Transcription(AudioTranscription),
}

#[derive(Deserialize)]
pub struct FrameContent {
pub file_path: String,
pub timestamp: Option<DateTime<Utc>>,
pub app_name: Option<String>,
pub window_name: Option<String>,
pub ocr_results: Option<Vec<OCRResult>>,
pub tags: Option<Vec<String>>,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct OCRResult {
pub text: String,
pub text_json: Option<String>,
pub ocr_engine: Option<String>,
pub focused: Option<bool>,
}

#[derive(Deserialize)]
pub struct AudioTranscription {
pub transcription: String,
pub transcription_engine: String,
}

#[derive(Serialize)]
pub struct AddContentResponse {
pub success: bool,
pub message: Option<String>,
}

async fn add_frame_to_db(
state: &AppState,
frame: &FrameContent,
device_name: &str,
) -> Result<(), anyhow::Error> {
let db = &state.db;

let frame_id = db.insert_frame(
&device_name,
Some(frame.timestamp.unwrap_or_else(Utc::now)),
).await?;

if let Some(ocr_results) = &frame.ocr_results {
for ocr in ocr_results {
db.insert_ocr_text(
frame_id,
&ocr.text,
&ocr.text_json.as_deref().unwrap_or(""),
&frame.app_name.as_deref().unwrap_or(""),
&frame.window_name.as_deref().unwrap_or(""),
Arc::new(OcrEngine::default()), // Ideally could pass any str as ocr_engine since can be run outside of screenpipe
false
).await?;
}
}

if let Some(tags) = &frame.tags {
db.add_tags(frame_id, TagContentType::Vision, tags.clone()).await?;
}

Ok(())
}

fn encode_frame_from_file_path(file_path: &str) -> Result<Vec<u8>, anyhow::Error> {
let image = image::open(file_path)?;
let mut buffer = Vec::new();
image.write_to(&mut std::io::Cursor::new(&mut buffer), ImageFormat::Png)?;
Ok(buffer)
}

async fn write_frames_to_video(
frames: &Vec<FrameContent>,
video_file_path: &str,
fps: f64,
) -> Result<(), anyhow::Error> {
let mut ffmpeg_child = start_ffmpeg_process(video_file_path, fps).await?;
let mut ffmpeg_stdin = ffmpeg_child.stdin.take().expect("Failed to open stdin for FFmpeg");

for frame in frames {
let encoded_frame = encode_frame_from_file_path(&frame.file_path)?;
if let Err(e) = write_frame_to_ffmpeg(&mut ffmpeg_stdin, &encoded_frame).await {
error!("Failed to write frame to FFmpeg: {}", e);
return Err(e);
}
}

finish_ffmpeg_process(ffmpeg_child, Some(ffmpeg_stdin)).await;
Ok(())
}

async fn add_transcription_to_db(
state: &AppState,
transcription: &AudioTranscription,
device_name: &str,
) -> Result<(), anyhow::Error> {
let db = &state.db;

let device = AudioDevice {
name: device_name.to_string(),
device_type: DeviceType::Input,
};

let dummy_audio_chunk_id = db.insert_audio_chunk("").await?;

db.insert_audio_transcription(
dummy_audio_chunk_id, // No associated audio chunk
&transcription.transcription,
-1,
&transcription.transcription_engine,
&device,
).await?;

Ok(())
}

pub(crate) async fn add_to_database(
State(state): State<Arc<AppState>>,
Json(payload): Json<AddContentRequest>,
) -> Result<JsonResponse<AddContentResponse>, (StatusCode, JsonResponse<Value>)> {

let device_name = payload.device_name.clone();
let mut success_messages = Vec::new();

match payload.content.content_type.as_str() {
"frames" => {
if let ContentData::Frames(frames) = &payload.content.data {
if !frames.is_empty() {
let output_dir = state.screenpipe_dir.join("data");
let time = Utc::now();
let formatted_time = time.format("%Y-%m-%d_%H-%M-%S").to_string();
let video_file_path = PathBuf::from(output_dir)
.join(format!("{}_{}.mp4", device_name, formatted_time))
.to_str()
.expect("Failed to create valid path")
.to_string();

if let Err(e) = state.db.insert_video_chunk(&video_file_path, &device_name).await {
error!("Failed to insert video chunk for device {}: {}", device_name, e);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
JsonResponse(json!({"error": format!("Failed to insert video chunk: {}", e)})),
));
}

if let Err(e) = write_frames_to_video(frames, &video_file_path, MAX_FPS).await {
error!("Failed to write frames to video file {}: {}", video_file_path, e);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
JsonResponse(json!({"error": format!("Failed to write frames to video: {}", e)})),
));
}

for frame in frames {
if let Err(e) = add_frame_to_db(&state, frame, &device_name).await {
error!("Failed to add frame content for device {}: {}", device_name, e);
}
}

success_messages.push("Frames added successfully".to_string());
}
}
}
"transcription" => {
if let ContentData::Transcription(transcription) = &payload.content.data {
if let Err(e) = add_transcription_to_db(&state, transcription, &device_name).await {
error!("Failed to add transcription for device {}: {}", device_name, e);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
JsonResponse(json!({"error": format!("Failed to add transcription: {}", e)})),
));
}

success_messages.push("Transcription added successfully".to_string());
}
}
_ => {
error!("Unknown content type: {}", payload.content.content_type);
return Err((
StatusCode::BAD_REQUEST,
JsonResponse(json!({"error": "Unsupported content type"})),
));
}
}

Ok(JsonResponse(AddContentResponse {
success: true,
message: Some(success_messages.join(", ")),
}))
}

#[cfg(feature = "experimental")]
async fn input_control_handler(
JsonResponse(payload): JsonResponse<InputControlRequest>,
Expand Down Expand Up @@ -966,7 +1178,8 @@ pub fn create_router() -> Router<Arc<AppState>> {
.route("/pipes/update", post(update_pipe_config_handler))
.route("/experimental/frames/merge", post(merge_frames_handler))
.route("/health", get(health_check))
.route("/raw_sql", post(execute_raw_sql));
.route("/raw_sql", post(execute_raw_sql))
.route("/add", post(add_to_database));

#[cfg(feature = "llm")]
let router = router.route("/llm/chat", post(llm_chat_handler));
Expand Down
8 changes: 4 additions & 4 deletions screenpipe-server/src/video.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command};
use tokio::sync::mpsc::channel;
use tokio::time::sleep;

const MAX_FPS: f64 = 30.0; // Adjust based on your needs
pub(crate) const MAX_FPS: f64 = 30.0; // Adjust based on your needs
const MAX_QUEUE_SIZE: usize = 10;

pub struct VideoCapture {
Expand Down Expand Up @@ -144,7 +144,7 @@ impl VideoCapture {
}
}

async fn start_ffmpeg_process(output_file: &str, fps: f64) -> Result<Child, anyhow::Error> {
pub async fn start_ffmpeg_process(output_file: &str, fps: f64) -> Result<Child, anyhow::Error> {
// Overriding fps with max fps if over the max and warning user
let fps = if fps > MAX_FPS {
warn!("Overriding FPS from {} to {}", fps, MAX_FPS);
Expand Down Expand Up @@ -201,7 +201,7 @@ async fn start_ffmpeg_process(output_file: &str, fps: f64) -> Result<Child, anyh
Ok(child)
}

async fn write_frame_to_ffmpeg(stdin: &mut ChildStdin, buffer: &[u8]) -> Result<(), anyhow::Error> {
pub async fn write_frame_to_ffmpeg(stdin: &mut ChildStdin, buffer: &[u8]) -> Result<(), anyhow::Error> {
stdin.write_all(buffer).await?;
Ok(())
}
Expand Down Expand Up @@ -383,7 +383,7 @@ async fn flush_ffmpeg_input(stdin: &mut ChildStdin, frame_count: usize, fps: f64
}
}

async fn finish_ffmpeg_process(child: Child, stdin: Option<ChildStdin>) {
pub async fn finish_ffmpeg_process(child: Child, stdin: Option<ChildStdin>) {
drop(stdin); // Ensure stdin is closed
match child.wait_with_output().await {
Ok(output) => {
Expand Down
Loading
Loading