Skip to content

Commit

Permalink
Extend context and name propagation in errors (#5396)
Browse files Browse the repository at this point in the history
Add error context related to particular operator during graph construction and pipeline build:
* error propagation in Pipeline::Build() is adjusted
* Python calls to backend: Pipeline::AddOperator now augment errors with context.
* OpSpec building errors propagate the proper operator name.

Replace most of the SchemaName() occurrences that were used to indicate the name operator
with the fully formatted operator name in the correct API.

Make sure that the naming information is added as soon as possible, so it can be accessed in partially
constructed schema with the correct values present.

Note: This PR mostly adds the context like:
```
Error in <device> operator `nvidia.dali.fn.operator_name`,
which was used in the pipeline definition with the following traceback:
<traceback>
encountered:
<Original error message>
```
to places where it was not previously used, but we are processing a single operator.

Error messages are adjusted to show the user-facing input/output/argument name (in uniform way)
rather than the internal one. 
Otherwise the checks and messages are preserved.
The types of error messages are not adjusted.

Signed-off-by: Krzysztof Lecki <[email protected]>
  • Loading branch information
klecki authored Apr 11, 2024
1 parent 57eb5f2 commit 3b1b5f8
Show file tree
Hide file tree
Showing 21 changed files with 295 additions and 197 deletions.
11 changes: 6 additions & 5 deletions dali/operators/generic/reshape.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -21,6 +21,7 @@
#include "dali/core/tensor_shape_print.h"
#include "dali/operators/generic/reshape.h"
#include "dali/pipeline/data/views.h"
#include "dali/pipeline/operator/name_utils.h"

namespace dali {

Expand Down Expand Up @@ -135,11 +136,11 @@ Reshape<Backend>::Reshape(const OpSpec &spec) : Base(spec) {
&& !has_src_dims_arg) {
bool can_have_dtype = spec.GetSchema().HasArgument("dtype");
if (can_have_dtype) {
DALI_ENFORCE(output_type_arg_ != DALI_NO_TYPE, make_string(OpName(),
" is no-op: arguments specify neither new shape, layout nor type."));
DALI_ENFORCE(output_type_arg_ != DALI_NO_TYPE, make_string("`", GetOpDisplayName(spec, true),
"` is no-op: arguments specify neither new shape, layout nor type."));
} else {
DALI_FAIL(make_string(OpName(),
" is no-op: arguments specify neither new shape nor layout."));
DALI_FAIL(make_string("`", GetOpDisplayName(spec, true),
"` is no-op: arguments specify neither new shape nor layout."));
}
}
use_layout_ = has_layout_arg;
Expand Down
4 changes: 0 additions & 4 deletions dali/operators/generic/reshape.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@ class Reshape : public StatelessOperator<Backend> {
TensorLayout layout_;

private:
inline const std::string &OpName() const {
return this->spec_.SchemaName();
}

TensorListShape<> input_shape_;
TensorShape<> uniform_shape_;
std::vector<float> rel_uniform_shape_;
Expand Down
5 changes: 3 additions & 2 deletions dali/operators/reader/reader_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "dali/operators/reader/loader/loader.h"
#include "dali/operators/reader/parser/parser.h"
#include "dali/pipeline/operator/checkpointing/snapshot_serializer.h"
#include "dali/pipeline/operator/name_utils.h"
#include "dali/pipeline/operator/operator.h"

namespace dali {
Expand Down Expand Up @@ -104,7 +105,7 @@ class DataReader : public Operator<Backend> {

void SaveState(OpCheckpoint &cpt, AccessOrder order) override {
if constexpr (!supports_checkpointing) {
DALI_FAIL("The reader ", spec_.SchemaName(), " does not support checkpointing.");
DALI_FAIL("The reader `", GetOpDisplayName(spec_, true), "` does not support checkpointing.");
} else {
DALI_ENFORCE(checkpointing_,
"Cannot save the checkpoint, because "
Expand All @@ -116,7 +117,7 @@ class DataReader : public Operator<Backend> {

void RestoreState(const OpCheckpoint &cpt) override {
if constexpr (!supports_checkpointing) {
DALI_FAIL("The reader ", spec_.SchemaName(), " does not support checkpointing.");
DALI_FAIL("The reader `", GetOpDisplayName(spec_, true), "` does not support checkpointing.");
} else {
DALI_ENFORCE(checkpointing_,
"Cannot restore the checkpoint, because "
Expand Down
38 changes: 19 additions & 19 deletions dali/pipeline/graph/graph_descr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
// limitations under the License.

#include <algorithm>
#include <exception>
#include <string>

#include "dali/core/error_handling.h"
#include "dali/pipeline/graph/op_graph.h"

#include "dali/pipeline/operator/error_reporting.h"
#include "dali/pipeline/operator/name_utils.h"
#include "dali/pipeline/operator/op_schema.h"

#include "dali/pipeline/operator/builtin/make_contiguous.h"
Expand Down Expand Up @@ -55,19 +57,20 @@ void CheckOpConstraints(const OpSpec &spec) {
const int additional_outputs = schema.CalculateAdditionalOutputs(spec);

DALI_ENFORCE(schema.SupportsInPlace(spec) || !spec.GetArgument<bool>("inplace"),
"Op '" + spec.SchemaName() + "' does not support in-place execution.");
DALI_ENFORCE(spec.NumRegularInput() <= schema.MaxNumInput(),
"Operator '" + spec.SchemaName() +
"' supports a maximum of " + std::to_string(schema.MaxNumInput()) + " inputs, "
"but was passed " + std::to_string(spec.NumRegularInput()) + ".");
DALI_ENFORCE(spec.NumRegularInput() >= schema.MinNumInput(),
"Operator '" + spec.SchemaName() +
"' supports a minimum of " + std::to_string(schema.MinNumInput()) + " inputs, "
"but was passed " + std::to_string(spec.NumRegularInput()) + ".");
make_string("Operator `", GetOpDisplayName(spec, true),
"` does not support in-place execution."));
DALI_ENFORCE(
spec.NumRegularInput() <= schema.MaxNumInput(),
make_string("Operator `", GetOpDisplayName(spec, true), "` supports a maximum of ",
schema.MaxNumInput(), " inputs, but was passed ", spec.NumRegularInput(), "."));
DALI_ENFORCE(
spec.NumRegularInput() >= schema.MinNumInput(),
make_string("Operator `", GetOpDisplayName(spec, true), "` supports a minimum of ",
schema.MinNumInput(), " inputs, but was passed ", spec.NumRegularInput(), "."));
DALI_ENFORCE(spec.NumOutput() == schema.CalculateOutputs(spec) + additional_outputs,
"Operator '" + spec.SchemaName() + "' supports "
+ std::to_string(schema.CalculateOutputs(spec) + additional_outputs)
+ " outputs, but was passed " + std::to_string(spec.NumOutput()) + ".");
make_string("Operator `", GetOpDisplayName(spec, true), "` supports ",
schema.CalculateOutputs(spec) + additional_outputs,
" outputs, but was passed ", spec.NumOutput(), "."));
}

OpType ParseOpType(const std::string &device) {
Expand Down Expand Up @@ -207,17 +210,14 @@ void OpGraph::InstantiateOperators() {

for (auto op_type : order) {
for (auto op_id : op_partitions_[static_cast<int>(op_type)]) {
std::exception_ptr eptr;
try {
op_nodes_[op_id].InstantiateOperator();
} catch (...) {
eptr = std::current_exception();
PropagateError({std::current_exception(),
"Critical error when building pipeline:\n" +
GetErrorContextMessage(op_nodes_[op_id].spec),
"\nCurrent pipeline object is no longer valid."});
}

PropagateError({eptr,
"Critical error when building pipeline:\n" +
GetErrorContextMessage(op_nodes_[op_id].spec),
"\nCurrent pipeline object is no longer valid."});
}
}
}
Expand Down
6 changes: 1 addition & 5 deletions dali/pipeline/operator/builtin/external_source.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -64,10 +64,6 @@ class ExternalSource : public InputOperator<Backend> {

virtual ~ExternalSource() = default;

inline string name() const override {
return "ExternalSource (" + output_name_ + ")";
}

const TensorLayout& in_layout() const override {
return layout_;
}
Expand Down
5 changes: 3 additions & 2 deletions dali/pipeline/operator/error_reporting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <sstream>
#include <stdexcept>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -96,7 +97,7 @@ void PropagateError(ErrorInfo error) {
}
}

std::string GetErrorContextMessage(const OpSpec &spec) {
std::string GetErrorContextMessage(const OpSpec &spec, std::string_view message_name) {
auto device = spec.GetArgument<std::string>("device");
auto op_name = GetOpDisplayName(spec, true);
std::transform(device.begin(), device.end(), device.begin(), ::toupper);
Expand All @@ -109,7 +110,7 @@ std::string GetErrorContextMessage(const OpSpec &spec) {
formatted_origin_stack + "\n") :
" "; // we need space before "encountered"

return make_string("Error in ", device, " operator `", op_name, "`",
return make_string(message_name, " in ", device, " operator `", op_name, "`",
optional_stack_mention, "encountered:\n\n");
}

Expand Down
5 changes: 4 additions & 1 deletion dali/pipeline/operator/error_reporting.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <exception>
#include <stdexcept>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -183,8 +184,10 @@ class DaliStopIteration : public DaliError {
* * the origin stack trace of the operator within pipeline definition.
*
* It can be prepended to the original error message.
* @param message_name Will be used as the prefix of the error message, for example:
* "Error in <device> operator <op_name>" or "Warning in <device> operator <op_name>"
*/
std::string GetErrorContextMessage(const OpSpec &spec);
std::string GetErrorContextMessage(const OpSpec &spec, std::string_view message_name = "Error");

} // namespace dali

Expand Down
18 changes: 17 additions & 1 deletion dali/pipeline/operator/name_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
namespace dali {

std::string GetOpModule(const OpSpec &spec) {
return spec.GetArgument<std::string>("_module");;
return spec.GetArgument<std::string>("_module");
}

std::string GetOpDisplayName(const OpSpec &spec, bool include_module_path) {
Expand All @@ -39,4 +39,20 @@ std::string GetOpDisplayName(const OpSpec &spec, bool include_module_path) {
}
}

std::string FormatInput(const OpSpec &spec, int input_idx, bool capitalize) {
if (spec.GetSchema().HasInputDox()) {
return make_string(capitalize ? "I" : "i", "nput `", input_idx, "` ('__",
spec.GetSchema().GetInputName(input_idx), "')");
}
return make_string(capitalize ? "I" : "i", "nput `", input_idx, "`");
}

std::string FormatOutput(const OpSpec &spec, int output_idx, bool capitalize) {
return make_string(capitalize ? "O" : "o", "utput `", output_idx, "`");
}

std::string FormatArgument(const OpSpec &spec, const std::string &argument, bool capitalize) {
return make_string(capitalize ? "A" : "a", "rgument '", argument, "'");
}

} // namespace dali
29 changes: 29 additions & 0 deletions dali/pipeline/operator/name_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,35 @@ DLL_PUBLIC std::string GetOpModule(const OpSpec &spec);
*/
DLL_PUBLIC std::string GetOpDisplayName(const OpSpec &spec, bool include_module_path = false);

/**
* @brief Uniformly format the display of the operator input index, optionally including the name
* if provided in schema doc.
*
* @param input_idx Index of the input
* @param capitalize should be true if the output should start with capital letter (used at the
* start of the sentence)
*/
DLL_PUBLIC std::string FormatInput(const OpSpec &spec, int input_idx, bool capitalize = false);

/**
* @brief Uniformly format the display of the operator output index.
*
* @param input_idx Index of the output
* @param capitalize should be true if the output should start with capital letter (used at the
* start of the sentence)
*/
DLL_PUBLIC std::string FormatOutput(const OpSpec &spec, int output_idx, bool capitalize = false);

/**
* @brief Uniformly format the display of the operator argument name
*
* @param argument string representing the name of the argument (without additional quotes)
* @param capitalize should be true if the output should start with capital letter (used at the
* start of the sentence)
*/
DLL_PUBLIC std::string FormatArgument(const OpSpec &spec, const std::string &argument,
bool capitalize = false);

} // namespace dali

#endif // DALI_PIPELINE_OPERATOR_NAME_UTILS_H_
61 changes: 54 additions & 7 deletions dali/pipeline/operator/op_spec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
// limitations under the License.

#include "dali/pipeline/operator/op_spec.h"

#include "dali/pipeline/data/types.h"
#include "dali/pipeline/operator/name_utils.h"

namespace dali {

Expand All @@ -26,7 +26,7 @@ OpSpec& OpSpec::AddInput(const string &name, const string &device, bool regular_
// We rely on the fact that regular inputs are first in inputs_ vector
DALI_ENFORCE(NumArgumentInput() == 0,
"All regular inputs (particularly, `" + name + "`) need to be added to the op `" +
this->SchemaName() + "` before argument inputs.");
GetOpDisplayName(*this, true) + "` before argument inputs.");
}

inputs_.push_back({name, device});
Expand All @@ -50,17 +50,64 @@ OpSpec& OpSpec::AddOutput(const string &name, const string &device) {

OpSpec& OpSpec::AddArgumentInput(const string &arg_name, const string &inp_name) {
DALI_ENFORCE(!this->HasArgument(arg_name), make_string(
"Argument ", arg_name, " is already specified."));
"Argument '", arg_name, "' is already specified."));
const OpSchema& schema = GetSchema();
DALI_ENFORCE(schema.HasArgument(arg_name), make_string(
"Argument '", arg_name, "' is not part of the op schema '", schema.name(), "'"));
DALI_ENFORCE(schema.IsTensorArgument(arg_name), make_string(
"Argument `", arg_name, "` in operator `", schema.name(), "` is not a a tensor argument."));
DALI_ENFORCE(schema.HasArgument(arg_name),
make_string("Argument '", arg_name, "' is not supported by operator `",
GetOpDisplayName(*this, true), "`."));
DALI_ENFORCE(schema.IsTensorArgument(arg_name),
make_string("Argument '", arg_name, "' in operator `", GetOpDisplayName(*this, true),
"` is not an argument input."));
int idx = inputs_.size();
argument_inputs_.push_back({ arg_name, idx });
argument_input_idxs_[arg_name] = idx;
AddInput(inp_name, "cpu", false);
return *this;
}

OpSpec& OpSpec::SetInitializedArg(const string& arg_name, std::shared_ptr<Argument> arg) {
if (schema_ && schema_->IsDeprecatedArg(arg_name)) {
const auto& deprecation_meta = schema_->DeprecatedArgMeta(arg_name);
// Argument was removed, and we can discard it
if (deprecation_meta.removed) {
return *this;
}
if (!deprecation_meta.renamed_to.empty()) {
const auto& new_arg_name = deprecation_meta.renamed_to;
DALI_ENFORCE(argument_idxs_.find(new_arg_name) == argument_idxs_.end(),
make_string("Operator `", GetOpDisplayName(*this, true), "` got an unexpected '",
arg_name, "' deprecated argument when '", new_arg_name,
"' was already provided."));

set_through_deprecated_arguments_[new_arg_name] = arg_name;
// Adjust the arg so it carries the proper name for serialization
if (arg->has_name()) {
arg->set_name(new_arg_name);
}
auto [it, inserted] = argument_idxs_.insert({new_arg_name, arguments_.size()});
if (inserted)
arguments_.push_back(std::move(arg));
else
arguments_[it->second] = std::move(arg);
return *this;
}
}
EnforceNoAliasWithDeprecated(arg_name);
auto [it, inserted] = argument_idxs_.insert({arg_name, arguments_.size()});
if (inserted)
arguments_.push_back(std::move(arg));
else
arguments_[it->second] = std::move(arg);
return *this;
}

void OpSpec::EnforceNoAliasWithDeprecated(const string& arg_name) {
auto set_through = set_through_deprecated_arguments_.find(arg_name);
DALI_ENFORCE(set_through == set_through_deprecated_arguments_.end(),
make_string("Operator `", GetOpDisplayName(*this, true), "` got an unexpected '",
set_through->second, "' deprecated argument when '", arg_name,
"' was already provided."));
}


} // namespace dali
Loading

0 comments on commit 3b1b5f8

Please sign in to comment.