-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathtorch.cpp
354 lines (288 loc) · 10.6 KB
/
torch.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
#define _GLIBCXX_USE_CXX11_ABI 0
#include <torch/torch.h>
#include <torch/script.h>
#include "torch.hpp"
#include <iostream>
#include <stdlib.h>
#include <exception>
#include <string>
#define HANDLE_TH_ERRORS \
try {
#define END_HANDLE_TH_ERRORS(errVar, retVal) \
} \
catch (const torch::Error& e) { \
auto msg = e.what_without_backtrace(); \
auto err = Torch_Error{ \
.message = new char[strlen(msg)+1], \
}; \
std::strcpy(err.message, msg); \
*errVar = err; \
return retVal; \
} \
catch (const std::exception& e) { \
auto msg = e.what(); \
auto err = Torch_Error{ \
.message = new char[strlen(msg)+1], \
}; \
std::strcpy(err.message, msg); \
*errVar = err; \
return retVal; \
}
struct Torch_Tensor {
torch::Tensor tensor;
};
struct Torch_JITModule {
std::shared_ptr<torch::jit::script::Module> module;
};
struct Torch_JITModule_Method {
torch::jit::script::Method& run;
};
torch::TensorOptions Torch_ConvertDataTypeToOptions(Torch_DataType dtype) {
torch::TensorOptions options;
switch (dtype) {
case Torch_Byte:
options = torch::TensorOptions(torch::kByte);
break;
case Torch_Char:
options = torch::TensorOptions(torch::kChar);
break;
case Torch_Short:
options = torch::TensorOptions(torch::kShort);
break;
case Torch_Int:
options = torch::TensorOptions(torch::kInt);
break;
case Torch_Long:
options = torch::TensorOptions(torch::kLong);
break;
case Torch_Half:
options = torch::TensorOptions(torch::kHalf);
break;
case Torch_Float:
options = torch::TensorOptions(torch::kFloat);
break;
case Torch_Double:
options = torch::TensorOptions(torch::kDouble);
break;
default:
// TODO handle other types
break;
}
return options;
}
Torch_DataType Torch_ConvertScalarTypeToDataType(torch::ScalarType type) {
Torch_DataType dtype;
switch (type) {
case torch::kByte:
dtype = Torch_Byte;
break;
case torch::kChar:
dtype = Torch_Char;
break;
case torch::kShort:
dtype = Torch_Short;
break;
case torch::kInt:
dtype = Torch_Int;
break;
case torch::kLong:
dtype = Torch_Long;
break;
case torch::kHalf:
dtype = Torch_Half;
break;
case torch::kFloat:
dtype = Torch_Float;
break;
case torch::kDouble:
dtype = Torch_Double;
break;
default:
dtype = Torch_Unknown;
}
return dtype;
}
Torch_IValue Torch_ConvertIValueToTorchIValue(torch::IValue value) {
if (value.isTensor()) {
auto tensor = new Torch_Tensor();
tensor->tensor = value.toTensor();
return Torch_IValue{
.itype = Torch_IValueTypeTensor,
.data_ptr = tensor,
};
} else if (value.isTuple()) {
auto elements = value.toTuple()->elements();
auto tuple = (Torch_IValueTuple*)malloc(sizeof(Torch_IValueTuple));
auto values = (Torch_IValue*)malloc(sizeof(Torch_IValue) * elements.size());
for(std::vector<torch::IValue>::size_type i = 0; i != elements.size(); i++) {
*(values + i) = Torch_ConvertIValueToTorchIValue(elements[i]);
}
tuple->values = values;
tuple->length = elements.size();
return Torch_IValue{
.itype = Torch_IValueTypeTuple,
.data_ptr = tuple,
};
}
return Torch_IValue{};
}
torch::IValue Torch_ConvertTorchIValueToIValue(Torch_IValue value) {
if (value.itype == Torch_IValueTypeTensor) {
auto tensor = (Torch_Tensor*)value.data_ptr;
return tensor->tensor;
} else if (value.itype == Torch_IValueTypeTuple) {
auto tuple = (Torch_IValueTuple*)value.data_ptr;
std::vector<torch::IValue> values;
values.reserve(tuple->length);
for (int i = 0; i < tuple->length; i++) {
auto ival = *(tuple->values+i);
values.push_back(Torch_ConvertTorchIValueToIValue(ival));
}
return torch::jit::Tuple::create(std::move(values));
}
// TODO handle this case
return 0;
}
Torch_TensorContext Torch_NewTensor(void* input_data, int64_t* dimensions, int n_dim, Torch_DataType dtype) {
torch::TensorOptions options = Torch_ConvertDataTypeToOptions(dtype);
std::vector<int64_t> sizes;
sizes.assign(dimensions, dimensions + n_dim);
torch::Tensor ten = torch::from_blob(input_data, torch::IntList(sizes), options);
auto tensor = new Torch_Tensor();
tensor->tensor = ten;
return (void *)tensor;
}
void* Torch_TensorValue(Torch_TensorContext ctx) {
auto tensor = (Torch_Tensor*)ctx;
return tensor->tensor.data_ptr();
}
Torch_DataType Torch_TensorType(Torch_TensorContext ctx) {
auto tensor = (Torch_Tensor*)ctx;
auto type = tensor->tensor.scalar_type();
return Torch_ConvertScalarTypeToDataType(type);
}
int64_t* Torch_TensorShape(Torch_TensorContext ctx, size_t* dims){
auto tensor = (Torch_Tensor*)ctx;
auto sizes = tensor->tensor.sizes();
*dims = sizes.size();
return (int64_t*)sizes.data();
}
void Torch_PrintTensors(Torch_TensorContext* tensors, size_t input_size) {
for (int i = 0; i < input_size; i++) {
auto ctx = tensors+i;
auto tensor = (Torch_Tensor*)*ctx;
std::cout << tensor->tensor << "\n";
}
}
void Torch_DeleteTensor(Torch_TensorContext ctx) {
auto tensor = (Torch_Tensor*)ctx;
delete tensor;
}
Torch_JITModuleContext Torch_CompileTorchScript(char* cstring_script, Torch_Error* error) {
HANDLE_TH_ERRORS
std::string script(cstring_script);
auto mod = new Torch_JITModule();
mod->module = torch::jit::compile(script);
return (void *)mod;
END_HANDLE_TH_ERRORS(error, NULL)
}
Torch_JITModuleContext Torch_LoadJITModule(char* cstring_path, Torch_Error* error) {
HANDLE_TH_ERRORS
std::string module_path(cstring_path);
auto mod = new Torch_JITModule();
mod->module = torch::jit::load(module_path);
return (void *)mod;
END_HANDLE_TH_ERRORS(error, NULL)
}
void Torch_ExportJITModule(Torch_JITModuleContext ctx, char* cstring_path, Torch_Error* error) {
HANDLE_TH_ERRORS
std::string module_path(cstring_path);
auto mod = (Torch_JITModule*)ctx;
mod->module->save(module_path);
END_HANDLE_TH_ERRORS(error,)
}
Torch_JITModuleMethodContext Torch_JITModuleGetMethod(Torch_JITModuleContext ctx, char* cstring_method, Torch_Error* error) {
HANDLE_TH_ERRORS
std::string method_name(cstring_method);
auto mod = (Torch_JITModule*)ctx;
auto met = new Torch_JITModule_Method{
mod->module->get_method(method_name)
};
return (void *)met;
END_HANDLE_TH_ERRORS(error, NULL)
}
char** Torch_JITModuleGetMethodNames(Torch_JITModuleContext ctx, size_t* len) {
auto mod = (Torch_JITModule*)ctx;
auto size = mod->module->get_methods().size();
*len = size;
auto result = (char**)malloc(sizeof(char*) * size);
int i = 0;
for (auto& method : mod->module->get_methods()) {
auto key = method.value()->name();
auto ckey = new char[key.length() + 1];
strcpy(ckey, key.c_str());
*(result + i) = ckey;
i++;
}
return result;
}
Torch_IValue Torch_JITModuleMethodRun(Torch_JITModuleMethodContext ctx, Torch_IValue* inputs, size_t input_size, Torch_Error* error) {
HANDLE_TH_ERRORS
auto met = (Torch_JITModule_Method*)ctx;
std::vector<torch::IValue> inputs_vec;
for (int i = 0; i < input_size; i++) {
auto ival = *(inputs+i);
inputs_vec.push_back(Torch_ConvertTorchIValueToIValue(ival));
}
auto res = met->run(inputs_vec);
return Torch_ConvertIValueToTorchIValue(res);
END_HANDLE_TH_ERRORS(error, Torch_IValue{})
}
Torch_ModuleMethodArgument* Torch_JITModuleMethodArguments(Torch_JITModuleMethodContext ctx, size_t* res_size) {
auto met = (Torch_JITModule_Method*)ctx;
auto schema = met->run.getSchema();
auto arguments = schema.arguments();
auto result = (Torch_ModuleMethodArgument*)malloc(sizeof(Torch_ModuleMethodArgument)*arguments.size());
*res_size = arguments.size();
for(std::vector<torch::Argument>::size_type i = 0; i != arguments.size(); i++) {
auto name = arguments[i].name();
char *cstr_name = new char[name.length() + 1];
strcpy(cstr_name, name.c_str());
auto type = arguments[i].type()->str();
char *cstr_type = new char[type.length() + 1];
strcpy(cstr_type, type.c_str());
*(result + i) = Torch_ModuleMethodArgument{
.name = cstr_name,
.typ = cstr_type,
};
}
return result;
}
Torch_ModuleMethodArgument* Torch_JITModuleMethodReturns(Torch_JITModuleMethodContext ctx, size_t* res_size) {
auto met = (Torch_JITModule_Method*)ctx;
auto schema = met->run.getSchema();
auto arguments = schema.returns();
auto result = (Torch_ModuleMethodArgument*)malloc(sizeof(Torch_ModuleMethodArgument)*arguments.size());
*res_size = arguments.size();
for(std::vector<torch::Argument>::size_type i = 0; i != arguments.size(); i++) {
auto name = arguments[i].name();
char *cstr_name = new char[name.length() + 1];
strcpy(cstr_name, name.c_str());
auto type = arguments[i].type()->str();
char *cstr_type = new char[type.length() + 1];
strcpy(cstr_type, type.c_str());
*(result + i) = Torch_ModuleMethodArgument{
.name = cstr_name,
.typ = cstr_type,
};
}
return result;
}
void Torch_DeleteJITModuleMethod(Torch_JITModuleMethodContext ctx) {
auto med = (Torch_JITModule_Method*)ctx;
delete med;
}
void Torch_DeleteJITModule(Torch_JITModuleContext ctx) {
auto mod = (Torch_JITModule*)ctx;
delete mod;
}