Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nasty bugs in custom update kernels #439

Merged
merged 4 commits into from
Jul 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 66 additions & 26 deletions src/genn/backends/cuda/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,32 @@ class Timer
const bool m_TimingEnabled;
const bool m_SynchroniseOnStop;
};
//-----------------------------------------------------------------------
template<typename T, typename G>
void genGroupStartID(CodeStream &os, size_t &idStart, size_t &totalConstMem,
const T &m, G getPaddedNumThreads)
{
// Calculate size of array
const size_t sizeBytes = m.getGroups().size() * sizeof(unsigned int);

// If there is enough constant memory left for group, declare it in constant memory space
if(sizeBytes < totalConstMem) {
os << "__device__ __constant__ ";
totalConstMem -= sizeBytes;
}
// Otherwise, declare it in global memory space
else {
os << "__device__ ";
}

// Declare array of starting thread indices for each neuron group
os << "unsigned int d_merged" << T::name << "GroupStartID" << m.getIndex() << "[] = {";
for(const auto &ng : m.getGroups()) {
os << idStart << ", ";
idStart += getPaddedNumThreads(ng.get());
}
os << "};" << std::endl;
}
//-----------------------------------------------------------------------
void genGroupStartIDs(CodeStream &, size_t&, size_t&)
{
Expand All @@ -94,26 +118,7 @@ void genGroupStartIDs(CodeStream &os, size_t &idStart, size_t &totalConstMem,
{
// Loop through merged groups
for(const auto &m : mergedGroups) {
// Calculate size of array
const size_t sizeBytes = m.getGroups().size() * sizeof(unsigned int);

// If there is enough constant memory left for group, declare it in constant memory space
if(sizeBytes < totalConstMem) {
os << "__device__ __constant__ ";
totalConstMem -= sizeBytes;
}
// Otherwise, declare it in global memory space
else {
os << "__device__ ";
}

// Declare array of starting thread indices for each neuron group
os << "unsigned int d_merged" << T::name << "GroupStartID" << m.getIndex() << "[] = {";
for(const auto &ng : m.getGroups()) {
os << idStart << ", ";
idStart += getPaddedNumThreads(ng.get());
}
os << "};" << std::endl;
genGroupStartID(os, idStart, totalConstMem, m, getPaddedNumThreads);
}

// Generate any remaining groups
Expand All @@ -128,7 +133,35 @@ void genMergedKernelDataStructures(CodeStream &os, size_t &totalConstMem,
size_t idStart = 0;
genGroupStartIDs(os, std::ref(idStart), std::ref(totalConstMem), args...);
}
//-----------------------------------------------------------------------
void genFilteredGroupStartIDs(CodeStream &, size_t&, size_t&)
{
}
//-----------------------------------------------------------------------
template<typename T, typename G, typename F, typename ...Args>
void genFilteredGroupStartIDs(CodeStream &os, size_t &idStart, size_t &totalConstMem,
const std::vector<T> &mergedGroups, G getPaddedNumThreads, F filter,
Args... args)
{
// Loop through merged groups
for(const auto &m : mergedGroups) {
if(filter(m)) {
genGroupStartID(os, idStart, totalConstMem, m, getPaddedNumThreads);
}
}

// Generate any remaining groups
genFilteredGroupStartIDs(os, idStart, totalConstMem, args...);
}
//-----------------------------------------------------------------------
template<typename ...Args>
void genFilteredMergedKernelDataStructures(CodeStream &os, size_t &totalConstMem,
Args... args)
{
// Generate group start id arrays
size_t idStart = 0;
genFilteredGroupStartIDs(os, std::ref(idStart), std::ref(totalConstMem), args...);
}
//-----------------------------------------------------------------------
template<typename T, typename G>
size_t getNumMergedGroupThreads(const std::vector<T> &groups, G getNumThreads)
Expand Down Expand Up @@ -593,12 +626,6 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged
getGroupStartIDSize(modelMerged.getMergedSynapseDynamicsGroups()) +
getGroupStartIDSize(modelMerged.getMergedNeuronUpdateGroups()));
size_t totalConstMem = (getChosenDeviceSafeConstMemBytes() > timestepGroupStartIDSize) ? (getChosenDeviceSafeConstMemBytes() - timestepGroupStartIDSize) : 0;
genMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedCustomUpdateGroups(),
[this](const CustomUpdateInternal &cg){ return padKernelSize(cg.getSize(), KernelCustomUpdate); });
genMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedCustomUpdateWUGroups(),
[&model, this](const CustomUpdateWUInternal &cg){ return getPaddedNumCustomUpdateWUThreads(cg, model.getBatchSize()); });
genMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedCustomUpdateTransposeWUGroups(),
[&model, this](const CustomUpdateWUInternal &cg){ return getPaddedNumCustomUpdateTransposeWUThreads(cg, model.getBatchSize()); });

// Build set containing union of all custom update groupsnames
std::set<std::string> customUpdateGroups;
Expand All @@ -618,6 +645,15 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged
|| std::any_of(modelMerged.getMergedCustomUpdateWUGroups().cbegin(), modelMerged.getMergedCustomUpdateWUGroups().cend(),
[&g](const CustomUpdateWUGroupMerged &c) { return (c.getArchetype().getUpdateGroupName() == g); }))
{
genFilteredMergedKernelDataStructures(os, totalConstMem,
modelMerged.getMergedCustomUpdateGroups(),
[this](const CustomUpdateInternal &cg){ return padKernelSize(cg.getSize(), KernelCustomUpdate); },
[g](const CustomUpdateGroupMerged &cg){ return cg.getArchetype().getUpdateGroupName() == g; },

modelMerged.getMergedCustomUpdateWUGroups(),
[&model, this](const CustomUpdateWUInternal &cg){ return getPaddedNumCustomUpdateWUThreads(cg, model.getBatchSize()); },
[g](const CustomUpdateWUGroupMerged &cg){ return cg.getArchetype().getUpdateGroupName() == g; });

os << "extern \"C\" __global__ void " << KernelNames[KernelCustomUpdate] << g << "(" << model.getTimePrecision() << " t)" << std::endl;
{
CodeStream::Scope b(os);
Expand All @@ -641,6 +677,10 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged
if(std::any_of(modelMerged.getMergedCustomUpdateTransposeWUGroups().cbegin(), modelMerged.getMergedCustomUpdateTransposeWUGroups().cend(),
[&g](const CustomUpdateTransposeWUGroupMerged &c){ return (c.getArchetype().getUpdateGroupName() == g); }))
{
genFilteredMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedCustomUpdateTransposeWUGroups(),
[&model, this](const CustomUpdateWUInternal &cg){ return getPaddedNumCustomUpdateTransposeWUThreads(cg, model.getBatchSize()); },
[g](const CustomUpdateTransposeWUGroupMerged &cg){ return cg.getArchetype().getUpdateGroupName() == g; });

os << "extern \"C\" __global__ void " << KernelNames[KernelCustomTransposeUpdate] << g << "(" << model.getTimePrecision() << " t)" << std::endl;
{
CodeStream::Scope b(os);
Expand Down
58 changes: 51 additions & 7 deletions src/genn/backends/opencl/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,40 @@ void genMergedKernelDataStructures(CodeStream &os, Args... args)
genGroupStartIDs(os, std::ref(idStart), args...);
}
//-----------------------------------------------------------------------
void genFilteredGroupStartIDs(CodeStream &, size_t &)
{
}
//-----------------------------------------------------------------------
template<typename T, typename G, typename F, typename ...Args>
void genFilteredGroupStartIDs(CodeStream &os, size_t &idStart,
const std::vector<T> &mergedGroups, G getPaddedNumThreads, F filter,
Args... args)
{
// Loop through merged groups
for(const auto &m : mergedGroups) {
if(filter(m)) {
// Declare array of starting thread indices for each neuron group
os << "__constant unsigned int d_merged" << T::name << "GroupStartID" << m.getIndex() << "[] = {";
for(const auto &ng : m.getGroups()) {
os << idStart << ", ";
idStart += getPaddedNumThreads(ng.get());
}
os << "};" << std::endl;
}
}

// Generate any remaining groups
genFilteredGroupStartIDs(os, idStart, args...);
}
//-----------------------------------------------------------------------
template<typename ...Args>
void genFilteredMergedKernelDataStructures(CodeStream &os, Args... args)
{
// Generate group start id arrays
size_t idStart = 0;
genFilteredGroupStartIDs(os, std::ref(idStart), args...);
}
//-----------------------------------------------------------------------
void genReadEventTiming(CodeStream &os, const std::string &name)
{
os << "const cl_ulong tmpStart = " << name << "Event.getProfilingInfo<CL_PROFILING_COMMAND_START>();" << std::endl;
Expand Down Expand Up @@ -845,13 +879,6 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged


// Generate data structure for accessing merged groups from within custom update kernels
genMergedKernelDataStructures(
customUpdateKernels,
modelMerged.getMergedCustomUpdateGroups(), [this](const CustomUpdateInternal &cg) { return padKernelSize(cg.getSize(), KernelCustomUpdate); },
modelMerged.getMergedCustomUpdateWUGroups(), [&model, this](const CustomUpdateWUInternal &cg) { return getPaddedNumCustomUpdateWUThreads(cg, model.getBatchSize()); });
genMergedKernelDataStructures(
customUpdateKernels,
modelMerged.getMergedCustomUpdateTransposeWUGroups(), [&model, this](const CustomUpdateWUInternal &cg) { return getPaddedNumCustomUpdateTransposeWUThreads(cg, model.getBatchSize()); });
customUpdateKernels << std::endl;

// Generate kernels used to populate merged structs
Expand All @@ -867,6 +894,17 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged
|| std::any_of(modelMerged.getMergedCustomUpdateWUGroups().cbegin(), modelMerged.getMergedCustomUpdateWUGroups().cend(),
[&g](const CustomUpdateWUGroupMerged &c) { return (c.getArchetype().getUpdateGroupName() == g.first); }))
{
genFilteredMergedKernelDataStructures(
customUpdateKernels,
modelMerged.getMergedCustomUpdateGroups(),
[this](const CustomUpdateInternal &cg) { return padKernelSize(cg.getSize(), KernelCustomUpdate); },
[g](const CustomUpdateGroupMerged &c){ return (c.getArchetype().getUpdateGroupName() == g.first); },

modelMerged.getMergedCustomUpdateWUGroups(),
[&model, this](const CustomUpdateWUInternal &cg) { return getPaddedNumCustomUpdateWUThreads(cg, model.getBatchSize()); },
[&g](const CustomUpdateWUGroupMerged &c) { return (c.getArchetype().getUpdateGroupName() == g.first); });


customUpdateKernels << "__attribute__((reqd_work_group_size(" << getKernelBlockSize(KernelCustomUpdate) << ", 1, 1)))" << std::endl;
customUpdateKernels << "__kernel void " << KernelNames[KernelCustomUpdate] << g.first << "(";
genMergedGroupKernelParams(customUpdateKernels, modelMerged.getMergedCustomUpdateGroups(), true);
Expand Down Expand Up @@ -897,6 +935,12 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged
if(std::any_of(modelMerged.getMergedCustomUpdateTransposeWUGroups().cbegin(), modelMerged.getMergedCustomUpdateTransposeWUGroups().cend(),
[&g](const CustomUpdateTransposeWUGroupMerged &c) { return (c.getArchetype().getUpdateGroupName() == g.first); }))
{
genFilteredMergedKernelDataStructures(
customUpdateKernels,
modelMerged.getMergedCustomUpdateTransposeWUGroups(),
[&model, this](const CustomUpdateWUInternal &cg) { return getPaddedNumCustomUpdateTransposeWUThreads(cg, model.getBatchSize()); },
[&g](const CustomUpdateTransposeWUGroupMerged &c) { return (c.getArchetype().getUpdateGroupName() == g.first); });

customUpdateKernels << "__attribute__((reqd_work_group_size(" << getKernelBlockSize(KernelCustomUpdate) << ", 8, 1)))" << std::endl;
customUpdateKernels << "__kernel void " << KernelNames[KernelCustomTransposeUpdate] << g.first << "(";
genMergedGroupKernelParams(customUpdateKernels, modelMerged.getMergedCustomUpdateTransposeWUGroups(), true);
Expand Down