diff --git a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/AlgorithmPool.java b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/AlgorithmPool.java index 0299cf501d..9ded512ef4 100644 --- a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/AlgorithmPool.java +++ b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/AlgorithmPool.java @@ -22,11 +22,11 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import com.baidu.hugegraph.job.algorithm.cent.BetweennessCentralityAlgorithmV2; -import com.baidu.hugegraph.job.algorithm.cent.StressCentralityAlgorithm; +import com.baidu.hugegraph.job.algorithm.cent.BetweennessCentralityAlgorithm; import com.baidu.hugegraph.job.algorithm.cent.ClosenessCentralityAlgorithm; import com.baidu.hugegraph.job.algorithm.cent.DegreeCentralityAlgorithm; import com.baidu.hugegraph.job.algorithm.cent.EigenvectorCentralityAlgorithm; +import com.baidu.hugegraph.job.algorithm.cent.StressCentralityAlgorithm; import com.baidu.hugegraph.job.algorithm.comm.ClusterCoeffcientAlgorithm; import com.baidu.hugegraph.job.algorithm.comm.KCoreAlgorithm; import com.baidu.hugegraph.job.algorithm.comm.LouvainAlgorithm; @@ -36,6 +36,7 @@ import com.baidu.hugegraph.job.algorithm.path.RingsDetectAlgorithm; import com.baidu.hugegraph.job.algorithm.rank.PageRankAlgorithm; import com.baidu.hugegraph.job.algorithm.similarity.FusiformSimilarityAlgorithm; +import com.baidu.hugegraph.util.E; public class AlgorithmPool { @@ -47,7 +48,7 @@ public class AlgorithmPool { INSTANCE.register(new DegreeCentralityAlgorithm()); INSTANCE.register(new StressCentralityAlgorithm()); - INSTANCE.register(new BetweennessCentralityAlgorithmV2()); + INSTANCE.register(new BetweennessCentralityAlgorithm()); INSTANCE.register(new ClosenessCentralityAlgorithm()); INSTANCE.register(new EigenvectorCentralityAlgorithm()); @@ -81,6 +82,13 @@ public Algorithm find(String name) { return this.algorithms.get(name); } + public Algorithm get(String name) { + Algorithm algorithm = this.algorithms.get(name); + E.checkArgument(algorithm != null, + "Not found algorithm '%s'", name); + return algorithm; + } + public static AlgorithmPool instance() { return INSTANCE; } diff --git a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/SubgraphStatAlgorithm.java b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/SubgraphStatAlgorithm.java index 46aa797822..d91748e41e 100644 --- a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/SubgraphStatAlgorithm.java +++ b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/SubgraphStatAlgorithm.java @@ -33,14 +33,6 @@ import com.baidu.hugegraph.config.CoreOptions; import com.baidu.hugegraph.config.HugeConfig; import com.baidu.hugegraph.job.UserJob; -import com.baidu.hugegraph.job.algorithm.cent.BetweennessCentralityAlgorithmV2; -import com.baidu.hugegraph.job.algorithm.cent.StressCentralityAlgorithm; -import com.baidu.hugegraph.job.algorithm.cent.ClosenessCentralityAlgorithm; -import com.baidu.hugegraph.job.algorithm.cent.DegreeCentralityAlgorithm; -import com.baidu.hugegraph.job.algorithm.cent.EigenvectorCentralityAlgorithm; -import com.baidu.hugegraph.job.algorithm.comm.ClusterCoeffcientAlgorithm; -import com.baidu.hugegraph.job.algorithm.path.RingsDetectAlgorithm; -import com.baidu.hugegraph.job.algorithm.rank.PageRankAlgorithm; import com.baidu.hugegraph.task.HugeTask; import com.baidu.hugegraph.traversal.algorithm.HugeTraverser; import com.baidu.hugegraph.traversal.optimize.HugeScriptTraversal; @@ -149,34 +141,35 @@ public Traverser(UserJob job) { } public Object subgraphStat(UserJob job) { + AlgorithmPool pool = AlgorithmPool.instance(); Map results = InsertionOrderUtil.newMap(); GraphTraversalSource g = job.graph().traversal(); results.put("vertices_count", g.V().count().next()); results.put("edges_count", g.E().count().next()); - Algorithm algo = new DegreeCentralityAlgorithm(); + Algorithm algo = pool.get("degree_centrality"); Map parameters = ImmutableMap.copyOf(PARAMS); results.put("degrees", algo.call(job, parameters)); - algo = new StressCentralityAlgorithm(); + algo = pool.get("stress_centrality"); results.put("stress", algo.call(job, parameters)); - algo = new BetweennessCentralityAlgorithmV2(); + algo = pool.get("betweenness_centrality"); results.put("betweenness", algo.call(job, parameters)); - algo = new EigenvectorCentralityAlgorithm(); + algo = pool.get("eigenvector_centrality"); results.put("eigenvectors", algo.call(job, parameters)); - algo = new ClosenessCentralityAlgorithm(); + algo = pool.get("closeness_centrality"); results.put("closeness", algo.call(job, parameters)); results.put("page_ranks", pageRanks(job)); - algo = new ClusterCoeffcientAlgorithm(); + algo = pool.get("cluster_coeffcient"); results.put("cluster_coeffcient", algo.call(job, parameters)); - algo = new RingsDetectAlgorithm(); + algo = pool.get("rings"); parameters = ImmutableMap.builder() .putAll(PARAMS) .put("count_only", true) @@ -189,7 +182,7 @@ public Object subgraphStat(UserJob job) { } private Map pageRanks(UserJob job) { - PageRankAlgorithm algo = new PageRankAlgorithm(); + Algorithm algo = AlgorithmPool.instance().get("page_rank"); algo.call(job, ImmutableMap.of("alpha", 0.15)); // Collect page ranks diff --git a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/cent/AbstractCentAlgorithm.java b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/cent/AbstractCentAlgorithm.java index 752dc74ba5..066234873b 100644 --- a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/cent/AbstractCentAlgorithm.java +++ b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/cent/AbstractCentAlgorithm.java @@ -150,7 +150,9 @@ protected GraphTraversal filterNonShortestPath( return t.filter(it -> { Id start = it.path(Pop.first, "v").id(); Id end = it.path(Pop.last, "v").id(); - int len = it.>path(Pop.all, "v").size(); + int len = it.path().size(); + assert len == it.>path(Pop.all, "v").size(); + Pair key = Pair.of(start, end); Integer shortest = triples.get(key); if (shortest != null && len > shortest) { diff --git a/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/cent/BetweennessCentralityAlgorithm.java b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/cent/BetweennessCentralityAlgorithm.java new file mode 100644 index 0000000000..46f4d4a405 --- /dev/null +++ b/hugegraph-core/src/main/java/com/baidu/hugegraph/job/algorithm/cent/BetweennessCentralityAlgorithm.java @@ -0,0 +1,150 @@ +/* + * Copyright 2017 HugeGraph Authors + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with this + * work for additional information regarding copyright ownership. The ASF + * licenses this file to You under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.baidu.hugegraph.job.algorithm.cent; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.commons.lang3.mutable.MutableFloat; +import org.apache.tinkerpop.gremlin.process.traversal.P; +import org.apache.tinkerpop.gremlin.process.traversal.Pop; +import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversal; +import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.__; +import org.apache.tinkerpop.gremlin.structure.Vertex; + +import com.baidu.hugegraph.backend.id.Id; +import com.baidu.hugegraph.backend.id.SplicingIdGenerator; +import com.baidu.hugegraph.job.UserJob; +import com.baidu.hugegraph.structure.HugeElement; +import com.baidu.hugegraph.type.define.Directions; + +public class BetweennessCentralityAlgorithm extends AbstractCentAlgorithm { + + @Override + public String name() { + return "betweenness_centrality"; + } + + @Override + public void checkParameters(Map parameters) { + super.checkParameters(parameters); + } + + @Override + public Object call(UserJob job, Map parameters) { + try (Traverser traverser = new Traverser(job)) { + return traverser.betweennessCentrality(direction(parameters), + edgeLabel(parameters), + depth(parameters), + degree(parameters), + sample(parameters), + sourceLabel(parameters), + sourceSample(parameters), + sourceCLabel(parameters), + top(parameters)); + } + } + + private static class Traverser extends AbstractCentAlgorithm.Traverser { + + public Traverser(UserJob job) { + super(job); + } + + public Object betweennessCentrality(Directions direction, + String label, + int depth, + long degree, + long sample, + String sourceLabel, + long sourceSample, + String sourceCLabel, + long topN) { + assert depth > 0; + assert degree > 0L || degree == NO_LIMIT; + assert topN >= 0L || topN == NO_LIMIT; + + GraphTraversal t = constructSource(sourceLabel, + sourceSample, + sourceCLabel); + t = constructPath(t, direction, label, degree, sample, + sourceLabel, sourceCLabel); + t = t.emit().until(__.loops().is(P.gte(depth))); + t = filterNonShortestPath(t, false); + + GraphTraversal tg = this.groupPathByEndpoints(t); + tg = this.computeBetweenness(tg); + GraphTraversal tLimit = topN(tg, topN); + + return this.execute(tLimit, () -> tLimit.next()); + } + + protected GraphTraversal groupPathByEndpoints( + GraphTraversal t) { + return t.map(it -> { + // t.select(Pop.all, "v").unfold().id() + List path = it.path(Pop.all, "v"); + List pathById = new ArrayList<>(path.size()); + for (HugeElement v : path) { + pathById.add(v.id()); + } + return pathById; + }).group().by(it -> { + // group by the first and last vertex + @SuppressWarnings("unchecked") + List path = (List) it; + assert path.size() >= 2; + String first = path.get(0).toString(); + String last = path.get(path.size() -1).toString(); + return SplicingIdGenerator.concat(first, last); + }).unfold(); + } + + protected GraphTraversal computeBetweenness( + GraphTraversal t) { + return t.fold(new HashMap(), (results, it) -> { + @SuppressWarnings("unchecked") + Map.Entry> entry = (Map.Entry>) it; + @SuppressWarnings("unchecked") + List> paths = (List>) entry.getValue(); + for (List path : paths) { + int len = path.size(); + if (len <= 2) { + // only two vertex, no betweenness vertex + continue; + } + // skip the first and last vertex + for (int i = 1; i < len - 1; i++) { + Id vertex = path.get(i); + MutableFloat value = results.get(vertex); + if (value == null) { + value = new MutableFloat(); + results.put(vertex, value); + } + value.add(1.0f / paths.size()); + } + } + return results; + }); + } + } +}