Skip to content

Commit

Permalink
Add source location information to error messages (pytorch#6059)
Browse files Browse the repository at this point in the history
  • Loading branch information
goldsborough authored and apaszke committed Mar 29, 2018
1 parent 7ffcb20 commit d42fcdb
Show file tree
Hide file tree
Showing 35 changed files with 286 additions and 195 deletions.
10 changes: 0 additions & 10 deletions aten/src/ATen/ATenAssert.h

This file was deleted.

6 changes: 4 additions & 2 deletions aten/src/ATen/ArrayRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
// removed a bunch of slice variants for simplicity...

#pragma once
#include <assert.h>

#include <ATen/Error.h>

#include <array>
#include <iterator>
#include <vector>
#include "ATenAssert.h"

namespace at {
/// ArrayRef - Represent a constant reference to an array (0 or more elements
Expand Down
5 changes: 3 additions & 2 deletions aten/src/ATen/CPUFixedAllocator.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "TH/TH.h"
#include "ATen/Error.h"

// This file creates a fake allocator that just throws exceptions if
// it is actually used.
Expand All @@ -11,11 +12,11 @@
namespace at {

static cpu_fixed_malloc(void *, ptrdiff_t) {
runtime_error("attempting to resize a tensor view of an external blob");
AT_ERROR("attempting to resize a tensor view of an external blob");
}

static cpu_fixed_realloc(void *, void*, ptrdiff_t) {
runtime_error("attempting to resize a tensor view of an external blob");
AT_ERROR("attempting to resize a tensor view of an external blob");
}

static cpu_fixed_free(void * state, void * allocation) {
Expand Down
9 changes: 5 additions & 4 deletions aten/src/ATen/CUDAFixedAllocator.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "THC/THC.h"
#include "ATen/Error.h"

// This file creates a fake allocator that just throws exceptions if
// it is actually used.
Expand All @@ -11,11 +12,11 @@
namespace at {

static cuda_fixed_malloc(void *, void**, size_t, cudaStream_t) {
runtime_error("attempting to resize a tensor view of an external blob");
AT_ERROR("attempting to resize a tensor view of an external blob");
}

static cpu_fixed_realloc(void*, void**, size_t, size_t, cudaStream_t) {
runtime_error("attempting to resize a tensor view of an external blob");
AT_ERROR("attempting to resize a tensor view of an external blob");
}

static cuda_fixed_free(void * state, void * allocation) {
Expand All @@ -25,11 +26,11 @@ static cuda_fixed_free(void * state, void * allocation) {
}

static cuda_fixed_emptyCache(void*) {
runtime_error("?? attempting to resize a tensor view of an external blob");
AT_ERROR("?? attempting to resize a tensor view of an external blob");
}

static cuda_fixed_cacheInfo(void*, int, size_t*, size_t*) {
runtime_error("?? attempting to resize a tensor view of an external blob");
AT_ERROR("?? attempting to resize a tensor view of an external blob");
}


Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/CheckGenerator.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "ATen/Error.h"
#include "ATen/Generator.h"
#include "ATen/Utils.h"

Expand All @@ -11,7 +12,7 @@ static inline T * check_generator(Generator * expr, Generator * defaultValue) {
expr = defaultValue;
if(auto result = dynamic_cast<T*>(expr))
return result;
runtime_error("Expected a '%s' but found '%s'", typeid(T).name(), typeid(expr).name());
AT_ERROR("Expected a '%s' but found '%s'", typeid(T).name(), typeid(expr).name());
}

} // namespace at
5 changes: 3 additions & 2 deletions aten/src/ATen/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "ATen/Generator.h"
#include "ATen/Type.h"
#include "ATen/Utils.h"
#include "ATen/Error.h"

#include <memory>
#include <mutex>
Expand Down Expand Up @@ -32,15 +33,15 @@ class AT_API Context {
auto & undef = type_registry[static_cast<int>(Backend::Undefined)][static_cast<int>(ScalarType::Undefined)];
if (undef) return *undef;
}
runtime_error("%s%sType is not enabled.",toString(p),toString(s));
AT_ERROR("%s%sType is not enabled.",toString(p),toString(s));
}
return *type;
}
Generator & defaultGenerator(Backend p) {
initCUDAIfNeeded(p);
auto & generator = generator_registry[static_cast<int>(p)];
if(!generator)
runtime_error("%s backend type not enabled.",toString(p));
AT_ERROR("%s backend type not enabled.",toString(p));
return *generator;
}
bool hasMKL() const;
Expand Down
102 changes: 49 additions & 53 deletions aten/src/ATen/Dispatch.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include <ATen/ATenAssert.h>
#include <ATen/Error.h>
#include <ATen/Half.h>
#include <ATen/Type.h>

Expand All @@ -10,62 +10,58 @@
return __VA_ARGS__(); \
}

#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
[&] { \
const at::Type& the_type = TYPE; \
switch (the_type.scalarType()) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
default: \
at::runtime_error( \
"%s not implemented for '%s'", (NAME), the_type.toString()); \
} \
#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
[&] { \
const at::Type& the_type = TYPE; \
switch (the_type.scalarType()) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
default: \
AT_ERROR("%s not implemented for '%s'", (NAME), the_type.toString()); \
} \
}()

#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
[&] { \
const at::Type& the_type = TYPE; \
switch (the_type.scalarType()) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, Half, __VA_ARGS__) \
default: \
at::runtime_error( \
"%s not implemented for '%s'", (NAME), the_type.toString()); \
} \
#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
[&] { \
const at::Type& the_type = TYPE; \
switch (the_type.scalarType()) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, Half, __VA_ARGS__) \
default: \
AT_ERROR("%s not implemented for '%s'", (NAME), the_type.toString()); \
} \
}()

