Skip to content

Commit

Permalink
shoehorn gather into using TensorMath.lua
Browse files Browse the repository at this point in the history
  • Loading branch information
hughperkins committed Jun 24, 2015
1 parent cf80f67 commit 0e469f4
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 45 deletions.
53 changes: 45 additions & 8 deletions TensorMath.lua
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,51 @@ wrap("mul",
{name=Tensor, method={default=1}},
{name=real}})

-- wrap("gather",
-- cname("gather"),
-- {{name=Tensor, default=true, returned=true,
-- init=function(arg)
-- return table.concat(
-- {
-- arg.__metatable.init(arg),
-- string.format("THLongStorage* %s_size = THLongTensor_newSizeOf(%s);", arg:carg(), arg.args[4]:carg()),
-- string.format("TH%s_resize(%s, %s_size, NULL);", Tensor, arg:carg(), arg:carg()),
-- string.format("THLongStorage_free(%s_size);", arg:carg())
-- }, '\n')
-- end
-- },
-- {name=Tensor},
-- {name="index"},
-- {name=Tensor, noreadadd=true}})

--function diag(arg)
-- print('diag', arg)
-- for k,v in pairs(arg) do
-- print(k,v)
-- end
-- return ''
--end

wrap("gather",
cname("gather"),
{{name=Tensor, default=true, returned=true,
init=function(arg)
return table.concat(
{
string.format('THClState *state = cltorch_getstate(L);'),
string.format('THLongStorage *%s_newSize = THLongStorage_newWithSize(%s->nDimension);', arg:carg(), arg.args[4]:carg()),
string.format('THLongStorage_rawCopy(%s_newSize, %s->size);', arg:carg(), arg.args[4]:carg()),
string.format('%s = THClTensor_new(state);', arg:carg()),
string.format('THClTensor_resize(state, %s, %s_newSize, NULL);', arg:carg(), arg:carg()),
string.format('THLongStorage_free(%s_newSize);', arg:carg()),
}, '\n')
end
},
{name=Tensor},
{name="index"},
{name=Tensor, noreadadd=true}
})

