diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/Node2vecAlgo.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/Node2vecAlgo.scala index 98e777a..36356c4 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/Node2vecAlgo.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/Node2vecAlgo.scala @@ -155,11 +155,14 @@ object Node2vecAlgo { graph = Graph(indexedNodes, indexedEdges) .mapVertices[NodeAttr] { case (vertexId, clickNode) => - val (j, q) = this.setupAlias(clickNode.neighbors) - val nextNodeIndex = this.drawAlias(j, q) - clickNode.path = Array(vertexId, clickNode.neighbors(nextNodeIndex)._1) - - clickNode + if (clickNode != null) { + val (j, q) = this.setupAlias(clickNode.neighbors) + val nextNodeIndex = this.drawAlias(j, q) + clickNode.path = Array(vertexId, clickNode.neighbors(nextNodeIndex)._1) + clickNode + } else { + NodeAttr() + } } .mapTriplets { edgeTriplet: EdgeTriplet[NodeAttr, EdgeAttr] => val (j, q) = this.setupEdgeAlias(bcP.value, bcQ.value)(edgeTriplet.srcId, @@ -210,11 +213,14 @@ object Node2vecAlgo { .map { case (edge, ((srcNodeId, pathBuffer), attr)) => try { - val nextNodeIndex = this.drawAlias(attr.J, attr.q) - val nextNodeId = attr.dstNeighbors(nextNodeIndex) - pathBuffer.append(nextNodeId) - - (srcNodeId, pathBuffer) + if (pathBuffer != null && pathBuffer.nonEmpty && attr.dstNeighbors != null && attr.dstNeighbors.nonEmpty) { + val nextNodeIndex = this.drawAlias(attr.J, attr.q) + val nextNodeId = attr.dstNeighbors(nextNodeIndex) + pathBuffer.append(nextNodeId) + (srcNodeId, pathBuffer) + } else { + (srcNodeId, pathBuffer) + } } catch { case e: Exception => throw new RuntimeException(e.getMessage) }