Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]: Move k2.Fsa to C++ #814

Merged
merged 34 commits into from
Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f48563b
Make k2 ragged tensor more PyTorch-y like.
csukuangfj Aug 25, 2021
06a6d20
Refactoring: Start to add the wrapper class AnyTensor.
csukuangfj Aug 26, 2021
2a70298
Refactoring.
csukuangfj Aug 26, 2021
6bc05bf
initial attempt to support autograd.
csukuangfj Aug 26, 2021
c7bb9d5
First working version with autograd for Sum().
csukuangfj Aug 27, 2021
d569b42
Fix comments.
csukuangfj Aug 27, 2021
dcea808
Support __getitem__ and pickling.
csukuangfj Aug 27, 2021
cb4f00f
Add more docs for k2.ragged.Tensor
csukuangfj Aug 27, 2021
1b5c015
Put documentation in header files.
csukuangfj Aug 28, 2021
a8d4a8e
Minor fixes.
csukuangfj Aug 28, 2021
1f78c93
Fix a typo.
csukuangfj Aug 28, 2021
892fb04
Fix an error.
csukuangfj Aug 28, 2021
fb96d97
Add more doc.
csukuangfj Aug 28, 2021
2f01361
Wrap RaggedShape.
csukuangfj Aug 29, 2021
626cc7a
[Not for Merge]: Move k2.Fsa related code to C++.
csukuangfj Aug 29, 2021
0e60a69
Merge remote-tracking branch 'dan/master' into fsa
csukuangfj Sep 8, 2021
f11947e
Remove extra files.
csukuangfj Sep 8, 2021
9ac1e78
Update doc URL. (#821)
csukuangfj Sep 8, 2021
44ff35b
Merge remote-tracking branch 'dan/master' into fsa
csukuangfj Sep 12, 2021
1dc7e1e
Support manipulating attributes of k2.ragged.Fsa.
csukuangfj Sep 12, 2021
bbe0ded
Support indexing 2-axes RaggedTensor, Support slicing for RaggedTenso…
pkufool Sep 14, 2021
2c28070
Prune with max_arcs in IntersectDense (#820)
pkufool Sep 14, 2021
210175c
Release v1.8
pkufool Sep 14, 2021
33a212c
Create a ragged tensor from a regular tensor. (#827)
csukuangfj Sep 15, 2021
971af7d
Trigger GitHub actions manually. (#829)
csukuangfj Sep 16, 2021
646704e
Run GitHub actions on merging. (#830)
csukuangfj Sep 16, 2021
8030001
Support printing ragged tensors in a more compact way. (#831)
csukuangfj Sep 17, 2021
7029b1f
Merge remote-tracking branch 'dan/master' into fsa
csukuangfj Sep 18, 2021
d73a5b5
Add levenshtein alignment (#828)
pkufool Sep 19, 2021
f2fd997
Release v1.9
pkufool Sep 19, 2021
b2cb9c0
Add Fsa.get_forward_scores.
csukuangfj Sep 19, 2021
13408aa
Merge remote-tracking branch 'dan/master' into fsa
csukuangfj Sep 19, 2021
cbff6a1
Implement backprop for Fsa.get_forward_scores()
csukuangfj Sep 19, 2021
cca7a54
Construct RaggedArc from unary function tensor (#30)
pkufool Sep 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion k2/python/csrc/torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ set(torch_srcs
ragged_ops.cu
torch_util.cu

v2/k2.cu
v2/any.cu
v2/doc/doc.cu
v2/fsa.cu
v2/k2.cu
v2/ragged_any.cu
v2/ragged_arc.cu
v2/ragged_shape.cu
)

Expand Down
69 changes: 38 additions & 31 deletions k2/python/csrc/torch/v2/any.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,19 @@ void PybindRaggedAny(py::module &m) {
any.def(
"clone",
[](const RaggedAny &self) -> RaggedAny {
DeviceGuard guard(self.any_.Context());
DeviceGuard guard(self.any.Context());
return self.Clone();
},
kRaggedAnyCloneDoc);

any.def(
"__eq__",
[](const RaggedAny &self, const RaggedAny &other) -> bool {
DeviceGuard guard(self.any_.Context());
Dtype t = self.any_.GetDtype();
DeviceGuard guard(self.any.Context());
Dtype t = self.any.GetDtype();
bool ans = false;
FOR_REAL_AND_INT32_TYPES(t, T, {
ans = Equal<T>(self.any_.Specialize<T>(), other.any_.Specialize<T>());
ans = Equal<T>(self.any.Specialize<T>(), other.any.Specialize<T>());
});
return ans;
},
Expand All @@ -104,12 +104,11 @@ void PybindRaggedAny(py::module &m) {
any.def(
"__ne__",
[](const RaggedAny &self, const RaggedAny &other) -> bool {
DeviceGuard guard(self.any_.Context());
Dtype t = self.any_.GetDtype();
DeviceGuard guard(self.any.Context());
Dtype t = self.any.GetDtype();
bool ans = false;
FOR_REAL_AND_INT32_TYPES(t, T, {
ans =
!Equal<T>(self.any_.Specialize<T>(), other.any_.Specialize<T>());
ans = !Equal<T>(self.any.Specialize<T>(), other.any.Specialize<T>());
});
return ans;
},
Expand All @@ -124,37 +123,37 @@ void PybindRaggedAny(py::module &m) {
any.def(
"numel",
[](RaggedAny &self) -> int32_t {
DeviceGuard guard(self.any_.Context());
return self.any_.NumElements();
DeviceGuard guard(self.any.Context());
return self.any.NumElements();
},
kRaggedAnyNumelDoc);

any.def(
"tot_size",
[](const RaggedAny &self, int32_t axis) -> int32_t {
DeviceGuard guard(self.any_.Context());
return self.any_.TotSize(axis);
DeviceGuard guard(self.any.Context());
return self.any.TotSize(axis);
},
py::arg("axis"), kRaggedAnyTotSizeDoc);

any.def(py::pickle(
[](const RaggedAny &self) -> py::tuple {
DeviceGuard guard(self.any_.Context());
K2_CHECK(self.any_.NumAxes() == 2 || self.any_.NumAxes() == 3)
DeviceGuard guard(self.any.Context());
K2_CHECK(self.any.NumAxes() == 2 || self.any.NumAxes() == 3)
<< "Only support Ragged with NumAxes() == 2 or 3 for now, given "
<< self.any_.NumAxes();
Array1<int32_t> row_splits1 = self.any_.RowSplits(1);
Dtype t = self.any_.GetDtype();
<< self.any.NumAxes();
Array1<int32_t> row_splits1 = self.any.RowSplits(1);
Dtype t = self.any.GetDtype();

FOR_REAL_AND_INT32_TYPES(t, T, {
auto values = self.any_.Specialize<T>().values;
auto values = self.any.Specialize<T>().values;
// We use "row_ids" placeholder here to make it compatible for the
// old format file.
if (self.any_.NumAxes() == 2) {
if (self.any.NumAxes() == 2) {
return py::make_tuple(ToTorch(row_splits1), "row_ids1",
ToTorch(values));
} else {
Array1<int32_t> row_splits2 = self.any_.RowSplits(2);
Array1<int32_t> row_splits2 = self.any.RowSplits(2);
return py::make_tuple(ToTorch(row_splits1), "row_ids1",
ToTorch(row_splits2), "row_ids2",
ToTorch(values));
Expand Down Expand Up @@ -213,7 +212,7 @@ void PybindRaggedAny(py::module &m) {
any.def_property_readonly(
"dtype",
[](const RaggedAny &self) -> py::object {
Dtype t = self.any_.GetDtype();
Dtype t = self.any.GetDtype();
auto torch = py::module::import("torch");
switch (t) {
case kFloatDtype:
Expand All @@ -234,10 +233,10 @@ void PybindRaggedAny(py::module &m) {
any.def_property_readonly(
"device",
[](const RaggedAny &self) -> py::object {
DeviceType d = self.any_.Context()->GetDeviceType();
DeviceType d = self.any.Context()->GetDeviceType();
torch::DeviceType device_type = ToTorchDeviceType(d);

torch::Device device(device_type, self.any_.Context()->GetDeviceId());
torch::Device device(device_type, self.any.Context()->GetDeviceId());

PyObject *ptr = THPDevice_New(device);

Expand All @@ -251,22 +250,23 @@ void PybindRaggedAny(py::module &m) {
any.def_property_readonly(
"data",
[](RaggedAny &self) -> torch::Tensor {
Dtype t = self.any_.GetDtype();
Dtype t = self.any.GetDtype();
FOR_REAL_AND_INT32_TYPES(
t, T, { return ToTorch(self.any_.values.Specialize<T>()); });
t, T, { return ToTorch(self.any.values.Specialize<T>()); });

// Unreachable code
return {};
},
kRaggedAnyDataDoc);

any.def_property_readonly(
"shape", [](RaggedAny &self) -> RaggedShape { return self.any_.shape; });
"shape", [](RaggedAny &self) -> RaggedShape { return self.any.shape; },
"Return the ``Shape`` of this tensor.");

any.def_property_readonly(
"grad",
[](RaggedAny &self) -> torch::optional<torch::Tensor> {
if (!self.data_.defined()) return {};
if (!self.data.defined()) return {};

return self.Data().grad();
},
Expand All @@ -275,7 +275,7 @@ void PybindRaggedAny(py::module &m) {
any.def_property(
"requires_grad",
[](RaggedAny &self) -> bool {
if (!self.data_.defined()) return false;
if (!self.data.defined()) return false;

return self.Data().requires_grad();
},
Expand All @@ -287,19 +287,19 @@ void PybindRaggedAny(py::module &m) {
any.def_property_readonly(
"is_cuda",
[](RaggedAny &self) -> bool {
return self.any_.Context()->GetDeviceType() == kCuda;
return self.any.Context()->GetDeviceType() == kCuda;
},
kRaggedAnyIsCudaDoc);

// NumAxes() does not access GPU memory
any.def_property_readonly(
"num_axes",
[](const RaggedAny &self) -> int32_t { return self.any_.NumAxes(); },
[](const RaggedAny &self) -> int32_t { return self.any.NumAxes(); },
kRaggedAnyNumAxesDoc);

// Dim0() does not access GPU memory
any.def_property_readonly(
"dim0", [](const RaggedAny &self) -> int32_t { return self.any_.Dim0(); },
"dim0", [](const RaggedAny &self) -> int32_t { return self.any.Dim0(); },
kRaggedAnyDim0Doc);

//==================================================
Expand All @@ -313,6 +313,13 @@ void PybindRaggedAny(py::module &m) {
return RaggedAny(data, dtype);
},
py::arg("data"), py::arg("dtype") = py::none(), kRaggedAnyInitDataDoc);

m.def(
"create_tensor",
[](const std::string &s, py::object dtype = py::none()) -> RaggedAny {
return RaggedAny(s, dtype);
},
py::arg("s"), py::arg("dtype") = py::none(), kRaggedAnyInitStrDoc);
}

} // namespace k2
81 changes: 81 additions & 0 deletions k2/python/csrc/torch/v2/autograd/arc_sort.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/**
* @brief A wrapper around k2::ArcSort to support autograd
*
* @copyright
* Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef K2_PYTHON_CSRC_TORCH_V2_AUTOGRAD_ARC_SORT_H
#define K2_PYTHON_CSRC_TORCH_V2_AUTOGRAD_ARC_SORT_H

#include "k2/csrc/fsa_algo.h"
#include "k2/csrc/ragged_ops.h"
#include "k2/csrc/tensor.h"
#include "k2/csrc/tensor_ops.h"
#include "k2/python/csrc/torch/torch_util.h"
#include "k2/python/csrc/torch/v2/ragged_arc.h"

using namespace torch::autograd;

namespace k2 {

// see https://pytorch.org/tutorials/advanced/cpp_autograd
class ArcSortFunction : public torch::autograd::Function<ArcSortFunction> {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how we do backprop for ArcSort.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah OK, cool.

public:
/* ArcSort an Fsa. It is a wrapper around k2::ArcSort, supporting autograd.


@param ragged The input Fsa.
@param dummy Its purpose is to make autograd to track the operations on
the input `ragged`. It is the same as `ragged.scores`.
@param out The output Fsa.

@return Return a 1-D unused tensor, which is out->scores.
*/
static torch::Tensor forward(AutogradContext *ctx,
/*const*/ RaggedArc &ragged,
torch::Tensor /*dummy*/, RaggedArc *out) {
Array1<int32_t> arc_map;
ArcSort(ragged.fsa, &out->fsa, &arc_map);

ctx->save_for_backward({ToTorch(arc_map)});

return out->Scores();
}

static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
torch::Tensor arc_map_tensor = saved[0];
Array1<int32_t> arc_map = FromTorch<int32_t>(arc_map_tensor);
torch::Tensor grad_output_tensor = grad_outputs[0];
Tensor grad_output = FromTorch(grad_output_tensor, TensorTag{});

Tensor ans = Index(grad_output, arc_map, /*allow_minus_one*/ false,
/*default_value*/ 0);

return {
torch::Tensor(), // ragged
ToTorch(ans), // dummy
torch::Tensor() // out
};
}
};

} // namespace k2

#endif // K2_PYTHON_CSRC_TORCH_V2_AUTOGRAD_ARC_SORT_H
15 changes: 7 additions & 8 deletions k2/python/csrc/torch/v2/autograd/sum.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class SumFunction : public torch::autograd::Function<SumFunction> {

@param ragged The input RaggedAny
@param dummy Its purpose is to make autograd to track the operations on
the input `ragged`. It is the same as `ragged.data_`.
the input `ragged`. It is the same as `ragged.data`.
@param initial_value This value is added to the sum of each sublist,
so when a sublist is empty, its sum is this value.

Expand All @@ -50,21 +50,20 @@ class SumFunction : public torch::autograd::Function<SumFunction> {
*/
static torch::Tensor forward(AutogradContext *ctx, const RaggedAny &ragged,
torch::Tensor /*dummy*/, float initial_value) {
ctx->saved_data["n"] = ragged.any_.values.Dim();
ctx->saved_data["n"] = ragged.any.values.Dim();

int32_t num_axes = ragged.any_.NumAxes();
int32_t num_axes = ragged.any.NumAxes();

torch::Tensor row_ids =
ToTorch(const_cast<RaggedAny &>(ragged).any_.RowIds(num_axes - 1));
ToTorch(const_cast<RaggedAny &>(ragged).any.RowIds(num_axes - 1));

ctx->save_for_backward({row_ids});

Dtype t = ragged.any_.GetDtype();
Dtype t = ragged.any.GetDtype();

FOR_REAL_AND_INT32_TYPES(t, T, {
Array1<T> values(ragged.any_.Context(),
ragged.any_.TotSize(num_axes - 2));
SumPerSublist<T>(ragged.any_.Specialize<T>(), initial_value, &values);
Array1<T> values(ragged.any.Context(), ragged.any.TotSize(num_axes - 2));
SumPerSublist<T>(ragged.any.Specialize<T>(), initial_value, &values);
return ToTorch(values);
});

Expand Down
68 changes: 68 additions & 0 deletions k2/python/csrc/torch/v2/fsa.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/**
* @brief python wrapper for Ragged<Arc>
*
* @copyright
* Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "k2/csrc/ragged.h"
#include "k2/python/csrc/torch/v2/fsa.h"
#include "k2/python/csrc/torch/v2/ragged_arc.h"

namespace k2 {

void PybindRaggedArc(py::module &m) {
py::class_<RaggedArc> fsa(m, "Fsa");
fsa.def(py::init<>());

fsa.def(py::init<const std::string &, py::list>(), py::arg("s"),
py::arg("extra_label_names") = py::none());
fsa.def("__str__", &RaggedArc::ToString);
fsa.def("__repr__", &RaggedArc::ToString);

fsa.def("requires_grad_", &RaggedArc::SetRequiresGrad,
py::arg("requires_grad") = true);

fsa.def("arc_sort", &RaggedArc::ArcSort);

fsa.def_property(
"scores", [](RaggedArc &self) -> torch::Tensor { return self.Scores(); },
[](RaggedArc &self, torch::Tensor scores) {
self.Scores().copy_(scores);
});

fsa.def_property_readonly(
"grad", [](RaggedArc &self) -> torch::optional<torch::Tensor> {
if (!self.scores.defined()) return {};

return self.Scores().grad();
});

fsa.def_property(
"requires_grad",
[](RaggedArc &self) -> bool {
if (!self.scores.defined()) return false;

return self.Scores().requires_grad();
},
[](RaggedArc &self, bool requires_grad) -> void {
self.SetRequiresGrad(requires_grad);
});
}

} // namespace k2
Loading