diff --git a/example/sparse/matrix_fact_parallel_model.py b/example/sparse/matrix_fact_parallel_model.py index 350ecc0a4b64..0b349a694aa4 100644 --- a/example/sparse/matrix_fact_parallel_model.py +++ b/example/sparse/matrix_fact_parallel_model.py @@ -32,9 +32,12 @@ def matrix_fact_model_parallel_net(factor_size, num_hidden, max_user, max_item): item_weight = mx.symbol.Variable('item_weight', stype='row_sparse') item = mx.symbol.contrib.SparseEmbedding(data=item, weight=item_weight, input_dim=max_item, output_dim=factor_size) - # set ctx_group attribute to 'dev2' for the symbols created in this scope, - # the symbols will be bound to the context that 'dev2' map to in group2ctxs + # set ctx_group attribute to 'dev2' for the symbols created in this scope, + # the symbols will be bound to the context that 'dev2' map to in group2ctxs with mx.AttrScope(ctx_group='dev2'): + weight = mx.symbol.Variable('ufcweight') + bias = mx.symbol.Variable('ufcbias') + user = mx.symbol.FullyConnected(data=user, weight=weight, bias=bias, num_hidden=num_hidden) # predict by the inner product, which is elementwise product and then sum pred = user * item pred = mx.symbol.sum(data=pred, axis=1) diff --git a/example/sparse/matrix_factorization_model_parallel.py b/example/sparse/matrix_factorization_model_parallel.py index 525ce881dbdd..c622d67d4ccf 100644 --- a/example/sparse/matrix_factorization_model_parallel.py +++ b/example/sparse/matrix_factorization_model_parallel.py @@ -83,7 +83,7 @@ # initialize the module # map the ctx_group attribute to the context assignment - group2ctxs={'dev1':mx.cpu(), 'dev2':[mx.gpu(i) for i in range(num_gpus)]} + group2ctxs={'dev1':mx.cpu(), 'dev2':[mx.cpu(i) for i in range(num_gpus)]} mod = mx.module.Module(symbol=net, context=[mx.cpu()]*num_gpus, data_names=['user', 'item'], label_names=['score'], group2ctxs=group2ctxs) mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index dd4867559d5a..09096a1134eb 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -362,6 +362,7 @@ Graph AssignContext(Graph g, // loop through all the rest of input nodes not specified // in the ctx_map and populate maps and lists + LOG(INFO) << "args context"; size_t arg_top = 0, aux_top = 0; for (size_t i = 0; i < num_forward_inputs; ++i) { const uint32_t nid = idx.input_nodes().at(i); @@ -380,10 +381,22 @@ Graph AssignContext(Graph g, ctx_list.push_back(ctx); // save the current ctx in the list } device[nid] = ctx2id.at(ctx); // assign device id to the current node + LOG(INFO) << "nid: " << nid << " ctx.dev_id " << ctx.dev_id; } - + LOG(INFO) << "====================="; + LOG(INFO) << num_forward_outputs << " num_forward_outputs"; + LOG(INFO) << g.outputs.size() << " g.outputs.size()"; + LOG(INFO) << arg_grad_ctxes.size() << " arg_grad_ctxes.size()"; + // loop through backward input nodes and populate maps and lists // the backward input nodes is the gradient of the loss wrt the output + LOG(INFO) << "arg grads contexts"; + for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i){ + const uint32_t nid = idx.outputs()[i].node_id; + Context ctx = arg_grad_ctxes[i - num_forward_outputs]; + LOG(INFO) << "nid " << nid << " ctx " << ctx.dev_id; + } + LOG(INFO) << "====================="; for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i) { const uint32_t nid = idx.outputs()[i].node_id; Context ctx = arg_grad_ctxes[i - num_forward_outputs]; @@ -393,7 +406,34 @@ Graph AssignContext(Graph g, } int devid = ctx2id.at(ctx); if (device[nid] != -1) { - CHECK_EQ(device[nid], devid) << "device of same output not equal to each other"; + LOG(INFO) << "fail nid " << nid << " ctx " << ctx.dev_id; + const nnvm::IndexedGraph::Node fail_node = idx[nid]; + // print the graph structure + const auto& ret = g; + const auto &idx = ret.indexed_graph(); + uint32_t node_start = 0, node_end = idx.num_nodes(); + if (ret.attrs.count("node_range")) { + const auto& range = ret.GetAttr >("node_range"); + node_start = range.first; + node_end = range.second; + } + for (uint32_t nid = node_start; nid < node_end; ++nid) { + const auto& inode = idx[nid]; + if (inode.source->is_variable()) { + LOG(INFO) << "node " << nid << " var " << inode.source->attrs.name; + } else { + LOG(INFO) << "node " << nid << " " << inode.source->attrs.op->name; + for (const auto& e : inode.inputs) { + auto eid = idx.entry_id(e); + LOG(INFO) << "\t\tinput " << eid << " (entry id)"; + } + for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { + uint32_t eid = idx.entry_id(nid, index); + LOG(INFO) << "\t\toutput " << eid << " (entry id)"; + } + } + } // end of the print + CHECK_EQ(device[nid], devid) << fail_node.source->attrs.name << " device of same output not equal to each other"; } else { device[nid] = devid; }