#define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
[&] { \
const at::Type& the_type = TYPE; \
switch (the_type.scalarType()) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
default: \
at::runtime_error( \
"%s not implemented for '%s'", (NAME), the_type.toString()); \
} \
#define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
[&] { \
const at::Type& the_type = TYPE; \
switch (the_type.scalarType()) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
default: \
AT_ERROR("%s not implemented for '%s'", (NAME), the_type.toString()); \
} \
}()

#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \
[&] { \
const at::Type& the_type = TYPE; \
switch (the_type.scalarType()) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, Half, __VA_ARGS__) \
default: \
at::runtime_error( \
"%s not implemented for '%s'", (NAME), the_type.toString()); \
} \
#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \
[&] { \
const at::Type& the_type = TYPE; \
switch (the_type.scalarType()) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, Half, __VA_ARGS__) \
default: \
AT_ERROR("%s not implemented for '%s'", (NAME), the_type.toString()); \
} \
}()
94 changes: 94 additions & 0 deletions aten/src/ATen/Error.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#pragma once

#include <ATen/ATenGeneral.h> // for AT_API

#include <cstdint>
#include <cstdio>
#include <exception>
#include <stdexcept>
#include <string>
#include <type_traits>

#include <stdarg.h>

namespace at {
namespace detail {
/// A tiny implementation of static `all_of`.
template <bool...>
struct pack;
template <bool... values>
struct all_of : std::is_same<pack<values..., true>, pack<true, values...>> {};

/// A printf wrapper that returns an std::string.
inline std::string format(const char* format_string, ...) {
static constexpr size_t kMaximumStringLength = 4096;
char buffer[kMaximumStringLength];

va_list format_args;
va_start(format_args, format_string);
vsnprintf(buffer, sizeof(buffer), format_string, format_args);
va_end(format_args);

return buffer;
}

/// Represents a location in source code (for debugging).
struct SourceLocation {
std::string toString() const {
return format("%s at %s:%d", function, file, line);
}

const char* function;
const char* file;
uint32_t line;
};
} // namespace detail

/// The primary ATen error class.
/// Provides a complete error message with source location information via
/// `what()`, and a more concise message via `what_without_location()`. Should
/// primarily be used with the `AT_ERROR` macro.
struct AT_API Error : public std::exception {
template <typename... FormatArgs>
Error(
detail::SourceLocation source_location,
const char* format_string,
FormatArgs&&... format_args)
: what_without_location_(detail::format(
format_string,
std::forward<FormatArgs>(format_args)...)),
what_(
what_without_location_ + " (" + source_location.toString() + ")") {
// NOTE: A "literal type"
// (http://en.cppreference.com/w/cpp/concept/LiteralType) could also be a
// constexpr struct, so it's not 100% guaranteed that the `printf` call
// inside `format` is safe, but it will catch 99.9% of all errors we'll run
// into, such as passing `std::string`.
static_assert(
detail::all_of<std::is_literal_type<FormatArgs>::value...>::value,
"arguments to `format` must be literal types!");
}

/// Returns the complete error message including the source location.
const char* what() const noexcept override {
return what_.c_str();
}

/// Returns only the error message string, without source location.
const char* what_without_location() const noexcept {
return what_without_location_.c_str();
}

private:
std::string what_without_location_;
std::string what_;
};
} // namespace at

