From 9c9c26c3cb8aeb6090f08dcbcfc310309c034b31 Mon Sep 17 00:00:00 2001 From: Jermy Li Date: Wed, 6 May 2020 15:15:53 +0800 Subject: [PATCH] louvain: add modularity parameter and fix isolated community lost (#14) * add modularity parameter for louvain * fix: louvain lost isolated community from one to next pass Change-Id: I6a7dadc80635429aa2898939aa337aae01bc8d12 --- .../job/algorithm/AbstractAlgorithm.java | 3 +- .../job/algorithm/comm/LouvainAlgorithm.java | 20 +- .../job/algorithm/comm/LouvainTraverser.java | 187 ++++++++++++------ 3 files changed, 145 insertions(+), 65 deletions(-) diff --git a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/AbstractAlgorithm.java b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/AbstractAlgorithm.java index 248a92bdb1..969bda1d8d 100644 --- a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/AbstractAlgorithm.java +++ b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/AbstractAlgorithm.java @@ -59,7 +59,7 @@ public abstract class AbstractAlgorithm implements Algorithm { public static final long MAX_RESULT_SIZE = 100L * Bytes.MB; - public static final long MAX_QUERY_LIMIT = 10000000L; // about 10GB + public static final long MAX_QUERY_LIMIT = 100000000L; // about 100GB public static final int BATCH = 500; public static final String CATEGORY_AGGR = "aggregate"; @@ -81,6 +81,7 @@ public abstract class AbstractAlgorithm implements Algorithm { public static final String KEY_TIMES = "times"; public static final String KEY_STABLE_TIMES = "stable_times"; public static final String KEY_PRECISION = "precision"; + public static final String KEY_SHOW_MOD= "show_modularity"; public static final String KEY_SHOW_COMM = "show_community"; public static final String KEY_CLEAR = "clear"; public static final String KEY_CAPACITY = "capacity"; diff --git a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/comm/LouvainAlgorithm.java b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/comm/LouvainAlgorithm.java index 3f6de63e8c..c0c05f9a22 100644 --- a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/comm/LouvainAlgorithm.java +++ b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/comm/LouvainAlgorithm.java @@ -22,7 +22,7 @@ import java.util.Map; import com.baidu.hugegraph.job.Job; -import com.baidu.hugegraph.util.E; +import com.baidu.hugegraph.traversal.algorithm.HugeTraverser; public class LouvainAlgorithm extends AbstractCommAlgorithm { @@ -39,6 +39,7 @@ public void checkParameters(Map parameters) { degree(parameters); sourceLabel(parameters); sourceCLabel(parameters); + showModularity(parameters); showCommunity(parameters); clearPass(parameters); } @@ -52,10 +53,13 @@ public Object call(Job job, Map parameters) { LouvainTraverser traverser = new LouvainTraverser(job, degree, label, clabel); Long clearPass = clearPass(parameters); + Long modPass = showModularity(parameters); String showComm = showCommunity(parameters); try { if (clearPass != null) { return traverser.clearPass(clearPass.intValue()); + } else if (modPass != null) { + return traverser.modularity(modPass.intValue()); } else if (showComm != null) { return traverser.showCommunity(showComm); } else { @@ -74,10 +78,16 @@ protected static Long clearPass(Map parameters) { return null; } long pass = parameterLong(parameters, KEY_CLEAR); - // TODO: change to checkNonNegative() - E.checkArgument(pass >= 0 || pass == -1, - "The %s parameter must be >= 0 or == -1, but got %s", - KEY_CLEAR, pass); + HugeTraverser.checkNonNegativeOrNoLimit(pass, KEY_CLEAR); + return pass; + } + + protected static Long showModularity(Map parameters) { + if (!parameters.containsKey(KEY_SHOW_MOD)) { + return null; + } + long pass = parameterLong(parameters, KEY_SHOW_MOD); + HugeTraverser.checkNonNegative(pass, KEY_SHOW_MOD); return pass; } } diff --git a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/comm/LouvainTraverser.java b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/comm/LouvainTraverser.java index 0177d8f2d7..a63a1259dc 100644 --- a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/comm/LouvainTraverser.java +++ b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/comm/LouvainTraverser.java @@ -28,6 +28,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.NoSuchElementException; import java.util.Set; import org.apache.commons.lang3.mutable.MutableInt; @@ -52,6 +53,7 @@ import com.baidu.hugegraph.structure.HugeEdge; import com.baidu.hugegraph.structure.HugeVertex; import com.baidu.hugegraph.type.define.Directions; +import com.baidu.hugegraph.util.InsertionOrderUtil; import com.baidu.hugegraph.util.Log; import com.google.common.collect.ImmutableMap; @@ -89,23 +91,6 @@ public LouvainTraverser(Job job, long degree, this.cache = new Cache(); } - @SuppressWarnings("unused") - private Id genId2(int pass, Id cid) { - // gen id for merge-community vertex - String id = cid.toString(); - if (pass == 0) { - // conncat pass with cid - id = pass + "~" + id; - } else { - // replace last pass with current pass - String lastPass = String.valueOf(pass - 1); - assert id.startsWith(lastPass); - id = id.substring(lastPass.length()); - id = pass + id; - } - return IdGenerator.of(id); - } - private void defineSchemaOfPk() { String label = this.labelOfPassN(0); if (this.graph().existsVertexLabel(label) || @@ -131,8 +116,7 @@ private void defineSchemaOfPassN(int pass) { SchemaManager schema = this.graph().schema(); try { schema.vertexLabel(this.passLabel).useCustomizeStringId() - .properties(C_KIN, C_MEMBERS) - .nullableKeys(C_KIN, C_MEMBERS) + .properties(C_KIN, C_MEMBERS, C_WEIGHT) .create(); schema.edgeLabel(this.passLabel) .sourceLabel(this.passLabel) @@ -189,9 +173,16 @@ private float weightOfEdges(List edges) { return weight; } - private Vertex newCommunityNode(Id cid, int kin, List members) { + private Vertex newCommunityNode(Id cid, float cweight, + int kin, List members) { assert !members.isEmpty() : members; - return this.graph().addVertex(T.label, this.passLabel, T.id, cid, + /* + * cweight: members size(all pass) of the community, just for debug + * kin: edges weight in the community + * members: members id of the community of last pass + */ + return this.graph().addVertex(T.label, this.passLabel, + T.id, cid, C_WEIGHT, cweight, C_KIN, kin, C_MEMBERS, members); } @@ -204,12 +195,12 @@ private Edge newCommunityEdge(Vertex source, Vertex target, float weight) { return source.addEdge(this.passLabel, target, C_WEIGHT, weight); } - private void insertNewCommunity(int pass, Id cid, int kin, - List members, + private void insertNewCommunity(int pass, Id cid, float cweight, + int kin, List members, Map cedges) { // create backend vertex if it's the first time Id vid = this.cache.genId(pass, cid); - Vertex node = this.newCommunityNode(vid, kin, members); + Vertex node = this.newCommunityNode(vid, cweight, kin, members); commitIfNeeded(); // update backend vertex edges for (Map.Entry e : cedges.entrySet()) { @@ -262,6 +253,7 @@ private List neighbors(Id vid) { } private float weightOfVertex(Vertex v, List edges) { + // degree/weight of vertex Float value = this.cache.vertexWeight((Id) v.id()); if (value != null) { return value; @@ -281,9 +273,21 @@ private int kinOfVertex(Vertex v) { return 0; } - private Id cidOfVertex(Vertex v) { + private float cweightOfVertex(Vertex v) { + if (v.label().startsWith(C_PASS) && v.property(C_WEIGHT).isPresent()) { + return v.value(C_WEIGHT); + } + return 1f; + } + + private Id cidOfVertex(Vertex v, List nbs) { Id vid = (Id) v.id(); Community c = this.cache.vertex2Community(vid); + // ensure source vertex exist in cache + if (c == null) { + c = this.wrapCommunity(v, nbs); + assert c != null; + } return c != null ? c.cid : vid; } @@ -292,15 +296,15 @@ private Id cidOfVertex(Vertex v) { // and save as community vertex when merge() // 3: wrap community vertex as community node, // and repeat step 2 and step 3. - private Community wrapCommunity(Vertex otherV) { - Id vid = (Id) otherV.id(); + private Community wrapCommunity(Vertex v, List nbs) { + Id vid = (Id) v.id(); Community comm = this.cache.vertex2Community(vid); if (comm != null) { return comm; } comm = new Community(vid); - comm.add(this, otherV, null); // will traverse the neighbors of otherV + comm.add(this, v, nbs); this.cache.vertex2Community(vid, comm); return comm; } @@ -316,7 +320,8 @@ private Collection> nbCommunities( // skip the old intermediate data, or filter clabel continue; } - Community c = wrapCommunity(otherV); + // will traverse the neighbors of otherV + Community c = this.wrapCommunity(otherV, null); if (!comms.containsKey(c.cid)) { comms.put(c.cid, Pair.of(c, new MutableInt(0))); } @@ -359,8 +364,8 @@ private double moveCommunities(int pass) { continue; } total++; - Id cid = cidOfVertex(v); List nbs = neighbors((Id) v.id()); + Id cid = cidOfVertex(v, nbs); double ki = kinOfVertex(v) + weightOfVertex(v, nbs); // update community of v if △Q changed double maxDeltaQ = 0d; @@ -377,13 +382,13 @@ private double moveCommunities(int pass) { // weight between c and otherC double kiin = nbc.getRight().floatValue(); // weight of otherC - int tot = otherC.kin() + otherC.kout(); + double tot = otherC.kin() + otherC.kout(); if (cid.equals(otherC.cid)) { tot -= ki; - assert tot >= 0; + assert tot >= 0d; // expect tot >= 0, but may be something wrong? - if (tot < 0) { - tot = 0; + if (tot < 0d) { + tot = 0d; } } double deltaQ = kiin - ki * tot / this.m; @@ -407,6 +412,7 @@ private double moveCommunities(int pass) { private void mergeCommunities(int pass) { // merge each community as a vertex Collection>> comms = this.cache.communities(); + assert this.allMembersExist(comms, pass -1); this.cache.resetVertexWeight(); for (Pair> pair : comms) { Community c = pair.getKey(); @@ -417,6 +423,7 @@ private void mergeCommunities(int pass) { int kin = c.kin(); Set vertices = pair.getRight(); assert !vertices.isEmpty(); + assert vertices.size() == c.size(); List members = new ArrayList<>(vertices.size()); Map cedges = new HashMap<>(vertices.size()); for (Id v : vertices) { @@ -432,7 +439,8 @@ private void mergeCommunities(int pass) { kin += weightOfEdge(edge); continue; } - Id otherCid = cidOfVertex(otherV); + assert this.cache.vertex2Community(otherV.id()) != null; + Id otherCid = cidOfVertex(otherV, null); if (otherCid.compareTo(c.cid) < 0) { // skip if it should be collected by otherC continue; @@ -440,17 +448,33 @@ private void mergeCommunities(int pass) { if (!cedges.containsKey(otherCid)) { cedges.put(otherCid, new MutableInt(0)); } + // update edge weight cedges.get(otherCid).add(weightOfEdge(edge)); } } // insert new community vertex and edges into storage - this.insertNewCommunity(pass, c.cid, kin, members, cedges); + this.insertNewCommunity(pass, c.cid, c.weight(), kin, members, cedges); } this.graph().tx().commit(); // reset communities this.cache.reset(); } + private boolean allMembersExist(Collection>> comms, + int pass) { + String lastLabel = labelOfPassN(pass); + GraphTraversal t = pass < 0 ? this.g.V().id() : + this.g.V().hasLabel(lastLabel).id(); + Set all = this.execute(t, t::toSet); + for (Pair> comm : comms) { + all.removeAll(comm.getRight()); + } + if (all.size() > 0) { + LOG.warn("Lost members of last pass: {}", all); + } + return all.isEmpty(); + } + public Object louvain(int maxTimes, int stableTimes, double precision) { assert maxTimes > 0; assert precision > 0d; @@ -496,31 +520,40 @@ public Object louvain(int maxTimes, int stableTimes, double precision) { } } - long communities = 0L; + Map results = InsertionOrderUtil.newMap(); + results.putAll(ImmutableMap.of("pass_times", times, + "phase1_times", movedTimes, + "last_precision", movedPercent, + "times", maxTimes)); + Number communities = 0L; + Number modularity = -1L; String commLabel = this.passLabel; if (!commLabel.isEmpty()) { - GraphTraversal t = this.g.V().hasLabel(commLabel).count(); - communities = this.execute(t, t::next); + communities = tryNext(this.g.V().hasLabel(commLabel).count()); + modularity = this.modularity(commLabel); } - return ImmutableMap.of("pass_times", times, - "phase1_times", movedTimes, - "last_precision", movedPercent, - "times", maxTimes, - "communities", communities); + results.putAll(ImmutableMap.of("communities", communities, + "modularity", modularity)); + return results; } public double modularity(int pass) { - // pass: label the last pass + // community vertex label of one pass String label = labelOfPassN(pass); - Number kin = this.g.V().hasLabel(label).values(C_KIN).sum().next(); - Number weight = this.g.E().hasLabel(label).values(C_WEIGHT).sum().next(); + return this.modularity(label); + } + + private double modularity(String label) { + // label: community vertex label of one pass + Number kin = tryNext(this.g.V().hasLabel(label).values(C_KIN).sum()); + Number weight = tryNext(this.g.E().hasLabel(label).values(C_WEIGHT).sum()); double m = kin.intValue() + weight.floatValue() * 2.0d; double q = 0.0d; - Iterator coms = this.g.V().hasLabel(label); - while (coms.hasNext()) { - Vertex com = coms.next(); - int cin = com.value(C_KIN); - Number cout = this.g.V(com).bothE().values(C_WEIGHT).sum().next(); + Iterator comms = this.vertices(label, LIMIT); + while (comms.hasNext()) { + Vertex comm = comms.next(); + int cin = comm.value(C_KIN); + Number cout = tryNext(this.g.V(comm).bothE().values(C_WEIGHT).sum()); double cdegree = cin + cout.floatValue(); // Q = ∑(I/M - ((2I+O)/2M)^2) q += cin / m - Math.pow(cdegree / m, 2); @@ -528,6 +561,16 @@ public double modularity(int pass) { return q; } + private Number tryNext(GraphTraversal iter) { + return this.execute(iter, () -> { + try { + return iter.next(); + } catch (NoSuchElementException e) { + return 0; + } + }); + } + public Collection showCommunity(String community) { final String C_PASS0 = labelOfPassN(0); Collection comms = Arrays.asList(community); @@ -604,8 +647,10 @@ private static class Community { // community id (stored as a backend vertex) private final Id cid; - // community members size + // community members size of last pass [just for skip large community] private int size = 0; + // community members size of origin vertex [just for debug members lost] + private float weight = 0f; /* * weight of all edges in community(2X), sum of kin of new members * [each is from the last pass, stored in backend vertex] @@ -615,8 +660,7 @@ private static class Community { * weight of all edges between communities, sum of kout of new members * [each is last pass, calculated in real time by neighbors] */ - // - private int kout = 0; + private float kout = 0f; public Community(Id cid) { this.cid = cid; @@ -630,14 +674,20 @@ public int size() { return this.size; } + public float weight() { + return this.weight; + } + public void add(LouvainTraverser t, Vertex v, List nbs) { this.size++; + this.weight += t.cweightOfVertex(v); this.kin += t.kinOfVertex(v); this.kout += t.weightOfVertex(v, nbs); } public void remove(LouvainTraverser t, Vertex v, List nbs) { this.size--; + this.weight -= t.cweightOfVertex(v); this.kin -= t.kinOfVertex(v); this.kout -= t.weightOfVertex(v, nbs); } @@ -646,14 +696,15 @@ public int kin() { return this.kin; } - public int kout() { + public float kout() { return this.kout; } @Override public String toString() { - return String.format("[%s](size=%s kin=%s kout=%s)", - this.cid , this.size, this.kin, this.kout); + return String.format("[%s](size=%s weight=%s kin=%s kout=%s)", + this.cid , this.size, this.weight, + this.kin, this.kout); } } @@ -669,7 +720,8 @@ public Cache() { this.genIds = new HashMap<>(); } - public Community vertex2Community(Id id) { + public Community vertex2Community(Object id) { + assert id instanceof Id; return this.vertex2Community.get(id); } @@ -703,6 +755,23 @@ public Id genId(int pass, Id cid) { return IdGenerator.of(id); } + @SuppressWarnings("unused") + public Id genId2(int pass, Id cid) { + // gen id for merge-community vertex + String id = cid.toString(); + if (pass == 0) { + // conncat pass with cid + id = pass + "~" + id; + } else { + // replace last pass with current pass + String lastPass = String.valueOf(pass - 1); + assert id.startsWith(lastPass); + id = id.substring(lastPass.length()); + id = pass + id; + } + return IdGenerator.of(id); + } + public Collection>> communities(){ // TODO: get communities from backend store instead of ram Map>> comms = new HashMap<>();