diff --git a/dali/operators/generic/roi_random_crop.cc b/dali/operators/generic/roi_random_crop.cc index 4e4e3071ab3..cc8b6e7702a 100644 --- a/dali/operators/generic/roi_random_crop.cc +++ b/dali/operators/generic/roi_random_crop.cc @@ -40,6 +40,7 @@ either specified with `roi_start`/`roi_end` or `roi_start`/`roi_shape`. The operator produces an output representing the cropping window start coordinates. )code") + .AddRandomSeedArg() .AddArg("crop_shape", R"code(Cropping window dimensions.)code", DALI_INT_VEC, true) .AddArg("roi_start", diff --git a/dali/operators/image/crop/bbox_crop.cc b/dali/operators/image/crop/bbox_crop.cc index 04fffe0b118..0dc41d570f2 100644 --- a/dali/operators/image/crop/bbox_crop.cc +++ b/dali/operators/image/crop/bbox_crop.cc @@ -194,6 +194,7 @@ associated with each of the bounding boxes.)code") return spec.NumRegularInput() - 1 + // +1 if labels are provided spec.GetArgument<bool>("output_bbox_indices"); // +1 if output_bbox_indices=True }) + .AddRandomSeedArg() .AddOptionalArg( "thresholds", R"code(Minimum IoU or a different metric, if specified by `threshold_type`, of the diff --git a/dali/operators/image/crop/random_crop_attr.cc b/dali/operators/image/crop/random_crop_attr.cc index 9a474c5baa4..e15fc8c5a4e 100644 --- a/dali/operators/image/crop/random_crop_attr.cc +++ b/dali/operators/image/crop/random_crop_attr.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -29,6 +29,7 @@ The cropped image's area will be equal to ``A`` * original image's area.)code", std::vector<float>{0.08, 1.0}) .AddOptionalArg("num_attempts", R"code(Maximum number of attempts used to choose random area and aspect ratio.)code", - 10); + 10) + .AddRandomSeedArg(); } // namespace dali diff --git a/dali/operators/image/remap/jitter.cu b/dali/operators/image/remap/jitter.cu index 662cae7d423..fbdfedb99b6 100644 --- a/dali/operators/image/remap/jitter.cu +++ b/dali/operators/image/remap/jitter.cu @@ -29,6 +29,7 @@ and bounded by half of the `nDegree` parameter.)code") R"code(Each pixel is moved by a random amount in the ``[-nDegree/2, nDegree/2]`` range)code", 2) .InputLayout(0, "HWC") + .AddRandomSeedArg() .AddParent("DisplacementFilter"); DALI_REGISTER_OPERATOR(Jitter, Jitter<GPUBackend>, GPU); diff --git a/dali/operators/random/batch_permutation.cc b/dali/operators/random/batch_permutation.cc index d2ba0d754d5..84d0b9e2330 100644 --- a/dali/operators/random/batch_permutation.cc +++ b/dali/operators/random/batch_permutation.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -29,6 +29,7 @@ indexing samples in the batch.)") R"(If true, the output can contain repetitions and omissions.)", false) .AddOptionalArg("no_fixed_points", R"(If true, the the output permutation cannot contain fixed points, that is ``out[i] != i``. This argument is ignored when batch size is 1.)", false) + .AddRandomSeedArg() .AddParent("ImplicitScopeAttr"); void BatchPermutation::RunImpl(Workspace &ws) { diff --git a/dali/operators/random/choice_cpu.cc b/dali/operators/random/choice_cpu.cc index c873fdee998..6ac992f6927 100644 --- a/dali/operators/random/choice_cpu.cc +++ b/dali/operators/random/choice_cpu.cc @@ -48,7 +48,8 @@ that is: :meth:`nvidia.dali.types.DALIDataType`, :meth:`nvidia.dali.types.DALIIm "Distribution of the probabilities. " "If not specified, uniform distribution is assumed.", nullptr, true) - .AddOptionalArg<std::vector<int>>("shape", "Shape of the output data.", nullptr, true); + .AddOptionalArg<std::vector<int>>("shape", "Shape of the output data.", nullptr, true) + .AddRandomSeedArg(); DALI_REGISTER_OPERATOR(random__Choice, Choice<CPUBackend>, CPU); diff --git a/dali/operators/random/noise/gaussian_noise.cc b/dali/operators/random/noise/gaussian_noise.cc index eabbba37cb4..3fa795ce7a9 100644 --- a/dali/operators/random/noise/gaussian_noise.cc +++ b/dali/operators/random/noise/gaussian_noise.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ The shape and data type of the output will match the input. )code") .NumInput(1) .NumOutput(1) + .AddRandomSeedArg() .AddOptionalArg<float>("mean", R"code(Mean of the distribution.)code", 0.f, true) diff --git a/dali/operators/random/noise/salt_and_pepper_noise.cc b/dali/operators/random/noise/salt_and_pepper_noise.cc index f22a41471ce..d2e24303713 100644 --- a/dali/operators/random/noise/salt_and_pepper_noise.cc +++ b/dali/operators/random/noise/salt_and_pepper_noise.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ The shape and data type of the output will match the input. )code") .NumInput(1) .NumOutput(1) + .AddRandomSeedArg() .AddOptionalArg<float>("prob", R"code(Probability of an output value to take a salt or pepper value.)code", 0.05f, true) diff --git a/dali/operators/random/noise/shot_noise.cc b/dali/operators/random/noise/shot_noise.cc index 09932a03eda..ee59fe25e48 100644 --- a/dali/operators/random/noise/shot_noise.cc +++ b/dali/operators/random/noise/shot_noise.cc @@ -42,6 +42,7 @@ The shape and data type of the output will match the input. )code") .NumInput(1) .NumOutput(1) + .AddRandomSeedArg() .AddOptionalArg<float>("factor", R"code(Factor parameter.)code", 20.0f, true); diff --git a/dali/operators/random/rng_base.cc b/dali/operators/random/rng_base.cc index b8a4dd03361..4273600ac5b 100644 --- a/dali/operators/random/rng_base.cc +++ b/dali/operators/random/rng_base.cc @@ -29,6 +29,8 @@ It should be added as parent to all RNG operators.)code") .. note:: The generated numbers are converted to the output data type, rounding and clamping if necessary. -)code", nullptr); +)code", nullptr) + .AddRandomSeedArg(); + } // namespace dali diff --git a/dali/operators/reader/loader/loader.cc b/dali/operators/reader/loader/loader.cc index 7bbacf1504a..14402c52a5c 100644 --- a/dali/operators/reader/loader/loader.cc +++ b/dali/operators/reader/loader/loader.cc @@ -18,6 +18,7 @@ namespace dali { DALI_SCHEMA(LoaderBase) + .AddRandomSeedArg() .AddOptionalArg("random_shuffle", R"code(Determines whether to randomly shuffle data. diff --git a/dali/operators/reader/tfrecord_reader_op.cc b/dali/operators/reader/tfrecord_reader_op.cc index a77ea7a732d..1d6e8634b12 100644 --- a/dali/operators/reader/tfrecord_reader_op.cc +++ b/dali/operators/reader/tfrecord_reader_op.cc @@ -38,6 +38,7 @@ DALI_REGISTER_OPERATOR(readers___TFRecord, TFRecordReader, CPU); // Common part of schema for internal readers._tfrecord and public readers.tfrecord schema. DALI_SCHEMA(readers___TFRecordBase) .DocStr(R"code(Read sample data from a TensorFlow TFRecord file.)code") + .AddRandomSeedArg() .AddArg("path", R"code(List of paths to TFRecord files.)code", DALI_STRING_VEC) diff --git a/dali/operators/segmentation/random_mask_pixel.cc b/dali/operators/segmentation/random_mask_pixel.cc index 159e91a64a6..dc6d7aecde3 100644 --- a/dali/operators/segmentation/random_mask_pixel.cc +++ b/dali/operators/segmentation/random_mask_pixel.cc @@ -51,7 +51,8 @@ This argument is mutually exclusive with `value` argument. If 0, the pixel position is sampled uniformly from all available pixels.)code", 0, true) .NumInput(1) - .NumOutput(1); + .NumOutput(1) + .AddRandomSeedArg(); class RandomMaskPixelCPU : public rng::OperatorWithRng<CPUBackend> { public: diff --git a/dali/operators/segmentation/random_object_bbox.cc b/dali/operators/segmentation/random_object_bbox.cc index 07f675185aa..a168cf935c2 100644 --- a/dali/operators/segmentation/random_object_bbox.cc +++ b/dali/operators/segmentation/random_object_bbox.cc @@ -51,6 +51,7 @@ With probability 1-foreground_prob, the entire area of the input is returned.)") int output_class = spec.GetArgument<bool>("output_class"); return 1 + separate_corners + output_class; }) + .AddRandomSeedArg() .AddOptionalArg("ignore_class", R"(If True, all objects are picked with equal probability, regardless of the class they belong to. Otherwise, a class is picked first and then an object is randomly selected from this class. diff --git a/dali/operators/ssd/random_crop.cc b/dali/operators/ssd/random_crop.cc index 812f9178f64..35cfda7543a 100644 --- a/dali/operators/ssd/random_crop.cc +++ b/dali/operators/ssd/random_crop.cc @@ -34,6 +34,7 @@ cropped and valid bounding boxes and valid labels are returned.)code") .NumInput(3) // [img, bbox, label] .NumOutput(3) // [img, bbox, label] .AddOptionalArg("num_attempts", R"code(Number of attempts.)code", 1) + .AddRandomSeedArg() .Deprecate("RandomBBoxCrop"); // deprecated in DALI 0.30 /* diff --git a/dali/pipeline/operator/argument.h b/dali/pipeline/operator/argument.h index e419998c32a..d703ae689a4 100644 --- a/dali/pipeline/operator/argument.h +++ b/dali/pipeline/operator/argument.h @@ -111,8 +111,8 @@ class Argument { return has_name_; } - inline const string get_name() const { - return has_name() ? name_ : "<no name>"; + inline std::string_view get_name() const & { + return has_name() ? std::string_view(name_) : "<no name>"; } inline void set_name(string name) { @@ -126,7 +126,7 @@ class Argument { } virtual std::string ToString() const { - return get_name(); + return std::string(get_name()); } virtual DALIDataType GetTypeId() const = 0; @@ -230,8 +230,8 @@ template <typename T> T Argument::Get() { ArgumentInst<T>* self = dynamic_cast<ArgumentInst<T>*>(this); if (self == nullptr) { - DALI_FAIL("Invalid type of argument \"" + this->get_name() + "\". Expected " + - typeid(T).name()); + DALI_FAIL(make_string("Invalid type of argument \"", get_name(), "\". Expected ", + typeid(T).name())); } return self->Get(); } diff --git a/dali/pipeline/operator/builtin/input_operator.h b/dali/pipeline/operator/builtin/input_operator.h index 1d15c323a9d..631a20a0049 100644 --- a/dali/pipeline/operator/builtin/input_operator.h +++ b/dali/pipeline/operator/builtin/input_operator.h @@ -570,7 +570,7 @@ class InputOperator : public Operator<Backend>, virtual public BatchSizeProvider * Checks, if the Operator defined by provided Schema is an InputOperator */ inline bool IsInputOperator(const OpSchema &schema) { - const auto &parents = schema.GetParents(); + const auto &parents = schema.GetParentNames(); return std::any_of(parents.begin(), parents.end(), [](const std::string &p) { return p == "InputOperatorBase"; }); } diff --git a/dali/pipeline/operator/op_schema.cc b/dali/pipeline/operator/op_schema.cc index 528e7afb47a..e10a9e04fd2 100644 --- a/dali/pipeline/operator/op_schema.cc +++ b/dali/pipeline/operator/op_schema.cc @@ -14,6 +14,8 @@ #include <string> +#include <string_view> +#include <sstream> #include "dali/core/error_handling.h" #include "dali/core/python_util.h" @@ -22,46 +24,50 @@ namespace dali { -std::map<string, OpSchema> &SchemaRegistry::registry() { - static std::map<string, OpSchema> schema_map; +std::map<string, OpSchema, std::less<>> &SchemaRegistry::registry() { + static std::map<string, OpSchema, std::less<>> schema_map; return schema_map; } -OpSchema &SchemaRegistry::RegisterSchema(const std::string &name) { +OpSchema &SchemaRegistry::RegisterSchema(std::string_view name) { auto &schema_map = registry(); - DALI_ENFORCE(schema_map.count(name) == 0, - "OpSchema already " - "registered for operator '" + - name + - "'. DALI_SCHEMA(op) " - "should only be called once per op."); + + if (schema_map.count(name)) + throw std::logic_error(make_string( + "OpSchema already registered for operator '", name, "'.\n" + "DALI_SCHEMA(op) should only be called once per op.")); // Insert the op schema and return a reference to it - schema_map.emplace(std::make_pair(name, OpSchema(name))); - return schema_map.at(name); + auto [it, inserted] = schema_map.emplace(name, name); + return it->second; } -const OpSchema &SchemaRegistry::GetSchema(const std::string &name) { +const OpSchema &SchemaRegistry::GetSchema(std::string_view name) { auto &schema_map = registry(); auto it = schema_map.find(name); - DALI_ENFORCE(it != schema_map.end(), "Schema for operator '" + name + "' not registered"); + if (it == schema_map.end()) + throw invalid_key("Schema for operator '" + std::string(name) + "' not registered"); + return it->second; } -const OpSchema *SchemaRegistry::TryGetSchema(const std::string &name) { +const OpSchema *SchemaRegistry::TryGetSchema(std::string_view name) { auto &schema_map = registry(); auto it = schema_map.find(name); return it != schema_map.end() ? &it->second : nullptr; } const OpSchema &OpSchema::Default() { - static OpSchema default_schema(""); + static OpSchema default_schema(DefaultSchemaTag{}); return default_schema; } -OpSchema::OpSchema(const std::string &name) : name_(name) { +OpSchema::OpSchema(std::string_view name) : name_(name) { // Process the module path and operator name InitNames(); +} + +OpSchema::OpSchema(DefaultSchemaTag) : name_(""), default_(true) { // Fill internal arguments AddInternalArg("num_threads", "Number of CPU threads in a thread pool", -1); AddInternalArg("max_batch_size", "Max batch size", -1); @@ -70,10 +76,9 @@ OpSchema::OpSchema(const std::string &name) : name_(name) { AddInternalArg("default_cuda_stream_priority", "Default cuda stream priority", 0); // deprecated AddInternalArg("checkpointing", "Setting to `true` enables checkpointing", false); - AddOptionalArg("seed", R"code(Random seed. - + AddOptionalArg<int>("seed", R"code(Random seed. If not provided, it will be populated based on the global seed of the pipeline.)code", - -1); + nullptr); AddOptionalArg("bytes_per_sample_hint", R"code(Output size hint, in bytes per sample. @@ -129,6 +134,11 @@ a pipeline scope. False if it was defined without pipeline being set as current. "Operator name as presented in the API it was instantiated in (without the module " "path), for example: cast_like or CastLike.", OperatorName()); + + DeprecateArg("seed", true, + "The argument \"seed\" should not be used with operators that don't use " + "random numbers."); + arguments_["seed"].hidden = true; } @@ -156,31 +166,38 @@ void OpSchema::InitNames() { operator_name_ = name_.substr(start_pos); } -OpSchema &OpSchema::DocStr(const string &dox) { - dox_ = dox; +OpSchema &OpSchema::DocStr(std::string dox) { + dox_ = std::move(dox); return *this; } -OpSchema &OpSchema::InputDox(int index, const string &name, const string &type_doc, - const string &doc) { +OpSchema &OpSchema::InputDox(int index, std::string_view name, std::string type_doc, + std::string doc) { CheckInputIndex(index); - DALI_ENFORCE(!name.empty(), "Name of the argument should not be empty"); - DALI_ENFORCE(call_dox_.empty(), + if (name.empty()) + throw std::invalid_argument("Name of the argument should not be empty"); + if (!call_dox_.empty()) + throw std::logic_error( "Providing docstrings for inputs is not supported when the CallDocStr was used."); input_dox_set_ = true; - input_dox_[index] = {name, type_doc, doc}; + input_info_[index].name = name; + input_info_[index].doc = {std::move(type_doc), std::move(doc)}; return *this; } -OpSchema &OpSchema::CallDocStr(const string &doc, bool append_kwargs_section) { - DALI_ENFORCE(!doc.empty(), "The custom docstring for __call__ should not be empty."); +OpSchema &OpSchema::CallDocStr(std::string doc, bool append_kwargs_section) { + if (doc.empty()) + throw std::logic_error("The custom docstring for __call__ should not be empty."); - DALI_ENFORCE(!input_dox_set_, + if (input_dox_set_) { + throw std::logic_error( "Providing docstring for `__call__` is not supported when docstrings for separate " "inputs were set using InputDox."); - call_dox_ = doc; + } + + call_dox_ = std::move(doc); append_kwargs_section_ = append_kwargs_section; return *this; } @@ -199,49 +216,48 @@ OpSchema &OpSchema::AdditionalOutputsFn(SpecFunc f) { OpSchema &OpSchema::NumInput(int n) { - DALI_ENFORCE(n >= 0); + if (n < 0) + throw std::invalid_argument("The number of inputs must not be negative"); max_num_input_ = n; min_num_input_ = n; - input_dox_.resize(n); - input_layouts_.resize(n); - input_devices_.resize(n); + input_info_.resize(n); return *this; } OpSchema &OpSchema::NumInput(int min, int max) { - DALI_ENFORCE(min <= max); - DALI_ENFORCE(min >= 0); - DALI_ENFORCE(max >= 0); + if (min < 0 || max < 0) + throw std::invalid_argument("The number of inputs must not be negative"); + if (min > max) + throw std::invalid_argument("The min. number of inputs must not be greater than max."); min_num_input_ = min; max_num_input_ = max; - input_layouts_.resize(max); - input_dox_.resize(max); - input_devices_.resize(max); + input_info_.resize(max); return *this; } OpSchema &OpSchema::InputDevice(int first, int one_past, dali::InputDevice device) { for (int i = first; i < one_past; i++) - input_devices_[i] = device; + input_info_[i].device = device; return *this; } OpSchema &OpSchema::InputDevice(int index, dali::InputDevice device) { - input_devices_[index] = device; + input_info_[index].device = device; return *this; } DLL_PUBLIC dali::InputDevice OpSchema::GetInputDevice(int index) const { - return input_devices_[index]; + return input_info_[index].device; } OpSchema &OpSchema::NumOutput(int n) { - DALI_ENFORCE(n >= 0); + if (n < 0) + throw std::invalid_argument("The number of outputs must not be negative"); num_output_ = n; return *this; } @@ -253,12 +269,6 @@ OpSchema &OpSchema::DisableAutoInputDox() { } -OpSchema &OpSchema::DisallowInstanceGrouping() { - allow_instance_grouping_ = false; - return *this; -} - - OpSchema &OpSchema::SequenceOperator() { is_sequence_operator_ = true; return *this; @@ -295,10 +305,10 @@ OpSchema &OpSchema::MakeDocPartiallyHidden() { } -OpSchema &OpSchema::Deprecate(const std::string &in_favor_of, const std::string &explanation) { +OpSchema &OpSchema::Deprecate(std::string in_favor_of, std::string explanation) { is_deprecated_ = true; - deprecated_in_favor_of_ = in_favor_of; - deprecation_message_ = explanation; + deprecated_in_favor_of_ = std::move(in_favor_of); + deprecation_message_ = std::move(explanation); return *this; } @@ -309,19 +319,50 @@ OpSchema &OpSchema::Unserializable() { } -OpSchema &OpSchema::AddArg(const std::string &s, const std::string &doc, const DALIDataType dtype, - bool enable_tensor_input, bool support_per_frame_input) { - CheckArgument(s); - arguments_[s] = {doc, dtype}; - if (enable_tensor_input) { - tensor_arguments_[s] = {support_per_frame_input}; +ArgumentDef &OpSchema::AddArgumentImpl(std::string_view name) { + if (HasInternalArgument(name)) + throw std::invalid_argument(make_string( + "The argument name `", name, "` is reserved for internal use")); + + auto [it, inserted] = arguments_.emplace(name, ArgumentDef()); + if (!inserted) { + throw std::invalid_argument(make_string( + "The schema for operator `", name_, "` already contains an argument `", name, "`.")); } + auto &arg = it->second; + arg.defined_in = this; + arg.name = std::string(name); + return arg; +} + +ArgumentDef &OpSchema::AddArgumentImpl(std::string_view name, + DALIDataType type, + std::unique_ptr<Value> default_value, + std::string doc) { + auto &arg = AddArgumentImpl(name); + arg.dtype = type; + arg.default_value = std::move(default_value); + arg.doc = std::move(doc); + + if (ShouldHideArgument(name)) + arg.hidden = true; + + return arg; +} + +OpSchema &OpSchema::AddArg(std::string_view s, std::string doc, const DALIDataType dtype, + bool enable_tensor_input, bool support_per_frame_input) { + auto &arg = AddArgumentImpl(s, dtype, nullptr, std::move(doc)); + arg.required = true; + arg.tensor = enable_tensor_input; + if (arg.tensor) + arg.per_frame = support_per_frame_input; return *this; } -OpSchema &OpSchema::AddTypeArg(const std::string &s, const std::string &doc) { - return AddArg(s, doc, DALI_DATA_TYPE); +OpSchema &OpSchema::AddTypeArg(std::string_view s, std::string doc) { + return AddArg(s, std::move(doc), DALI_DATA_TYPE); } @@ -332,12 +373,17 @@ OpSchema &OpSchema::InputLayout(int index, TensorLayout layout) { OpSchema &OpSchema::InputLayout(int index, std::initializer_list<TensorLayout> layouts) { CheckInputIndex(index); - DALI_ENFORCE(input_layouts_[index].empty(), - "Layouts for input " + std::to_string(index) + " already specified"); + if (!input_info_[index].layouts.empty()) + throw std::logic_error(make_string("Layouts for input ", index, " already specified")); + + std::set<TensorLayout> unique_layouts; for (auto &l : layouts) { - DALI_ENFORCE(!l.empty(), "Cannot specify an empty layout for an input"); + auto [it, inserted] = unique_layouts.insert(l); + if (!inserted) + throw std::logic_error(make_string( + "The layout \"", l, "\" for input ", index, " specified more than once.")); } - input_layouts_[index] = layouts; + input_info_[index].layouts = layouts; return *this; } @@ -357,16 +403,16 @@ OpSchema &OpSchema::InputLayout(std::initializer_list<TensorLayout> layouts) { const TensorLayout &OpSchema::GetInputLayout(int index, int sample_ndim, const TensorLayout &layout) const { CheckInputIndex(index); - DALI_ENFORCE(layout.empty() || layout.ndim() == sample_ndim, - make_string("The layout '", layout, "' is not valid for ", sample_ndim, - "-dimensional tensor")); + if (!layout.empty() && layout.ndim() != sample_ndim) + throw std::invalid_argument(make_string( + "The layout '", layout, "' is not valid for ", sample_ndim, "-dimensional tensor")); - if (input_layouts_[index].empty()) { + if (input_info_[index].layouts.empty()) { return layout; } if (layout.empty()) { - for (auto &l : input_layouts_[index]) + for (auto &l : input_info_[index].layouts) if (l.ndim() == sample_ndim) return l; std::stringstream ss; @@ -374,11 +420,11 @@ const TensorLayout &OpSchema::GetInputLayout(int index, int sample_ndim, << " does not match any of the allowed" " layouts for input " << index << ". Valid layouts are:\n"; - for (auto &l : input_layouts_[index]) + for (auto &l : input_info_[index].layouts) ss << l.c_str() << "\n"; - DALI_FAIL(ss.str()); + throw std::invalid_argument(ss.str()); } else { - for (auto &l : input_layouts_[index]) + for (auto &l : input_info_[index].layouts) if (l == layout) return l; std::stringstream ss; @@ -386,68 +432,75 @@ const TensorLayout &OpSchema::GetInputLayout(int index, int sample_ndim, << "\" does not match any of the allowed" " layouts for input " << index << ". Valid layouts are:\n"; - for (auto &l : input_layouts_[index]) + for (auto &l : input_info_[index].layouts) ss << l.c_str() << "\n"; - DALI_FAIL(ss.str()); + throw std::invalid_argument(ss.str()); } } const std::vector<TensorLayout> &OpSchema::GetSupportedLayouts(int input_idx) const { CheckInputIndex(input_idx); - return input_layouts_[input_idx]; + return input_info_[input_idx].layouts; } -OpSchema &OpSchema::AddOptionalArg(const std::string &s, const std::string &doc, DALIDataType dtype, +OpSchema &OpSchema::AddOptionalArg(std::string_view s, std::string doc, DALIDataType dtype, std::nullptr_t, bool enable_tensor_input, bool support_per_frame_input) { - CheckArgument(s); - optional_arguments_[s] = {doc, dtype, nullptr, ShouldHideArgument(s)}; - if (enable_tensor_input) { - tensor_arguments_[s] = {support_per_frame_input}; - } + auto &arg = AddArgumentImpl(s, dtype, nullptr, std::move(doc)); + arg.tensor = enable_tensor_input; + if (arg.tensor) + arg.per_frame = support_per_frame_input; return *this; } -OpSchema &OpSchema::AddOptionalTypeArg(const std::string &s, const std::string &doc, +OpSchema &OpSchema::AddOptionalTypeArg(std::string_view name, std::string doc, DALIDataType default_value) { - CheckArgument(s); - auto to_store = Value::construct(default_value); - optional_arguments_[s] = {doc, DALI_DATA_TYPE, to_store.get(), ShouldHideArgument(s)}; - optional_arguments_unq_.push_back(std::move(to_store)); + AddArgumentImpl(name, DALI_DATA_TYPE, Value::construct(default_value), std::move(doc)); return *this; } -OpSchema &OpSchema::AddOptionalTypeArg(const std::string &s, const std::string &doc) { - return AddOptionalArg<DALIDataType>(s, doc, nullptr); +OpSchema &OpSchema::AddOptionalTypeArg(std::string_view s, std::string doc) { + return AddOptionalArg<DALIDataType>(s, std::move(doc), nullptr); } -OpSchema &OpSchema::AddOptionalArg(const std::string &s, const std::string &doc, - const char *default_value) { - return AddOptionalArg(s, doc, std::string(default_value), false); +OpSchema &OpSchema::AddRandomSeedArg() { + AddOptionalArg<int>("seed", + "Random seed; if not set, one will be assigned automatically.", + -1); + return *this; } +bool OpSchema::HasRandomSeedArg() const { + return !IsDeprecatedArg("seed"); +} -OpSchema &OpSchema::DeprecateArgInFavorOf(const std::string &arg_name, std::string renamed_to, +OpSchema &OpSchema::DeprecateArgInFavorOf(std::string_view arg_name, std::string renamed_to, std::string msg) { if (msg.empty()) msg = DefaultDeprecatedArgMsg(arg_name, renamed_to, false); - deprecated_arguments_[arg_name] = {std::move(renamed_to), std::move(msg), false}; + + auto &alias = AddArgumentImpl(arg_name); + alias.defined_in = this; + alias.deprecated = std::make_unique<ArgumentDeprecation>(renamed_to, std::move(msg), false); + return *this; } - -OpSchema &OpSchema::DeprecateArg(const std::string &arg_name, bool removed, std::string msg) { - DALI_ENFORCE( - HasArgument(arg_name), - make_string("Argument \"", arg_name, - "\" has been marked for deprecation but it is not present in the schema.")); +OpSchema &OpSchema::DeprecateArg(std::string_view arg_name, bool removed, std::string msg) { if (msg.empty()) msg = DefaultDeprecatedArgMsg(arg_name, {}, removed); - deprecated_arguments_[arg_name] = {{}, std::move(msg), removed}; + + auto &arg = arguments_[std::string(arg_name)]; + + if (arg.deprecated) + throw std::logic_error(make_string("The argument \"", arg_name, "\" is already deprecated")); + + arg.deprecated = std::make_unique<ArgumentDeprecation>("", std::move(msg), removed); + return *this; } @@ -459,8 +512,8 @@ OpSchema &OpSchema::InPlaceFn(SpecFunc f) { } -OpSchema &OpSchema::AddParent(const std::string &parentName) { - parents_.push_back(parentName); +OpSchema &OpSchema::AddParent(std::string parent_name) { + parent_names_.push_back(std::move(parent_name)); return *this; } @@ -476,26 +529,78 @@ OpSchema &OpSchema::PassThrough(const std::map<int, int> &inout) { for (const auto &elems : inout) { outputs.insert(elems.second); } - DALI_ENFORCE(inout.size() == outputs.size(), - "Pass through can be defined only as 1-1 mapping between inputs and outputs, " - "without duplicates."); - DALI_ENFORCE(!HasSamplewisePassThrough(), "Two different modes of pass through can't be mixed."); + if (inout.size() != outputs.size()) + throw std::logic_error( + "Pass through can be defined only as 1-1 mapping between inputs and outputs, " + "without duplicates."); + + if (HasSamplewisePassThrough()) + throw std::logic_error("Two different modes of pass through can't be mixed."); passthrough_map_ = inout; return *this; } OpSchema &OpSchema::SamplewisePassThrough() { - DALI_ENFORCE(!HasStrictPassThrough(), "Two different modes of pass through can't be mixed."); + if (HasStrictPassThrough()) + throw std::logic_error("Two different modes of pass through can't be mixed."); samplewise_any_passthrough_ = true; return *this; } -const vector<std::string> &OpSchema::GetParents() const { - return parents_; +const vector<std::string> &OpSchema::GetParentNames() const { + return parent_names_; } +const vector<const OpSchema *> &OpSchema::GetParents() const { + return parents_.Get([&]() { + std::vector<const OpSchema *> parents; + if (default_) + return parents; // the default schema has no parents + + parents.reserve(parent_names_.size() + 1); // add one more for the default + for (auto &name : parent_names_) { + parents.push_back(&SchemaRegistry::GetSchema(name)); + } + parents.push_back(&Default()); + return parents; + }); +} + +std::map<std::string, const ArgumentDef *, std::less<>> &OpSchema::GetFlattenedArguments() const { + return flattened_arguments_.Get([&]() { + if (circular_inheritance_detector_) + throw std::logic_error(make_string( + "Circular schema inheritance detected in \"", name(), "\"")); + circular_inheritance_detector_++; + + std::map<std::string, const ArgumentDef *, std::less<>> args; + for (auto &[name, arg] : arguments_) + args.emplace(name, &arg); + + // First insert all non-deprecated arguments that don't come from the default schema. + // Once we've gone over those, add the deprecated ones and finally the default. + std::vector<std::pair<std::string_view, const ArgumentDef *>> deprecated, from_default; + for (auto *parent : GetParents()) { + for (auto &[name, arg] : parent->GetFlattenedArguments()) { + if (arg->defined_in == &Default()) + from_default.emplace_back(name, arg); + else if (arg->deprecated) + deprecated.emplace_back(name, arg); + else + args.emplace(name, arg); // this will skip arguments defined in this schema + } + } + for (auto &[name, arg] : deprecated) + args.emplace(name, arg); + for (auto &[name, arg] : from_default) + args.emplace(name, arg); + + circular_inheritance_detector_--; + return args; + }); +} string OpSchema::Dox() const { return dox_; @@ -518,7 +623,8 @@ DLL_PUBLIC bool OpSchema::HasCallDox() const { DLL_PUBLIC std::string OpSchema::GetCallDox() const { - DALI_ENFORCE(HasCallDox(), "__call__ docstring was not set"); + if (!HasCallDox()) + throw std::logic_error("__call__ docstring was not set"); return call_dox_; } @@ -529,16 +635,18 @@ DLL_PUBLIC bool OpSchema::HasInputDox() const { DLL_PUBLIC std::string OpSchema::GetCallSignatureInputs() const { - DALI_ENFORCE(HasInputDox(), "Input documentation was not specified for this operator."); + if (!HasInputDox()) + throw std::logic_error("Input documentation was not specified for this operator."); + std::stringstream result; for (int i = 0; i < MinNumInput(); i++) { - result << input_dox_[i].name; + result << input_info_[i].name; if (i < MaxNumInput() - 1) { result << ", "; } } for (int i = MinNumInput(); i < MaxNumInput(); i++) { - result << input_dox_[i].name << " = None"; + result << input_info_[i].name << " = None"; if (i < MaxNumInput() - 1) { result << ", "; } @@ -549,25 +657,28 @@ DLL_PUBLIC std::string OpSchema::GetCallSignatureInputs() const { DLL_PUBLIC std::string OpSchema::GetInputName(int input_idx) const { CheckInputIndex(input_idx); - DALI_ENFORCE(HasInputDox(), "Input documentation was not specified for this operator."); - DALI_ENFORCE(!input_dox_[input_idx].name.empty(), - make_string("Docstring for input ", input_idx, - "was not set. All inputs should be documented.")); - return input_dox_[input_idx].name; + if (!HasInputDox()) + throw std::logic_error("Input documentation was not specified for this operator."); + if (input_info_[input_idx].name.empty()) + throw std::logic_error(make_string("Docstring for input ", input_idx, + "was not set. All inputs should be documented.")); + return input_info_[input_idx].name; } DLL_PUBLIC std::string OpSchema::GetInputType(int input_idx) const { CheckInputIndex(input_idx); - DALI_ENFORCE(HasInputDox(), "Input documentation was not specified for this operator."); - return input_dox_[input_idx].type_doc; + if (!HasInputDox()) + throw std::logic_error("Input documentation was not specified for this operator."); + return input_info_[input_idx].doc.type_doc; } DLL_PUBLIC std::string OpSchema::GetInputDox(int input_idx) const { CheckInputIndex(input_idx); - DALI_ENFORCE(HasInputDox(), "Input documentation was not specified for this operator."); - return input_dox_[input_idx].doc; + if (!HasInputDox()) + throw std::logic_error("Input documentation was not specified for this operator."); + return input_info_[input_idx].doc.doc; } @@ -586,11 +697,6 @@ int OpSchema::NumOutput() const { } -bool OpSchema::AllowsInstanceGrouping() const { - return allow_instance_grouping_; -} - - bool OpSchema::IsSequenceOperator() const { return is_sequence_operator_; } @@ -636,29 +742,19 @@ const std::string &OpSchema::DeprecationMessage() const { } -DLL_PUBLIC bool OpSchema::IsDeprecatedArg(const std::string &arg_name) const { - if (deprecated_arguments_.find(arg_name) != deprecated_arguments_.end()) - return true; - for (const auto &parent_name : parents_) { - const OpSchema &parent = SchemaRegistry::GetSchema(parent_name); - if (parent.IsDeprecatedArg(arg_name)) - return true; - } +DLL_PUBLIC bool OpSchema::IsDeprecatedArg(std::string_view arg_name) const { + if (auto *arg = FindArgument(arg_name)) + return arg->deprecated != nullptr; return false; } -DLL_PUBLIC const DeprecatedArgDef &OpSchema::DeprecatedArgMeta(const std::string &arg_name) const { - auto it = deprecated_arguments_.find(arg_name); - if (it != deprecated_arguments_.end()) { - return it->second; - } - for (const auto &parent_name : parents_) { - const OpSchema &parent = SchemaRegistry::GetSchema(parent_name); - if (parent.IsDeprecatedArg(arg_name)) - return parent.DeprecatedArgMeta(arg_name); - } - DALI_FAIL(make_string("No deprecation metadata for argument \"", arg_name, "\" found.")); +DLL_PUBLIC const ArgumentDeprecation &OpSchema::DeprecatedArgInfo(std::string_view arg_name) const { + auto &arg = GetArgument(arg_name); + if (!arg.deprecated) + throw std::invalid_argument( + make_string("No deprecation info for argument \"", arg_name, "\" found.")); + return *arg.deprecated; } @@ -743,118 +839,77 @@ bool OpSchema::SupportsInPlace(const OpSpec &spec) const { void OpSchema::CheckArgs(const OpSpec &spec) const { - std::vector<string> vec = spec.ListArguments(); - std::set<std::string> req_arguments_left; - auto required_arguments = GetRequiredArguments(); - for (auto &arg_pair : required_arguments) { - req_arguments_left.insert(arg_pair.first); + auto args_in_spec = spec.ListArgumentNames(); + for (const auto &name : args_in_spec) { + auto *arg = FindArgument(name); + if (!arg) + throw std::invalid_argument(make_string("Got an unexpected argument \"", name, "\"")); } - for (const auto &s : vec) { - DALI_ENFORCE(HasArgument(s) || internal_arguments_.find(s) != internal_arguments_.end(), - "Got an unexpected argument \"" + s + "\""); - std::set<std::string>::iterator it = req_arguments_left.find(s); - if (it != req_arguments_left.end()) { - req_arguments_left.erase(it); - } - } - if (!req_arguments_left.empty()) { - std::string ret = "Not all required arguments were specified for op \"" + this->name() + - "\". Please specify values for arguments: "; - for (auto &str : req_arguments_left) { - ret += "\"" + str + "\", "; - } - ret.erase(ret.size() - 2); - ret += "."; - DALI_FAIL(ret); + std::vector<std::string_view> missing_args; + for (auto &[name, arg] : GetFlattenedArguments()) { + if (arg->required) + if (!args_in_spec.count(name)) + missing_args.push_back(name); } -} - -bool OpSchema::HasRequiredArgument(const std::string &name, bool local_only) const { - bool ret = arguments_.find(name) != arguments_.end(); - if (ret || local_only) { - return ret; - } - for (const auto &p : parents_) { - const OpSchema &parent = SchemaRegistry::GetSchema(p); - ret = ret || parent.HasRequiredArgument(name); + if (!missing_args.empty()) { + std::stringstream ss; + ss << "Not all required arguments were specified for op \"" << name() << "\". " + "Please specify values for arguments: " + "\""; + join(ss, missing_args, "\", \""); + ss << "\""; + throw std::runtime_error(ss.str()); } - return ret; } - -bool OpSchema::HasOptionalArgument(const std::string &name, bool local_only) const { - bool ret = optional_arguments_.find(name) != optional_arguments_.end(); - if (ret || local_only) { - return ret; - } - for (const auto &p : parents_) { - const OpSchema &parent = SchemaRegistry::GetSchema(p); - ret = ret || parent.HasOptionalArgument(name); - } - return ret; +bool OpSchema::HasOptionalArgument(std::string_view name) const { + return FindArgument(name, [](const ArgumentDef &arg) { + return !arg.required; + }); } - -bool OpSchema::HasInternalArgument(const std::string &name, bool local_only) const { - bool ret = internal_arguments_.find(name) != internal_arguments_.end(); - if (ret || local_only) { - return ret; - } - for (const auto &p : parents_) { - const OpSchema &parent = SchemaRegistry::GetSchema(p); - ret = ret || parent.HasInternalArgument(name); +bool OpSchema::HasInternalArgument(std::string_view name) const { + if (default_) { + auto it = arguments_.find(name); + if (it != arguments_.end()) + return it->second.internal; + return false; + } else { + return Default().HasInternalArgument(name); } - return ret; } +const ArgumentDef &OpSchema::GetArgument(std::string_view name) const { + if (auto *arg = FindArgument(name)) + return *arg; + throw invalid_key(make_string( + "Argument \"", name, "\" is not defined for operator \"", this->name(), "\".")); +} -std::string OpSchema::GetArgumentDox(const std::string &name) const { - DALI_ENFORCE(HasArgument(name), - "Argument \"" + name + "\" is not supported by operator \"" + this->name() + "\"."); - if (HasRequiredArgument(name)) { - return GetRequiredArguments().at(name).doc; - } else { - return GetOptionalArguments().at(name).doc; - } +const std::string &OpSchema::GetArgumentDox(std::string_view name) const { + return GetArgument(name).doc; } -DALIDataType OpSchema::GetArgumentType(const std::string &name) const { - DALI_ENFORCE(HasArgument(name), - "Argument \"" + name + "\" is not supported by operator \"" + this->name() + "\"."); - if (HasRequiredArgument(name)) { - return GetRequiredArguments().at(name).dtype; - } else { - return GetOptionalArguments().at(name).dtype; - } +DALIDataType OpSchema::GetArgumentType(std::string_view name) const { + return GetArgument(name).dtype; } -bool OpSchema::HasArgumentDefaultValue(const std::string &name) const { - DALI_ENFORCE(HasArgument(name, true), - "Argument \"" + name + "\" is not supported by operator \"" + this->name() + "\"."); - if (HasRequiredArgument(name)) { - return false; - } - if (HasInternalArgument(name, true)) { - return true; - } - auto *value_ptr = GetOptionalArguments().at(name).default_value; - return value_ptr != nullptr; +bool OpSchema::HasArgumentDefaultValue(std::string_view name) const { + return GetArgument(name).default_value != nullptr; } -std::string OpSchema::GetArgumentDefaultValueString(const std::string &name) const { - DALI_ENFORCE(HasOptionalArgument(name), "Argument \"" + name + - "\" is either not supported by operator \"" + - this->name() + "\" or is not optional."); - - auto *value_ptr = GetOptionalArguments().at(name).default_value; +std::string OpSchema::GetArgumentDefaultValueString(std::string_view name) const { + auto *value_ptr = GetArgument(name).default_value.get(); - DALI_ENFORCE(value_ptr, - make_string("Argument \"", name, - "\" in operator \"" + this->name() + "\" has no default value.")); + if (!value_ptr) { + throw std::invalid_argument( + make_string("Argument \"", name, + "\" in operator \"", this->name(), "\" has no default value.")); + } auto &val = *value_ptr; auto str = val.ToString(); @@ -871,71 +926,37 @@ std::string OpSchema::GetArgumentDefaultValueString(const std::string &name) con std::vector<std::string> OpSchema::GetArgumentNames() const { std::vector<std::string> ret; - const auto &required = GetRequiredArguments(); - const auto &optional = GetOptionalArguments(); - const auto &deprecated = GetDeprecatedArguments(); - for (auto &arg_pair : required) { - ret.push_back(arg_pair.first); - } - for (auto &arg_pair : optional) { - if (!arg_pair.second.hidden) { - ret.push_back(arg_pair.first); - } - } - for (auto &arg_pair : deprecated) { - // Deprecated aliases only appear in `deprecated` but regular - // deprecated arguments appear both in `deprecated` and either `required` or `optional`. - if (required.find(arg_pair.first) == required.end() && - optional.find(arg_pair.first) == optional.end()) - ret.push_back(arg_pair.first); - } + const auto &args = GetFlattenedArguments(); + for (auto it = args.begin(); it != args.end(); ++it) + if (!it->second->hidden) + ret.push_back(it->first); return ret; } -bool OpSchema::IsTensorArgument(const std::string &name) const { +bool OpSchema::IsTensorArgument(std::string_view name) const { return FindTensorArgument(name); } -bool OpSchema::ArgSupportsPerFrameInput(const std::string &arg_name) const { - auto arg_desc = FindTensorArgument(arg_name); - return arg_desc && arg_desc->supports_per_frame; +bool OpSchema::ArgSupportsPerFrameInput(std::string_view arg_name) const { + return FindArgument(arg_name, [](const ArgumentDef &arg) { return arg.per_frame; }); } -const TensorArgDesc *OpSchema::FindTensorArgument(const std::string &name) const { - auto it = tensor_arguments_.find(name); - if (it != tensor_arguments_.end()) { - return &it->second; - } - for (const auto &p : parents_) { - const OpSchema &parent = SchemaRegistry::GetSchema(p); - auto desc = parent.FindTensorArgument(name); - if (desc) { - return desc; - } - } - return nullptr; +const ArgumentDef *OpSchema::FindTensorArgument(std::string_view name) const { + return FindArgument(name, [](const ArgumentDef &def) { return def.tensor; }); } - -void OpSchema::CheckArgument(const std::string &s) { - DALI_ENFORCE(!HasArgument(s, false, true), "Argument \"" + s + "\" already added to the schema"); - DALI_ENFORCE(internal_arguments_.find(s) == internal_arguments_.end(), - "Argument name \"" + s + "\" is reserved for internal use"); -} - - void OpSchema::CheckInputIndex(int index) const { - DALI_ENFORCE(index >= 0 && index < max_num_input_, - "Output index (=" + std::to_string(index) + ") out of range [0.." + - std::to_string(max_num_input_) + ").\nWas NumInput called?"); + if (index < 0 && index >= max_num_input_) + throw std::out_of_range(make_string( + "Input index ", index, " is out of range [0..", max_num_input_, ").\nWas NumInput called?")); } -std::string OpSchema::DefaultDeprecatedArgMsg(const std::string &arg_name, - const std::string &renamed_to, bool removed) const { +std::string OpSchema::DefaultDeprecatedArgMsg(std::string_view arg_name, + std::string_view renamed_to, bool removed) const { std::stringstream ss; if (removed) { ss << "The argument `" << arg_name @@ -950,70 +971,19 @@ std::string OpSchema::DefaultDeprecatedArgMsg(const std::string &arg_name, } -std::map<std::string, RequiredArgumentDef> OpSchema::GetRequiredArguments() const { - auto ret = arguments_; - for (const auto &parent_name : parents_) { - const OpSchema &parent = SchemaRegistry::GetSchema(parent_name); - const auto &parent_args = parent.GetRequiredArguments(); - ret.insert(parent_args.begin(), parent_args.end()); - } - return ret; -} - - -std::map<std::string, DefaultedArgumentDef> OpSchema::GetOptionalArguments() const { - auto ret = optional_arguments_; - for (const auto &parent_name : parents_) { - const OpSchema &parent = SchemaRegistry::GetSchema(parent_name); - const auto &parent_args = parent.GetOptionalArguments(); - ret.insert(parent_args.begin(), parent_args.end()); - } - return ret; -} - - -std::map<std::string, DeprecatedArgDef> OpSchema::GetDeprecatedArguments() const { - auto ret = deprecated_arguments_; - for (const auto &parent_name : parents_) { - const OpSchema &parent = SchemaRegistry::GetSchema(parent_name); - const auto &parent_args = parent.GetDeprecatedArguments(); - ret.insert(parent_args.begin(), parent_args.end()); - } - return ret; +const Value *OpSchema::FindDefaultValue(std::string_view name) const { + if (auto *arg = FindArgument(name)) + return arg->default_value.get(); + else + return nullptr; } -std::pair<const OpSchema *, const Value *> OpSchema::FindDefaultValue(const std::string &name, - bool local_only, - bool include_internal) const { - auto it = optional_arguments_.find(name); - if (it != optional_arguments_.end()) { - return {this, it->second.default_value}; - } - if (include_internal) { - it = internal_arguments_.find(name); - if (it != internal_arguments_.end()) { - return {this, it->second.default_value}; - } - } - if (local_only) - return {nullptr, nullptr}; - - for (const auto &p : parents_) { - const OpSchema &parent = SchemaRegistry::GetSchema(p); - auto schema_val = parent.FindDefaultValue(name, false, include_internal); - if (schema_val.first && schema_val.second) - return schema_val; - } - return {nullptr, nullptr}; -} - - -bool OpSchema::HasArgument(const std::string &name, - bool include_internal, - bool local_only) const { - return HasRequiredArgument(name, local_only) || HasOptionalArgument(name, local_only) || - (include_internal && HasInternalArgument(name, true)); +bool OpSchema::HasArgument(std::string_view name, bool include_internal) const { + if (auto *arg = FindArgument(name)) + return arg && (include_internal || !arg->internal); + else + return false; } } // namespace dali diff --git a/dali/pipeline/operator/op_schema.h b/dali/pipeline/operator/op_schema.h index 236b072f317..6a1021635cd 100644 --- a/dali/pipeline/operator/op_schema.h +++ b/dali/pipeline/operator/op_schema.h @@ -18,10 +18,12 @@ #include <functional> #include <map> #include <memory> +#include <mutex> #include <numeric> #include <set> #include <sstream> #include <string> +#include <string_view> #include <utility> #include <vector> @@ -37,31 +39,6 @@ namespace dali { class OpSpec; -struct RequiredArgumentDef { - std::string doc; - DALIDataType dtype; -}; - -struct DefaultedArgumentDef { - std::string doc; - DALIDataType dtype; - Value *default_value; - // As opposed to purely internal argument, the hidden argument - // can be specified as any other argument on per-operator basis, through Python API, etc. - // It is just hidden from the docs. - bool hidden; -}; - -struct DeprecatedArgDef { - std::string renamed_to = {}; - std::string msg = {}; - bool removed = false; -}; - -struct TensorArgDesc { - bool supports_per_frame = false; -}; - enum class InputDevice : uint8_t { /** CPU for CPU and Mixed operators; GPU for GPU operators. */ MatchBackend = 0, @@ -93,46 +70,113 @@ enum class InputDevice : uint8_t { Metadata, }; +struct ArgumentDeprecation { + ArgumentDeprecation() = default; + ArgumentDeprecation(string renamed_to, string msg, bool removed = false) + : renamed_to(renamed_to), msg(msg), removed(removed) {} + + std::string renamed_to; + std::string msg; + bool removed = false; +}; + +class OpSchema; + +struct ArgumentDef { + const OpSchema *defined_in; + std::string name; + std::string doc; + DALIDataType dtype; + + // TODO(michalz): Convert to bit fields in C++20 (before C++20 bit fields can't have initializers) + bool required = false; + bool tensor = false; + bool per_frame = false; + bool internal = false; + bool hidden = false; + + std::unique_ptr<Value> default_value; + std::unique_ptr<ArgumentDeprecation> deprecated; +}; + +struct InputInfo { + std::string name; + InputDevice device; + + struct InputDoc { + std::string type_doc; + std::string doc; + } doc; + + std::vector<TensorLayout> layouts; +}; + +namespace detail { + +/** A helper class for lazy evaluation + * + * In some cases, it's impossible or wasteful to compute a value eagerly. + * This class provides a thread-safe storage for such value. + * + * Usage: place in your class (possibly as a mutable field) and call Get with a function + * that returns a value convertible to T. This function will be called only once per LazyValue's + * lifetime. + * + * NOTE: Copying the lazy value is a no-op. + */ +template <typename T> +struct LazyValue { + LazyValue() = default; + LazyValue(const LazyValue &) {} + LazyValue(LazyValue &&) {} + LazyValue &operator=(const LazyValue &) { return *this; } + LazyValue &operator=(LazyValue &&) { return *this; } + + template <typename PopulateFn> + T &Get(PopulateFn &&fn) { + if (data) + return *data; + std::lock_guard g(lock); + if (data) + return *data; + data = std::make_unique<T>(fn()); + return *data; + } + std::unique_ptr<T> data; + std::recursive_mutex lock; +}; +} // namespace detail + + class DLL_PUBLIC OpSchema { public: typedef std::function<int(const OpSpec &spec)> SpecFunc; - OpSchema(OpSchema &&) = default; + OpSchema(OpSchema &&) = delete; OpSchema(const OpSchema &) = delete; OpSchema &operator=(const OpSchema &) = delete; - OpSchema &operator=(OpSchema &&) = default; + OpSchema &operator=(OpSchema &&) = delete; - DLL_PUBLIC explicit OpSchema(const std::string &name); + explicit OpSchema(std::string_view name); - DLL_PUBLIC inline ~OpSchema() = default; + inline ~OpSchema() = default; - /** - * @brief Returns an empty schema, with only internal arguments - */ - DLL_PUBLIC static const OpSchema &Default(); + /** Returns an empty schema, with only internal arguments */ + static const OpSchema &Default(); - /** - * @brief Returns the schema name of this operator. - */ - DLL_PUBLIC const std::string &name() const; + /** Returns the schema name of this operator. */ + const std::string &name() const; - /** - * @brief Returns the module path of this operator. - */ - DLL_PUBLIC const std::vector<std::string> &ModulePath() const; + /** Returns the module path of this operator. */ + const std::vector<std::string> &ModulePath() const; - /** - * @brief Returns the camel case name of the operator (without the module path) - */ - DLL_PUBLIC const std::string &OperatorName() const; + /** Returns the camel case name of the operator (without the module path) */ + const std::string &OperatorName() const; - /** - * @brief Sets the doc string for this operator. - */ - DLL_PUBLIC OpSchema &DocStr(const string &dox); + /** Sets the doc string for this operator. */ + OpSchema &DocStr(std::string dox); - /** - * @brief Sets the docstring for input. + /** Sets the docstring for input. * * Set the documentation for intput at given `index`. * @@ -143,11 +187,9 @@ class DLL_PUBLIC OpSchema { * name : type_doc * doc */ - DLL_PUBLIC OpSchema &InputDox(int index, const string &name, const string &type_doc, - const string &doc); + OpSchema &InputDox(int index, std::string_view name, std::string type_doc, std::string doc); - /** - * @brief Allows to set a docstring for __call__ method of Operator. + /** Allows to set a docstring for __call__ method of Operator. * * The first line of the string can contain the signature that will be used * in the sphinx-generated documentation, for example: @@ -171,22 +213,18 @@ class DLL_PUBLIC OpSchema { * @param doc * @param append_kwargs_section */ - DLL_PUBLIC OpSchema &CallDocStr(const string &doc, bool append_kwargs_section = false); + OpSchema &CallDocStr(std::string doc, bool append_kwargs_section = false); - /** - * @brief Sets a function that infers the number of outputs this - * op will produce from the ops specification. This is required - * to expose the op to the python interface. + /** Sets a function that infers the number of outputs this op will produce from OpSpec. + * + * This is required to expose the op to the python interface. * * If the ops has a fixed number of outputs, this function - * does not need to be added to the schema + * does not need to be added to the schema. */ - DLL_PUBLIC OpSchema &OutputFn(SpecFunc f); + OpSchema &OutputFn(SpecFunc f); - /** - * @brief Sets a function to determine the number of - * additional outputs (independent of output sets) from an - * op from the op's specification. + /** Sets a function to determine the number of additional outputs from the OpSpec. * * If this function is not set it will be assumed that no * additional outputs can be returned @@ -194,113 +232,74 @@ class DLL_PUBLIC OpSchema { * Use case is to expose additional information (such as random * numbers used within operators) to the user */ - DLL_PUBLIC OpSchema &AdditionalOutputsFn(SpecFunc f); + OpSchema &AdditionalOutputsFn(SpecFunc f); - /** - * @brief Sets the number of inputs that the op can receive. - */ - DLL_PUBLIC OpSchema &NumInput(int n); + /** Sets the number of inputs that the op can receive. */ + OpSchema &NumInput(int n); - /** - * @brief Sets the min and max number of inputs the op can receive. - */ - DLL_PUBLIC OpSchema &NumInput(int min, int max); + /** Sets the min and max number of inputs the op can receive. */ + OpSchema &NumInput(int min, int max); - /** - * @brief Sets the input device for given range of inputs - */ - DLL_PUBLIC OpSchema &InputDevice(int first, int one_past, dali::InputDevice device); + /** Sets the input device for given range of inputs */ + OpSchema &InputDevice(int first, int one_past, dali::InputDevice device); - /** - * @brief Sets the input device for given range of input - */ - DLL_PUBLIC OpSchema &InputDevice(int index, dali::InputDevice device); + /** Sets the input device for given range of input */ + OpSchema &InputDevice(int index, dali::InputDevice device); - /** - * @brief Gets the supported input device for given input - */ - DLL_PUBLIC dali::InputDevice GetInputDevice(int index) const; + /** Gets the supported input device for given input */ + dali::InputDevice GetInputDevice(int index) const; - /** - * @brief Sets the number of outputs that the op can receive. - */ - DLL_PUBLIC OpSchema &NumOutput(int n); + /** Sets the number of outputs that the op can receive. */ + OpSchema &NumOutput(int n); /** * @brief Indicates that this operator should not use auto-generated documentation * of inputs and `__call__` operator with custom signature. */ - DLL_PUBLIC OpSchema &DisableAutoInputDox(); + OpSchema &DisableAutoInputDox(); - /** - * @brief Indicates that multiple instances of this operator cannot share a logical ID to achieve - * uniform processing of multiple input sets - */ - DLL_PUBLIC OpSchema &DisallowInstanceGrouping(); + /** Notes that this operator expects sequence inputs exclusively */ + OpSchema &SequenceOperator(); - /** - * @brief Notes that this operator expects sequence inputs exclusively - */ - DLL_PUBLIC OpSchema &SequenceOperator(); + /** Notes that sequences can be used with this op */ + OpSchema &AllowSequences(); - /** - * @brief Notes that sequences can be used with this op - */ - DLL_PUBLIC OpSchema &AllowSequences(); + /** Notes that the operator can process 3D data. */ + OpSchema &SupportVolumetric(); - /** - * Notes that the operator can process 3D data. - * @return - */ - DLL_PUBLIC OpSchema &SupportVolumetric(); + /** Notes that this operator is internal and shouldn't be exposed in Python API. */ + OpSchema &MakeInternal(); - /** - * @brief Notes that this operator is internal to DALI backend (and shouldn't be exposed in Python - * API) - */ - DLL_PUBLIC OpSchema &MakeInternal(); - - /** - * @brief Notes that this operator doc should not be visible (but the Op is exposed in Python API) - */ - DLL_PUBLIC OpSchema &MakeDocHidden(); + /** Notes that this operator doc should not be visible (but the Op is exposed in Python API) */ + OpSchema &MakeDocHidden(); /** * @brief Notes that for this operator only the doc_str should be visible, but not the docs for - * the inputs, outputs or argument (the Op is exposed in Python API) + * the inputs, outputs or argument (the Op is exposed in Python API) */ - DLL_PUBLIC OpSchema &MakeDocPartiallyHidden(); + OpSchema &MakeDocPartiallyHidden(); - /** - * @brief Notes that this operator is deprecated and optionally specifies the operator to be used - * instead + /** Notes that this operator is deprecated and optionally specifies its successor * * @param in_favor_of schema name of the replacement * @param explanation additional explanation */ - DLL_PUBLIC OpSchema &Deprecate(const std::string &in_favor_of = "", - const std::string &explanation = ""); + OpSchema &Deprecate(std::string in_favor_of = "", + std::string explanation = ""); - /** - * @brief Notes that this operator cannot be serialized - */ - DLL_PUBLIC OpSchema &Unserializable(); + /** Notes that this operator cannot be serialized */ + OpSchema &Unserializable(); - /** - * @brief Adds a required argument to op with its type - */ - DLL_PUBLIC OpSchema &AddArg(const std::string &s, const std::string &doc, - const DALIDataType dtype, bool enable_tensor_input = false, - bool support_per_frame_input = false); + /** Adds a required argument to op with its type */ + OpSchema &AddArg(std::string_view s, std::string doc, + const DALIDataType dtype, bool enable_tensor_input = false, + bool support_per_frame_input = false); - /** - * @brief Adds a required argument of type DALIDataType - */ - DLL_PUBLIC OpSchema &AddTypeArg(const std::string &s, const std::string &doc); + /** Adds a required argument of type DALIDataType */ + OpSchema &AddTypeArg(std::string_view s, std::string doc); - /** - * @brief Sets input layout constraints and default for given input. + /** Sets input layout constraints and default for given input. * * At run-time, when the operator encounters a tensor(list) with specified * layout, but different than one provided to this function, error is raised. @@ -308,10 +307,9 @@ class DLL_PUBLIC OpSchema { * If the input tensor has no layout, the one provided to this function is assumed * if number of dimensions matches. Otherswise, error is raised. */ - DLL_PUBLIC OpSchema &InputLayout(int index, TensorLayout layout); + OpSchema &InputLayout(int index, TensorLayout layout); - /** - * @brief Sets input layout constraints and default for given input. + /** Sets input layout constraints and default for given input. * * At run-time, when the operator encounters a tensor(list) with specified * layout, but not one of those provided to this function, error is raised. @@ -321,134 +319,134 @@ class DLL_PUBLIC OpSchema { * it will be the default value for this input. If number of dimensions doesn't * match any of the layouts provided here, an error is raised. */ - DLL_PUBLIC OpSchema &InputLayout(int index, std::initializer_list<TensorLayout> layouts); + OpSchema &InputLayout(int index, std::initializer_list<TensorLayout> layouts); - /** - * @brief Sets input layout constraint and default for all inputs. + /** Sets input layout constraint and default for all inputs. + * * @see InputLayout(int index, TensorLayout layout) */ - DLL_PUBLIC OpSchema &InputLayout(TensorLayout layout); + OpSchema &InputLayout(TensorLayout layout); - /** - * @brief Sets input layout constraint and default for all inputs. + /** Sets input layout constraint and default for all inputs. + * * @see InputLayout(int index, TensorLayout layout) */ - DLL_PUBLIC OpSchema &InputLayout(std::initializer_list<TensorLayout> layouts); + OpSchema &InputLayout(std::initializer_list<TensorLayout> layouts); /** * @brief Verifies that the layout is valid for given input index and number of dimensions * or returns a default layout if the layout parameter is empty. */ - DLL_PUBLIC const TensorLayout &GetInputLayout(int index, int sample_ndim, - const TensorLayout &layout = {}) const; + const TensorLayout &GetInputLayout(int index, int sample_ndim, + const TensorLayout &layout = {}) const; - DLL_PUBLIC const std::vector<TensorLayout> &GetSupportedLayouts(int input_idx) const; + const std::vector<TensorLayout> &GetSupportedLayouts(int input_idx) const; /** * @brief Adds an optional non-vector argument without default to op * The type can be specified as enum, nullptr_t is used for overload resolution - * If the arg name starts is with an underscore, it will be marked hidden, which + * If the arg name starts with an underscore, it will be marked hidden, which * makes it not listed in the docs. */ - DLL_PUBLIC OpSchema &AddOptionalArg(const std::string &s, const std::string &doc, - DALIDataType dtype, std::nullptr_t, - bool enable_tensor_input = false, - bool support_per_frame_input = false); + OpSchema &AddOptionalArg(std::string_view s, std::string doc, + DALIDataType dtype, std::nullptr_t, + bool enable_tensor_input = false, + bool support_per_frame_input = false); /** * @brief Adds an optional non-vector argument without default to op. - * If the arg name starts is with an underscore, it will be marked hidden, which + * If the arg name starts with an underscore, it will be marked hidden, which * makes it not listed in the docs. */ template <typename T> - DLL_PUBLIC inline OpSchema &AddOptionalArg(const std::string &s, const std::string &doc, - std::nullptr_t, bool enable_tensor_input = false, - bool support_per_frame_input = false) { - AddOptionalArg(s, doc, type2id<T>::value, nullptr, enable_tensor_input, + inline OpSchema &AddOptionalArg(std::string_view name, std::string doc, + std::nullptr_t, bool enable_tensor_input = false, + bool support_per_frame_input = false) { + AddOptionalArg(name, doc, type2id<T>::value, nullptr, enable_tensor_input, support_per_frame_input); return *this; } - /** - * @brief Adds an optional non-vector argument to op + /** Adds an optional non-vector argument to op * - * If the arg name starts is with an underscore, it will be marked hidden, which - * makes it not listed in the docs. + * If the arg name starts with an underscore, it will be marked hidden, which + * makes it not listed in the docs. */ template <typename T> - DLL_PUBLIC inline std::enable_if_t<!is_vector<T>::value && !is_std_array<T>::value, OpSchema &> - AddOptionalArg(const std::string &s, const std::string &doc, T default_value, + inline std::enable_if_t<!is_vector<T>::value && !is_std_array<T>::value, OpSchema &> + AddOptionalArg(std::string_view name, std::string doc, T default_value, bool enable_tensor_input = false, bool support_per_frame_input = false) { static_assert( !std::is_same<T, DALIDataType>::value, R"(Use `AddOptionalTypeArg` instead. `AddOptionalArg` with a default value should not be used with DALIDataType, to avoid confusion with `AddOptionalArg<type>(name, doc, nullptr)`)"); - CheckArgument(s); - auto to_store = Value::construct(default_value); - optional_arguments_[s] = {doc, type2id<T>::value, to_store.get(), ShouldHideArgument(s)}; - optional_arguments_unq_.push_back(std::move(to_store)); - if (enable_tensor_input) { - tensor_arguments_[s] = {support_per_frame_input}; - } + auto &arg = AddArgumentImpl(name, type2id<T>::value, + Value::construct(default_value), + std::move(doc)); + arg.tensor = enable_tensor_input; + arg.per_frame = support_per_frame_input; return *this; } - /** - * @brief Adds an optional argument of type DALIDataType with a default value + /** Adds an optional argument of type DALIDataType with a default value * - * If the arg name starts is with an underscore, it will be marked hidden, which - * makes it not listed in the docs. + * If the arg name starts with an underscore, it will be marked hidden, which + * makes it not listed in the docs. */ - DLL_PUBLIC OpSchema &AddOptionalTypeArg(const std::string &s, const std::string &doc, - DALIDataType default_value); + OpSchema &AddOptionalTypeArg(std::string_view name, std::string doc, DALIDataType default_value); - /** - * @brief Adds an optional argument of type DALIDataType without a default value + /** Adds an optional argument of type DALIDataType without a default value * - * If the arg name starts is with an underscore, it will be marked hidden, which - * makes it not listed in the docs. + * If the arg name starts with an underscore, it will be marked hidden, which + * makes it not listed in the docs. */ - DLL_PUBLIC OpSchema &AddOptionalTypeArg(const std::string &s, const std::string &doc); + OpSchema &AddOptionalTypeArg(std::string_view name, std::string doc); - DLL_PUBLIC OpSchema &AddOptionalArg(const std::string &s, const std::string &doc, - const char *default_value); + inline OpSchema &AddOptionalArg(std::string_view name, std::string doc, + const char *default_value) { + return AddOptionalArg(name, std::move(doc), std::string(default_value), false); + } - /** - * @brief Adds an optional vector argument to op + inline OpSchema &AddOptionalArg(std::string_view name, std::string doc, + std::string_view default_value) { + return AddOptionalArg(name, std::move(doc), std::string(default_value), false); + } + + /** Adds an optional vector argument to op * - * If the arg name starts is with an underscore, it will be marked hidden, which - * makes it not listed in the docs. + * If the arg name starts with an underscore, it will be marked hidden, which + * makes it not listed in the docs. */ template <typename T> - DLL_PUBLIC inline OpSchema &AddOptionalArg(const std::string &s, const std::string &doc, - std::vector<T> default_value, - bool enable_tensor_input = false, - bool support_per_frame_input = false) { - CheckArgument(s); + inline OpSchema &AddOptionalArg(std::string_view name, std::string doc, + std::vector<T> default_value, + bool enable_tensor_input = false, + bool support_per_frame_input = false) { using S = argument_storage_t<T>; - auto to_store = Value::construct(detail::convert_vector<S>(default_value)); - bool hide_argument = ShouldHideArgument(s); - optional_arguments_[s] = {doc, type2id<std::vector<T>>::value, to_store.get(), hide_argument}; - optional_arguments_unq_.push_back(std::move(to_store)); - if (enable_tensor_input) { - tensor_arguments_[s] = {support_per_frame_input}; - } + auto value = Value::construct(detail::convert_vector<S>(default_value)); + auto &arg = AddArgumentImpl(name, + type2id<std::vector<T>>::value, + std::move(value), + std::move(doc)); + arg.tensor = enable_tensor_input; + arg.per_frame = support_per_frame_input; return *this; } - /** - * @brief Marks an argument as deprecated in favor of a new argument + OpSchema &AddRandomSeedArg(); + + /** Marks an argument as deprecated in favor of a new argument * * Providing renamed_to means the argument has been renamed and we can safely * propagate the value to the new argument name. */ - DLL_PUBLIC OpSchema &DeprecateArgInFavorOf(const std::string &arg_name, std::string renamed_to, - std::string msg = {}); + OpSchema &DeprecateArgInFavorOf(std::string_view arg_name, std::string renamed_to, + std::string msg = {}); - /** - * @brief Marks an argument as deprecated + /** Marks an argument as deprecated + * * @remarks There are three ways to deprecate an argument * 1. removed==true, means the operator will not use the * argument at all and it can be safely discarded. @@ -456,28 +454,24 @@ used with DALIDataType, to avoid confusion with `AddOptionalArg<type>(name, doc, * deprecated argument until it is finally removed completely from the schema. * 3. For renaming the argument see DeprecateArgInFavorOf */ - DLL_PUBLIC OpSchema &DeprecateArg(const std::string &arg_name, bool removed = true, - std::string msg = {}); + OpSchema &DeprecateArg(std::string_view arg_name, bool removed = true, + std::string msg = {}); /** * @brief Sets a function that infers whether the op can - * be executed in-place depending on the ops specification. + * be executed in-place depending on the ops specification. */ - DLL_PUBLIC OpSchema &InPlaceFn(SpecFunc f); + OpSchema &InPlaceFn(SpecFunc f); - /** - * @brief Sets a parent (which could be used as a storage of default parameters) - * Does not support cyclic dependency. There can be multiple parents - * and the lookup is transitive. + /** Sets a parent (which could be used as a storage of default parameters) + * + * Does not support cyclic dependency. There can be multiple parents and the lookup is transitive. * Only arguments are inherited, inputs and outputs are not. */ - DLL_PUBLIC OpSchema &AddParent(const std::string &parentName); + OpSchema &AddParent(std::string parent_name); - /** - * @brief Notes that this operator should not be pruned from - * a graph even if its outputs are unused. - */ - DLL_PUBLIC OpSchema &NoPrune(); + /** Notes that this operator should not be pruned from a graph even if its outputs are unused. */ + OpSchema &NoPrune(); /** * @brief Informs that the data passes through this operator unchanged, only @@ -493,331 +487,294 @@ used with DALIDataType, to avoid confusion with `AddOptionalArg<type>(name, doc, * @param inout - tells which inputs are passed through to which outputs. * Only (partial - as in partial function) bijective mappings are allowed. */ - DLL_PUBLIC OpSchema &PassThrough(const std::map<int, int> &inout); + OpSchema &PassThrough(const std::map<int, int> &inout); /** * @brief Informs that the operator passes through data unchanged, sharing the allocation * from input to output. - * The data is passed on sample basis, allowing to mix any input to any output. + * + * The data is passed on sample basis, allowing to mix any input to any output. */ - DLL_PUBLIC OpSchema &SamplewisePassThrough(); + OpSchema &SamplewisePassThrough(); - /** - * @brief Get parent schemas (non-recursive) - */ - DLL_PUBLIC const vector<std::string> &GetParents() const; + /** Get parent schemas (non-recursive) */ + const vector<std::string> &GetParentNames() const; - /** - * @brief Get the docstring of the operator - provided by DocStr in the schema definition. - */ - DLL_PUBLIC string Dox() const; + const vector<const OpSchema *> &GetParents() const; - /** - * @brief Return true wether the default input docs can be used - */ - DLL_PUBLIC bool CanUseAutoInputDox() const; + /** Get the docstring of the operator - provided by DocStr in the schema definition. */ + string Dox() const; + + /** Return true wether the default input docs can be used */ + bool CanUseAutoInputDox() const; /** * @brief Whether the docstring for kwargs should be automatically generated and appended to the - * one provided in CallDocStr. + * one provided in CallDocStr. */ - DLL_PUBLIC bool AppendKwargsSection() const; + bool AppendKwargsSection() const; - /** - * @brief Return true when `__call__` docstring was explicitly set + /** Return true when `__call__` docstring was explicitly set * * Should be considered as highest preference */ - DLL_PUBLIC bool HasCallDox() const; + bool HasCallDox() const; - /** - * @brief Get the documentation for Operator __call__ signature provided by CallDocStr. - */ - DLL_PUBLIC std::string GetCallDox() const; + /** Get the documentation for Operator __call__ signature provided by CallDocStr. */ + std::string GetCallDox() const; - /** - * @brief Check if this operator has input docstrings provided - */ - DLL_PUBLIC bool HasInputDox() const; + /** Check if this operator has input docstrings provided */ + bool HasInputDox() const; /** * @brief List all the inputs that should appear in `__call__` signature based on the input * docs that were specified. Requires HasInputDox() to return true * */ - DLL_PUBLIC std::string GetCallSignatureInputs() const; + std::string GetCallSignatureInputs() const; - /** - * @brief Get the docstring name of the input at given index. - */ - DLL_PUBLIC std::string GetInputName(int input_idx) const; + /** Get the docstring name of the input at given index. */ + std::string GetInputName(int input_idx) const; - /** - * @brief Get the docstring type of the input at given index. - */ - DLL_PUBLIC std::string GetInputType(int input_idx) const; + /** Get the docstring type of the input at given index. */ + std::string GetInputType(int input_idx) const; - /** - * @brief Get the docstring text of the input at given index. - */ - DLL_PUBLIC std::string GetInputDox(int input_idx) const; + /** Get the docstring text of the input at given index. */ + std::string GetInputDox(int input_idx) const; - /** - * @brief Get the maximal number of accepted inputs. - */ - DLL_PUBLIC int MaxNumInput() const; + /** Get the maximal number of accepted inputs. */ + int MaxNumInput() const; - /** - * @brief Get the minimal number of required inputs. - */ - DLL_PUBLIC int MinNumInput() const; + /** Get the minimal number of required inputs. */ + int MinNumInput() const; - /** - * @brief Get the number of static outputs, see also CalculateOutputs and - * CalculateAdditionalOutputs - */ - DLL_PUBLIC int NumOutput() const; - - DLL_PUBLIC bool AllowsInstanceGrouping() const; + /** Get the number of static outputs, see also CalculateOutputs and CalculateAdditionalOutputs */ + int NumOutput() const; - /** - * @brief Whether this operator accepts ONLY sequences as inputs - */ - DLL_PUBLIC bool IsSequenceOperator() const; + /** Whether this operator accepts ONLY sequences as inputs */ + bool IsSequenceOperator() const; - /** - * @brief Whether this operator accepts sequences as inputs - */ - DLL_PUBLIC bool AllowsSequences() const; + /** Whether this operator accepts sequences as inputs */ + bool AllowsSequences() const; - /** - * @brief Whether this operator accepts volumes as inputs - */ - DLL_PUBLIC bool SupportsVolumetric() const; + /** Whether this operator accepts volumes as inputs */ + bool SupportsVolumetric() const; - /** - * @brief Whether this operator is internal to DALI backend (and shouldn't be exposed in Python - * API) - */ - DLL_PUBLIC bool IsInternal() const; + /** Whether this operator is internal to DALI backend (and shouldn't be exposed in Python API) */ + bool IsInternal() const; - /** - * @brief Whether this operator doc should not be visible (but the Op is exposed in Python API) - */ - DLL_PUBLIC bool IsDocHidden() const; + /** Whether this operator doc should not be visible (but the Op is exposed in Python API) */ + bool IsDocHidden() const; /** - * @brief Whether this operator doc should be visible without documenting any parameters + * Whether this operator doc should be visible without documenting any parameters. * Useful for deprecated ops. */ - DLL_PUBLIC bool IsDocPartiallyHidden() const; + bool IsDocPartiallyHidden() const; - /** - * @brief Whether this operator is deprecated. - */ - DLL_PUBLIC bool IsDeprecated() const; + /** Whether this operator is deprecated. */ + bool IsDeprecated() const; - /** - * @brief What operator replaced the current one. - */ - DLL_PUBLIC const std::string &DeprecatedInFavorOf() const; + /** What operator replaced the current one. */ + const std::string &DeprecatedInFavorOf() const; - /** - * @brief Additional deprecation message - */ - DLL_PUBLIC const std::string &DeprecationMessage() const; + /** Additional deprecation message */ + const std::string &DeprecationMessage() const; - /** - * @brief Whether given argument is deprecated. - */ - DLL_PUBLIC bool IsDeprecatedArg(const std::string &arg_name) const; + /** Whether given argument is deprecated. */ + bool IsDeprecatedArg(std::string_view arg_name) const; - /** - * @brief Metadata about the argument deprecation - error message, renaming, removal, etc. - */ - DLL_PUBLIC const DeprecatedArgDef &DeprecatedArgMeta(const std::string &arg_name) const; + /** Information about the argument deprecation - error message, renaming, removal, etc. */ + const ArgumentDeprecation &DeprecatedArgInfo(std::string_view arg_name) const; - /** - * @brief Check whether this operator calculates number of outputs statically + /** Check whether this operator calculates number of outputs statically + * * @return false if static, true if dynamic */ - DLL_PUBLIC bool HasOutputFn() const; + bool HasOutputFn() const; - /** - * @brief Check whether this operator won't be pruned out of graph even if not used. - */ - DLL_PUBLIC bool IsNoPrune() const; + /** Check whether this operator won't be pruned out of graph even if not used. */ + bool IsNoPrune() const; - DLL_PUBLIC bool IsSerializable() const; + bool IsSerializable() const; - /** - * @brief Returns the index of the output to which the input is passed. + /** Returns the index of the output to which the input is passed. + * * @param strict consider only fully passed through batches * @return Output indicies or empty vector if given input is not passed through. */ - DLL_PUBLIC std::vector<int> GetPassThroughOutputIdx(int input_idx, const OpSpec &spec, + std::vector<int> GetPassThroughOutputIdx(int input_idx, const OpSpec &spec, bool strict = true) const; - /** - * @brief Is the input_idx passed through to output_idx - */ - DLL_PUBLIC bool IsPassThrough(int input_idx, int output_idx, bool strict = true) const; + /** Is the input_idx passed through to output_idx */ + bool IsPassThrough(int input_idx, int output_idx, bool strict = true) const; - /** - * @brief Does this operator pass through any data? - */ - DLL_PUBLIC bool HasPassThrough() const; + /** Does this operator pass through any data? */ + bool HasPassThrough() const; - /** - * @brief Does this operator pass through any data as a whole batch to batch? - */ - DLL_PUBLIC bool HasStrictPassThrough() const; + /** Does this operator pass through any data as a whole batch to batch? */ + bool HasStrictPassThrough() const; - /** - * @brief Does this operator pass through any data by the means of sharing individual samples? - */ - DLL_PUBLIC bool HasSamplewisePassThrough() const; + /** Does this operator pass through any data by the means of sharing individual samples? */ + bool HasSamplewisePassThrough() const; - /** - * @brief Return the static number of outputs or calculate regular outputs using output_fn - */ - DLL_PUBLIC int CalculateOutputs(const OpSpec &spec) const; + /** Return the static number of outputs or calculate regular outputs using output_fn */ + int CalculateOutputs(const OpSpec &spec) const; - /** - * @brief Calculate the number of additional outputs obtained from additional_outputs_fn - */ - DLL_PUBLIC int CalculateAdditionalOutputs(const OpSpec &spec) const; + /** Calculate the number of additional outputs obtained from additional_outputs_fn */ + int CalculateAdditionalOutputs(const OpSpec &spec) const; - DLL_PUBLIC bool SupportsInPlace(const OpSpec &spec) const; + bool SupportsInPlace(const OpSpec &spec) const; - DLL_PUBLIC void CheckArgs(const OpSpec &spec) const; + void CheckArgs(const OpSpec &spec) const; - /** - * @brief Get default value of optional or internal argument. The default value must be declared - */ + /** Get default value of optional or internal argument. The default value must be declared */ template <typename T> - DLL_PUBLIC inline T GetDefaultValueForArgument(const std::string &s) const; + inline T GetDefaultValueForArgument(std::string_view s) const; - DLL_PUBLIC bool HasRequiredArgument(const std::string &name, bool local_only = false) const; + /** Checks if the argument with the given name is defined and not required */ + bool HasOptionalArgument(std::string_view name) const; - DLL_PUBLIC bool HasOptionalArgument(const std::string &name, bool local_only = false) const; + /** Checks if the argument with the given name is defined and marked as internal */ + bool HasInternalArgument(std::string_view name) const; - DLL_PUBLIC bool HasInternalArgument(const std::string &name, bool local_only = false) const; - - /** - * @brief Finds default value for a given argument + /** Finds default value for a given argument + * * @return A pair of the defining schema and the value */ - DLL_PUBLIC std::pair<const OpSchema *, const Value *> FindDefaultValue( - const std::string &arg_name, bool local_only = false, bool include_internal = true) const; + const Value *FindDefaultValue(std::string_view arg_name) const; - /** - * @brief Checks whether the schema defines an argument with the given name + /** Checks whether the schema defines an argument with the given name + * * @param include_internal - returns `true` also for internal/implicit arugments * @param local_only - doesn't look in parent schemas */ - DLL_PUBLIC bool HasArgument(const std::string &name, - bool include_internal = false, - bool local_only = false) const; + bool HasArgument(std::string_view name, bool include_internal = false) const; - /** - * @brief Get docstring for operator argument of given name (Python Operator Kwargs). - */ - DLL_PUBLIC std::string GetArgumentDox(const std::string &name) const; + /** Returns true if the operator has a "seed" argument. */ + bool HasRandomSeedArg() const; - /** - * @brief Get enum representing type of argument of given name. - */ - DLL_PUBLIC DALIDataType GetArgumentType(const std::string &name) const; + /** Get docstring for operator argument of given name (Python Operator Kwargs). */ + const std::string &GetArgumentDox(std::string_view name) const; - /** - * @brief Check if the argument has a default value. - * Required arguments always return false. - * Internal arguments always return true. + /** Get enum representing type of argument of given name. */ + DALIDataType GetArgumentType(std::string_view name) const; + + const ArgumentDef &GetArgument(std::string_view name) const; + + /** Check if the argument has a default value. + * + * Required arguments always return false. + * Internal arguments always return true. */ - DLL_PUBLIC bool HasArgumentDefaultValue(const std::string &name) const; + bool HasArgumentDefaultValue(std::string_view name) const; /** * @brief Get default value of optional argument represented as python-compatible repr string. * Not allowed for internal arguments. */ - DLL_PUBLIC std::string GetArgumentDefaultValueString(const std::string &name) const; + std::string GetArgumentDefaultValueString(std::string_view name) const; - /** - * @brief Get names of all required, optional, and deprecated arguments - */ - DLL_PUBLIC std::vector<std::string> GetArgumentNames() const; - DLL_PUBLIC bool IsTensorArgument(const std::string &name) const; - DLL_PUBLIC bool ArgSupportsPerFrameInput(const std::string &arg_name) const; + /** Get names of all required, optional, and deprecated arguments */ + std::vector<std::string> GetArgumentNames() const; + bool IsTensorArgument(std::string_view name) const; + bool ArgSupportsPerFrameInput(std::string_view arg_name) const; + + bool IsDefault() const { return default_; } private: - static inline bool ShouldHideArgument(const std::string &name) { + struct DefaultSchemaTag {}; + /** Populates the default schema with common arguments */ + explicit OpSchema(DefaultSchemaTag); + + static inline bool ShouldHideArgument(std::string_view name) { return name.size() && name[0] == '_'; } - const TensorArgDesc *FindTensorArgument(const std::string &name) const; + const ArgumentDef *FindTensorArgument(std::string_view name) const; + + template <typename Pred> + const ArgumentDef *FindArgument(std::string_view name, Pred &&pred) const { + if (auto *arg = FindArgument(name)) + return pred(*arg) ? arg : nullptr; + else + return nullptr; + } + + const ArgumentDef *FindArgument(std::string_view name) const { + auto &args = GetFlattenedArguments(); + auto it = args.find(name); + if (it == args.end()) + return nullptr; + return it->second; + } - void CheckArgument(const std::string &s); + void CheckArgument(std::string_view s); void CheckInputIndex(int index) const; - std::string DefaultDeprecatedArgMsg(const std::string &arg_name, const std::string &renamed_to, + std::string DefaultDeprecatedArgMsg(std::string_view arg_name, std::string_view renamed_to, bool removed) const; - /** - * @brief Add internal argument to schema. It always has a value. - */ + /** Add internal argument to schema. It always has a value. */ template <typename T> - void AddInternalArg(const std::string &name, const std::string &doc, T value) { - auto v = Value::construct(value); - internal_arguments_[name] = {doc, type2id<T>::value, v.get(), true}; - internal_arguments_unq_.push_back(std::move(v)); + void AddInternalArg(std::string_view name, std::string doc, T value) { + auto &arg = AddArgumentImpl(name, type2id<T>::value, Value::construct(value), std::move(doc)); + arg.hidden = true; + arg.internal = true; } - std::map<std::string, RequiredArgumentDef> GetRequiredArguments() const; - std::map<std::string, DefaultedArgumentDef> GetOptionalArguments() const; - std::map<std::string, DeprecatedArgDef> GetDeprecatedArguments() const; + ArgumentDef &AddArgumentImpl(std::string_view name); - /** - * @brief Initialize the module_path_ and operator_name_ fields based on the schema name. - */ + ArgumentDef &AddArgumentImpl(std::string_view name, + DALIDataType type, + std::unique_ptr<Value> default_value, + std::string doc); + + /** Initialize the module_path_ and operator_name_ fields based on the schema name. */ void InitNames(); std::string dox_; - /** @brief The name of the schema */ + /// The name of the schema std::string name_; - /** @brief The module path for the operator */ + /// The module path for the operator std::vector<std::string> module_path_; - /** @brief The camel case name of the operator (without the module path) */ + /// The PascalCase name of the operator (without the module path) std::string operator_name_; - bool disable_auto_input_dox_ = false; + //////////////////////////////////////////////////////////////////////////// + // Inputs, outputs and arguments - struct InputDoc { - std::string name = {}; - std::string type_doc = {}; - std::string doc = {}; - }; - std::vector<InputDoc> input_dox_ = {}; - bool input_dox_set_ = false; + /// All locally defined arguments + std::map<std::string, ArgumentDef, std::less<>> arguments_; - // Custom docstring, if not empty should be used in place of input_dox_ descriptions - std::string call_dox_ = {}; + mutable + detail::LazyValue<std::map<std::string, const ArgumentDef *, std::less<>>> flattened_arguments_; + mutable int circular_inheritance_detector_ = 0; - // Whether to append kwargs section to __call__ docstring. Off by default, - // can be turned on for call_dox_ specified manually - bool append_kwargs_section_ = false; + std::map<std::string, const ArgumentDef *, std::less<>> &GetFlattenedArguments() const; + + /// The properties of the inputs + std::vector<InputInfo> input_info_; + bool disable_auto_input_dox_ = false; + bool input_dox_set_ = false; SpecFunc output_fn_, in_place_fn_, additional_outputs_fn_; int min_num_input_ = 0, max_num_input_ = 0; int num_output_ = 0; - bool allow_instance_grouping_ = true; - vector<string> parents_; + //////////////////////////////////////////////////////////////////////////// + // Schema inheritance - bool support_volumetric_ = false; + /// Names of the parent schemas + vector<string> parent_names_; + /// Cached pointers to parent schemas, to avoid repeated lookups + mutable detail::LazyValue<std::vector<const OpSchema *>> parents_; + //////////////////////////////////////////////////////////////////////////// + // Documentation-related + bool support_volumetric_ = false; bool allow_sequences_ = false; bool is_sequence_operator_ = false; @@ -825,52 +782,62 @@ used with DALIDataType, to avoid confusion with `AddOptionalArg<type>(name, doc, bool is_doc_hidden_ = false; bool is_doc_partially_hidden_ = false; - bool no_prune_ = false; + /// Custom docstring, if not empty should be used in place of input_dox_ descriptions + std::string call_dox_ = {}; + /// Whether to append kwargs section to __call__ docstring. Off by default, + /// can be turned on for call_dox_ specified manually + bool append_kwargs_section_ = false; + + //////////////////////////////////////////////////////////////////////////// + // Internal flags + bool no_prune_ = false; bool serializable_ = true; + const bool default_ = false; + //////////////////////////////////////////////////////////////////////////// + // Passthrough operators std::map<int, int> passthrough_map_; bool samplewise_any_passthrough_ = false; + //////////////////////////////////////////////////////////////////////////// + // Deprecation bool is_deprecated_ = false; std::string deprecated_in_favor_of_; std::string deprecation_message_; - - std::map<std::string, RequiredArgumentDef> arguments_; - std::map<std::string, DefaultedArgumentDef> optional_arguments_; - std::map<std::string, DefaultedArgumentDef> internal_arguments_; - std::map<std::string, DeprecatedArgDef> deprecated_arguments_; - std::vector<std::unique_ptr<Value>> optional_arguments_unq_; - std::vector<std::unique_ptr<Value>> internal_arguments_unq_; - std::vector<std::vector<TensorLayout>> input_layouts_; - std::vector<dali::InputDevice> input_devices_; - - std::map<std::string, TensorArgDesc> tensor_arguments_; }; class SchemaRegistry { public: - DLL_PUBLIC static OpSchema &RegisterSchema(const std::string &name); - DLL_PUBLIC static const OpSchema &GetSchema(const std::string &name); - DLL_PUBLIC static const OpSchema *TryGetSchema(const std::string &name); + DLL_PUBLIC static OpSchema &RegisterSchema(std::string_view name); + DLL_PUBLIC static const OpSchema &GetSchema(std::string_view name); + DLL_PUBLIC static const OpSchema *TryGetSchema(std::string_view name); private: inline SchemaRegistry() {} - DLL_PUBLIC static std::map<string, OpSchema> ®istry(); + DLL_PUBLIC static std::map<string, OpSchema, std::less<>> ®istry(); }; template <typename T> -inline T OpSchema::GetDefaultValueForArgument(const std::string &s) const { - const Value *v = FindDefaultValue(s, false, true).second; - DALI_ENFORCE(v != nullptr, - make_string("The argument \"", s, "\" doesn't have a default value in schema \"", - name(), "\".")); +inline T OpSchema::GetDefaultValueForArgument(std::string_view name) const { + const Value *v = FindDefaultValue(name); + if (!v) { + (void)GetArgument(name); // throw an error if the argument is undefined + + // otherwise throw a different error + throw std::invalid_argument(make_string( + "The argument \"", name, "\" in operator \"", this->name(), + "\" doesn't have a default value.")); + } using S = argument_storage_t<T>; const ValueInst<S> *vS = dynamic_cast<const ValueInst<S> *>(v); - DALI_ENFORCE(vS != nullptr, "Unexpected type of the default value for argument \"" + s + - "\" of schema \"" + this->name() + "\""); + if (!vS) { + throw std::invalid_argument(make_string( + "Unexpected type of the default value for argument \"", name, "\" of schema \"", + this->name(), "\"")); + } return static_cast<T>(vS->Get()); } diff --git a/dali/pipeline/operator/op_schema_test.cc b/dali/pipeline/operator/op_schema_test.cc index 7e743a46913..6020e804609 100644 --- a/dali/pipeline/operator/op_schema_test.cc +++ b/dali/pipeline/operator/op_schema_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -80,9 +80,9 @@ TEST(OpSchemaTest, OptionalArgumentDefaultValue) { ASSERT_TRUE(schema.HasArgumentDefaultValue("foo")); ASSERT_FALSE(schema.HasArgumentDefaultValue("no_default")); - ASSERT_THROW(schema.GetDefaultValueForArgument<int>("no_default"), std::runtime_error); + ASSERT_THROW(schema.GetDefaultValueForArgument<int>("no_default"), std::invalid_argument); - ASSERT_THROW(schema.HasArgumentDefaultValue("don't have this one"), std::runtime_error); + ASSERT_THROW(schema.HasArgumentDefaultValue("don't have this one"), invalid_key); } DALI_SCHEMA(Dummy4) @@ -109,8 +109,19 @@ TEST(OpSchemaTest, OptionalArgumentDefaultValueInheritance) { ASSERT_FALSE(schema.HasArgumentDefaultValue("no_default")); ASSERT_FALSE(schema.HasArgumentDefaultValue("no_default2")); - ASSERT_THROW(schema.GetDefaultValueForArgument<int>("no_default"), std::runtime_error); - ASSERT_THROW(schema.GetDefaultValueForArgument<bool>("no_default2"), std::runtime_error); + ASSERT_THROW(schema.GetDefaultValueForArgument<int>("no_default"), std::invalid_argument); + ASSERT_THROW(schema.GetDefaultValueForArgument<bool>("no_default2"), std::invalid_argument); +} + +DALI_SCHEMA(Circular1) + .AddParent("Circular2"); + +DALI_SCHEMA(Circular2) + .AddParent("Circular1"); + +TEST(OpSchemaTest, CircularInheritance) { + EXPECT_THROW(SchemaRegistry::GetSchema("Circular1").HasArgument("foo"), std::logic_error); + EXPECT_THROW(SchemaRegistry::GetSchema("Circular2").HasArgument("foo"), std::logic_error); } DALI_SCHEMA(Dummy5) @@ -143,8 +154,8 @@ TEST(OpSchemaTest, OptionalArgumentDefaultValueMultipleInheritance) { ASSERT_FALSE(schema.HasArgumentDefaultValue("no_default")); ASSERT_FALSE(schema.HasArgumentDefaultValue("no_default2")); - ASSERT_THROW(schema.GetDefaultValueForArgument<int>("no_default"), std::runtime_error); - ASSERT_THROW(schema.GetDefaultValueForArgument<bool>("no_default2"), std::runtime_error); + ASSERT_THROW(schema.GetDefaultValueForArgument<int>("no_default"), std::invalid_argument); + ASSERT_THROW(schema.GetDefaultValueForArgument<bool>("no_default2"), std::invalid_argument); } DALI_SCHEMA(Dummy6) @@ -183,9 +194,9 @@ TEST(OpSchemaTest, OptionalArgumentDefaultValueMultipleParent) { ASSERT_FALSE(schema.HasArgumentDefaultValue("no_default2")); ASSERT_FALSE(schema.HasArgumentDefaultValue("no_default3")); - ASSERT_THROW(schema.GetDefaultValueForArgument<int>("no_default"), std::runtime_error); - ASSERT_THROW(schema.GetDefaultValueForArgument<bool>("no_default2"), std::runtime_error); - ASSERT_THROW(schema.GetDefaultValueForArgument<float>("no_default3"), std::runtime_error); + ASSERT_THROW(schema.GetDefaultValueForArgument<int>("no_default"), std::invalid_argument); + ASSERT_THROW(schema.GetDefaultValueForArgument<bool>("no_default2"), std::invalid_argument); + ASSERT_THROW(schema.GetDefaultValueForArgument<float>("no_default3"), std::invalid_argument); } DALI_SCHEMA(Dummy8) diff --git a/dali/pipeline/operator/op_spec.cc b/dali/pipeline/operator/op_spec.cc index d88eb818ab3..ad512e5375e 100644 --- a/dali/pipeline/operator/op_spec.cc +++ b/dali/pipeline/operator/op_spec.cc @@ -110,7 +110,7 @@ OpSpec& OpSpec::AddArgumentInput(const string &arg_name, const string &inp_name) "Argument '", arg_name, "' is already specified.")); const OpSchema& schema = GetSchemaOrDefault(); DALI_ENFORCE(schema.HasArgument(arg_name), - make_string("Argument '", arg_name, "' is not supported by operator `", + make_string("Argument '", arg_name, "' is not defined for operator `", GetOpDisplayName(*this, true), "`.")); DALI_ENFORCE(schema.IsTensorArgument(arg_name), make_string("Argument '", arg_name, "' in operator `", GetOpDisplayName(*this, true), @@ -124,7 +124,7 @@ OpSpec& OpSpec::AddArgumentInput(const string &arg_name, const string &inp_name) OpSpec& OpSpec::SetInitializedArg(const string& arg_name, std::shared_ptr<Argument> arg) { if (schema_ && schema_->IsDeprecatedArg(arg_name)) { - const auto& deprecation_meta = schema_->DeprecatedArgMeta(arg_name); + const auto& deprecation_meta = schema_->DeprecatedArgInfo(arg_name); // Argument was removed, and we can discard it if (deprecation_meta.removed) { return *this; diff --git a/dali/pipeline/operator/op_spec.h b/dali/pipeline/operator/op_spec.h index a901b38307b..bb7a259fe97 100644 --- a/dali/pipeline/operator/op_spec.h +++ b/dali/pipeline/operator/op_spec.h @@ -15,6 +15,7 @@ #ifndef DALI_PIPELINE_OPERATOR_OP_SPEC_H_ #define DALI_PIPELINE_OPERATOR_OP_SPEC_H_ +#include <functional> #include <map> #include <utility> #include <string> @@ -321,13 +322,13 @@ class DLL_PUBLIC OpSpec { /** * @brief Lists all arguments specified in this spec. */ - DLL_PUBLIC std::vector<std::string> ListArguments() const { - std::vector<std::string> ret; + DLL_PUBLIC auto ListArgumentNames() const { + std::set<std::string_view, std::less<>> ret; for (auto &a : arguments_) { - ret.push_back(a->get_name()); + ret.insert(a->get_name()); } for (auto &a : argument_inputs_) { - ret.push_back(a.first); + ret.insert(a.first); } return ret; } @@ -535,9 +536,9 @@ inline bool OpSpec::TryGetArgumentImpl( } } else if (schema.HasArgument(name, true) && schema.HasArgumentDefaultValue(name)) { // Argument wasn't present locally, get the default from the associated schema if any - auto schema_val = schema.FindDefaultValue(name); + auto *val = schema.FindDefaultValue(name); using VT = const ValueInst<S>; - if (VT *vt = dynamic_cast<VT *>(schema_val.second)) { + if (VT *vt = dynamic_cast<VT *>(val)) { result = static_cast<T>(vt->Get()); return true; } @@ -577,9 +578,9 @@ inline bool OpSpec::TryGetRepeatedArgumentImpl(C &result, const string &name) co } } else if (schema.HasArgument(name, true) && schema.HasArgumentDefaultValue(name)) { // Argument wasn't present locally, get the default from the associated schema if any - auto schema_val = schema.FindDefaultValue(name); + auto *val = schema.FindDefaultValue(name); using VT = const ValueInst<V>; - if (VT *vt = dynamic_cast<VT *>(schema_val.second)) { + if (VT *vt = dynamic_cast<VT *>(val)) { detail::copy_vector(result, vt->Get()); return true; } diff --git a/dali/pipeline/operator/op_spec_test.cc b/dali/pipeline/operator/op_spec_test.cc index 83707620dac..e369318afc8 100644 --- a/dali/pipeline/operator/op_spec_test.cc +++ b/dali/pipeline/operator/op_spec_test.cc @@ -80,36 +80,36 @@ TEST(OpSpecTest, GetArgumentTensorSet) { auto spec0 = OpSpec("DummyOpForSpecTest") .AddArg("max_batch_size", 2) .AddArgumentInput(arg_name, "<not_used>"); - ASSERT_EQ(spec0.GetArgument<int32_t>(arg_name, &ws0, 0), 42); - ASSERT_EQ(spec0.GetArgument<int32_t>(arg_name, &ws0, 1), 43); + EXPECT_EQ(spec0.GetArgument<int32_t>(arg_name, &ws0, 0), 42); + EXPECT_EQ(spec0.GetArgument<int32_t>(arg_name, &ws0, 1), 43); int result = 0; ASSERT_TRUE(spec0.TryGetArgument<int32_t>(result, arg_name, &ws0, 0)); - ASSERT_EQ(result, 42); + EXPECT_EQ(result, 42); ASSERT_TRUE(spec0.TryGetArgument<int32_t>(result, arg_name, &ws0, 1)); - ASSERT_EQ(result, 43); - ASSERT_THROW(spec0.GetArgument<float>(arg_name, &ws0, 0), std::runtime_error); + EXPECT_EQ(result, 43); + EXPECT_THROW(spec0.GetArgument<float>(arg_name, &ws0, 0), std::runtime_error); float tmp = 0.f; - ASSERT_FALSE(spec0.TryGetArgument<float>(tmp, arg_name, &ws0, 0)); + EXPECT_FALSE(spec0.TryGetArgument<float>(tmp, arg_name, &ws0, 0)); ArgumentWorkspace ws1; auto spec1 = OpSpec("DummyOpForSpecTest") .AddArg("max_batch_size", 2); // If we have a default optional argument, we will just return its value if (arg_name != "default_tensor"s) { - ASSERT_THROW(spec1.GetArgument<int>(arg_name, &ws1, 0), std::runtime_error); - ASSERT_THROW(spec1.GetArgument<int>(arg_name, &ws1, 1), std::runtime_error); + EXPECT_THROW(spec1.GetArgument<int>(arg_name, &ws1, 0), std::invalid_argument); + EXPECT_THROW(spec1.GetArgument<int>(arg_name, &ws1, 1), std::invalid_argument); int result = 0; - ASSERT_FALSE(spec1.TryGetArgument<int>(result, arg_name, &ws1, 0)); - ASSERT_FALSE(spec1.TryGetArgument<int>(result, arg_name, &ws1, 1)); + EXPECT_FALSE(spec1.TryGetArgument<int>(result, arg_name, &ws1, 0)); + EXPECT_FALSE(spec1.TryGetArgument<int>(result, arg_name, &ws1, 1)); } else { - ASSERT_EQ(spec1.GetArgument<int>(arg_name, &ws1, 0), 11); - ASSERT_EQ(spec1.GetArgument<int>(arg_name, &ws1, 1), 11); + EXPECT_EQ(spec1.GetArgument<int>(arg_name, &ws1, 0), 11); + EXPECT_EQ(spec1.GetArgument<int>(arg_name, &ws1, 1), 11); int result = 0; - ASSERT_TRUE(spec1.TryGetArgument<int>(result, arg_name, &ws1, 0)); - ASSERT_EQ(result, 11); + EXPECT_TRUE(spec1.TryGetArgument<int>(result, arg_name, &ws1, 0)); + EXPECT_EQ(result, 11); result = 0; - ASSERT_TRUE(spec1.TryGetArgument<int>(result, arg_name, &ws1, 1)); - ASSERT_EQ(result, 11); + EXPECT_TRUE(spec1.TryGetArgument<int>(result, arg_name, &ws1, 1)); + EXPECT_EQ(result, 11); } } } @@ -121,14 +121,14 @@ TEST(OpSpecTest, GetArgumentValue) { auto spec0 = OpSpec("DummyOpForSpecTest") .AddArg("max_batch_size", 2) .AddArg(arg_name, 42); - ASSERT_EQ(spec0.GetArgument<int>(arg_name, &ws), 42); + EXPECT_EQ(spec0.GetArgument<int>(arg_name, &ws), 42); int result = 0; ASSERT_TRUE(spec0.TryGetArgument(result, arg_name, &ws)); - ASSERT_EQ(result, 42); + EXPECT_EQ(result, 42); - ASSERT_THROW(spec0.GetArgument<float>(arg_name, &ws), std::runtime_error); + EXPECT_THROW(spec0.GetArgument<float>(arg_name, &ws), std::runtime_error); float tmp = 0.f; - ASSERT_FALSE(spec0.TryGetArgument(tmp, arg_name, &ws)); + EXPECT_FALSE(spec0.TryGetArgument(tmp, arg_name, &ws)); } for (const auto &arg_name : {"required"s, "no_default"s, @@ -136,28 +136,28 @@ TEST(OpSpecTest, GetArgumentValue) { ArgumentWorkspace ws; auto spec0 = OpSpec("DummyOpForSpecTest") .AddArg("max_batch_size", 2); - ASSERT_THROW(spec0.GetArgument<int>(arg_name, &ws), std::runtime_error); + EXPECT_THROW(spec0.GetArgument<int>(arg_name, &ws), std::invalid_argument); int result = 0; - ASSERT_FALSE(spec0.TryGetArgument(result, arg_name, &ws)); + EXPECT_FALSE(spec0.TryGetArgument(result, arg_name, &ws)); - ASSERT_THROW(spec0.GetArgument<float>(arg_name, &ws), std::runtime_error); + EXPECT_THROW(spec0.GetArgument<float>(arg_name, &ws), std::invalid_argument); float tmp = 0.f; - ASSERT_FALSE(spec0.TryGetArgument(tmp, arg_name, &ws)); + EXPECT_FALSE(spec0.TryGetArgument(tmp, arg_name, &ws)); } for (const auto &arg_name : {"default"s, "default_tensor"s}) { ArgumentWorkspace ws; auto spec0 = OpSpec("DummyOpForSpecTest") .AddArg("max_batch_size", 2); - ASSERT_EQ(spec0.GetArgument<int>(arg_name, &ws), 11); + EXPECT_EQ(spec0.GetArgument<int>(arg_name, &ws), 11); int result = 0; ASSERT_TRUE(spec0.TryGetArgument(result, arg_name, &ws)); - ASSERT_EQ(result, 11); + EXPECT_EQ(result, 11); - ASSERT_THROW(spec0.GetArgument<float>(arg_name, &ws), std::runtime_error); + EXPECT_THROW(spec0.GetArgument<float>(arg_name, &ws), std::invalid_argument); float tmp = 0.f; - ASSERT_FALSE(spec0.TryGetArgument(tmp, arg_name, &ws)); + EXPECT_FALSE(spec0.TryGetArgument(tmp, arg_name, &ws)); } } @@ -170,10 +170,10 @@ TEST(OpSpecTest, GetArgumentVec) { .AddArg("max_batch_size", 2) .AddArg(arg_name, value); - ASSERT_EQ(spec0.GetRepeatedArgument<int32_t>(arg_name), value); + EXPECT_EQ(spec0.GetRepeatedArgument<int32_t>(arg_name), value); std::vector<int32_t> result; ASSERT_TRUE(spec0.TryGetRepeatedArgument(result, arg_name)); - ASSERT_EQ(result, value); + EXPECT_EQ(result, value); } for (const auto &arg_name : {"required_vec"s, "no_default_vec"s}) { @@ -181,17 +181,17 @@ TEST(OpSpecTest, GetArgumentVec) { auto spec0 = OpSpec("DummyOpForSpecTest") .AddArg("max_batch_size", 2); - ASSERT_THROW(spec0.GetRepeatedArgument<int32_t>(arg_name), std::runtime_error); + EXPECT_THROW(spec0.GetRepeatedArgument<int32_t>(arg_name), std::invalid_argument); std::vector<int32_t> result_v; ASSERT_FALSE(spec0.TryGetRepeatedArgument(result_v, arg_name)); SmallVector<int32_t, 1> result_sv; - ASSERT_FALSE(spec0.TryGetRepeatedArgument(result_sv, arg_name)); + EXPECT_FALSE(spec0.TryGetRepeatedArgument(result_sv, arg_name)); - ASSERT_THROW(spec0.GetRepeatedArgument<float>(arg_name), std::runtime_error); + EXPECT_THROW(spec0.GetRepeatedArgument<float>(arg_name), std::invalid_argument); std::vector<float> tmp_v; - ASSERT_FALSE(spec0.TryGetRepeatedArgument(tmp_v, arg_name)); + EXPECT_FALSE(spec0.TryGetRepeatedArgument(tmp_v, arg_name)); SmallVector<float, 1> tmp_sv; - ASSERT_FALSE(spec0.TryGetRepeatedArgument(tmp_sv, arg_name)); + EXPECT_FALSE(spec0.TryGetRepeatedArgument(tmp_sv, arg_name)); } { @@ -200,7 +200,7 @@ TEST(OpSpecTest, GetArgumentVec) { auto spec0 = OpSpec("DummyOpForSpecTest") .AddArg("max_batch_size", 2); auto default_val = std::vector<int32_t>{0, 1}; - ASSERT_EQ(spec0.GetRepeatedArgument<int32_t>(arg_name), default_val); + EXPECT_EQ(spec0.GetRepeatedArgument<int32_t>(arg_name), default_val); } } @@ -208,36 +208,36 @@ TEST(OpSpecTest, GetArgumentVec) { TEST(OpSpecTest, GetArgumentNonExisting) { auto spec0 = OpSpec("DummyOpForSpecTest") .AddArg("max_batch_size", 2); - ASSERT_THROW(spec0.GetArgument<int>("<no_such_argument>"), DALIException); + EXPECT_THROW(spec0.GetArgument<int>("<no_such_argument>"), invalid_key); int result = 0; - ASSERT_FALSE(spec0.TryGetArgument<int>(result, "<no_such_argument>")); + EXPECT_FALSE(spec0.TryGetArgument<int>(result, "<no_such_argument>")); - ASSERT_THROW(spec0.GetRepeatedArgument<int>("<no_such_argument>"), DALIException); + EXPECT_THROW(spec0.GetRepeatedArgument<int>("<no_such_argument>"), invalid_key); std::vector<int> result_vec; - ASSERT_FALSE(spec0.TryGetRepeatedArgument(result_vec, "<no_such_argument>")); + EXPECT_FALSE(spec0.TryGetRepeatedArgument(result_vec, "<no_such_argument>")); SmallVector<int, 1> result_sv; - ASSERT_FALSE(spec0.TryGetRepeatedArgument(result_sv, "<no_such_argument>")); + EXPECT_FALSE(spec0.TryGetRepeatedArgument(result_sv, "<no_such_argument>")); } TEST(OpSpecTest, DeprecatedArgs) { auto spec0 = OpSpec("DummyOpForSpecTest") .AddArg("max_batch_size", 2) .AddArg("deprecated_arg", 1); - ASSERT_THROW(spec0.GetArgument<int>("deprecated_arg"), DALIException); - ASSERT_EQ(spec0.GetArgument<int>("replacing_arg"), 1); + EXPECT_THROW(spec0.GetArgument<int>("deprecated_arg"), std::invalid_argument); + EXPECT_EQ(spec0.GetArgument<int>("replacing_arg"), 1); int result = 0; - ASSERT_FALSE(spec0.TryGetArgument<int>(result, "deprecated_arg")); + EXPECT_FALSE(spec0.TryGetArgument<int>(result, "deprecated_arg")); ASSERT_TRUE(spec0.TryGetArgument<int>(result, "replacing_arg")); - ASSERT_EQ(result, 1); + EXPECT_EQ(result, 1); - ASSERT_THROW(OpSpec("DummyOpForSpecTest") + EXPECT_THROW(OpSpec("DummyOpForSpecTest") .AddArg("max_batch_size", 2) .AddArg("deprecated_arg", 1) .AddArg("replacing_arg", 2), DALIException); - ASSERT_THROW(OpSpec("DummyOpForSpecTest") + EXPECT_THROW(OpSpec("DummyOpForSpecTest") .AddArg("max_batch_size", 2) .AddArg("replacing_arg", 1) .AddArg("deprecated_arg", 2), DALIException); @@ -247,7 +247,7 @@ TEST(OpSpecTest, DeprecatedArgs) { .AddArg("deprecated_ignored_arg", 42); // It is marked as to be ingored, but there's no reason we should not be // able to query for the argument if it was provided. - ASSERT_TRUE(spec0.TryGetArgument<int>(result, "deprecated_ignored_arg")); + EXPECT_TRUE(spec0.TryGetArgument<int>(result, "deprecated_ignored_arg")); } TEST(OpSpecTest, DeprecatedArgsParents) { @@ -256,55 +256,55 @@ TEST(OpSpecTest, DeprecatedArgsParents) { .AddArg("grandparent_deprecated_arg", 3) .AddArg("parent_zero_deprecated_arg", 4) .AddArg("parent_one_deprecated_arg", 5); - ASSERT_THROW(spec0.GetArgument<int>("grandparent_deprecated_arg"), DALIException); - ASSERT_THROW(spec0.GetArgument<int>("parent_zero_deprecated_arg"), DALIException); - ASSERT_THROW(spec0.GetArgument<int>("parent_one_deprecated_arg"), DALIException); - ASSERT_EQ(spec0.GetArgument<int>("grandparent_replacing_arg"), 3); - ASSERT_EQ(spec0.GetArgument<int>("parent_zero_replacing_arg"), 4); - ASSERT_EQ(spec0.GetArgument<int>("parent_one_replacing_arg"), 5); + EXPECT_THROW(spec0.GetArgument<int>("grandparent_deprecated_arg"), std::invalid_argument); + EXPECT_THROW(spec0.GetArgument<int>("parent_zero_deprecated_arg"), std::invalid_argument); + EXPECT_THROW(spec0.GetArgument<int>("parent_one_deprecated_arg"), std::invalid_argument); + EXPECT_EQ(spec0.GetArgument<int>("grandparent_replacing_arg"), 3); + EXPECT_EQ(spec0.GetArgument<int>("parent_zero_replacing_arg"), 4); + EXPECT_EQ(spec0.GetArgument<int>("parent_one_replacing_arg"), 5); int result = 0; - ASSERT_FALSE(spec0.TryGetArgument<int>(result, "grandparent_deprecated_arg")); + EXPECT_FALSE(spec0.TryGetArgument<int>(result, "grandparent_deprecated_arg")); ASSERT_TRUE(spec0.TryGetArgument<int>(result, "grandparent_replacing_arg")); - ASSERT_EQ(result, 3); + EXPECT_EQ(result, 3); - ASSERT_FALSE(spec0.TryGetArgument<int>(result, "parent_zero_deprecated_arg")); + EXPECT_FALSE(spec0.TryGetArgument<int>(result, "parent_zero_deprecated_arg")); ASSERT_TRUE(spec0.TryGetArgument<int>(result, "parent_zero_replacing_arg")); - ASSERT_EQ(result, 4); + EXPECT_EQ(result, 4); - ASSERT_FALSE(spec0.TryGetArgument<int>(result, "parent_one_deprecated_arg")); + EXPECT_FALSE(spec0.TryGetArgument<int>(result, "parent_one_deprecated_arg")); ASSERT_TRUE(spec0.TryGetArgument<int>(result, "parent_one_replacing_arg")); - ASSERT_EQ(result, 5); + EXPECT_EQ(result, 5); - ASSERT_THROW(OpSpec("DummyOpForSpecTest") + EXPECT_THROW(OpSpec("DummyOpForSpecTest") .AddArg("max_batch_size", 2) .AddArg("grandparent_deprecated_arg", 1) .AddArg("grandparent_replacing_arg", 2), DALIException); - ASSERT_THROW(OpSpec("DummyOpForSpecTest") + EXPECT_THROW(OpSpec("DummyOpForSpecTest") .AddArg("max_batch_size", 2) .AddArg("grandparent_replacing_arg", 1) .AddArg("grandparent_deprecated_arg", 2), DALIException); - ASSERT_THROW(OpSpec("DummyOpForSpecTest") + EXPECT_THROW(OpSpec("DummyOpForSpecTest") .AddArg("max_batch_size", 2) .AddArg("parent_zero_deprecated_arg", 1) .AddArg("parent_zero_replacing_arg", 2), DALIException); - ASSERT_THROW(OpSpec("DummyOpForSpecTest") + EXPECT_THROW(OpSpec("DummyOpForSpecTest") .AddArg("max_batch_size", 2) .AddArg("parent_zero_replacing_arg", 1) .AddArg("parent_zero_deprecated_arg", 2), DALIException); - ASSERT_THROW(OpSpec("DummyOpForSpecTest") + EXPECT_THROW(OpSpec("DummyOpForSpecTest") .AddArg("max_batch_size", 2) .AddArg("parent_one_deprecated_arg", 1) .AddArg("parent_one_replacing_arg", 2), DALIException); - ASSERT_THROW(OpSpec("DummyOpForSpecTest") + EXPECT_THROW(OpSpec("DummyOpForSpecTest") .AddArg("max_batch_size", 2) .AddArg("parent_one_replacing_arg", 1) .AddArg("parent_one_deprecated_arg", 2), DALIException); @@ -369,7 +369,7 @@ class TestArgumentInput_Consumer : public Operator<CPUBackend> { } // Non-matching shapes (differnet than 1 scalar value per sample) should not work with // OpSpec::GetArgument() - ASSERT_THROW(auto z = spec_.GetArgument<float>("arg2", &ws, 0), std::runtime_error); + EXPECT_THROW(auto z = spec_.GetArgument<float>("arg2", &ws, 0), std::runtime_error); // They can be accessed as proper ArgumentInputs auto &ref_1 = ws.ArgumentInput("arg1"); diff --git a/dali/pipeline/pipeline.cc b/dali/pipeline/pipeline.cc index 84079f05ce7..09fa204b221 100644 --- a/dali/pipeline/pipeline.cc +++ b/dali/pipeline/pipeline.cc @@ -387,10 +387,6 @@ void Pipeline::AddToOpSpecs(const std::string &inst_name, const OpSpec &spec, in GetOpDisplayName(spec, true) + "` using logical_id=" + std::to_string(logical_id) + " which is already assigned to " + group_name + "."); const OpSchema &schema = SchemaRegistry::GetSchema(spec.SchemaName()); - DALI_ENFORCE(schema.AllowsInstanceGrouping(), - "Operator `" + GetOpDisplayName(spec, true) + - "` does not support synced random execution required " - "for multiple input sets processing."); } op_specs_.push_back({inst_name, spec, logical_id}); logical_ids_[logical_id].push_back(op_specs_.size() - 1); @@ -649,20 +645,26 @@ void Pipeline::ToGPU(std::map<string, EdgeMeta>::iterator it) { } void Pipeline::PrepareOpSpec(OpSpec *spec, int logical_id) { - if (logical_id_to_seed_.find(logical_id) == logical_id_to_seed_.end()) { - logical_id_to_seed_[logical_id] = seed_[current_seed_]; - } spec->AddArg("max_batch_size", max_batch_size_) .AddArg("num_threads", num_threads_) .AddArg("device_id", device_id_) - .AddArg("checkpointing", checkpointing_) - .AddArgIfNotExisting("seed", logical_id_to_seed_[logical_id]); + .AddArg("checkpointing", checkpointing_); string dev = spec->GetArgument<string>("device"); if (dev == "cpu" || dev == "mixed") spec->AddArg("cpu_prefetch_queue_depth", prefetch_queue_depth_.cpu_size); if (dev == "gpu" || dev == "mixed") spec->AddArg("gpu_prefetch_queue_depth", prefetch_queue_depth_.gpu_size); - current_seed_ = (current_seed_+1) % MAX_SEEDS; + + if (spec->GetSchemaOrDefault().HasRandomSeedArg()) { + if (spec->ArgumentDefined("seed")) { + logical_id_to_seed_[logical_id] = spec->GetArgument<int64_t>("seed"); + } else { + if (logical_id_to_seed_.find(logical_id) == logical_id_to_seed_.end()) + logical_id_to_seed_[logical_id] = seed_[current_seed_]; + spec->AddArg("seed", logical_id_to_seed_[logical_id]); + current_seed_ = (current_seed_+1) % MAX_SEEDS; + } + } } /** @@ -695,7 +697,7 @@ void SerializeToProtobuf(dali_proto::OpDef *op, const string &inst_name, const O for (auto& a : spec.Arguments()) { // filter out args that need to be dealt with on // loading a serialized pipeline - auto &name = a->get_name(); + auto name = a->get_name(); if (name == "max_batch_size" || name == "num_threads" || name == "bytes_per_sample_hint") { diff --git a/dali/pipeline/pipeline_test.cc b/dali/pipeline/pipeline_test.cc index bcffe86ee76..80b27951c69 100644 --- a/dali/pipeline/pipeline_test.cc +++ b/dali/pipeline/pipeline_test.cc @@ -489,7 +489,7 @@ TEST_F(PipelineTestOnce, TestPresize) { TYPED_TEST(PipelineTest, TestSeedSet) { int num_thread = TypeParam::nt; int batch_size = this->jpegs_.nImages(); - constexpr int seed_set = 567; + constexpr int64_t seed_set = 567; Pipeline pipe(batch_size, num_thread, 0); @@ -500,7 +500,7 @@ TYPED_TEST(PipelineTest, TestSeedSet) { pipe.AddExternalInput("data"); pipe.AddOperator( - OpSpec("Copy") + OpSpec("DummyOpToAdd") .AddArg("device", "cpu") .AddArg("seed", seed_set) .AddInput("data", "cpu") @@ -523,8 +523,40 @@ TYPED_TEST(PipelineTest, TestSeedSet) { // Check if seed can be manually set EXPECT_EQ(original_graph.GetOp("copy1")->spec.GetArgument<int64_t>("seed"), seed_set); - EXPECT_EQ(original_graph.GetOp("copy2")->spec.GetArgument<int64_t>("seed"), seed_set); - EXPECT_NE(original_graph.GetOp("data")->spec.GetArgument<int64_t>("seed"), seed_set); + // The "seed" argument is deprecated as removed - so the argument is not added to the OpSpec + EXPECT_FALSE(original_graph.GetOp("copy2")->spec.HasArgument("seed")); + EXPECT_FALSE(original_graph.GetOp("data")->spec.HasArgument("seed")); +} + + +TYPED_TEST(PipelineTest, TestSeedAuto) { + int num_thread = TypeParam::nt; + int batch_size = this->jpegs_.nImages(); + + Pipeline pipe(batch_size, num_thread, 0); + + + TensorList<CPUBackend> batch; + test::MakeRandomBatch(batch, batch_size); + + pipe.AddExternalInput("data"); + + pipe.AddOperator( + OpSpec("DummyOpToAdd") + .AddArg("device", "cpu") + .AddInput("data", "cpu") + .AddOutput("out0", "cpu"), "dummy"); + + pipe.Build({{"out0", "gpu"}}); + + pipe.SetExternalInput("data", batch); + + graph::OpGraph &original_graph = this->GetGraph(&pipe); + + // ExternalSource doesn't have a seed... + EXPECT_FALSE(original_graph.GetOp("data")->spec.HasArgument("seed")); + // ...but DumyOpToAdd does - check if it was set by the Pipeline + EXPECT_TRUE(original_graph.GetOp("dummy")->spec.HasArgument("seed")); } @@ -650,31 +682,8 @@ DALI_REGISTER_OPERATOR(DummyOpToAdd, DummyOpToAdd, CPU); DALI_SCHEMA(DummyOpToAdd) .DocStr("DummyOpToAdd") .NumInput(1) - .NumOutput(1); - - -class DummyOpNoSync : public Operator<CPUBackend> { - public: - explicit DummyOpNoSync(const OpSpec &spec) : Operator<CPUBackend>(spec) {} - - bool HasContiguousOutputs() const override { - return false; - } - - bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override { - return false; - } - - void RunImpl(Workspace &ws) override {} -}; - -DALI_REGISTER_OPERATOR(DummyOpNoSync, DummyOpNoSync, CPU); - -DALI_SCHEMA(DummyOpNoSync) - .DocStr("DummyOpNoSync") - .DisallowInstanceGrouping() - .NumInput(1) - .NumOutput(1); + .NumOutput(1) + .AddRandomSeedArg(); TEST(PipelineTest, AddOperator) { Pipeline pipe(10, 4, 0); @@ -702,16 +711,6 @@ TEST(PipelineTest, AddOperator) { EXPECT_EQ(third_op, second_op + 1); - int disallow_sync_op = pipe.AddOperator(OpSpec("DummyOpNoSync") - .AddArg("device", "cpu") - .AddInput("data_in0", "cpu") - .AddOutput("data_out3", "cpu"), "DummyOpNoSync"); - - ASSERT_THROW(pipe.AddOperator(OpSpec("DummyOpNoSync") - .AddArg("device", "cpu") - .AddInput("data_in0", "cpu") - .AddOutput("data_out4", "cpu"), "DummyOpNoSync2", disallow_sync_op), std::runtime_error); - vector<std::pair<string, string>> outputs = { {"data_out0", "cpu"}, {"data_out1", "cpu"}, {"data_out2", "cpu"}}; pipe.Build(outputs); @@ -811,7 +810,6 @@ DALI_REGISTER_OPERATOR(DummyInputOperator, DummyInputOperator, CPU); DALI_SCHEMA(DummyInputOperator) .DocStr("DummyInputOperator") - .DisallowInstanceGrouping() .NumInput(0) .NumOutput(2); diff --git a/dali/python/backend_impl.cc b/dali/python/backend_impl.cc index 5e3318b8840..b190a09a2dc 100644 --- a/dali/python/backend_impl.cc +++ b/dali/python/backend_impl.cc @@ -1684,7 +1684,7 @@ pools as well as from the host pinned memory pool. This function is safe to use while DALI pipelines are running.)"); } -py::dict DeprecatedArgMetaToDict(const DeprecatedArgDef & meta) { +py::dict ArgumentDeprecationInfoToDict(const ArgumentDeprecation & meta) { py::dict d; d["msg"] = meta.msg; d["removed"] = meta.removed; @@ -2396,8 +2396,7 @@ PYBIND11_MODULE(backend_impl, m) { .def("GetArgumentDefaultValueString", &OpSchema::GetArgumentDefaultValueString) .def("GetArgumentNames", &OpSchema::GetArgumentNames) .def("IsArgumentOptional", &OpSchema::HasOptionalArgument, - "arg_name"_a, - "local_only"_a = false) + "arg_name"_a) .def("IsTensorArgument", &OpSchema::IsTensorArgument) .def("ArgSupportsPerFrameInput", &OpSchema::ArgSupportsPerFrameInput) .def("IsSequenceOperator", &OpSchema::IsSequenceOperator) @@ -2411,10 +2410,10 @@ PYBIND11_MODULE(backend_impl, m) { .def("DeprecatedInFavorOf", &OpSchema::DeprecatedInFavorOf) .def("DeprecationMessage", &OpSchema::DeprecationMessage) .def("IsDeprecatedArg", &OpSchema::IsDeprecatedArg) - .def("DeprecatedArgMeta", + .def("DeprecatedArgInfo", [](OpSchema *schema, const std::string &arg_name) { - auto meta = schema->DeprecatedArgMeta(arg_name); - return DeprecatedArgMetaToDict(meta); + auto meta = schema->DeprecatedArgInfo(arg_name); + return ArgumentDeprecationInfoToDict(meta); }) .def("GetSupportedLayouts", &OpSchema::GetSupportedLayouts) .def("HasArgument", @@ -2444,6 +2443,8 @@ PYBIND11_MODULE(backend_impl, m) { try { if (p) std::rethrow_exception(p); + } catch (const invalid_key &e) { + PyErr_SetString(PyExc_KeyError, e.what()); } catch (const DaliRuntimeError &e) { PyErr_SetString(PyExc_RuntimeError, e.what()); } catch (const DaliIndexError &e) { diff --git a/dali/python/nvidia/dali/ops/__init__.py b/dali/python/nvidia/dali/ops/__init__.py index 8e31bde3718..5cf40ce7411 100644 --- a/dali/python/nvidia/dali/ops/__init__.py +++ b/dali/python/nvidia/dali/ops/__init__.py @@ -188,7 +188,7 @@ def _handle_arg_deprecations(schema, kwargs, op_name): for arg_name in arg_names: if not schema.IsDeprecatedArg(arg_name): continue - meta = schema.DeprecatedArgMeta(arg_name) + meta = schema.DeprecatedArgInfo(arg_name) new_name = meta["renamed_to"] removed = meta["removed"] msg = meta["msg"] diff --git a/dali/python/nvidia/dali/ops/_docs.py b/dali/python/nvidia/dali/ops/_docs.py index 7fb19896e4e..6c8bbba9b30 100644 --- a/dali/python/nvidia/dali/ops/_docs.py +++ b/dali/python/nvidia/dali/ops/_docs.py @@ -107,7 +107,7 @@ def _get_kwargs(schema): doc = "" deprecation_warning = None if schema.IsDeprecatedArg(arg): - meta = schema.DeprecatedArgMeta(arg) + meta = schema.DeprecatedArgInfo(arg) msg = meta["msg"] assert msg is not None deprecation_warning = ".. warning::\n\n " + msg.replace("\n", "\n ") diff --git a/dali/test/python/operator_1/test_crop.py b/dali/test/python/operator_1/test_crop.py index f88cdec01a1..0832dce4384 100644 --- a/dali/test/python/operator_1/test_crop.py +++ b/dali/test/python/operator_1/test_crop.py @@ -678,7 +678,7 @@ def get_data(): pipe = get_pipe(batch_size=batch_size, device_id=0, num_threads=3) pipe.build() with assert_raises( - RuntimeError, glob=f'The layout "{layout}" does not match any of the allowed layouts' + ValueError, glob=f'The layout "{layout}" does not match any of the allowed layouts' ): pipe.run() diff --git a/dali/test/python/operator_1/test_crop_mirror_normalize.py b/dali/test/python/operator_1/test_crop_mirror_normalize.py index a1de4e39207..737216ec7b6 100644 --- a/dali/test/python/operator_1/test_crop_mirror_normalize.py +++ b/dali/test/python/operator_1/test_crop_mirror_normalize.py @@ -953,7 +953,7 @@ def get_data(): pipe = get_pipe(batch_size=batch_size, device_id=0, num_threads=3) pipe.build() with assert_raises( - RuntimeError, glob=f'The layout "{layout}" does not match any of the allowed layouts' + ValueError, glob=f'The layout "{layout}" does not match any of the allowed layouts' ): pipe.run() diff --git a/dali/test/python/operator_1/test_debayer.py b/dali/test/python/operator_1/test_debayer.py index 433c81bfd16..6882c5c6cf0 100644 --- a/dali/test/python/operator_1/test_debayer.py +++ b/dali/test/python/operator_1/test_debayer.py @@ -302,7 +302,7 @@ def test_too_many_channels(): def test_wrong_sample_dim(): with assert_raises( - RuntimeError, glob="The number of dimensions 5 does not match any of the allowed" + ValueError, glob="The number of dimensions 5 does not match any of the allowed" ): _test_shape_pipeline((1, 1, 1, 1, 1), np.uint8) diff --git a/dali/test/python/operator_2/test_resize.py b/dali/test/python/operator_2/test_resize.py index 0d8764df266..ef1217d6aaa 100644 --- a/dali/test/python/operator_2/test_resize.py +++ b/dali/test/python/operator_2/test_resize.py @@ -338,7 +338,7 @@ def build_pipes( batch_size=batch_size, num_threads=8, device_id=0, - seed=1234, + seed=12345, exec_async=False, exec_pipelined=False, ) diff --git a/dali/test/python/reader/test_coco.py b/dali/test/python/reader/test_coco.py index e652d8bbbdc..922093be764 100644 --- a/dali/test/python/reader/test_coco.py +++ b/dali/test/python/reader/test_coco.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -192,8 +192,8 @@ def test_operator_coco_reader_same_images(): @raises( - RuntimeError, - glob='Argument "preprocessed_annotations_dir" is not supported by operator *readers*COCO', + KeyError, + glob='Argument "preprocessed_annotations_dir" is not defined for operator *readers*COCO', ) def test_invalid_args(): pipeline = Pipeline(batch_size=2, num_threads=4, device_id=0) diff --git a/dali/test/python/test_pipeline.py b/dali/test/python/test_pipeline.py index 4362a740588..a64f481c15d 100644 --- a/dali/test/python/test_pipeline.py +++ b/dali/test/python/test_pipeline.py @@ -36,7 +36,7 @@ get_dali_extra_path, RandomDataIterator, ) -from nose_utils import raises, assert_raises, SkipTest +from nose_utils import raises, assert_raises, assert_warns, SkipTest test_data_root = get_dali_extra_path() caffe_db_folder = os.path.join(test_data_root, "db", "lmdb") @@ -493,6 +493,19 @@ def test_none_seed(): assert np.sum(np.abs(test_out_ref - test_out)) != 0 +def test_seed_deprecated(): + @pipeline_def(batch_size=1, device_id=None, num_threads=1) + def my_pipe(): + with assert_warns( + DeprecationWarning, + glob='The argument "seed" should not be used with operators ' + "that don't use random numbers.", + ): + return fn.reshape(np.float32([1, 2]), shape=[2], seed=123) + + my_pipe() + + def test_as_array(): batch_size = 64 diff --git a/include/dali/core/error_handling.h b/include/dali/core/error_handling.h index 7ba9cdfdf5d..fb88d1a0ee1 100644 --- a/include/dali/core/error_handling.h +++ b/include/dali/core/error_handling.h @@ -85,10 +85,21 @@ class DALIException : public std::runtime_error { }; struct unsupported_exception : std::runtime_error { - explicit unsupported_exception(const std::string &str) : runtime_error(str), msg(str) {} + explicit unsupported_exception(const std::string &str) : runtime_error(str) {} +}; - const char *what() const noexcept override { return msg.c_str(); } - std::string msg; +/** An exception thrown when an invalid dictionary key is provided + * + * The exception denotes an invalid key. It can be thrown when: + * - the key is not found and the function returns a non-nullable type + * - the key doesn't meet some constraints (e.g. a dictionary doesn't accept an empty + * string as a key). + * + * This exception is used at the Python boundary to raise KeyError rather than IndexError. + */ +struct invalid_key : std::out_of_range { + explicit invalid_key(const std::string &message) : std::out_of_range(message) {} + explicit invalid_key(const char *message) : std::out_of_range(message) {} }; inline string BuildErrorString(string statement, string file, int line) {