Skip to content

Commit

Permalink
Use url column instead of file path
Browse files Browse the repository at this point in the history
  • Loading branch information
DhruvaBansal00 committed Nov 11, 2024
1 parent 6b4ce4d commit 02cad6d
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions src/autolabel/transforms/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,26 @@
from autolabel.transforms import BaseTransform
from autolabel.cache import BaseCache

from autolabel.transforms.schema import TransformError, TransformErrorType


class OCRTransform(BaseTransform):
"""This class is used to extract text from any document using OCR. The output columns dictionary for this class should include the keys 'content_column' and 'metadata_column'
This transform supports the following image formats: PDF, PNG, JPEG, TIFF, JPEG 2000, GIF, WebP, BMP, and PNM
"""

COLUMN_NAMES = [
"content_column",
"metadata_column",
]
COLUMN_NAMES = ["content_column"]

def __init__(
self,
cache: BaseCache,
output_columns: Dict[str, Any],
file_path_column: str,
url_column: str,
lang: str = None,
) -> None:
super().__init__(cache, output_columns)
self.file_path_column = file_path_column
self.url_column = url_column
self.lang = lang

try:
Expand Down Expand Up @@ -79,12 +78,15 @@ async def _apply(self, row: Dict[str, Any]) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The dict of output columns.
"""
curr_file_location = row[self.file_path_column]
curr_file_location = row[self.url_column]
# download file to temp location if a url
if curr_file_location.startswith("http"):
try:
curr_file_path = self.download_file(curr_file_location)
else:
curr_file_path = curr_file_location
except Exception as e:
raise TransformError(
TransformErrorType.TRANSFORM_ERROR,
f"Error downloading file: {e}",
)
ocr_output = []
if curr_file_path.endswith(".pdf"):
pages = self.convert_from_path(curr_file_path)
Expand All @@ -94,16 +96,15 @@ async def _apply(self, row: Dict[str, Any]) -> Dict[str, Any]:

transformed_row = {
self.output_columns["content_column"]: "\n\n".join(ocr_output),
self.output_columns["metadata_column"]: {"num_pages": len(ocr_output)},
}
return self._return_output_row(transformed_row)

def params(self) -> Dict[str, Any]:
return {
"output_columns": self.output_columns,
"file_path_column": self.file_path_column,
"url_column": self.url_column,
"lang": self.lang,
}

def input_columns(self) -> List[str]:
return [self.file_path_column]
return [self.url_column]

0 comments on commit 02cad6d

Please sign in to comment.