Skip to content

Commit

Permalink
refactor(//core/conversion/var): Rename Arg -> Var
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed May 26, 2020
1 parent ac2b102 commit fa509d9
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 45 deletions.
13 changes: 6 additions & 7 deletions core/conversion/arg/BUILD → core/conversion/var/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,17 @@ config_setting(
)

cc_library(
name = "arg",
name = "var",
hdrs = [
"Arg.h",
"Arg_inl.h"
"Var.h",
"Var_inl.h"
],
srcs = [
"Arg.cpp",
"Var.cpp",
],
deps = [
"@tensorrt//:nvinfer",
"//core/util:prelude",
"//core/conversion/conversionctx",
] + select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
Expand All @@ -33,7 +32,7 @@ pkg_tar(
name = "include",
package_dir = "core/conversion/arg/",
srcs = [
"Arg.h",
"Arg_inl.h"
"Var.h",
"Var_inl.h"
],
)
34 changes: 17 additions & 17 deletions core/conversion/arg/Arg.cpp → core/conversion/var/Var.cpp
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
#include "core/util/prelude.h"
#include "core/conversion/arg/Arg.h"
#include "core/conversion/var/Var.h"

namespace trtorch {
namespace core {
namespace conversion {

Arg::Arg() {
Var::Var() {
ptr_.none = nullptr;
type_ = Type::kNone;
}

Arg::Arg(const torch::jit::IValue* p)
Var::Var(const torch::jit::IValue* p)
: type_(Type::kIValue) {
ptr_.ivalue = p;
}

Arg::Arg(nvinfer1::ITensor* p)
Var::Var(nvinfer1::ITensor* p)
: type_(Type::kITensor) {
ptr_.tensor = p;
}

Arg::Arg(const Arg& a) {
Var::Var(const Var& a) {
switch(a.type_) {
case Type::kITensor:
ptr_.tensor = a.ptr_.tensor;
Expand All @@ -37,7 +37,7 @@ Arg::Arg(const Arg& a) {
}
}

Arg& Arg::operator=(const Arg& a) {
Var& Var::operator=(const Var& a) {
switch(a.type_) {
case Type::kITensor:
ptr_.tensor = a.ptr_.tensor;
Expand All @@ -55,23 +55,23 @@ Arg& Arg::operator=(const Arg& a) {
return (*this);
}

Arg& Arg::operator=(const torch::jit::IValue* in) {
Var& Var::operator=(const torch::jit::IValue* in) {
ptr_.ivalue = in;
type_ = Type::kIValue;
return (*this);
}

Arg& Arg::operator=(nvinfer1::ITensor* in) {
Var& Var::operator=(nvinfer1::ITensor* in) {
ptr_.tensor = in;
type_ = Type::kITensor;
return (*this);
}

Arg::Type Arg::type() const {
Var::Type Var::type() const {
return type_;
}

std::string Arg::type_name() const {
std::string Var::type_name() const {
switch(type_) {
case Type::kITensor:
return "nvinfer1::ITensor";
Expand All @@ -85,41 +85,41 @@ std::string Arg::type_name() const {
}
}

const torch::jit::IValue* Arg::IValue() const {
TRTORCH_CHECK(isIValue(), "Requested IValue from Arg, however arg type is " << type_name());
const torch::jit::IValue* Var::IValue() const {
TRTORCH_CHECK(isIValue(), "Requested IValue from Var, however Var type is " << type_name());
if (type_ == Type::kIValue) {
return ptr_.ivalue;
} else {
return nullptr;
}
}

nvinfer1::ITensor* Arg::ITensor() const {
TRTORCH_CHECK(isITensor(), "Requested ITensor from Arg, however arg type is " << type_name());
nvinfer1::ITensor* Var::ITensor() const {
TRTORCH_CHECK(isITensor(), "Requested ITensor from Var, however Var type is " << type_name());
if (type_ == Type::kITensor) {
return ptr_.tensor;
} else {
return nullptr;
}
}

bool Arg::isITensor() const {
bool Var::isITensor() const {
if (type_ == Type::kITensor) {
return true;
} else {
return false;
}
}

bool Arg::isIValue() const {
bool Var::isIValue() const {
if (type_ == Type::kIValue) {
return true;
} else {
return false;
}
}

bool Arg::isNone() const {
bool Var::isNone() const {
if (type_ == Type::kNone) {
return true;
} else {
Expand Down
28 changes: 13 additions & 15 deletions core/conversion/arg/Arg.h → core/conversion/var/Var.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,29 @@
#include <string>
#include <map>

#include "torch/csrc/jit/runtime/custom_operator.h"
#include "ATen/core/function_schema.h"
#include "torch/csrc/jit/ir/ir.h"

#include "core/util/prelude.h"
#include "core/conversion/conversionctx/ConversionCtx.h"

namespace trtorch {
namespace core {
namespace conversion {

class Arg {
class Var : torch::CustomClassHolder {
public:
enum Type {
kITensor,
kIValue,
kNone
};

Arg();
Arg(const torch::jit::IValue* p);
Arg(nvinfer1::ITensor* p);
Arg(const Arg& a);
Arg& operator=(const Arg& a);
Arg& operator=(const torch::jit::IValue* in);
Arg& operator=(nvinfer1::ITensor* in);
Var();
Var(const torch::jit::IValue* p);
Var(nvinfer1::ITensor* p);
Var(const Var& a);
Var& operator=(const Var& a);
Var& operator=(const torch::jit::IValue* in);
Var& operator=(nvinfer1::ITensor* in);
const torch::jit::IValue* IValue() const;
nvinfer1::ITensor* ITensor() const;

Expand Down Expand Up @@ -59,21 +57,21 @@ class Arg {
bool isIValue() const;
bool isITensor() const;
bool isNone() const;
Arg::Type type() const;
Var::Type type() const;
std::string type_name() const;
private:
union ArgContainer {
union VarContainer {
const torch::jit::IValue* ivalue;
nvinfer1::ITensor* tensor;
void* none;
};

ArgContainer ptr_;
VarContainer ptr_;
Type type_;
};

} // namespace conversion
} // namespace core
} // namespace trtorch

#include "core/conversion/arg/Arg_inl.h"
#include "core/conversion/var/Var_inl.h"
10 changes: 4 additions & 6 deletions core/conversion/arg/Arg_inl.h → core/conversion/var/Var_inl.h
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
#pragma once

//#include "core/conversion/arg/Arg.h"

namespace trtorch {
namespace core {
namespace conversion {

#define DEFINE_UNWRAP_TO(ival_type, method_variant) \
template<> \
inline ival_type Arg::unwrapTo<ival_type>() { \
inline ival_type Var::unwrapTo<ival_type>() { \
TRTORCH_CHECK(isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name()); \
auto ivalue = ptr_.ivalue; \
TRTORCH_CHECK(ivalue->is##method_variant(), "Requested unwrapping of arg IValue assuming it was " << typeid(ival_type).name() << " however type is " << *(ptr_.ivalue->type())); \
return ptr_.ivalue->to<ival_type>(); \
} \
template<> \
inline ival_type Arg::unwrapTo(ival_type default_val) { \
inline ival_type Var::unwrapTo(ival_type default_val) { \
try { \
return this->unwrapTo<ival_type>(); \
} catch(trtorch::Error& e) { \
Expand All @@ -24,11 +22,11 @@ inline ival_type Arg::unwrapTo(ival_type default_val) { \
} \
} \
\
inline ival_type Arg::unwrapTo##method_variant(ival_type default_val) { \
inline ival_type Var::unwrapTo##method_variant(ival_type default_val) { \
return this->unwrapTo<ival_type>(default_val); \
} \
\
inline ival_type Arg::unwrapTo##method_variant() { \
inline ival_type Var::unwrapTo##method_variant() { \
return this->unwrapTo<ival_type>(); \
}

Expand Down

0 comments on commit fa509d9

Please sign in to comment.