Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add options for eval and gradient required #78

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 25 additions & 33 deletions src/ctorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,20 @@ constexpr auto get_device(torch_device_t device)
}
}

void set_is_training(torch_jit_script_module_t module, const bool is_training)
{
auto model = static_cast<torch::jit::script::Module*>(module);
if (is_training) {
model->train();
} else {
model->eval();
}
}

torch_tensor_t torch_zeros(int ndim, const int64_t* shape, torch_data_t dtype,
torch_device_t device)
torch_device_t device, const bool requires_grad)
{
torch::AutoGradMode enable_grad(requires_grad);
torch::Tensor* tensor = nullptr;
try {
// This doesn't throw if shape and dimensions are incompatible
Expand All @@ -66,8 +77,9 @@ torch_tensor_t torch_zeros(int ndim, const int64_t* shape, torch_data_t dtype,
}

torch_tensor_t torch_ones(int ndim, const int64_t* shape, torch_data_t dtype,
torch_device_t device)
torch_device_t device, const bool requires_grad)
{
torch::AutoGradMode enable_grad(requires_grad);
torch::Tensor* tensor = nullptr;
try {
// This doesn't throw if shape and dimensions are incompatible
Expand All @@ -88,8 +100,9 @@ torch_tensor_t torch_ones(int ndim, const int64_t* shape, torch_data_t dtype,
}

torch_tensor_t torch_empty(int ndim, const int64_t* shape, torch_data_t dtype,
torch_device_t device)
torch_device_t device, const bool requires_grad)
{
torch::AutoGradMode enable_grad(requires_grad);
torch::Tensor* tensor = nullptr;
try {
// This doesn't throw if shape and dimensions are incompatible
Expand All @@ -109,38 +122,12 @@ torch_tensor_t torch_empty(int ndim, const int64_t* shape, torch_data_t dtype,
return tensor;
}

/*
// Exposes the given data as a Tensor without taking ownership of the original
// data
torch_tensor_t torch_from_blob(void* data, int ndim, const int64_t* shape,
torch_data_t dtype, torch_device_t device)
{
torch::Tensor* tensor = nullptr;
try {
// This doesn't throw if shape and dimensions are incompatible
c10::IntArrayRef vshape(shape, ndim);
tensor = new torch::Tensor;
*tensor = torch::from_blob(
data, vshape,
torch::dtype(get_dtype(dtype))).to(get_device(device));
} catch (const torch::Error& e) {
std::cerr << "[ERROR]: " << e.msg() << std::endl;
delete tensor;
exit(EXIT_FAILURE);
} catch (const std::exception& e) {
std::cerr << "[ERROR]: " << e.what() << std::endl;
delete tensor;
exit(EXIT_FAILURE);
}
return tensor;
}

*/
// New version of torch_from_blob that uses strides
torch_tensor_t torch_from_blob(void* data, int ndim, const int64_t* shape,
const int64_t* strides, torch_data_t dtype,
torch_device_t device)
torch_device_t device, const bool requires_grad)
{
torch::AutoGradMode enable_grad(requires_grad);
torch::Tensor* tensor = nullptr;

try {
Expand Down Expand Up @@ -176,8 +163,11 @@ void torch_tensor_delete(torch_tensor_t tensor)
delete t;
}

torch_jit_script_module_t torch_jit_load(const char* filename)
torch_jit_script_module_t torch_jit_load(const char* filename,
const bool requires_grad,
const bool is_training)
{
torch::AutoGradMode enable_grad(requires_grad);
torch::jit::script::Module* module = nullptr;
try {
module = new torch::jit::script::Module;
Expand All @@ -191,14 +181,16 @@ torch_jit_script_module_t torch_jit_load(const char* filename)
delete module;
exit(EXIT_FAILURE);
}
set_is_training(module, is_training);

return module;
}

void torch_jit_module_forward(const torch_jit_script_module_t module,
const torch_tensor_t *inputs, const int nin,
torch_tensor_t output)
torch_tensor_t output, const bool requires_grad)
{
torch::AutoGradMode enable_grad(requires_grad);
// Here we cast the pointers we recieved in to Tensor objects
auto model = static_cast<torch::jit::script::Module*>(module);
auto in = reinterpret_cast<torch::Tensor* const*>(inputs);
Expand Down
28 changes: 21 additions & 7 deletions src/ctorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,35 @@ typedef enum { torch_kCPU, torch_kCUDA } torch_device_t;
* @param shape of the Tensor
* @param data type of the elements of the Tensor
* @param device used (cpu, CUDA, etc.)
* @param whether gradient is required
*/
EXPORT_C torch_tensor_t torch_zeros(int ndim, const int64_t* shape,
torch_data_t dtype, torch_device_t device);
torch_data_t dtype, torch_device_t device,
const bool requires_grad);

/**
* Function to generate a Torch Tensor of ones
* @param number of dimensions of the Tensor
* @param shape of the Tensor
* @param data type of the elements of the Tensor
* @param device used (cpu, CUDA, etc.)
* @param whether gradient is required
*/
EXPORT_C torch_tensor_t torch_ones(int ndim, const int64_t* shape,
torch_data_t dtype, torch_device_t device);
torch_data_t dtype, torch_device_t device,
const bool requires_grad);

/**
* Function to generate an empty Torch Tensor
* @param number of dimensions of the Tensor
* @param shape of the Tensor
* @param data type of the elements of the Tensor
* @param device used (cpu, CUDA, etc.)
* @param whether gradient is required
*/
EXPORT_C torch_tensor_t torch_empty(int ndim, const int64_t* shape,
torch_data_t dtype, torch_device_t device);
torch_data_t dtype, torch_device_t device,
const bool requires_grad);

/**
* Function to create a Torch Tensor from memory location given extra information
Expand All @@ -70,13 +76,15 @@ EXPORT_C torch_tensor_t torch_empty(int ndim, const int64_t* shape,
* @param strides to take through data
* @param data type of the elements of the Tensor
* @param device used (cpu, CUDA, etc.)
* @param whether gradient is required
* @return Torch Tensor interpretation of the data pointed at
*/
EXPORT_C torch_tensor_t torch_from_blob(void* data, int ndim,
const int64_t* shape,
const int64_t* strides,
torch_data_t dtype,
torch_device_t device);
torch_device_t device,
const bool requires_grad);

