From fd5916476217b1ce8b0f74bb4289d0f157a42951 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Thu, 8 Jun 2023 15:32:11 -0400 Subject: [PATCH] Nulls not distinct support in postgresql Added support in autogenerate for NULLS NOT DISTINCT in the PostgreSQL dialect. Closes: #1249 Pull-request: https://github.com/sqlalchemy/alembic/pull/1249 Pull-request-sha: e4a7ffed54677d5aba9ab0251026a8a2a0e71278 Change-Id: I299a24fa7af4ae9387d6b48ce49fb516dfb84518 --- alembic/autogenerate/compare.py | 20 +++-- alembic/autogenerate/render.py | 43 +++++----- alembic/ddl/impl.py | 6 ++ alembic/ddl/postgresql.py | 27 ++++++- alembic/operations/base.py | 2 +- docs/build/unreleased/1249.rst | 6 ++ tests/requirements.py | 21 +++++ tests/test_postgresql.py | 139 ++++++++++++++++++++++++++++++++ 8 files changed, 228 insertions(+), 36 deletions(-) create mode 100644 docs/build/unreleased/1249.rst diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py index db32a6a4..c441a200 100644 --- a/alembic/autogenerate/compare.py +++ b/alembic/autogenerate/compare.py @@ -444,11 +444,11 @@ class _uq_constraint_sig(_constraint_sig): is_index = False is_unique = True - def __init__(self, const: UniqueConstraint) -> None: + def __init__(self, const: UniqueConstraint, impl: DefaultImpl) -> None: self.const = const self.name = const.name - self.sig = ("UNIQUE_CONSTRAINT",) + tuple( - sorted([col.name for col in const.columns]) + self.sig = ("UNIQUE_CONSTRAINT",) + impl.create_unique_constraint_sig( + const ) @property @@ -616,6 +616,7 @@ def _compare_indexes_and_uniques( # 2a. if the dialect dupes unique indexes as unique constraints # (mysql and oracle), correct for that + impl = autogen_context.migration_context.impl if unique_constraints_duplicate_unique_indexes: _correct_for_uq_duplicates_uix( conn_uniques, @@ -623,6 +624,7 @@ def _compare_indexes_and_uniques( metadata_unique_constraints, metadata_indexes, autogen_context.dialect, + impl, ) # 3. give the dialect a chance to omit indexes and constraints that @@ -640,15 +642,16 @@ def _compare_indexes_and_uniques( # Index and UniqueConstraint so we can easily work with them # interchangeably metadata_unique_constraints_sig = { - _uq_constraint_sig(uq) for uq in metadata_unique_constraints + _uq_constraint_sig(uq, impl) for uq in metadata_unique_constraints } - impl = autogen_context.migration_context.impl metadata_indexes_sig = { _ix_constraint_sig(ix, impl) for ix in metadata_indexes } - conn_unique_constraints = {_uq_constraint_sig(uq) for uq in conn_uniques} + conn_unique_constraints = { + _uq_constraint_sig(uq, impl) for uq in conn_uniques + } conn_indexes_sig = {_ix_constraint_sig(ix, impl) for ix in conn_indexes} @@ -858,6 +861,7 @@ def _correct_for_uq_duplicates_uix( metadata_unique_constraints, metadata_indexes, dialect, + impl, ): # dedupe unique indexes vs. constraints, since MySQL / Oracle # doesn't really have unique constraints as a separate construct. @@ -880,7 +884,7 @@ def _correct_for_uq_duplicates_uix( } unnamed_metadata_uqs = { - _uq_constraint_sig(cons).sig + _uq_constraint_sig(cons, impl).sig for name, cons in metadata_cons_names if name is None } @@ -904,7 +908,7 @@ def _correct_for_uq_duplicates_uix( for overlap in uqs_dupe_indexes: if overlap not in metadata_uq_names: if ( - _uq_constraint_sig(uqs_dupe_indexes[overlap]).sig + _uq_constraint_sig(uqs_dupe_indexes[overlap], impl).sig not in unnamed_metadata_uqs ): conn_unique_constraints.discard(uqs_dupe_indexes[overlap]) diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py index 215af8ce..3dfb5e9e 100644 --- a/alembic/autogenerate/render.py +++ b/alembic/autogenerate/render.py @@ -26,6 +26,7 @@ if TYPE_CHECKING: from typing import Literal + from sqlalchemy.sql.base import DialectKWArgs from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.schema import CheckConstraint @@ -268,6 +269,15 @@ def _drop_table(autogen_context: AutogenContext, op: ops.DropTableOp) -> str: return text +def _render_dialect_kwargs_items( + autogen_context: AutogenContext, item: DialectKWArgs +) -> list[str]: + return [ + f"{key}={_render_potential_expr(val, autogen_context)}" + for key, val in item.dialect_kwargs.items() + ] + + @renderers.dispatch_for(ops.CreateIndexOp) def _add_index(autogen_context: AutogenContext, op: ops.CreateIndexOp) -> str: index = op.to_index() @@ -286,6 +296,8 @@ def _add_index(autogen_context: AutogenContext, op: ops.CreateIndexOp) -> str: ) assert index.table is not None + + opts = _render_dialect_kwargs_items(autogen_context, index) text = tmpl % { "prefix": _alembic_autogenerate_prefix(autogen_context), "name": _render_gen_name(autogen_context, index.name), @@ -297,18 +309,7 @@ def _add_index(autogen_context: AutogenContext, op: ops.CreateIndexOp) -> str: "schema": (", schema=%r" % _ident(index.table.schema)) if index.table.schema else "", - "kwargs": ( - ", " - + ", ".join( - [ - "%s=%s" - % (key, _render_potential_expr(val, autogen_context)) - for key, val in index.kwargs.items() - ] - ) - ) - if len(index.kwargs) - else "", + "kwargs": ", " + ", ".join(opts) if opts else "", } return text @@ -326,24 +327,13 @@ def _drop_index(autogen_context: AutogenContext, op: ops.DropIndexOp) -> str: "%(prefix)sdrop_index(%(name)r, " "table_name=%(table_name)r%(schema)s%(kwargs)s)" ) - + opts = _render_dialect_kwargs_items(autogen_context, index) text = tmpl % { "prefix": _alembic_autogenerate_prefix(autogen_context), "name": _render_gen_name(autogen_context, op.index_name), "table_name": _ident(op.table_name), "schema": ((", schema=%r" % _ident(op.schema)) if op.schema else ""), - "kwargs": ( - ", " - + ", ".join( - [ - "%s=%s" - % (key, _render_potential_expr(val, autogen_context)) - for key, val in index.kwargs.items() - ] - ) - ) - if len(index.kwargs) - else "", + "kwargs": ", " + ", ".join(opts) if opts else "", } return text @@ -604,6 +594,7 @@ def _uq_constraint( opts.append( ("name", _render_gen_name(autogen_context, constraint.name)) ) + dialect_options = _render_dialect_kwargs_items(autogen_context, constraint) if alter: args = [repr(_render_gen_name(autogen_context, constraint.name))] @@ -611,6 +602,7 @@ def _uq_constraint( args += [repr(_ident(constraint.table.name))] args.append(repr([_ident(col.name) for col in constraint.columns])) args.extend(["%s=%r" % (k, v) for k, v in opts]) + args.extend(dialect_options) return "%(prefix)screate_unique_constraint(%(args)s)" % { "prefix": _alembic_autogenerate_prefix(autogen_context), "args": ", ".join(args), @@ -618,6 +610,7 @@ def _uq_constraint( else: args = [repr(_ident(col.name)) for col in constraint.columns] args.extend(["%s=%r" % (k, v) for k, v in opts]) + args.extend(dialect_options) return "%(prefix)sUniqueConstraint(%(args)s)" % { "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), "args": ", ".join(args), diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 726f1686..31667ef8 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -668,6 +668,12 @@ def create_index_sig(self, index: Index) -> Tuple[Any, ...]: # order of col matters in an index return tuple(col.name for col in index.columns) + def create_unique_constraint_sig( + self, const: UniqueConstraint + ) -> Tuple[Any, ...]: + # order of col does not matters in an unique constraint + return tuple(sorted([col.name for col in const.columns])) + def _skip_functional_indexes(self, metadata_indexes, conn_indexes): conn_indexes_by_name = {c.name: c for c in conn_indexes} diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py index c2d31062..afabd6c0 100644 --- a/alembic/ddl/postgresql.py +++ b/alembic/ddl/postgresql.py @@ -12,7 +12,6 @@ from typing import Union from sqlalchemy import Column -from sqlalchemy import Index from sqlalchemy import literal_column from sqlalchemy import Numeric from sqlalchemy import text @@ -50,6 +49,8 @@ if TYPE_CHECKING: from typing import Literal + from sqlalchemy import Index + from sqlalchemy import UniqueConstraint from sqlalchemy.dialects.postgresql.array import ARRAY from sqlalchemy.dialects.postgresql.base import PGDDLCompiler from sqlalchemy.dialects.postgresql.hstore import HSTORE @@ -305,6 +306,21 @@ def _default_modifiers(self, exp: ClauseElement) -> str: break return to_remove + def _dialect_sig( + self, item: Union[Index, UniqueConstraint] + ) -> Tuple[Any, ...]: + if ( + item.dialect_kwargs.get("postgresql_nulls_not_distinct") + is not None + ): + return ( + ( + "nulls_not_distinct", + item.dialect_kwargs["postgresql_nulls_not_distinct"], + ), + ) + return () + def create_index_sig(self, index: Index) -> Tuple[Any, ...]: return tuple( self._cleanup_index_expr( @@ -316,7 +332,14 @@ def create_index_sig(self, index: Index) -> Tuple[Any, ...]: ), ) for e in index.expressions - ) + ) + self._dialect_sig(index) + + def create_unique_constraint_sig( + self, const: UniqueConstraint + ) -> Tuple[Any, ...]: + return tuple( + sorted([col.name for col in const.columns]) + ) + self._dialect_sig(const) def _compile_element(self, element: ClauseElement) -> str: return element.compile( diff --git a/alembic/operations/base.py b/alembic/operations/base.py index a2acafef..e2c1fd23 100644 --- a/alembic/operations/base.py +++ b/alembic/operations/base.py @@ -86,7 +86,7 @@ def __init__( @classmethod def register_operation( cls, name: str, sourcename: Optional[str] = None - ) -> Callable[..., Any]: + ) -> Callable[[_T], _T]: """Register a new operation for this class. This method is normally used to add new operations diff --git a/docs/build/unreleased/1249.rst b/docs/build/unreleased/1249.rst new file mode 100644 index 00000000..b3740cbe --- /dev/null +++ b/docs/build/unreleased/1249.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: usecase, autogenerate + :tickets: 1248 + + Added support in autogenerate for NULLS NOT DISTINCT in + the PostgreSQL dialect. diff --git a/tests/requirements.py b/tests/requirements.py index dbbb88a5..d67a8479 100644 --- a/tests/requirements.py +++ b/tests/requirements.py @@ -1,4 +1,5 @@ from sqlalchemy import exc as sqla_exc +from sqlalchemy import Index from sqlalchemy import text from alembic.testing import exclusions @@ -430,3 +431,23 @@ def reflect_indexes_with_expressions(self): @property def indexes_with_expressions(self): return exclusions.only_on(["postgresql", "sqlite>=3.9.0"]) + + @property + def nulls_not_distinct_sa(self): + def _has_nulls_not_distinct(): + try: + Index("foo", "bar", postgresql_nulls_not_distinct=True) + return True + except sqla_exc.ArgumentError: + return False + + return exclusions.only_if( + _has_nulls_not_distinct, + "sqlalchemy with nulls not distinct support needed", + ) + + @property + def nulls_not_distinct_db(self): + return self.nulls_not_distinct_sa + exclusions.only_on( + ["postgresql>=15"] + ) diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 7b7afdc0..8984437b 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -1258,6 +1258,45 @@ def test_jsonb_type(self): "postgresql.JSONB(astext_type=sa.Text())", ) + @config.requirements.nulls_not_distinct_sa + def test_render_unique_nulls_not_distinct_constraint(self): + m = MetaData() + t = Table("tbl", m, Column("c", Integer)) + uc = UniqueConstraint( + t.c.c, + name="uq_1", + deferrable="XYZ", + postgresql_nulls_not_distinct=True, + ) + eq_ignore_whitespace( + autogenerate.render.render_op_text( + self.autogen_context, + ops.AddConstraintOp.from_constraint(uc), + ), + "op.create_unique_constraint('uq_1', 'tbl', ['c'], " + "deferrable='XYZ', postgresql_nulls_not_distinct=True)", + ) + eq_ignore_whitespace( + autogenerate.render._render_unique_constraint( + uc, self.autogen_context, None + ), + "sa.UniqueConstraint('c', deferrable='XYZ', name='uq_1', " + "postgresql_nulls_not_distinct=True)", + ) + + @config.requirements.nulls_not_distinct_sa + def test_render_index_nulls_not_distinct_constraint(self): + m = MetaData() + t = Table("tbl", m, Column("c", Integer)) + idx = Index("ix_42", t.c.c, postgresql_nulls_not_distinct=False) + eq_ignore_whitespace( + autogenerate.render.render_op_text( + self.autogen_context, ops.CreateIndexOp.from_index(idx) + ), + "op.create_index('ix_42', 'tbl', ['c'], unique=False, " + "postgresql_nulls_not_distinct=False)", + ) + class PGUniqueIndexAutogenerateTest(AutogenFixtureTest, TestBase): __only_on__ = "postgresql" @@ -1394,3 +1433,103 @@ def test_uq_dropped(self): eq_(diffs[0][0], "remove_constraint") eq_(diffs[0][1].name, "uq_name") eq_(len(diffs), 1) + + +case = combinations(False, True, None, argnames="case", id_="s") +name_type = combinations( + ( + "index", + lambda value: Index( + "nnd_obj", "name", unique=True, postgresql_nulls_not_distinct=value + ), + ), + ( + "constraint", + lambda value: UniqueConstraint( + "id", "name", name="nnd_obj", postgresql_nulls_not_distinct=value + ), + ), + argnames="name,type_", + id_="sa", +) + + +class PGNullsNotDistinctAutogenerateTest(AutogenFixtureTest, TestBase): + __requires__ = ("nulls_not_distinct_db",) + __only_on__ = "postgresql" + __backend__ = True + + @case + @name_type + def test_add(self, case, name, type_): + m1 = MetaData() + m2 = MetaData() + Table( + "tbl", + m1, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + Table( + "tbl", + m2, + Column("id", Integer, primary_key=True), + Column("name", String), + type_(case), + ) + diffs = self._fixture(m1, m2) + eq_(len(diffs), 1) + eq_(diffs[0][0], f"add_{name}") + added = diffs[0][1] + eq_(added.name, "nnd_obj") + eq_(added.dialect_kwargs["postgresql_nulls_not_distinct"], case) + + @case + @name_type + def test_remove(self, case, name, type_): + m1 = MetaData() + m2 = MetaData() + Table( + "tbl", + m1, + Column("id", Integer, primary_key=True), + Column("name", String), + type_(case), + ) + Table( + "tbl", + m2, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + diffs = self._fixture(m1, m2) + eq_(len(diffs), 1) + eq_(diffs[0][0], f"remove_{name}") + eq_(diffs[0][1].name, "nnd_obj") + + @case + @name_type + def test_toggle_not_distinct(self, case, name, type_): + m1 = MetaData() + m2 = MetaData() + to = not case + Table( + "tbl", + m1, + Column("id", Integer, primary_key=True), + Column("name", String), + type_(case), + ) + Table( + "tbl", + m2, + Column("id", Integer, primary_key=True), + Column("name", String), + type_(to), + ) + diffs = self._fixture(m1, m2) + eq_(len(diffs), 2) + eq_(diffs[0][0], f"remove_{name}") + eq_(diffs[1][0], f"add_{name}") + eq_(diffs[1][1].name, "nnd_obj") + eq_(diffs[1][1].dialect_kwargs["postgresql_nulls_not_distinct"], to)