Skip to content

Commit

Permalink
add more copy_from method (#36978)
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy authored Nov 5, 2021
1 parent d572fa2 commit f00f4fc
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 10 deletions.
32 changes: 22 additions & 10 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,17 @@ static int GetNCCLVersion() {
}
#endif

template <typename PlaceType>
static void TensorCopyFrom(framework::Tensor *dst, const framework::Tensor &src,
const PlaceType &place, int64_t batch_size) {
if (batch_size < 0) {
framework::TensorCopy(src, place, dst);
} else {
auto sliced = src.Slice(0, batch_size);
framework::TensorCopy(sliced, place, dst);
}
}

#ifdef PADDLE_WITH_AVX
PYBIND11_MODULE(core_avx, m) {
#else
Expand Down Expand Up @@ -755,16 +766,17 @@ PYBIND11_MODULE(core_noavx, m) {
paddle::framework::proto::VarType::Type type) {
return reinterpret_cast<uintptr_t>(self.mutable_data(place, type));
})
.def("_copy_from",
[](framework::Tensor &self, const framework::Tensor &other,
const platform::Place &place, int64_t batch_size) {
if (batch_size < 0) {
framework::TensorCopy(other, place, &self);
} else {
auto sliced = other.Slice(0, batch_size);
framework::TensorCopy(sliced, place, &self);
}
},
.def("_copy_from", &TensorCopyFrom<paddle::platform::CPUPlace>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("_copy_from", &TensorCopyFrom<paddle::platform::XPUPlace>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("_copy_from", &TensorCopyFrom<paddle::platform::CUDAPlace>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("_copy_from", &TensorCopyFrom<paddle::platform::NPUPlace>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("_copy_from", &TensorCopyFrom<paddle::platform::CUDAPinnedPlace>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("_copy_from", &TensorCopyFrom<paddle::platform::Place>,
py::arg("tensor"), py::arg("place"), py::arg("batch_size") = -1)
.def("set", SetTensorFromPyArray<paddle::platform::CPUPlace>,
py::arg("array"), py::arg("place"), py::arg("zero_copy") = false)
Expand Down
39 changes: 39 additions & 0 deletions python/paddle/fluid/tests/unittests/test_tensor_copy_from.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import unittest
import numpy as np
from paddle.fluid.core import LoDTensor as Tensor


class TestTensorCopyFrom(unittest.TestCase):
def test_main(self):
place = paddle.CPUPlace()
np_value = np.random.random(size=[10, 30]).astype('float32')
t_src = Tensor()
t_src.set(np_value, place)
self.assertTrue(np.array_equal(np_value, t_src))

t_dst1 = Tensor()
t_dst1._copy_from(t_src, place)
self.assertTrue(np.array_equal(np_value, t_dst1))

t_dst2 = Tensor()
t_dst2._copy_from(t_src, place, 5)
self.assertTrue(np.array_equal(np.array(np_value[0:5]), t_dst2))


if __name__ == "__main__":
unittest.main()

0 comments on commit f00f4fc

Please sign in to comment.