diff --git a/migration_fixer/management/commands/makemigrations.py b/migration_fixer/management/commands/makemigrations.py index e65c176a..dc475aa4 100644 --- a/migration_fixer/management/commands/makemigrations.py +++ b/migration_fixer/management/commands/makemigrations.py @@ -11,7 +11,7 @@ from django.core.management.commands.makemigrations import Command as BaseCommand from django.db import DEFAULT_DB_ALIAS, connections, router from django.db.migrations.loader import MigrationLoader -from git import InvalidGitRepositoryError, Repo +from git import GitCommandError, InvalidGitRepositoryError, Repo from migration_fixer.utils import ( fix_numbered_migration, @@ -19,7 +19,6 @@ get_migration_module_path, migration_sorter, no_translations, - sibling_nodes, ) @@ -127,11 +126,19 @@ def handle(self, *app_labels, **options): force=self.force_update, ) else: - remote = self.repo.remotes[self.remote] - remote.fetch( - f"{self.default_branch}:{self.default_branch}", - force=self.force_update, - ) + try: + remote = self.repo.remotes[self.remote] + remote.fetch( + f"{self.default_branch}:{self.default_branch}", + force=self.force_update, + ) + except GitCommandError as e: # pragma: no cover + raise CommandError( + self.style.ERROR( + f"Unable to fetch {self.remote} branch " + f"'{self.default_branch}': {e.stderr}", + ), + ) if self.verbosity >= 2: self.stdout.write( @@ -170,13 +177,9 @@ def handle(self, *app_labels, **options): ): loader.check_consistent_history(connection) - conflicts = { - app_name: sibling_nodes(loader.graph, app_name) - for app_name in loader.detect_conflicts() - } + conflict_leaf_nodes = loader.detect_conflicts() - for app_label in conflicts: - conflict = conflicts[app_label] + for app_label, leaf_nodes in conflict_leaf_nodes.items(): migration_module, _ = loader.migrations_module(app_label) migration_path = get_migration_module_path(migration_module) @@ -202,43 +205,24 @@ def handle(self, *app_labels, **options): ) ] - # Only consider files from the current conflict. - conflict_base = [ - get_filename(path) - for path in changed_files - if get_filename(path) in conflict - ][0] - sorted_changed_files = sorted( changed_files, key=partial(migration_sorter, app_label=app_label), ) - changed_files = [ - path - for path in sorted_changed_files - if ( - int(get_filename(path).split("_")[0]) - >= int(conflict_base.split("_")[0]) - ) - ] - # Local migration local_filenames = [ - get_filename(p) for p in changed_files + get_filename(p) for p in sorted_changed_files ] - if self.verbosity >= 2: - self.stdout.write( - f"Retrieving the last migration on: {self.default_branch}" - ) - last_remote = [ + # Calculate the last changed file on the default branch + conflict_bases = [ name - for name in conflict + for name in leaf_nodes if name not in local_filenames ] - if not last_remote: # pragma: no cover + if not conflict_bases: # pragma: no cover raise CommandError( self.style.ERROR( f"Unable to determine the last migration on: " @@ -248,12 +232,14 @@ def handle(self, *app_labels, **options): ) ) - last_remote_filename, *rest = last_remote - changed_files = changed_files or [ - f"{fname}.py" for fname in rest - ] + conflict_base = conflict_bases[0] + + if self.verbosity >= 2: + self.stdout.write( + f"Retrieving the last migration on: {self.default_branch}" + ) - seed_split = last_remote_filename.split("_") + seed_split = conflict_base.split("_") if ( seed_split @@ -269,8 +255,8 @@ def handle(self, *app_labels, **options): app_label=app_label, migration_path=migration_path, seed=int(seed_split[0]), - start_name=last_remote_filename, - changed_files=changed_files, + start_name=conflict_base, + changed_files=sorted_changed_files, writer=( lambda m: self.stdout.write(m) if self.verbosity >= 2 @@ -279,7 +265,7 @@ def handle(self, *app_labels, **options): ) else: # pragma: no cover raise ValueError( - f"Unable to fix migration: {last_remote_filename}. \n" + f"Unable to fix migration: {conflict_base}. \n" f"NOTE: It needs to begin with a number. eg. 0001_*", ) except (ValueError, IndexError, TypeError) as e: diff --git a/migration_fixer/tests/demo b/migration_fixer/tests/demo index f7377ea6..7c99f0c5 160000 --- a/migration_fixer/tests/demo +++ b/migration_fixer/tests/demo @@ -1 +1 @@ -Subproject commit f7377ea62c393201ae3992b41524a9c0f0f45ce7 +Subproject commit 7c99f0c5f4097b40fc03b632d0aacbf4924e804b diff --git a/migration_fixer/utils.py b/migration_fixer/utils.py index 207e8e02..8cc53d37 100644 --- a/migration_fixer/utils.py +++ b/migration_fixer/utils.py @@ -3,9 +3,7 @@ from importlib import import_module from itertools import count from pathlib import Path -from typing import Callable, List, Optional - -from django.db.migrations.graph import MigrationGraph +from typing import Callable, List DEFAULT_TIMEOUT = 120 MIGRATION_REGEX = "\\((?P['\"]){app_label}(['\"]),\\s(['\"])(?P.*)(['\"])\\)," @@ -140,20 +138,3 @@ def get_migration_module_path(migration_module_path: str) -> Path: raise return Path(os.path.dirname(os.path.abspath(migration_module.__file__))) - - -def sibling_nodes(graph: MigrationGraph, app_name: Optional[str] = None) -> List[str]: - """ - Return all sibling nodes that have the same parent - - it's usually the result of a VCS merge and needs some user input. - """ - siblings = set() - - for node in graph.nodes: - if len(graph.node_map[node].children) > 1 and ( - not app_name or app_name == node[0] - ): - for child in graph.node_map[node].children: - siblings.add(child[-1]) - - return sorted(siblings)