Skip to content

Commit

Permalink
Avoid naming collisions in FDW, schema, connection id when data refre…
Browse files Browse the repository at this point in the history
…shes run concurrently
  • Loading branch information
stacimc committed Oct 7, 2024
1 parent d069d02 commit 276bd49
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
18 changes: 16 additions & 2 deletions catalog/dags/data_refresh/copy_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,35 @@
def initialize_fdw(
upstream_conn_id: str,
downstream_conn_id: str,
media_type: str,
task: AbstractOperator = None,
):
"""Create the FDW and prepare it for copying."""
upstream_connection = Connection.get_connection_from_secrets(upstream_conn_id)
fdw_name = f"upstream_{media_type}"

run_sql.function(
postgres_conn_id=downstream_conn_id,
sql_template=queries.CREATE_FDW_QUERY,
task=task,
fdw_name=fdw_name,
host=upstream_connection.host,
port=upstream_connection.port,
dbname=upstream_connection.schema,
user=upstream_connection.login,
password=upstream_connection.password,
)

return fdw_name


@task(
max_active_tis_per_dagrun=1,
map_index_template="{{ task.op_kwargs['upstream_table_name'] }}",
)
def create_schema(downstream_conn_id: str, upstream_table_name: str) -> str:
def create_schema(
downstream_conn_id: str, upstream_table_name: str, fdw_name: str
) -> str:
"""
Create a new schema in the downstream DB through which the upstream table
can be accessed. Returns the schema name.
Expand All @@ -73,7 +80,9 @@ def create_schema(downstream_conn_id: str, upstream_table_name: str) -> str:
schema_name = f"upstream_{upstream_table_name}_schema"
downstream_pg.run(
queries.CREATE_SCHEMA_QUERY.format(
schema_name=schema_name, upstream_table_name=upstream_table_name
fdw_name=fdw_name,
schema_name=schema_name,
upstream_table_name=upstream_table_name,
)
)
return schema_name
Expand Down Expand Up @@ -183,6 +192,7 @@ def copy_data(
def copy_upstream_table(
upstream_conn_id: str,
downstream_conn_id: str,
fdw_name: str,
timeout: timedelta,
limit: int,
upstream_table_name: str,
Expand All @@ -206,6 +216,7 @@ def copy_upstream_table(
schema = create_schema(
downstream_conn_id=downstream_conn_id,
upstream_table_name=upstream_table_name,
fdw_name=fdw_name,
)

create_temp_table = run_sql.override(
Expand Down Expand Up @@ -286,6 +297,7 @@ def copy_upstream_tables(
init_fdw = initialize_fdw(
upstream_conn_id=upstream_conn_id,
downstream_conn_id=downstream_conn_id,
media_type=data_refresh_config.media_type,
)

limit = get_record_limit()
Expand All @@ -294,13 +306,15 @@ def copy_upstream_tables(
copy_tables = copy_upstream_table.partial(
upstream_conn_id=upstream_conn_id,
downstream_conn_id=downstream_conn_id,
fdw_name=init_fdw,
timeout=data_refresh_config.copy_data_timeout,
limit=limit,
).expand_kwargs([asdict(tm) for tm in data_refresh_config.table_mappings])

drop_fdw = run_sql.override(task_id="drop_fdw")(
postgres_conn_id=downstream_conn_id,
sql_template=queries.DROP_SERVER_QUERY,
fdw_name=init_fdw,
)

# Set up dependencies
Expand Down
4 changes: 3 additions & 1 deletion catalog/dags/data_refresh/distributed_reindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,14 @@ def get_instance_ip_address(
)
def create_connection(
instance_id: str,
media_type: str,
server: str,
):
"""
Create an Airflow Connection for the given indexer worker and persist it. It will
later be dropped in a cleanup task.
"""
worker_conn_id = f"indexer_worker_{instance_id or 'localhost'}"
worker_conn_id = f"indexer_worker_{instance_id or media_type}"

# Create the Connection
logger.info(f"Creating connection with id {worker_conn_id}")
Expand Down Expand Up @@ -333,6 +334,7 @@ def reindex(

worker_conn = create_connection(
instance_id=instance_id,
media_type=data_refresh_config.media_type,
server=instance_ip_address,
)

Expand Down
10 changes: 5 additions & 5 deletions catalog/dags/data_refresh/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

CREATE_FDW_QUERY = dedent(
"""
DROP SERVER IF EXISTS upstream CASCADE;
CREATE SERVER upstream FOREIGN DATA WRAPPER postgres_fdw
DROP SERVER IF EXISTS {fdw_name} CASCADE;
CREATE SERVER {fdw_name} FOREIGN DATA WRAPPER postgres_fdw
OPTIONS (host '{host}', dbname '{dbname}', port '{port}');
CREATE USER MAPPING IF NOT EXISTS FOR deploy SERVER upstream
CREATE USER MAPPING IF NOT EXISTS FOR deploy SERVER {fdw_name}
OPTIONS (user '{user}', password '{password}');
"""
)
Expand All @@ -20,7 +20,7 @@
CREATE SCHEMA {schema_name} AUTHORIZATION deploy;
IMPORT FOREIGN SCHEMA public LIMIT TO ({upstream_table_name})
FROM SERVER upstream INTO {schema_name};
FROM SERVER {fdw_name} INTO {schema_name};
"""
)

Expand Down Expand Up @@ -79,7 +79,7 @@

ADD_PRIMARY_KEY_QUERY = "ALTER TABLE {temp_table_name} ADD PRIMARY KEY (id);"

DROP_SERVER_QUERY = "DROP SERVER upstream CASCADE;"
DROP_SERVER_QUERY = "DROP SERVER {fdw_name} CASCADE;"

SELECT_TABLE_INDICES_QUERY = (
"SELECT indexdef FROM pg_indexes WHERE tablename='{table_name}';"
Expand Down

0 comments on commit 276bd49

Please sign in to comment.