Skip to content

Commit

Permalink
Merge pull request #3 from gty111/batch
Browse files Browse the repository at this point in the history
Fix batch dimension
  • Loading branch information
Eigensystem authored Sep 24, 2024
2 parents 90a4e0b + a067f21 commit a7e7ee7
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion distvae/modules/patch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def forward(self, patch_hidden_state):
)
patch_hidden_state_list = [
torch.empty(
[1, patch_hidden_state.shape[1], patch_height_list[i].item(), patch_hidden_state.shape[-1]],
[patch_hidden_state.shape[0], patch_hidden_state.shape[1], patch_height_list[i].item(), patch_hidden_state.shape[-1]],
dtype=patch_hidden_state.dtype,
device=f"cuda:{self.rank}"
) for i in range(self.world_size)
Expand Down

0 comments on commit a7e7ee7

Please sign in to comment.