-
Notifications
You must be signed in to change notification settings - Fork 168
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: ExaTrkX edge building KDTree on CPU + fixes + refactor + tests (#…
…2360) * Replaces brute force edge building with `Acts::KDTree` method * Abstracts `std::vector` to `torch::Tensor` conversions * Fixes bug in edge duplicate removal * Add more unit tests and enable in CI
- Loading branch information
1 parent
9cd59fd
commit 0486e42
Showing
12 changed files
with
640 additions
and
288 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
36 changes: 0 additions & 36 deletions
36
Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/buildEdges.hpp
This file was deleted.
Oops, something went wrong.
92 changes: 92 additions & 0 deletions
92
Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
// This file is part of the Acts project. | ||
// | ||
// Copyright (C) 2023 CERN for the benefit of the Acts project | ||
// | ||
// This Source Code Form is subject to the terms of the Mozilla Public | ||
// License, v. 2.0. If a copy of the MPL was not distributed with this | ||
// file, You can obtain one at http://mozilla.org/MPL/2.0/. | ||
|
||
#pragma once | ||
|
||
#include "Acts/Utilities/Concepts.hpp" | ||
|
||
#include <cstdint> | ||
#include <vector> | ||
|
||
#include <torch/torch.h> | ||
|
||
namespace Acts::detail { | ||
|
||
/// So far this is only needed for integers | ||
template <typename T> | ||
struct TorchTypeMap {}; | ||
|
||
template <> | ||
struct TorchTypeMap<int64_t> { | ||
constexpr static torch::Dtype type = torch::kInt64; | ||
}; | ||
|
||
template <> | ||
struct TorchTypeMap<int32_t> { | ||
constexpr static torch::Dtype type = torch::kInt32; | ||
}; | ||
|
||
template <> | ||
struct TorchTypeMap<int16_t> { | ||
constexpr static torch::Dtype type = torch::kInt16; | ||
}; | ||
|
||
template <> | ||
struct TorchTypeMap<int8_t> { | ||
constexpr static torch::Dtype type = torch::kInt8; | ||
}; | ||
|
||
template <> | ||
struct TorchTypeMap<float> { | ||
constexpr static torch::Dtype type = torch::kFloat32; | ||
}; | ||
|
||
template <> | ||
struct TorchTypeMap<double> { | ||
constexpr static torch::Dtype type = torch::kFloat64; | ||
}; | ||
|
||
/// Converts vector to 2D tensor | ||
/// Make sure your vector has a even number of elements! | ||
/// @Note Input must be mutable, due to torch API. | ||
/// @Note Tensor does not take ownership! `.clone()` afterwards to get | ||
/// ownership of the data | ||
template <typename T> | ||
at::Tensor vectorToTensor2D(std::vector<T> &vec, std::size_t cols) { | ||
assert(vec.size() % cols == 0); | ||
|
||
auto opts = | ||
at::TensorOptions().dtype(TorchTypeMap<T>::type).device(torch::kCPU); | ||
|
||
return torch::from_blob( | ||
vec.data(), | ||
{static_cast<long>(vec.size() / cols), static_cast<long>(cols)}, opts); | ||
} | ||
|
||
/// Converts 2D tensor to vector | ||
/// @Note Automatically converts tensor to target type! | ||
template <typename T> | ||
std::vector<T> tensor2DToVector(const at::Tensor &tensor) { | ||
assert(tensor.sizes().size() == 2); | ||
|
||
// clone to make sure we own the data | ||
// bring to CPU | ||
// convert to requested type | ||
// ensure the tensor is contiguous (e.g. not the case if indexed with step) | ||
|
||
at::Tensor transformedTensor = | ||
tensor.to(torch::kCPU).to(TorchTypeMap<T>::type).contiguous(); | ||
|
||
std::vector<T> edgeIndex( | ||
transformedTensor.template data_ptr<T>(), | ||
transformedTensor.template data_ptr<T>() + transformedTensor.numel()); | ||
|
||
return edgeIndex; | ||
} | ||
|
||
} // namespace Acts::detail |
47 changes: 47 additions & 0 deletions
47
Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/detail/buildEdges.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
// This file is part of the Acts project. | ||
// | ||
// Copyright (C) 2022 CERN for the benefit of the Acts project | ||
// | ||
// This Source Code Form is subject to the terms of the Mozilla Public | ||
// License, v. 2.0. If a copy of the MPL was not distributed with this | ||
// file, You can obtain one at http://mozilla.org/MPL/2.0/. | ||
|
||
#pragma once | ||
|
||
#include <cstdint> | ||
|
||
namespace at { | ||
class Tensor; | ||
} | ||
|
||
namespace Acts { | ||
namespace detail { | ||
|
||
/// Post process edges | ||
at::Tensor postprocessEdgeTensor(at::Tensor edges, bool removeSelfLoops = true, | ||
bool removeDuplicates = true, | ||
bool flipDirections = false); | ||
|
||
/// Edge building using FRNN and CUDA. | ||
/// Raises an exception if not built with CUDA | ||
at::Tensor buildEdgesFRNN(at::Tensor& embedFeatures, float rVal, int kVal, | ||
bool flipDirections = false); | ||
|
||
/// Edge building using the Acts KD-Tree implementation | ||
/// Note that this implementation has no maximum number of neighbours | ||
/// in the NN search. kVal is only a hint for reserving memory | ||
at::Tensor buildEdgesKDTree(at::Tensor& embedFeatures, float rVal, int kVal, | ||
bool flipDirections = false); | ||
|
||
/// Dispatches either to FRNN or KD-Tree based edge building | ||
/// | ||
/// @param embedFeatures Tensor of shape (n_nodes, embedding_dim) | ||
/// @param rVal radius for NN search | ||
/// @param kVal max number of neighbours in NN search | ||
/// @param flipDirections if we want to randomly flip directions of the | ||
/// edges after the edge building | ||
at::Tensor buildEdges(at::Tensor& embedFeatures, float rVal, int kVal, | ||
bool flipDirections = false); | ||
|
||
} // namespace detail | ||
} // namespace Acts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.