From 2950fd768541fc902d8f7218e4243e8d83414c51 Mon Sep 17 00:00:00 2001 From: Dan Hansen Date: Mon, 14 Aug 2023 03:51:18 -0700 Subject: [PATCH] [Models] [Postgres] Check if the dynamically-added index is in the table schema before adding (#32731) * Check if the index is in the table schema before adding * add pre-condition assertion * static checks * Update test_models.py * integrate upstream auth manager changes --- airflow/auth/managers/fab/models/__init__.py | 8 ++- tests/auth/managers/fab/test_models.py | 62 ++++++++++++++++++++ 2 files changed, 68 insertions(+), 2 deletions(-) create mode 100644 tests/auth/managers/fab/test_models.py diff --git a/airflow/auth/managers/fab/models/__init__.py b/airflow/auth/managers/fab/models/__init__.py index cb11e8fb06a2..0bc26adb7eb4 100644 --- a/airflow/auth/managers/fab/models/__init__.py +++ b/airflow/auth/managers/fab/models/__init__.py @@ -255,11 +255,15 @@ class RegisterUser(Model): def add_index_on_ab_user_username_postgres(table, conn, **kw): if conn.dialect.name != "postgresql": return - table.indexes.add(Index("idx_ab_user_username", func.lower(table.c.username), unique=True)) + index_name = "idx_ab_user_username" + if not any(table_index.name == index_name for table_index in table.indexes): + table.indexes.add(Index(index_name, func.lower(table.c.username), unique=True)) @event.listens_for(RegisterUser.__table__, "before_create") def add_index_on_ab_register_user_username_postgres(table, conn, **kw): if conn.dialect.name != "postgresql": return - table.indexes.add(Index("idx_ab_register_user_username", func.lower(table.c.username), unique=True)) + index_name = "idx_ab_register_user_username" + if not any(table_index.name == index_name for table_index in table.indexes): + table.indexes.add(Index(index_name, func.lower(table.c.username), unique=True)) diff --git a/tests/auth/managers/fab/test_models.py b/tests/auth/managers/fab/test_models.py new file mode 100644 index 000000000000..f2703e8d66e4 --- /dev/null +++ b/tests/auth/managers/fab/test_models.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from sqlalchemy import Column, MetaData, String, Table + +from airflow.auth.managers.fab.models import ( + add_index_on_ab_register_user_username_postgres, + add_index_on_ab_user_username_postgres, +) + +_mock_conn = mock.MagicMock() +_mock_conn.dialect = mock.MagicMock() +_mock_conn.dialect.name = "postgresql" + + +def test_add_index_on_ab_user_username_postgres(): + table = Table("test_table", MetaData(), Column("username", String)) + + assert len(table.indexes) == 0 + + add_index_on_ab_user_username_postgres(table, _mock_conn) + + # Assert that the index was added to the table + assert len(table.indexes) == 1 + + add_index_on_ab_user_username_postgres(table, _mock_conn) + + # Assert that index is not re-added when the schema is recreated + assert len(table.indexes) == 1 + + +def test_add_index_on_ab_register_user_username_postgres(): + table = Table("test_table", MetaData(), Column("username", String)) + + assert len(table.indexes) == 0 + + add_index_on_ab_register_user_username_postgres(table, _mock_conn) + + # Assert that the index was added to the table + assert len(table.indexes) == 1 + + add_index_on_ab_register_user_username_postgres(table, _mock_conn) + + # Assert that index is not re-added when the schema is recreated + assert len(table.indexes) == 1