Skip to content

Commit

Permalink
bug fix when saving config files for cached activations, added paired…
Browse files Browse the repository at this point in the history
… activations
  • Loading branch information
jkminder committed Oct 22, 2024
1 parent ca0ecc6 commit 12fe08a
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,20 @@ def collect(data : Dataset,
ActivationCache.collate_store_shards(store_dirs, shard_count, activation_cache, submodule_names, shuffle_shards, io)

# store configs
with open(os.path.join(store_dir, "config.json"), "w") as f:
json.dump({"batch_size" : batch_size, "context_len" : context_len, "shard_size" : shard_size, "d_model" : d_model, "shuffle_shards" : shuffle_shards, "io" : io, "total_size" : total_size, "shard_count" : shard_count}, f)
for i, store_dir in enumerate(store_dirs):
with open(os.path.join(store_dir, "config.json"), "w") as f:
json.dump({"batch_size" : batch_size, "context_len" : context_len, "shard_size" : shard_size, "d_model" : d_model, "shuffle_shards" : shuffle_shards, "io" : io, "total_size" : total_size, "shard_count" : shard_count}, f)
logger.info(f"Finished collecting activations. Total size: {total_size}")


class PairedActivationCache:
def __init__(self, store_dir_1 : str, store_dir_2 : str):
self.activation_cache_1 = ActivationCache(store_dir_1)
self.activation_cache_2 = ActivationCache(store_dir_2)
assert len(self.activation_cache_1) == len(self.activation_cache_2)

def __len__(self):
return len(self.activation_cache_1)

def __getitem__(self, index : int):
return th.stack((self.activation_cache_1[index], self.activation_cache_2[index]), dim=0)

0 comments on commit 12fe08a

Please sign in to comment.