diff --git a/py/trtorch/csrc/tensorrt_classes.h b/py/trtorch/csrc/tensorrt_classes.h index 3b8b3cdcdb..1ad32b3167 100644 --- a/py/trtorch/csrc/tensorrt_classes.h +++ b/py/trtorch/csrc/tensorrt_classes.h @@ -18,13 +18,13 @@ namespace pyapi { } // TODO: Make this error message more informative -#define ADD_ENUM_GET_SET(field_name, type, max_val) \ - void set_##field_name(int64_t val) { \ - TRTORCH_CHECK(val < max_val, "Invalid enum value for field"); \ - field_name = static_cast(val); \ - } \ - int64_t get_##field_name() { \ - return static_cast(field_name); \ +#define ADD_ENUM_GET_SET(field_name, type, max_val) \ + void set_##field_name(int64_t val) { \ + TRTORCH_CHECK(val >= 0 && val <= max_val, "Invalid enum value for field"); \ + field_name = static_cast(val); \ + } \ + int64_t get_##field_name() { \ + return static_cast(field_name); \ } struct InputRange : torch::CustomClassHolder {