Skip to content

Commit

Permalink
Fixes some issues and adds unit test for in-memory and persistent DEG.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ralph Gasser committed May 8, 2024
1 parent 412c8b9 commit 6a28f93
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ abstract class AbstractDynamicExplorationGraph<I:Comparable<I>,V>(val degree: In

if (count <= this.degree) { /* Case 1: Graph does not satisfy regularity condition since it is too small; make all existing nodes connect to the new node. */
this.graph.addVertex(newNode)
for (node in this.graph) {
for (node in this.graph.vertices()) {
if (node == newNode) continue
val distance = this.distance(value, this.getValue(node))
this.graph.addEdge(node, newNode, distance)
Expand Down Expand Up @@ -108,15 +108,17 @@ abstract class AbstractDynamicExplorationGraph<I:Comparable<I>,V>(val degree: In

/* Case 1: Small graph - brute-force search. */
if (this.size() < 1000L) {
for (vertex in this.graph) {
val distance = Distance(vertex.label, this.distance(query, this.getValue(vertex)))
distanceComputationCount++
results.add(distance)
if (results.size > k) {
results.pollLast()
this.graph.vertices().use { vertices ->
for (vertex in vertices) {
val distance = Distance(vertex.label, this.distance(query, this.getValue(vertex)))
distanceComputationCount++
results.add(distance)
if (results.size > k) {
results.pollLast()
}
}
return results.toList()
}
return results.toList()
}

/* Case 2a: DEG search. Initialize queue with results vertices to check. */
Expand Down Expand Up @@ -220,18 +222,19 @@ abstract class AbstractDynamicExplorationGraph<I:Comparable<I>,V>(val degree: In
private fun getSeedNodes(sampleSize: Int): List<Node<I>> {
val graphSize = this.graph.size()
require(sampleSize <= graphSize) { "The sample size $sampleSize exceeds graph size of graph (s = $sampleSize, g = $graphSize)." }
val iterator = this.graph.iterator()
var position = 0L
return this.random.longs(0L, graphSize).distinct().limit(sampleSize.toLong()).sorted().mapToObj {
while (iterator.hasNext()) {
if ((position++) == it) {
break
} else {
iterator.next()
this.graph.vertices().use { iterator ->
var position = 0L
return this.random.longs(0L, graphSize).distinct().limit(sampleSize.toLong()).sorted().mapToObj {
while (iterator.hasNext()) {
if ((position++) == it) {
break
} else {
iterator.next()
}
}
}
iterator.next()
}.toList()
iterator.next()
}.toList()
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package org.vitrivr.cottontail.utilities.graph

import org.vitrivr.cottontail.core.basics.CloseableIterator

/**
* A simple, weighted [Graph] data structure on elements of type [V].
*
* @author Ralph Gasser
* @version 1.0.0
*/
interface Graph<V>: Iterable<V> {
interface Graph<V> {
/**
* Returns the number of vertexes in this [Graph].
*
Expand Down Expand Up @@ -66,6 +68,13 @@ interface Graph<V>: Iterable<V> {
*/
fun edges(from: V): Map<V,Float>

/**
* Creates and returns a [CloseableIterator] for this [Graph].
*
* @return A [CloseableIterator] over all vertices [V] in this [Graph].
*/
fun vertices(): CloseableIterator<V>

/**
* Returns the weight from one vertex [V] to another vertex [V] in this [Graph].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.vitrivr.cottontail.utilities.graph.undirected

import it.unimi.dsi.fastutil.objects.Object2FloatOpenHashMap
import it.unimi.dsi.fastutil.objects.Object2ObjectLinkedOpenHashMap
import org.vitrivr.cottontail.core.basics.CloseableIterator
import org.vitrivr.cottontail.utilities.graph.Graph

/**
Expand Down Expand Up @@ -113,5 +114,13 @@ class WeightedUndirectedInMemoryGraph<V>(private val maxDegree: Int = Int.MAX_VA
}


override fun iterator(): Iterator<V> = this.map.keys.iterator()
/**
*
*/
override fun vertices(): CloseableIterator<V> = object : CloseableIterator<V> {
private val iterator = this@WeightedUndirectedInMemoryGraph.map.keys.iterator()
override fun hasNext(): Boolean = this.iterator.hasNext()
override fun next(): V = this.iterator.next()
override fun close() {/* No op */ }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import jetbrains.exodus.bindings.FloatBinding
import jetbrains.exodus.env.Store
import jetbrains.exodus.env.Transaction
import jetbrains.exodus.util.LightOutputStream
import org.vitrivr.cottontail.core.basics.CloseableIterator
import org.vitrivr.cottontail.utilities.graph.Graph
import java.io.ByteArrayInputStream
import java.util.LinkedList
Expand All @@ -21,7 +22,8 @@ class WeightedUndirectedXodusGraph<V>(private val store: Store, private val txn:
private var count: Long = 0L

init {
this.store.openCursor(this.txn.readonlySnapshot).use { cursor ->

this.store.openCursor(this.txn).use { cursor ->
while (cursor.nextNoDup) this.count++
}
}
Expand Down Expand Up @@ -181,18 +183,15 @@ class WeightedUndirectedXodusGraph<V>(private val store: Store, private val txn:
}

/**
* Returns a [CloseableIterator] over vertices [V] for this [WeightedUndirectedXodusGraph].
*
* @return [CloseableIterator]
*/
override fun iterator(): Iterator<V> = object: Iterator<V> {
override fun vertices(): CloseableIterator<V> = object: CloseableIterator<V> {
private val cursor = this@WeightedUndirectedXodusGraph.store.openCursor(this@WeightedUndirectedXodusGraph.txn)
override fun hasNext(): Boolean {
val ret = this.cursor.nextNoDup
if (!ret) {
this.cursor.close()
}
return ret
}
override fun hasNext(): Boolean = this.cursor.nextNoDup
override fun next(): V = this@WeightedUndirectedXodusGraph.serializer.deserialize(this.cursor.key)
override fun close() = this.cursor.close()
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package org.vitrivr.cottontail.dbms.index.deg

import jetbrains.exodus.env.Environments
import jetbrains.exodus.env.StoreConfig
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test
import org.vitrivr.cottontail.core.database.TupleId
import org.vitrivr.cottontail.core.queries.functions.math.distance.binary.EuclideanDistance
import org.vitrivr.cottontail.core.types.Types
import org.vitrivr.cottontail.core.values.DoubleValue
import org.vitrivr.cottontail.core.values.FloatVectorValue
import org.vitrivr.cottontail.dbms.index.diskann.graph.deg.AbstractDynamicExplorationGraph
import org.vitrivr.cottontail.dbms.index.diskann.graph.deg.InMemoryDynamicExplorationGraph
import org.vitrivr.cottontail.dbms.index.diskann.graph.deg.XodusDynamicExplorationGraph
import org.vitrivr.cottontail.dbms.index.diskann.graph.serializer.TupleIdNodeSerializer
import org.vitrivr.cottontail.test.TestConstants
import org.vitrivr.cottontail.utilities.formats.FVecsReader
import org.vitrivr.cottontail.utilities.math.ranking.RankingUtilities
import org.vitrivr.cottontail.utilities.selection.ComparablePair
import org.vitrivr.cottontail.utilities.selection.MinHeapSelection
import java.util.*
import kotlin.time.Duration
import kotlin.time.measureTime


/**
* This is a basic test case that makes sure that the [XodusDynamicExplorationGraph] and the [InMemoryDynamicExplorationGraph] work as expected.
*
* @author Ralph Gasser
* @version 1.0.0
*/
class BasicDEGTest {

companion object {
const val K = 100
}

/**
* Tests the [XodusDynamicExplorationGraph] with 10'000 SIFT vectors ([FloatVectorValue]).
*/
@Test
fun testPersistentDEGWithSIFTVector() {
Environments.newInstance(TestConstants.testConfig().root.toFile()).use { environment ->
environment.executeInTransaction { txn ->
/* Create a new store. */
val store = environment.openStore("test", StoreConfig.WITH_DUPLICATES, txn, true)!!

/* Prepare parameters. */
val type = Types.FloatVector(128)
val distance = EuclideanDistance.FloatVector(type)
val list = LinkedList<FloatVectorValue>()
val graph = XodusDynamicExplorationGraph<TupleId, FloatVectorValue>(4, 16, 0.2f, store, txn, TupleIdNodeSerializer()) { v1, v2 -> distance.invoke(v1, v2).value.toFloat() }

/* Index vectors and build graph & ground truth. */
this.index(graph, list)

/* Perform search. */
this.search(graph, list, distance)
}
}
}

/**
* Tests the [InMemoryDynamicExplorationGraph] with 10'000 SIFT vectors ([FloatVectorValue]).
*/
@Test
fun testInMemoryDEGWithSIFTVector() {
/* Prepare parameters. */
val type = Types.FloatVector(128)
val distance = EuclideanDistance.FloatVector(type)
val list = LinkedList<FloatVectorValue>()
val graph = InMemoryDynamicExplorationGraph<TupleId, FloatVectorValue>(4, 16, 0.2f) { v1, v2 -> distance.invoke(v1, v2).value.toFloat() }

/* Index vectors and build graph & ground truth. */
this.index(graph, list)

/* Perform search. */
this.search(graph, list, distance)
}

/**
* Indexes the SIFT test data.
*
* @param graph The [AbstractDynamicExplorationGraph] to add to.
* @param list The [LinkedList] for brute-force search.
*/
private fun index(graph: AbstractDynamicExplorationGraph<TupleId, FloatVectorValue>, list: LinkedList<FloatVectorValue> ) {
/* Read vectors and build graph. */
FVecsReader(this.javaClass.getResourceAsStream("/sift/siftsmall_base.fvecs")!!).use { reader ->
while (reader.hasNext()) {
val next = FloatVectorValue(reader.next())
list.add(next)
graph.index(list.size.toLong(), next)
}
}
}

/**
* Searches the SIFT test data.
*
* @param graph The [AbstractDynamicExplorationGraph] to add to.
* @param list The [LinkedList] for brute-force search.
* @param distance The [EuclideanDistance] function.
*/
private fun search(graph: AbstractDynamicExplorationGraph<TupleId, FloatVectorValue>, list: LinkedList<FloatVectorValue>, distance: EuclideanDistance<FloatVectorValue>) {
/* Fetch results through full table scan. */
FVecsReader(this.javaClass.getResourceAsStream("/sift/siftsmall_query.fvecs")!!).use { reader ->
var queries = 0
var recall = 0.0f
var bruteForceDuration = Duration.ZERO
var indexDuration = Duration.ZERO

while (reader.hasNext()) {
val query = FloatVectorValue(reader.next())
val bruteForceResults = MinHeapSelection<ComparablePair<TupleId, DoubleValue>>(K)
bruteForceDuration += measureTime {
list.forEachIndexed { index, vector ->
bruteForceResults.offer(ComparablePair((index + 1L), distance.invoke(query, vector)!!))
}
}

/* Fetch results through index. */
val indexResults = ArrayList<ComparablePair<TupleId, DoubleValue>>(K)
indexDuration += measureTime {
graph.search(query, K, 0.2f).forEach { indexResults.add(ComparablePair(it.label, DoubleValue(it.distance))) }
}
recall += RankingUtilities.recallAtK(bruteForceResults.toList().map { it.first }, indexResults.map { it.first }, K)
queries++
}
recall /= queries
indexDuration /= queries
bruteForceDuration /= queries

/* Since the data comes pre-clustered, accuracy should always be greater than 90%. */
Assertions.assertTrue(recall >= 0.8f) { "Recall attained by indexed search is too small (r = $recall)." }
println("Search using DEG completed (r = $recall, withIndex = $indexDuration, bruteForce = $bruteForceDuration). Brute-force duration is always in memory!")
}
}
}

0 comments on commit 6a28f93

Please sign in to comment.