diff --git a/target_redshift/sinks.py b/target_redshift/sinks.py index 0949e30..ac59c1b 100644 --- a/target_redshift/sinks.py +++ b/target_redshift/sinks.py @@ -126,23 +126,23 @@ def process_batch(self, context: dict) -> None: self.path = Path(self.config["temp_dir"]) / self.file self.path.parent.mkdir(parents=True, exist_ok=True) self.object = f'{self.config["s3_key_prefix"]} / {self.file}' - self._bulk_insert_records( + self.bulk_insert_records( table=temp_table, records=context["records"], cursor=cursor, ) self.logger.info(f'merging {len(context["records"])} records into {table}') # noqa: G004 # Merge data from temp table to main table - self._upsert( + self.upsert( from_table=temp_table, to_table=table, join_keys=self.key_properties, cursor=cursor, ) # clean_resources - self._clean_resources() + self.clean_resources() - def _bulk_insert_records( # type: ignore[override] + def bulk_insert_records( # type: ignore[override] self, table: sqlalchemy.Table, records: Iterable[dict[str, Any]], @@ -164,14 +164,14 @@ def _bulk_insert_records( # type: ignore[override] Returns: True if table exists, False if not, None if unsure or undetectable. """ - self._write_csv(records) + self.write_csv(records) msg = f'writing {len(records)} records to s3://{self.config["s3_bucket"]}/{self.object}' self.logger.info(msg) - self._copy_to_s3() - self._copy_to_redshift(table, cursor) + self.copy_to_s3() + self.copy_to_redshift(table, cursor) return True - def _upsert( + def upsert( self, from_table: sqlalchemy.Table, to_table: sqlalchemy.Table, @@ -200,16 +200,21 @@ def _upsert( join_predicates.append(from_table_key == to_table_key) join_condition = sqlalchemy.and_(*join_predicates) - merge_sql = f""" - MERGE INTO {self.connector.quote(str(to_table))} - USING {self.connector.quote(str(from_table))} - ON {join_condition} - REMOVE DUPLICATES - """ - cursor.execute(merge_sql) - return None + if len(join_keys) > 0: + sql = f""" + MERGE INTO {self.connector.quote(str(to_table))} + USING {self.connector.quote(str(from_table))} + ON {join_condition} + REMOVE DUPLICATES + """ + else: + sql = f""" + INSERT INTO {self.connector.quote(str(to_table))} + SELECT * FROM {self.connector.quote(str(from_table))} + """ # noqa: S608 + cursor.execute(sql) - def _write_csv(self, records: list[dict]) -> None: + def write_csv(self, records: list[dict]) -> None: """Write records to a local csv file. Parameters @@ -255,7 +260,7 @@ def _write_csv(self, records: list[dict]) -> None: ) writer.writerows(records) - def _copy_to_s3(self) -> None: + def copy_to_s3(self) -> None: """Copy the csv file to s3.""" try: _ = self.s3_client.upload_file( @@ -264,7 +269,7 @@ def _copy_to_s3(self) -> None: except ClientError: self.logger.exception() - def _copy_to_redshift(self, table: sqlalchemy.Table, cursor: Cursor) -> None: + def copy_to_redshift(self, table: sqlalchemy.Table, cursor: Cursor) -> None: """Copy the s3 csv file to redshift.""" copy_credentials = f"IAM_ROLE '{self.config['aws_redshift_copy_role_arn']}'" @@ -288,7 +293,7 @@ def _copy_to_redshift(self, table: sqlalchemy.Table, cursor: Cursor) -> None: """ cursor.execute(copy_sql) - def _parse_timestamps_in_record( + def parse_timestamps_in_record( self, record: dict, schema: dict, @@ -333,7 +338,7 @@ def _parse_timestamps_in_record( ) record[key] = date_val - def _clean_resources(self) -> None: + def clean_resources(self) -> None: """Remove local and s3 resources.""" Path.unlink(self.path) if self.config["remove_s3_files"]: