diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index 57f9f2220289..073cfa4f14f6 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -233,7 +233,11 @@ void Imperative::RecordOp( nnvm::ObjectPtr node = nnvm::Node::Create(); node->attrs = std::move(attrs); - node->attrs.name = "node_" + std::to_string(node_count_++); + if (node->attrs.name == "") { + node->attrs.name = "node_" + std::to_string(node_count_++); + } else { + node_count_++; + } AGInfo& info = AGInfo::Create(node); info.state = state; info.ctx = outputs[0]->ctx(); @@ -322,7 +326,11 @@ void Imperative::RecordDeferredCompute(nnvm::NodeAttrs &&attrs, } node->attrs = std::move(attrs); // Need to support NameManager in imperative API to better name node->attrs.name - node->attrs.name = "node_" + std::to_string(node_count_++); + if (node->attrs.name == "") { + node->attrs.name = "node_" + std::to_string(node_count_++); + } else { + node_count_++; + } for (uint32_t i = 0; i < outputs.size(); ++i) { outputs[i]->deferredcompute_entry_ = nnvm::NodeEntry{node, i, 0}; diff --git a/tests/python/gpu/test_profiler_gpu.py b/tests/python/gpu/test_profiler_gpu.py index 929615446d2a..82510ddcc8bb 100644 --- a/tests/python/gpu/test_profiler_gpu.py +++ b/tests/python/gpu/test_profiler_gpu.py @@ -39,7 +39,7 @@ def test_gpu_memory_profiler_symbolic(): with profiler.scope("tensordot"): A = mx.sym.Variable('A') B = mx.sym.Variable('B') - C = mx.symbol.dot(A, B) + C = mx.symbol.dot(A, B, name="dot") executor = C._simple_bind(mx.gpu(), 'write', A=(1024, 2048), B=(2048, 4096)) @@ -62,7 +62,7 @@ def test_gpu_memory_profiler_symbolic(): 'Requested Size' : str(4 * a.size)}, {'Attribute Name' : 'tensordot:in_arg:B', 'Requested Size' : str(4 * b.size)}, - {'Attribute Name' : 'tensordot:dot0', + {'Attribute Name' : 'tensordot:dot', 'Requested Size' : str(4 * c.size)}, {'Attribute Name' : 'init:_random_uniform', 'Requested Size' : str(4 * a.size)}, @@ -77,7 +77,7 @@ def test_gpu_memory_profiler_symbolic(): # resource:temp_space (sample_op.h +365),8,0,4096,0 # symbol:arg_grad:unknown,8388608,0,8388608,0 # symbol:arg_grad:unknown,33554432,0,33554432,0 - # tensordot:dot0,16777216,0,16777216,0 + # tensordot:dot,16777216,0,16777216,0 # tensordot:in_arg:A,8388608,0,8388608,0 # tensordot:in_arg:B,33554432,0,33554432,0 # tensordot:node_0_backward,33554432,0,33554432,0