Skip to content

Commit

Permalink
Adds first working example of (in-memory) DEG with simple unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
ppanopticon committed May 8, 2024
1 parent 41c0d71 commit bf0af0a
Show file tree
Hide file tree
Showing 15 changed files with 469 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ sealed class EuclideanDistance<T : VectorValue<*>>(type: Types.Vector<T,*>): Min
class FloatVectorVectorized(type: Types.Vector<FloatVectorValue,*>): EuclideanDistance<FloatVectorValue>(type), VectorisedFunction<DoubleValue> {
companion object {
@JvmStatic
private val SPECIES: VectorSpecies<Float> = jdk.incubator.vector.FloatVector.SPECIES_PREFERRED
private val SPECIES: VectorSpecies<Float> = jdk.incubator.vector.FloatVector.SPECIES_PREFERRED
}

override val name: Name.FunctionName = FUNCTION_NAME
Expand All @@ -197,19 +197,19 @@ sealed class EuclideanDistance<T : VectorValue<*>>(type: Types.Vector<T,*>): Min
val query = (arguments[1] as FloatVectorValue).data

/* Vectorised distance calculation. */
var vectorSum = jdk.incubator.vector.FloatVector.zero(SPECIES)
var sum = 0.0f
val bound = SPECIES.loopBound(this.vectorSize)
for (i in 0 until bound step SPECIES.length()) {
val vp = jdk.incubator.vector.FloatVector.fromArray(SPECIES, probing, i)
val vq = jdk.incubator.vector.FloatVector.fromArray(SPECIES, query, i)
val diff = vp.lanewise(VectorOperators.SUB, vq)
vectorSum = vectorSum.lanewise(VectorOperators.ADD, diff.lanewise(VectorOperators.MUL, diff))
val diff = vp.sub(vq)
sum += diff.mul(diff).reduceLanes(VectorOperators.ADD)
}

/* Scalar version for remainder. */
var sum = vectorSum.reduceLanes(VectorOperators.ADD)
for (i in bound until this.vectorSize) {
sum += (query[i] - probing[i]).pow(2)
val diff: Float = query[i] - probing[i]
sum = Math.fma(diff, diff, sum)
}
return DoubleValue(sqrt(sum))
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package org.vitrivr.cottontail.utilities.math.ranking

/**
*
* @author Ralph Gasser
* @version 1.0.0
*/
object RankingUtilities {
/**
* Calculates the recall at k for a given [List] of retrieved and relevant items. Comparison is based on object equality.
*
* @param retrieved [List] of retrieved items.
* @param relevant [List] of relevant items.
* @param k The number of items to consider.
*/
fun <V> recallAtK(retrieved: List<V>, relevant: List<V>, k: Int): Float {
require(k > 0) { "Parameter k must be greater than 0." }
require(retrieved.size >= k) { "Number of retrieved items must be greater than k." }
require(relevant.size >= k) { "Number of relevant items must be greater than k." }
var score = 0.0f
for (i in 0 until k) {
if (relevant.contains(retrieved[i])) {
score += 1.0f
}
}
return score/k
}

}
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package org.vitrivr.cottontail.dbms.index.diskann.graph

import it.unimi.dsi.fastutil.objects.Object2FloatLinkedOpenHashMap
import it.unimi.dsi.fastutil.objects.ObjectArraySet
import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet
import org.vitrivr.cottontail.core.database.TupleId
import org.vitrivr.cottontail.core.types.VectorValue
import org.vitrivr.cottontail.utilities.graph.Graph
import java.lang.Math.floorDiv
import kotlin.collections.HashSet
import java.util.*
import kotlin.math.max
import kotlin.streams.toList


/**
* This class implements a Dynamic Exploration Graph (DEG) as proposed in [1]. It can be used to perform approximate nearest neighbour search (ANNS).
Expand All @@ -18,8 +18,9 @@ import kotlin.math.max
* @author Ralph Gasser
* @version 1.0.0
*/
abstract class AbstractDynamicExplorationGraph<I:Comparable<I>,V>(private val degree: Int, val graph: Graph<AbstractDynamicExplorationGraph<I,V>.Node>, private val epsilonExt: Float = 0.3f, private val kExt: Int = 60) {
abstract class AbstractDynamicExplorationGraph<I:Comparable<I>,V>(private val degree: Int, val graph: Graph<AbstractDynamicExplorationGraph<I,V>.Node>, private val epsilonExt: Float = 0.2f, private val kExt: Int = 60) {

private val random = SplittableRandom()

init {
require(this.degree % 2 == 0) { "Dynamic Exploration Graph (DEG) must be even-regular." }
Expand All @@ -37,39 +38,41 @@ abstract class AbstractDynamicExplorationGraph<I:Comparable<I>,V>(private val de
/* Create new (empty) node and store vector. */
val newNode = Node(identifier)
this.storeVector(identifier, vector)
this.graph.addVertex(newNode)

if (count <= this.degree) { /* Case 1: Graph does not satisfy regularity condition since it is too small: Create new node and make all existing nodes connect to */
this.graph.addVertex(newNode)
for (node in this.graph) {
if (node == newNode) continue
val distance = this.distance(vector, node.vector)
this.graph.addEdge(newNode, node, distance)
this.graph.addEdge(node, newNode, distance)
}
} else { /* Case 2: Graph is regular. */
val results = this.search(vector, this.kExt, this.epsilonExt)
var skipRng = false
val results = this.search(vector, this.kExt, this.epsilonExt, this.getSeedNodes(1))
var phase = 0

/* Add new vertex. */
this.graph.addVertex(newNode)

/* Start insert procedure. */
var newNeighbours = this.graph.edges(newNode)
val newNeighbours = this.graph.edges(newNode)
while (newNeighbours.size < this.degree) {
for ((candidateNode, candidateWeight) in results) {
if (newNeighbours.size >= this.degree) break
if (newNeighbours.contains(candidateNode)) continue
if (!(skipRng || checkMrng(newNode, candidateNode, candidateWeight))) continue
if (newNeighbours.containsKey(candidateNode)) continue
if (phase <= 1 && checkMrng(candidateNode, newNode, candidateWeight)) continue

/* Find candidate neighbour. */
val (candidateNeighbour,candidateNeighbourWeight) = this.graph.edges(candidateNode).filter { !newNeighbours.contains(it.key) }.maxBy { it.value }
/* Find new neighbour. */
val newNeighbour = this.graph.edges(candidateNode).filter { it.key !in newNeighbours }.maxByOrNull { it.value }?.key ?: continue
val newNeighbourDistance = this.distance(newNode.vector, newNeighbour.vector)

/* Remove edge from candidate node to candidate neighbour. */
this.graph.removeEdge(candidateNode, candidateNeighbour)
this.graph.removeEdge(candidateNode, newNeighbour)

/* Add edges to new nodes. */
this.graph.addEdge(newNode, candidateNode, candidateWeight)
this.graph.addEdge(newNode, candidateNeighbour, candidateNeighbourWeight)
this.graph.addEdge(newNode, newNeighbour, newNeighbourDistance)
}
skipRng = true
newNeighbours = this.graph.edges(newNode) /* Fetch new neighbours. */
phase += 1
}
}
}
Expand All @@ -82,56 +85,92 @@ abstract class AbstractDynamicExplorationGraph<I:Comparable<I>,V>(private val de
* @param epsilon The epsilon value for the search.
* @return [List] of [Triple]s containing the [TupleId], distance and [VectorValue] of the approximate nearest neighbours.
*/
fun search(query: V, k: Int, epsilon: Float): List<Distance> {
val seed = this.getSeedNodes(this.degree)
val checked = HashSet<Node>()
var r = Float.MAX_VALUE
fun search(query: V, k: Int, eps: Float): List<Distance> = search(query, k, eps, this.getSeedNodes(10))

/* Results. */
val results = Object2FloatLinkedOpenHashMap<Node>(k + 1)
/**
* Performs a search in this [AbstractDynamicExplorationGraph].
*
* @param query The query [VectorValue] to search for.
* @param k The number of nearest neighbours to return.
* @return [List] of [Triple]s containing the [TupleId], distance and [VectorValue] of the approximate nearest neighbours.
*/
protected fun search(query: V, k: Int, eps: Float, seed: List<Node>): List<Distance> {
val checked = ObjectOpenHashSet<Node>()
val results: TreeSet<Distance> = TreeSet<Distance>()
var distanceComputationCount = 0

/* Perform search. */
while (seed.isNotEmpty()) {
/* Find seed node closest to query. */
var closestNode: Node = seed.first()
var closestDistance = Float.MAX_VALUE
for (node in seed) {
val distance = this.distance(query, node.vector)
if (distance < closestDistance) {
closestDistance = distance
closestNode = node
/* Case 1: Small graph - brute-force search. */
if (this.size() < 1000L) {
for (vertex in this.graph) {
val distance = Distance(vertex, this.distance(query, vertex.vector))
distanceComputationCount++
results.add(distance)
if (results.size > k) {
results.pollLast()
}
}
seed.remove(closestNode)
return results.toList()
}

/* Case 2a: DEG search. Initialize queue with results vertices to check. */
var radius = Float.MAX_VALUE
val nextNodes = PriorityQueue<Distance>(k * 10)
for (node in seed) {
if (!checked.contains(node)) {
/* Mark node as checked. */
checked.add(node)

/* Calculate distance and add to queue. */
val distance = Distance(node, this.distance(query, node.vector))
distanceComputationCount++
nextNodes.add(distance)
if (distance.distance < radius) {
results.add(distance)
if (results.size > k) {
results.pollLast()
radius = results.last().distance
}
}
}

}

/* Perform based on queue. */
while (nextNodes.isNotEmpty()) {
/* Find seed node closest to query. */
val next: Distance = nextNodes.poll()

/* Abort condition. */
if (closestDistance > r * (1 + epsilon)) {
if (next.distance > radius * (1 + eps)) {
break
}

/* Load neighbouring nodes to continue search. */
for ((node, _) in this.graph.edges(closestNode)) {
for ((node, _) in this.graph.edges(next.identifier)) {
if (!checked.contains(node)) {
val distance = this.distance(query, node.vector)
if (distance < r * (1 + epsilon)) {
seed.add(node)
if (distance <= r) {
results[node] = distance
/* Mark node as checked. */
checked.add(node)

/* Calculate distance and add to queue. */
val distance = Distance(node, this.distance(query, node.vector))
distanceComputationCount++
if (distance.distance <= radius * (1 + eps)) {
/* Add node ID to set of checked nodes. */
nextNodes.add(distance)

if (distance.distance < radius) {
results.add(distance)
if (results.size > k) {
val largest = results.maxBy { it.value }
results.removeFloat(largest.key)
r = largest.value
results.pollLast()
radius = results.last().distance
}
}
}

/* Add node ID to set of checked nodes. */
checked.add(node)
}
}
}

return results.map { Distance(it.key, it.value) }.sorted()
return results.toList()
}

/**
Expand Down Expand Up @@ -163,21 +202,22 @@ abstract class AbstractDynamicExplorationGraph<I:Comparable<I>,V>(private val de
/**
* Obtains random seed [Node]s for range search.
*
* @param size The number of seed [Node]s to obtain.
* @param sampleSize The number of seed [Node]s to obtain.
* @return [MutableSet] of [AbstractDynamicExplorationGraph.Node]s
*/
private fun getSeedNodes(size: Int): MutableSet<Node> {
private fun getSeedNodes(sampleSize: Int): List<Node> {
val graphSize = this.graph.size()
val sampleSize = size.toLong()
require(sampleSize <= graphSize) { "The sample size $sampleSize exceeds graph size of graph (s = $sampleSize, g = $graphSize)" }
val set = ObjectArraySet<Node>(size)
for ((i, node) in this.graph.withIndex()) {
if (i % floorDiv(graphSize, sampleSize) == 0L) {
set.add(node)
if (set.size >= size) break
val indexes = this.random.longs(0L, graphSize).distinct().limit(sampleSize.toLong()).sorted().toList().toSet()
val results = ArrayList<Node>(sampleSize)
val iterator = this.graph.iterator()
for (i in 0L until graphSize) {
val next = iterator.next()
if (i in indexes) {
results.add(next)
}
}
return set
return results
}

/**
Expand All @@ -188,10 +228,9 @@ abstract class AbstractDynamicExplorationGraph<I:Comparable<I>,V>(private val de
* @return True if MRNG condition is satisfied, false otherwise.
*/
private fun checkMrng(v1: Node, v2: Node, targetWeight: Float): Boolean {
val v1N = this.graph.edges(v1)
val v2N = this.graph.edges(v2)
for (node in (v1N.keys intersect v2N.keys)) {
if (targetWeight > max(v2N[node]!!, v1N[node]!!)) {
for ((neighbour, neighbourWeight) in this.graph.edges(v1)) {
val neighbourTargetWeight = this.graph.weight(neighbour, v2)
if (neighbourTargetWeight >= 0.0f && targetWeight > max(neighbourWeight, neighbourTargetWeight)) {
return false
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package org.vitrivr.cottontail.dbms.index.diskann.graph

import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap
import org.vitrivr.cottontail.utilities.graph.memory.InMemoryGraph
import org.vitrivr.cottontail.utilities.graph.undirected.WeightedUndirectedInMemoryGraph

/**
*
*/
class InMemoryDynamicExplorationGraph<I: Comparable<I>,V>(degree: Int, private val df: (V, V) -> Float): AbstractDynamicExplorationGraph<I,V>(degree, InMemoryGraph(degree)) {
class InMemoryDynamicExplorationGraph<I: Comparable<I>,V>(degree: Int, private val df: (V, V) -> Float): AbstractDynamicExplorationGraph<I,V>(degree, WeightedUndirectedInMemoryGraph(degree)) {
private val vectors = Object2ObjectOpenHashMap<I,V>()
override fun size(): Long = this.graph.size()
override fun distance(a: V, b: V): Float = this.df(a, b)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package org.vitrivr.cottontail.dbms.index.diskann.graph

import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap
import org.vitrivr.cottontail.utilities.graph.undirected.WeightedUndirectedInMemoryGraph

/**
*
* @author Ralph Gasser
* @version 1.0.0
*/
class PersistentDynamicExplorationGraph<I: Comparable<I>,V>(degree: Int, private val df: (V, V) -> Float): AbstractDynamicExplorationGraph<I,V>(degree, WeightedUndirectedInMemoryGraph(degree)) {
private val vectors = Object2ObjectOpenHashMap<I,V>()
override fun size(): Long = this.graph.size()
override fun distance(a: V, b: V): Float = this.df(a, b)
override fun loadVector(identifier: I): V = this.vectors[identifier] ?: throw NoSuchElementException("Could not find vector for identifier $identifier")
override fun storeVector(identifier: I, vector: V) {
this.vectors[identifier] = vector
}
}
Loading

0 comments on commit bf0af0a

Please sign in to comment.