-
Notifications
You must be signed in to change notification settings - Fork 0
/
blob.h
130 lines (114 loc) · 3.99 KB
/
blob.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#ifndef CAFFE2_CORE_BLOB_H_
#define CAFFE2_CORE_BLOB_H_
#include <cstddef>
#include <sstream>
#include <typeinfo>
#include <type_traits>
#include <vector>
#include "caffe2/core/common.h"
#include <ATen/core/blob.h>
#include <c10/util/typeid.h>
#include "caffe2/core/logging.h"
#include "caffe2/core/tensor.h"
#include "caffe2/core/tensor_int8.h"
namespace caffe2 {
inline bool BlobIsInt8TensorCPUType(const Blob& blob) {
return blob.meta().Match<int8::Int8TensorCPU>();
}
inline bool BlobIsTensorType(const Blob& blob, DeviceType device_type) {
bool is_match = blob.meta().Match<Tensor>();
if (!is_match) {
return false;
}
const Tensor* tensor = &blob.Get<Tensor>();
return tensor && *tensor && tensor->GetDeviceType() == device_type;
}
inline Tensor* BlobSetTensor(Blob* blob, Tensor&& tensor) {
return blob->Reset<Tensor>(new Tensor(std::move(tensor)));
}
inline Tensor GetSizedTensorWithOptions(
Tensor&& previous_tensor,
at::IntArrayRef dims,
at::TensorOptions options) {
Tensor tensor = std::move(previous_tensor);
if (!tensor.defined()) {
return caffe2::empty(dims, options);
}
if (tensor.GetDevice() == options.device() ||
(!tensor.GetDevice().has_index() &&
tensor.GetDeviceType() == options.device().type())) {
if (tensor.sizes() != dims) {
// Resize when the dims doesn't match
tensor.Resize(dims);
}
if (tensor.dtype() == options.dtype()) {
tensor.raw_mutable_data();
} else {
// create a new Tensor when the data_type doesn't match
return caffe2::empty(dims, options);
}
return tensor;
}
return caffe2::empty(dims, options);
}
// need to keep both functions that returns Tensor* and the one
// returns Tensor for clangr codemod
inline Tensor*
BlobGetMutableTensor(Blob* blob, at::IntArrayRef dims, at::TensorOptions options) {
if (blob->IsType<Tensor>()) {
Tensor* tensor = blob->GetMutable<Tensor>();
if (*tensor) {
// We only compare device_type if the index is not set since there are Tensors
// TODO: remove the extra check when all the Tensors are properly initialized
const auto tensorDevice = tensor->GetDevice();
if (tensorDevice == options.device() || (!tensorDevice.has_index() && tensor->GetDeviceType() == options.device().type())) {
if (tensor->sizes() != dims) {
// Resize when the dims doesn't match
tensor->Resize(dims);
}
tensor->raw_mutable_data(options.dtype());
return tensor;
}
// create a new Tensor when device doesn't match
}
}
VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<Tensor>()
<< " dims: " << dims;
// << " options: " << options; (operator<< for Options is in at:: now)
return BlobSetTensor(blob, caffe2::empty(dims, options));
}
inline Tensor
XBlobGetMutableTensor(Blob* blob, at::IntArrayRef dims, at::TensorOptions options) {
return BlobGetMutableTensor(blob, dims, options)->UnsafeSharedInstance();
}
inline Tensor* BlobGetMutableTensor(Blob* blob, DeviceType device_type) {
if (blob->IsType<Tensor>()) {
Tensor* tensor = blob->GetMutable<Tensor>();
if (*tensor && tensor->GetDeviceType() == device_type) {
return tensor;
}
}
// if we're here, then either Blob didn't hold a Tensor
// or that Tensor had the wrong DeviceType.
VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<Tensor>()
<< " DeviceType:" << device_type;
return BlobSetTensor(blob, Tensor(device_type));
}
inline const Tensor& BlobGetTensor(const Blob& blob, DeviceType device_type) {
if (blob.IsType<Tensor>()) {
const auto& tensor = blob.Get<Tensor>();
if (tensor.GetDeviceType() == device_type) {
return tensor;
}
}
CAFFE_THROW("Blob didn't contain a Tensor or the device_type doesn't match");
}
inline Tensor BlobGetTensorOrUndefined(const Blob& blob) {
if (blob.IsType<Tensor>()) {
return blob.Get<Tensor>().UnsafeSharedInstance();
} else {
return Tensor();
}
}
} // namespace caffe2
#endif // CAFFE2_CORE_BLOB_H_