Skip to content

Commit

Permalink
Support PROTECT via _NMODLMUTEX{,UN}LOCK. (#1557)
Browse files Browse the repository at this point in the history
* Fix thread_variable conditions.

* Support PROTECT via `_NMODLMUTEX{,UN}LOCK`.

* Add tests.
  • Loading branch information
1uc authored Nov 14, 2024
1 parent 16b5f95 commit 96299cb
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/codegen/codegen_helper_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ void CodegenHelperVisitor::find_non_range_variables() {

// if model is thread safe and if parameter is being written then
// those variables should be promoted to thread safe variable
if (info.vectorize && info.thread_safe && var->get_write_count() > 0) {
if (info.vectorize && info.declared_thread_safe && var->get_write_count() > 0) {
var->mark_thread_safe();
info.thread_variables.push_back(var);
info.thread_var_data_size += var->get_length();
Expand Down
9 changes: 9 additions & 0 deletions src/codegen/codegen_neuron_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,7 @@ void CodegenNeuronCppVisitor::print_neuron_includes() {
printer->add_multi_line(R"CODE(
#include "mech_api.h"
#include "neuron/cache/mechanism_range.hpp"
#include "nmodlmutex.h"
#include "nrniv_mf.h"
#include "section_fwd.hpp"
)CODE");
Expand Down Expand Up @@ -3063,6 +3064,14 @@ void CodegenNeuronCppVisitor::visit_for_netcon(const ast::ForNetcon& node) {
printer->pop_block();
}

void CodegenNeuronCppVisitor::visit_protect_statement(const ast::ProtectStatement& node) {
printer->add_line("_NMODLMUTEXLOCK");
printer->add_indent();
node.get_expression()->accept(*this);
printer->add_text(";");
printer->add_line("_NMODLMUTEXUNLOCK");
}


} // namespace codegen
} // namespace nmodl
1 change: 1 addition & 0 deletions src/codegen/codegen_neuron_cpp_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor {
void visit_for_netcon(const ast::ForNetcon& node) override;
void visit_longitudinal_diffusion_block(const ast::LongitudinalDiffusionBlock& node) override;
void visit_lon_diffuse(const ast::LonDiffuse& node) override;
void visit_protect_statement(const ast::ProtectStatement& node) override;

public:
/****************************************************************************************/
Expand Down
1 change: 1 addition & 0 deletions test/usecases/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ set(NMODL_USECASE_DIRS
point_process
pointer
procedure
protect
random
solve
state
Expand Down
17 changes: 17 additions & 0 deletions test/usecases/protect/shared_counter.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
NEURON {
SUFFIX shared_counter
GLOBAL g_cnt
}

ASSIGNED {
g_cnt
}


INITIAL {
PROTECT g_cnt = 0
}

BREAKPOINT {
PROTECT g_cnt = g_cnt + 1
}
30 changes: 30 additions & 0 deletions test/usecases/protect/test_shared_counter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from neuron import h, gui


def test_shared_counter():
nthreads = 32
nseg = 10
nsteps = 100

sections = [h.Section() for _ in range(nthreads)]
for s in sections:
s.insert("shared_counter")
s.nseg = nseg

pc = h.ParallelContext()

pc.nthread(nthreads)
for k, s in enumerate(sections):
pc.partition(k, h.SectionList([s]))

h.finitialize()
for _ in range(nsteps):
h.step()

expected = nthreads * nseg * nsteps
g_cnt = h.g_cnt_shared_counter
assert h.g_cnt_shared_counter == expected, f"{g_cnt}"


if __name__ == "__main__":
test_shared_counter()

0 comments on commit 96299cb

Please sign in to comment.