From 3b09a89d95f765399324dd53b4cb8504b0a7903b Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 25 Nov 2022 12:29:40 -0500 Subject: [PATCH] run pyupgrade command is: find alembic -name "*.py" | xargs pyupgrade --py37-plus --keep-runtime-typing --keep-percent-format I'm having some weird fighting with the tools/write_pyi, where in different runtime contexts it keeps losing "MigrationContext" and also Callable drops the args, but it's not consisistent. For whatever reason, under py311 things *do* work every time. im working w/ clean tox environments so not really sure what the change is. anyway, let's at least fix the quoting up around the types. This is towards getting the "*" in the op signatures for #1130. Change-Id: I9175905d3b4325e03a97d6752356b70be20e9fad --- alembic/autogenerate/api.py | 58 ++-- alembic/autogenerate/compare.py | 315 ++++++++++----------- alembic/autogenerate/render.py | 148 +++++----- alembic/autogenerate/rewriter.py | 58 ++-- alembic/command.py | 24 +- alembic/config.py | 2 +- alembic/ddl/base.py | 98 +++---- alembic/ddl/impl.py | 111 ++++---- alembic/ddl/mssql.py | 66 ++--- alembic/ddl/mysql.py | 56 ++-- alembic/ddl/oracle.py | 26 +- alembic/ddl/postgresql.py | 76 ++--- alembic/ddl/sqlite.py | 26 +- alembic/operations/base.py | 8 +- alembic/operations/batch.py | 62 ++-- alembic/operations/ops.py | 305 ++++++++++---------- alembic/operations/schemaobj.py | 26 +- alembic/runtime/environment.py | 6 +- alembic/runtime/migration.py | 78 +++-- alembic/script/base.py | 10 +- alembic/script/revision.py | 46 ++- alembic/testing/env.py | 1 - alembic/testing/fixtures.py | 1 - alembic/testing/suite/_autogen_fixtures.py | 3 +- alembic/testing/warnings.py | 1 - alembic/util/langhelpers.py | 4 +- alembic/util/messaging.py | 6 +- alembic/util/sqla_compat.py | 52 ++-- tests/test_autogen_composition.py | 10 +- tests/test_autogen_diffs.py | 28 +- tests/test_autogen_indexes.py | 38 ++- tests/test_autogen_render.py | 4 +- tests/test_batch.py | 46 ++- tests/test_command.py | 16 +- tests/test_config.py | 1 - tests/test_environment.py | 1 - tests/test_external_dialect.py | 32 +-- tests/test_postgresql.py | 8 +- tests/test_script_consumption.py | 6 +- tests/test_script_production.py | 13 +- tests/test_version_traversal.py | 162 +++++------ 41 files changed, 962 insertions(+), 1076 deletions(-) diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py index cbd64e18..d7a0913d 100644 --- a/alembic/autogenerate/api.py +++ b/alembic/autogenerate/api.py @@ -40,7 +40,7 @@ from alembic.script.base import ScriptDirectory -def compare_metadata(context: "MigrationContext", metadata: "MetaData") -> Any: +def compare_metadata(context: MigrationContext, metadata: MetaData) -> Any: """Compare a database schema to that given in a :class:`~sqlalchemy.schema.MetaData` instance. @@ -136,8 +136,8 @@ def compare_metadata(context: "MigrationContext", metadata: "MetaData") -> Any: def produce_migrations( - context: "MigrationContext", metadata: "MetaData" -) -> "MigrationScript": + context: MigrationContext, metadata: MetaData +) -> MigrationScript: """Produce a :class:`.MigrationScript` structure based on schema comparison. @@ -167,13 +167,13 @@ def produce_migrations( def render_python_code( - up_or_down_op: "UpgradeOps", + up_or_down_op: UpgradeOps, sqlalchemy_module_prefix: str = "sa.", alembic_module_prefix: str = "op.", render_as_batch: bool = False, imports: Tuple[str, ...] = (), render_item: None = None, - migration_context: Optional["MigrationContext"] = None, + migration_context: Optional[MigrationContext] = None, ) -> str: """Render Python code given an :class:`.UpgradeOps` or :class:`.DowngradeOps` object. @@ -205,7 +205,7 @@ def render_python_code( def _render_migration_diffs( - context: "MigrationContext", template_args: Dict[Any, Any] + context: MigrationContext, template_args: Dict[Any, Any] ) -> None: """legacy, used by test_autogen_composition at the moment""" @@ -229,7 +229,7 @@ class AutogenContext: """Maintains configuration and state that's specific to an autogenerate operation.""" - metadata: Optional["MetaData"] = None + metadata: Optional[MetaData] = None """The :class:`~sqlalchemy.schema.MetaData` object representing the destination. @@ -247,7 +247,7 @@ class AutogenContext: """ - connection: Optional["Connection"] = None + connection: Optional[Connection] = None """The :class:`~sqlalchemy.engine.base.Connection` object currently connected to the database backend being compared. @@ -256,7 +256,7 @@ class AutogenContext: """ - dialect: Optional["Dialect"] = None + dialect: Optional[Dialect] = None """The :class:`~sqlalchemy.engine.Dialect` object currently in use. This is normally obtained from the @@ -278,13 +278,13 @@ class AutogenContext: """ - migration_context: "MigrationContext" = None # type: ignore[assignment] + migration_context: MigrationContext = None # type: ignore[assignment] """The :class:`.MigrationContext` established by the ``env.py`` script.""" def __init__( self, - migration_context: "MigrationContext", - metadata: Optional["MetaData"] = None, + migration_context: MigrationContext, + metadata: Optional[MetaData] = None, opts: Optional[dict] = None, autogenerate: bool = True, ) -> None: @@ -342,7 +342,7 @@ def __init__( self._has_batch: bool = False @util.memoized_property - def inspector(self) -> "Inspector": + def inspector(self) -> Inspector: if self.connection is None: raise TypeError( "can't return inspector as this " @@ -397,18 +397,16 @@ def run_name_filters( def run_object_filters( self, object_: Union[ - "Table", - "Index", - "Column", - "UniqueConstraint", - "ForeignKeyConstraint", + Table, + Index, + Column, + UniqueConstraint, + ForeignKeyConstraint, ], name: Optional[str], type_: str, reflected: bool, - compare_to: Optional[ - Union["Table", "Index", "Column", "UniqueConstraint"] - ], + compare_to: Optional[Union[Table, Index, Column, UniqueConstraint]], ) -> bool: """Run the context's object filters and return True if the targets should be part of the autogenerate operation. @@ -476,8 +474,8 @@ class RevisionContext: def __init__( self, - config: "Config", - script_directory: "ScriptDirectory", + config: Config, + script_directory: ScriptDirectory, command_args: Dict[str, Any], process_revision_directives: Optional[Callable] = None, ) -> None: @@ -492,8 +490,8 @@ def __init__( self.generated_revisions = [self._default_revision()] def _to_script( - self, migration_script: "MigrationScript" - ) -> Optional["Script"]: + self, migration_script: MigrationScript + ) -> Optional[Script]: template_args: Dict[str, Any] = self.template_args.copy() if getattr(migration_script, "_needs_render", False): @@ -522,19 +520,19 @@ def _to_script( ) def run_autogenerate( - self, rev: tuple, migration_context: "MigrationContext" + self, rev: tuple, migration_context: MigrationContext ) -> None: self._run_environment(rev, migration_context, True) def run_no_autogenerate( - self, rev: tuple, migration_context: "MigrationContext" + self, rev: tuple, migration_context: MigrationContext ) -> None: self._run_environment(rev, migration_context, False) def _run_environment( self, rev: tuple, - migration_context: "MigrationContext", + migration_context: MigrationContext, autogenerate: bool, ) -> None: if autogenerate: @@ -587,7 +585,7 @@ def _run_environment( for migration_script in self.generated_revisions: migration_script._needs_render = True - def _default_revision(self) -> "MigrationScript": + def _default_revision(self) -> MigrationScript: command_args: Dict[str, Any] = self.command_args op = ops.MigrationScript( rev_id=command_args["rev_id"] or util.rev_id(), @@ -602,6 +600,6 @@ def _default_revision(self) -> "MigrationScript": ) return op - def generate_scripts(self) -> Iterator[Optional["Script"]]: + def generate_scripts(self) -> Iterator[Optional[Script]]: for generated_revision in self.generated_revisions: yield self._to_script(generated_revision) diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py index c32ab4d9..c9971ea3 100644 --- a/alembic/autogenerate/compare.py +++ b/alembic/autogenerate/compare.py @@ -47,7 +47,7 @@ def _populate_migration_script( - autogen_context: "AutogenContext", migration_script: "MigrationScript" + autogen_context: AutogenContext, migration_script: MigrationScript ) -> None: upgrade_ops = migration_script.upgrade_ops_list[-1] downgrade_ops = migration_script.downgrade_ops_list[-1] @@ -60,14 +60,14 @@ def _populate_migration_script( def _produce_net_changes( - autogen_context: "AutogenContext", upgrade_ops: "UpgradeOps" + autogen_context: AutogenContext, upgrade_ops: UpgradeOps ) -> None: connection = autogen_context.connection assert connection is not None include_schemas = autogen_context.opts.get("include_schemas", False) - inspector: "Inspector" = inspect(connection) + inspector: Inspector = inspect(connection) default_schema = connection.dialect.default_schema_name schemas: Set[Optional[str]] @@ -93,8 +93,8 @@ def _produce_net_changes( @comparators.dispatch_for("schema") def _autogen_for_tables( - autogen_context: "AutogenContext", - upgrade_ops: "UpgradeOps", + autogen_context: AutogenContext, + upgrade_ops: UpgradeOps, schemas: Union[Set[None], Set[Optional[str]]], ) -> None: inspector = autogen_context.inspector @@ -135,11 +135,11 @@ def _autogen_for_tables( def _compare_tables( - conn_table_names: "set", - metadata_table_names: "set", - inspector: "Inspector", - upgrade_ops: "UpgradeOps", - autogen_context: "AutogenContext", + conn_table_names: set, + metadata_table_names: set, + inspector: Inspector, + upgrade_ops: UpgradeOps, + autogen_context: AutogenContext, ) -> None: default_schema = inspector.bind.dialect.default_schema_name @@ -159,17 +159,14 @@ def _compare_tables( # to adjust for the MetaData collection storing the tables either # as "schemaname.tablename" or just "tablename", create a new lookup # which will match the "non-default-schema" keys to the Table object. - tname_to_table = dict( - ( - no_dflt_schema, - autogen_context.table_key_to_table[ - sa_schema._get_table_key(tname, schema) - ], - ) + tname_to_table = { + no_dflt_schema: autogen_context.table_key_to_table[ + sa_schema._get_table_key(tname, schema) + ] for no_dflt_schema, (schema, tname) in zip( metadata_table_names_no_dflt_schema, metadata_table_names ) - ) + } metadata_table_names = metadata_table_names_no_dflt_schema for s, tname in metadata_table_names.difference(conn_table_names): @@ -279,9 +276,7 @@ def _compare_tables( upgrade_ops.ops.append(modify_table_ops) -def _make_index( - params: Dict[str, Any], conn_table: "Table" -) -> Optional["Index"]: +def _make_index(params: Dict[str, Any], conn_table: Table) -> Optional[Index]: exprs = [] for col_name in params["column_names"]: if col_name is None: @@ -302,8 +297,8 @@ def _make_index( def _make_unique_constraint( - params: Dict[str, Any], conn_table: "Table" -) -> "UniqueConstraint": + params: Dict[str, Any], conn_table: Table +) -> UniqueConstraint: uq = sa_schema.UniqueConstraint( *[conn_table.c[cname] for cname in params["column_names"]], name=params["name"], @@ -315,8 +310,8 @@ def _make_unique_constraint( def _make_foreign_key( - params: Dict[str, Any], conn_table: "Table" -) -> "ForeignKeyConstraint": + params: Dict[str, Any], conn_table: Table +) -> ForeignKeyConstraint: tname = params["referred_table"] if params["referred_schema"]: tname = "%s.%s" % (params["referred_schema"], tname) @@ -340,12 +335,12 @@ def _make_foreign_key( @contextlib.contextmanager def _compare_columns( schema: Optional[str], - tname: Union["quoted_name", str], - conn_table: "Table", - metadata_table: "Table", - modify_table_ops: "ModifyTableOps", - autogen_context: "AutogenContext", - inspector: "Inspector", + tname: Union[quoted_name, str], + conn_table: Table, + metadata_table: Table, + modify_table_ops: ModifyTableOps, + autogen_context: AutogenContext, + inspector: Inspector, ) -> Iterator[None]: name = "%s.%s" % (schema, tname) if schema else tname metadata_col_names = OrderedSet( @@ -411,9 +406,9 @@ def _compare_columns( class _constraint_sig: - const: Union["UniqueConstraint", "ForeignKeyConstraint", "Index"] + const: Union[UniqueConstraint, ForeignKeyConstraint, Index] - def md_name_to_sql_name(self, context: "AutogenContext") -> Optional[str]: + def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]: return sqla_compat._get_constraint_final_name( self.const, context.dialect ) @@ -432,7 +427,7 @@ class _uq_constraint_sig(_constraint_sig): is_index = False is_unique = True - def __init__(self, const: "UniqueConstraint") -> None: + def __init__(self, const: UniqueConstraint) -> None: self.const = const self.name = const.name self.sig = tuple(sorted([col.name for col in const.columns])) @@ -445,25 +440,25 @@ def column_names(self) -> List[str]: class _ix_constraint_sig(_constraint_sig): is_index = True - def __init__(self, const: "Index") -> None: + def __init__(self, const: Index) -> None: self.const = const self.name = const.name self.sig = tuple(sorted([col.name for col in const.columns])) self.is_unique = bool(const.unique) - def md_name_to_sql_name(self, context: "AutogenContext") -> Optional[str]: + def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]: return sqla_compat._get_constraint_final_name( self.const, context.dialect ) @property - def column_names(self) -> Union[List["quoted_name"], List[None]]: + def column_names(self) -> Union[List[quoted_name], List[None]]: return sqla_compat._get_index_column_names(self.const) class _fk_constraint_sig(_constraint_sig): def __init__( - self, const: "ForeignKeyConstraint", include_options: bool = False + self, const: ForeignKeyConstraint, include_options: bool = False ) -> None: self.const = const self.name = const.name @@ -508,12 +503,12 @@ def __init__( @comparators.dispatch_for("table") def _compare_indexes_and_uniques( - autogen_context: "AutogenContext", - modify_ops: "ModifyTableOps", + autogen_context: AutogenContext, + modify_ops: ModifyTableOps, schema: Optional[str], - tname: Union["quoted_name", str], - conn_table: Optional["Table"], - metadata_table: Optional["Table"], + tname: Union[quoted_name, str], + conn_table: Optional[Table], + metadata_table: Optional[Table], ) -> None: inspector = autogen_context.inspector @@ -522,11 +517,11 @@ def _compare_indexes_and_uniques( # 1a. get raw indexes and unique constraints from metadata ... if metadata_table is not None: - metadata_unique_constraints = set( + metadata_unique_constraints = { uq for uq in metadata_table.constraints if isinstance(uq, sa_schema.UniqueConstraint) - ) + } metadata_indexes = set(metadata_table.indexes) else: metadata_unique_constraints = set() @@ -589,16 +584,16 @@ def _compare_indexes_and_uniques( # for DROP TABLE uniques are inline, don't need them conn_uniques = set() # type:ignore[assignment] else: - conn_uniques = set( # type:ignore[assignment] + conn_uniques = { # type:ignore[assignment] _make_unique_constraint(uq_def, conn_table) for uq_def in conn_uniques - ) + } - conn_indexes = set( # type:ignore[assignment] + conn_indexes = { # type:ignore[assignment] index for index in (_make_index(ix, conn_table) for ix in conn_indexes) if index is not None - ) + } # 2a. if the dialect dupes unique indexes as unique constraints # (mysql and oracle), correct for that @@ -626,63 +621,59 @@ def _compare_indexes_and_uniques( # _constraint_sig() objects provide a consistent facade over both # Index and UniqueConstraint so we can easily work with them # interchangeably - metadata_unique_constraints_sig = set( + metadata_unique_constraints_sig = { _uq_constraint_sig(uq) for uq in metadata_unique_constraints - ) + } - metadata_indexes_sig = set( - _ix_constraint_sig(ix) for ix in metadata_indexes - ) + metadata_indexes_sig = {_ix_constraint_sig(ix) for ix in metadata_indexes} - conn_unique_constraints = set( - _uq_constraint_sig(uq) for uq in conn_uniques - ) + conn_unique_constraints = {_uq_constraint_sig(uq) for uq in conn_uniques} - conn_indexes_sig = set(_ix_constraint_sig(ix) for ix in conn_indexes) + conn_indexes_sig = {_ix_constraint_sig(ix) for ix in conn_indexes} # 5. index things by name, for those objects that have names - metadata_names = dict( - (cast(str, c.md_name_to_sql_name(autogen_context)), c) + metadata_names = { + cast(str, c.md_name_to_sql_name(autogen_context)): c for c in metadata_unique_constraints_sig.union( metadata_indexes_sig # type:ignore[arg-type] ) if isinstance(c, _ix_constraint_sig) or sqla_compat._constraint_is_named(c.const, autogen_context.dialect) - ) + } - conn_uniques_by_name = dict((c.name, c) for c in conn_unique_constraints) - conn_indexes_by_name: Dict[Optional[str], _ix_constraint_sig] = dict( - (c.name, c) for c in conn_indexes_sig - ) - conn_names = dict( - (c.name, c) + conn_uniques_by_name = {c.name: c for c in conn_unique_constraints} + conn_indexes_by_name: Dict[Optional[str], _ix_constraint_sig] = { + c.name: c for c in conn_indexes_sig + } + conn_names = { + c.name: c for c in conn_unique_constraints.union( conn_indexes_sig # type:ignore[arg-type] ) if c.name is not None - ) + } - doubled_constraints = dict( - (name, (conn_uniques_by_name[name], conn_indexes_by_name[name])) + doubled_constraints = { + name: (conn_uniques_by_name[name], conn_indexes_by_name[name]) for name in set(conn_uniques_by_name).intersection( conn_indexes_by_name ) - ) + } # 6. index things by "column signature", to help with unnamed unique # constraints. - conn_uniques_by_sig = dict((uq.sig, uq) for uq in conn_unique_constraints) - metadata_uniques_by_sig = dict( - (uq.sig, uq) for uq in metadata_unique_constraints_sig - ) - metadata_indexes_by_sig = dict((ix.sig, ix) for ix in metadata_indexes_sig) - unnamed_metadata_uniques = dict( - (uq.sig, uq) + conn_uniques_by_sig = {uq.sig: uq for uq in conn_unique_constraints} + metadata_uniques_by_sig = { + uq.sig: uq for uq in metadata_unique_constraints_sig + } + metadata_indexes_by_sig = {ix.sig: ix for ix in metadata_indexes_sig} + unnamed_metadata_uniques = { + uq.sig: uq for uq in metadata_unique_constraints_sig if not sqla_compat._constraint_is_named( uq.const, autogen_context.dialect ) - ) + } # assumptions: # 1. a unique constraint or an index from the connection *always* @@ -864,37 +855,31 @@ def _correct_for_uq_duplicates_uix( for cons in metadata_unique_constraints ] - metadata_uq_names = set( + metadata_uq_names = { name for name, cons in metadata_cons_names if name is not None - ) + } - unnamed_metadata_uqs = set( - [ - _uq_constraint_sig(cons).sig - for name, cons in metadata_cons_names - if name is None - ] - ) + unnamed_metadata_uqs = { + _uq_constraint_sig(cons).sig + for name, cons in metadata_cons_names + if name is None + } - metadata_ix_names = set( - [ - sqla_compat._get_constraint_final_name(cons, dialect) - for cons in metadata_indexes - if cons.unique - ] - ) + metadata_ix_names = { + sqla_compat._get_constraint_final_name(cons, dialect) + for cons in metadata_indexes + if cons.unique + } # for reflection side, names are in their final database form # already since they're from the database - conn_ix_names = dict( - (cons.name, cons) for cons in conn_indexes if cons.unique - ) + conn_ix_names = {cons.name: cons for cons in conn_indexes if cons.unique} - uqs_dupe_indexes = dict( - (cons.name, cons) + uqs_dupe_indexes = { + cons.name: cons for cons in conn_unique_constraints if cons.info["duplicates_index"] - ) + } for overlap in uqs_dupe_indexes: if overlap not in metadata_uq_names: @@ -910,13 +895,13 @@ def _correct_for_uq_duplicates_uix( @comparators.dispatch_for("column") def _compare_nullable( - autogen_context: "AutogenContext", - alter_column_op: "AlterColumnOp", + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, schema: Optional[str], - tname: Union["quoted_name", str], - cname: Union["quoted_name", str], - conn_col: "Column", - metadata_col: "Column", + tname: Union[quoted_name, str], + cname: Union[quoted_name, str], + conn_col: Column, + metadata_col: Column, ) -> None: metadata_col_nullable = metadata_col.nullable @@ -952,13 +937,13 @@ def _compare_nullable( @comparators.dispatch_for("column") def _setup_autoincrement( - autogen_context: "AutogenContext", - alter_column_op: "AlterColumnOp", + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, schema: Optional[str], - tname: Union["quoted_name", str], - cname: "quoted_name", - conn_col: "Column", - metadata_col: "Column", + tname: Union[quoted_name, str], + cname: quoted_name, + conn_col: Column, + metadata_col: Column, ) -> None: if metadata_col.table._autoincrement_column is metadata_col: @@ -971,13 +956,13 @@ def _setup_autoincrement( @comparators.dispatch_for("column") def _compare_type( - autogen_context: "AutogenContext", - alter_column_op: "AlterColumnOp", + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, schema: Optional[str], - tname: Union["quoted_name", str], - cname: Union["quoted_name", str], - conn_col: "Column", - metadata_col: "Column", + tname: Union[quoted_name, str], + cname: Union[quoted_name, str], + conn_col: Column, + metadata_col: Column, ) -> None: conn_type = conn_col.type @@ -1015,8 +1000,8 @@ def _compare_type( def _render_server_default_for_compare( metadata_default: Optional[Any], - metadata_col: "Column", - autogen_context: "AutogenContext", + metadata_col: Column, + autogen_context: AutogenContext, ) -> Optional[str]: rendered = _user_defined_render( "server_default", metadata_default, autogen_context @@ -1055,13 +1040,13 @@ def _normalize_computed_default(sqltext: str) -> str: def _compare_computed_default( - autogen_context: "AutogenContext", - alter_column_op: "AlterColumnOp", + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, schema: Optional[str], - tname: "str", - cname: "str", - conn_col: "Column", - metadata_col: "Column", + tname: str, + cname: str, + conn_col: Column, + metadata_col: Column, ) -> None: rendered_metadata_default = str( cast(sa_schema.Computed, metadata_col.server_default).sqltext.compile( @@ -1121,13 +1106,13 @@ def _compare_identity_default( @comparators.dispatch_for("column") def _compare_server_default( - autogen_context: "AutogenContext", - alter_column_op: "AlterColumnOp", + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, schema: Optional[str], - tname: Union["quoted_name", str], - cname: Union["quoted_name", str], - conn_col: "Column", - metadata_col: "Column", + tname: Union[quoted_name, str], + cname: Union[quoted_name, str], + conn_col: Column, + metadata_col: Column, ) -> Optional[bool]: metadata_default = metadata_col.server_default @@ -1210,14 +1195,14 @@ def _compare_server_default( @comparators.dispatch_for("column") def _compare_column_comment( - autogen_context: "AutogenContext", - alter_column_op: "AlterColumnOp", + autogen_context: AutogenContext, + alter_column_op: AlterColumnOp, schema: Optional[str], - tname: Union["quoted_name", str], - cname: "quoted_name", - conn_col: "Column", - metadata_col: "Column", -) -> Optional["Literal[False]"]: + tname: Union[quoted_name, str], + cname: quoted_name, + conn_col: Column, + metadata_col: Column, +) -> Optional[Literal[False]]: assert autogen_context.dialect is not None if not autogen_context.dialect.supports_comments: @@ -1239,12 +1224,12 @@ def _compare_column_comment( @comparators.dispatch_for("table") def _compare_foreign_keys( - autogen_context: "AutogenContext", - modify_table_ops: "ModifyTableOps", + autogen_context: AutogenContext, + modify_table_ops: ModifyTableOps, schema: Optional[str], - tname: Union["quoted_name", str], - conn_table: Optional["Table"], - metadata_table: Optional["Table"], + tname: Union[quoted_name, str], + conn_table: Optional[Table], + metadata_table: Optional[Table], ) -> None: # if we're doing CREATE TABLE, all FKs are created @@ -1253,11 +1238,11 @@ def _compare_foreign_keys( return inspector = autogen_context.inspector - metadata_fks = set( + metadata_fks = { fk for fk in metadata_table.constraints if isinstance(fk, sa_schema.ForeignKeyConstraint) - ) + } conn_fks_list = [ fk @@ -1273,10 +1258,10 @@ def _compare_foreign_keys( conn_fks_list and "options" in conn_fks_list[0] ) - conn_fks = set( + conn_fks = { _make_foreign_key(const, conn_table) # type: ignore[arg-type] for const in conn_fks_list - ) + } # give the dialect a chance to correct the FKs to match more # closely @@ -1284,25 +1269,23 @@ def _compare_foreign_keys( conn_fks, metadata_fks ) - metadata_fks_sig = set( + metadata_fks_sig = { _fk_constraint_sig(fk, include_options=backend_reflects_fk_options) for fk in metadata_fks - ) + } - conn_fks_sig = set( + conn_fks_sig = { _fk_constraint_sig(fk, include_options=backend_reflects_fk_options) for fk in conn_fks - ) + } - conn_fks_by_sig = dict((c.sig, c) for c in conn_fks_sig) - metadata_fks_by_sig = dict((c.sig, c) for c in metadata_fks_sig) + conn_fks_by_sig = {c.sig: c for c in conn_fks_sig} + metadata_fks_by_sig = {c.sig: c for c in metadata_fks_sig} - metadata_fks_by_name = dict( - (c.name, c) for c in metadata_fks_sig if c.name is not None - ) - conn_fks_by_name = dict( - (c.name, c) for c in conn_fks_sig if c.name is not None - ) + metadata_fks_by_name = { + c.name: c for c in metadata_fks_sig if c.name is not None + } + conn_fks_by_name = {c.name: c for c in conn_fks_sig if c.name is not None} def _add_fk(obj, compare_to): if autogen_context.run_object_filters( @@ -1361,12 +1344,12 @@ def _remove_fk(obj, compare_to): @comparators.dispatch_for("table") def _compare_table_comment( - autogen_context: "AutogenContext", - modify_table_ops: "ModifyTableOps", + autogen_context: AutogenContext, + modify_table_ops: ModifyTableOps, schema: Optional[str], - tname: Union["quoted_name", str], - conn_table: Optional["Table"], - metadata_table: Optional["Table"], + tname: Union[quoted_name, str], + conn_table: Optional[Table], + metadata_table: Optional[Table], ) -> None: assert autogen_context.dialect is not None diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py index 1ac6753d..41903d81 100644 --- a/alembic/autogenerate/render.py +++ b/alembic/autogenerate/render.py @@ -54,9 +54,9 @@ def _render_gen_name( - autogen_context: "AutogenContext", - name: Optional[Union["quoted_name", str]], -) -> Optional[Union["quoted_name", str, "_f_name"]]: + autogen_context: AutogenContext, + name: Optional[Union[quoted_name, str]], +) -> Optional[Union[quoted_name, str, _f_name]]: if isinstance(name, conv): return _f_name(_alembic_autogenerate_prefix(autogen_context), name) else: @@ -70,9 +70,9 @@ def _indent(text: str) -> str: def _render_python_into_templatevars( - autogen_context: "AutogenContext", - migration_script: "MigrationScript", - template_args: Dict[str, Union[str, "Config"]], + autogen_context: AutogenContext, + migration_script: MigrationScript, + template_args: Dict[str, Union[str, Config]], ) -> None: imports = autogen_context.imports @@ -92,8 +92,8 @@ def _render_python_into_templatevars( def _render_cmd_body( - op_container: "ops.OpContainer", - autogen_context: "AutogenContext", + op_container: ops.OpContainer, + autogen_context: AutogenContext, ) -> str: buf = StringIO() @@ -120,7 +120,7 @@ def _render_cmd_body( def render_op( - autogen_context: "AutogenContext", op: "ops.MigrateOperation" + autogen_context: AutogenContext, op: ops.MigrateOperation ) -> List[str]: renderer = renderers.dispatch(op) lines = util.to_list(renderer(autogen_context, op)) @@ -128,14 +128,14 @@ def render_op( def render_op_text( - autogen_context: "AutogenContext", op: "ops.MigrateOperation" + autogen_context: AutogenContext, op: ops.MigrateOperation ) -> str: return "\n".join(render_op(autogen_context, op)) @renderers.dispatch_for(ops.ModifyTableOps) def _render_modify_table( - autogen_context: "AutogenContext", op: "ModifyTableOps" + autogen_context: AutogenContext, op: ModifyTableOps ) -> List[str]: opts = autogen_context.opts render_as_batch = opts.get("render_as_batch", False) @@ -164,7 +164,7 @@ def _render_modify_table( @renderers.dispatch_for(ops.CreateTableCommentOp) def _render_create_table_comment( - autogen_context: "AutogenContext", op: "ops.CreateTableCommentOp" + autogen_context: AutogenContext, op: ops.CreateTableCommentOp ) -> str: templ = ( @@ -189,7 +189,7 @@ def _render_create_table_comment( @renderers.dispatch_for(ops.DropTableCommentOp) def _render_drop_table_comment( - autogen_context: "AutogenContext", op: "ops.DropTableCommentOp" + autogen_context: AutogenContext, op: ops.DropTableCommentOp ) -> str: templ = ( @@ -211,9 +211,7 @@ def _render_drop_table_comment( @renderers.dispatch_for(ops.CreateTableOp) -def _add_table( - autogen_context: "AutogenContext", op: "ops.CreateTableOp" -) -> str: +def _add_table(autogen_context: AutogenContext, op: ops.CreateTableOp) -> str: table = op.to_table() args = [ @@ -263,9 +261,7 @@ def _add_table( @renderers.dispatch_for(ops.DropTableOp) -def _drop_table( - autogen_context: "AutogenContext", op: "ops.DropTableOp" -) -> str: +def _drop_table(autogen_context: AutogenContext, op: ops.DropTableOp) -> str: text = "%(prefix)sdrop_table(%(tname)r" % { "prefix": _alembic_autogenerate_prefix(autogen_context), "tname": _ident(op.table_name), @@ -277,9 +273,7 @@ def _drop_table( @renderers.dispatch_for(ops.CreateIndexOp) -def _add_index( - autogen_context: "AutogenContext", op: "ops.CreateIndexOp" -) -> str: +def _add_index(autogen_context: AutogenContext, op: ops.CreateIndexOp) -> str: index = op.to_index() has_batch = autogen_context._has_batch @@ -324,9 +318,7 @@ def _add_index( @renderers.dispatch_for(ops.DropIndexOp) -def _drop_index( - autogen_context: "AutogenContext", op: "ops.DropIndexOp" -) -> str: +def _drop_index(autogen_context: AutogenContext, op: ops.DropIndexOp) -> str: index = op.to_index() has_batch = autogen_context._has_batch @@ -362,14 +354,14 @@ def _drop_index( @renderers.dispatch_for(ops.CreateUniqueConstraintOp) def _add_unique_constraint( - autogen_context: "AutogenContext", op: "ops.CreateUniqueConstraintOp" + autogen_context: AutogenContext, op: ops.CreateUniqueConstraintOp ) -> List[str]: return [_uq_constraint(op.to_constraint(), autogen_context, True)] @renderers.dispatch_for(ops.CreateForeignKeyOp) def _add_fk_constraint( - autogen_context: "AutogenContext", op: "ops.CreateForeignKeyOp" + autogen_context: AutogenContext, op: ops.CreateForeignKeyOp ) -> str: args = [repr(_render_gen_name(autogen_context, op.constraint_name))] @@ -418,7 +410,7 @@ def _add_check_constraint(constraint, autogen_context): @renderers.dispatch_for(ops.DropConstraintOp) def _drop_constraint( - autogen_context: "AutogenContext", op: "ops.DropConstraintOp" + autogen_context: AutogenContext, op: ops.DropConstraintOp ) -> str: if autogen_context._has_batch: @@ -440,9 +432,7 @@ def _drop_constraint( @renderers.dispatch_for(ops.AddColumnOp) -def _add_column( - autogen_context: "AutogenContext", op: "ops.AddColumnOp" -) -> str: +def _add_column(autogen_context: AutogenContext, op: ops.AddColumnOp) -> str: schema, tname, column = op.schema, op.table_name, op.column if autogen_context._has_batch: @@ -462,9 +452,7 @@ def _add_column( @renderers.dispatch_for(ops.DropColumnOp) -def _drop_column( - autogen_context: "AutogenContext", op: "ops.DropColumnOp" -) -> str: +def _drop_column(autogen_context: AutogenContext, op: ops.DropColumnOp) -> str: schema, tname, column_name = op.schema, op.table_name, op.column_name @@ -487,7 +475,7 @@ def _drop_column( @renderers.dispatch_for(ops.AlterColumnOp) def _alter_column( - autogen_context: "AutogenContext", op: "ops.AlterColumnOp" + autogen_context: AutogenContext, op: ops.AlterColumnOp ) -> str: tname = op.table_name @@ -556,7 +544,7 @@ def __repr__(self) -> str: return "%sf(%r)" % (self.prefix, _ident(self.name)) -def _ident(name: Optional[Union["quoted_name", str]]) -> Optional[str]: +def _ident(name: Optional[Union[quoted_name, str]]) -> Optional[str]: """produce a __repr__() object for a string identifier that may use quoted_name() in SQLAlchemy 0.9 and greater. @@ -574,7 +562,7 @@ def _ident(name: Optional[Union["quoted_name", str]]) -> Optional[str]: def _render_potential_expr( value: Any, - autogen_context: "AutogenContext", + autogen_context: AutogenContext, wrap_in_text: bool = True, is_server_default: bool = False, ) -> str: @@ -597,7 +585,7 @@ def _render_potential_expr( def _get_index_rendered_expressions( - idx: "Index", autogen_context: "AutogenContext" + idx: Index, autogen_context: AutogenContext ) -> List[str]: return [ repr(_ident(getattr(exp, "name", None))) @@ -608,8 +596,8 @@ def _get_index_rendered_expressions( def _uq_constraint( - constraint: "UniqueConstraint", - autogen_context: "AutogenContext", + constraint: UniqueConstraint, + autogen_context: AutogenContext, alter: bool, ) -> str: opts: List[Tuple[str, Any]] = [] @@ -654,11 +642,11 @@ def _user_autogenerate_prefix(autogen_context, target): return prefix -def _sqlalchemy_autogenerate_prefix(autogen_context: "AutogenContext") -> str: +def _sqlalchemy_autogenerate_prefix(autogen_context: AutogenContext) -> str: return autogen_context.opts["sqlalchemy_module_prefix"] or "" -def _alembic_autogenerate_prefix(autogen_context: "AutogenContext") -> str: +def _alembic_autogenerate_prefix(autogen_context: AutogenContext) -> str: if autogen_context._has_batch: return "batch_op." else: @@ -666,8 +654,8 @@ def _alembic_autogenerate_prefix(autogen_context: "AutogenContext") -> str: def _user_defined_render( - type_: str, object_: Any, autogen_context: "AutogenContext" -) -> Union[str, "Literal[False]"]: + type_: str, object_: Any, autogen_context: AutogenContext +) -> Union[str, Literal[False]]: if "render_item" in autogen_context.opts: render = autogen_context.opts["render_item"] if render: @@ -677,7 +665,7 @@ def _user_defined_render( return False -def _render_column(column: "Column", autogen_context: "AutogenContext") -> str: +def _render_column(column: Column, autogen_context: AutogenContext) -> str: rendered = _user_defined_render("column", column, autogen_context) if rendered is not False: return rendered @@ -734,7 +722,7 @@ def _render_column(column: "Column", autogen_context: "AutogenContext") -> str: def _should_render_server_default_positionally( - server_default: Union["Computed", "DefaultClause"] + server_default: Union[Computed, DefaultClause] ) -> bool: return sqla_compat._server_default_is_computed( server_default @@ -742,10 +730,8 @@ def _should_render_server_default_positionally( def _render_server_default( - default: Optional[ - Union["FetchedValue", str, "TextClause", "ColumnElement"] - ], - autogen_context: "AutogenContext", + default: Optional[Union[FetchedValue, str, TextClause, ColumnElement]], + autogen_context: AutogenContext, repr_: bool = True, ) -> Optional[str]: rendered = _user_defined_render("server_default", default, autogen_context) @@ -771,7 +757,7 @@ def _render_server_default( def _render_computed( - computed: "Computed", autogen_context: "AutogenContext" + computed: Computed, autogen_context: AutogenContext ) -> str: text = _render_potential_expr( computed.sqltext, autogen_context, wrap_in_text=False @@ -788,7 +774,7 @@ def _render_computed( def _render_identity( - identity: "Identity", autogen_context: "AutogenContext" + identity: Identity, autogen_context: AutogenContext ) -> str: # always=None means something different than always=False kwargs = OrderedDict(always=identity.always) @@ -802,7 +788,7 @@ def _render_identity( } -def _get_identity_options(identity_options: "Identity") -> OrderedDict: +def _get_identity_options(identity_options: Identity) -> OrderedDict: kwargs = OrderedDict() for attr in sqla_compat._identity_options_attrs: value = getattr(identity_options, attr, None) @@ -812,8 +798,8 @@ def _get_identity_options(identity_options: "Identity") -> OrderedDict: def _repr_type( - type_: "TypeEngine", - autogen_context: "AutogenContext", + type_: TypeEngine, + autogen_context: AutogenContext, _skip_variants: bool = False, ) -> str: rendered = _user_defined_render("type", type_, autogen_context) @@ -855,9 +841,7 @@ def _repr_type( return "%s%r" % (prefix, type_) -def _render_ARRAY_type( - type_: "ARRAY", autogen_context: "AutogenContext" -) -> str: +def _render_ARRAY_type(type_: ARRAY, autogen_context: AutogenContext) -> str: return cast( str, _render_type_w_subtype( @@ -867,7 +851,7 @@ def _render_ARRAY_type( def _render_Variant_type( - type_: "TypeEngine", autogen_context: "AutogenContext" + type_: TypeEngine, autogen_context: AutogenContext ) -> str: base_type, variant_mapping = sqla_compat._get_variant_mapping(type_) base = _repr_type(base_type, autogen_context, _skip_variants=True) @@ -882,12 +866,12 @@ def _render_Variant_type( def _render_type_w_subtype( - type_: "TypeEngine", - autogen_context: "AutogenContext", + type_: TypeEngine, + autogen_context: AutogenContext, attrname: str, regexp: str, prefix: Optional[str] = None, -) -> Union[Optional[str], "Literal[False]"]: +) -> Union[Optional[str], Literal[False]]: outer_repr = repr(type_) inner_type = getattr(type_, attrname, None) if inner_type is None: @@ -919,9 +903,9 @@ def _render_type_w_subtype( def _render_constraint( - constraint: "Constraint", - autogen_context: "AutogenContext", - namespace_metadata: Optional["MetaData"], + constraint: Constraint, + autogen_context: AutogenContext, + namespace_metadata: Optional[MetaData], ) -> Optional[str]: try: renderer = _constraint_renderers.dispatch(constraint) @@ -934,9 +918,9 @@ def _render_constraint( @_constraint_renderers.dispatch_for(sa_schema.PrimaryKeyConstraint) def _render_primary_key( - constraint: "PrimaryKeyConstraint", - autogen_context: "AutogenContext", - namespace_metadata: Optional["MetaData"], + constraint: PrimaryKeyConstraint, + autogen_context: AutogenContext, + namespace_metadata: Optional[MetaData], ) -> Optional[str]: rendered = _user_defined_render("primary_key", constraint, autogen_context) if rendered is not False: @@ -960,9 +944,9 @@ def _render_primary_key( def _fk_colspec( - fk: "ForeignKey", + fk: ForeignKey, metadata_schema: Optional[str], - namespace_metadata: "MetaData", + namespace_metadata: MetaData, ) -> str: """Implement a 'safe' version of ForeignKey._get_colspec() that won't fail if the remote table can't be resolved. @@ -997,7 +981,7 @@ def _fk_colspec( def _populate_render_fk_opts( - constraint: "ForeignKeyConstraint", opts: List[Tuple[str, str]] + constraint: ForeignKeyConstraint, opts: List[Tuple[str, str]] ) -> None: if constraint.onupdate: @@ -1014,9 +998,9 @@ def _populate_render_fk_opts( @_constraint_renderers.dispatch_for(sa_schema.ForeignKeyConstraint) def _render_foreign_key( - constraint: "ForeignKeyConstraint", - autogen_context: "AutogenContext", - namespace_metadata: "MetaData", + constraint: ForeignKeyConstraint, + autogen_context: AutogenContext, + namespace_metadata: MetaData, ) -> Optional[str]: rendered = _user_defined_render("foreign_key", constraint, autogen_context) if rendered is not False: @@ -1053,9 +1037,9 @@ def _render_foreign_key( @_constraint_renderers.dispatch_for(sa_schema.UniqueConstraint) def _render_unique_constraint( - constraint: "UniqueConstraint", - autogen_context: "AutogenContext", - namespace_metadata: Optional["MetaData"], + constraint: UniqueConstraint, + autogen_context: AutogenContext, + namespace_metadata: Optional[MetaData], ) -> str: rendered = _user_defined_render("unique", constraint, autogen_context) if rendered is not False: @@ -1066,9 +1050,9 @@ def _render_unique_constraint( @_constraint_renderers.dispatch_for(sa_schema.CheckConstraint) def _render_check_constraint( - constraint: "CheckConstraint", - autogen_context: "AutogenContext", - namespace_metadata: Optional["MetaData"], + constraint: CheckConstraint, + autogen_context: AutogenContext, + namespace_metadata: Optional[MetaData], ) -> Optional[str]: rendered = _user_defined_render("check", constraint, autogen_context) if rendered is not False: @@ -1106,9 +1090,7 @@ def _render_check_constraint( @renderers.dispatch_for(ops.ExecuteSQLOp) -def _execute_sql( - autogen_context: "AutogenContext", op: "ops.ExecuteSQLOp" -) -> str: +def _execute_sql(autogen_context: AutogenContext, op: ops.ExecuteSQLOp) -> str: if not isinstance(op.sqltext, str): raise NotImplementedError( "Autogenerate rendering of SQL Expression language constructs " diff --git a/alembic/autogenerate/rewriter.py b/alembic/autogenerate/rewriter.py index 79f665a0..1a29b963 100644 --- a/alembic/autogenerate/rewriter.py +++ b/alembic/autogenerate/rewriter.py @@ -95,11 +95,11 @@ def add_column_idx(context, revision, op): def rewrites( self, operator: Union[ - Type["AddColumnOp"], - Type["MigrateOperation"], - Type["AlterColumnOp"], - Type["CreateTableOp"], - Type["ModifyTableOps"], + Type[AddColumnOp], + Type[MigrateOperation], + Type[AlterColumnOp], + Type[CreateTableOp], + Type[ModifyTableOps], ], ) -> Callable: """Register a function as rewriter for a given type. @@ -118,10 +118,10 @@ def add_column_nullable(context, revision, op): def _rewrite( self, - context: "MigrationContext", - revision: "Revision", - directive: "MigrateOperation", - ) -> Iterator["MigrateOperation"]: + context: MigrationContext, + revision: Revision, + directive: MigrateOperation, + ) -> Iterator[MigrateOperation]: try: _rewriter = self.dispatch.dispatch(directive) except ValueError: @@ -141,9 +141,9 @@ def _rewrite( def __call__( self, - context: "MigrationContext", - revision: "Revision", - directives: List["MigrationScript"], + context: MigrationContext, + revision: Revision, + directives: List[MigrationScript], ) -> None: self.process_revision_directives(context, revision, directives) if self._chained: @@ -152,9 +152,9 @@ def __call__( @_traverse.dispatch_for(ops.MigrationScript) def _traverse_script( self, - context: "MigrationContext", - revision: "Revision", - directive: "MigrationScript", + context: MigrationContext, + revision: Revision, + directive: MigrationScript, ) -> None: upgrade_ops_list = [] for upgrade_ops in directive.upgrade_ops_list: @@ -179,26 +179,26 @@ def _traverse_script( @_traverse.dispatch_for(ops.OpContainer) def _traverse_op_container( self, - context: "MigrationContext", - revision: "Revision", - directive: "OpContainer", + context: MigrationContext, + revision: Revision, + directive: OpContainer, ) -> None: self._traverse_list(context, revision, directive.ops) @_traverse.dispatch_for(ops.MigrateOperation) def _traverse_any_directive( self, - context: "MigrationContext", - revision: "Revision", - directive: "MigrateOperation", + context: MigrationContext, + revision: Revision, + directive: MigrateOperation, ) -> None: pass def _traverse_for( self, - context: "MigrationContext", - revision: "Revision", - directive: "MigrateOperation", + context: MigrationContext, + revision: Revision, + directive: MigrateOperation, ) -> Any: directives = list(self._rewrite(context, revision, directive)) for directive in directives: @@ -208,8 +208,8 @@ def _traverse_for( def _traverse_list( self, - context: "MigrationContext", - revision: "Revision", + context: MigrationContext, + revision: Revision, directives: Any, ) -> None: dest = [] @@ -220,8 +220,8 @@ def _traverse_list( def process_revision_directives( self, - context: "MigrationContext", - revision: "Revision", - directives: List["MigrationScript"], + context: MigrationContext, + revision: Revision, + directives: List[MigrationScript], ) -> None: self._traverse_list(context, revision, directives) diff --git a/alembic/command.py b/alembic/command.py index 162b3d0c..5c33a95e 100644 --- a/alembic/command.py +++ b/alembic/command.py @@ -37,7 +37,7 @@ def list_templates(config): def init( - config: "Config", + config: Config, directory: str, template: str = "generic", package: bool = False, @@ -114,7 +114,7 @@ def init( def revision( - config: "Config", + config: Config, message: Optional[str] = None, autogenerate: bool = False, sql: bool = False, @@ -125,7 +125,7 @@ def revision( rev_id: Optional[str] = None, depends_on: Optional[str] = None, process_revision_directives: Optional[ProcessRevisionDirectiveFn] = None, -) -> Union[Optional["Script"], List[Optional["Script"]]]: +) -> Union[Optional[Script], List[Optional[Script]]]: """Create a new revision file. :param config: a :class:`.Config` object. @@ -241,12 +241,12 @@ def retrieve_migrations(rev, context): def merge( - config: "Config", + config: Config, revisions: str, message: Optional[str] = None, branch_label: Optional[str] = None, rev_id: Optional[str] = None, -) -> Optional["Script"]: +) -> Optional[Script]: """Merge two revisions together. Creates a new migration file. :param config: a :class:`.Config` instance @@ -280,7 +280,7 @@ def merge( def upgrade( - config: "Config", + config: Config, revision: str, sql: bool = False, tag: Optional[str] = None, @@ -323,7 +323,7 @@ def upgrade(rev, context): def downgrade( - config: "Config", + config: Config, revision: str, sql: bool = False, tag: Optional[str] = None, @@ -394,7 +394,7 @@ def show_current(rev, context): def history( - config: "Config", + config: Config, rev_range: Optional[str] = None, verbose: bool = False, indicate_current: bool = False, @@ -517,7 +517,7 @@ def branches(config, verbose=False): ) -def current(config: "Config", verbose: bool = False) -> None: +def current(config: Config, verbose: bool = False) -> None: """Display the current revision for a database. :param config: a :class:`.Config` instance. @@ -546,7 +546,7 @@ def display_version(rev, context): def stamp( - config: "Config", + config: Config, revision: str, sql: bool = False, tag: Optional[str] = None, @@ -615,7 +615,7 @@ def do_stamp(rev, context): script.run_env() -def edit(config: "Config", rev: str) -> None: +def edit(config: Config, rev: str) -> None: """Edit revision script(s) using $EDITOR. :param config: a :class:`.Config` instance. @@ -648,7 +648,7 @@ def edit_current(rev, context): util.open_in_editor(sc.path) -def ensure_version(config: "Config", sql: bool = False) -> None: +def ensure_version(config: Config, sql: bool = False) -> None: """Create the alembic version table if it doesn't exist already . :param config: a :class:`.Config` instance. diff --git a/alembic/config.py b/alembic/config.py index 8464407d..ac27d585 100644 --- a/alembic/config.py +++ b/alembic/config.py @@ -561,7 +561,7 @@ def run_cmd(self, config: Config, options: Namespace) -> None: fn( config, *[getattr(options, k, None) for k in positional], - **dict((k, getattr(options, k, None)) for k in kwarg), + **{k: getattr(options, k, None) for k in kwarg}, ) except util.CommandError as e: if options.raiseerr: diff --git a/alembic/ddl/base.py b/alembic/ddl/base.py index c9107867..c3bdaf38 100644 --- a/alembic/ddl/base.py +++ b/alembic/ddl/base.py @@ -46,7 +46,7 @@ class AlterTable(DDLElement): def __init__( self, table_name: str, - schema: Optional[Union["quoted_name", str]] = None, + schema: Optional[Union[quoted_name, str]] = None, ) -> None: self.table_name = table_name self.schema = schema @@ -56,10 +56,10 @@ class RenameTable(AlterTable): def __init__( self, old_table_name: str, - new_table_name: Union["quoted_name", str], - schema: Optional[Union["quoted_name", str]] = None, + new_table_name: Union[quoted_name, str], + schema: Optional[Union[quoted_name, str]] = None, ) -> None: - super(RenameTable, self).__init__(old_table_name, schema=schema) + super().__init__(old_table_name, schema=schema) self.new_table_name = new_table_name @@ -69,12 +69,12 @@ def __init__( name: str, column_name: str, schema: Optional[str] = None, - existing_type: Optional["TypeEngine"] = None, + existing_type: Optional[TypeEngine] = None, existing_nullable: Optional[bool] = None, existing_server_default: Optional[_ServerDefault] = None, existing_comment: Optional[str] = None, ) -> None: - super(AlterColumn, self).__init__(name, schema=schema) + super().__init__(name, schema=schema) self.column_name = column_name self.existing_type = ( sqltypes.to_instance(existing_type) @@ -90,15 +90,15 @@ class ColumnNullable(AlterColumn): def __init__( self, name: str, column_name: str, nullable: bool, **kw ) -> None: - super(ColumnNullable, self).__init__(name, column_name, **kw) + super().__init__(name, column_name, **kw) self.nullable = nullable class ColumnType(AlterColumn): def __init__( - self, name: str, column_name: str, type_: "TypeEngine", **kw + self, name: str, column_name: str, type_: TypeEngine, **kw ) -> None: - super(ColumnType, self).__init__(name, column_name, **kw) + super().__init__(name, column_name, **kw) self.type_ = sqltypes.to_instance(type_) @@ -106,7 +106,7 @@ class ColumnName(AlterColumn): def __init__( self, name: str, column_name: str, newname: str, **kw ) -> None: - super(ColumnName, self).__init__(name, column_name, **kw) + super().__init__(name, column_name, **kw) self.newname = newname @@ -118,15 +118,15 @@ def __init__( default: Optional[_ServerDefault], **kw, ) -> None: - super(ColumnDefault, self).__init__(name, column_name, **kw) + super().__init__(name, column_name, **kw) self.default = default class ComputedColumnDefault(AlterColumn): def __init__( - self, name: str, column_name: str, default: Optional["Computed"], **kw + self, name: str, column_name: str, default: Optional[Computed], **kw ) -> None: - super(ComputedColumnDefault, self).__init__(name, column_name, **kw) + super().__init__(name, column_name, **kw) self.default = default @@ -135,11 +135,11 @@ def __init__( self, name: str, column_name: str, - default: Optional["Identity"], - impl: "DefaultImpl", + default: Optional[Identity], + impl: DefaultImpl, **kw, ) -> None: - super(IdentityColumnDefault, self).__init__(name, column_name, **kw) + super().__init__(name, column_name, **kw) self.default = default self.impl = impl @@ -148,18 +148,18 @@ class AddColumn(AlterTable): def __init__( self, name: str, - column: "Column", - schema: Optional[Union["quoted_name", str]] = None, + column: Column, + schema: Optional[Union[quoted_name, str]] = None, ) -> None: - super(AddColumn, self).__init__(name, schema=schema) + super().__init__(name, schema=schema) self.column = column class DropColumn(AlterTable): def __init__( - self, name: str, column: "Column", schema: Optional[str] = None + self, name: str, column: Column, schema: Optional[str] = None ) -> None: - super(DropColumn, self).__init__(name, schema=schema) + super().__init__(name, schema=schema) self.column = column @@ -167,13 +167,13 @@ class ColumnComment(AlterColumn): def __init__( self, name: str, column_name: str, comment: Optional[str], **kw ) -> None: - super(ColumnComment, self).__init__(name, column_name, **kw) + super().__init__(name, column_name, **kw) self.comment = comment @compiles(RenameTable) def visit_rename_table( - element: "RenameTable", compiler: "DDLCompiler", **kw + element: RenameTable, compiler: DDLCompiler, **kw ) -> str: return "%s RENAME TO %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -182,9 +182,7 @@ def visit_rename_table( @compiles(AddColumn) -def visit_add_column( - element: "AddColumn", compiler: "DDLCompiler", **kw -) -> str: +def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str: return "%s %s" % ( alter_table(compiler, element.table_name, element.schema), add_column(compiler, element.column, **kw), @@ -192,9 +190,7 @@ def visit_add_column( @compiles(DropColumn) -def visit_drop_column( - element: "DropColumn", compiler: "DDLCompiler", **kw -) -> str: +def visit_drop_column(element: DropColumn, compiler: DDLCompiler, **kw) -> str: return "%s %s" % ( alter_table(compiler, element.table_name, element.schema), drop_column(compiler, element.column.name, **kw), @@ -203,7 +199,7 @@ def visit_drop_column( @compiles(ColumnNullable) def visit_column_nullable( - element: "ColumnNullable", compiler: "DDLCompiler", **kw + element: ColumnNullable, compiler: DDLCompiler, **kw ) -> str: return "%s %s %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -213,9 +209,7 @@ def visit_column_nullable( @compiles(ColumnType) -def visit_column_type( - element: "ColumnType", compiler: "DDLCompiler", **kw -) -> str: +def visit_column_type(element: ColumnType, compiler: DDLCompiler, **kw) -> str: return "%s %s %s" % ( alter_table(compiler, element.table_name, element.schema), alter_column(compiler, element.column_name), @@ -224,9 +218,7 @@ def visit_column_type( @compiles(ColumnName) -def visit_column_name( - element: "ColumnName", compiler: "DDLCompiler", **kw -) -> str: +def visit_column_name(element: ColumnName, compiler: DDLCompiler, **kw) -> str: return "%s RENAME %s TO %s" % ( alter_table(compiler, element.table_name, element.schema), format_column_name(compiler, element.column_name), @@ -236,7 +228,7 @@ def visit_column_name( @compiles(ColumnDefault) def visit_column_default( - element: "ColumnDefault", compiler: "DDLCompiler", **kw + element: ColumnDefault, compiler: DDLCompiler, **kw ) -> str: return "%s %s %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -249,7 +241,7 @@ def visit_column_default( @compiles(ComputedColumnDefault) def visit_computed_column( - element: "ComputedColumnDefault", compiler: "DDLCompiler", **kw + element: ComputedColumnDefault, compiler: DDLCompiler, **kw ): raise exc.CompileError( 'Adding or removing a "computed" construct, e.g. GENERATED ' @@ -259,7 +251,7 @@ def visit_computed_column( @compiles(IdentityColumnDefault) def visit_identity_column( - element: "IdentityColumnDefault", compiler: "DDLCompiler", **kw + element: IdentityColumnDefault, compiler: DDLCompiler, **kw ): raise exc.CompileError( 'Adding, removing or modifying an "identity" construct, ' @@ -269,8 +261,8 @@ def visit_identity_column( def quote_dotted( - name: Union["quoted_name", str], quote: functools.partial -) -> Union["quoted_name", str]: + name: Union[quoted_name, str], quote: functools.partial +) -> Union[quoted_name, str]: """quote the elements of a dotted name""" if isinstance(name, quoted_name): @@ -280,10 +272,10 @@ def quote_dotted( def format_table_name( - compiler: "Compiled", - name: Union["quoted_name", str], - schema: Optional[Union["quoted_name", str]], -) -> Union["quoted_name", str]: + compiler: Compiled, + name: Union[quoted_name, str], + schema: Optional[Union[quoted_name, str]], +) -> Union[quoted_name, str]: quote = functools.partial(compiler.preparer.quote) if schema: return quote_dotted(schema, quote) + "." + quote(name) @@ -292,13 +284,13 @@ def format_table_name( def format_column_name( - compiler: "DDLCompiler", name: Optional[Union["quoted_name", str]] -) -> Union["quoted_name", str]: + compiler: DDLCompiler, name: Optional[Union[quoted_name, str]] +) -> Union[quoted_name, str]: return compiler.preparer.quote(name) # type: ignore[arg-type] def format_server_default( - compiler: "DDLCompiler", + compiler: DDLCompiler, default: Optional[_ServerDefault], ) -> str: return compiler.get_column_default_string( @@ -306,27 +298,27 @@ def format_server_default( ) -def format_type(compiler: "DDLCompiler", type_: "TypeEngine") -> str: +def format_type(compiler: DDLCompiler, type_: TypeEngine) -> str: return compiler.dialect.type_compiler.process(type_) def alter_table( - compiler: "DDLCompiler", + compiler: DDLCompiler, name: str, schema: Optional[str], ) -> str: return "ALTER TABLE %s" % format_table_name(compiler, name, schema) -def drop_column(compiler: "DDLCompiler", name: str, **kw) -> str: +def drop_column(compiler: DDLCompiler, name: str, **kw) -> str: return "DROP COLUMN %s" % format_column_name(compiler, name) -def alter_column(compiler: "DDLCompiler", name: str) -> str: +def alter_column(compiler: DDLCompiler, name: str) -> str: return "ALTER COLUMN %s" % format_column_name(compiler, name) -def add_column(compiler: "DDLCompiler", column: "Column", **kw) -> str: +def add_column(compiler: DDLCompiler, column: Column, **kw) -> str: text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw) const = " ".join( diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 79d5245e..728d1dae 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -52,7 +52,7 @@ class ImplMeta(type): def __init__( cls, classname: str, - bases: Tuple[Type["DefaultImpl"]], + bases: Tuple[Type[DefaultImpl]], dict_: Dict[str, Any], ): newtype = type.__init__(cls, classname, bases, dict_) @@ -61,7 +61,7 @@ def __init__( return newtype -_impls: Dict[str, Type["DefaultImpl"]] = {} +_impls: Dict[str, Type[DefaultImpl]] = {} Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"]) @@ -91,11 +91,11 @@ class DefaultImpl(metaclass=ImplMeta): def __init__( self, - dialect: "Dialect", - connection: Optional["Connection"], + dialect: Dialect, + connection: Optional[Connection], as_sql: bool, transactional_ddl: Optional[bool], - output_buffer: Optional["TextIO"], + output_buffer: Optional[TextIO], context_opts: Dict[str, Any], ) -> None: self.dialect = dialect @@ -116,7 +116,7 @@ def __init__( ) @classmethod - def get_by_dialect(cls, dialect: "Dialect") -> Type["DefaultImpl"]: + def get_by_dialect(cls, dialect: Dialect) -> Type[DefaultImpl]: return _impls[dialect.name] def static_output(self, text: str) -> None: @@ -125,7 +125,7 @@ def static_output(self, text: str) -> None: self.output_buffer.flush() def requires_recreate_in_batch( - self, batch_op: "BatchOperationsImpl" + self, batch_op: BatchOperationsImpl ) -> bool: """Return True if the given :class:`.BatchOperationsImpl` would need the table to be recreated and copied in order to @@ -138,7 +138,7 @@ def requires_recreate_in_batch( return False def prep_table_for_batch( - self, batch_impl: "ApplyBatchImpl", table: "Table" + self, batch_impl: ApplyBatchImpl, table: Table ) -> None: """perform any operations needed on a table before a new one is created to replace it in batch mode. @@ -149,16 +149,16 @@ def prep_table_for_batch( """ @property - def bind(self) -> Optional["Connection"]: + def bind(self) -> Optional[Connection]: return self.connection def _exec( self, - construct: Union["ClauseElement", str], + construct: Union[ClauseElement, str], execution_options: Optional[dict] = None, multiparams: Sequence[dict] = (), params: Dict[str, int] = util.immutabledict(), - ) -> Optional["CursorResult"]: + ) -> Optional[CursorResult]: if isinstance(construct, str): construct = text(construct) if self.as_sql: @@ -196,7 +196,7 @@ def _exec( def execute( self, - sql: Union["ClauseElement", str], + sql: Union[ClauseElement, str], execution_options: None = None, ) -> None: self._exec(sql, execution_options) @@ -206,15 +206,15 @@ def alter_column( table_name: str, column_name: str, nullable: Optional[bool] = None, - server_default: Union["_ServerDefault", "Literal[False]"] = False, + server_default: Union[_ServerDefault, Literal[False]] = False, name: Optional[str] = None, - type_: Optional["TypeEngine"] = None, + type_: Optional[TypeEngine] = None, schema: Optional[str] = None, autoincrement: Optional[bool] = None, - comment: Optional[Union[str, "Literal[False]"]] = False, + comment: Optional[Union[str, Literal[False]]] = False, existing_comment: Optional[str] = None, - existing_type: Optional["TypeEngine"] = None, - existing_server_default: Optional["_ServerDefault"] = None, + existing_type: Optional[TypeEngine] = None, + existing_server_default: Optional[_ServerDefault] = None, existing_nullable: Optional[bool] = None, existing_autoincrement: Optional[bool] = None, **kw: Any, @@ -316,15 +316,15 @@ def alter_column( def add_column( self, table_name: str, - column: "Column", - schema: Optional[Union[str, "quoted_name"]] = None, + column: Column, + schema: Optional[Union[str, quoted_name]] = None, ) -> None: self._exec(base.AddColumn(table_name, column, schema=schema)) def drop_column( self, table_name: str, - column: "Column", + column: Column, schema: Optional[str] = None, **kw, ) -> None: @@ -334,20 +334,20 @@ def add_constraint(self, const: Any) -> None: if const._create_rule is None or const._create_rule(self): self._exec(schema.AddConstraint(const)) - def drop_constraint(self, const: "Constraint") -> None: + def drop_constraint(self, const: Constraint) -> None: self._exec(schema.DropConstraint(const)) def rename_table( self, old_table_name: str, - new_table_name: Union[str, "quoted_name"], - schema: Optional[Union[str, "quoted_name"]] = None, + new_table_name: Union[str, quoted_name], + schema: Optional[Union[str, quoted_name]] = None, ) -> None: self._exec( base.RenameTable(old_table_name, new_table_name, schema=schema) ) - def create_table(self, table: "Table") -> None: + def create_table(self, table: Table) -> None: table.dispatch.before_create( table, self.connection, checkfirst=False, _ddl_runner=self ) @@ -370,7 +370,7 @@ def create_table(self, table: "Table") -> None: if comment and with_comment: self.create_column_comment(column) - def drop_table(self, table: "Table") -> None: + def drop_table(self, table: Table) -> None: table.dispatch.before_drop( table, self.connection, checkfirst=False, _ddl_runner=self ) @@ -379,24 +379,24 @@ def drop_table(self, table: "Table") -> None: table, self.connection, checkfirst=False, _ddl_runner=self ) - def create_index(self, index: "Index") -> None: + def create_index(self, index: Index) -> None: self._exec(schema.CreateIndex(index)) - def create_table_comment(self, table: "Table") -> None: + def create_table_comment(self, table: Table) -> None: self._exec(schema.SetTableComment(table)) - def drop_table_comment(self, table: "Table") -> None: + def drop_table_comment(self, table: Table) -> None: self._exec(schema.DropTableComment(table)) - def create_column_comment(self, column: "ColumnElement") -> None: + def create_column_comment(self, column: ColumnElement) -> None: self._exec(schema.SetColumnComment(column)) - def drop_index(self, index: "Index") -> None: + def drop_index(self, index: Index) -> None: self._exec(schema.DropIndex(index)) def bulk_insert( self, - table: Union["TableClause", "Table"], + table: Union[TableClause, Table], rows: List[dict], multiinsert: bool = True, ) -> None: @@ -408,19 +408,16 @@ def bulk_insert( for row in rows: self._exec( sqla_compat._insert_inline(table).values( - **dict( - ( - k, - sqla_compat._literal_bindparam( - k, v, type_=table.c[k].type - ) - if not isinstance( - v, sqla_compat._literal_bindparam - ) - else v, + **{ + k: sqla_compat._literal_bindparam( + k, v, type_=table.c[k].type ) + if not isinstance( + v, sqla_compat._literal_bindparam + ) + else v for k, v in row.items() - ) + } ) ) else: @@ -435,7 +432,7 @@ def bulk_insert( sqla_compat._insert_inline(table).values(**row) ) - def _tokenize_column_type(self, column: "Column") -> Params: + def _tokenize_column_type(self, column: Column) -> Params: definition = self.dialect.type_compiler.process(column.type).lower() # tokenize the SQLAlchemy-generated version of a type, so that @@ -474,7 +471,7 @@ def _tokenize_column_type(self, column: "Column") -> Params: return params def _column_types_match( - self, inspector_params: "Params", metadata_params: "Params" + self, inspector_params: Params, metadata_params: Params ) -> bool: if inspector_params.token0 == metadata_params.token0: return True @@ -496,7 +493,7 @@ def _column_types_match( return False def _column_args_match( - self, inspected_params: "Params", meta_params: "Params" + self, inspected_params: Params, meta_params: Params ) -> bool: """We want to compare column parameters. However, we only want to compare parameters that are set. If they both have `collation`, @@ -529,7 +526,7 @@ def _column_args_match( return True def compare_type( - self, inspector_column: "Column", metadata_column: "Column" + self, inspector_column: Column, metadata_column: Column ) -> bool: """Returns True if there ARE differences between the types of the two columns. Takes impl.type_synonyms into account between retrospected @@ -555,10 +552,10 @@ def compare_server_default( def correct_for_autogen_constraints( self, - conn_uniques: Set["UniqueConstraint"], - conn_indexes: Set["Index"], - metadata_unique_constraints: Set["UniqueConstraint"], - metadata_indexes: Set["Index"], + conn_uniques: Set[UniqueConstraint], + conn_indexes: Set[Index], + metadata_unique_constraints: Set[UniqueConstraint], + metadata_indexes: Set[Index], ) -> None: pass @@ -569,7 +566,7 @@ def cast_for_batch_migrate(self, existing, existing_transfer, new_type): ) def render_ddl_sql_expr( - self, expr: "ClauseElement", is_server_default: bool = False, **kw: Any + self, expr: ClauseElement, is_server_default: bool = False, **kw: Any ) -> str: """Render a SQL expression that is typically a server default, index expression, etc. @@ -587,15 +584,13 @@ def render_ddl_sql_expr( ) ) - def _compat_autogen_column_reflect( - self, inspector: "Inspector" - ) -> Callable: + def _compat_autogen_column_reflect(self, inspector: Inspector) -> Callable: return self.autogen_column_reflect def correct_for_autogen_foreignkeys( self, - conn_fks: Set["ForeignKeyConstraint"], - metadata_fks: Set["ForeignKeyConstraint"], + conn_fks: Set[ForeignKeyConstraint], + metadata_fks: Set[ForeignKeyConstraint], ) -> None: pass @@ -637,8 +632,8 @@ def emit_commit(self) -> None: self.static_output("COMMIT" + self.command_terminator) def render_type( - self, type_obj: "TypeEngine", autogen_context: "AutogenContext" - ) -> Union[str, "Literal[False]"]: + self, type_obj: TypeEngine, autogen_context: AutogenContext + ) -> Union[str, Literal[False]]: return False def _compare_identity_default(self, metadata_identity, inspector_identity): diff --git a/alembic/ddl/mssql.py b/alembic/ddl/mssql.py index 28f0678e..6a208ec6 100644 --- a/alembic/ddl/mssql.py +++ b/alembic/ddl/mssql.py @@ -62,13 +62,13 @@ class MSSQLImpl(DefaultImpl): ) def __init__(self, *arg, **kw) -> None: - super(MSSQLImpl, self).__init__(*arg, **kw) + super().__init__(*arg, **kw) self.batch_separator = self.context_opts.get( "mssql_batch_separator", self.batch_separator ) - def _exec(self, construct: Any, *args, **kw) -> Optional["CursorResult"]: - result = super(MSSQLImpl, self)._exec(construct, *args, **kw) + def _exec(self, construct: Any, *args, **kw) -> Optional[CursorResult]: + result = super()._exec(construct, *args, **kw) if self.as_sql and self.batch_separator: self.static_output(self.batch_separator) return result @@ -77,7 +77,7 @@ def emit_begin(self) -> None: self.static_output("BEGIN TRANSACTION" + self.command_terminator) def emit_commit(self) -> None: - super(MSSQLImpl, self).emit_commit() + super().emit_commit() if self.as_sql and self.batch_separator: self.static_output(self.batch_separator) @@ -87,13 +87,13 @@ def alter_column( # type:ignore[override] column_name: str, nullable: Optional[bool] = None, server_default: Optional[ - Union["_ServerDefault", "Literal[False]"] + Union[_ServerDefault, Literal[False]] ] = False, name: Optional[str] = None, - type_: Optional["TypeEngine"] = None, + type_: Optional[TypeEngine] = None, schema: Optional[str] = None, - existing_type: Optional["TypeEngine"] = None, - existing_server_default: Optional["_ServerDefault"] = None, + existing_type: Optional[TypeEngine] = None, + existing_server_default: Optional[_ServerDefault] = None, existing_nullable: Optional[bool] = None, **kw: Any, ) -> None: @@ -136,7 +136,7 @@ def alter_column( # type:ignore[override] kw["server_default"] = server_default kw["existing_server_default"] = existing_server_default - super(MSSQLImpl, self).alter_column( + super().alter_column( table_name, column_name, nullable=nullable, @@ -158,7 +158,7 @@ def alter_column( # type:ignore[override] ) ) if server_default is not None: - super(MSSQLImpl, self).alter_column( + super().alter_column( table_name, column_name, schema=schema, @@ -166,11 +166,11 @@ def alter_column( # type:ignore[override] ) if name is not None: - super(MSSQLImpl, self).alter_column( + super().alter_column( table_name, column_name, schema=schema, name=name ) - def create_index(self, index: "Index") -> None: + def create_index(self, index: Index) -> None: # this likely defaults to None if not present, so get() # should normally not return the default value. being # defensive in any case @@ -182,25 +182,25 @@ def create_index(self, index: "Index") -> None: self._exec(CreateIndex(index)) def bulk_insert( # type:ignore[override] - self, table: Union["TableClause", "Table"], rows: List[dict], **kw: Any + self, table: Union[TableClause, Table], rows: List[dict], **kw: Any ) -> None: if self.as_sql: self._exec( "SET IDENTITY_INSERT %s ON" % self.dialect.identifier_preparer.format_table(table) ) - super(MSSQLImpl, self).bulk_insert(table, rows, **kw) + super().bulk_insert(table, rows, **kw) self._exec( "SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(table) ) else: - super(MSSQLImpl, self).bulk_insert(table, rows, **kw) + super().bulk_insert(table, rows, **kw) def drop_column( self, table_name: str, - column: "Column", + column: Column, schema: Optional[str] = None, **kw, ) -> None: @@ -221,9 +221,7 @@ def drop_column( drop_fks = kw.pop("mssql_drop_foreign_key", False) if drop_fks: self._exec(_ExecDropFKConstraint(table_name, column, schema)) - super(MSSQLImpl, self).drop_column( - table_name, column, schema=schema, **kw - ) + super().drop_column(table_name, column, schema=schema, **kw) def compare_server_default( self, @@ -244,9 +242,9 @@ def clean(value): ) def _compare_identity_default(self, metadata_identity, inspector_identity): - diff, ignored, is_alter = super( - MSSQLImpl, self - )._compare_identity_default(metadata_identity, inspector_identity) + diff, ignored, is_alter = super()._compare_identity_default( + metadata_identity, inspector_identity + ) if ( metadata_identity is None @@ -268,7 +266,7 @@ class _ExecDropConstraint(Executable, ClauseElement): def __init__( self, tname: str, - colname: Union["Column", str], + colname: Union[Column, str], type_: str, schema: Optional[str], ) -> None: @@ -282,7 +280,7 @@ class _ExecDropFKConstraint(Executable, ClauseElement): inherit_cache = False def __init__( - self, tname: str, colname: "Column", schema: Optional[str] + self, tname: str, colname: Column, schema: Optional[str] ) -> None: self.tname = tname self.colname = colname @@ -291,7 +289,7 @@ def __init__( @compiles(_ExecDropConstraint, "mssql") def _exec_drop_col_constraint( - element: "_ExecDropConstraint", compiler: "MSSQLCompiler", **kw + element: _ExecDropConstraint, compiler: MSSQLCompiler, **kw ) -> str: schema, tname, colname, type_ = ( element.schema, @@ -317,7 +315,7 @@ def _exec_drop_col_constraint( @compiles(_ExecDropFKConstraint, "mssql") def _exec_drop_col_fk_constraint( - element: "_ExecDropFKConstraint", compiler: "MSSQLCompiler", **kw + element: _ExecDropFKConstraint, compiler: MSSQLCompiler, **kw ) -> str: schema, tname, colname = element.schema, element.tname, element.colname @@ -336,22 +334,20 @@ def _exec_drop_col_fk_constraint( @compiles(AddColumn, "mssql") -def visit_add_column( - element: "AddColumn", compiler: "MSDDLCompiler", **kw -) -> str: +def visit_add_column(element: AddColumn, compiler: MSDDLCompiler, **kw) -> str: return "%s %s" % ( alter_table(compiler, element.table_name, element.schema), mssql_add_column(compiler, element.column, **kw), ) -def mssql_add_column(compiler: "MSDDLCompiler", column: "Column", **kw) -> str: +def mssql_add_column(compiler: MSDDLCompiler, column: Column, **kw) -> str: return "ADD %s" % compiler.get_column_specification(column, **kw) @compiles(ColumnNullable, "mssql") def visit_column_nullable( - element: "ColumnNullable", compiler: "MSDDLCompiler", **kw + element: ColumnNullable, compiler: MSDDLCompiler, **kw ) -> str: return "%s %s %s %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -363,7 +359,7 @@ def visit_column_nullable( @compiles(ColumnDefault, "mssql") def visit_column_default( - element: "ColumnDefault", compiler: "MSDDLCompiler", **kw + element: ColumnDefault, compiler: MSDDLCompiler, **kw ) -> str: # TODO: there can also be a named constraint # with ADD CONSTRAINT here @@ -376,7 +372,7 @@ def visit_column_default( @compiles(ColumnName, "mssql") def visit_rename_column( - element: "ColumnName", compiler: "MSDDLCompiler", **kw + element: ColumnName, compiler: MSDDLCompiler, **kw ) -> str: return "EXEC sp_rename '%s.%s', %s, 'COLUMN'" % ( format_table_name(compiler, element.table_name, element.schema), @@ -387,7 +383,7 @@ def visit_rename_column( @compiles(ColumnType, "mssql") def visit_column_type( - element: "ColumnType", compiler: "MSDDLCompiler", **kw + element: ColumnType, compiler: MSDDLCompiler, **kw ) -> str: return "%s %s %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -398,7 +394,7 @@ def visit_column_type( @compiles(RenameTable, "mssql") def visit_rename_table( - element: "RenameTable", compiler: "MSDDLCompiler", **kw + element: RenameTable, compiler: MSDDLCompiler, **kw ) -> str: return "EXEC sp_rename '%s', %s" % ( format_table_name(compiler, element.table_name, element.schema), diff --git a/alembic/ddl/mysql.py b/alembic/ddl/mysql.py index 0c03fbe1..a4527602 100644 --- a/alembic/ddl/mysql.py +++ b/alembic/ddl/mysql.py @@ -51,16 +51,16 @@ def alter_column( # type:ignore[override] table_name: str, column_name: str, nullable: Optional[bool] = None, - server_default: Union["_ServerDefault", "Literal[False]"] = False, + server_default: Union[_ServerDefault, Literal[False]] = False, name: Optional[str] = None, - type_: Optional["TypeEngine"] = None, + type_: Optional[TypeEngine] = None, schema: Optional[str] = None, - existing_type: Optional["TypeEngine"] = None, - existing_server_default: Optional["_ServerDefault"] = None, + existing_type: Optional[TypeEngine] = None, + existing_server_default: Optional[_ServerDefault] = None, existing_nullable: Optional[bool] = None, autoincrement: Optional[bool] = None, existing_autoincrement: Optional[bool] = None, - comment: Optional[Union[str, "Literal[False]"]] = False, + comment: Optional[Union[str, Literal[False]]] = False, existing_comment: Optional[str] = None, **kw: Any, ) -> None: @@ -71,7 +71,7 @@ def alter_column( # type:ignore[override] ): # modifying computed or identity columns is not supported # the default will raise - super(MySQLImpl, self).alter_column( + super().alter_column( table_name, column_name, nullable=nullable, @@ -147,17 +147,17 @@ def alter_column( # type:ignore[override] def drop_constraint( self, - const: "Constraint", + const: Constraint, ) -> None: if isinstance(const, schema.CheckConstraint) and _is_type_bound(const): return - super(MySQLImpl, self).drop_constraint(const) + super().drop_constraint(const) def _is_mysql_allowed_functional_default( self, - type_: Optional["TypeEngine"], - server_default: Union["_ServerDefault", "Literal[False]"], + type_: Optional[TypeEngine], + server_default: Union[_ServerDefault, Literal[False]], ) -> bool: return ( type_ is not None @@ -263,12 +263,12 @@ def correct_for_autogen_constraints( metadata_indexes.remove(idx) def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks): - conn_fk_by_sig = dict( - (compare._fk_constraint_sig(fk).sig, fk) for fk in conn_fks - ) - metadata_fk_by_sig = dict( - (compare._fk_constraint_sig(fk).sig, fk) for fk in metadata_fks - ) + conn_fk_by_sig = { + compare._fk_constraint_sig(fk).sig: fk for fk in conn_fks + } + metadata_fk_by_sig = { + compare._fk_constraint_sig(fk).sig: fk for fk in metadata_fks + } for sig in set(conn_fk_by_sig).intersection(metadata_fk_by_sig): mdfk = metadata_fk_by_sig[sig] @@ -299,7 +299,7 @@ def __init__( self, name: str, column_name: str, - default: "_ServerDefault", + default: _ServerDefault, schema: Optional[str] = None, ) -> None: super(AlterColumn, self).__init__(name, schema=schema) @@ -314,11 +314,11 @@ def __init__( column_name: str, schema: Optional[str] = None, newname: Optional[str] = None, - type_: Optional["TypeEngine"] = None, + type_: Optional[TypeEngine] = None, nullable: Optional[bool] = None, - default: Optional[Union["_ServerDefault", "Literal[False]"]] = False, + default: Optional[Union[_ServerDefault, Literal[False]]] = False, autoincrement: Optional[bool] = None, - comment: Optional[Union[str, "Literal[False]"]] = False, + comment: Optional[Union[str, Literal[False]]] = False, ) -> None: super(AlterColumn, self).__init__(name, schema=schema) self.column_name = column_name @@ -352,7 +352,7 @@ def _mysql_doesnt_support_individual(element, compiler, **kw): @compiles(MySQLAlterDefault, "mysql", "mariadb") def _mysql_alter_default( - element: "MySQLAlterDefault", compiler: "MySQLDDLCompiler", **kw + element: MySQLAlterDefault, compiler: MySQLDDLCompiler, **kw ) -> str: return "%s ALTER COLUMN %s %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -365,7 +365,7 @@ def _mysql_alter_default( @compiles(MySQLModifyColumn, "mysql", "mariadb") def _mysql_modify_column( - element: "MySQLModifyColumn", compiler: "MySQLDDLCompiler", **kw + element: MySQLModifyColumn, compiler: MySQLDDLCompiler, **kw ) -> str: return "%s MODIFY %s %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -383,7 +383,7 @@ def _mysql_modify_column( @compiles(MySQLChangeColumn, "mysql", "mariadb") def _mysql_change_column( - element: "MySQLChangeColumn", compiler: "MySQLDDLCompiler", **kw + element: MySQLChangeColumn, compiler: MySQLDDLCompiler, **kw ) -> str: return "%s CHANGE %s %s %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -401,12 +401,12 @@ def _mysql_change_column( def _mysql_colspec( - compiler: "MySQLDDLCompiler", + compiler: MySQLDDLCompiler, nullable: Optional[bool], - server_default: Optional[Union["_ServerDefault", "Literal[False]"]], - type_: "TypeEngine", + server_default: Optional[Union[_ServerDefault, Literal[False]]], + type_: TypeEngine, autoincrement: Optional[bool], - comment: Optional[Union[str, "Literal[False]"]], + comment: Optional[Union[str, Literal[False]]], ) -> str: spec = "%s %s" % ( compiler.dialect.type_compiler.process(type_), @@ -426,7 +426,7 @@ def _mysql_colspec( @compiles(schema.DropConstraint, "mysql", "mariadb") def _mysql_drop_constraint( - element: "DropConstraint", compiler: "MySQLDDLCompiler", **kw + element: DropConstraint, compiler: MySQLDDLCompiler, **kw ) -> str: """Redefine SQLAlchemy's drop constraint to raise errors for invalid constraint type.""" diff --git a/alembic/ddl/oracle.py b/alembic/ddl/oracle.py index accd1fcf..920b70ae 100644 --- a/alembic/ddl/oracle.py +++ b/alembic/ddl/oracle.py @@ -41,13 +41,13 @@ class OracleImpl(DefaultImpl): identity_attrs_ignore = () def __init__(self, *arg, **kw) -> None: - super(OracleImpl, self).__init__(*arg, **kw) + super().__init__(*arg, **kw) self.batch_separator = self.context_opts.get( "oracle_batch_separator", self.batch_separator ) - def _exec(self, construct: Any, *args, **kw) -> Optional["CursorResult"]: - result = super(OracleImpl, self)._exec(construct, *args, **kw) + def _exec(self, construct: Any, *args, **kw) -> Optional[CursorResult]: + result = super()._exec(construct, *args, **kw) if self.as_sql and self.batch_separator: self.static_output(self.batch_separator) return result @@ -61,7 +61,7 @@ def emit_commit(self) -> None: @compiles(AddColumn, "oracle") def visit_add_column( - element: "AddColumn", compiler: "OracleDDLCompiler", **kw + element: AddColumn, compiler: OracleDDLCompiler, **kw ) -> str: return "%s %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -71,7 +71,7 @@ def visit_add_column( @compiles(ColumnNullable, "oracle") def visit_column_nullable( - element: "ColumnNullable", compiler: "OracleDDLCompiler", **kw + element: ColumnNullable, compiler: OracleDDLCompiler, **kw ) -> str: return "%s %s %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -82,7 +82,7 @@ def visit_column_nullable( @compiles(ColumnType, "oracle") def visit_column_type( - element: "ColumnType", compiler: "OracleDDLCompiler", **kw + element: ColumnType, compiler: OracleDDLCompiler, **kw ) -> str: return "%s %s %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -93,7 +93,7 @@ def visit_column_type( @compiles(ColumnName, "oracle") def visit_column_name( - element: "ColumnName", compiler: "OracleDDLCompiler", **kw + element: ColumnName, compiler: OracleDDLCompiler, **kw ) -> str: return "%s RENAME COLUMN %s TO %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -104,7 +104,7 @@ def visit_column_name( @compiles(ColumnDefault, "oracle") def visit_column_default( - element: "ColumnDefault", compiler: "OracleDDLCompiler", **kw + element: ColumnDefault, compiler: OracleDDLCompiler, **kw ) -> str: return "%s %s %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -117,7 +117,7 @@ def visit_column_default( @compiles(ColumnComment, "oracle") def visit_column_comment( - element: "ColumnComment", compiler: "OracleDDLCompiler", **kw + element: ColumnComment, compiler: OracleDDLCompiler, **kw ) -> str: ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}" @@ -135,7 +135,7 @@ def visit_column_comment( @compiles(RenameTable, "oracle") def visit_rename_table( - element: "RenameTable", compiler: "OracleDDLCompiler", **kw + element: RenameTable, compiler: OracleDDLCompiler, **kw ) -> str: return "%s RENAME TO %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -143,17 +143,17 @@ def visit_rename_table( ) -def alter_column(compiler: "OracleDDLCompiler", name: str) -> str: +def alter_column(compiler: OracleDDLCompiler, name: str) -> str: return "MODIFY %s" % format_column_name(compiler, name) -def add_column(compiler: "OracleDDLCompiler", column: "Column", **kw) -> str: +def add_column(compiler: OracleDDLCompiler, column: Column, **kw) -> str: return "ADD %s" % compiler.get_column_specification(column, **kw) @compiles(IdentityColumnDefault, "oracle") def visit_identity_column( - element: "IdentityColumnDefault", compiler: "OracleDDLCompiler", **kw + element: IdentityColumnDefault, compiler: OracleDDLCompiler, **kw ): text = "%s %s " % ( alter_table(compiler, element.table_name, element.schema), diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py index 5d93803a..29efe4c9 100644 --- a/alembic/ddl/postgresql.py +++ b/alembic/ddl/postgresql.py @@ -136,13 +136,13 @@ def alter_column( # type:ignore[override] table_name: str, column_name: str, nullable: Optional[bool] = None, - server_default: Union["_ServerDefault", "Literal[False]"] = False, + server_default: Union[_ServerDefault, Literal[False]] = False, name: Optional[str] = None, - type_: Optional["TypeEngine"] = None, + type_: Optional[TypeEngine] = None, schema: Optional[str] = None, autoincrement: Optional[bool] = None, - existing_type: Optional["TypeEngine"] = None, - existing_server_default: Optional["_ServerDefault"] = None, + existing_type: Optional[TypeEngine] = None, + existing_server_default: Optional[_ServerDefault] = None, existing_nullable: Optional[bool] = None, existing_autoincrement: Optional[bool] = None, **kw: Any, @@ -169,7 +169,7 @@ def alter_column( # type:ignore[override] ) ) - super(PostgresqlImpl, self).alter_column( + super().alter_column( table_name, column_name, nullable=nullable, @@ -230,13 +230,13 @@ def correct_for_autogen_constraints( metadata_indexes, ): - conn_indexes_by_name = dict((c.name, c) for c in conn_indexes) + conn_indexes_by_name = {c.name: c for c in conn_indexes} - doubled_constraints = set( + doubled_constraints = { index for index in conn_indexes if index.info.get("duplicates_constraint") - ) + } for ix in doubled_constraints: conn_indexes.remove(ix) @@ -260,8 +260,8 @@ def correct_for_autogen_constraints( metadata_indexes.discard(idx) def render_type( - self, type_: "TypeEngine", autogen_context: "AutogenContext" - ) -> Union[str, "Literal[False]"]: + self, type_: TypeEngine, autogen_context: AutogenContext + ) -> Union[str, Literal[False]]: mod = type(type_).__module__ if not mod.startswith("sqlalchemy.dialects.postgresql"): return False @@ -273,7 +273,7 @@ def render_type( return False def _render_HSTORE_type( - self, type_: "HSTORE", autogen_context: "AutogenContext" + self, type_: HSTORE, autogen_context: AutogenContext ) -> str: return cast( str, @@ -283,7 +283,7 @@ def _render_HSTORE_type( ) def _render_ARRAY_type( - self, type_: "ARRAY", autogen_context: "AutogenContext" + self, type_: ARRAY, autogen_context: AutogenContext ) -> str: return cast( str, @@ -293,7 +293,7 @@ def _render_ARRAY_type( ) def _render_JSON_type( - self, type_: "JSON", autogen_context: "AutogenContext" + self, type_: JSON, autogen_context: AutogenContext ) -> str: return cast( str, @@ -303,7 +303,7 @@ def _render_JSON_type( ) def _render_JSONB_type( - self, type_: "JSONB", autogen_context: "AutogenContext" + self, type_: JSONB, autogen_context: AutogenContext ) -> str: return cast( str, @@ -315,17 +315,17 @@ def _render_JSONB_type( class PostgresqlColumnType(AlterColumn): def __init__( - self, name: str, column_name: str, type_: "TypeEngine", **kw + self, name: str, column_name: str, type_: TypeEngine, **kw ) -> None: using = kw.pop("using", None) - super(PostgresqlColumnType, self).__init__(name, column_name, **kw) + super().__init__(name, column_name, **kw) self.type_ = sqltypes.to_instance(type_) self.using = using @compiles(RenameTable, "postgresql") def visit_rename_table( - element: RenameTable, compiler: "PGDDLCompiler", **kw + element: RenameTable, compiler: PGDDLCompiler, **kw ) -> str: return "%s RENAME TO %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -335,7 +335,7 @@ def visit_rename_table( @compiles(PostgresqlColumnType, "postgresql") def visit_column_type( - element: PostgresqlColumnType, compiler: "PGDDLCompiler", **kw + element: PostgresqlColumnType, compiler: PGDDLCompiler, **kw ) -> str: return "%s %s %s %s" % ( alter_table(compiler, element.table_name, element.schema), @@ -347,7 +347,7 @@ def visit_column_type( @compiles(ColumnComment, "postgresql") def visit_column_comment( - element: "ColumnComment", compiler: "PGDDLCompiler", **kw + element: ColumnComment, compiler: PGDDLCompiler, **kw ) -> str: ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}" comment = ( @@ -369,7 +369,7 @@ def visit_column_comment( @compiles(IdentityColumnDefault, "postgresql") def visit_identity_column( - element: "IdentityColumnDefault", compiler: "PGDDLCompiler", **kw + element: IdentityColumnDefault, compiler: PGDDLCompiler, **kw ): text = "%s %s " % ( alter_table(compiler, element.table_name, element.schema), @@ -415,14 +415,14 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp): def __init__( self, constraint_name: Optional[str], - table_name: Union[str, "quoted_name"], + table_name: Union[str, quoted_name], elements: Union[ Sequence[Tuple[str, str]], - Sequence[Tuple["ColumnClause", str]], + Sequence[Tuple[ColumnClause, str]], ], - where: Optional[Union["BinaryExpression", str]] = None, + where: Optional[Union[BinaryExpression, str]] = None, schema: Optional[str] = None, - _orig_constraint: Optional["ExcludeConstraint"] = None, + _orig_constraint: Optional[ExcludeConstraint] = None, **kw, ) -> None: self.constraint_name = constraint_name @@ -435,8 +435,8 @@ def __init__( @classmethod def from_constraint( # type:ignore[override] - cls, constraint: "ExcludeConstraint" - ) -> "CreateExcludeConstraintOp": + cls, constraint: ExcludeConstraint + ) -> CreateExcludeConstraintOp: constraint_table = sqla_compat._table_for_constraint(constraint) return cls( @@ -455,8 +455,8 @@ def from_constraint( # type:ignore[override] ) def to_constraint( - self, migration_context: Optional["MigrationContext"] = None - ) -> "ExcludeConstraint": + self, migration_context: Optional[MigrationContext] = None + ) -> ExcludeConstraint: if self._orig_constraint is not None: return self._orig_constraint schema_obj = schemaobj.SchemaObjects(migration_context) @@ -479,12 +479,12 @@ def to_constraint( @classmethod def create_exclude_constraint( cls, - operations: "Operations", + operations: Operations, constraint_name: str, table_name: str, *elements: Any, **kw: Any, - ) -> Optional["Table"]: + ) -> Optional[Table]: """Issue an alter to create an EXCLUDE constraint using the current migration context. @@ -546,16 +546,16 @@ def batch_create_exclude_constraint( @render.renderers.dispatch_for(CreateExcludeConstraintOp) def _add_exclude_constraint( - autogen_context: "AutogenContext", op: "CreateExcludeConstraintOp" + autogen_context: AutogenContext, op: CreateExcludeConstraintOp ) -> str: return _exclude_constraint(op.to_constraint(), autogen_context, alter=True) @render._constraint_renderers.dispatch_for(ExcludeConstraint) def _render_inline_exclude_constraint( - constraint: "ExcludeConstraint", - autogen_context: "AutogenContext", - namespace_metadata: "MetaData", + constraint: ExcludeConstraint, + autogen_context: AutogenContext, + namespace_metadata: MetaData, ) -> str: rendered = render._user_defined_render( "exclude", constraint, autogen_context @@ -566,7 +566,7 @@ def _render_inline_exclude_constraint( return _exclude_constraint(constraint, autogen_context, False) -def _postgresql_autogenerate_prefix(autogen_context: "AutogenContext") -> str: +def _postgresql_autogenerate_prefix(autogen_context: AutogenContext) -> str: imports = autogen_context.imports if imports is not None: @@ -575,8 +575,8 @@ def _postgresql_autogenerate_prefix(autogen_context: "AutogenContext") -> str: def _exclude_constraint( - constraint: "ExcludeConstraint", - autogen_context: "AutogenContext", + constraint: ExcludeConstraint, + autogen_context: AutogenContext, alter: bool, ) -> str: opts: List[Tuple[str, Union[quoted_name, str, _f_name, None]]] = [] @@ -645,7 +645,7 @@ def _exclude_constraint( def _render_potential_column( - value: Union["ColumnClause", "Column"], autogen_context: "AutogenContext" + value: Union[ColumnClause, Column], autogen_context: AutogenContext ) -> str: if isinstance(value, ColumnClause): template = "%(prefix)scolumn(%(name)r)" diff --git a/alembic/ddl/sqlite.py b/alembic/ddl/sqlite.py index f986c32c..51233326 100644 --- a/alembic/ddl/sqlite.py +++ b/alembic/ddl/sqlite.py @@ -41,7 +41,7 @@ class SQLiteImpl(DefaultImpl): """ def requires_recreate_in_batch( - self, batch_op: "BatchOperationsImpl" + self, batch_op: BatchOperationsImpl ) -> bool: """Return True if the given :class:`.BatchOperationsImpl` would need the table to be recreated and copied in order to @@ -68,7 +68,7 @@ def requires_recreate_in_batch( else: return False - def add_constraint(self, const: "Constraint"): + def add_constraint(self, const: Constraint): # attempt to distinguish between an # auto-gen constraint and an explicit one if const._create_rule is None: # type:ignore[attr-defined] @@ -85,7 +85,7 @@ def add_constraint(self, const: "Constraint"): "SQLite migrations using a copy-and-move strategy." ) - def drop_constraint(self, const: "Constraint"): + def drop_constraint(self, const: Constraint): if const._create_rule is None: # type:ignore[attr-defined] raise NotImplementedError( "No support for ALTER of constraints in SQLite dialect. " @@ -95,8 +95,8 @@ def drop_constraint(self, const: "Constraint"): def compare_server_default( self, - inspector_column: "Column", - metadata_column: "Column", + inspector_column: Column, + metadata_column: Column, rendered_metadata_default: Optional[str], rendered_inspector_default: Optional[str], ) -> bool: @@ -140,8 +140,8 @@ def _guess_if_default_is_unparenthesized_sql_expr( def autogen_column_reflect( self, - inspector: "Inspector", - table: "Table", + inspector: Inspector, + table: Table, column_info: Dict[str, Any], ) -> None: # SQLite expression defaults require parenthesis when sent @@ -152,11 +152,11 @@ def autogen_column_reflect( column_info["default"] = "(%s)" % (column_info["default"],) def render_ddl_sql_expr( - self, expr: "ClauseElement", is_server_default: bool = False, **kw + self, expr: ClauseElement, is_server_default: bool = False, **kw ) -> str: # SQLite expression defaults require parenthesis when sent # as DDL - str_expr = super(SQLiteImpl, self).render_ddl_sql_expr( + str_expr = super().render_ddl_sql_expr( expr, is_server_default=is_server_default, **kw ) @@ -169,9 +169,9 @@ def render_ddl_sql_expr( def cast_for_batch_migrate( self, - existing: "Column", - existing_transfer: Dict[str, Union["TypeEngine", "Cast"]], - new_type: "TypeEngine", + existing: Column, + existing_transfer: Dict[str, Union[TypeEngine, Cast]], + new_type: TypeEngine, ) -> None: if ( existing.type._type_affinity # type:ignore[attr-defined] @@ -185,7 +185,7 @@ def cast_for_batch_migrate( @compiles(RenameTable, "sqlite") def visit_rename_table( - element: "RenameTable", compiler: "DDLCompiler", **kw + element: RenameTable, compiler: DDLCompiler, **kw ) -> str: return "%s RENAME TO %s" % ( alter_table(compiler, element.table_name, element.schema), diff --git a/alembic/operations/base.py b/alembic/operations/base.py index 2178998a..04b66b55 100644 --- a/alembic/operations/base.py +++ b/alembic/operations/base.py @@ -75,7 +75,7 @@ class Operations(util.ModuleClsProxy): """ - impl: Union["DefaultImpl", "BatchOperationsImpl"] + impl: Union[DefaultImpl, BatchOperationsImpl] _to_impl = util.Dispatcher() def __init__( @@ -222,13 +222,13 @@ def batch_alter_table( schema: Optional[str] = None, recreate: Literal["auto", "always", "never"] = "auto", partial_reordering: Optional[tuple] = None, - copy_from: Optional["Table"] = None, + copy_from: Optional[Table] = None, table_args: Tuple[Any, ...] = (), table_kwargs: Mapping[str, Any] = util.immutabledict(), reflect_args: Tuple[Any, ...] = (), reflect_kwargs: Mapping[str, Any] = util.immutabledict(), naming_convention: Optional[Dict[str, str]] = None, - ) -> Iterator["BatchOperations"]: + ) -> Iterator[BatchOperations]: """Invoke a series of per-table migrations in batch. Batch mode allows a series of operations specific to a table @@ -514,7 +514,7 @@ class BatchOperations(Operations): """ - impl: "BatchOperationsImpl" + impl: BatchOperationsImpl def _noop(self, operation): raise NotImplementedError( diff --git a/alembic/operations/batch.py b/alembic/operations/batch.py index f1459e2b..0c773c68 100644 --- a/alembic/operations/batch.py +++ b/alembic/operations/batch.py @@ -86,11 +86,11 @@ def __init__( self.batch = [] @property - def dialect(self) -> "Dialect": + def dialect(self) -> Dialect: return self.operations.impl.dialect @property - def impl(self) -> "DefaultImpl": + def impl(self) -> DefaultImpl: return self.operations.impl def _should_recreate(self) -> bool: @@ -174,19 +174,19 @@ def add_column(self, *arg, **kw) -> None: def drop_column(self, *arg, **kw) -> None: self.batch.append(("drop_column", arg, kw)) - def add_constraint(self, const: "Constraint") -> None: + def add_constraint(self, const: Constraint) -> None: self.batch.append(("add_constraint", (const,), {})) - def drop_constraint(self, const: "Constraint") -> None: + def drop_constraint(self, const: Constraint) -> None: self.batch.append(("drop_constraint", (const,), {})) def rename_table(self, *arg, **kw): self.batch.append(("rename_table", arg, kw)) - def create_index(self, idx: "Index") -> None: + def create_index(self, idx: Index) -> None: self.batch.append(("create_index", (idx,), {})) - def drop_index(self, idx: "Index") -> None: + def drop_index(self, idx: Index) -> None: self.batch.append(("drop_index", (idx,), {})) def create_table_comment(self, table): @@ -208,8 +208,8 @@ def create_column_comment(self, column): class ApplyBatchImpl: def __init__( self, - impl: "DefaultImpl", - table: "Table", + impl: DefaultImpl, + table: Table, table_args: tuple, table_kwargs: Dict[str, Any], reflected: bool, @@ -236,12 +236,12 @@ def __init__( self._grab_table_elements() @classmethod - def _calc_temp_name(cls, tablename: Union["quoted_name", str]) -> str: + def _calc_temp_name(cls, tablename: Union[quoted_name, str]) -> str: return ("_alembic_tmp_%s" % tablename)[0:50] def _grab_table_elements(self) -> None: schema = self.table.schema - self.columns: Dict[str, "Column"] = OrderedDict() + self.columns: Dict[str, Column] = OrderedDict() for c in self.table.c: c_copy = _copy(c, schema=schema) c_copy.unique = c_copy.index = False @@ -250,11 +250,11 @@ def _grab_table_elements(self) -> None: if isinstance(c.type, SchemaEventTarget): assert c_copy.type is not c.type self.columns[c.name] = c_copy - self.named_constraints: Dict[str, "Constraint"] = {} + self.named_constraints: Dict[str, Constraint] = {} self.unnamed_constraints = [] self.col_named_constraints = {} - self.indexes: Dict[str, "Index"] = {} - self.new_indexes: Dict[str, "Index"] = {} + self.indexes: Dict[str, Index] = {} + self.new_indexes: Dict[str, Index] = {} for const in self.table.constraints: if _is_type_bound(const): @@ -336,14 +336,12 @@ def _transfer_elements_to_new_table(self) -> None: list(self.named_constraints.values()) + self.unnamed_constraints ): - const_columns = set( - [c.key for c in _columns_for_constraint(const)] - ) + const_columns = {c.key for c in _columns_for_constraint(const)} if not const_columns.issubset(self.column_transfers): continue - const_copy: "Constraint" + const_copy: Constraint if isinstance(const, ForeignKeyConstraint): if _fk_is_self_referential(const): # for self-referential constraint, refer to the @@ -368,7 +366,7 @@ def _transfer_elements_to_new_table(self) -> None: self._setup_referent(m, const) new_table.append_constraint(const_copy) - def _gather_indexes_from_both_tables(self) -> List["Index"]: + def _gather_indexes_from_both_tables(self) -> List[Index]: assert self.new_table is not None idx: List[Index] = [] @@ -402,7 +400,7 @@ def _gather_indexes_from_both_tables(self) -> List["Index"]: return idx def _setup_referent( - self, metadata: "MetaData", constraint: "ForeignKeyConstraint" + self, metadata: MetaData, constraint: ForeignKeyConstraint ) -> None: spec = constraint.elements[ 0 @@ -440,7 +438,7 @@ def colspec(elem: Any): schema=referent_schema, ) - def _create(self, op_impl: "DefaultImpl") -> None: + def _create(self, op_impl: DefaultImpl) -> None: self._transfer_elements_to_new_table() op_impl.prep_table_for_batch(self, self.table) @@ -484,11 +482,11 @@ def alter_column( table_name: str, column_name: str, nullable: Optional[bool] = None, - server_default: Optional[Union["Function", str, bool]] = False, + server_default: Optional[Union[Function, str, bool]] = False, name: Optional[str] = None, - type_: Optional["TypeEngine"] = None, + type_: Optional[TypeEngine] = None, autoincrement: None = None, - comment: Union[str, "Literal[False]"] = False, + comment: Union[str, Literal[False]] = False, **kw, ) -> None: existing = self.columns[column_name] @@ -587,9 +585,9 @@ def _setup_dependencies_for_add_column( insert_after = index_cols[idx] else: # insert before a column that is also new - insert_after = dict( - (b, a) for a, b in self.add_col_ordering - )[insert_before] + insert_after = { + b: a for a, b in self.add_col_ordering + }[insert_before] if insert_before: self.add_col_ordering += ((colname, insert_before),) @@ -607,7 +605,7 @@ def _setup_dependencies_for_add_column( def add_column( self, table_name: str, - column: "Column", + column: Column, insert_before: Optional[str] = None, insert_after: Optional[str] = None, **kw, @@ -621,7 +619,7 @@ def add_column( self.column_transfers[column.name] = {} def drop_column( - self, table_name: str, column: Union["ColumnClause", "Column"], **kw + self, table_name: str, column: Union[ColumnClause, Column], **kw ) -> None: if column.name in self.table.primary_key.columns: _remove_column_from_collection( @@ -663,7 +661,7 @@ def drop_table_comment(self, table): """ - def add_constraint(self, const: "Constraint") -> None: + def add_constraint(self, const: Constraint) -> None: if not const.name: raise ValueError("Constraint must have a name") if isinstance(const, sql_schema.PrimaryKeyConstraint): @@ -672,7 +670,7 @@ def add_constraint(self, const: "Constraint") -> None: self.named_constraints[const.name] = const - def drop_constraint(self, const: "Constraint") -> None: + def drop_constraint(self, const: Constraint) -> None: if not const.name: raise ValueError("Constraint must have a name") try: @@ -698,10 +696,10 @@ def drop_constraint(self, const: "Constraint") -> None: for col in const.columns: self.columns[col.name].primary_key = False - def create_index(self, idx: "Index") -> None: + def create_index(self, idx: Index) -> None: self.new_indexes[idx.name] = idx # type: ignore[index] - def drop_index(self, idx: "Index") -> None: + def drop_index(self, idx: Index) -> None: try: del self.indexes[idx.name] # type: ignore[arg-type] except KeyError: diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py index 85ffe149..a93596da 100644 --- a/alembic/operations/ops.py +++ b/alembic/operations/ops.py @@ -78,9 +78,9 @@ def info(self): """ return {} - _mutations: FrozenSet["Rewriter"] = frozenset() + _mutations: FrozenSet[Rewriter] = frozenset() - def reverse(self) -> "MigrateOperation": + def reverse(self) -> MigrateOperation: raise NotImplementedError def to_diff_tuple(self) -> Tuple[Any, ...]: @@ -105,21 +105,21 @@ def go(klass): return go @classmethod - def from_constraint(cls, constraint: "Constraint") -> "AddConstraintOp": + def from_constraint(cls, constraint: Constraint) -> AddConstraintOp: return cls.add_constraint_ops.dispatch(constraint.__visit_name__)( constraint ) @abstractmethod def to_constraint( - self, migration_context: Optional["MigrationContext"] = None - ) -> "Constraint": + self, migration_context: Optional[MigrationContext] = None + ) -> Constraint: pass - def reverse(self) -> "DropConstraintOp": + def reverse(self) -> DropConstraintOp: return DropConstraintOp.from_constraint(self.to_constraint()) - def to_diff_tuple(self) -> Tuple[str, "Constraint"]: + def to_diff_tuple(self) -> Tuple[str, Constraint]: return ("add_constraint", self.to_constraint()) @@ -134,7 +134,7 @@ def __init__( table_name: str, type_: Optional[str] = None, schema: Optional[str] = None, - _reverse: Optional["AddConstraintOp"] = None, + _reverse: Optional[AddConstraintOp] = None, ) -> None: self.constraint_name = constraint_name self.table_name = table_name @@ -142,12 +142,12 @@ def __init__( self.schema = schema self._reverse = _reverse - def reverse(self) -> "AddConstraintOp": + def reverse(self) -> AddConstraintOp: return AddConstraintOp.from_constraint(self.to_constraint()) def to_diff_tuple( self, - ) -> Tuple[str, "SchemaItem"]: + ) -> Tuple[str, SchemaItem]: if self.constraint_type == "foreignkey": return ("remove_fk", self.to_constraint()) else: @@ -156,8 +156,8 @@ def to_diff_tuple( @classmethod def from_constraint( cls, - constraint: "Constraint", - ) -> "DropConstraintOp": + constraint: Constraint, + ) -> DropConstraintOp: types = { "unique_constraint": "unique", "foreign_key_constraint": "foreignkey", @@ -178,7 +178,7 @@ def from_constraint( def to_constraint( self, - ) -> "Constraint": + ) -> Constraint: if self._reverse is not None: constraint = self._reverse.to_constraint() @@ -197,12 +197,12 @@ def to_constraint( @classmethod def drop_constraint( cls, - operations: "Operations", + operations: Operations, constraint_name: str, table_name: str, type_: Optional[str] = None, schema: Optional[str] = None, - ) -> Optional["Table"]: + ) -> Optional[Table]: r"""Drop a constraint of the given name, typically via DROP CONSTRAINT. :param constraint_name: name of the constraint. @@ -222,7 +222,7 @@ def drop_constraint( @classmethod def batch_drop_constraint( cls, - operations: "BatchOperations", + operations: BatchOperations, constraint_name: str, type_: Optional[str] = None, ) -> None: @@ -271,7 +271,7 @@ def __init__( self.kw = kw @classmethod - def from_constraint(cls, constraint: "Constraint") -> "CreatePrimaryKeyOp": + def from_constraint(cls, constraint: Constraint) -> CreatePrimaryKeyOp: constraint_table = sqla_compat._table_for_constraint(constraint) pk_constraint = cast("PrimaryKeyConstraint", constraint) @@ -284,8 +284,8 @@ def from_constraint(cls, constraint: "Constraint") -> "CreatePrimaryKeyOp": ) def to_constraint( - self, migration_context: Optional["MigrationContext"] = None - ) -> "PrimaryKeyConstraint": + self, migration_context: Optional[MigrationContext] = None + ) -> PrimaryKeyConstraint: schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.primary_key_constraint( @@ -299,12 +299,12 @@ def to_constraint( @classmethod def create_primary_key( cls, - operations: "Operations", + operations: Operations, constraint_name: Optional[str], table_name: str, columns: List[str], schema: Optional[str] = None, - ) -> Optional["Table"]: + ) -> Optional[Table]: """Issue a "create primary key" instruction using the current migration context. @@ -347,7 +347,7 @@ def create_primary_key( @classmethod def batch_create_primary_key( cls, - operations: "BatchOperations", + operations: BatchOperations, constraint_name: str, columns: List[str], ) -> None: @@ -397,8 +397,8 @@ def __init__( @classmethod def from_constraint( - cls, constraint: "Constraint" - ) -> "CreateUniqueConstraintOp": + cls, constraint: Constraint + ) -> CreateUniqueConstraintOp: constraint_table = sqla_compat._table_for_constraint(constraint) @@ -419,8 +419,8 @@ def from_constraint( ) def to_constraint( - self, migration_context: Optional["MigrationContext"] = None - ) -> "UniqueConstraint": + self, migration_context: Optional[MigrationContext] = None + ) -> UniqueConstraint: schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.unique_constraint( self.constraint_name, @@ -433,7 +433,7 @@ def to_constraint( @classmethod def create_unique_constraint( cls, - operations: "Operations", + operations: Operations, constraint_name: Optional[str], table_name: str, columns: Sequence[str], @@ -484,7 +484,7 @@ def create_unique_constraint( @classmethod def batch_create_unique_constraint( cls, - operations: "BatchOperations", + operations: BatchOperations, constraint_name: str, columns: Sequence[str], **kw: Any, @@ -531,11 +531,11 @@ def __init__( self.remote_cols = remote_cols self.kw = kw - def to_diff_tuple(self) -> Tuple[str, "ForeignKeyConstraint"]: + def to_diff_tuple(self) -> Tuple[str, ForeignKeyConstraint]: return ("add_fk", self.to_constraint()) @classmethod - def from_constraint(cls, constraint: "Constraint") -> "CreateForeignKeyOp": + def from_constraint(cls, constraint: Constraint) -> CreateForeignKeyOp: fk_constraint = cast("ForeignKeyConstraint", constraint) kw: dict = {} @@ -576,8 +576,8 @@ def from_constraint(cls, constraint: "Constraint") -> "CreateForeignKeyOp": ) def to_constraint( - self, migration_context: Optional["MigrationContext"] = None - ) -> "ForeignKeyConstraint": + self, migration_context: Optional[MigrationContext] = None + ) -> ForeignKeyConstraint: schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.foreign_key_constraint( self.constraint_name, @@ -591,7 +591,7 @@ def to_constraint( @classmethod def create_foreign_key( cls, - operations: "Operations", + operations: Operations, constraint_name: Optional[str], source_table: str, referent_table: str, @@ -605,7 +605,7 @@ def create_foreign_key( source_schema: Optional[str] = None, referent_schema: Optional[str] = None, **dialect_kw: Any, - ) -> Optional["Table"]: + ) -> Optional[Table]: """Issue a "create foreign key" instruction using the current migration context. @@ -671,7 +671,7 @@ def create_foreign_key( @classmethod def batch_create_foreign_key( cls, - operations: "BatchOperations", + operations: BatchOperations, constraint_name: str, referent_table: str, local_cols: List[str], @@ -736,7 +736,7 @@ def __init__( self, constraint_name: Optional[str], table_name: str, - condition: Union[str, "TextClause", "ColumnElement[Any]"], + condition: Union[str, TextClause, ColumnElement[Any]], schema: Optional[str] = None, **kw: Any, ) -> None: @@ -748,8 +748,8 @@ def __init__( @classmethod def from_constraint( - cls, constraint: "Constraint" - ) -> "CreateCheckConstraintOp": + cls, constraint: Constraint + ) -> CreateCheckConstraintOp: constraint_table = sqla_compat._table_for_constraint(constraint) ck_constraint = cast("CheckConstraint", constraint) @@ -763,8 +763,8 @@ def from_constraint( ) def to_constraint( - self, migration_context: Optional["MigrationContext"] = None - ) -> "CheckConstraint": + self, migration_context: Optional[MigrationContext] = None + ) -> CheckConstraint: schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.check_constraint( self.constraint_name, @@ -777,13 +777,13 @@ def to_constraint( @classmethod def create_check_constraint( cls, - operations: "Operations", + operations: Operations, constraint_name: Optional[str], table_name: str, - condition: Union[str, "BinaryExpression"], + condition: Union[str, BinaryExpression], schema: Optional[str] = None, **kw: Any, - ) -> Optional["Table"]: + ) -> Optional[Table]: """Issue a "create check constraint" instruction using the current migration context. @@ -830,11 +830,11 @@ def create_check_constraint( @classmethod def batch_create_check_constraint( cls, - operations: "BatchOperations", + operations: BatchOperations, constraint_name: str, - condition: "TextClause", + condition: TextClause, **kw: Any, - ) -> Optional["Table"]: + ) -> Optional[Table]: """Issue a "create check constraint" instruction using the current batch migration context. @@ -865,7 +865,7 @@ def __init__( self, index_name: str, table_name: str, - columns: Sequence[Union[str, "TextClause", "ColumnElement[Any]"]], + columns: Sequence[Union[str, TextClause, ColumnElement[Any]]], schema: Optional[str] = None, unique: bool = False, **kw: Any, @@ -877,14 +877,14 @@ def __init__( self.unique = unique self.kw = kw - def reverse(self) -> "DropIndexOp": + def reverse(self) -> DropIndexOp: return DropIndexOp.from_index(self.to_index()) - def to_diff_tuple(self) -> Tuple[str, "Index"]: + def to_diff_tuple(self) -> Tuple[str, Index]: return ("add_index", self.to_index()) @classmethod - def from_index(cls, index: "Index") -> "CreateIndexOp": + def from_index(cls, index: Index) -> CreateIndexOp: assert index.table is not None return cls( index.name, # type: ignore[arg-type] @@ -896,8 +896,8 @@ def from_index(cls, index: "Index") -> "CreateIndexOp": ) def to_index( - self, migration_context: Optional["MigrationContext"] = None - ) -> "Index": + self, migration_context: Optional[MigrationContext] = None + ) -> Index: schema_obj = schemaobj.SchemaObjects(migration_context) idx = schema_obj.index( @@ -916,11 +916,11 @@ def create_index( operations: Operations, index_name: str, table_name: str, - columns: Sequence[Union[str, "TextClause", "Function"]], + columns: Sequence[Union[str, TextClause, Function]], schema: Optional[str] = None, unique: bool = False, **kw: Any, - ) -> Optional["Table"]: + ) -> Optional[Table]: r"""Issue a "create index" instruction using the current migration context. @@ -970,11 +970,11 @@ def create_index( @classmethod def batch_create_index( cls, - operations: "BatchOperations", + operations: BatchOperations, index_name: str, columns: List[str], **kw: Any, - ) -> Optional["Table"]: + ) -> Optional[Table]: """Issue a "create index" instruction using the current batch migration context. @@ -1001,10 +1001,10 @@ class DropIndexOp(MigrateOperation): def __init__( self, - index_name: Union["quoted_name", str, "conv"], + index_name: Union[quoted_name, str, conv], table_name: Optional[str] = None, schema: Optional[str] = None, - _reverse: Optional["CreateIndexOp"] = None, + _reverse: Optional[CreateIndexOp] = None, **kw: Any, ) -> None: self.index_name = index_name @@ -1013,14 +1013,14 @@ def __init__( self._reverse = _reverse self.kw = kw - def to_diff_tuple(self) -> Tuple[str, "Index"]: + def to_diff_tuple(self) -> Tuple[str, Index]: return ("remove_index", self.to_index()) - def reverse(self) -> "CreateIndexOp": + def reverse(self) -> CreateIndexOp: return CreateIndexOp.from_index(self.to_index()) @classmethod - def from_index(cls, index: "Index") -> "DropIndexOp": + def from_index(cls, index: Index) -> DropIndexOp: assert index.table is not None return cls( index.name, # type: ignore[arg-type] @@ -1031,8 +1031,8 @@ def from_index(cls, index: "Index") -> "DropIndexOp": ) def to_index( - self, migration_context: Optional["MigrationContext"] = None - ) -> "Index": + self, migration_context: Optional[MigrationContext] = None + ) -> Index: schema_obj = schemaobj.SchemaObjects(migration_context) # need a dummy column name here since SQLAlchemy @@ -1048,12 +1048,12 @@ def to_index( @classmethod def drop_index( cls, - operations: "Operations", + operations: Operations, index_name: str, table_name: Optional[str] = None, schema: Optional[str] = None, **kw: Any, - ) -> Optional["Table"]: + ) -> Optional[Table]: r"""Issue a "drop index" instruction using the current migration context. @@ -1081,7 +1081,7 @@ def drop_index( @classmethod def batch_drop_index( cls, operations: BatchOperations, index_name: str, **kw: Any - ) -> Optional["Table"]: + ) -> Optional[Table]: """Issue a "drop index" instruction using the current batch migration context. @@ -1107,9 +1107,9 @@ class CreateTableOp(MigrateOperation): def __init__( self, table_name: str, - columns: Sequence["SchemaItem"], + columns: Sequence[SchemaItem], schema: Optional[str] = None, - _namespace_metadata: Optional["MetaData"] = None, + _namespace_metadata: Optional[MetaData] = None, _constraints_included: bool = False, **kw: Any, ) -> None: @@ -1123,18 +1123,18 @@ def __init__( self._namespace_metadata = _namespace_metadata self._constraints_included = _constraints_included - def reverse(self) -> "DropTableOp": + def reverse(self) -> DropTableOp: return DropTableOp.from_table( self.to_table(), _namespace_metadata=self._namespace_metadata ) - def to_diff_tuple(self) -> Tuple[str, "Table"]: + def to_diff_tuple(self) -> Tuple[str, Table]: return ("add_table", self.to_table()) @classmethod def from_table( - cls, table: "Table", _namespace_metadata: Optional["MetaData"] = None - ) -> "CreateTableOp": + cls, table: Table, _namespace_metadata: Optional[MetaData] = None + ) -> CreateTableOp: if _namespace_metadata is None: _namespace_metadata = table.metadata @@ -1157,8 +1157,8 @@ def from_table( ) def to_table( - self, migration_context: Optional["MigrationContext"] = None - ) -> "Table": + self, migration_context: Optional[MigrationContext] = None + ) -> Table: schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.table( @@ -1175,11 +1175,11 @@ def to_table( @classmethod def create_table( cls, - operations: "Operations", + operations: Operations, table_name: str, - *columns: "SchemaItem", + *columns: SchemaItem, **kw: Any, - ) -> "Optional[Table]": + ) -> Optional[Table]: r"""Issue a "create table" instruction using the current migration context. @@ -1269,7 +1269,7 @@ def __init__( table_name: str, schema: Optional[str] = None, table_kw: Optional[MutableMapping[Any, Any]] = None, - _reverse: Optional["CreateTableOp"] = None, + _reverse: Optional[CreateTableOp] = None, ) -> None: self.table_name = table_name self.schema = schema @@ -1279,16 +1279,16 @@ def __init__( self.prefixes = self.table_kw.pop("prefixes", None) self._reverse = _reverse - def to_diff_tuple(self) -> Tuple[str, "Table"]: + def to_diff_tuple(self) -> Tuple[str, Table]: return ("remove_table", self.to_table()) - def reverse(self) -> "CreateTableOp": + def reverse(self) -> CreateTableOp: return CreateTableOp.from_table(self.to_table()) @classmethod def from_table( - cls, table: "Table", _namespace_metadata: Optional["MetaData"] = None - ) -> "DropTableOp": + cls, table: Table, _namespace_metadata: Optional[MetaData] = None + ) -> DropTableOp: return cls( table.name, schema=table.schema, @@ -1304,8 +1304,8 @@ def from_table( ) def to_table( - self, migration_context: Optional["MigrationContext"] = None - ) -> "Table": + self, migration_context: Optional[MigrationContext] = None + ) -> Table: if self._reverse: cols_and_constraints = self._reverse.columns else: @@ -1329,7 +1329,7 @@ def to_table( @classmethod def drop_table( cls, - operations: "Operations", + operations: Operations, table_name: str, schema: Optional[str] = None, **kw: Any, @@ -1377,17 +1377,17 @@ def __init__( new_table_name: str, schema: Optional[str] = None, ) -> None: - super(RenameTableOp, self).__init__(old_table_name, schema=schema) + super().__init__(old_table_name, schema=schema) self.new_table_name = new_table_name @classmethod def rename_table( cls, - operations: "Operations", + operations: Operations, old_table_name: str, new_table_name: str, schema: Optional[str] = None, - ) -> Optional["Table"]: + ) -> Optional[Table]: """Emit an ALTER TABLE to rename a table. :param old_table_name: old name. @@ -1424,12 +1424,12 @@ def __init__( @classmethod def create_table_comment( cls, - operations: "Operations", + operations: Operations, table_name: str, comment: Optional[str], existing_comment: None = None, schema: Optional[str] = None, - ) -> Optional["Table"]: + ) -> Optional[Table]: """Emit a COMMENT ON operation to set the comment for a table. .. versionadded:: 1.0.6 @@ -1534,11 +1534,11 @@ def __init__( @classmethod def drop_table_comment( cls, - operations: "Operations", + operations: Operations, table_name: str, existing_comment: Optional[str] = None, schema: Optional[str] = None, - ) -> Optional["Table"]: + ) -> Optional[Table]: """Issue a "drop table comment" operation to remove an existing comment set on a table. @@ -1609,13 +1609,13 @@ def __init__( existing_nullable: Optional[bool] = None, existing_comment: Optional[str] = None, modify_nullable: Optional[bool] = None, - modify_comment: Optional[Union[str, "Literal[False]"]] = False, + modify_comment: Optional[Union[str, Literal[False]]] = False, modify_server_default: Any = False, modify_name: Optional[str] = None, modify_type: Optional[Any] = None, **kw: Any, ) -> None: - super(AlterColumnOp, self).__init__(table_name, schema=schema) + super().__init__(table_name, schema=schema) self.column_name = column_name self.existing_type = existing_type self.existing_server_default = existing_server_default @@ -1723,7 +1723,7 @@ def has_changes(self) -> bool: else: return False - def reverse(self) -> "AlterColumnOp": + def reverse(self) -> AlterColumnOp: kw = self.kw.copy() kw["existing_type"] = self.existing_type @@ -1740,11 +1740,11 @@ def reverse(self) -> "AlterColumnOp": kw["modify_comment"] = self.modify_comment # TODO: make this a little simpler - all_keys = set( + all_keys = { m.group(1) for m in [re.match(r"^(?:existing_|modify_)(.+)$", k) for k in kw] if m - ) + } for k in all_keys: if "modify_%s" % k in kw: @@ -1763,21 +1763,19 @@ def alter_column( table_name: str, column_name: str, nullable: Optional[bool] = None, - comment: Optional[Union[str, "Literal[False]"]] = False, + comment: Optional[Union[str, Literal[False]]] = False, server_default: Any = False, new_column_name: Optional[str] = None, - type_: Optional[Union["TypeEngine", Type["TypeEngine"]]] = None, - existing_type: Optional[ - Union["TypeEngine", Type["TypeEngine"]] - ] = None, + type_: Optional[Union[TypeEngine, Type[TypeEngine]]] = None, + existing_type: Optional[Union[TypeEngine, Type[TypeEngine]]] = None, existing_server_default: Optional[ - Union[str, bool, "Identity", "Computed"] + Union[str, bool, Identity, Computed] ] = False, existing_nullable: Optional[bool] = None, existing_comment: Optional[str] = None, schema: Optional[str] = None, **kw: Any, - ) -> Optional["Table"]: + ) -> Optional[Table]: r"""Issue an "alter column" instruction using the current migration context. @@ -1891,20 +1889,18 @@ def batch_alter_column( operations: BatchOperations, column_name: str, nullable: Optional[bool] = None, - comment: Union[str, "Literal[False]"] = False, - server_default: Union["Function", bool] = False, + comment: Union[str, Literal[False]] = False, + server_default: Union[Function, bool] = False, new_column_name: Optional[str] = None, - type_: Optional[Union["TypeEngine", Type["TypeEngine"]]] = None, - existing_type: Optional[ - Union["TypeEngine", Type["TypeEngine"]] - ] = None, + type_: Optional[Union[TypeEngine, Type[TypeEngine]]] = None, + existing_type: Optional[Union[TypeEngine, Type[TypeEngine]]] = None, existing_server_default: bool = False, existing_nullable: None = None, existing_comment: None = None, insert_before: None = None, insert_after: None = None, **kw: Any, - ) -> Optional["Table"]: + ) -> Optional[Table]: """Issue an "alter column" instruction using the current batch migration context. @@ -1958,29 +1954,29 @@ class AddColumnOp(AlterTableOp): def __init__( self, table_name: str, - column: "Column", + column: Column, schema: Optional[str] = None, **kw: Any, ) -> None: - super(AddColumnOp, self).__init__(table_name, schema=schema) + super().__init__(table_name, schema=schema) self.column = column self.kw = kw - def reverse(self) -> "DropColumnOp": + def reverse(self) -> DropColumnOp: return DropColumnOp.from_column_and_tablename( self.schema, self.table_name, self.column ) def to_diff_tuple( self, - ) -> Tuple[str, Optional[str], str, "Column"]: + ) -> Tuple[str, Optional[str], str, Column]: return ("add_column", self.schema, self.table_name, self.column) - def to_column(self) -> "Column": + def to_column(self) -> Column: return self.column @classmethod - def from_column(cls, col: "Column") -> "AddColumnOp": + def from_column(cls, col: Column) -> AddColumnOp: return cls(col.table.name, col, schema=col.table.schema) @classmethod @@ -1988,18 +1984,18 @@ def from_column_and_tablename( cls, schema: Optional[str], tname: str, - col: "Column", - ) -> "AddColumnOp": + col: Column, + ) -> AddColumnOp: return cls(tname, col, schema=schema) @classmethod def add_column( cls, - operations: "Operations", + operations: Operations, table_name: str, - column: "Column", + column: Column, schema: Optional[str] = None, - ) -> Optional["Table"]: + ) -> Optional[Table]: """Issue an "add column" instruction using the current migration context. @@ -2055,11 +2051,11 @@ def add_column( @classmethod def batch_add_column( cls, - operations: "BatchOperations", - column: "Column", + operations: BatchOperations, + column: Column, insert_before: Optional[str] = None, insert_after: Optional[str] = None, - ) -> Optional["Table"]: + ) -> Optional[Table]: """Issue an "add column" instruction using the current batch migration context. @@ -2094,17 +2090,17 @@ def __init__( table_name: str, column_name: str, schema: Optional[str] = None, - _reverse: Optional["AddColumnOp"] = None, + _reverse: Optional[AddColumnOp] = None, **kw: Any, ) -> None: - super(DropColumnOp, self).__init__(table_name, schema=schema) + super().__init__(table_name, schema=schema) self.column_name = column_name self.kw = kw self._reverse = _reverse def to_diff_tuple( self, - ) -> Tuple[str, Optional[str], str, "Column"]: + ) -> Tuple[str, Optional[str], str, Column]: return ( "remove_column", self.schema, @@ -2112,7 +2108,7 @@ def to_diff_tuple( self.to_column(), ) - def reverse(self) -> "AddColumnOp": + def reverse(self) -> AddColumnOp: if self._reverse is None: raise ValueError( "operation is not reversible; " @@ -2128,8 +2124,8 @@ def from_column_and_tablename( cls, schema: Optional[str], tname: str, - col: "Column", - ) -> "DropColumnOp": + col: Column, + ) -> DropColumnOp: return cls( tname, col.name, @@ -2138,8 +2134,8 @@ def from_column_and_tablename( ) def to_column( - self, migration_context: Optional["MigrationContext"] = None - ) -> "Column": + self, migration_context: Optional[MigrationContext] = None + ) -> Column: if self._reverse is not None: return self._reverse.column schema_obj = schemaobj.SchemaObjects(migration_context) @@ -2148,12 +2144,12 @@ def to_column( @classmethod def drop_column( cls, - operations: "Operations", + operations: Operations, table_name: str, column_name: str, schema: Optional[str] = None, **kw: Any, - ) -> Optional["Table"]: + ) -> Optional[Table]: """Issue a "drop column" instruction using the current migration context. @@ -2196,8 +2192,8 @@ def drop_column( @classmethod def batch_drop_column( - cls, operations: "BatchOperations", column_name: str, **kw: Any - ) -> Optional["Table"]: + cls, operations: BatchOperations, column_name: str, **kw: Any + ) -> Optional[Table]: """Issue a "drop column" instruction using the current batch migration context. @@ -2221,7 +2217,7 @@ class BulkInsertOp(MigrateOperation): def __init__( self, - table: Union["Table", "TableClause"], + table: Union[Table, TableClause], rows: List[dict], multiinsert: bool = True, ) -> None: @@ -2233,7 +2229,7 @@ def __init__( def bulk_insert( cls, operations: Operations, - table: Union["Table", "TableClause"], + table: Union[Table, TableClause], rows: List[dict], multiinsert: bool = True, ) -> None: @@ -2322,7 +2318,7 @@ class ExecuteSQLOp(MigrateOperation): def __init__( self, - sqltext: Union["Update", str, "Insert", "TextClause"], + sqltext: Union[Update, str, Insert, TextClause], execution_options: None = None, ) -> None: self.sqltext = sqltext @@ -2332,9 +2328,9 @@ def __init__( def execute( cls, operations: Operations, - sqltext: Union[str, "TextClause", "Update"], + sqltext: Union[str, TextClause, Update], execution_options: None = None, - ) -> Optional["Table"]: + ) -> Optional[Table]: r"""Execute the given SQL using the current migration context. The given SQL can be a plain string, e.g.:: @@ -2434,12 +2430,11 @@ def as_diffs(self) -> Any: @classmethod def _ops_as_diffs( - cls, migrations: "OpContainer" + cls, migrations: OpContainer ) -> Iterator[Tuple[Any, ...]]: for op in migrations.ops: if hasattr(op, "ops"): - for sub_op in cls._ops_as_diffs(cast("OpContainer", op)): - yield sub_op + yield from cls._ops_as_diffs(cast("OpContainer", op)) else: yield op.to_diff_tuple() @@ -2453,11 +2448,11 @@ def __init__( ops: Sequence[MigrateOperation], schema: Optional[str] = None, ) -> None: - super(ModifyTableOps, self).__init__(ops) + super().__init__(ops) self.table_name = table_name self.schema = schema - def reverse(self) -> "ModifyTableOps": + def reverse(self) -> ModifyTableOps: return ModifyTableOps( self.table_name, ops=list(reversed([op.reverse() for op in self.ops])), @@ -2480,16 +2475,16 @@ def __init__( ops: Sequence[MigrateOperation] = (), upgrade_token: str = "upgrades", ) -> None: - super(UpgradeOps, self).__init__(ops=ops) + super().__init__(ops=ops) self.upgrade_token = upgrade_token - def reverse_into(self, downgrade_ops: "DowngradeOps") -> "DowngradeOps": + def reverse_into(self, downgrade_ops: DowngradeOps) -> DowngradeOps: downgrade_ops.ops[:] = list( # type:ignore[index] reversed([op.reverse() for op in self.ops]) ) return downgrade_ops - def reverse(self) -> "DowngradeOps": + def reverse(self) -> DowngradeOps: return self.reverse_into(DowngradeOps(ops=[])) @@ -2508,7 +2503,7 @@ def __init__( ops: Sequence[MigrateOperation] = (), downgrade_token: str = "downgrades", ) -> None: - super(DowngradeOps, self).__init__(ops=ops) + super().__init__(ops=ops) self.downgrade_token = downgrade_token def reverse(self): @@ -2546,8 +2541,8 @@ class MigrationScript(MigrateOperation): def __init__( self, rev_id: Optional[str], - upgrade_ops: "UpgradeOps", - downgrade_ops: "DowngradeOps", + upgrade_ops: UpgradeOps, + downgrade_ops: DowngradeOps, message: Optional[str] = None, imports: Set[str] = set(), head: Optional[str] = None, @@ -2618,7 +2613,7 @@ def downgrade_ops(self, downgrade_ops): assert isinstance(elem, DowngradeOps) @property - def upgrade_ops_list(self) -> List["UpgradeOps"]: + def upgrade_ops_list(self) -> List[UpgradeOps]: """A list of :class:`.UpgradeOps` instances. This is used in place of the :attr:`.MigrationScript.upgrade_ops` @@ -2629,7 +2624,7 @@ def upgrade_ops_list(self) -> List["UpgradeOps"]: return self._upgrade_ops @property - def downgrade_ops_list(self) -> List["DowngradeOps"]: + def downgrade_ops_list(self) -> List[DowngradeOps]: """A list of :class:`.DowngradeOps` instances. This is used in place of the :attr:`.MigrationScript.downgrade_ops` diff --git a/alembic/operations/schemaobj.py b/alembic/operations/schemaobj.py index 6c6f9714..dfda8bbe 100644 --- a/alembic/operations/schemaobj.py +++ b/alembic/operations/schemaobj.py @@ -36,7 +36,7 @@ class SchemaObjects: def __init__( - self, migration_context: Optional["MigrationContext"] = None + self, migration_context: Optional[MigrationContext] = None ) -> None: self.migration_context = migration_context @@ -47,7 +47,7 @@ def primary_key_constraint( cols: Sequence[str], schema: Optional[str] = None, **dialect_kw, - ) -> "PrimaryKeyConstraint": + ) -> PrimaryKeyConstraint: m = self.metadata() columns = [sa_schema.Column(n, NULLTYPE) for n in cols] t = sa_schema.Table(table_name, m, *columns, schema=schema) @@ -71,7 +71,7 @@ def foreign_key_constraint( initially: Optional[str] = None, match: Optional[str] = None, **dialect_kw, - ) -> "ForeignKeyConstraint": + ) -> ForeignKeyConstraint: m = self.metadata() if source == referent and source_schema == referent_schema: t1_cols = local_cols + remote_cols @@ -120,7 +120,7 @@ def unique_constraint( local_cols: Sequence[str], schema: Optional[str] = None, **kw, - ) -> "UniqueConstraint": + ) -> UniqueConstraint: t = sa_schema.Table( source, self.metadata(), @@ -138,10 +138,10 @@ def check_constraint( self, name: Optional[str], source: str, - condition: Union[str, "TextClause", "ColumnElement[Any]"], + condition: Union[str, TextClause, ColumnElement[Any]], schema: Optional[str] = None, **kw, - ) -> Union["CheckConstraint"]: + ) -> Union[CheckConstraint]: t = sa_schema.Table( source, self.metadata(), @@ -182,7 +182,7 @@ def generic_constraint( t.append_constraint(const) return const - def metadata(self) -> "MetaData": + def metadata(self) -> MetaData: kw = {} if ( self.migration_context is not None @@ -193,7 +193,7 @@ def metadata(self) -> "MetaData": kw["naming_convention"] = mt.naming_convention return sa_schema.MetaData(**kw) - def table(self, name: str, *columns, **kw) -> "Table": + def table(self, name: str, *columns, **kw) -> Table: m = self.metadata() cols = [ @@ -230,17 +230,17 @@ def table(self, name: str, *columns, **kw) -> "Table": self._ensure_table_for_fk(m, f) return t - def column(self, name: str, type_: "TypeEngine", **kw) -> "Column": + def column(self, name: str, type_: TypeEngine, **kw) -> Column: return sa_schema.Column(name, type_, **kw) def index( self, name: str, tablename: Optional[str], - columns: Sequence[Union[str, "TextClause", "ColumnElement[Any]"]], + columns: Sequence[Union[str, TextClause, ColumnElement[Any]]], schema: Optional[str] = None, **kw, - ) -> "Index": + ) -> Index: t = sa_schema.Table( tablename or "no_table", self.metadata(), @@ -264,9 +264,7 @@ def _parse_table_key(self, table_key: str) -> Tuple[Optional[str], str]: sname = None return (sname, tname) - def _ensure_table_for_fk( - self, metadata: "MetaData", fk: "ForeignKey" - ) -> None: + def _ensure_table_for_fk(self, metadata: MetaData, fk: ForeignKey) -> None: """create a placeholder Table object for the referent of a ForeignKey. diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py index 6dbbcc31..44dcd72d 100644 --- a/alembic/runtime/environment.py +++ b/alembic/runtime/environment.py @@ -99,14 +99,14 @@ def my_function(rev, context): """ - _migration_context: Optional["MigrationContext"] = None + _migration_context: Optional[MigrationContext] = None - config: "Config" = None # type:ignore[assignment] + config: Config = None # type:ignore[assignment] """An instance of :class:`.Config` representing the configuration file contents as well as other variables set programmatically within it.""" - script: "ScriptDirectory" = None # type:ignore[assignment] + script: ScriptDirectory = None # type:ignore[assignment] """An instance of :class:`.ScriptDirectory` which provides programmatic access to version files within the ``versions/`` directory. diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py index 677d0c74..95eb82a4 100644 --- a/alembic/runtime/migration.py +++ b/alembic/runtime/migration.py @@ -50,11 +50,11 @@ class _ProxyTransaction: - def __init__(self, migration_context: "MigrationContext") -> None: + def __init__(self, migration_context: MigrationContext) -> None: self.migration_context = migration_context @property - def _proxied_transaction(self) -> Optional["Transaction"]: + def _proxied_transaction(self) -> Optional[Transaction]: return self.migration_context._transaction def rollback(self) -> None: @@ -69,7 +69,7 @@ def commit(self) -> None: t.commit() self.migration_context._transaction = None - def __enter__(self) -> "_ProxyTransaction": + def __enter__(self) -> _ProxyTransaction: return self def __exit__(self, type_: None, value: None, traceback: None) -> None: @@ -127,22 +127,22 @@ class MigrationContext: def __init__( self, - dialect: "Dialect", - connection: Optional["Connection"], + dialect: Dialect, + connection: Optional[Connection], opts: Dict[str, Any], - environment_context: Optional["EnvironmentContext"] = None, + environment_context: Optional[EnvironmentContext] = None, ) -> None: self.environment_context = environment_context self.opts = opts self.dialect = dialect - self.script: Optional["ScriptDirectory"] = opts.get("script") + self.script: Optional[ScriptDirectory] = opts.get("script") as_sql: bool = opts.get("as_sql", False) transactional_ddl = opts.get("transactional_ddl") self._transaction_per_migration = opts.get( "transaction_per_migration", False ) self.on_version_apply_callbacks = opts.get("on_version_apply", ()) - self._transaction: Optional["Transaction"] = None + self._transaction: Optional[Transaction] = None if as_sql: self.connection = cast( @@ -215,14 +215,14 @@ def __init__( @classmethod def configure( cls, - connection: Optional["Connection"] = None, + connection: Optional[Connection] = None, url: Optional[str] = None, dialect_name: Optional[str] = None, - dialect: Optional["Dialect"] = None, - environment_context: Optional["EnvironmentContext"] = None, + dialect: Optional[Dialect] = None, + environment_context: Optional[EnvironmentContext] = None, dialect_opts: Optional[Dict[str, str]] = None, opts: Optional[Any] = None, - ) -> "MigrationContext": + ) -> MigrationContext: """Create a new :class:`.MigrationContext`. This is a factory method usually called @@ -366,7 +366,7 @@ def upgrade(): def begin_transaction( self, _per_migration: bool = False - ) -> Union["_ProxyTransaction", ContextManager]: + ) -> Union[_ProxyTransaction, ContextManager]: """Begin a logical transaction for migration operations. This method is used within an ``env.py`` script to demarcate where @@ -552,9 +552,7 @@ def _has_version_table(self) -> bool: self.connection, self.version_table, self.version_table_schema ) - def stamp( - self, script_directory: "ScriptDirectory", revision: str - ) -> None: + def stamp(self, script_directory: ScriptDirectory, revision: str) -> None: """Stamp the version table with a specific revision. This method calculates those branches to which the given revision @@ -653,7 +651,7 @@ def _in_connection_transaction(self) -> bool: def execute( self, - sql: Union["ClauseElement", str], + sql: Union[ClauseElement, str], execution_options: Optional[dict] = None, ) -> None: """Execute a SQL construct or string statement. @@ -667,15 +665,15 @@ def execute( self.impl._exec(sql, execution_options) def _stdout_connection( - self, connection: Optional["Connection"] - ) -> "MockConnection": + self, connection: Optional[Connection] + ) -> MockConnection: def dump(construct, *multiparams, **params): self.impl._exec(construct) return MockEngineStrategy.MockConnection(self.dialect, dump) @property - def bind(self) -> Optional["Connection"]: + def bind(self) -> Optional[Connection]: """Return the current "bind". In online mode, this is an instance of @@ -696,7 +694,7 @@ def bind(self) -> Optional["Connection"]: return self.connection @property - def config(self) -> Optional["Config"]: + def config(self) -> Optional[Config]: """Return the :class:`.Config` used by the current environment, if any.""" @@ -706,7 +704,7 @@ def config(self) -> Optional["Config"]: return None def _compare_type( - self, inspector_column: "Column", metadata_column: "Column" + self, inspector_column: Column, metadata_column: Column ) -> bool: if self._user_compare_type is False: return False @@ -726,8 +724,8 @@ def _compare_type( def _compare_server_default( self, - inspector_column: "Column", - metadata_column: "Column", + inspector_column: Column, + metadata_column: Column, rendered_metadata_default: Optional[str], rendered_column_default: Optional[str], ) -> bool: @@ -756,7 +754,7 @@ def _compare_server_default( class HeadMaintainer: - def __init__(self, context: "MigrationContext", heads: Any) -> None: + def __init__(self, context: MigrationContext, heads: Any) -> None: self.context = context self.heads = set(heads) @@ -820,7 +818,7 @@ def _update_version(self, from_: str, to_: str) -> None: % (from_, to_, self.context.version_table, ret.rowcount) ) - def update_to_step(self, step: Union["RevisionStep", "StampStep"]) -> None: + def update_to_step(self, step: Union[RevisionStep, StampStep]) -> None: if step.should_delete_branch(self.heads): vers = step.delete_version_num log.debug("branch delete %s", vers) @@ -916,12 +914,12 @@ class MigrationInfo: from dependencies. """ - revision_map: "RevisionMap" + revision_map: RevisionMap """The revision map inside of which this operation occurs.""" def __init__( self, - revision_map: "RevisionMap", + revision_map: RevisionMap, is_upgrade: bool, is_stamp: bool, up_revisions: Union[str, Tuple[str, ...]], @@ -1010,14 +1008,14 @@ def name(self) -> str: @classmethod def upgrade_from_script( - cls, revision_map: "RevisionMap", script: "Script" - ) -> "RevisionStep": + cls, revision_map: RevisionMap, script: Script + ) -> RevisionStep: return RevisionStep(revision_map, script, True) @classmethod def downgrade_from_script( - cls, revision_map: "RevisionMap", script: "Script" - ) -> "RevisionStep": + cls, revision_map: RevisionMap, script: Script + ) -> RevisionStep: return RevisionStep(revision_map, script, False) @property @@ -1046,7 +1044,7 @@ def __str__(self): class RevisionStep(MigrationStep): def __init__( - self, revision_map: "RevisionMap", revision: "Script", is_upgrade: bool + self, revision_map: RevisionMap, revision: Script, is_upgrade: bool ) -> None: self.revision_map = revision_map self.revision = revision @@ -1142,12 +1140,12 @@ def merge_branch_idents( other_heads = set(heads).difference(self.from_revisions) if other_heads: - ancestors = set( + ancestors = { r.revision for r in self.revision_map._get_ancestor_nodes( self.revision_map.get_revisions(other_heads), check=False ) - ) + } from_revisions = list( set(self.from_revisions).difference(ancestors) ) @@ -1164,12 +1162,12 @@ def merge_branch_idents( def _unmerge_to_revisions(self, heads: Collection[str]) -> Tuple[str, ...]: other_heads = set(heads).difference([self.revision.revision]) if other_heads: - ancestors = set( + ancestors = { r.revision for r in self.revision_map._get_ancestor_nodes( self.revision_map.get_revisions(other_heads), check=False ) - ) + } return tuple(set(self.to_revisions).difference(ancestors)) else: return self.to_revisions @@ -1253,7 +1251,7 @@ def insert_version_num(self) -> str: return self.revision.revision @property - def info(self) -> "MigrationInfo": + def info(self) -> MigrationInfo: return MigrationInfo( revision_map=self.revision_map, up_revisions=self.revision.revision, @@ -1270,7 +1268,7 @@ def __init__( to_: Optional[Union[str, Collection[str]]], is_upgrade: bool, branch_move: bool, - revision_map: Optional["RevisionMap"] = None, + revision_map: Optional[RevisionMap] = None, ) -> None: self.from_: Tuple[str, ...] = util.to_tuple(from_, default=()) self.to_: Tuple[str, ...] = util.to_tuple(to_, default=()) @@ -1368,7 +1366,7 @@ def should_unmerge_branches(self, heads: Set[str]) -> bool: return len(self.to_) > 1 @property - def info(self) -> "MigrationInfo": + def info(self) -> MigrationInfo: up, down = ( (self.to_, self.from_) if self.is_upgrade diff --git a/alembic/script/base.py b/alembic/script/base.py index cae0a2bc..3c09cef7 100644 --- a/alembic/script/base.py +++ b/alembic/script/base.py @@ -463,7 +463,7 @@ def _downgrade_revs( def _stamp_revs( self, revision: _RevIdType, heads: _RevIdType - ) -> List["StampStep"]: + ) -> List[StampStep]: with self._catch_revision_errors( multiple_heads="Multiple heads are present; please specify a " "single target revision" @@ -592,7 +592,7 @@ def _ensure_directory(self, path: str) -> None: if not os.path.exists(path): util.status("Creating directory %s" % path, os.makedirs, path) - def _generate_create_date(self) -> "datetime.datetime": + def _generate_create_date(self) -> datetime.datetime: if self.timezone is not None: if tz is None: raise util.CommandError( @@ -769,7 +769,7 @@ def _rev_path( path: str, rev_id: str, message: Optional[str], - create_date: "datetime.datetime", + create_date: datetime.datetime, ) -> str: epoch = int(create_date.timestamp()) slug = "_".join(_slug_re.findall(message or "")).lower() @@ -804,7 +804,7 @@ class Script(revision.Revision): def __init__(self, module: ModuleType, rev_id: str, path: str): self.module = module self.path = path - super(Script, self).__init__( + super().__init__( rev_id, module.down_revision, # type: ignore[attr-defined] branch_labels=util.to_tuple( @@ -964,7 +964,7 @@ def _list_py_dir(cls, scriptdir: ScriptDirectory, path: str) -> List[str]: # in the immediate path paths = os.listdir(path) - names = set(fname.split(".")[0] for fname in paths) + names = {fname.split(".")[0] for fname in paths} # look for __pycache__ if os.path.exists(os.path.join(path, "__pycache__")): diff --git a/alembic/script/revision.py b/alembic/script/revision.py index 6e25891d..39152969 100644 --- a/alembic/script/revision.py +++ b/alembic/script/revision.py @@ -51,7 +51,7 @@ def __init__( ) -> None: self.lower = lower self.upper = upper - super(RangeNotAncestorError, self).__init__( + super().__init__( "Revision %s is not an ancestor of revision %s" % (lower or "base", upper or "base") ) @@ -61,7 +61,7 @@ class MultipleHeads(RevisionError): def __init__(self, heads: Sequence[str], argument: Optional[str]) -> None: self.heads = heads self.argument = argument - super(MultipleHeads, self).__init__( + super().__init__( "Multiple heads are present for given argument '%s'; " "%s" % (argument, ", ".join(heads)) ) @@ -69,7 +69,7 @@ def __init__(self, heads: Sequence[str], argument: Optional[str]) -> None: class ResolutionError(RevisionError): def __init__(self, message: str, argument: str) -> None: - super(ResolutionError, self).__init__(message) + super().__init__(message) self.argument = argument @@ -78,7 +78,7 @@ class CycleDetected(RevisionError): def __init__(self, revisions: Sequence[str]) -> None: self.revisions = revisions - super(CycleDetected, self).__init__( + super().__init__( "%s is detected in revisions (%s)" % (self.kind, ", ".join(revisions)) ) @@ -88,21 +88,21 @@ class DependencyCycleDetected(CycleDetected): kind = "Dependency cycle" def __init__(self, revisions: Sequence[str]) -> None: - super(DependencyCycleDetected, self).__init__(revisions) + super().__init__(revisions) class LoopDetected(CycleDetected): kind = "Self-loop" def __init__(self, revision: str) -> None: - super(LoopDetected, self).__init__([revision]) + super().__init__([revision]) class DependencyLoopDetected(DependencyCycleDetected, LoopDetected): kind = "Dependency self-loop" def __init__(self, revision: Sequence[str]) -> None: - super(DependencyLoopDetected, self).__init__(revision) + super().__init__(revision) class RevisionMap: @@ -114,7 +114,7 @@ class RevisionMap: """ - def __init__(self, generator: Callable[[], Iterable["Revision"]]) -> None: + def __init__(self, generator: Callable[[], Iterable[Revision]]) -> None: """Construct a new :class:`.RevisionMap`. :param generator: a zero-arg callable that will generate an iterable @@ -180,10 +180,10 @@ def _revision_map(self) -> _RevisionMapType: # general) map_: _InterimRevisionMapType = sqlautil.OrderedDict() - heads: Set["Revision"] = sqlautil.OrderedSet() - _real_heads: Set["Revision"] = sqlautil.OrderedSet() - bases: Tuple["Revision", ...] = () - _real_bases: Tuple["Revision", ...] = () + heads: Set[Revision] = sqlautil.OrderedSet() + _real_heads: Set[Revision] = sqlautil.OrderedSet() + bases: Tuple[Revision, ...] = () + _real_bases: Tuple[Revision, ...] = () has_branch_labels = set() all_revisions = set() @@ -249,10 +249,10 @@ def _revision_map(self) -> _RevisionMapType: def _detect_cycles( self, rev_map: _InterimRevisionMapType, - heads: Set["Revision"], - bases: Tuple["Revision", ...], - _real_heads: Set["Revision"], - _real_bases: Tuple["Revision", ...], + heads: Set[Revision], + bases: Tuple[Revision, ...], + _real_heads: Set[Revision], + _real_bases: Tuple[Revision, ...], ) -> None: if not rev_map: return @@ -299,7 +299,7 @@ def _detect_cycles( raise DependencyCycleDetected(sorted(deleted_revs)) def _map_branch_labels( - self, revisions: Collection["Revision"], map_: _RevisionMapType + self, revisions: Collection[Revision], map_: _RevisionMapType ) -> None: for revision in revisions: if revision.branch_labels: @@ -320,7 +320,7 @@ def _map_branch_labels( map_[branch_label] = revision def _add_branches( - self, revisions: Collection["Revision"], map_: _RevisionMapType + self, revisions: Collection[Revision], map_: _RevisionMapType ) -> None: for revision in revisions: if revision.branch_labels: @@ -344,7 +344,7 @@ def _add_branches( break def _add_depends_on( - self, revisions: Collection["Revision"], map_: _RevisionMapType + self, revisions: Collection[Revision], map_: _RevisionMapType ) -> None: """Resolve the 'dependencies' for each revision in a collection in terms of actual revision ids, as opposed to branch labels or other @@ -367,7 +367,7 @@ def _add_depends_on( revision._resolved_dependencies = () def _normalize_depends_on( - self, revisions: Collection["Revision"], map_: _RevisionMapType + self, revisions: Collection[Revision], map_: _RevisionMapType ) -> None: """Create a collection of "dependencies" that omits dependencies that are already ancestor nodes for each revision in a given @@ -406,9 +406,7 @@ def _normalize_depends_on( else: revision._normalized_resolved_dependencies = () - def add_revision( - self, revision: "Revision", _replace: bool = False - ) -> None: + def add_revision(self, revision: Revision, _replace: bool = False) -> None: """add a single revision to an existing map. This method is for single-revision use cases, it's not @@ -602,7 +600,7 @@ def _revision_for_ident( else: branch_rev = None - revision: Union[Optional[Revision], "Literal[False]"] + revision: Union[Optional[Revision], Literal[False]] try: revision = self._revision_map[resolved_id] except KeyError: diff --git a/alembic/testing/env.py b/alembic/testing/env.py index 13d29ff9..3d42f1cb 100644 --- a/alembic/testing/env.py +++ b/alembic/testing/env.py @@ -1,4 +1,3 @@ -#!coding: utf-8 import importlib.machinery import os import shutil diff --git a/alembic/testing/fixtures.py b/alembic/testing/fixtures.py index 26427507..ef1c3bba 100644 --- a/alembic/testing/fixtures.py +++ b/alembic/testing/fixtures.py @@ -1,4 +1,3 @@ -# coding: utf-8 from __future__ import annotations import configparser diff --git a/alembic/testing/suite/_autogen_fixtures.py b/alembic/testing/suite/_autogen_fixtures.py index f97dd753..e09fbfe5 100644 --- a/alembic/testing/suite/_autogen_fixtures.py +++ b/alembic/testing/suite/_autogen_fixtures.py @@ -208,8 +208,7 @@ class AutogenTest(_ComparesFKs): def _flatten_diffs(self, diffs): for d in diffs: if isinstance(d, list): - for fd in self._flatten_diffs(d): - yield fd + yield from self._flatten_diffs(d) else: yield d diff --git a/alembic/testing/warnings.py b/alembic/testing/warnings.py index d809dfe2..86d45a0d 100644 --- a/alembic/testing/warnings.py +++ b/alembic/testing/warnings.py @@ -5,7 +5,6 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from __future__ import absolute_import import warnings diff --git a/alembic/util/langhelpers.py b/alembic/util/langhelpers.py index ff2687ce..8203358e 100644 --- a/alembic/util/langhelpers.py +++ b/alembic/util/langhelpers.py @@ -30,7 +30,7 @@ class _ModuleClsMeta(type): def __setattr__(cls, key: str, value: Callable) -> None: - super(_ModuleClsMeta, cls).__setattr__(key, value) + super().__setattr__(key, value) cls._update_module_proxies(key) # type: ignore @@ -270,7 +270,7 @@ def go(*arg, **kw): else: return fn_or_list # type: ignore - def branch(self) -> "Dispatcher": + def branch(self) -> Dispatcher: """Return a copy of this dispatcher that is independently writable.""" diff --git a/alembic/util/messaging.py b/alembic/util/messaging.py index 54dc04fd..7d9d090a 100644 --- a/alembic/util/messaging.py +++ b/alembic/util/messaging.py @@ -30,7 +30,7 @@ _h, TERMWIDTH, _hp, _wp = struct.unpack("HHHH", ioctl) if TERMWIDTH <= 0: # can occur if running in emacs pseudo-tty TERMWIDTH = None -except (ImportError, IOError): +except (ImportError, OSError): TERMWIDTH = None @@ -42,7 +42,7 @@ def write_outstream(stream: TextIO, *text) -> None: t = t.decode(encoding) try: stream.write(t) - except IOError: + except OSError: # suppress "broken pipe" errors. # no known way to handle this on Python 3 however # as the exception is "ignored" (noisily) in TextIOWrapper. @@ -92,7 +92,7 @@ def msg(msg: str, newline: bool = True, flush: bool = False) -> None: sys.stdout.flush() -def format_as_comma(value: Optional[Union[str, "Iterable[str]"]]) -> str: +def format_as_comma(value: Optional[Union[str, Iterable[str]]]) -> str: if value is None: return "" elif isinstance(value, str): diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py index 8046c9c4..23255be3 100644 --- a/alembic/util/sqla_compat.py +++ b/alembic/util/sqla_compat.py @@ -108,7 +108,7 @@ def _safe_int(value: str) -> Union[int, str]: @contextlib.contextmanager def _ensure_scope_for_ddl( - connection: Optional["Connection"], + connection: Optional[Connection], ) -> Iterator[None]: try: in_transaction = connection.in_transaction # type: ignore[union-attr] @@ -137,8 +137,8 @@ def url_render_as_string(url, hide_password=True): def _safe_begin_connection_transaction( - connection: "Connection", -) -> "Transaction": + connection: Connection, +) -> Transaction: transaction = _get_connection_transaction(connection) if transaction: return transaction @@ -147,7 +147,7 @@ def _safe_begin_connection_transaction( def _safe_commit_connection_transaction( - connection: "Connection", + connection: Connection, ) -> None: transaction = _get_connection_transaction(connection) if transaction: @@ -155,14 +155,14 @@ def _safe_commit_connection_transaction( def _safe_rollback_connection_transaction( - connection: "Connection", + connection: Connection, ) -> None: transaction = _get_connection_transaction(connection) if transaction: transaction.rollback() -def _get_connection_in_transaction(connection: Optional["Connection"]) -> bool: +def _get_connection_in_transaction(connection: Optional[Connection]) -> bool: try: in_transaction = connection.in_transaction # type: ignore except AttributeError: @@ -184,8 +184,8 @@ def _copy(schema_item: _CE, **kw) -> _CE: def _get_connection_transaction( - connection: "Connection", -) -> Optional["Transaction"]: + connection: Connection, +) -> Optional[Transaction]: if sqla_14: return connection.get_transaction() else: @@ -201,7 +201,7 @@ def _create_url(*arg, **kw) -> url.URL: def _connectable_has_table( - connectable: "Connection", tablename: str, schemaname: Union[str, None] + connectable: Connection, tablename: str, schemaname: Union[str, None] ) -> bool: if sqla_14: return inspect(connectable).has_table(tablename, schemaname) @@ -244,7 +244,7 @@ def _server_default_is_identity(*server_default) -> bool: return any(isinstance(sd, Identity) for sd in server_default) -def _table_for_constraint(constraint: "Constraint") -> "Table": +def _table_for_constraint(constraint: Constraint) -> Table: if isinstance(constraint, ForeignKeyConstraint): table = constraint.parent assert table is not None @@ -263,7 +263,7 @@ def _columns_for_constraint(constraint): def _reflect_table( - inspector: "Inspector", table: "Table", include_cols: None + inspector: Inspector, table: Table, include_cols: None ) -> None: if sqla_14: return inspector.reflect_table(table, None) @@ -326,7 +326,7 @@ def _fk_spec(constraint): ) -def _fk_is_self_referential(constraint: "ForeignKeyConstraint") -> bool: +def _fk_is_self_referential(constraint: ForeignKeyConstraint) -> bool: spec = constraint.elements[0]._get_colspec() # type: ignore[attr-defined] tokens = spec.split(".") tokens.pop(-1) # colname @@ -335,7 +335,7 @@ def _fk_is_self_referential(constraint: "ForeignKeyConstraint") -> bool: return tablekey == constraint.parent.key -def _is_type_bound(constraint: "Constraint") -> bool: +def _is_type_bound(constraint: Constraint) -> bool: # this deals with SQLAlchemy #3260, don't copy CHECK constraints # that will be generated by the type. # new feature added for #3260 @@ -351,7 +351,7 @@ def _find_columns(clause): def _remove_column_from_collection( - collection: "ColumnCollection", column: Union["Column", "ColumnClause"] + collection: ColumnCollection, column: Union[Column, ColumnClause] ) -> None: """remove a column from a ColumnCollection.""" @@ -369,8 +369,8 @@ def _remove_column_from_collection( def _textual_index_column( - table: "Table", text_: Union[str, "TextClause", "ColumnElement"] -) -> Union["ColumnElement", "Column"]: + table: Table, text_: Union[str, TextClause, ColumnElement] +) -> Union[ColumnElement, Column]: """a workaround for the Index construct's severe lack of flexibility""" if isinstance(text_, str): c = Column(text_, sqltypes.NULLTYPE) @@ -384,7 +384,7 @@ def _textual_index_column( raise ValueError("String or text() construct expected") -def _copy_expression(expression: _CE, target_table: "Table") -> _CE: +def _copy_expression(expression: _CE, target_table: Table) -> _CE: def replace(col): if ( isinstance(col, Column) @@ -423,7 +423,7 @@ class _textual_index_element(sql.ColumnElement): __visit_name__ = "_textual_idx_element" - def __init__(self, table: "Table", text: "TextClause") -> None: + def __init__(self, table: Table, text: TextClause) -> None: self.table = table self.text = text self.key = text.text @@ -436,7 +436,7 @@ def get_children(self): @compiles(_textual_index_element) def _render_textual_index_column( - element: _textual_index_element, compiler: "SQLCompiler", **kw + element: _textual_index_element, compiler: SQLCompiler, **kw ) -> str: return compiler.process(element.text, **kw) @@ -447,7 +447,7 @@ class _literal_bindparam(BindParameter): @compiles(_literal_bindparam) def _render_literal_bindparam( - element: _literal_bindparam, compiler: "SQLCompiler", **kw + element: _literal_bindparam, compiler: SQLCompiler, **kw ) -> str: return compiler.render_literal_bindparam(element, **kw) @@ -460,7 +460,7 @@ def _get_index_column_names(idx): return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)] -def _column_kwargs(col: "Column") -> Mapping: +def _column_kwargs(col: Column) -> Mapping: if sqla_13: return col.kwargs else: @@ -468,7 +468,7 @@ def _column_kwargs(col: "Column") -> Mapping: def _get_constraint_final_name( - constraint: Union["Index", "Constraint"], dialect: Optional["Dialect"] + constraint: Union[Index, Constraint], dialect: Optional[Dialect] ) -> Optional[str]: if constraint.name is None: return None @@ -508,7 +508,7 @@ def _get_constraint_final_name( def _constraint_is_named( - constraint: Union["Constraint", "Index"], dialect: Optional["Dialect"] + constraint: Union[Constraint, Index], dialect: Optional[Dialect] ) -> bool: if sqla_14: if constraint.name is None: @@ -522,7 +522,7 @@ def _constraint_is_named( return constraint.name is not None -def _is_mariadb(mysql_dialect: "Dialect") -> bool: +def _is_mariadb(mysql_dialect: Dialect) -> bool: if sqla_14: return mysql_dialect.is_mariadb # type: ignore[attr-defined] else: @@ -536,7 +536,7 @@ def _mariadb_normalized_version_info(mysql_dialect): return mysql_dialect._mariadb_normalized_version_info -def _insert_inline(table: Union["TableClause", "Table"]) -> "Insert": +def _insert_inline(table: Union[TableClause, Table]) -> Insert: if sqla_14: return table.insert().inline() else: @@ -554,5 +554,5 @@ def create_mock_engine(url, executor, **kw): # type: ignore[misc] "postgresql://", strategy="mock", executor=executor ) - def _select(*columns, **kw) -> "Select": # type: ignore[no-redef] + def _select(*columns, **kw) -> Select: # type: ignore[no-redef] return sql.select(list(columns), **kw) # type: ignore[call-overload] diff --git a/tests/test_autogen_composition.py b/tests/test_autogen_composition.py index acd3603b..99e5486f 100644 --- a/tests/test_autogen_composition.py +++ b/tests/test_autogen_composition.py @@ -243,12 +243,10 @@ def render_item(type_, col, autogen_context): autogenerate._render_migration_diffs(self.context, template_args) eq_( set(template_args["imports"].split("\n")), - set( - [ - "from foobar import bat", - "from mypackage import my_special_import", - ] - ), + { + "from foobar import bat", + "from mypackage import my_special_import", + }, ) diff --git a/tests/test_autogen_diffs.py b/tests/test_autogen_diffs.py index ead1a7cd..86b2460c 100644 --- a/tests/test_autogen_diffs.py +++ b/tests/test_autogen_diffs.py @@ -289,7 +289,7 @@ class AutogenDefaultSchemaIsNoneTest(AutogenFixtureTest, TestBase): __only_on__ = "sqlite" def setUp(self): - super(AutogenDefaultSchemaIsNoneTest, self).setUp() + super().setUp() # in SQLAlchemy 1.4, SQLite dialect is setting this name # to "main" as is the actual default schema name for SQLite. @@ -512,13 +512,11 @@ def include_object(obj, name, type_, reflected, compare_to): ) alter_cols = ( - set( - [ - d[2] - for d in self._flatten_diffs(diffs) - if d[0].startswith("modify") - ] - ) + { + d[2] + for d in self._flatten_diffs(diffs) + if d[0].startswith("modify") + } .union( d[3].name for d in self._flatten_diffs(diffs) @@ -530,7 +528,7 @@ def include_object(obj, name, type_, reflected, compare_to): if d[0] == "add_table" ) ) - eq_(alter_cols, set(["user_id", "order", "user"])) + eq_(alter_cols, {"user_id", "order", "user"}) def test_include_name(self): all_names = set() @@ -582,13 +580,11 @@ def include_name(name, type_, parent_names): ) alter_cols = ( - set( - [ - d[2] - for d in self._flatten_diffs(diffs) - if d[0].startswith("modify") - ] - ) + { + d[2] + for d in self._flatten_diffs(diffs) + if d[0].startswith("modify") + } .union( d[3].name for d in self._flatten_diffs(diffs) diff --git a/tests/test_autogen_indexes.py b/tests/test_autogen_indexes.py index fb710991..68a6bd6f 100644 --- a/tests/test_autogen_indexes.py +++ b/tests/test_autogen_indexes.py @@ -552,7 +552,7 @@ def test_unnamed_cols_changed(self): diffs = self._fixture(m1, m2) - diffs = set( + diffs = { ( cmd, isinstance(obj, (UniqueConstraint, Index)) @@ -560,23 +560,21 @@ def test_unnamed_cols_changed(self): else False, ) for cmd, obj in diffs - ) + } if self.reports_unnamed_constraints: if self.reports_unique_constraints_as_indexes: eq_( diffs, - set([("remove_index", True), ("add_constraint", False)]), + {("remove_index", True), ("add_constraint", False)}, ) else: eq_( diffs, - set( - [ - ("remove_constraint", True), - ("add_constraint", False), - ] - ), + { + ("remove_constraint", True), + ("add_constraint", False), + }, ) def test_remove_named_unique_index(self): @@ -594,8 +592,8 @@ def test_remove_named_unique_index(self): diffs = self._fixture(m1, m2) if self.reports_unique_constraints: - diffs = set((cmd, obj.name) for cmd, obj in diffs) - eq_(diffs, set([("remove_index", "xidx")])) + diffs = {(cmd, obj.name) for cmd, obj in diffs} + eq_(diffs, {("remove_index", "xidx")}) else: eq_(diffs, []) @@ -614,11 +612,11 @@ def test_remove_named_unique_constraint(self): diffs = self._fixture(m1, m2) if self.reports_unique_constraints: - diffs = set((cmd, obj.name) for cmd, obj in diffs) + diffs = {(cmd, obj.name) for cmd, obj in diffs} if self.reports_unique_constraints_as_indexes: - eq_(diffs, set([("remove_index", "xidx")])) + eq_(diffs, {("remove_index", "xidx")}) else: - eq_(diffs, set([("remove_constraint", "xidx")])) + eq_(diffs, {("remove_constraint", "xidx")}) else: eq_(diffs, []) @@ -668,9 +666,9 @@ def test_add_uq_ix_on_table_create(self): eq_(diffs[0][0], "add_table") eq_(len(diffs), 2) - assert UniqueConstraint not in set( + assert UniqueConstraint not in { type(c) for c in diffs[0][1].constraints - ) + } eq_(diffs[1][0], "add_index") d_table = diffs[0][1] @@ -1071,9 +1069,7 @@ def test_drop_table_w_indexes(self): eq_(diffs[1][0], "remove_index") eq_(diffs[2][0], "remove_table") - eq_( - set([diffs[0][1].name, diffs[1][1].name]), set(["xy_idx", "y_idx"]) - ) + eq_({diffs[0][1].name, diffs[1][1].name}, {"xy_idx", "y_idx"}) def test_add_ix_on_table_create(self): m1 = MetaData() @@ -1083,9 +1079,9 @@ def test_add_ix_on_table_create(self): eq_(diffs[0][0], "add_table") eq_(len(diffs), 2) - assert UniqueConstraint not in set( + assert UniqueConstraint not in { type(c) for c in diffs[0][1].constraints - ) + } eq_(diffs[1][0], "add_index") eq_(diffs[1][1].unique, False) diff --git a/tests/test_autogen_render.py b/tests/test_autogen_render.py index 67093284..0a2fc876 100644 --- a/tests/test_autogen_render.py +++ b/tests/test_autogen_render.py @@ -1296,7 +1296,7 @@ def render(type_, obj, context): ) eq_( self.autogen_context.imports, - set(["from mypackage import MySpecialType"]), + {"from mypackage import MySpecialType"}, ) def test_render_modify_type(self): @@ -1833,7 +1833,7 @@ def test_repr_dialect_type(self): ) eq_( self.autogen_context.imports, - set(["from sqlalchemy.dialects import mysql"]), + {"from sqlalchemy.dialects import mysql"}, ) def test_render_server_default_text(self): diff --git a/tests/test_batch.py b/tests/test_batch.py index 2d29f6c6..e0289aa4 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -1553,11 +1553,11 @@ def test_ix_existing(self): insp = inspect(self.conn) eq_( - set( + { (ix["name"], tuple(ix["column_names"])) for ix in insp.get_indexes("t_w_ix") - ), - set([("ix_data", ("data",)), ("ix_thing", ("thing",))]), + }, + {("ix_data", ("data",)), ("ix_thing", ("thing",))}, ) def test_fk_points_to_me_auto(self): @@ -2268,39 +2268,37 @@ def _datetime_server_default_fixture(self): @exclusions.fails() def test_drop_pk_col_readd_pk_col(self): - super(BatchRoundTripMySQLTest, self).test_drop_pk_col_readd_pk_col() + super().test_drop_pk_col_readd_pk_col() @exclusions.fails() def test_drop_pk_col_readd_col_also_pk_const(self): - super( - BatchRoundTripMySQLTest, self - ).test_drop_pk_col_readd_col_also_pk_const() + super().test_drop_pk_col_readd_col_also_pk_const() @exclusions.fails() def test_rename_column_pk(self): - super(BatchRoundTripMySQLTest, self).test_rename_column_pk() + super().test_rename_column_pk() @exclusions.fails() def test_rename_column(self): - super(BatchRoundTripMySQLTest, self).test_rename_column() + super().test_rename_column() @exclusions.fails() def test_change_type(self): - super(BatchRoundTripMySQLTest, self).test_change_type() + super().test_change_type() def test_create_drop_index(self): - super(BatchRoundTripMySQLTest, self).test_create_drop_index() + super().test_create_drop_index() # fails on mariadb 10.2, succeeds on 10.3 @exclusions.fails_if(config.requirements.mysql_check_col_name_change) def test_rename_column_boolean(self): - super(BatchRoundTripMySQLTest, self).test_rename_column_boolean() + super().test_rename_column_boolean() def test_change_type_boolean_to_int(self): - super(BatchRoundTripMySQLTest, self).test_change_type_boolean_to_int() + super().test_change_type_boolean_to_int() def test_change_type_int_to_boolean(self): - super(BatchRoundTripMySQLTest, self).test_change_type_int_to_boolean() + super().test_change_type_int_to_boolean() class BatchRoundTripPostgresqlTest(BatchRoundTripTest): @@ -2327,34 +2325,26 @@ def _datetime_server_default_fixture(self): @exclusions.fails() def test_drop_pk_col_readd_pk_col(self): - super( - BatchRoundTripPostgresqlTest, self - ).test_drop_pk_col_readd_pk_col() + super().test_drop_pk_col_readd_pk_col() @exclusions.fails() def test_drop_pk_col_readd_col_also_pk_const(self): - super( - BatchRoundTripPostgresqlTest, self - ).test_drop_pk_col_readd_col_also_pk_const() + super().test_drop_pk_col_readd_col_also_pk_const() @exclusions.fails() def test_change_type(self): - super(BatchRoundTripPostgresqlTest, self).test_change_type() + super().test_change_type() def test_create_drop_index(self): - super(BatchRoundTripPostgresqlTest, self).test_create_drop_index() + super().test_create_drop_index() @exclusions.fails() def test_change_type_int_to_boolean(self): - super( - BatchRoundTripPostgresqlTest, self - ).test_change_type_int_to_boolean() + super().test_change_type_int_to_boolean() @exclusions.fails() def test_change_type_boolean_to_int(self): - super( - BatchRoundTripPostgresqlTest, self - ).test_change_type_boolean_to_int() + super().test_change_type_boolean_to_int() def test_add_col_table_has_native_boolean(self): self._native_boolean_fixture() diff --git a/tests/test_command.py b/tests/test_command.py index 0c0ce378..e136c4e7 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -224,15 +224,13 @@ def _assert_lines(self, revs): yield - lines = set( - [ - re.match(r"(^.\w)", elem).group(1) - for elem in re.split( - "\n", buf.getvalue().decode("ascii", "replace").strip() - ) - if elem - ] - ) + lines = { + re.match(r"(^.\w)", elem).group(1) + for elem in re.split( + "\n", buf.getvalue().decode("ascii", "replace").strip() + ) + if elem + } eq_(lines, set(revs)) diff --git a/tests/test_config.py b/tests/test_config.py index 7957a1b7..9f3929a7 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,4 +1,3 @@ -#!coding: utf-8 import os import tempfile diff --git a/tests/test_environment.py b/tests/test_environment.py index d6c3a65d..d9c14ca4 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -1,4 +1,3 @@ -#!coding: utf-8 import os import sys diff --git a/tests/test_external_dialect.py b/tests/test_external_dialect.py index 9ddc12f0..de66517e 100644 --- a/tests/test_external_dialect.py +++ b/tests/test_external_dialect.py @@ -65,7 +65,7 @@ def __init__(self, item_type): if isinstance(item_type, type): item_type = item_type() self.item_type = item_type - super(EXT_ARRAY, self).__init__() + super().__init__() class FOOBARTYPE(sqla_types.TypeEngine): @@ -94,12 +94,10 @@ def test_render_type(self): eq_( self.autogen_context.imports, - set( - [ - "from tests.test_external_dialect " - "import custom_dialect_types" - ] - ), + { + "from tests.test_external_dialect " + "import custom_dialect_types" + }, ) def test_external_nested_render_sqla_type(self): @@ -121,12 +119,10 @@ def test_external_nested_render_sqla_type(self): eq_( self.autogen_context.imports, - set( - [ - "from tests.test_external_dialect " - "import custom_dialect_types" - ] - ), + { + "from tests.test_external_dialect " + "import custom_dialect_types" + }, ) def test_external_nested_render_external_type(self): @@ -141,10 +137,8 @@ def test_external_nested_render_external_type(self): eq_( self.autogen_context.imports, - set( - [ - "from tests.test_external_dialect " - "import custom_dialect_types" - ] - ), + { + "from tests.test_external_dialect " + "import custom_dialect_types" + }, ) diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index b9be5cb3..6a67e0be 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -838,9 +838,7 @@ def _expect_default(self, c_expected, col, seq=None): insp = inspect(config.db) uo = ops.UpgradeOps(ops=[]) - _compare_tables( - set([(None, "t")]), set([]), insp, uo, self.autogen_context - ) + _compare_tables({(None, "t")}, set(), insp, uo, self.autogen_context) diffs = uo.as_diffs() tab = diffs[0][1] @@ -857,8 +855,8 @@ def _expect_default(self, c_expected, col, seq=None): Table("t", m2, Column("x", BigInteger())) self.autogen_context.metadata = m2 _compare_tables( - set([(None, "t")]), - set([(None, "t")]), + {(None, "t")}, + {(None, "t")}, insp, uo, self.autogen_context, diff --git a/tests/test_script_consumption.py b/tests/test_script_consumption.py index d478ae1b..fa84d7e3 100644 --- a/tests/test_script_consumption.py +++ b/tests/test_script_consumption.py @@ -1,5 +1,3 @@ -# coding: utf-8 - from contextlib import contextmanager import os import re @@ -369,7 +367,7 @@ def test_steps(self): alembic.mock_event_listener = None self._env_file_fixture() with mock.patch("alembic.mock_event_listener", mock.Mock()) as mymock: - super(CallbackEnvironmentTest, self).test_steps() + super().test_steps() calls = mymock.call_args_list assert calls for call in calls: @@ -682,7 +680,7 @@ def test_encode(self): bytes_io=True, output_encoding="utf-8" ) as buf: command.upgrade(self.cfg, self.a, sql=True) - assert "« S’il vous plaît…".encode("utf-8") in buf.getvalue() + assert "« S’il vous plaît…".encode() in buf.getvalue() class VersionNameTemplateTest(TestBase): diff --git a/tests/test_script_production.py b/tests/test_script_production.py index 2cf9052a..bedf545d 100644 --- a/tests/test_script_production.py +++ b/tests/test_script_production.py @@ -1,6 +1,7 @@ import datetime import os import re +from unittest.mock import patch from dateutil import tz import sqlalchemy as sa @@ -36,10 +37,6 @@ from alembic.testing.fixtures import TestBase from alembic.util import CommandError -try: - from unittest.mock import patch -except ImportError: - from mock import patch # noqa env, abc, def_ = None, None, None @@ -62,7 +59,7 @@ def test_steps(self): self._test_008_long_name_configurable() def _test_001_environment(self): - assert_set = set(["env.py", "script.py.mako", "README"]) + assert_set = {"env.py", "script.py.mako", "README"} eq_(assert_set.intersection(os.listdir(env.dir)), assert_set) def _test_002_rev_ids(self): @@ -101,7 +98,7 @@ def _test_005_nextrev(self): ) eq_(script.revision, def_) eq_(script.down_revision, abc) - eq_(env.get_revision(abc).nextrev, set([def_])) + eq_(env.get_revision(abc).nextrev, {def_}) assert script.module.down_revision == abc assert callable(script.module.upgrade) assert callable(script.module.downgrade) @@ -115,7 +112,7 @@ def _test_006_from_clean_env(self): env = staging_env(create=False) abc_rev = env.get_revision(abc) def_rev = env.get_revision(def_) - eq_(abc_rev.nextrev, set([def_])) + eq_(abc_rev.nextrev, {def_}) eq_(abc_rev.revision, abc) eq_(def_rev.down_revision, abc) eq_(env.get_heads(), [def_]) @@ -319,7 +316,7 @@ def test_create_script_splice(self): rev = script.get_revision(rev.revision) eq_(rev.down_revision, self.b) assert "some message" in rev.doc - eq_(set(script.get_heads()), set([rev.revision, self.c])) + eq_(set(script.get_heads()), {rev.revision, self.c}) def test_create_script_missing_splice(self): assert_raises_message( diff --git a/tests/test_version_traversal.py b/tests/test_version_traversal.py index 92413ac0..f7ad4f08 100644 --- a/tests/test_version_traversal.py +++ b/tests/test_version_traversal.py @@ -75,14 +75,14 @@ def test_upgrade_path(self): self.e.revision, self.c.revision, [self.up_(self.d), self.up_(self.e)], - set([self.e.revision]), + {self.e.revision}, ) self._assert_upgrade( self.c.revision, None, [self.up_(self.a), self.up_(self.b), self.up_(self.c)], - set([self.c.revision]), + {self.c.revision}, ) def test_relative_upgrade_path(self): @@ -90,32 +90,32 @@ def test_relative_upgrade_path(self): "+2", self.a.revision, [self.up_(self.b), self.up_(self.c)], - set([self.c.revision]), + {self.c.revision}, ) self._assert_upgrade( - "+1", self.a.revision, [self.up_(self.b)], set([self.b.revision]) + "+1", self.a.revision, [self.up_(self.b)], {self.b.revision} ) self._assert_upgrade( "+3", self.b.revision, [self.up_(self.c), self.up_(self.d), self.up_(self.e)], - set([self.e.revision]), + {self.e.revision}, ) self._assert_upgrade( "%s+2" % self.b.revision, self.a.revision, [self.up_(self.b), self.up_(self.c), self.up_(self.d)], - set([self.d.revision]), + {self.d.revision}, ) self._assert_upgrade( "%s-2" % self.d.revision, self.a.revision, [self.up_(self.b)], - set([self.b.revision]), + {self.b.revision}, ) def test_invalid_relative_upgrade_path(self): @@ -142,7 +142,7 @@ def test_downgrade_path(self): self.c.revision, self.e.revision, [self.down_(self.e), self.down_(self.d)], - set([self.c.revision]), + {self.c.revision}, ) self._assert_downgrade( @@ -155,28 +155,28 @@ def test_downgrade_path(self): def test_relative_downgrade_path(self): self._assert_downgrade( - "-1", self.c.revision, [self.down_(self.c)], set([self.b.revision]) + "-1", self.c.revision, [self.down_(self.c)], {self.b.revision} ) self._assert_downgrade( "-3", self.e.revision, [self.down_(self.e), self.down_(self.d), self.down_(self.c)], - set([self.b.revision]), + {self.b.revision}, ) self._assert_downgrade( "%s+2" % self.a.revision, self.d.revision, [self.down_(self.d)], - set([self.c.revision]), + {self.c.revision}, ) self._assert_downgrade( "%s-2" % self.c.revision, self.d.revision, [self.down_(self.d), self.down_(self.c), self.down_(self.b)], - set([self.a.revision]), + {self.a.revision}, ) def test_invalid_relative_downgrade_path(self): @@ -287,7 +287,7 @@ def test_upgrade_single_branch(self): self.d1.revision, self.b.revision, [self.up_(self.c1), self.up_(self.d1)], - set([self.d1.revision]), + {self.d1.revision}, ) def test_upgrade_multiple_branch(self): @@ -303,7 +303,7 @@ def test_upgrade_multiple_branch(self): self.up_(self.c1), self.up_(self.d1), ], - set([self.d1.revision, self.d2.revision]), + {self.d1.revision, self.d2.revision}, ) def test_downgrade_multiple_branch(self): @@ -317,7 +317,7 @@ def test_downgrade_multiple_branch(self): self.down_(self.c2), self.down_(self.b), ], - set([self.a.revision]), + {self.a.revision}, ) def test_relative_upgrade(self): @@ -326,7 +326,7 @@ def test_relative_upgrade(self): "c2branch@head-1", self.b.revision, [self.up_(self.c2)], - set([self.c2.revision]), + {self.c2.revision}, ) def test_relative_downgrade_baseplus2(self): @@ -340,7 +340,7 @@ def test_relative_downgrade_baseplus2(self): self.down_(self.d2), self.down_(self.c2), ], - set([self.b.revision]), + {self.b.revision}, ) def test_relative_downgrade_branchplus2(self): @@ -353,7 +353,7 @@ def test_relative_downgrade_branchplus2(self): "c2branch@base+2", [self.d2.revision, self.d1.revision], [self.down_(self.d2), self.down_(self.c2)], - set([self.d1.revision]), + {self.d1.revision}, ) def test_relative_downgrade_branchplus3(self): @@ -362,13 +362,13 @@ def test_relative_downgrade_branchplus3(self): self.c2.revision, [self.d2.revision, self.d1.revision], [self.down_(self.d2)], - set([self.d1.revision, self.c2.revision]), + {self.d1.revision, self.c2.revision}, ) self._assert_downgrade( "c2branch@base+3", [self.d2.revision, self.d1.revision], [self.down_(self.d2)], - set([self.d1.revision, self.c2.revision]), + {self.d1.revision, self.c2.revision}, ) # Old downgrade -1 behaviour depends on order of branch upgrades. @@ -381,7 +381,7 @@ def test_downgrade_once_order_right(self): "-1", [self.d2.revision, self.d1.revision], [self.down_(self.d2)], - set([self.d1.revision, self.c2.revision]), + {self.d1.revision, self.c2.revision}, ) def test_downgrade_once_order_right_unbalanced(self): @@ -390,7 +390,7 @@ def test_downgrade_once_order_right_unbalanced(self): "-1", [self.c2.revision, self.d1.revision], [self.down_(self.c2)], - set([self.d1.revision]), + {self.d1.revision}, ) def test_downgrade_once_order_left(self): @@ -399,7 +399,7 @@ def test_downgrade_once_order_left(self): "-1", [self.d1.revision, self.d2.revision], [self.down_(self.d1)], - set([self.d2.revision, self.c1.revision]), + {self.d2.revision, self.c1.revision}, ) def test_downgrade_once_order_left_unbalanced(self): @@ -408,7 +408,7 @@ def test_downgrade_once_order_left_unbalanced(self): "-1", [self.c1.revision, self.d2.revision], [self.down_(self.c1)], - set([self.d2.revision]), + {self.d2.revision}, ) def test_downgrade_once_order_left_unbalanced_labelled(self): @@ -416,73 +416,73 @@ def test_downgrade_once_order_left_unbalanced_labelled(self): "c1branch@-1", [self.d1.revision, self.d2.revision], [self.down_(self.d1)], - set([self.c1.revision, self.d2.revision]), + {self.c1.revision, self.d2.revision}, ) # Captures https://github.com/sqlalchemy/alembic/issues/765 def test_downgrade_relative_order_right(self): self._assert_downgrade( - "{}-1".format(self.d2.revision), + f"{self.d2.revision}-1", [self.d2.revision, self.c1.revision], [self.down_(self.d2)], - set([self.c1.revision, self.c2.revision]), + {self.c1.revision, self.c2.revision}, ) def test_downgrade_relative_order_left(self): self._assert_downgrade( - "{}-1".format(self.d2.revision), + f"{self.d2.revision}-1", [self.c1.revision, self.d2.revision], [self.down_(self.d2)], - set([self.c1.revision, self.c2.revision]), + {self.c1.revision, self.c2.revision}, ) def test_downgrade_single_branch_c1branch(self): """Use branch label to specify the branch to downgrade.""" self._assert_downgrade( - "c1branch@{}".format(self.b.revision), + f"c1branch@{self.b.revision}", (self.c1.revision, self.d2.revision), [ self.down_(self.c1), ], - set([self.d2.revision]), + {self.d2.revision}, ) def test_downgrade_single_branch_c1branch_from_d1_head(self): """Use branch label to specify the branch (where the branch label is not on the head revision).""" self._assert_downgrade( - "c2branch@{}".format(self.b.revision), + f"c2branch@{self.b.revision}", (self.c1.revision, self.d2.revision), [ self.down_(self.d2), self.down_(self.c2), ], - set([self.c1.revision]), + {self.c1.revision}, ) def test_downgrade_single_branch_c2(self): """Use a revision on the branch (not head) to specify the branch.""" self._assert_downgrade( - "{}@{}".format(self.c2.revision, self.b.revision), + f"{self.c2.revision}@{self.b.revision}", (self.d1.revision, self.d2.revision), [ self.down_(self.d2), self.down_(self.c2), ], - set([self.d1.revision]), + {self.d1.revision}, ) def test_downgrade_single_branch_d1(self): """Use the head revision to specify the branch.""" self._assert_downgrade( - "{}@{}".format(self.d1.revision, self.b.revision), + f"{self.d1.revision}@{self.b.revision}", (self.d1.revision, self.d2.revision), [ self.down_(self.d1), self.down_(self.c1), ], - set([self.d2.revision]), + {self.d2.revision}, ) def test_downgrade_relative_to_branch_head(self): @@ -490,7 +490,7 @@ def test_downgrade_relative_to_branch_head(self): "c1branch@head-1", (self.d1.revision, self.d2.revision), [self.down_(self.d1)], - set([self.c1.revision, self.d2.revision]), + {self.c1.revision, self.d2.revision}, ) def test_upgrade_other_branch_from_mergepoint(self): @@ -500,7 +500,7 @@ def test_upgrade_other_branch_from_mergepoint(self): "c2branch@+1", (self.c1.revision), [self.up_(self.c2)], - set([self.c1.revision, self.c2.revision]), + {self.c1.revision, self.c2.revision}, ) def test_upgrade_one_branch_of_heads(self): @@ -511,7 +511,7 @@ def test_upgrade_one_branch_of_heads(self): "c2branch@+1", (self.c1.revision, self.c2.revision), [self.up_(self.d2)], - set([self.c1.revision, self.d2.revision]), + {self.c1.revision, self.d2.revision}, ) def test_ambiguous_upgrade(self): @@ -525,13 +525,11 @@ def test_ambiguous_upgrade(self): def test_upgrade_from_base(self): self._assert_upgrade( - "base+1", [], [self.up_(self.a)], set([self.a.revision]) + "base+1", [], [self.up_(self.a)], {self.a.revision} ) def test_upgrade_from_base_implicit(self): - self._assert_upgrade( - "+1", [], [self.up_(self.a)], set([self.a.revision]) - ) + self._assert_upgrade("+1", [], [self.up_(self.a)], {self.a.revision}) def test_downgrade_minus1_to_base(self): self._assert_downgrade( @@ -553,13 +551,13 @@ def test_downgrade_no_effect_branched(self): self.c2.revision, [self.d1.revision, self.c2.revision], [], - set([self.d1.revision, self.c2.revision]), + {self.d1.revision, self.c2.revision}, ) self._assert_downgrade( self.d1.revision, [self.d1.revision, self.c2.revision], [], - set([self.d1.revision, self.c2.revision]), + {self.d1.revision, self.c2.revision}, ) @@ -614,7 +612,7 @@ def test_mergepoint_to_only_one_side_upgrade(self): self.d1.revision, (self.d2.revision, self.b1.revision), [self.up_(self.c1), self.up_(self.d1)], - set([self.d2.revision, self.d1.revision]), + {self.d2.revision, self.d1.revision}, ) def test_mergepoint_to_only_one_side_downgrade(self): @@ -623,7 +621,7 @@ def test_mergepoint_to_only_one_side_downgrade(self): self.b1.revision, (self.d2.revision, self.d1.revision), [self.down_(self.d1), self.down_(self.c1)], - set([self.d2.revision, self.b1.revision]), + {self.d2.revision, self.b1.revision}, ) @@ -698,7 +696,7 @@ def test_mergepoint_to_only_one_side_upgrade(self): self.d1.revision, (self.d3.revision, self.d2.revision, self.b1.revision), [self.up_(self.c1), self.up_(self.d1)], - set([self.d3.revision, self.d2.revision, self.d1.revision]), + {self.d3.revision, self.d2.revision, self.d1.revision}, ) def test_mergepoint_to_only_one_side_downgrade(self): @@ -706,7 +704,7 @@ def test_mergepoint_to_only_one_side_downgrade(self): self.b1.revision, (self.d3.revision, self.d2.revision, self.d1.revision), [self.down_(self.d1), self.down_(self.c1)], - set([self.d3.revision, self.d2.revision, self.b1.revision]), + {self.d3.revision, self.d2.revision, self.b1.revision}, ) def test_mergepoint_to_two_sides_upgrade(self): @@ -716,7 +714,7 @@ def test_mergepoint_to_two_sides_upgrade(self): (self.d3.revision, self.b2.revision, self.b1.revision), [self.up_(self.c2), self.up_(self.c1), self.up_(self.d1)], # this will merge b2 and b1 into d1 - set([self.d3.revision, self.d1.revision]), + {self.d3.revision, self.d1.revision}, ) # but then! b2 will break out again if we keep going with it @@ -724,7 +722,7 @@ def test_mergepoint_to_two_sides_upgrade(self): self.d2.revision, (self.d3.revision, self.d1.revision), [self.up_(self.d2)], - set([self.d3.revision, self.d2.revision, self.d1.revision]), + {self.d3.revision, self.d2.revision, self.d1.revision}, ) @@ -916,14 +914,14 @@ def test_downgrade_to_dependency(self): heads = [self.c2.revision, self.d1.revision] head = HeadMaintainer(mock.Mock(), heads) head.update_to_step(self.down_(self.d1)) - eq_(head.heads, set([self.c2.revision])) + eq_(head.heads, {self.c2.revision}) def test_stamp_across_dependency(self): heads = [self.e1.revision, self.c2.revision] head = HeadMaintainer(mock.Mock(), heads) for step in self.env._stamp_revs(self.b1.revision, heads): head.update_to_step(step) - eq_(head.heads, set([self.b1.revision])) + eq_(head.heads, {self.b1.revision}) class DependsOnBranchTestTwo(MigrationTest): @@ -1010,15 +1008,13 @@ def test_kaboom(self): self.b2.revision, heads, [self.down_(self.bmerge)], - set( - [ - self.amerge.revision, - self.b1.revision, - self.cmerge.revision, - # b2 isn't here, but d1 is, which implies b2. OK! - self.d1.revision, - ] - ), + { + self.amerge.revision, + self.b1.revision, + self.cmerge.revision, + # b2 isn't here, but d1 is, which implies b2. OK! + self.d1.revision, + }, ) # start with those heads.. @@ -1034,15 +1030,13 @@ def test_kaboom(self): "d1@base", heads, [self.down_(self.d1)], - set( - [ - self.amerge.revision, - self.b1.revision, - # b2 has to be INSERTed, because it was implied by d1 - self.b2.revision, - self.cmerge.revision, - ] - ), + { + self.amerge.revision, + self.b1.revision, + # b2 has to be INSERTed, because it was implied by d1 + self.b2.revision, + self.cmerge.revision, + }, ) # start with those heads ... @@ -1071,7 +1065,7 @@ def test_kaboom(self): self.down_(self.c2), self.down_(self.c3), ], - set([]), + set(), ) @@ -1122,7 +1116,7 @@ def test_downgrade_over_crisscross(self): "b1", ["a3", "b2"], [self.down_(self.b2)], - set(["a3"]), # we have b1 also, which is implied by a3 + {"a3"}, # we have b1 also, which is implied by a3 ) @@ -1145,7 +1139,7 @@ def test_traverse(self): self.a2.revision, None, [self.up_(self.a1), self.up_(self.a2)], - set(["a2"]), + {"a2"}, ) def test_traverse_down(self): @@ -1153,7 +1147,7 @@ def test_traverse_down(self): self.a1.revision, self.a2.revision, [self.down_(self.a2)], - set(["a1"]), + {"a1"}, ) @@ -1190,7 +1184,7 @@ def test_dependencies_are_normalized(self): heads, [self.down_(self.b4)], # a3 isn't here, because b3 still implies a3 - set([self.b3.revision]), + {self.b3.revision}, ) @@ -1239,7 +1233,7 @@ def test_upgrade_path(self): self.up_(self.b2), self.up_(self.c2), ], - set([self.c2.revision]), + {self.c2.revision}, ) @@ -1276,8 +1270,8 @@ def test_stamp_to_heads(self): revs = self.env._stamp_revs("heads", ()) eq_(len(revs), 2) eq_( - set(r.to_revisions for r in revs), - set([(self.b1.revision,), (self.b2.revision,)]), + {r.to_revisions for r in revs}, + {(self.b1.revision,), (self.b2.revision,)}, ) def test_stamp_to_heads_no_moves_needed(self): @@ -1448,19 +1442,19 @@ def test_downgrade_independent_branch(self): """c2branch depends on c1branch so can be taken down on its own. Current behaviour also takes down the dependency unnecessarily.""" self._assert_downgrade( - "c2branch@{}".format(self.b.revision), + f"c2branch@{self.b.revision}", (self.d1.revision, self.d2.revision), [ self.down_(self.d2), self.down_(self.c2), ], - set([self.d1.revision]), + {self.d1.revision}, ) def test_downgrade_branch_dependency(self): """c2branch depends on c1branch so taking down c1branch requires taking down both""" - destination = "c1branch@{}".format(self.b.revision) + destination = f"c1branch@{self.b.revision}" source = self.d1.revision, self.d2.revision revs = self.env._downgrade_revs(destination, source) # Drops c1, d1 as requested, also drops d2 due to dependence on d1. @@ -1483,4 +1477,4 @@ def test_downgrade_branch_dependency(self): head = HeadMaintainer(mock.Mock(), heads) for rev in revs: head.update_to_step(rev) - eq_(head.heads, set([self.c2.revision])) + eq_(head.heads, {self.c2.revision})