Skip to content

Commit

Permalink
fix for issue #266
Browse files Browse the repository at this point in the history
  • Loading branch information
joanglaunes committed Dec 21, 2022
1 parent 2737511 commit 900f232
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 6 deletions.
8 changes: 7 additions & 1 deletion keopscore/keopscore/formulas/complex/ComplexAdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,11 @@ def ScalarOp(self, out, inF, inG):
return string

def DiffT(self, v, gradin):
from keopscore.formulas.complex.ComplexSum import ComplexSum
f, g = self.children
return ComplexAdd(f.DiffT(v, gradin), g.DiffT(v, gradin))
if f.dim == 2 and g.dim > 2:
return ComplexAdd(f.DiffT(v, ComplexSum(gradin)), g.DiffT(v, gradin))
elif g.dim == 2 and f.dim > 2:
return ComplexAdd(f.DiffT(v, gradin), g.DiffT(v, ComplexSum(gradin)))
else:
return ComplexAdd(f.DiffT(v, gradin), g.DiffT(v, gradin))
19 changes: 17 additions & 2 deletions keopscore/keopscore/formulas/complex/ComplexMult.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from keopscore.formulas.VectorizedComplexScalarOp import VectorizedComplexScalarOp
from keopscore.formulas.complex.Conj import Conj
from keopscore.formulas.complex.ComplexAdd import ComplexAdd

# /////////////////////////////////////////////////////////////////////////
# //// ComplexMult ////
Expand All @@ -16,7 +17,21 @@ def ScalarOp(self, out, inF, inG):
return string

def DiffT(self, v, gradin):
from keopscore.formulas.complex.ComplexSum import ComplexSum
f, g = self.children
return f.DiffT(v, ComplexMult(Conj(g), gradin)) + g.DiffT(
v, ComplexMult(Conj(f), gradin)
if f.dim == 2 and g.dim > 2:
return ComplexAdd(
f.DiffT(v, ComplexSum(ComplexMult(Conj(g), gradin))),
g.DiffT(v, ComplexMult(Conj(f), gradin))
)
elif g.dim == 2 and f.dim > 2:
return ComplexAdd(
f.DiffT(v, ComplexMult(Conj(g), gradin)),
g.DiffT(v, ComplexSum(ComplexMult(Conj(f), gradin)))
)
else:
return ComplexAdd(
f.DiffT(v, ComplexMult(Conj(g), gradin)),
g.DiffT(v, ComplexMult(Conj(f), gradin))
)

8 changes: 7 additions & 1 deletion keopscore/keopscore/formulas/complex/ComplexSubtract.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,11 @@ def ScalarOp(self, out, inF, inG):
return string

def DiffT(self, v, gradin):
from keopscore.formulas.complex.ComplexSum import ComplexSum
f, g = self.children
return ComplexSubtract(f.DiffT(v, gradin), g.DiffT(v, gradin))
if f.dim == 2 and g.dim > 2:
return ComplexSubtract(f.DiffT(v, ComplexSum(gradin)), g.DiffT(v, gradin))
elif g.dim == 2 and f.dim > 2:
return ComplexSubtract(f.DiffT(v, gradin), g.DiffT(v, ComplexSum(gradin)))
else:
return ComplexSubtract(f.DiffT(v, gradin), g.DiffT(v, gradin))
8 changes: 7 additions & 1 deletion keopscore/keopscore/formulas/maths/Add.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from keopscore.formulas.Operation import Broadcast
from keopscore.formulas.VectorizedScalarOp import VectorizedScalarOp
from keopscore.formulas.maths.Mult import Mult_Impl
from keopscore.formulas.maths.Sum import Sum
from keopscore.formulas.variables.IntCst import IntCst, IntCst_Impl
from keopscore.formulas.variables.Zero import Zero

Expand All @@ -23,7 +24,12 @@ def ScalarOp(self, out, arg0, arg1):

def DiffT(self, v, gradin):
fa, fb = self.children
return fa.DiffT(v, gradin) + fb.DiffT(v, gradin)
if fa.dim == 1 and fb.dim > 1:
return fa.DiffT(v, Sum(gradin)) + fb.DiffT(v, gradin)
elif fb.dim == 1 and fa.dim > 1:
return fa.DiffT(v, gradin) + fb.DiffT(v, Sum(gradin))
else:
return fa.DiffT(v, gradin) + fb.DiffT(v, gradin)

# parameters for testing the operation (optional)
nargs = 2 # number of arguments
Expand Down
8 changes: 7 additions & 1 deletion keopscore/keopscore/formulas/maths/Subtract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from keopscore.formulas.VectorizedScalarOp import VectorizedScalarOp
from keopscore.formulas.variables.Zero import Zero
from keopscore.formulas.maths.Mult import Mult_Impl
from keopscore.formulas.maths.Sum import Sum
from keopscore.formulas.variables.IntCst import IntCst, IntCst_Impl

##########################
Expand All @@ -22,7 +23,12 @@ def ScalarOp(self, out, arg0, arg1):

def DiffT(self, v, gradin):
fa, fb = self.children
return fa.DiffT(v, gradin) - fb.DiffT(v, gradin)
if fa.dim == 1 and fb.dim > 1:
return fa.DiffT(v, Sum(gradin)) - fb.DiffT(v, gradin)
elif fb.dim == 1 and fa.dim > 1:
return fa.DiffT(v, gradin) - fb.DiffT(v, Sum(gradin))
else:
return fa.DiffT(v, gradin) - fb.DiffT(v, gradin)

# parameters for testing the operation (optional)
nargs = 2 # number of arguments
Expand Down
23 changes: 23 additions & 0 deletions pykeops/pykeops/sandbox/issue_266.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pykeops
pykeops.clean_pykeops()

import torch
from pykeops.torch import LazyTensor

H = 100
N = 1000
L = 5000
D = 10

dtype = torch.complex64
#dtype = torch.float32

# Complex
x_i = LazyTensor(torch.randn(H, N, 1, 1, dtype=dtype, requires_grad=True))
y_j = LazyTensor(torch.randn(1, 1, L, D, dtype=dtype, requires_grad=True))

D_ij = x_i * y_j

a_i = D_ij.sum(dim=1)

a_i.sum().backward()

0 comments on commit 900f232

Please sign in to comment.