diff --git a/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/mappings/DenseVectorFieldTest.scala b/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/mappings/DenseVectorFieldTest.scala index 58333c66f..81a85949d 100644 --- a/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/mappings/DenseVectorFieldTest.scala +++ b/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/mappings/DenseVectorFieldTest.scala @@ -1,17 +1,116 @@ package com.sksamuel.elastic4s.requests.mappings -import com.sksamuel.elastic4s.ElasticApi -import com.sksamuel.elastic4s.fields.DenseVectorField -import com.sksamuel.elastic4s.handlers.fields.DenseVectorFieldBuilderFn +import com.sksamuel.elastic4s.{ElasticApi, JacksonSupport} +import com.sksamuel.elastic4s.fields.{DenseVectorField, FlatIndexOptions, HnswIndexOptions, Int8FlatIndexOptions, Int8HnswIndexOptions, L2Norm} +import com.sksamuel.elastic4s.handlers.fields.{DenseVectorFieldBuilderFn, ElasticFieldBuilderFn} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers class DenseVectorFieldTest extends AnyFlatSpec with Matchers with ElasticApi { - private val field = DenseVectorField(name = "myfield", dims = 3) - "A DenseVectorField" should "support dims property" in { + val field = DenseVectorField(name = "myfield", dims = 3) DenseVectorFieldBuilderFn.build(field).string shouldBe """{"type":"dense_vector","dims":3,"index":false,"similarity":"l2_norm"}""" } + + "A DenseVectorField" should "support a hnsw type of kNN algorithm for a index_options if a index property is true" in { + val field = DenseVectorField( + name = "myfield", + dims = 3, + index = true, + similarity = L2Norm, + indexOptions = Some(HnswIndexOptions(m = Some(100), efConstruction = Some(200))) + ) + val jsonStringValue = """{"type":"dense_vector","dims":3,"index":true,"similarity":"l2_norm","index_options":{"type":"hnsw","m":100,"ef_construction":200}}""" + ElasticFieldBuilderFn(field).string shouldBe jsonStringValue + ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonStringValue)) shouldBe (field) + } + + "A DenseVectorField" should "don't support a hnsw type of kNN algorithm for a index_options if a index property is false" in { + val field = DenseVectorField( + name = "myfield", + dims = 3, + similarity = L2Norm, + indexOptions = Some(HnswIndexOptions(m = Some(100), efConstruction = Some(200))) + ) + val jsonStringValue = """{"type":"dense_vector","dims":3,"index":false,"similarity":"l2_norm"}""" + ElasticFieldBuilderFn(field).string shouldBe jsonStringValue + ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonStringValue)) shouldBe (field.copy(indexOptions = None)) + } + + "A DenseVectorField" should "support a int8_hnsw type of kNN algorithm for a index_options if a index property is true" in { + val field = DenseVectorField( + name = "myfield", + dims = 3, + index = true, + similarity = L2Norm, + indexOptions = Some(Int8HnswIndexOptions(m = Some(100), efConstruction = Some(200), confidenceInterval = Some(0.5d))) + ) + val jsonStringValue = """{"type":"dense_vector","dims":3,"index":true,"similarity":"l2_norm","index_options":{"type":"int8_hnsw","m":100,"ef_construction":200,"confidence_interval":0.5}}""" + ElasticFieldBuilderFn(field).string shouldBe jsonStringValue + ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonStringValue)) shouldBe (field) + } + + "A DenseVectorField" should "don't support a int8_hnsw type of kNN algorithm for a index_options if a index property is false" in { + val field = DenseVectorField( + name = "myfield", + dims = 3, + similarity = L2Norm, + indexOptions = Some(Int8HnswIndexOptions(m = Some(100), efConstruction = Some(200), confidenceInterval = Some(0.5d))) + ) + val jsonStringValue = """{"type":"dense_vector","dims":3,"index":false,"similarity":"l2_norm"}""" + ElasticFieldBuilderFn(field).string shouldBe jsonStringValue + ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonStringValue)) shouldBe (field.copy(indexOptions = None)) + } + + "A DenseVectorField" should "support a flat type of kNN algorithm for a index_options if a index property is true" in { + val field = DenseVectorField( + name = "myfield", + dims = 3, + index = true, + similarity = L2Norm, + indexOptions = Some(FlatIndexOptions()) + ) + val jsonStringValue = """{"type":"dense_vector","dims":3,"index":true,"similarity":"l2_norm","index_options":{"type":"flat"}}""" + ElasticFieldBuilderFn(field).string shouldBe jsonStringValue + ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonStringValue)) shouldBe (field) + } + + "A DenseVectorField" should "don't support a flat type of kNN algorithm for a index_options if a index property is false" in { + val field = DenseVectorField( + name = "myfield", + dims = 3, + similarity = L2Norm, + indexOptions = Some(FlatIndexOptions()) + ) + val jsonStringValue = """{"type":"dense_vector","dims":3,"index":false,"similarity":"l2_norm"}""" + ElasticFieldBuilderFn(field).string shouldBe jsonStringValue + ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonStringValue)) shouldBe (field.copy(indexOptions = None)) + } + + "A DenseVectorField" should "support a int8_flat type of kNN algorithm for a index_options if a index property is true" in { + val field = DenseVectorField( + name = "myfield", + dims = 3, + index = true, + similarity = L2Norm, + indexOptions = Some(Int8FlatIndexOptions(confidenceInterval = Some(0.5d))) + ) + val jsonStringValue = """{"type":"dense_vector","dims":3,"index":true,"similarity":"l2_norm","index_options":{"type":"int8_flat","confidence_interval":0.5}}""" + ElasticFieldBuilderFn(field).string shouldBe jsonStringValue + ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonStringValue)) shouldBe (field) + } + + "A DenseVectorField" should "don't support a int8_flat type of kNN algorithm for a index_options if a index property is false" in { + val field = DenseVectorField( + name = "myfield", + dims = 3, + similarity = L2Norm, + indexOptions = Some(Int8FlatIndexOptions(confidenceInterval = Some(0.5d))) + ) + val jsonStringValue = """{"type":"dense_vector","dims":3,"index":false,"similarity":"l2_norm"}""" + ElasticFieldBuilderFn(field).string shouldBe jsonStringValue + ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonStringValue)) shouldBe (field.copy(indexOptions = None)) + } } diff --git a/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/fields/DenseVectorField.scala b/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/fields/DenseVectorField.scala index 385440258..2815cf914 100644 --- a/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/fields/DenseVectorField.scala +++ b/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/fields/DenseVectorField.scala @@ -13,6 +13,25 @@ case object Cosine extends Similarity { val name = "cosine" } case class DenseVectorField(name: String, dims: Int, index: Boolean = false, - similarity: Similarity = L2Norm) extends ElasticField { - override def `type`: String = DenseVectorField.`type` + similarity: Similarity = L2Norm, + indexOptions: Option[DenseVectorIndexOptions] = None) extends ElasticField { + override def `type`: String = DenseVectorField.`type` +} + +sealed trait DenseVectorIndexOptions { + def `type`: String +} +case class HnswIndexOptions(m: Option[Int] = None, efConstruction: Option[Int] = None) extends DenseVectorIndexOptions { + val `type`: String = "hnsw" +} +case class Int8HnswIndexOptions(m: Option[Int] = None, + efConstruction: Option[Int] = None, + confidenceInterval: Option[Double] = None) extends DenseVectorIndexOptions { + val `type`: String = "int8_hnsw" +} +case class FlatIndexOptions() extends DenseVectorIndexOptions { + val `type`: String = "flat" +} +case class Int8FlatIndexOptions(confidenceInterval: Option[Double] = None) extends DenseVectorIndexOptions { + val `type`: String = "int8_flat" } diff --git a/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/fields/DenseVectorFieldBuilderFn.scala b/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/fields/DenseVectorFieldBuilderFn.scala index 4619ee395..2414ef618 100644 --- a/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/fields/DenseVectorFieldBuilderFn.scala +++ b/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/fields/DenseVectorFieldBuilderFn.scala @@ -1,15 +1,32 @@ package com.sksamuel.elastic4s.handlers.fields -import com.sksamuel.elastic4s.fields.DenseVectorField +import com.sksamuel.elastic4s.fields.{DenseVectorField, DenseVectorIndexOptions, FlatIndexOptions, HnswIndexOptions, Int8FlatIndexOptions, Int8HnswIndexOptions} import com.sksamuel.elastic4s.json.{XContentBuilder, XContentFactory} object DenseVectorFieldBuilderFn { + + private def getIndexOptions(values: Map[String, Any]): DenseVectorIndexOptions= + values("type").asInstanceOf[String] match { + case "hnsw" => HnswIndexOptions( + values.get("m").map(_.asInstanceOf[Int]), + values.get("ef_construction").map(_.asInstanceOf[Int]) + ) + case "int8_hnsw" => Int8HnswIndexOptions( + values.get("m").map(_.asInstanceOf[Int]), + values.get("ef_construction").map(_.asInstanceOf[Int]), + values.get("confidence_interval").map(_.asInstanceOf[Double]) + ) + case "flat" => FlatIndexOptions() + case "int8_flat" => Int8FlatIndexOptions(values.get("confidence_interval").map(_.asInstanceOf[Double])) + } + def toField(name: String, values: Map[String, Any]): DenseVectorField = DenseVectorField( name, - values.get("dims").map(_.asInstanceOf[Int]).get + values("dims").asInstanceOf[Int], + values("index").asInstanceOf[Boolean], + indexOptions = values.get("index_options").map(_.asInstanceOf[Map[String, Any]]).map(getIndexOptions) ) - def build(field: DenseVectorField): XContentBuilder = { val builder = XContentFactory.jsonBuilder() @@ -17,6 +34,23 @@ object DenseVectorFieldBuilderFn { builder.field("dims", field.dims) builder.field("index", field.index) builder.field("similarity", field.similarity.name) + field.indexOptions.filter(_ => field.index).foreach { options => + builder.startObject("index_options") + builder.field("type", options.`type`) + options match { + case HnswIndexOptions(m, efConstruction) => + m.foreach(builder.field("m", _)) + efConstruction.foreach(builder.field("ef_construction", _)) + case Int8HnswIndexOptions(m, efConstruction, confidenceInterval) => + m.foreach(builder.field("m", _)) + efConstruction.foreach(builder.field("ef_construction", _)) + confidenceInterval.foreach(builder.field("confidence_interval", _)) + case FlatIndexOptions() => () + case Int8FlatIndexOptions(confidenceInterval) => + confidenceInterval.foreach(builder.field("confidence_interval", _)) + } + builder.endObject() + } builder.endObject() } }