From 4a0c357b2f5847e85f687c06ae8f6aae1989083a Mon Sep 17 00:00:00 2001 From: Joseph Smith Date: Fri, 27 Sep 2024 17:41:58 +0100 Subject: [PATCH] Add skeleton of EmbeddingExtractor --- .../app/extraction/EmbeddingExtractor.scala | 40 +++++++++++++++++++ .../services/index/ElasticsearchPages.scala | 14 +++++++ backend/app/services/index/Pages.scala | 2 + 3 files changed, 56 insertions(+) create mode 100644 backend/app/extraction/EmbeddingExtractor.scala diff --git a/backend/app/extraction/EmbeddingExtractor.scala b/backend/app/extraction/EmbeddingExtractor.scala new file mode 100644 index 00000000..678f55c8 --- /dev/null +++ b/backend/app/extraction/EmbeddingExtractor.scala @@ -0,0 +1,40 @@ +package extraction + +import cats.syntax.either._ +import model.manifest.Blob +import model.{English, Languages} +import org.apache.commons.io.FileUtils +import services.index.{Index, Pages} +import services.{ScratchSpace, TranscribeConfig} +import utils.FfMpeg.FfMpegSubprocessCrashedException +import utils._ +import utils.attempt.{Failure, FfMpegFailure, UnknownFailure} + +import java.io.File +import scala.concurrent.ExecutionContext + +class EmbeddingExtractor(index: Index, pages: Pages, scratchSpace: ScratchSpace, transcribeConfig: TranscribeConfig)(implicit executionContext: ExecutionContext) extends FileExtractor(scratchSpace) with Logging { + val mimeTypes: Set[String] = Set( + "application/pdf" + ) + + def canProcessMimeType: String => Boolean = mimeTypes.contains + + override def indexing = true + // set a low priority as transcription takes a long time, we don't want to block up the workers + override def priority = 1 + + override def extract(blob: Blob, file: File, params: ExtractionParams): Either[Failure, Unit] = { + logger.info(s"Running embedding extractor '${blob.uri.value}'") + + // Get the pages from elasticsearch + for { + page <- pages.getAllPages(blob.uri) + } yield { + // Run page through embedding model + // Write embeddings to a vector field against the page + } + + Right(()) + } +} diff --git a/backend/app/services/index/ElasticsearchPages.scala b/backend/app/services/index/ElasticsearchPages.scala index bf764c6b..d9495e22 100644 --- a/backend/app/services/index/ElasticsearchPages.scala +++ b/backend/app/services/index/ElasticsearchPages.scala @@ -106,6 +106,20 @@ class ElasticsearchPages(val client: ElasticClient, indexNamePrefix: String)(imp } } + override def getAllPages(uri: Uri): Attempt[List[Page]] = { + execute { + // TODO: check if this is the best way to query (seems a little overcomplicated) + search(textIndexName).query( + should(matchAllQuery()) + .filter(List(termQuery(PagesFields.resourceId, uri.value))) + ) + .sortBy(fieldSort(PagesFields.page).asc()) + }.flatMap { resp => + val pages = resp.to[Page].toList + Attempt.Right(pages) + } + } + // TODO MRB: collapse total page count and height into fields on the document itself // TODO SC/JS: We agree. private def getTotalPageCount(indexName: String, uri: Uri): Attempt[Long] = { diff --git a/backend/app/services/index/Pages.scala b/backend/app/services/index/Pages.scala index b949060d..3103a658 100644 --- a/backend/app/services/index/Pages.scala +++ b/backend/app/services/index/Pages.scala @@ -12,4 +12,6 @@ trait Pages { def getTextPages(uri: Uri, top: Double, bottom: Double, highlightQuery: Option[String]): Attempt[PageResult] def getPage(uri: Uri, pageNumber: Int, highlightQuery: Option[String]): Attempt[Page] + + def getAllPages(uri: Uri): Attempt[List[Page]] }