Skip to content

Commit

Permalink
Add skeleton of EmbeddingExtractor
Browse files Browse the repository at this point in the history
  • Loading branch information
joelochlann committed Sep 27, 2024
1 parent 49f7357 commit 4a0c357
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 0 deletions.
40 changes: 40 additions & 0 deletions backend/app/extraction/EmbeddingExtractor.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package extraction

import cats.syntax.either._
import model.manifest.Blob
import model.{English, Languages}
import org.apache.commons.io.FileUtils
import services.index.{Index, Pages}
import services.{ScratchSpace, TranscribeConfig}
import utils.FfMpeg.FfMpegSubprocessCrashedException
import utils._
import utils.attempt.{Failure, FfMpegFailure, UnknownFailure}

import java.io.File
import scala.concurrent.ExecutionContext

class EmbeddingExtractor(index: Index, pages: Pages, scratchSpace: ScratchSpace, transcribeConfig: TranscribeConfig)(implicit executionContext: ExecutionContext) extends FileExtractor(scratchSpace) with Logging {
val mimeTypes: Set[String] = Set(
"application/pdf"
)

def canProcessMimeType: String => Boolean = mimeTypes.contains

override def indexing = true
// set a low priority as transcription takes a long time, we don't want to block up the workers
override def priority = 1

override def extract(blob: Blob, file: File, params: ExtractionParams): Either[Failure, Unit] = {
logger.info(s"Running embedding extractor '${blob.uri.value}'")

// Get the pages from elasticsearch
for {
page <- pages.getAllPages(blob.uri)
} yield {
// Run page through embedding model
// Write embeddings to a vector field against the page
}

Right(())
}
}
14 changes: 14 additions & 0 deletions backend/app/services/index/ElasticsearchPages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,20 @@ class ElasticsearchPages(val client: ElasticClient, indexNamePrefix: String)(imp
}
}

override def getAllPages(uri: Uri): Attempt[List[Page]] = {
execute {
// TODO: check if this is the best way to query (seems a little overcomplicated)
search(textIndexName).query(
should(matchAllQuery())
.filter(List(termQuery(PagesFields.resourceId, uri.value)))
)
.sortBy(fieldSort(PagesFields.page).asc())
}.flatMap { resp =>
val pages = resp.to[Page].toList
Attempt.Right(pages)
}
}

// TODO MRB: collapse total page count and height into fields on the document itself
// TODO SC/JS: We agree.
private def getTotalPageCount(indexName: String, uri: Uri): Attempt[Long] = {
Expand Down
2 changes: 2 additions & 0 deletions backend/app/services/index/Pages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ trait Pages {
def getTextPages(uri: Uri, top: Double, bottom: Double, highlightQuery: Option[String]): Attempt[PageResult]

def getPage(uri: Uri, pageNumber: Int, highlightQuery: Option[String]): Attempt[Page]

def getAllPages(uri: Uri): Attempt[List[Page]]
}

0 comments on commit 4a0c357

Please sign in to comment.