diff --git a/src/genn/backends/cuda/backend.cc b/src/genn/backends/cuda/backend.cc index 61bf832809..065dbfb6ef 100644 --- a/src/genn/backends/cuda/backend.cc +++ b/src/genn/backends/cuda/backend.cc @@ -80,8 +80,32 @@ class Timer const bool m_TimingEnabled; const bool m_SynchroniseOnStop; }; +//----------------------------------------------------------------------- +template +void genGroupStartID(CodeStream &os, size_t &idStart, size_t &totalConstMem, + const T &m, G getPaddedNumThreads) +{ + // Calculate size of array + const size_t sizeBytes = m.getGroups().size() * sizeof(unsigned int); + // If there is enough constant memory left for group, declare it in constant memory space + if(sizeBytes < totalConstMem) { + os << "__device__ __constant__ "; + totalConstMem -= sizeBytes; + } + // Otherwise, declare it in global memory space + else { + os << "__device__ "; + } + // Declare array of starting thread indices for each neuron group + os << "unsigned int d_merged" << T::name << "GroupStartID" << m.getIndex() << "[] = {"; + for(const auto &ng : m.getGroups()) { + os << idStart << ", "; + idStart += getPaddedNumThreads(ng.get()); + } + os << "};" << std::endl; +} //----------------------------------------------------------------------- void genGroupStartIDs(CodeStream &, size_t&, size_t&) { @@ -94,26 +118,7 @@ void genGroupStartIDs(CodeStream &os, size_t &idStart, size_t &totalConstMem, { // Loop through merged groups for(const auto &m : mergedGroups) { - // Calculate size of array - const size_t sizeBytes = m.getGroups().size() * sizeof(unsigned int); - - // If there is enough constant memory left for group, declare it in constant memory space - if(sizeBytes < totalConstMem) { - os << "__device__ __constant__ "; - totalConstMem -= sizeBytes; - } - // Otherwise, declare it in global memory space - else { - os << "__device__ "; - } - - // Declare array of starting thread indices for each neuron group - os << "unsigned int d_merged" << T::name << "GroupStartID" << m.getIndex() << "[] = {"; - for(const auto &ng : m.getGroups()) { - os << idStart << ", "; - idStart += getPaddedNumThreads(ng.get()); - } - os << "};" << std::endl; + genGroupStartID(os, idStart, totalConstMem, m, getPaddedNumThreads); } // Generate any remaining groups @@ -128,7 +133,35 @@ void genMergedKernelDataStructures(CodeStream &os, size_t &totalConstMem, size_t idStart = 0; genGroupStartIDs(os, std::ref(idStart), std::ref(totalConstMem), args...); } +//----------------------------------------------------------------------- +void genFilteredGroupStartIDs(CodeStream &, size_t&, size_t&) +{ +} +//----------------------------------------------------------------------- +template +void genFilteredGroupStartIDs(CodeStream &os, size_t &idStart, size_t &totalConstMem, + const std::vector &mergedGroups, G getPaddedNumThreads, F filter, + Args... args) +{ + // Loop through merged groups + for(const auto &m : mergedGroups) { + if(filter(m)) { + genGroupStartID(os, idStart, totalConstMem, m, getPaddedNumThreads); + } + } + // Generate any remaining groups + genFilteredGroupStartIDs(os, idStart, totalConstMem, args...); +} +//----------------------------------------------------------------------- +template +void genFilteredMergedKernelDataStructures(CodeStream &os, size_t &totalConstMem, + Args... args) +{ + // Generate group start id arrays + size_t idStart = 0; + genFilteredGroupStartIDs(os, std::ref(idStart), std::ref(totalConstMem), args...); +} //----------------------------------------------------------------------- template size_t getNumMergedGroupThreads(const std::vector &groups, G getNumThreads) @@ -593,12 +626,6 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged getGroupStartIDSize(modelMerged.getMergedSynapseDynamicsGroups()) + getGroupStartIDSize(modelMerged.getMergedNeuronUpdateGroups())); size_t totalConstMem = (getChosenDeviceSafeConstMemBytes() > timestepGroupStartIDSize) ? (getChosenDeviceSafeConstMemBytes() - timestepGroupStartIDSize) : 0; - genMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedCustomUpdateGroups(), - [this](const CustomUpdateInternal &cg){ return padKernelSize(cg.getSize(), KernelCustomUpdate); }); - genMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedCustomUpdateWUGroups(), - [&model, this](const CustomUpdateWUInternal &cg){ return getPaddedNumCustomUpdateWUThreads(cg, model.getBatchSize()); }); - genMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedCustomUpdateTransposeWUGroups(), - [&model, this](const CustomUpdateWUInternal &cg){ return getPaddedNumCustomUpdateTransposeWUThreads(cg, model.getBatchSize()); }); // Build set containing union of all custom update groupsnames std::set customUpdateGroups; @@ -618,6 +645,15 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged || std::any_of(modelMerged.getMergedCustomUpdateWUGroups().cbegin(), modelMerged.getMergedCustomUpdateWUGroups().cend(), [&g](const CustomUpdateWUGroupMerged &c) { return (c.getArchetype().getUpdateGroupName() == g); })) { + genFilteredMergedKernelDataStructures(os, totalConstMem, + modelMerged.getMergedCustomUpdateGroups(), + [this](const CustomUpdateInternal &cg){ return padKernelSize(cg.getSize(), KernelCustomUpdate); }, + [g](const CustomUpdateGroupMerged &cg){ return cg.getArchetype().getUpdateGroupName() == g; }, + + modelMerged.getMergedCustomUpdateWUGroups(), + [&model, this](const CustomUpdateWUInternal &cg){ return getPaddedNumCustomUpdateWUThreads(cg, model.getBatchSize()); }, + [g](const CustomUpdateWUGroupMerged &cg){ return cg.getArchetype().getUpdateGroupName() == g; }); + os << "extern \"C\" __global__ void " << KernelNames[KernelCustomUpdate] << g << "(" << model.getTimePrecision() << " t)" << std::endl; { CodeStream::Scope b(os); @@ -641,6 +677,10 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged if(std::any_of(modelMerged.getMergedCustomUpdateTransposeWUGroups().cbegin(), modelMerged.getMergedCustomUpdateTransposeWUGroups().cend(), [&g](const CustomUpdateTransposeWUGroupMerged &c){ return (c.getArchetype().getUpdateGroupName() == g); })) { + genFilteredMergedKernelDataStructures(os, totalConstMem, modelMerged.getMergedCustomUpdateTransposeWUGroups(), + [&model, this](const CustomUpdateWUInternal &cg){ return getPaddedNumCustomUpdateTransposeWUThreads(cg, model.getBatchSize()); }, + [g](const CustomUpdateTransposeWUGroupMerged &cg){ return cg.getArchetype().getUpdateGroupName() == g; }); + os << "extern \"C\" __global__ void " << KernelNames[KernelCustomTransposeUpdate] << g << "(" << model.getTimePrecision() << " t)" << std::endl; { CodeStream::Scope b(os); diff --git a/src/genn/backends/opencl/backend.cc b/src/genn/backends/opencl/backend.cc index 46dfc14e6b..e88b3f35d9 100644 --- a/src/genn/backends/opencl/backend.cc +++ b/src/genn/backends/opencl/backend.cc @@ -103,6 +103,40 @@ void genMergedKernelDataStructures(CodeStream &os, Args... args) genGroupStartIDs(os, std::ref(idStart), args...); } //----------------------------------------------------------------------- +void genFilteredGroupStartIDs(CodeStream &, size_t &) +{ +} +//----------------------------------------------------------------------- +template +void genFilteredGroupStartIDs(CodeStream &os, size_t &idStart, + const std::vector &mergedGroups, G getPaddedNumThreads, F filter, + Args... args) +{ + // Loop through merged groups + for(const auto &m : mergedGroups) { + if(filter(m)) { + // Declare array of starting thread indices for each neuron group + os << "__constant unsigned int d_merged" << T::name << "GroupStartID" << m.getIndex() << "[] = {"; + for(const auto &ng : m.getGroups()) { + os << idStart << ", "; + idStart += getPaddedNumThreads(ng.get()); + } + os << "};" << std::endl; + } + } + + // Generate any remaining groups + genFilteredGroupStartIDs(os, idStart, args...); +} +//----------------------------------------------------------------------- +template +void genFilteredMergedKernelDataStructures(CodeStream &os, Args... args) +{ + // Generate group start id arrays + size_t idStart = 0; + genFilteredGroupStartIDs(os, std::ref(idStart), args...); +} +//----------------------------------------------------------------------- void genReadEventTiming(CodeStream &os, const std::string &name) { os << "const cl_ulong tmpStart = " << name << "Event.getProfilingInfo();" << std::endl; @@ -845,13 +879,6 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged // Generate data structure for accessing merged groups from within custom update kernels - genMergedKernelDataStructures( - customUpdateKernels, - modelMerged.getMergedCustomUpdateGroups(), [this](const CustomUpdateInternal &cg) { return padKernelSize(cg.getSize(), KernelCustomUpdate); }, - modelMerged.getMergedCustomUpdateWUGroups(), [&model, this](const CustomUpdateWUInternal &cg) { return getPaddedNumCustomUpdateWUThreads(cg, model.getBatchSize()); }); - genMergedKernelDataStructures( - customUpdateKernels, - modelMerged.getMergedCustomUpdateTransposeWUGroups(), [&model, this](const CustomUpdateWUInternal &cg) { return getPaddedNumCustomUpdateTransposeWUThreads(cg, model.getBatchSize()); }); customUpdateKernels << std::endl; // Generate kernels used to populate merged structs @@ -867,6 +894,17 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged || std::any_of(modelMerged.getMergedCustomUpdateWUGroups().cbegin(), modelMerged.getMergedCustomUpdateWUGroups().cend(), [&g](const CustomUpdateWUGroupMerged &c) { return (c.getArchetype().getUpdateGroupName() == g.first); })) { + genFilteredMergedKernelDataStructures( + customUpdateKernels, + modelMerged.getMergedCustomUpdateGroups(), + [this](const CustomUpdateInternal &cg) { return padKernelSize(cg.getSize(), KernelCustomUpdate); }, + [g](const CustomUpdateGroupMerged &c){ return (c.getArchetype().getUpdateGroupName() == g.first); }, + + modelMerged.getMergedCustomUpdateWUGroups(), + [&model, this](const CustomUpdateWUInternal &cg) { return getPaddedNumCustomUpdateWUThreads(cg, model.getBatchSize()); }, + [&g](const CustomUpdateWUGroupMerged &c) { return (c.getArchetype().getUpdateGroupName() == g.first); }); + + customUpdateKernels << "__attribute__((reqd_work_group_size(" << getKernelBlockSize(KernelCustomUpdate) << ", 1, 1)))" << std::endl; customUpdateKernels << "__kernel void " << KernelNames[KernelCustomUpdate] << g.first << "("; genMergedGroupKernelParams(customUpdateKernels, modelMerged.getMergedCustomUpdateGroups(), true); @@ -897,6 +935,12 @@ void Backend::genCustomUpdate(CodeStream &os, const ModelSpecMerged &modelMerged if(std::any_of(modelMerged.getMergedCustomUpdateTransposeWUGroups().cbegin(), modelMerged.getMergedCustomUpdateTransposeWUGroups().cend(), [&g](const CustomUpdateTransposeWUGroupMerged &c) { return (c.getArchetype().getUpdateGroupName() == g.first); })) { + genFilteredMergedKernelDataStructures( + customUpdateKernels, + modelMerged.getMergedCustomUpdateTransposeWUGroups(), + [&model, this](const CustomUpdateWUInternal &cg) { return getPaddedNumCustomUpdateTransposeWUThreads(cg, model.getBatchSize()); }, + [&g](const CustomUpdateTransposeWUGroupMerged &c) { return (c.getArchetype().getUpdateGroupName() == g.first); }); + customUpdateKernels << "__attribute__((reqd_work_group_size(" << getKernelBlockSize(KernelCustomUpdate) << ", 8, 1)))" << std::endl; customUpdateKernels << "__kernel void " << KernelNames[KernelCustomTransposeUpdate] << g.first << "("; genMergedGroupKernelParams(customUpdateKernels, modelMerged.getMergedCustomUpdateTransposeWUGroups(), true);