From f6cb441db3a7e735d50d8f1730ae2f3a2af61aa0 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sat, 5 Oct 2024 19:45:33 -0700 Subject: [PATCH] Added Hamming distance for MySQL --- README.md | 19 +++++++++++++++++++ lib/neighbor/model.rb | 7 ++++++- lib/neighbor/utils.rb | 9 ++++++++- test/mysql_bit_test.rb | 35 +++++++++++++++++++++++++++++++++++ test/support/mysql.rb | 2 ++ 5 files changed, 70 insertions(+), 2 deletions(-) create mode 100644 test/mysql_bit_test.rb diff --git a/README.md b/README.md index 17e0e28..d110549 100644 --- a/README.md +++ b/README.md @@ -417,9 +417,28 @@ Supported values are: - `euclidean` - `cosine` +- `hamming` Note: The `DISTANCE()` function is [only available on HeatWave](https://dev.mysql.com/doc/refman/9.0/en/vector-functions.html) +### Binary Vectors + +Use the `binary` type to store binary vectors + +```ruby +class AddEmbeddingToItems < ActiveRecord::Migration[7.2] + def change + add_column :items, :embedding, :binary + end +end +``` + +Get the nearest neighbors by Hamming distance + +```ruby +Item.nearest_neighbors(:embedding, "\x05", distance: "hamming").first(5) +``` + ## Examples - [Embeddings](#openai-embeddings) with OpenAI diff --git a/lib/neighbor/model.rb b/lib/neighbor/model.rb index 33f5bfd..5b10e71 100644 --- a/lib/neighbor/model.rb +++ b/lib/neighbor/model.rb @@ -66,6 +66,7 @@ def self.neighbor_attributes dimensions = v[:dimensions] dimensions ||= column_info&.limit unless column_info&.type == :binary type = v[:type] || column_info&.type + type = :bit if type == :binary && adapter == :mysql if !Neighbor::Utils.validate_dimensions(value, type, dimensions, adapter).nil? errors.add(k, "must have #{dimensions} dimensions") @@ -144,7 +145,11 @@ def self.neighbor_attributes when :mariadb "VEC_DISTANCE(#{quoted_attribute}, #{query})" when :mysql - "DISTANCE(#{quoted_attribute}, #{query}, #{connection.quote(operator)})" + if operator == "BIT_COUNT" + "BIT_COUNT(#{quoted_attribute} ^ #{query})" + else + "DISTANCE(#{quoted_attribute}, #{query}, #{connection.quote(operator)})" + end else if operator == "#" "bit_count(#{quoted_attribute} # #{query})" diff --git a/lib/neighbor/utils.rb b/lib/neighbor/utils.rb index c964a2b..b782005 100644 --- a/lib/neighbor/utils.rb +++ b/lib/neighbor/utils.rb @@ -2,7 +2,7 @@ module Neighbor module Utils def self.validate_dimensions(value, type, expected, adapter) dimensions = type == :sparsevec ? value.dimensions : value.size - dimensions *= 8 if type == :bit && adapter == :sqlite + dimensions *= 8 if type == :bit && [:sqlite, :mysql].include?(adapter) if expected && dimensions != expected "Expected #{expected} dimensions, not #{dimensions}" end @@ -20,6 +20,8 @@ def self.validate_finite(value, type) end def self.validate(value, dimensions:, type:, adapter:) + type = :bit if type == :binary && adapter == :mysql + if (message = validate_dimensions(value, type, dimensions, adapter)) raise Error, message end @@ -87,6 +89,11 @@ def self.operator(adapter, column_type, distance) when "euclidean" "EUCLIDEAN" end + when :binary + case distance + when "hamming" + "BIT_COUNT" + end else raise ArgumentError, "Unsupported type: #{column_type}" end diff --git a/test/mysql_bit_test.rb b/test/mysql_bit_test.rb new file mode 100644 index 0000000..2698407 --- /dev/null +++ b/test/mysql_bit_test.rb @@ -0,0 +1,35 @@ +require_relative "test_helper" +require_relative "support/mysql" + +class MysqlBitTest < Minitest::Test + def setup + MysqlItem.delete_all + end + + def test_hamming + create_bit_items + result = MysqlItem.find(1).nearest_neighbors(:binary_embedding, distance: "hamming").first(3) + assert_equal [2, 3], result.map(&:id) + assert_elements_in_delta [2, 3].map { |v| v * 1024 }, result.map(&:neighbor_distance) + end + + def test_hamming_scope + create_bit_items + result = MysqlItem.nearest_neighbors(:binary_embedding, "\x05" * 1024, distance: "hamming").first(5) + assert_equal [2, 3, 1], result.map(&:id) + assert_elements_in_delta [0, 1, 2].map { |v| v * 1024 }, result.map(&:neighbor_distance) + end + + def test_invalid_dimensions + error = assert_raises(ActiveRecord::RecordInvalid) do + MysqlItem.create!(binary_embedding: "\x00" * 1024 + "\x11") + end + assert_equal "Validation failed: Binary embedding must have 8192 dimensions", error.message + end + + def create_bit_items + MysqlItem.create!(id: 1, binary_embedding: "\x00" * 1024) + MysqlItem.create!(id: 2, binary_embedding: "\x05" * 1024) + MysqlItem.create!(id: 3, binary_embedding: "\x07" * 1024) + end +end diff --git a/test/support/mysql.rb b/test/support/mysql.rb index 92c3703..409bcdf 100644 --- a/test/support/mysql.rb +++ b/test/support/mysql.rb @@ -12,11 +12,13 @@ class MysqlRecord < ActiveRecord::Base MysqlRecord.connection.instance_eval do create_table :mysql_items, force: true do |t| t.vector :embedding, limit: 3 + t.binary :binary_embedding end end class MysqlItem < MysqlRecord has_neighbors :embedding + has_neighbors :binary_embedding, dimensions: 8192 end # ensure has_neighbors does not cause model schema to load