Skip to content

Commit

Permalink
Merge pull request #22 from RaghuSpaceRajan/RaghuSpaceRajan-fix_for_c…
Browse files Browse the repository at this point in the history
…heck_shapes

Updated check_shapes() to return the updated variables l_x, α, l_y, β.
  • Loading branch information
jeanfeydy authored Mar 27, 2021
2 parents 52d8b38 + cbc8694 commit 3ae5317
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions geomloss/samples_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def forward(self, *args):
Until then, please check the tutorials :-)"""

l_x, α, x, l_y, β, y = self.process_args(*args)
B, N, M, D = self.check_shapes(l_x, α, x, l_y, β, y)
B, N, M, D, l_x, α, l_y, β = self.check_shapes(l_x, α, x, l_y, β, y)

backend = (
self.backend
Expand Down Expand Up @@ -466,4 +466,4 @@ def check_shapes(self, l_x, α, x, l_y, β, y):
"Weights 'β' and samples 'y' should have compatible shapes."
)

return B, N, M, D
return B, N, M, D, l_x, α, l_y, β

0 comments on commit 3ae5317

Please sign in to comment.