Skip to content

Commit

Permalink
remove private methods and make sure upsert runs append as well
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiascadee committed Sep 4, 2024
1 parent fbbc33c commit cff2d07
Showing 1 changed file with 26 additions and 21 deletions.
47 changes: 26 additions & 21 deletions target_redshift/sinks.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,23 +126,23 @@ def process_batch(self, context: dict) -> None:
self.path = Path(self.config["temp_dir"]) / self.file
self.path.parent.mkdir(parents=True, exist_ok=True)
self.object = f'{self.config["s3_key_prefix"]} / {self.file}'
self._bulk_insert_records(
self.bulk_insert_records(
table=temp_table,
records=context["records"],
cursor=cursor,
)
self.logger.info(f'merging {len(context["records"])} records into {table}') # noqa: G004
# Merge data from temp table to main table
self._upsert(
self.upsert(
from_table=temp_table,
to_table=table,
join_keys=self.key_properties,
cursor=cursor,
)
# clean_resources
self._clean_resources()
self.clean_resources()

def _bulk_insert_records( # type: ignore[override]
def bulk_insert_records( # type: ignore[override]
self,
table: sqlalchemy.Table,
records: Iterable[dict[str, Any]],
Expand All @@ -164,14 +164,14 @@ def _bulk_insert_records( # type: ignore[override]
Returns:
True if table exists, False if not, None if unsure or undetectable.
"""
self._write_csv(records)
self.write_csv(records)
msg = f'writing {len(records)} records to s3://{self.config["s3_bucket"]}/{self.object}'
self.logger.info(msg)
self._copy_to_s3()
self._copy_to_redshift(table, cursor)
self.copy_to_s3()
self.copy_to_redshift(table, cursor)
return True

def _upsert(
def upsert(
self,
from_table: sqlalchemy.Table,
to_table: sqlalchemy.Table,
Expand Down Expand Up @@ -200,16 +200,21 @@ def _upsert(
join_predicates.append(from_table_key == to_table_key)

join_condition = sqlalchemy.and_(*join_predicates)
merge_sql = f"""
MERGE INTO {self.connector.quote(str(to_table))}
USING {self.connector.quote(str(from_table))}
ON {join_condition}
REMOVE DUPLICATES
"""
cursor.execute(merge_sql)
return None
if len(join_keys) > 0:
sql = f"""
MERGE INTO {self.connector.quote(str(to_table))}
USING {self.connector.quote(str(from_table))}
ON {join_condition}
REMOVE DUPLICATES
"""
else:
sql = f"""
INSERT INTO {self.connector.quote(str(to_table))}
SELECT * FROM {self.connector.quote(str(from_table))}
""" # noqa: S608
cursor.execute(sql)

def _write_csv(self, records: list[dict]) -> None:
def write_csv(self, records: list[dict]) -> None:
"""Write records to a local csv file.
Parameters
Expand Down Expand Up @@ -255,7 +260,7 @@ def _write_csv(self, records: list[dict]) -> None:
)
writer.writerows(records)

def _copy_to_s3(self) -> None:
def copy_to_s3(self) -> None:
"""Copy the csv file to s3."""
try:
_ = self.s3_client.upload_file(
Expand All @@ -264,7 +269,7 @@ def _copy_to_s3(self) -> None:
except ClientError:
self.logger.exception()

def _copy_to_redshift(self, table: sqlalchemy.Table, cursor: Cursor) -> None:
def copy_to_redshift(self, table: sqlalchemy.Table, cursor: Cursor) -> None:
"""Copy the s3 csv file to redshift."""
copy_credentials = f"IAM_ROLE '{self.config['aws_redshift_copy_role_arn']}'"

Expand All @@ -288,7 +293,7 @@ def _copy_to_redshift(self, table: sqlalchemy.Table, cursor: Cursor) -> None:
"""
cursor.execute(copy_sql)

def _parse_timestamps_in_record(
def parse_timestamps_in_record(
self,
record: dict,
schema: dict,
Expand Down Expand Up @@ -333,7 +338,7 @@ def _parse_timestamps_in_record(
)
record[key] = date_val

def _clean_resources(self) -> None:
def clean_resources(self) -> None:
"""Remove local and s3 resources."""
Path.unlink(self.path)
if self.config["remove_s3_files"]:
Expand Down

0 comments on commit cff2d07

Please sign in to comment.