wrap("div",
cname("div"),
{{name=Tensor, default=true, returned=true, method={default='nil'}},
Expand Down Expand Up @@ -375,14 +420,6 @@ wrap("addcdiv",
{name=Tensor},
{name=Tensor}})

--wrap("gather",
-- cname("gather"),
-- {{name=Tensor, default=true, returned=true, method={default='nil'}},
-- {name=Tensor},
-- {name="index", default=1},
-- {name=Tensor}
--})

wrap("maskedFill",
cname("maskedFill"),
{{name=Tensor, returned=true, method={default='nil'}},
Expand Down
10 changes: 0 additions & 10 deletions lib/THCl/THClGather.cl
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ kernel void THClTensor_kernel_gather(
int totalElements
)
{
// global float *dst = dst_data + dst_info.offset;
// global const float *src = src_data + src_info.offset;
// global const float *idx = idx_data + idx_info.offset;

for (int _linearId = get_global_id(0);
_linearId < totalElements;
_linearId += get_global_size(0)) {
Expand Down Expand Up @@ -47,12 +43,6 @@ kernel void THClTensor_kernel_gather(
if( d != dim ) { // this only matters for the source, the others are
// unaffected by which dimension we are on. I think.
srcOffset += curDimIndex * src_info->strides[d];
} else {
// do nothing... add it later, once we know the value
}
if( get_global_id(0) == 1 ) {
// dst_data[d] = idx_info->strides[d];
// dst_data[1] += 100;
}
linearId /= idx_info->sizes[d];
}
Expand Down
60 changes: 37 additions & 23 deletions lib/THCl/THClGather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,22 @@ THCL_API void THClTensor_gather(THClState *state, THClTensor *self, THClTensor *
// index will be ndims too, though one of the dims should have length 1
// self will be ndims
int nDims = src->nDimension;
cout << "nDims " << nDims << endl;
// int dim = lua_dim - 1;
// cout << "nDims " << nDims << "dim " << dim << endl;

// cout << "self dims " << self->nDimension << " nelem=" << THClTensor_nElement(state, self) << endl;
// cout << "src dims " << src->nDimension << " nelem=" << THClTensor_nElement(state, src) << endl;
// cout << "index dims " << index->nDimension << " nelem=" << THClTensor_nElement(state, index) << endl;

// if(self == src) {
// cout << "self == src" << endl;
// } else {
// cout << "self != src" << endl;
// }

// cout << "self " << THClTensor_toString(state, self) << endl;
// cout << "src " << THClTensor_toString(state, src) << endl;
// cout << "index " << THClTensor_toString(state, index) << endl;

THArgCheck(nDims >= 2, 2, "Tensors should have at least 2 dimensions"); // I guess?
// THArgCheck(self->nDimension == nDims, 2, "All tensors should have same number of dims");
Expand All @@ -67,17 +82,26 @@ THCL_API void THClTensor_gather(THClState *state, THClTensor *self, THClTensor *
if( i != dim ) {
THArgCheck(THClTensor_size(state, src, i) == THClTensor_size(state, index, i), 3, ("index tensor must have same dimensions as source tensor, but dimension " + easycl::toString(i) + " doesnt match").c_str());
}
cout << "index strides[" << i << "]=" << index->stride[i] << endl;
// cout << "index strides[" << i << "]=" << index->stride[i] << endl;
}

if( self != src ) {
newSize = THLongStorage_newWithSize(index->nDimension);
THLongStorage_rawCopy(newSize, index->size);
// newSize->data[dim] = nIndex;
THClTensor_resize(state, self, newSize, NULL);
THLongStorage_free(newSize);
}

newSize = THLongStorage_newWithSize(index->nDimension);
THLongStorage_rawCopy(newSize, index->size);
// newSize->data[dim] = nIndex;
THClTensor_resize(state, self, newSize, NULL);
THLongStorage_free(newSize);
// cout << "self dims " << self->nDimension << " nelem=" << THClTensor_nElement(state, self) << endl;
// cout << "src dims " << src->nDimension << " nelem=" << THClTensor_nElement(state, src) << endl;
// cout << "index dims " << index->nDimension << " nelem=" << THClTensor_nElement(state, index) << endl;
// cout << "self " << THClTensor_toString(state, self) << endl;
// cout << "src " << THClTensor_toString(state, src) << endl;
// cout << "index " << THClTensor_toString(state, index) << endl;

// This is just here to prove we are actually executing thi function :-)
THClTensor_fill(state, self, 0);
// THClTensor_fill(state, self, -99);

// since self is write-only, and index and src are read-only, ie none are read-write
// so, we dnot need to worry about contiguity (at least, not from point of view of correctness)
Expand All @@ -93,11 +117,11 @@ THCL_API void THClTensor_gather(THClState *state, THClTensor *self, THClTensor *
TensorInfoCl selfInfoCl(self);
TensorInfoCl srcInfoCl(src);
TensorInfoCl indexInfoCl(index);
cout << "indexInfo.dims=" << index->nDimension << endl;
cout << "indexInfo.dims=" << indexInfoCl.dims << endl;
for( int i = 0; i < nDims; i++ ) {
cout << "index strides[" << i << "]=" << indexInfoCl.strides[i] << endl;
}
// cout << "indexInfo.dims=" << index->nDimension << endl;
// cout << "indexInfo.dims=" << indexInfoCl.dims << endl;
// for( int i = 0; i < nDims; i++ ) {
// cout << "index strides[" << i << "]=" << indexInfoCl.strides[i] << endl;
// }

const dim3 block = getApplyBlock(state);

Expand Down Expand Up @@ -149,10 +173,6 @@ static std::string getTemplate() {
" int totalElements\n"
")\n"
"{\n"
"// global float *dst = dst_data + dst_info.offset;\n"
"// global const float *src = src_data + src_info.offset;\n"
"// global const float *idx = idx_data + idx_info.offset;\n"
"\n"
" for (int _linearId = get_global_id(0);\n"
" _linearId < totalElements;\n"
" _linearId += get_global_size(0)) {\n"
Expand Down Expand Up @@ -181,12 +201,6 @@ static std::string getTemplate() {
" if( d != dim ) { // this only matters for the source, the others are\n"
" // unaffected by which dimension we are on. I think.\n"
" srcOffset += curDimIndex * src_info->strides[d];\n"
" } else {\n"
" // do nothing... add it later, once we know the value\n"
" }\n"
" if( get_global_id(0) == 1 ) {\n"
"// dst_data[d] = idx_info->strides[d];\n"
"// dst_data[1] += 100;\n"
" }\n"
" linearId /= idx_info->sizes[d];\n"
" }\n"
Expand Down
27 changes: 27 additions & 0 deletions lib/THCl/THClTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "THClTensorCopy.h"
#include "THAtomic.h"
//}
#include "util/easycl_stringhelper.h"

#include <iostream>

Expand Down Expand Up @@ -847,3 +848,29 @@ int THClTensor_checkGPU(THClState *state, unsigned int nTensors, ...)
// return valid;
//#endif
}

std::string THClTensor_toString(THClState *state, const THClTensor *tensor) {
string res = "";
res += "THClTensor{";
res += "size={";
for( int i = 0; i < tensor->nDimension; i++ ) {
if(i > 0) {
res += ",";
}
res += easycl::toString(tensor->size[i]);
}
res += "},";
res += "stride={";
for( int i = 0; i < tensor->nDimension; i++ ) {
if(i > 0) {
res += ",";
}
res += easycl::toString(tensor->stride[i]);
}
res += "},";
res += "offset=" + easycl::toString(tensor->storageOffset);
res += ",nElem=" + easycl::toString(THClTensor_nElement(state, tensor));
res += "}";
return res;
}

5 changes: 5 additions & 0 deletions lib/THCl/THClTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,9 @@ THCL_API float THClTensor_get4d(THClState *state, const THClTensor *tensor, long
THCL_API int THClTensor_getDevice(THClState *state, const THClTensor *self);
THCL_API int THClTensor_checkGPU(THClState *state, unsigned int nTensors, ...);

// new
#ifdef __cplusplus
THCL_API std::string THClTensor_toString(THClState *state, const THClTensor *tensor);
#endif // __cplusplus

#endif
8 changes: 6 additions & 2 deletions lib/THCl/THClTensorMathPairwise.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#include <string>

#include "THClTensorMath.h"
#include "THClGeneral.h"
//#include "THCBlas.h"
Expand All @@ -8,6 +6,9 @@
#include "THClApply.h"
//#include "THCReduce.cuh"

#include <iostream>
#include <string>

using namespace std;

#ifndef DIVUP
Expand Down Expand Up @@ -75,14 +76,17 @@ void THClTensor_mul(THClState *state, THClTensor *self_, THClTensor *src_, float

void THClTensor_div(THClState* state, THClTensor *self_, THClTensor *src_, float value)
{
cout << "THClTensorMathPairwise.THClTensor_div(state, self_, src_, float value)" << endl;
THAssert(THClTensor_checkGPU(state, 2, self_, src_));
THArgCheck(value != 0.0f, 3, "divide by zero");

if (self_ == src_) {
cout << "self_ == src_" << endl;
if (!THClTensor_pointwiseApply1(state, self_, TensorMulConstantOp(1.0f / value))) {
THArgCheck(false, 2, CLTORCH_DIM_WARNING);
}
} else {
cout << "self_ != src_, resizing self_ as src_" << endl;
THClTensor_resizeAs(state, self_, src_);

if (!THClTensor_pointwiseApply2(state, self_, src_, TensorMulConstantOp(1.0f / value))) {
Expand Down
2 changes: 1 addition & 1 deletion test/test-tensor.lua
Original file line number Diff line number Diff line change
Expand Up @@ -703,9 +703,9 @@ if os.getenv('PROTOTYPING') ~= nil then
print('res', res)
-- print('gather cl', torch.gather(acl, 1, idxcl))
print('a:gather(1, idx)', a:gather(1, idx))
print('torch.gather(1, idxcl)', torch.gather(acl, 1, idxcl))
print('acl:gather(1, idxcl)', acl:gather(1, idxcl))


-- x = torch.range(1,12):double():resize(3,4):cl()
-- print('x', x)
-- mask = torch.ByteTensor(2,6):bernoulli():cl()
Expand Down
2 changes: 1 addition & 1 deletion torch/generic/Tensor.c
Original file line number Diff line number Diff line change
Expand Up @@ -1273,7 +1273,7 @@ static const struct luaL_Reg torch_Tensor_(_) [] = {
{"clone", torch_Tensor_(clone)},
{"contiguous", torch_Tensor_(contiguous)},
{"resizeAs", torch_Tensor_(resizeAs)},
{"gather", torch_Tensor_(gather)},
// {"gather", torch_Tensor_(gather)},
{"resize", torch_Tensor_(resize)},
{"narrow", torch_Tensor_(narrow)},
{"sub", torch_Tensor_(sub)},
Expand Down

0 comments on commit 0e469f4

Please sign in to comment.