diff --git a/src/codegen/codegen_helper_visitor.cpp b/src/codegen/codegen_helper_visitor.cpp index f2b5eb7ff..4067de9b7 100644 --- a/src/codegen/codegen_helper_visitor.cpp +++ b/src/codegen/codegen_helper_visitor.cpp @@ -290,7 +290,7 @@ void CodegenHelperVisitor::find_non_range_variables() { // if model is thread safe and if parameter is being written then // those variables should be promoted to thread safe variable - if (info.vectorize && info.thread_safe && var->get_write_count() > 0) { + if (info.vectorize && info.declared_thread_safe && var->get_write_count() > 0) { var->mark_thread_safe(); info.thread_variables.push_back(var); info.thread_var_data_size += var->get_length(); diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index 7e8420cce..ebcf222c5 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -1103,6 +1103,7 @@ void CodegenNeuronCppVisitor::print_neuron_includes() { printer->add_multi_line(R"CODE( #include "mech_api.h" #include "neuron/cache/mechanism_range.hpp" + #include "nmodlmutex.h" #include "nrniv_mf.h" #include "section_fwd.hpp" )CODE"); @@ -3063,6 +3064,14 @@ void CodegenNeuronCppVisitor::visit_for_netcon(const ast::ForNetcon& node) { printer->pop_block(); } +void CodegenNeuronCppVisitor::visit_protect_statement(const ast::ProtectStatement& node) { + printer->add_line("_NMODLMUTEXLOCK"); + printer->add_indent(); + node.get_expression()->accept(*this); + printer->add_text(";"); + printer->add_line("_NMODLMUTEXUNLOCK"); +} + } // namespace codegen } // namespace nmodl diff --git a/src/codegen/codegen_neuron_cpp_visitor.hpp b/src/codegen/codegen_neuron_cpp_visitor.hpp index 9b36b1c22..4e0ebb106 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.hpp +++ b/src/codegen/codegen_neuron_cpp_visitor.hpp @@ -839,6 +839,7 @@ class CodegenNeuronCppVisitor: public CodegenCppVisitor { void visit_for_netcon(const ast::ForNetcon& node) override; void visit_longitudinal_diffusion_block(const ast::LongitudinalDiffusionBlock& node) override; void visit_lon_diffuse(const ast::LonDiffuse& node) override; + void visit_protect_statement(const ast::ProtectStatement& node) override; public: /****************************************************************************************/ diff --git a/test/usecases/CMakeLists.txt b/test/usecases/CMakeLists.txt index f776c578d..73aa0ced3 100644 --- a/test/usecases/CMakeLists.txt +++ b/test/usecases/CMakeLists.txt @@ -25,6 +25,7 @@ set(NMODL_USECASE_DIRS point_process pointer procedure + protect random solve state diff --git a/test/usecases/protect/shared_counter.mod b/test/usecases/protect/shared_counter.mod new file mode 100644 index 000000000..c431c7ce3 --- /dev/null +++ b/test/usecases/protect/shared_counter.mod @@ -0,0 +1,17 @@ +NEURON { + SUFFIX shared_counter + GLOBAL g_cnt +} + +ASSIGNED { + g_cnt +} + + +INITIAL { + PROTECT g_cnt = 0 +} + +BREAKPOINT { + PROTECT g_cnt = g_cnt + 1 +} diff --git a/test/usecases/protect/test_shared_counter.py b/test/usecases/protect/test_shared_counter.py new file mode 100644 index 000000000..edc52ae6d --- /dev/null +++ b/test/usecases/protect/test_shared_counter.py @@ -0,0 +1,30 @@ +from neuron import h, gui + + +def test_shared_counter(): + nthreads = 32 + nseg = 10 + nsteps = 100 + + sections = [h.Section() for _ in range(nthreads)] + for s in sections: + s.insert("shared_counter") + s.nseg = nseg + + pc = h.ParallelContext() + + pc.nthread(nthreads) + for k, s in enumerate(sections): + pc.partition(k, h.SectionList([s])) + + h.finitialize() + for _ in range(nsteps): + h.step() + + expected = nthreads * nseg * nsteps + g_cnt = h.g_cnt_shared_counter + assert h.g_cnt_shared_counter == expected, f"{g_cnt}" + + +if __name__ == "__main__": + test_shared_counter()