diff --git a/hivemind/client/averaging/matchmaking.py b/hivemind/client/averaging/matchmaking.py index 8ec866e51..c711f49a6 100644 --- a/hivemind/client/averaging/matchmaking.py +++ b/hivemind/client/averaging/matchmaking.py @@ -467,5 +467,13 @@ async def _declare_averager_periodically(self, key_manager: GroupKeyManager): looking_for_group=False) +def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes: + """ A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values """ + schema_dicts = [{field_name: str(field_value) + for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()} + for tensor in tensors] + return DHTID.generate(source=schema_dicts).to_bytes() + + class MatchmakingException(Exception): """ An internal exception that marks undesired edge cases during averaging """