diff --git a/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/mappings/DenseVectorFieldTest.scala b/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/mappings/DenseVectorFieldTest.scala index 81a85949d..a6c1c5291 100644 --- a/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/mappings/DenseVectorFieldTest.scala +++ b/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/mappings/DenseVectorFieldTest.scala @@ -1,116 +1,60 @@ package com.sksamuel.elastic4s.requests.mappings -import com.sksamuel.elastic4s.{ElasticApi, JacksonSupport} -import com.sksamuel.elastic4s.fields.{DenseVectorField, FlatIndexOptions, HnswIndexOptions, Int8FlatIndexOptions, Int8HnswIndexOptions, L2Norm} -import com.sksamuel.elastic4s.handlers.fields.{DenseVectorFieldBuilderFn, ElasticFieldBuilderFn} +import com.sksamuel.elastic4s.fields.DenseVectorField.{Flat, Hnsw, Int8Flat, Int8Hnsw} +import com.sksamuel.elastic4s.ElasticApi +import com.sksamuel.elastic4s.fields.{Cosine, DenseVectorField, DenseVectorIndexOptions, DotProduct, L2Norm, MaxInnerProduct} +import com.sksamuel.elastic4s.handlers.fields.{DenseVectorFieldBuilderFn} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers class DenseVectorFieldTest extends AnyFlatSpec with Matchers with ElasticApi { + private val denseVectorIndexOptions = DenseVectorIndexOptions(Int8Hnsw, Some(10), Some(100), Some(1.0f)) "A DenseVectorField" should "support dims property" in { val field = DenseVectorField(name = "myfield", dims = 3) DenseVectorFieldBuilderFn.build(field).string shouldBe - """{"type":"dense_vector","dims":3,"index":false,"similarity":"l2_norm"}""" + """{"type":"dense_vector","dims":3,"index":false}""" } - "A DenseVectorField" should "support a hnsw type of kNN algorithm for a index_options if a index property is true" in { - val field = DenseVectorField( - name = "myfield", - dims = 3, - index = true, - similarity = L2Norm, - indexOptions = Some(HnswIndexOptions(m = Some(100), efConstruction = Some(200))) - ) - val jsonStringValue = """{"type":"dense_vector","dims":3,"index":true,"similarity":"l2_norm","index_options":{"type":"hnsw","m":100,"ef_construction":200}}""" - ElasticFieldBuilderFn(field).string shouldBe jsonStringValue - ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonStringValue)) shouldBe (field) - } - - "A DenseVectorField" should "don't support a hnsw type of kNN algorithm for a index_options if a index property is false" in { - val field = DenseVectorField( - name = "myfield", - dims = 3, - similarity = L2Norm, - indexOptions = Some(HnswIndexOptions(m = Some(100), efConstruction = Some(200))) - ) - val jsonStringValue = """{"type":"dense_vector","dims":3,"index":false,"similarity":"l2_norm"}""" - ElasticFieldBuilderFn(field).string shouldBe jsonStringValue - ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonStringValue)) shouldBe (field.copy(indexOptions = None)) - } - - "A DenseVectorField" should "support a int8_hnsw type of kNN algorithm for a index_options if a index property is true" in { - val field = DenseVectorField( - name = "myfield", - dims = 3, - index = true, - similarity = L2Norm, - indexOptions = Some(Int8HnswIndexOptions(m = Some(100), efConstruction = Some(200), confidenceInterval = Some(0.5d))) - ) - val jsonStringValue = """{"type":"dense_vector","dims":3,"index":true,"similarity":"l2_norm","index_options":{"type":"int8_hnsw","m":100,"ef_construction":200,"confidence_interval":0.5}}""" - ElasticFieldBuilderFn(field).string shouldBe jsonStringValue - ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonStringValue)) shouldBe (field) - } - - "A DenseVectorField" should "don't support a int8_hnsw type of kNN algorithm for a index_options if a index property is false" in { - val field = DenseVectorField( - name = "myfield", - dims = 3, - similarity = L2Norm, - indexOptions = Some(Int8HnswIndexOptions(m = Some(100), efConstruction = Some(200), confidenceInterval = Some(0.5d))) - ) - val jsonStringValue = """{"type":"dense_vector","dims":3,"index":false,"similarity":"l2_norm"}""" - ElasticFieldBuilderFn(field).string shouldBe jsonStringValue - ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonStringValue)) shouldBe (field.copy(indexOptions = None)) + it should "support all similarity options" in { + val field = DenseVectorField(name = "myfield", dims = 3, index = true, similarity = L2Norm) + DenseVectorFieldBuilderFn.build(field).string shouldBe + """{"type":"dense_vector","dims":3,"index":true,"similarity":"l2_norm"}""" + DenseVectorFieldBuilderFn.build(field.similarity(DotProduct)).string shouldBe + """{"type":"dense_vector","dims":3,"index":true,"similarity":"dot_product"}""" + DenseVectorFieldBuilderFn.build(field.similarity(Cosine)).string shouldBe + """{"type":"dense_vector","dims":3,"index":true,"similarity":"cosine"}""" + DenseVectorFieldBuilderFn.build(field.similarity(MaxInnerProduct)).string shouldBe + """{"type":"dense_vector","dims":3,"index":true,"similarity":"max_inner_product"}""" } - "A DenseVectorField" should "support a flat type of kNN algorithm for a index_options if a index property is true" in { - val field = DenseVectorField( - name = "myfield", - dims = 3, - index = true, - similarity = L2Norm, - indexOptions = Some(FlatIndexOptions()) - ) - val jsonStringValue = """{"type":"dense_vector","dims":3,"index":true,"similarity":"l2_norm","index_options":{"type":"flat"}}""" - ElasticFieldBuilderFn(field).string shouldBe jsonStringValue - ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonStringValue)) shouldBe (field) + it should "support elementType property" in { + val field = DenseVectorField(name = "myfield", dims = 3).elementType("byte") + DenseVectorFieldBuilderFn.build(field).string shouldBe + """{"type":"dense_vector","element_type":"byte","dims":3,"index":false}""" } - "A DenseVectorField" should "don't support a flat type of kNN algorithm for a index_options if a index property is false" in { - val field = DenseVectorField( - name = "myfield", - dims = 3, - similarity = L2Norm, - indexOptions = Some(FlatIndexOptions()) - ) - val jsonStringValue = """{"type":"dense_vector","dims":3,"index":false,"similarity":"l2_norm"}""" - ElasticFieldBuilderFn(field).string shouldBe jsonStringValue - ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonStringValue)) shouldBe (field.copy(indexOptions = None)) + it should "not set similarity or indexOptions when index = false" in { + val field = DenseVectorField(name = "myfield", dims = Some(3), index = Some(false), indexOptions = Some(denseVectorIndexOptions)) + DenseVectorFieldBuilderFn.build(field).string shouldBe + """{"type":"dense_vector","dims":3,"index":false}""" } - "A DenseVectorField" should "support a int8_flat type of kNN algorithm for a index_options if a index property is true" in { - val field = DenseVectorField( - name = "myfield", - dims = 3, - index = true, - similarity = L2Norm, - indexOptions = Some(Int8FlatIndexOptions(confidenceInterval = Some(0.5d))) - ) - val jsonStringValue = """{"type":"dense_vector","dims":3,"index":true,"similarity":"l2_norm","index_options":{"type":"int8_flat","confidence_interval":0.5}}""" - ElasticFieldBuilderFn(field).string shouldBe jsonStringValue - ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonStringValue)) shouldBe (field) + it should "support indexOptions property" in { + val field = DenseVectorField(name = "myfield", dims = Some(3), index = Some(true), indexOptions = Some(denseVectorIndexOptions)) + DenseVectorFieldBuilderFn.build(field).string shouldBe + """{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"int8_hnsw","m":10,"ef_construction":100,"confidence_interval":1.0}}""" } - "A DenseVectorField" should "don't support a int8_flat type of kNN algorithm for a index_options if a index property is false" in { - val field = DenseVectorField( - name = "myfield", - dims = 3, - similarity = L2Norm, - indexOptions = Some(Int8FlatIndexOptions(confidenceInterval = Some(0.5d))) - ) - val jsonStringValue = """{"type":"dense_vector","dims":3,"index":false,"similarity":"l2_norm"}""" - ElasticFieldBuilderFn(field).string shouldBe jsonStringValue - ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonStringValue)) shouldBe (field.copy(indexOptions = None)) + it should "support all index options types and only set m, efConstruction and confidenceInterval when applicable" in { + val field = DenseVectorField(name = "myfield", dims = Some(3), index = Some(true), indexOptions = Some(denseVectorIndexOptions)) + DenseVectorFieldBuilderFn.build(field).string shouldBe + """{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"int8_hnsw","m":10,"ef_construction":100,"confidence_interval":1.0}}""" + DenseVectorFieldBuilderFn.build(field.indexOptions(denseVectorIndexOptions.copy(`type` = Hnsw))).string shouldBe + """{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"hnsw","m":10,"ef_construction":100}}""" + DenseVectorFieldBuilderFn.build(field.indexOptions(denseVectorIndexOptions.copy(`type` = Flat))).string shouldBe + """{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"flat"}}""" + DenseVectorFieldBuilderFn.build(field.indexOptions(denseVectorIndexOptions.copy(`type` = Int8Flat))).string shouldBe + """{"type":"dense_vector","dims":3,"index":true,"index_options":{"type":"int8_flat","confidence_interval":1.0}}""" } } diff --git a/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/fields/DenseVectorField.scala b/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/fields/DenseVectorField.scala index 2815cf914..7c91c1e3c 100644 --- a/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/fields/DenseVectorField.scala +++ b/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/fields/DenseVectorField.scala @@ -2,36 +2,58 @@ package com.sksamuel.elastic4s.fields object DenseVectorField { val `type`: String = "dense_vector" + + sealed trait KnnType { + def name: String + } + case object Hnsw extends KnnType { val name = "hnsw" } + case object Int8Hnsw extends KnnType { val name = "int8_hnsw" } + case object Flat extends KnnType { val name = "flat" } + case object Int8Flat extends KnnType { val name = "int8_flat" } + + @deprecated("Use the new apply method", "8.14.0") + def apply(name: String, + dims: Int): DenseVectorField = + DenseVectorField(name, None, Some(dims), Some(false), Some(L2Norm)) + + @deprecated("Use the new apply method", "8.14.0") + def apply(name: String, + dims: Int, + index: Boolean): DenseVectorField = + DenseVectorField(name, None, Some(dims), Some(index), Some(L2Norm)) + + @deprecated("Use the new apply method", "8.14.0") + def apply(name: String, + dims: Int, + index: Boolean, + similarity: Similarity): DenseVectorField = + DenseVectorField(name, None, Some(dims), Some(index), Some(similarity)) } + sealed trait Similarity { def name: String } + case object L2Norm extends Similarity { val name = "l2_norm" } case object DotProduct extends Similarity { val name = "dot_product" } case object Cosine extends Similarity { val name = "cosine" } +case object MaxInnerProduct extends Similarity { val name = "max_inner_product" } + +case class DenseVectorIndexOptions(`type`: DenseVectorField.KnnType, m: Option[Int] = None, efConstruction: Option[Int] = None, confidenceInterval: Option[Float] = None) { + +} case class DenseVectorField(name: String, - dims: Int, - index: Boolean = false, - similarity: Similarity = L2Norm, + elementType: Option[String] = None, + dims: Option[Int] = None, + index: Option[Boolean] = None, + similarity: Option[Similarity] = None, indexOptions: Option[DenseVectorIndexOptions] = None) extends ElasticField { - override def `type`: String = DenseVectorField.`type` -} + override def `type`: String = DenseVectorField.`type` -sealed trait DenseVectorIndexOptions { - def `type`: String -} -case class HnswIndexOptions(m: Option[Int] = None, efConstruction: Option[Int] = None) extends DenseVectorIndexOptions { - val `type`: String = "hnsw" -} -case class Int8HnswIndexOptions(m: Option[Int] = None, - efConstruction: Option[Int] = None, - confidenceInterval: Option[Double] = None) extends DenseVectorIndexOptions { - val `type`: String = "int8_hnsw" -} -case class FlatIndexOptions() extends DenseVectorIndexOptions { - val `type`: String = "flat" -} -case class Int8FlatIndexOptions(confidenceInterval: Option[Double] = None) extends DenseVectorIndexOptions { - val `type`: String = "int8_flat" + def elementType(elementType: String): DenseVectorField = copy(elementType = Some(elementType)) + def dims(dims: Int): DenseVectorField = copy(dims = Some(dims)) + def index(index: Boolean): DenseVectorField = copy(index = Some(index)) + def similarity(similarity: Similarity): DenseVectorField = copy(similarity = Some(similarity)) + def indexOptions(indexOptions: DenseVectorIndexOptions): DenseVectorField = copy(indexOptions = Some(indexOptions)) } diff --git a/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/fields/DenseVectorFieldBuilderFn.scala b/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/fields/DenseVectorFieldBuilderFn.scala index 2414ef618..1fa27c03d 100644 --- a/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/fields/DenseVectorFieldBuilderFn.scala +++ b/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/fields/DenseVectorFieldBuilderFn.scala @@ -1,55 +1,66 @@ package com.sksamuel.elastic4s.handlers.fields -import com.sksamuel.elastic4s.fields.{DenseVectorField, DenseVectorIndexOptions, FlatIndexOptions, HnswIndexOptions, Int8FlatIndexOptions, Int8HnswIndexOptions} +import com.sksamuel.elastic4s.fields.DenseVectorField.{Hnsw, Int8Flat, Int8Hnsw} +import com.sksamuel.elastic4s.fields.{Cosine, DenseVectorField, DenseVectorIndexOptions, DotProduct, L2Norm, MaxInnerProduct, Similarity} import com.sksamuel.elastic4s.json.{XContentBuilder, XContentFactory} object DenseVectorFieldBuilderFn { + private def similarityFromString(similarity: String): Similarity = similarity match { + case "l2_norm" => L2Norm + case "dot_product" => DotProduct + case "cosine" => Cosine + case "max_inner_product" => MaxInnerProduct + } - private def getIndexOptions(values: Map[String, Any]): DenseVectorIndexOptions= + private def getIndexOptions(values: Map[String, Any]): DenseVectorIndexOptions = values("type").asInstanceOf[String] match { - case "hnsw" => HnswIndexOptions( + case "hnsw" => DenseVectorIndexOptions( + DenseVectorField.Hnsw, values.get("m").map(_.asInstanceOf[Int]), values.get("ef_construction").map(_.asInstanceOf[Int]) ) - case "int8_hnsw" => Int8HnswIndexOptions( + case "int8_hnsw" => DenseVectorIndexOptions( + DenseVectorField.Int8Hnsw, values.get("m").map(_.asInstanceOf[Int]), values.get("ef_construction").map(_.asInstanceOf[Int]), - values.get("confidence_interval").map(_.asInstanceOf[Double]) + values.get("confidence_interval").map(d => d.asInstanceOf[Double].toFloat) + ) + case "flat" => DenseVectorIndexOptions( + DenseVectorField.Flat + ) + case "int8_flat" => DenseVectorIndexOptions( + DenseVectorField.Int8Flat, + None, + None, + values.get("confidence_interval").map(d => d.asInstanceOf[Double].toFloat) ) - case "flat" => FlatIndexOptions() - case "int8_flat" => Int8FlatIndexOptions(values.get("confidence_interval").map(_.asInstanceOf[Double])) } def toField(name: String, values: Map[String, Any]): DenseVectorField = DenseVectorField( name, - values("dims").asInstanceOf[Int], - values("index").asInstanceOf[Boolean], - indexOptions = values.get("index_options").map(_.asInstanceOf[Map[String, Any]]).map(getIndexOptions) + values.get("element_type").map(_.asInstanceOf[String]), + values.get("dims").map(_.asInstanceOf[Int]), + values.get("index").map(_.asInstanceOf[Boolean]), + values.get("similarity").map(s => similarityFromString(s.asInstanceOf[String])), + indexOptions = values.get("index_options").map(_.asInstanceOf[Map[String, Any]]).map(getIndexOptions), ) def build(field: DenseVectorField): XContentBuilder = { - val builder = XContentFactory.jsonBuilder() builder.field("type", field.`type`) - builder.field("dims", field.dims) - builder.field("index", field.index) - builder.field("similarity", field.similarity.name) - field.indexOptions.filter(_ => field.index).foreach { options => - builder.startObject("index_options") - builder.field("type", options.`type`) - options match { - case HnswIndexOptions(m, efConstruction) => - m.foreach(builder.field("m", _)) - efConstruction.foreach(builder.field("ef_construction", _)) - case Int8HnswIndexOptions(m, efConstruction, confidenceInterval) => - m.foreach(builder.field("m", _)) - efConstruction.foreach(builder.field("ef_construction", _)) - confidenceInterval.foreach(builder.field("confidence_interval", _)) - case FlatIndexOptions() => () - case Int8FlatIndexOptions(confidenceInterval) => - confidenceInterval.foreach(builder.field("confidence_interval", _)) + field.elementType.foreach(builder.field("element_type", _)) + field.dims.foreach(builder.field("dims", _)) + field.index.foreach(builder.field("index", _)) + if (field.index.getOrElse(true)) { + field.similarity.foreach(similarity => builder.field("similarity", similarity.name)) + field.indexOptions.foreach { options => + builder.startObject("index_options") + builder.field("type", options.`type`.name) + if (Seq(Hnsw, Int8Hnsw).contains(options.`type`)) options.m.foreach(builder.field("m", _)) + if (Seq(Hnsw, Int8Hnsw).contains(options.`type`)) options.efConstruction.foreach(builder.field("ef_construction", _)) + if (Seq(Int8Hnsw, Int8Flat).contains(options.`type`)) options.confidenceInterval.foreach(builder.field("confidence_interval", _)) + builder.endObject() } - builder.endObject() } builder.endObject() } diff --git a/elastic4s-tests/src/test/scala/com/sksamuel/elastic4s/fields/ElasticFieldBuilderFnTest.scala b/elastic4s-tests/src/test/scala/com/sksamuel/elastic4s/fields/ElasticFieldBuilderFnTest.scala index f59030373..b8085cffe 100644 --- a/elastic4s-tests/src/test/scala/com/sksamuel/elastic4s/fields/ElasticFieldBuilderFnTest.scala +++ b/elastic4s-tests/src/test/scala/com/sksamuel/elastic4s/fields/ElasticFieldBuilderFnTest.scala @@ -112,6 +112,25 @@ class ElasticFieldBuilderFnTest extends AnyWordSpec with Matchers { ElasticFieldBuilderFn.construct(fieldSet.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonString)) shouldBe(fieldSet) } - } + "support DenseVectorField" in { + val field = DenseVectorField("dense_vector_field", elementType = Some("byte"), dims = Some(3), index = Some(true), indexOptions = Some(DenseVectorIndexOptions(DenseVectorField.Flat))) + val jsonString = """{"type":"dense_vector","element_type":"byte","dims":3,"index":true,"index_options":{"type":"flat"}}""" + ElasticFieldBuilderFn(field).string shouldBe jsonString + ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonString)) shouldBe field + } + + "support DenseVectorField with similarity" in { + val field = DenseVectorField("dense_vector_field", elementType = Some("byte"), dims = Some(3), index = Some(true), similarity = Some(MaxInnerProduct)) + val jsonString = """{"type":"dense_vector","element_type":"byte","dims":3,"index":true,"similarity":"max_inner_product"}""" + ElasticFieldBuilderFn(field).string shouldBe jsonString + ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonString)) shouldBe field + } + "support DenseVectorField with all index options" in { + val field = DenseVectorField("dense_vector_field", elementType = Some("byte"), dims = Some(3), index = Some(true), indexOptions = Some(DenseVectorIndexOptions(DenseVectorField.Int8Hnsw, Some(100), Some(200), Some(0.5f)))) + val jsonString = """{"type":"dense_vector","element_type":"byte","dims":3,"index":true,"index_options":{"type":"int8_hnsw","m":100,"ef_construction":200,"confidence_interval":0.5}}""" + ElasticFieldBuilderFn(field).string shouldBe jsonString + ElasticFieldBuilderFn.construct(field.name, JacksonSupport.mapper.readValue[Map[String, Any]](jsonString)) shouldBe field + } + } }