diff --git a/examples/pipelines/commoncrawl/components/download_commoncrawl_segments/fondant_component.yaml b/examples/pipelines/commoncrawl/components/download_commoncrawl_segments/fondant_component.yaml index 9a2c9a842..7672ae5a8 100644 --- a/examples/pipelines/commoncrawl/components/download_commoncrawl_segments/fondant_component.yaml +++ b/examples/pipelines/commoncrawl/components/download_commoncrawl_segments/fondant_component.yaml @@ -39,5 +39,11 @@ args: default: 5 target_language: description: Limit html extraction to target language based on metadata tags. - type: str - default: None \ No newline at end of file + aws_access_key_id: + description: AWS access key id used for authentication to load common crawl files + type: str + default: None + aws_secret_access_key: + description: AWS secret access key used for authentication to load common crawl files + type: str + default: None diff --git a/examples/pipelines/commoncrawl/components/download_commoncrawl_segments/src/main.py b/examples/pipelines/commoncrawl/components/download_commoncrawl_segments/src/main.py index fd6c24c5a..69d4b35ad 100644 --- a/examples/pipelines/commoncrawl/components/download_commoncrawl_segments/src/main.py +++ b/examples/pipelines/commoncrawl/components/download_commoncrawl_segments/src/main.py @@ -88,6 +88,8 @@ def __init__( retries: Optional[int] = None, backoff_factor: Optional[float] = None, target_language: Optional[str] = None, + aws_access_key_id: str = None, + aws_secret_access_key: str = None, ): """Downloads Commoncrawl segments based on a list of WARC paths. Args: @@ -95,6 +97,8 @@ def __init__( get_plain_text: Whether to convert the HTML content to plain text. n_records_to_download: The number of webpages to download from each segment. target_language: Limit html extraction to target language based on metadata tags. + aws_access_key_id: AWS access key id for initialise boto client + aws_secret_access_key: AWS access key id for initialise boto client """ self.use_s3 = use_s3 self.get_plain_text = get_plain_text @@ -104,7 +108,11 @@ def __init__( self.target_language = target_language # initialise s3 session - session = boto3.Session() + + session = boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + ) self.s3_client = session.client("s3") def transform( diff --git a/examples/pipelines/commoncrawl/components/load_from_commoncrawl/fondant_component.yaml b/examples/pipelines/commoncrawl/components/load_from_commoncrawl/fondant_component.yaml index 4037b77fa..8c0e3499d 100644 --- a/examples/pipelines/commoncrawl/components/load_from_commoncrawl/fondant_component.yaml +++ b/examples/pipelines/commoncrawl/components/load_from_commoncrawl/fondant_component.yaml @@ -15,4 +15,12 @@ args: n_segments_to_load: description: Number of segments to load from the commoncrawl index file type: int - default: None \ No newline at end of file + default: None + aws_access_key_id: + description: AWS access key id used for authentication to load common crawl files + type: str + default: None + aws_secret_access_key: + description: AWS secret access key used for authentication to load common crawl files + type: str + default: None diff --git a/examples/pipelines/commoncrawl/components/load_from_commoncrawl/src/main.py b/examples/pipelines/commoncrawl/components/load_from_commoncrawl/src/main.py index 6241d88d6..a233e1ade 100644 --- a/examples/pipelines/commoncrawl/components/load_from_commoncrawl/src/main.py +++ b/examples/pipelines/commoncrawl/components/load_from_commoncrawl/src/main.py @@ -17,17 +17,29 @@ S3_COMMONCRAWL_BUCKET = "commoncrawl" -def fetch_warc_file_from_s3(s3_bucket: str, s3_key: str) -> bytes: +def fetch_warc_file_from_s3( + s3_bucket: str, + s3_key: str, + aws_access_key_id: str = None, + aws_secret_access_key: str = None, +) -> bytes: """Fetches a WARC file from S3 and returns its content as a Dask DataFrame. Args: s3_bucket: The name of the S3 bucket. s3_key: The key of the S3 object to be downloaded. + aws_access_key_id: AWS access key id for initialise boto client + aws_secret_access_key: AWS access key id for initialise boto client + Returns: File object containing the WARC file content. """ logger.info(f"Fetching WARC file from S3: {s3_bucket}/{s3_key}...") - s3 = boto3.client("s3") + session = boto3.Session( + aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key + ) + s3 = session.client("s3") + file_obj = io.BytesIO() s3.download_fileobj(s3_bucket, s3_key, file_obj) file_obj.seek(0) @@ -63,15 +75,24 @@ def read_warc_paths_file( class LoadFromCommonCrawlComponent(DaskLoadComponent): def __init__( - self, *args, index_name: str, n_segments_to_load: t.Optional[int] = None + self, + *args, + index_name: str, + n_segments_to_load: t.Optional[int] = None, + aws_access_key_id: str = None, + aws_secret_access_key: str = None, ) -> None: self.index_name = index_name self.n_segments_to_load = n_segments_to_load + self.aws_access_key_id = aws_access_key_id + self.aws_secret_access_key = aws_secret_access_key """Loads a dataset of segment file paths from CommonCrawl based on a given index. Args: index_name: The name of the CommonCrawl index to load. n_segments_to_load: The number of segments to load from the index. + aws_access_key_id: AWS access key id for initialise boto client + aws_secret_access_key: AWS access key id for initialise boto client """ def load(self) -> dd.DataFrame: @@ -82,7 +103,10 @@ def load(self) -> dd.DataFrame: logger.info(f"Loading CommonCrawl index {self.index_name}...") warc_paths_file_key = f"crawl-data/{self.index_name}/warc.paths.gz" warc_paths_file_content = fetch_warc_file_from_s3( - S3_COMMONCRAWL_BUCKET, warc_paths_file_key + S3_COMMONCRAWL_BUCKET, + warc_paths_file_key, + self.aws_access_key_id, + self.aws_secret_access_key, ) warc_paths_df = read_warc_paths_file(