Skip to content

Commit

Permalink
Replace std::terminate with throw (#3)
Browse files Browse the repository at this point in the history
* Replace terminate with throw

* Add Jet exception and implement tests

* Remove code paths protected by static_assert

* Update testing to check exception output

* Update Abort docstrings

* Add tests for Abort.hpp

* Fix linting of tests

* Remove unneeded function

* Update JetException docstring

* Fix linting

* Remove Fatal word

* Rename invalid_tensor_file and extend from JetException

* Update CHANGELOG

* Assume row-major order for indices (#10)

* Switch multi-index to row-major order

* Replace 'utilities' tag with 'Utilities'

* Fix tests specifying indices in column-major order

* Update changelog

* Fix Python unit test using column-major indices

* Undo modification to previous changelog entry

* Fix PR number

* Update change-log

* Fix Abort formatting and tests

* Rename JetException and avoid naming collisions

* Fix class name in changelog

* Trigger CI

Co-authored-by: Mikhail Andrenkov <[email protected]>
  • Loading branch information
mlxd and Mandrenkov authored May 21, 2021
1 parent b512dc1 commit 45c6020
Show file tree
Hide file tree
Showing 17 changed files with 403 additions and 235 deletions.
4 changes: 4 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@

### Improvements

* Exceptions are now favoured in place of `std::terminate` with `Exception` being the new base type for all exceptions thrown by Jet. [(#3)](https://github.com/XanaduAI/jet/pull/3)

* `TaskBasedCpuContractor` now stores `Tensor` results. [(#8)](https://github.com/XanaduAI/jet/pull/8)

* `Tensor` class now checks data type at compile-time. [(#4)](https://github.com/XanaduAI/jet/pull/4)

### Breaking Changes

* Indices are now specified in row-major order. [(#10)](https://github.com/XanaduAI/jet/pull/10)

### Bug Fixes

* The output of `TensorNetwork::Contract()` and `TaskBasedCpuContractor::Contract()` now agree with one another. [(#6)](https://github.com/XanaduAI/jet/pull/6)
Expand Down
60 changes: 46 additions & 14 deletions include/jet/Abort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@

#include <exception>
#include <iostream>
#include <sstream>

/**
* @brief Macro that prints error message and source location to stderr
* and calls `std::terminate()`
* @brief Macro that throws `%Exception` with given message.
*
* @param message string literal describing error
*/
#define JET_ABORT(message) Jet::Abort(message, __FILE__, __LINE__, __func__)
/**
* @brief Macro that prints error message and source location to stderr
* and calls `std::terminate()` if expression evaluates to true
* @brief Macro that throws `%Exception` if expression evaluates to true.
*
* @param expression an expression
* @param message string literal describing error
Expand All @@ -22,8 +21,8 @@
JET_ABORT(message); \
}
/**
* @brief Macro that prints error message and source location to stderr
* and calls `std::terminate()` if expression evaluates to false
* @brief Macro that throws `%Exception` with error message if expression
* evaluates to false.
*
* @param expression an expression
* @param message string literal describing error
Expand All @@ -34,8 +33,8 @@
}

/**
* @brief Macro that prints expression and source location to stderr
* and calls `std::terminate()` if expression evaluates to false
* @brief Macro that throws `%Exception` with the given expression and source
* location if expression evaluates to false.
*
* @param expression an expression
*/
Expand All @@ -45,7 +44,40 @@
namespace Jet {

/**
* @brief Prints an error message to stderr and calls `std::terminate()`.
* @brief `%Exception` is the general exception thrown by Jet for runtime
* errors.
*
*/
class Exception : public std::exception {
public:
/**
* @brief Constructs a new `%Exception` exception.
*
* @param err_msg Error message explaining the exception condition.
*/
explicit Exception(const std::string &err_msg) noexcept : err_msg(err_msg)
{
}

/**
* @brief Destroys the `%Exception` object.
*/
virtual ~Exception() = default;

/**
* @brief Returns a string containing the exception message. Overrides
* the `std::exception` method.
*
* @return Exception message.
*/
const char *what() const noexcept { return err_msg.c_str(); }

private:
std::string err_msg;
};

/**
* @brief Throws an `%Exception` with the given error message.
*
* This function should not be called directly - use one of the `JET_ASSERT()`
* or `JET_ABORT()` macros, which provide the source location at compile time.
Expand All @@ -58,10 +90,10 @@ namespace Jet {
inline void Abort(const char *message, const char *file_name, int line,
const char *function_name)
{
std::cerr << "Fatal error in jet/" << file_name << ", line " << line
<< ", in " << function_name << ": " << message << std::endl;

std::terminate();
std::stringstream err_msg;
err_msg << "[" << file_name << "][Line:" << line
<< "][Method:" << function_name << "]: Error in Jet: " << message;
throw Exception(err_msg.str());
}

}; // namespace Jet
}; // namespace Jet
4 changes: 2 additions & 2 deletions include/jet/PathInfo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class PathInfo {
* @return Number of floating-point multiplications and additions needed to
* compute the tensor associated with the path step.
*/
double GetPathStepFlops(size_t id) const noexcept
double GetPathStepFlops(size_t id) const
{
JET_ABORT_IF_NOT(id < steps_.size(), "Step ID is invalid.");

Expand Down Expand Up @@ -208,7 +208,7 @@ class PathInfo {
*
* @return Number of elements in the tensor associated with the path step.
*/
double GetPathStepMemory(size_t id) const noexcept
double GetPathStepMemory(size_t id) const
{
JET_ABORT_IF_NOT(id < steps_.size(), "Step ID is invalid.");

Expand Down
12 changes: 6 additions & 6 deletions include/jet/Tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ Tensor<T> Reshape(const Tensor<T> &old_tensor,
using namespace Utilities;

JET_ABORT_IF_NOT(old_tensor.GetSize() ==
TensorHelpers::ShapeToSize(new_shape),
Jet::Utilities::ShapeToSize(new_shape),
"Size is inconsistent between tensors.");
Tensor<T> new_tensor(new_shape);
Utilities::FastCopy(old_tensor.GetData(), new_tensor.GetData());
Expand Down Expand Up @@ -482,7 +482,7 @@ template <class T> class Tensor {
* @param shape Dimension of each `%Tensor` index.
*/
Tensor(const std::vector<size_t> &shape)
: data_(TensorHelpers::ShapeToSize(shape))
: data_(Jet::Utilities::ShapeToSize(shape))
{
using namespace Utilities;
std::vector<std::string> indices(shape.size());
Expand All @@ -504,7 +504,7 @@ template <class T> class Tensor {
*/
Tensor(const std::vector<std::string> &indices,
const std::vector<size_t> &shape)
: data_(TensorHelpers::ShapeToSize(shape))
: data_(Jet::Utilities::ShapeToSize(shape))
{
InitIndicesAndShape(indices, shape);
}
Expand Down Expand Up @@ -690,7 +690,7 @@ template <class T> class Tensor {
/**
* @brief Sets the `%Tensor` data value at the given n-dimensional index.
*
* @param indices n-dimensional `%Tensor` data index.
* @param indices n-dimensional `%Tensor` data index in row-major order.
* @param value Data value to set at given index.
*/
void SetValue(const std::vector<size_t> &indices, const T &value)
Expand All @@ -701,7 +701,7 @@ template <class T> class Tensor {
/**
* @brief Returns the `%Tensor` data value at the given n-dimensional index.
*
* @param indices n-dimensional `%Tensor` data index.
* @param indices n-dimensional `%Tensor` data index in row-major order.
*
* @returns Complex data value.
*/
Expand Down Expand Up @@ -737,7 +737,7 @@ template <class T> class Tensor {
*
* @return Number of data elements.
*/
size_t GetSize() const { return TensorHelpers::ShapeToSize(shape_); }
size_t GetSize() const { return Jet::Utilities::ShapeToSize(shape_); }

/**
* @brief Returns a single scalar value from the `%Tensor` object.
Expand Down
13 changes: 0 additions & 13 deletions include/jet/TensorHelpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,18 +167,5 @@ inline void MultiplyTensorData(const std::vector<ComplexPrecision> &A,
}
}

/**
* @brief Calulate the size of data from the tensor size.
*
* @param tensor_shape Size of each tensor index label.
*/
inline size_t ShapeToSize(const std::vector<size_t> &tensor_shape)
{
size_t total_dim = 1;
for (const auto &dim : tensor_shape)
total_dim *= dim;
return total_dim;
}

}; // namespace TensorHelpers
}; // namespace Jet
5 changes: 2 additions & 3 deletions include/jet/TensorNetwork.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ template <class Tensor> class TensorNetwork {
* how raveled values are interpreted.
*/
void SliceIndices(const std::vector<std::string> &indices,
unsigned long long value) noexcept
unsigned long long value)
{
std::unordered_map<size_t, std::vector<size_t>> node_to_index_map;
std::vector<size_t> index_sizes(indices.size());
Expand All @@ -219,7 +219,6 @@ template <class Tensor> class TensorNetwork {
const auto it = index_to_edge_map_.find(indices[i]);
JET_ABORT_IF(it == index_to_edge_map_.end(),
"Sliced index does not exist.");

const auto &edge = it->second;
index_sizes[i] = edge.dim;

Expand Down Expand Up @@ -296,7 +295,7 @@ template <class Tensor> class TensorNetwork {
* @param path Contraction path specified as a list of node ID pairs.
* @return Tensor associated with the result of the final contraction.
*/
const Tensor &Contract(const path_t &path = {}) noexcept
const Tensor &Contract(const path_t &path = {})
{
JET_ABORT_IF(nodes_.empty(),
"An empty tensor network cannot be contracted.");
Expand Down
27 changes: 13 additions & 14 deletions include/jet/TensorNetworkIO.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,25 @@ template <class Tensor> struct TensorNetworkFile {
};

/**
* @brief `%invalid_tensor_file` is thrown when the contents of a tensor network
* @brief `%TensorFileException` is thrown when the contents of a tensor network
* file are invalid.
*/
class invalid_tensor_file : public std::invalid_argument {
class TensorFileException : public Exception {
public:
/**
* @brief Constructs a new `%invalid_tensor_file` exception.
* @brief Constructs a new `%TensorFileException` exception.
*
* @param what_arg Error message explaining what went wrong while loading a
* tensor network file.
*/
explicit invalid_tensor_file(const std::string &what_arg)
: std::invalid_argument("Error parsing tensor network file: " +
what_arg){};
explicit TensorFileException(const std::string &what_arg)
: Exception("Error parsing tensor network file: " + what_arg){};

/**
* @see invalid_tensor_file(const std::string&).
* @see TensorFileException(const std::string&).
*/
explicit invalid_tensor_file(const char *what_arg)
: invalid_tensor_file(std::string(what_arg)){};
explicit TensorFileException(const char *what_arg)
: TensorFileException(std::string(what_arg)){};
};

/**
Expand Down Expand Up @@ -178,19 +177,19 @@ template <class Tensor> class TensorNetworkSerializer {
* are correct.
*
* Throw json::exception if string is invalid json,
* invalid_tensor_file if it does not have the correct
* TensorFileException if it does not have the correct
* keys.
*/
void LoadAndValidateJSON_(const std::string &js_str)
{
js = json::parse(js_str); // throws json::exception if invalid json

if (!js.is_object()) {
throw invalid_tensor_file("root element must be an object.");
throw TensorFileException("root element must be an object.");
}

if (js.find("tensors") == js.end()) {
throw invalid_tensor_file("root object must contain 'tensors' key");
throw TensorFileException("root object must contain 'tensors' key");
}
}

Expand Down Expand Up @@ -231,7 +230,7 @@ template <class Tensor> class TensorNetworkSerializer {
* @brief Convert json array of complex values into native
* format.
*
* Throws invalid_tensor_file exception if any of elements
* Throws TensorFileException exception if any of elements
* of js_data to not encode a complex value.
*/
template <typename S>
Expand All @@ -248,7 +247,7 @@ template <class Tensor> class TensorNetworkSerializer {
}
}
catch (const json::exception &) {
throw invalid_tensor_file(
throw TensorFileException(
"Invalid element at index " + std::to_string(i) +
" of tensor " + std::to_string(tensor_index) +
": Could not parse " + js_data[i].dump() + " as complex.");
Expand Down
Loading

0 comments on commit 45c6020

Please sign in to comment.