Skip to content

Commit

Permalink
Merge pull request #439 from genn-team/custom_update_merge_fix
Browse files Browse the repository at this point in the history
Nasty bugs in custom update kernels
  • Loading branch information
neworderofjamie authored Jul 22, 2021
2 parents 54ec28f + ea2f548 commit 2dfd51f
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 33 deletions.
92 changes: 66 additions & 26 deletions src/genn/backends/cuda/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,32 @@ class Timer
const bool m_TimingEnabled;
const bool m_SynchroniseOnStop;
};
//-----------------------------------------------------------------------
template<typename T, typename G>
void genGroupStartID(CodeStream &os, size_t &idStart, size_t &totalConstMem,
const T &m, G getPaddedNumThreads)
{
// Calculate size of array
const size_t sizeBytes = m.getGroups().size() * sizeof(unsigned int);

// If there is enough constant memory left for group, declare it in constant memory space
if(sizeBytes < totalConstMem) {
os << "__device__ __constant__ ";
totalConstMem -= sizeBytes;
}
// Otherwise, declare it in global memory space
else {
os << "__device__ ";
}

// Declare array of starting thread indices for each neuron group
os << "unsigned int d_merged" << T::name << "GroupStartID" << m.getIndex() << "[] = {";
for(const auto &ng : m.getGroups()) {
os << idStart << ", ";
idStart += getPaddedNumThreads(ng.get());
}
os << "};" << std::endl;
}
//-----------------------------------------------------------------------
void genGroupStartIDs(CodeStream &, size_t&, size_t&)
{
Expand All @@ -94,26 +118,7 @@ void genGroupStartIDs(CodeStream &os, size_t &idStart, size_t &totalConstMem,
{
// Loop through merged groups
for(const auto &m : mergedGroups) {
// Calculate size of array
const size_t sizeBytes = m.getGroups().size() * sizeof(unsigned int);

// If there is enough constant memory left for group, declare it in constant memory space
if(sizeBytes < totalConstMem) {
os << "__device__ __constant__ ";
totalConstMem -= sizeBytes;
}
// Otherwise, declare it in global memory space
else {
os << "__device__ ";
}

// Declare array of starting thread indices for each neuron group
os << "unsigned int d_merged" << T::name << "GroupStartID" << m.getIndex() << "[] = {";
for(const auto &ng : m.getGroups()) {
os << idStart << ", ";
idStart += getPaddedNumThreads(ng.get());
}
os << "};" << std::endl;
genGroupStartID(os, idStart, totalConstMem, m, getPaddedNumThreads);
}

// Generate any remaining groups
Expand All @@ -128,7 +133,35 @@ void genMergedKernelDataStructures(CodeStream &os, size_t &totalConstMem,
size_t idStart = 0;
genGroupStartIDs(os, std::ref(idStart), std::ref(totalConstMem), args...);
}
//-----------------------------------------------------------------------
void genFilteredGroupStartIDs(CodeStream &, size_t&, size_t&)
{
}
//-----------------------------------------------------------------------
template<typename T, typename G, typename F, typename ...Args>
void genFilteredGroupStartIDs(CodeStream &os, size_t &idStart, size_t &totalConstMem,
const std::vector<T> &mergedGroups, G getPaddedNumThreads, F filter,
Args... args)
{
// Loop through merged groups
for(const auto &m : mergedGroups) {
if(filter(m)) {
genGroupStartID(os, idStart, totalConstMem, m, getPaddedNumThreads);
}
}

// Generate any remaining groups
genFilteredGroupStartIDs(os, idStart, totalConstMem, args...);
}
//-----------------------------------------------------------------------
template<typename ...Args>
void genFilteredMergedKernelDataStructures(CodeStream &os, size_t &totalConstMem,
Args... args)
{
// Generate group start id arrays
size_t idStart = 0;
genFilteredGroupStartIDs(os, std::ref(idStart), std::ref(totalConstMem), args...);
}
//-----------------------------------------------------------------------
template<typename T, typename G>
size_t getNumMergedGroupThreads(const std::vector<T> &groups, G getNumThreads)
Expand Down Expand Up @@ -593,12 +626,6 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged
getGroupStartIDSize(modelMerged.getMergedSynapseDynamicsGroups()) +
getGroupStartIDSize(modelMerged.getMergedNeuronUpdateGroups()));
size_t totalConstMem = (getChosenDeviceSafeConstMemBytes() > timestepGroupStartIDSize) ? (getChosenDeviceSafeConstMemBytes() - timestepGroupStartIDSize) : 0;
genMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedCustomUpdateGroups(),
[this](const CustomUpdateInternal &cg){ return padKernelSize(cg.getSize(), KernelCustomUpdate); });
genMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedCustomUpdateWUGroups(),
[&model, this](const CustomUpdateWUInternal &cg){ return getPaddedNumCustomUpdateWUThreads(cg, model.getBatchSize()); });
genMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedCustomUpdateTransposeWUGroups(),
[&model, this](const CustomUpdateWUInternal &cg){ return getPaddedNumCustomUpdateTransposeWUThreads(cg, model.getBatchSize()); });

