Skip to content

Commit

Permalink
Resolve bug with reseeding migrations. (#136)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Lint Action <github-action[bot]@github.com>
  • Loading branch information
3 people authored Aug 28, 2021
1 parent da2f1e1 commit 7267569
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 66 deletions.
76 changes: 31 additions & 45 deletions migration_fixer/management/commands/makemigrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@
from django.core.management.commands.makemigrations import Command as BaseCommand
from django.db import DEFAULT_DB_ALIAS, connections, router
from django.db.migrations.loader import MigrationLoader
from git import InvalidGitRepositoryError, Repo
from git import GitCommandError, InvalidGitRepositoryError, Repo

from migration_fixer.utils import (
fix_numbered_migration,
get_filename,
get_migration_module_path,
migration_sorter,
no_translations,
sibling_nodes,
)


Expand Down Expand Up @@ -127,11 +126,19 @@ def handle(self, *app_labels, **options):
force=self.force_update,
)
else:
remote = self.repo.remotes[self.remote]
remote.fetch(
f"{self.default_branch}:{self.default_branch}",
force=self.force_update,
)
try:
remote = self.repo.remotes[self.remote]
remote.fetch(
f"{self.default_branch}:{self.default_branch}",
force=self.force_update,
)
except GitCommandError as e: # pragma: no cover
raise CommandError(
self.style.ERROR(
f"Unable to fetch {self.remote} branch "
f"'{self.default_branch}': {e.stderr}",
),
)

if self.verbosity >= 2:
self.stdout.write(
Expand Down Expand Up @@ -170,13 +177,9 @@ def handle(self, *app_labels, **options):
):
loader.check_consistent_history(connection)

conflicts = {
app_name: sibling_nodes(loader.graph, app_name)
for app_name in loader.detect_conflicts()
}
conflict_leaf_nodes = loader.detect_conflicts()

for app_label in conflicts:
conflict = conflicts[app_label]
for app_label, leaf_nodes in conflict_leaf_nodes.items():
migration_module, _ = loader.migrations_module(app_label)
migration_path = get_migration_module_path(migration_module)

Expand All @@ -202,43 +205,24 @@ def handle(self, *app_labels, **options):
)
]

# Only consider files from the current conflict.
conflict_base = [
get_filename(path)
for path in changed_files
if get_filename(path) in conflict
][0]

sorted_changed_files = sorted(
changed_files,
key=partial(migration_sorter, app_label=app_label),
)

changed_files = [
path
for path in sorted_changed_files
if (
int(get_filename(path).split("_")[0])
>= int(conflict_base.split("_")[0])
)
]

# Local migration
local_filenames = [
get_filename(p) for p in changed_files
get_filename(p) for p in sorted_changed_files
]
if self.verbosity >= 2:
self.stdout.write(
f"Retrieving the last migration on: {self.default_branch}"
)

last_remote = [
# Calculate the last changed file on the default branch
conflict_bases = [
name
for name in conflict
for name in leaf_nodes
if name not in local_filenames
]

if not last_remote: # pragma: no cover
if not conflict_bases: # pragma: no cover
raise CommandError(
self.style.ERROR(
f"Unable to determine the last migration on: "
Expand All @@ -248,12 +232,14 @@ def handle(self, *app_labels, **options):
)
)

last_remote_filename, *rest = last_remote
changed_files = changed_files or [
f"{fname}.py" for fname in rest
]
conflict_base = conflict_bases[0]

if self.verbosity >= 2:
self.stdout.write(
f"Retrieving the last migration on: {self.default_branch}"
)

seed_split = last_remote_filename.split("_")
seed_split = conflict_base.split("_")

if (
seed_split
Expand All @@ -269,8 +255,8 @@ def handle(self, *app_labels, **options):
app_label=app_label,
migration_path=migration_path,
seed=int(seed_split[0]),
start_name=last_remote_filename,
changed_files=changed_files,
start_name=conflict_base,
changed_files=sorted_changed_files,
writer=(
lambda m: self.stdout.write(m)
if self.verbosity >= 2
Expand All @@ -279,7 +265,7 @@ def handle(self, *app_labels, **options):
)
else: # pragma: no cover
raise ValueError(
f"Unable to fix migration: {last_remote_filename}. \n"
f"Unable to fix migration: {conflict_base}. \n"
f"NOTE: It needs to begin with a number. eg. 0001_*",
)
except (ValueError, IndexError, TypeError) as e:
Expand Down
21 changes: 1 addition & 20 deletions migration_fixer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from importlib import import_module
from itertools import count
from pathlib import Path
from typing import Callable, List, Optional

from django.db.migrations.graph import MigrationGraph
from typing import Callable, List

DEFAULT_TIMEOUT = 120
MIGRATION_REGEX = "\\((?P<comma>['\"]){app_label}(['\"]),\\s(['\"])(?P<conflict_migration>.*)(['\"])\\),"
Expand Down Expand Up @@ -140,20 +138,3 @@ def get_migration_module_path(migration_module_path: str) -> Path:
raise

return Path(os.path.dirname(os.path.abspath(migration_module.__file__)))


def sibling_nodes(graph: MigrationGraph, app_name: Optional[str] = None) -> List[str]:
"""
Return all sibling nodes that have the same parent
- it's usually the result of a VCS merge and needs some user input.
"""
siblings = set()

for node in graph.nodes:
if len(graph.node_map[node].children) > 1 and (
not app_name or app_name == node[0]
):
for child in graph.node_map[node].children:
siblings.add(child[-1])

return sorted(siblings)

0 comments on commit 7267569

Please sign in to comment.