Skip to content

Commit

Permalink
RVV: ShuffleChannel, fp32, pack1
Browse files Browse the repository at this point in the history
  • Loading branch information
thelastlin committed Aug 20, 2023
1 parent 034a301 commit 326189f
Showing 1 changed file with 81 additions and 9 deletions.
90 changes: 81 additions & 9 deletions src/layer/riscv/shufflechannel_riscv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,41 @@ int ShuffleChannel_riscv::forward(const Mat& bottom_blob, Mat& top_blob, const O
// group too large or shuffle inside elempack
if (_group > elempack || (_group > 8 && elempack > 8) || channels % _group != 0)
{
Option opt_pack = opt;
opt_pack.blob_allocator = opt.workspace_allocator;

// convert to pack1
Mat bottom_blob_unpacked;
convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack);

convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt);
// convert packing won't change w,h
int channels_unpacked = bottom_blob_unpacked.c;
size_t elemsize_unpacked = bottom_blob_unpacked.elemsize;
int _group_unpack = reverse ? channels_unpacked / group : group;
if (channels_unpacked % group != 0)
{
// reject invalid group
return -100;
}
Mat top_blob_unpacked;
int ret = ShuffleChannel::forward(bottom_blob_unpacked, top_blob_unpacked, opt_pack);
if (ret != 0)
return ret;
top_blob_unpacked.create(w, h, channels_unpacked, elemsize_unpacked, opt.blob_allocator);

int channels_unpacked_per_group = channels_unpacked / _group_unpack;
const size_t feature_sz = (size_t)w * h;
for (int i = 0; i < _group_unpack; i++)
{
for (int j = 0; j < channels_unpacked_per_group; j++)
{
float* p_dst = top_blob_unpacked.channel(_group_unpack * j + i);
const float* p_src = bottom_blob_unpacked.channel(channels_unpacked_per_group * i + j);
int n = feature_sz;
while (n > 0)
{
size_t vl = vsetvl_e32m8(n);
vfloat32m8_t _src = vle32_v_f32m8(p_src, vl);
vse32_v_f32m8(p_dst, _src, vl);
n -= vl;
p_src += vl;
p_dst += vl;
}
}
}
convert_packing(top_blob_unpacked, top_blob, elempack, opt);
return 0;
}
Expand Down Expand Up @@ -475,8 +500,55 @@ int ShuffleChannel_riscv::forward(const Mat& bottom_blob, Mat& top_blob, const O

return 0;
}
#endif

#if __riscv_vector
if (elempack == 1)
{
#endif
if (channels % group != 0)
{
// reject invalid group
return -100;
}

top_blob.create(w, h, channels, elemsize, opt.blob_allocator);
if (top_blob.empty())
return -100;

#if __riscv_vector
const size_t feature_sz = (size_t)w * h;
#else
const size_t feature_sz = (size_t)w * h * elemsize;
#endif
for (int i = 0; i < _group; i++)
{
for (int j = 0; j < channels_per_group; j++)
{
#if __riscv_vector
float* p_dst = top_blob.channel(_group * j + i);
const float* p_src = bottom_blob.channel(channels_per_group * i + j);
int n = feature_sz;
while(n>0)
{
size_t vl = vsetvl_e32m8(n);
vfloat32m8_t _src = vle32_v_f32m8(p_src, vl);
vse32_v_f32m8(p_dst, _src, vl);
n -= vl;
p_src += vl;
p_dst += vl;
}
#else
int src_q = channels_per_group * i + j;
int dst_q = _group * j + i;
memcpy(top_blob.channel(dst_q), bottom_blob.channel(src_q), feature_sz);
#endif // __riscv_vector
return ShuffleChannel::forward(bottom_blob, top_blob, opt);
}
}
#if __riscv_vector
}
#endif // __riscv_vector

return 0;
}
} // namespace ncnn

0 comments on commit 326189f

Please sign in to comment.