Skip to content

Commit

Permalink
Graph logic is now abstracted away from DEG implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ralph Gasser committed May 6, 2024
1 parent 170b996 commit ec780fe
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 133 deletions.
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
package org.vitrivr.cottontail.dbms.index.diskann.graph

import it.unimi.dsi.fastutil.longs.Long2ObjectArrayMap
import jetbrains.exodus.core.dataStructures.hash.LongHashSet
import it.unimi.dsi.fastutil.objects.Object2FloatLinkedOpenHashMap
import org.apache.lucene.search.Weight
import org.vitrivr.cottontail.core.database.TupleId
import org.vitrivr.cottontail.core.types.VectorValue
import org.vitrivr.cottontail.utilities.graph.Graph
import java.lang.Math.floorDiv
import java.util.*
import kotlin.collections.HashMap
import kotlin.collections.HashSet
import kotlin.math.max

typealias NodeId = Long

typealias Weight = Double

/**
* This class implements a Dynamic Exploration Graph (DEG) as proposed in [1]. It can be used to perform approximate nearest neighbour search (ANNS).
*
Expand All @@ -20,7 +20,7 @@ typealias Weight = Double
* @author Ralph Gasser
* @version 1.0.0
*/
abstract class AbstractDynamicExplorationGraph<I,V>(val degree: Int): Iterable<Pair<NodeId,AbstractDynamicExplorationGraph<I,V>.Node>> {
abstract class AbstractDynamicExplorationGraph<I,V>(private val degree: Int, val graph: Graph<AbstractDynamicExplorationGraph<I,V>.Node>) {


init {
Expand All @@ -35,53 +35,60 @@ abstract class AbstractDynamicExplorationGraph<I,V>(val degree: Int): Iterable<P
*/
fun index(identifier: I, vector: V, epsilon: Double) {
val count = this.size()
val newNodeId = count + 1
val newNode = Node(identifier, Long2ObjectArrayMap(this.degree))

if (size() <= this.degree + 1) { /* Case 1: Graph does not satisfy regularity condition since it is too small: Create new node and make all existing nodes connect to */
for ((nodeId, node) in this) {
val distance = this.distance(vector, node.vector)
node.addEdge(newNodeId, distance)
newNode.addEdge(nodeId, distance)
/* Create new (empty) node and store vector. */
val newNode = Node(identifier)
this.storeVector(identifier, vector)

if (count <= this.degree)
{ /* Case 1: Graph does not satisfy regularity condition since it is too small: Create new node and make all existing nodes connect to */
this.graph.addVertex(newNode)
for (node in this.graph) {
val distance = this.distance(vector, node.vector).toFloat()
if (node != newNode) {
this.graph.addEdge(newNode, node, distance)
this.graph.addEdge(node, newNode, distance)
}
}
} else { /* Case 2: Graph is not regular. */
val nearest = this.search(vector, this.degree, epsilon)
val search = this.search(vector, this.degree, epsilon)
val connect = HashMap<Node, Float>()
var skipRng = false

/* Start insert procedure (. */
while (nearest.size < this.degree) {
val nodesToExplore = nearest.entries.filter { !nearest.containsKey(it.key) }.associate { it.key to it.value.first }.toMutableMap()
while (nearest.size < this.degree && nodesToExplore.isNotEmpty()) {
var closestNodeId = nodesToExplore.keys.first()
var closestNode = nodesToExplore.values.first()
/* Start insert procedure. */
while (connect.size < this.degree) {
val nodesToExplore = search.entries.filter { !connect.contains(it.key) }.associate { it.key to it.value }.toMutableMap()
while (connect.size < this.degree && nodesToExplore.isNotEmpty()) {
var closestNode = nodesToExplore.keys.first()
var closestDistance = Double.MAX_VALUE
for ((nodeId, node) in nodesToExplore.entries) {
for ((node, _) in nodesToExplore.entries) {
val distance = this.distance(vector, node.vector)
if (distance < closestDistance) {
closestDistance = distance
closestNodeId = nodeId
closestNode = node
}
nodesToExplore.remove(closestNodeId)
}
nodesToExplore.remove(closestNode)

/* Identify the best vertex to connect to existing vertex. */
if (skipRng || checkMrng(newNode, closestNode)) {
val longestEdge = closestNode.neighbours.entries.filter { newNode.neighbours.containsKey(it.key) }.maxBy { it.value }
newNode.addEdge(longestEdge.key, distance)
newNode.addEdge(closestNodeId, closestDistance)
/* Identify the best vertex to connect to existing vertex. */
if (skipRng || checkMrng(newNode, connect, closestNode)) {
val farthestNodeFromClosest = this.graph.edges(closestNode).filter { !connect.contains(it.key) }.maxBy { it.value }.key
connect[closestNode] = this.distance(closestNode.vector, newNode.vector).toFloat()
connect[farthestNodeFromClosest] = this.distance(farthestNodeFromClosest.vector, newNode.vector).toFloat()

/* Update receiving node. */
closestNode.removeEdge(longestEdge.key)
storeNode(closestNodeId, closestNode)
}
/* Update receiving node. */
this.graph.removeEdge(farthestNodeFromClosest, closestNode)
}
}
skipRng = true
}
}

/* Store new node. */
this.storeNode(newNodeId, newNode)
/* */
this.graph.addVertex(newNode)
for ((node, weight) in connect) {
this.graph.addEdge(newNode, node, weight)
}
}
}

/**
Expand All @@ -92,88 +99,58 @@ abstract class AbstractDynamicExplorationGraph<I,V>(val degree: Int): Iterable<P
* @param epsilon The epsilon value for the search.
* @return [List] of [Triple]s containing the [TupleId], distance and [VectorValue] of the approximate nearest neighbours.
*/
fun search(query: V, k: Int, epsilon: Double): Map<NodeId,Pair<Node,Double>> {
val seed = this.getSeedNodes()
val checked = LongHashSet()
var r = Double.MAX_VALUE
fun search(query: V, k: Int, epsilon: Double): Map<Node,Float> {
val seed = this.getSeedNodes(this.degree)
val checked = HashSet<Node>()
var r = Float.MAX_VALUE

/* Results. */
val results = Long2ObjectArrayMap<Pair<Node,Double>>(k + 1)
val results = Object2FloatLinkedOpenHashMap<Node>(k + 1)

/* Perform search. */
while (seed.isNotEmpty()) {
/* Find seed node closest to query. */
var closestNodeId = seed.keys.first()
var closestNode: Node = seed.values.first()
var closestNode: Node = seed.first()
var closestDistance = Double.MAX_VALUE
for ((id, node) in seed) {
for (node in seed) {
val distance = this.distance(query, node.vector)
if (distance < closestDistance) {
closestDistance = distance
closestNodeId = id
closestNode = node
}
}
seed.remove(closestNodeId)
seed.remove(closestNode)

/* Abort condition. */
if (closestDistance > r * (1 + epsilon)) {
break
}

/* Load neighbouring nodes to continue search. */
for ((nodeId, _) in closestNode.neighbours) {
if (!checked.contains(nodeId)) {
val node = this.getNode(nodeId)
for ((node, _) in this.graph.edges(closestNode)) {
if (!checked.contains(node)) {
val distance = this.distance(query, node.vector)
if (distance < r * (1 + epsilon)) {
seed[nodeId] = node
seed.add(node)
if (distance <= r) {
results[nodeId] = node to distance
results[node] = distance.toFloat()
if (results.size > k) {
val largest = results.long2ObjectEntrySet().maxBy { it.value.second }
results.remove(largest.longKey)
r = largest.value.second
val largest = results.maxBy { it.value }
results.removeFloat(largest.key)
r = largest.value
}
}
}

/* Add node ID to set of checked nodes. */
checked.add(nodeId)
checked.add(node)
}
}
}

return results
}

/**
* Returns a [NodeIterator] over the [Node]s in this [AbstractDynamicExplorationGraph].
*
* The default implementation may not be ideal, depending on what underlying storage is used.
*
* @return [NodeIterator]
*/
override fun iterator(): Iterator<Pair<NodeId,Node>> = NodeIterator()

/**
* Stores the [Node] with the given [NodeId]
*
* @param nodeId The [NodeId] of the [Node] to return.
* @param node The [Node] to store.
* @throws NoSuchElementException If [Node] with [NodeId] doesn't exist.
*/
protected abstract fun storeNode(nodeId: NodeId, node: Node)

/**
* Returns the [Node] with the given [NodeId]
*
* @param nodeId The [NodeId] of the [Node] to return.
* @return [Node]
* @throws NoSuchElementException If [Node] with [NodeId] doesn't exist.
*/
protected abstract fun getNode(nodeId: NodeId): Node

/**
* Returns the size of this [AbstractDynamicExplorationGraph].
*
Expand All @@ -189,6 +166,8 @@ abstract class AbstractDynamicExplorationGraph<I,V>(val degree: Int): Iterable<P
*/
protected abstract fun loadVector(identifier: I): V

protected abstract fun storeVector(identifier: I, vector: V)

/**
* Calculates the distance between two vectors [V]s.
*
Expand All @@ -204,19 +183,16 @@ abstract class AbstractDynamicExplorationGraph<I,V>(val degree: Int): Iterable<P
* @param size The number of seed [Node]s to obtain.
* @return [MutableMap of [AbstractDynamicExplorationGraph.Node]s keyed by [NodeId]
*/
private fun getSeedNodes(size: Int = 10): MutableMap<NodeId, Node> {
val map = Long2ObjectArrayMap<Node>()
val random = SplittableRandom()
(0 until size).map {
while (true) {
val nextNodeId = random.nextLong(0L, this.size())
val nextNode = this.getNode(nextNodeId)
if (map.putIfAbsent(nextNodeId, nextNode) != null) {
break
}
private fun getSeedNodes(size: Int): MutableSet<Node> {
require(size <= this.size()) { "Negative size of $size" }
val set = HashSet<Node>()
for ((i, node) in this.graph.withIndex()) {
if (i % floorDiv(this.graph.size(), size.toLong()) == 0L) {
set.add(node)
}
if (set.size >= size) break
}
return map
return set
}

/**
Expand All @@ -226,11 +202,12 @@ abstract class AbstractDynamicExplorationGraph<I,V>(val degree: Int): Iterable<P
* @param v2 The second [Node].
* @return True if MRNG condition is satisfied, false otherwise.
*/
private fun checkMrng(v1: Node, v2: Node): Boolean {
val neighbours = v1.neighbours.keys.intersect(v2.neighbours.keys)
private fun checkMrng(v1: Node, v1N: Map<Node,Float>, v2: Node): Boolean {
val v2N = this.graph.edges(v2)
val neighbours = v1N.keys intersect v2N.keys
val distance = this.distance(v1.vector, v2.vector)
for (nodeId in neighbours) {
if (distance > max(v2.neighbours[nodeId] ?: 0.0, v1.neighbours[nodeId] ?: 0.0)) {
for (node in neighbours) {
if (distance > max(v2N[node] ?: 0.0f, v1N[node] ?: 0.0f)) {
return false
}
}
Expand All @@ -243,43 +220,10 @@ abstract class AbstractDynamicExplorationGraph<I,V>(val degree: Int): Iterable<P
* @author Ralph Gasser
* @version 1.0.0
*/
inner class Node(val identifier: I, private val _edges: MutableMap<NodeId,Weight>) {
inner class Node(val identifier: I) {
/** The [VectorValue]; this value is loaded lazily. */
val vector: V by lazy { loadVector(this.identifier) }

/** The neighbours of this [Node]. */
val neighbours: Map<NodeId,Weight>
get() = this._edges.toMap()

/**
* Adds a new edge to this [Node].
*
* @param nodeId The [NodeId] of the
*/
fun addEdge(nodeId: NodeId, weight: Weight) {
require(this._edges.size < this@AbstractDynamicExplorationGraph.degree) { "Node contains to many edges (maximum degree is ${this@AbstractDynamicExplorationGraph.degree})." }
require(nodeId > 0 && nodeId < this@AbstractDynamicExplorationGraph.size()) { "NodeId $nodeId is out-of-bounds (maximum size = ${size()})." }
this._edges[nodeId] = weight
}

/**
* Removes an edge from this [Node].
*
* @param nodeId The [NodeId] of the edge to remove.
*/
fun removeEdge(nodeId: NodeId) {
this._edges.remove(nodeId)
}
}

/**
* Returns an [Iterator] over the [Node]s in this [AbstractDynamicExplorationGraph].
*
* <strong>Important:</string> This is a fairly naive implementation that could be improved in concrete implementations.
*/
inner class NodeIterator: Iterator<Pair<NodeId,Node>> {
private var current: NodeId = 0L
override fun hasNext(): Boolean = this.current < this@AbstractDynamicExplorationGraph.size()
override fun next(): Pair<NodeId,Node> = this.current to this@AbstractDynamicExplorationGraph.getNode(this.current++)
override fun equals(other: Any?): Boolean = other is AbstractDynamicExplorationGraph<*,*>.Node && other.identifier == this.identifier
override fun hashCode(): Int = this.identifier.hashCode()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package org.vitrivr.cottontail.dbms.index.diskann.graph

import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap
import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap
import org.vitrivr.cottontail.utilities.graph.Graph
import org.vitrivr.cottontail.utilities.graph.memory.InMemoryGraph

/**
*
*/
class InMemoryDynamicExplorationGraph<I,V>(degree: Int, private val df: (V, V) -> Double): AbstractDynamicExplorationGraph<I,V>(degree, InMemoryGraph(degree)) {
private val vectors = Object2ObjectOpenHashMap<I,V>()
override fun size(): Long = this.graph.size()
override fun distance(a: V, b: V): Double = this.df(a, b)
override fun loadVector(identifier: I): V = this.vectors[identifier] ?: throw NoSuchElementException("Could not find vector for identifier $identifier")
override fun storeVector(identifier: I, vector: V) {
this.vectors[identifier] = vector
}
}
Loading

0 comments on commit ec780fe

Please sign in to comment.