#define AT_ERROR(...) \
throw at::Error({__func__, __FILE__, __LINE__}, __VA_ARGS__)

#define AT_ASSERT(cond, ...) \
if (!(cond)) { \
AT_ERROR(__VA_ARGS__); \
}
4 changes: 3 additions & 1 deletion aten/src/ATen/ExpandUtils.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#pragma once

#include "ATen/Tensor.h"
#include "ATen/Error.h"

#include <functional>
#include <sstream>
#include <tuple>
Expand All @@ -14,7 +16,7 @@ std::tuple<std::vector<int64_t>, std::vector<int64_t> > inferExpandGeometry(cons
inline void check_defined(std::initializer_list<std::reference_wrapper<const Tensor>> tensors, const char *api_name) {
for (auto& t : tensors) {
if (!t.get().defined()) {
runtime_error("%s(...) called with an undefined Tensor", api_name);
AT_ERROR("%s(...) called with an undefined Tensor", api_name);
}
}
}
Expand Down
13 changes: 7 additions & 6 deletions aten/src/ATen/UndefinedTensor.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "ATen/UndefinedTensor.h"
#include "ATen/Context.h"
#include "ATen/Error.h"

namespace at {

Expand All @@ -13,28 +14,28 @@ const char * UndefinedTensor::toString() const {
}

IntList UndefinedTensor::sizes() const {
runtime_error("sizes() called on undefined Tensor");
AT_ERROR("sizes() called on undefined Tensor");
}

int64_t UndefinedTensor::dim() const {
runtime_error("dim() called on undefined Tensor");
AT_ERROR("dim() called on undefined Tensor");
}

const char * UndefinedTensor::typeString() {
return "UndefinedType";
}
void * UndefinedTensor::unsafeGetTH(bool retain) {
runtime_error("unsafeGetTH(bool retain) called on undefined Tensor");
AT_ERROR("unsafeGetTH(bool retain) called on undefined Tensor");
}
std::unique_ptr<Storage> UndefinedTensor::storage() {
runtime_error("storage() called on undefined Tensor");
AT_ERROR("storage() called on undefined Tensor");
}

IntList UndefinedTensor::strides() const {
runtime_error("strides() called on undefined Tensor");
AT_ERROR("strides() called on undefined Tensor");
}
Scalar UndefinedTensor::localScalar() {
runtime_error("localScalar() called on undefined Tensor");
AT_ERROR("localScalar() called on undefined Tensor");
}

UndefinedTensor UndefinedTensor::_singleton;
Expand Down
Loading

0 comments on commit d42fcdb

Please sign in to comment.