Skip to content

Commit

Permalink
Add readme. migrate_ckpt also returns list of done migrations
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Feb 20, 2024
1 parent 74d3c3a commit 25ea34c
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 7 deletions.
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1 +1,33 @@
# migrate-ckpt

```python
import torch
from migrate_ckpt import Migration, migrate_ckpt


def update_some_keys_callback(ckpt):
"""
Define a callback that takes a checkpoints and updates it.
"""
ckpt["some_keys"] = ckpt["some_other_keys"]
del ckpt["some_other_keys"]
return ckpt


# List a set of migrations. Whenever you update your model architecture,
# you should add one that updates the model starting from the previous
# state (output of the previous migration)
model_migrations = [
Migration("Update some keys", update_some_keys_callback),
]

# Will only perform new migrations.
# done_migrations returns the list of migration objects that were executed.
ckpt, done_migrations = migrate_ckpt(
torch.load("/path/to/some/checkpoint.ckpt"),
model_migrations,
)

# This has no effect, the model was already migrated.
ckpt_2, _ = migrate_ckpt(ckpt, model_migrations)
```
4 changes: 2 additions & 2 deletions migrate_ckpt/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _mark_ckpt(ckpt: CkptType, migration: Migration) -> CkptType:
def migrate_ckpt(
ckpt: CkptType,
migrations: Sequence[Migration],
) -> CkptType:
) -> tuple[CkptType, Sequence[Migration]]:
"""
Migrate checkpoint using provided migrations
Args:
Expand All @@ -55,4 +55,4 @@ def migrate_ckpt(
for migration in missing_migrations:
ckpt = migration.callback(ckpt)
ckpt = _mark_ckpt(ckpt, migration)
return ckpt
return ckpt, missing_migrations
10 changes: 5 additions & 5 deletions tests/test_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,21 @@ def update_test_callback(x):

def test_missing_fields():
ckpt: dict[str, Any] = {}
new_ckpt = migrate_ckpt(ckpt, [blank_migration])
new_ckpt, _ = migrate_ckpt(ckpt, [blank_migration])
assert ckpt_migration_key in new_ckpt
assert isinstance(new_ckpt[ckpt_migration_key], list)
assert len(new_ckpt[ckpt_migration_key]) == 1
assert new_ckpt[ckpt_migration_key][0] == "blank"


def test_missing_one_migration():
ckpt = migrate_ckpt({}, [blank_migration])
new_ckpt = migrate_ckpt(ckpt, [blank2_migration])
ckpt, _ = migrate_ckpt({}, [blank_migration])
new_ckpt, _ = migrate_ckpt(ckpt, [blank2_migration])
assert new_ckpt[ckpt_migration_key][1] == "blank2"


def test_execute_migration():
ckpt = migrate_ckpt(
ckpt, _ = migrate_ckpt(
{},
[add_field_migration],
)
Expand All @@ -49,7 +49,7 @@ def test_execute_migration():


def test_execute_related_migrations():
ckpt = migrate_ckpt(
ckpt, _ = migrate_ckpt(
{},
[add_field_migration, update_test_migration],
)
Expand Down

0 comments on commit 25ea34c

Please sign in to comment.