-
Notifications
You must be signed in to change notification settings - Fork 217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP]: Move k2.Fsa to C++ #814
Merged
Merged
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
f48563b
Make k2 ragged tensor more PyTorch-y like.
csukuangfj 06a6d20
Refactoring: Start to add the wrapper class AnyTensor.
csukuangfj 2a70298
Refactoring.
csukuangfj 6bc05bf
initial attempt to support autograd.
csukuangfj c7bb9d5
First working version with autograd for Sum().
csukuangfj d569b42
Fix comments.
csukuangfj dcea808
Support __getitem__ and pickling.
csukuangfj cb4f00f
Add more docs for k2.ragged.Tensor
csukuangfj 1b5c015
Put documentation in header files.
csukuangfj a8d4a8e
Minor fixes.
csukuangfj 1f78c93
Fix a typo.
csukuangfj 892fb04
Fix an error.
csukuangfj fb96d97
Add more doc.
csukuangfj 2f01361
Wrap RaggedShape.
csukuangfj 626cc7a
[Not for Merge]: Move k2.Fsa related code to C++.
csukuangfj 0e60a69
Merge remote-tracking branch 'dan/master' into fsa
csukuangfj f11947e
Remove extra files.
csukuangfj 9ac1e78
Update doc URL. (#821)
csukuangfj 44ff35b
Merge remote-tracking branch 'dan/master' into fsa
csukuangfj 1dc7e1e
Support manipulating attributes of k2.ragged.Fsa.
csukuangfj bbe0ded
Support indexing 2-axes RaggedTensor, Support slicing for RaggedTenso…
pkufool 2c28070
Prune with max_arcs in IntersectDense (#820)
pkufool 210175c
Release v1.8
pkufool 33a212c
Create a ragged tensor from a regular tensor. (#827)
csukuangfj 971af7d
Trigger GitHub actions manually. (#829)
csukuangfj 646704e
Run GitHub actions on merging. (#830)
csukuangfj 8030001
Support printing ragged tensors in a more compact way. (#831)
csukuangfj 7029b1f
Merge remote-tracking branch 'dan/master' into fsa
csukuangfj d73a5b5
Add levenshtein alignment (#828)
pkufool f2fd997
Release v1.9
pkufool b2cb9c0
Add Fsa.get_forward_scores.
csukuangfj 13408aa
Merge remote-tracking branch 'dan/master' into fsa
csukuangfj cbff6a1
Implement backprop for Fsa.get_forward_scores()
csukuangfj cca7a54
Construct RaggedArc from unary function tensor (#30)
pkufool File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
/** | ||
* @brief A wrapper around k2::ArcSort to support autograd | ||
* | ||
* @copyright | ||
* Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) | ||
* | ||
* @copyright | ||
* See LICENSE for clarification regarding multiple authors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#ifndef K2_PYTHON_CSRC_TORCH_V2_AUTOGRAD_ARC_SORT_H | ||
#define K2_PYTHON_CSRC_TORCH_V2_AUTOGRAD_ARC_SORT_H | ||
|
||
#include "k2/csrc/fsa_algo.h" | ||
#include "k2/csrc/ragged_ops.h" | ||
#include "k2/csrc/tensor.h" | ||
#include "k2/csrc/tensor_ops.h" | ||
#include "k2/python/csrc/torch/torch_util.h" | ||
#include "k2/python/csrc/torch/v2/ragged_arc.h" | ||
|
||
using namespace torch::autograd; | ||
|
||
namespace k2 { | ||
|
||
// see https://pytorch.org/tutorials/advanced/cpp_autograd | ||
class ArcSortFunction : public torch::autograd::Function<ArcSortFunction> { | ||
public: | ||
/* ArcSort an Fsa. It is a wrapper around k2::ArcSort, supporting autograd. | ||
|
||
|
||
@param ragged The input Fsa. | ||
@param dummy Its purpose is to make autograd to track the operations on | ||
the input `ragged`. It is the same as `ragged.scores`. | ||
@param out The output Fsa. | ||
|
||
@return Return a 1-D unused tensor, which is out->scores. | ||
*/ | ||
static torch::Tensor forward(AutogradContext *ctx, | ||
/*const*/ RaggedArc &ragged, | ||
torch::Tensor /*dummy*/, RaggedArc *out) { | ||
Array1<int32_t> arc_map; | ||
ArcSort(ragged.fsa, &out->fsa, &arc_map); | ||
|
||
ctx->save_for_backward({ToTorch(arc_map)}); | ||
|
||
return out->Scores(); | ||
} | ||
|
||
static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) { | ||
auto saved = ctx->get_saved_variables(); | ||
torch::Tensor arc_map_tensor = saved[0]; | ||
Array1<int32_t> arc_map = FromTorch<int32_t>(arc_map_tensor); | ||
torch::Tensor grad_output_tensor = grad_outputs[0]; | ||
Tensor grad_output = FromTorch(grad_output_tensor, TensorTag{}); | ||
|
||
Tensor ans = Index(grad_output, arc_map, /*allow_minus_one*/ false, | ||
/*default_value*/ 0); | ||
|
||
return { | ||
torch::Tensor(), // ragged | ||
ToTorch(ans), // dummy | ||
torch::Tensor() // out | ||
}; | ||
} | ||
}; | ||
|
||
} // namespace k2 | ||
|
||
#endif // K2_PYTHON_CSRC_TORCH_V2_AUTOGRAD_ARC_SORT_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
/** | ||
* @brief python wrapper for Ragged<Arc> | ||
* | ||
* @copyright | ||
* Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) | ||
* | ||
* @copyright | ||
* See LICENSE for clarification regarding multiple authors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "k2/csrc/ragged.h" | ||
#include "k2/python/csrc/torch/v2/fsa.h" | ||
#include "k2/python/csrc/torch/v2/ragged_arc.h" | ||
|
||
namespace k2 { | ||
|
||
void PybindRaggedArc(py::module &m) { | ||
py::class_<RaggedArc> fsa(m, "Fsa"); | ||
fsa.def(py::init<>()); | ||
|
||
fsa.def(py::init<const std::string &, py::list>(), py::arg("s"), | ||
py::arg("extra_label_names") = py::none()); | ||
fsa.def("__str__", &RaggedArc::ToString); | ||
fsa.def("__repr__", &RaggedArc::ToString); | ||
|
||
fsa.def("requires_grad_", &RaggedArc::SetRequiresGrad, | ||
py::arg("requires_grad") = true); | ||
|
||
fsa.def("arc_sort", &RaggedArc::ArcSort); | ||
|
||
fsa.def_property( | ||
"scores", [](RaggedArc &self) -> torch::Tensor { return self.Scores(); }, | ||
[](RaggedArc &self, torch::Tensor scores) { | ||
self.Scores().copy_(scores); | ||
}); | ||
|
||
fsa.def_property_readonly( | ||
"grad", [](RaggedArc &self) -> torch::optional<torch::Tensor> { | ||
if (!self.scores.defined()) return {}; | ||
|
||
return self.Scores().grad(); | ||
}); | ||
|
||
fsa.def_property( | ||
"requires_grad", | ||
[](RaggedArc &self) -> bool { | ||
if (!self.scores.defined()) return false; | ||
|
||
return self.Scores().requires_grad(); | ||
}, | ||
[](RaggedArc &self, bool requires_grad) -> void { | ||
self.SetRequiresGrad(requires_grad); | ||
}); | ||
} | ||
|
||
} // namespace k2 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is how we do
backprop
forArcSort
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah OK, cool.