diff --git a/test/knn_test.rb b/test/knn_test.rb index f3a4be0d..534fbfd3 100644 --- a/test/knn_test.rb +++ b/test/knn_test.rb @@ -4,7 +4,10 @@ class KnnTest < Minitest::Test def test_works store [{name: "A", embedding: [1, 2, 3]}, {name: "B", embedding: [-1, -2, -3]}] assert_order "*", ["A", "B"], knn: {embedding: [1, 2, 3]} + expected = Searchkick.opensearch? ? [1, 0] : [2, 1] - assert_equal expected, Product.search(knn: {embedding: [1, 2, 3]}).hits.map { |v| v["_score"] } + scores = Product.search(knn: {embedding: [1, 2, 3]}).hits.map { |v| v["_score"] } + assert_in_delta expected[0], scores[0] + assert_in_delta expected[1], scores[1] end end diff --git a/test/models/product.rb b/test/models/product.rb index 7a197dae..a3fcc230 100644 --- a/test/models/product.rb +++ b/test/models/product.rb @@ -23,12 +23,6 @@ class Product match: ENV["MATCH"] ? ENV["MATCH"].to_sym : nil, knn: {embedding: {dimensions: 3}} - if ActiveRecord::VERSION::STRING.to_f >= 7.1 - serialize :embedding, coder: JSON - else - serialize :embedding, JSON - end - attr_accessor :conversions, :user_ids, :aisle, :details class << self diff --git a/test/support/activerecord.rb b/test/support/activerecord.rb index c8412648..7d7e1c78 100644 --- a/test/support/activerecord.rb +++ b/test/support/activerecord.rb @@ -76,6 +76,12 @@ class Product < ActiveRecord::Base belongs_to :store + + if ActiveRecord::VERSION::STRING.to_f >= 7.1 + serialize :embedding, coder: JSON + else + serialize :embedding, JSON + end end class Store < ActiveRecord::Base diff --git a/test/support/mongoid.rb b/test/support/mongoid.rb index 9bcd6bf3..137616e2 100644 --- a/test/support/mongoid.rb +++ b/test/support/mongoid.rb @@ -21,6 +21,7 @@ class Product field :longitude, type: BigDecimal field :description field :alt_description + field :embedding, type: Array end class Store