Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add scatter composite rule. #52005

Merged
merged 8 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/prim/api/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
- concat
- elementwise_pow
- floor
- gather
- gather_nd
- log
- max
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,27 @@ void cos_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
set_output<T>(x_grad_tmp, x_grad);
}

template <typename T>
void scatter_grad(const Tensor& index,
const Tensor& updates,
const Tensor& out_grad,
bool overwrite,
Tensor* x_grad,
Tensor* updates_grad) {
if (x_grad) {
auto zero_tensor =
full<T>(phi::vectorize(updates.dims()), 0.0, updates.dtype());
auto tmp_grad = scatter<T>(out_grad, index, zero_tensor, false);
set_output<T>(tmp_grad, x_grad);
}

if (updates_grad) {
Scalar tmp_zero = 0;
auto tmp_updates_grad = gather<T>(out_grad, index, tmp_zero);
set_output<T>(tmp_updates_grad, updates_grad);
}
}

template <typename T>
void batch_norm_grad(const Tensor& x,
const Tensor& scale,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1264,6 +1264,7 @@
kernel :
func : scatter_grad
no_need_buffer : updates
composite: scatter_grad(index, updates, out_grad, overwrite, x_grad, updates_grad)

- backward_op : scatter_nd_add_grad
forward : scatter_nd_add (Tensor x, Tensor index, Tensor updates) -> Tensor(out)
Expand Down
138 changes: 112 additions & 26 deletions python/paddle/fluid/tests/unittests/test_scatter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class TestScatterOp(OpTest):
def setUp(self):
self.op_type = "scatter"
zxcd marked this conversation as resolved.
Show resolved Hide resolved
self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 50)).astype(target_dtype)
Expand All @@ -46,10 +48,12 @@ def _set_dtype(self):
self.dtype = np.float32

def test_check_output(self):
self.check_output()
self.check_output(check_prim=True)

def test_check_grad(self):
self.check_grad(["X", "Updates"], "Out")
self.check_grad(
["X", "Updates"], "Out", check_prim=True
)


class TestScatterFP16Op(TestScatterOp):
Expand All @@ -69,18 +73,29 @@ def _set_dtype(self):
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(
place, check_prim=True
)


def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)



class TestScatterOp0(OpTest):
def setUp(self):
self.op_type = "scatter"
self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype)
Expand All @@ -100,10 +115,12 @@ def _set_dtype(self):
self.dtype = np.float32

def test_check_output(self):
self.check_output()
self.check_output(check_prim=True)

def test_check_grad(self):
self.check_grad(["X", "Updates"], "Out")
self.check_grad(
["X", "Updates"], "Out", check_prim=True
)


class TestScatterFP16Op0(TestScatterOp0):
Expand All @@ -123,18 +140,27 @@ def _set_dtype(self):
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(
place, check_prim=True
)

def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)


class TestScatterOp1(OpTest):
def setUp(self):
self.op_type = "scatter"
self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype)
Expand All @@ -157,10 +183,12 @@ def _set_dtype(self):
self.dtype = np.float32

def test_check_output(self):
self.check_output()
self.check_output(check_prim=True)

def test_check_grad(self):
self.check_grad(["X", "Updates"], "Out")
self.check_grad(
["X", "Updates"], "Out", check_prim=True
)


class TestScatterFP16Op1(TestScatterOp1):
Expand All @@ -180,12 +208,19 @@ def _set_dtype(self):
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(
place, check_prim=True
)

def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)


@unittest.skipIf(
Expand All @@ -195,6 +230,8 @@ class TestScatterOp2(OpTest):
def setUp(self):
self.op_type = "scatter"
self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype)
Expand All @@ -215,12 +252,19 @@ def _set_dtype(self):
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-3)
self.check_output_with_place(
place, atol=1e-3, check_prim=True
)

def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)


@unittest.skipIf(
Expand Down Expand Up @@ -248,6 +292,8 @@ class TestScatterOp3(OpTest):
def setUp(self):
self.op_type = "scatter"
self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype)
Expand All @@ -272,12 +318,20 @@ def _set_dtype(self):
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-3)
self.check_output_with_place(
place, atol=1e-3, check_prim=True
)


def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)


@unittest.skipIf(
Expand All @@ -302,6 +356,8 @@ class TestScatterOp4(OpTest):
def setUp(self):
self.op_type = "scatter"
self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype)
Expand All @@ -320,10 +376,12 @@ def _set_dtype(self):
self.dtype = np.float32

def test_check_output(self):
self.check_output()
self.check_output(check_prim=True)

def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out')
self.check_grad(
['X', 'Updates'], 'Out', check_prim=True
)


class TestScatterFP16Op4(TestScatterOp4):
Expand All @@ -343,12 +401,19 @@ def _set_dtype(self):
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(
place, check_prim=True
)

def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)


@unittest.skipIf(
Expand All @@ -358,6 +423,8 @@ class TestScatterOp5(OpTest):
def setUp(self):
self.op_type = "scatter"
self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 3)).astype(target_dtype)
Expand All @@ -378,12 +445,19 @@ def _set_dtype(self):
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-3)
self.check_output_with_place(
place, atol=1e-3, check_prim=True
)

def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)


@unittest.skipIf(
Expand All @@ -408,6 +482,9 @@ class TestScatterOp6(OpTest):
def setUp(self):
self.op_type = "scatter"
self.python_api = paddle.scatter
self.public_python_api = paddle.scatter
self.prim_op_type = "prim"
self.enable_cinn = False
self._set_dtype()
target_dtype = "float16" if self.dtype == np.float16 else "float32"
ref_np = np.ones((3, 50)).astype(target_dtype)
Expand All @@ -426,10 +503,12 @@ def _set_dtype(self):
self.dtype = np.float32

def test_check_output(self):
self.check_output()
self.check_output(check_prim=True)

def test_check_grad(self):
self.check_grad(["X", "Updates"], "Out")
self.check_grad(
["X", "Updates"], "Out", check_prim=True
)


class TestScatterFP16Op6(TestScatterOp6):
Expand All @@ -449,12 +528,19 @@ def _set_dtype(self):
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(
place, check_prim=True
)

def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
self.check_grad_with_place(
place,
['X', 'Updates'],
'Out',
check_prim=True,
)


zxcd marked this conversation as resolved.
Show resolved Hide resolved
class TestScatterAPI(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2906,7 +2906,7 @@ def scatter(x, index, updates, overwrite=True, name=None):
check_variable_and_dtype(
x,
'dtype',
['float32', 'float64', 'float16', 'int32', 'int64'],
['float32', 'float64', 'float16', 'int32', 'int64', 'uint16'],
cxxly marked this conversation as resolved.
Show resolved Hide resolved
'scatter',
)
check_type(overwrite, 'overwrite', bool, 'scatter')
Expand Down