diff --git a/conftest.py b/conftest.py index a1ab2aed5..ec975c428 100644 --- a/conftest.py +++ b/conftest.py @@ -5,6 +5,7 @@ from olmo.config import ( DataConfig, InitFnType, + InstanceFilterConfig, ModelConfig, OptimizerConfig, PaddingDirection, @@ -72,6 +73,9 @@ def train_config(tmp_path, model_config) -> TrainConfig: "test_fixtures/c4-sample.03.json.gz", ], pad_direction=PaddingDirection.right, + instance_filter=InstanceFilterConfig( + repetition_max_period=3, repetition_min_period=1, repetition_max_count=3 + ), ), tokenizer=TokenizerConfig(identifier=TEST_MODEL), save_folder=str(tmp_path / "checkpoints"), diff --git a/tests/data/collator_test.py b/tests/data/collator_test.py index e94451313..9c3b60f29 100644 --- a/tests/data/collator_test.py +++ b/tests/data/collator_test.py @@ -129,3 +129,20 @@ def test_collate_with_label_mask(train_config, pad_direction): [[True, False, True, True], [False, True, True, False]], ) ).all() + + +@pytest.mark.parametrize( + "pad_direction", + [pytest.param(PaddingDirection.right, id="pad-right"), pytest.param(PaddingDirection.left, id="pad-left")], +) +def test_collate_with_instance_filter(train_config, pad_direction): + train_config.data.pad_direction = pad_direction + collator = DataCollator.from_train_config(train_config) + + inputs = [torch.tensor([0, 0, 2, 3]), torch.tensor([1, 1, 1])] + batch = collator(inputs) + assert batch["input_ids"].shape == (2, 4) + if pad_direction == "right": + assert batch["input_ids"][1][-1] == train_config.model.pad_token_id + else: + assert batch["input_ids"][1][0] == train_config.model.pad_token_id diff --git a/tests/data/memmap_dataset_test.py b/tests/data/memmap_dataset_test.py index e267043ee..39579cef1 100644 --- a/tests/data/memmap_dataset_test.py +++ b/tests/data/memmap_dataset_test.py @@ -3,6 +3,7 @@ import numpy as np +from olmo.config import InstanceFilterConfig from olmo.data.memmap_dataset import MemMapDataset from olmo.tokenizer import Tokenizer @@ -106,3 +107,21 @@ def test_concat_mmap_datasets(tmp_path: Path): # Should get the same with negative index. assert ds[-1]["input_ids"].tolist() == [3, 4, 5] assert ds[-1]["metadata"]["label"] == "test2" + + +def test_instance_filter(tmp_path: Path): + # Write some bad data to disk. + mmap = np.memmap(tmp_path / "bad_tokens.npy", dtype=np.uint16, mode="w+", shape=(128,)) + mmap[:] = list(np.ones(31)) + list(range(64 - 31)) + list(np.ones(32)) + list(range(64 - 32)) + mmap.flush() + + instance_filter_config = InstanceFilterConfig( + repetition_min_period=1, repetition_max_period=13, repetition_max_count=32 + ) + ds = MemMapDataset(tmp_path / "bad_tokens.npy", chunk_size=64, instance_filter_config=instance_filter_config) + + out = ds[0] + assert out["instance_mask"] is True + + out = ds[1] + assert out["instance_mask"] is False