// Build set containing union of all custom update groupsnames
std::set<std::string> customUpdateGroups;
Expand All @@ -618,6 +645,15 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged
|| std::any_of(modelMerged.getMergedCustomUpdateWUGroups().cbegin(), modelMerged.getMergedCustomUpdateWUGroups().cend(),
[&g](const CustomUpdateWUGroupMerged &c) { return (c.getArchetype().getUpdateGroupName() == g); }))
{
genFilteredMergedKernelDataStructures(os, totalConstMem,
modelMerged.getMergedCustomUpdateGroups(),
[this](const CustomUpdateInternal &cg){ return padKernelSize(cg.getSize(), KernelCustomUpdate); },
[g](const CustomUpdateGroupMerged &cg){ return cg.getArchetype().getUpdateGroupName() == g; },

modelMerged.getMergedCustomUpdateWUGroups(),
[&model, this](const CustomUpdateWUInternal &cg){ return getPaddedNumCustomUpdateWUThreads(cg, model.getBatchSize()); },
[g](const CustomUpdateWUGroupMerged &cg){ return cg.getArchetype().getUpdateGroupName() == g; });

os << "extern \"C\" __global__ void " << KernelNames[KernelCustomUpdate] << g << "(" << model.getTimePrecision() << " t)" << std::endl;
{
CodeStream::Scope b(os);
Expand All @@ -641,6 +677,10 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged
if(std::any_of(modelMerged.getMergedCustomUpdateTransposeWUGroups().cbegin(), modelMerged.getMergedCustomUpdateTransposeWUGroups().cend(),
[&g](const CustomUpdateTransposeWUGroupMerged &c){ return (c.getArchetype().getUpdateGroupName() == g); }))
{
genFilteredMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedCustomUpdateTransposeWUGroups(),
[&model, this](const CustomUpdateWUInternal &cg){ return getPaddedNumCustomUpdateTransposeWUThreads(cg, model.getBatchSize()); },
[g](const CustomUpdateTransposeWUGroupMerged &cg){ return cg.getArchetype().getUpdateGroupName() == g; });

os << "extern \"C\" __global__ void " << KernelNames[KernelCustomTransposeUpdate] << g << "(" << model.getTimePrecision() << " t)" << std::endl;
{
CodeStream::Scope b(os);
Expand Down
58 changes: 51 additions & 7 deletions src/genn/backends/opencl/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,40 @@ void genMergedKernelDataStructures(CodeStream &os, Args... args)
genGroupStartIDs(os, std::ref(idStart), args...);
}
//-----------------------------------------------------------------------
void genFilteredGroupStartIDs(CodeStream &, size_t &)
{
}
//-----------------------------------------------------------------------
template<typename T, typename G, typename F, typename ...Args>
void genFilteredGroupStartIDs(CodeStream &os, size_t &idStart,
const std::vector<T> &mergedGroups, G getPaddedNumThreads, F filter,
Args... args)
{
// Loop through merged groups
for(const auto &m : mergedGroups) {
if(filter(m)) {
// Declare array of starting thread indices for each neuron group
os << "__constant unsigned int d_merged" << T::name << "GroupStartID" << m.getIndex() << "[] = {";
for(const auto &ng : m.getGroups()) {
os << idStart << ", ";
idStart += getPaddedNumThreads(ng.get());
}
os << "};" << std::endl;
}
}

// Generate any remaining groups
genFilteredGroupStartIDs(os, idStart, args...);
}
//-----------------------------------------------------------------------
template<typename ...Args>
void genFilteredMergedKernelDataStructures(CodeStream &os, Args... args)
{
// Generate group start id arrays
size_t idStart = 0;
genFilteredGroupStartIDs(os, std::ref(idStart), args...);
}
//-----------------------------------------------------------------------
void genReadEventTiming(CodeStream &os, const std::string &name)
{
os << "const cl_ulong tmpStart = " << name << "Event.getProfilingInfo<CL_PROFILING_COMMAND_START>();" << std::endl;
Expand Down Expand Up @@ -845,13 +879,6 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged


// Generate data structure for accessing merged groups from within custom update kernels
genMergedKernelDataStructures(
customUpdateKernels,
modelMerged.getMergedCustomUpdateGroups(), [this](const CustomUpdateInternal &cg) { return padKernelSize(cg.getSize(), KernelCustomUpdate); },
modelMerged.getMergedCustomUpdateWUGroups(), [&model, this](const CustomUpdateWUInternal &cg) { return getPaddedNumCustomUpdateWUThreads(cg, model.getBatchSize()); });
genMergedKernelDataStructures(
customUpdateKernels,
modelMerged.getMergedCustomUpdateTransposeWUGroups(), [&model, this](const CustomUpdateWUInternal &cg) { return getPaddedNumCustomUpdateTransposeWUThreads(cg, model.getBatchSize()); });
customUpdateKernels << std::endl;

// Generate kernels used to populate merged structs
Expand All @@ -867,6 +894,17 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged
|| std::any_of(modelMerged.getMergedCustomUpdateWUGroups().cbegin(), modelMerged.getMergedCustomUpdateWUGroups().cend(),
[&g](const CustomUpdateWUGroupMerged &c) { return (c.getArchetype().getUpdateGroupName() == g.first); }))
{
genFilteredMergedKernelDataStructures(
customUpdateKernels,
modelMerged.getMergedCustomUpdateGroups(),
[this](const CustomUpdateInternal &cg) { return padKernelSize(cg.getSize(), KernelCustomUpdate); },
[g](const CustomUpdateGroupMerged &c){ return (c.getArchetype().getUpdateGroupName() == g.first); },

modelMerged.getMergedCustomUpdateWUGroups(),
[&model, this](const CustomUpdateWUInternal &cg) { return getPaddedNumCustomUpdateWUThreads(cg, model.getBatchSize()); },
[&g](const CustomUpdateWUGroupMerged &c) { return (c.getArchetype().getUpdateGroupName() == g.first); });


customUpdateKernels << "__attribute__((reqd_work_group_size(" << getKernelBlockSize(KernelCustomUpdate) << ", 1, 1)))" << std::endl;
customUpdateKernels << "__kernel void " << KernelNames[KernelCustomUpdate] << g.first << "(";
genMergedGroupKernelParams(customUpdateKernels, modelMerged.getMergedCustomUpdateGroups(), true);
Expand Down Expand Up @@ -897,6 +935,12 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged
if(std::any_of(modelMerged.getMergedCustomUpdateTransposeWUGroups().cbegin(), modelMerged.getMergedCustomUpdateTransposeWUGroups().cend(),
[&g](const CustomUpdateTransposeWUGroupMerged &c) { return (c.getArchetype().getUpdateGroupName() == g.first); }))
{
genFilteredMergedKernelDataStructures(
customUpdateKernels,
modelMerged.getMergedCustomUpdateTransposeWUGroups(),
[&model, this](const CustomUpdateWUInternal &cg) { return getPaddedNumCustomUpdateTransposeWUThreads(cg, model.getBatchSize()); },
[&g](const CustomUpdateTransposeWUGroupMerged &c) { return (c.getArchetype().getUpdateGroupName() == g.first); });

customUpdateKernels << "__attribute__((reqd_work_group_size(" << getKernelBlockSize(KernelCustomUpdate) << ", 8, 1)))" << std::endl;
customUpdateKernels << "__kernel void " << KernelNames[KernelCustomTransposeUpdate] << g.first << "(";
genMergedGroupKernelParams(customUpdateKernels, modelMerged.getMergedCustomUpdateTransposeWUGroups(), true);
Expand Down

0 comments on commit 2dfd51f

Please sign in to comment.