/**
* Function to print out a Torch Tensor
Expand All @@ -93,12 +101,14 @@ EXPORT_C void torch_tensor_print(const torch_tensor_t tensor);
* @param shape of the Tensor
* @param data type of the elements of the Tensor
* @param device used (cpu, CUDA, etc.)
* @param whether gradient is required
* @return Torch Tensor interpretation of the data pointed at
*/
EXPORT_C torch_tensor_t torch_from_blob_f(void* data, int ndim,
const int64_t* shape,
torch_data_t dtype,
torch_device_t device);
torch_device_t device,
const bool requires_grad);

/**
* Function to delete a Torch Tensor to clean up
Expand All @@ -113,20 +123,24 @@ EXPORT_C void torch_tensor_delete(torch_tensor_t tensor);
/**
* Function to load in a Torch model from a TorchScript file and store in a Torch Module
* @param filename where TorchScript description of model is stored
* @param whether gradient is required
* @param whether model is being trained
* @return Torch Module loaded in from file
*/
EXPORT_C torch_jit_script_module_t torch_jit_load(const char* filename);
EXPORT_C torch_jit_script_module_t torch_jit_load(const char* filename,
const bool requires_grad, const bool is_training);

/**
* Function to run the `forward` method of a Torch Module
* @param Torch Module containing the model
* @param vector of Torch Tensors as inputs to the model
* @param number of input Tensors in the input vector
* @param the output Tensor from running the model
* @param whether gradient is required
*/
EXPORT_C void torch_jit_module_forward(const torch_jit_script_module_t module,
const torch_tensor_t *inputs, const int nin,
torch_tensor_t output);
torch_tensor_t output, const bool requires_grad);

/**
* Function to delete a Torch Module to clean up
Expand Down
Loading