diff --git a/javascript/lib.cpp b/javascript/lib.cpp index c7036575..a01140a7 100644 --- a/javascript/lib.cpp +++ b/javascript/lib.cpp @@ -188,48 +188,13 @@ void Index::Add(Napi::CallbackInfo const& ctx) { if (ctx.Length() < 2) return Napi::TypeError::New(env, "Expects at least two arguments").ThrowAsJavaScriptException(); - std::vector keys; - std::vector vectors; - - Napi::Value arg = ctx[0]; - if (arg.IsArray()) { - Napi::Array keys_js = arg.As(); - for (std::size_t i = 0; i < keys_js.Length(); i++) { - Napi::Value key_js = keys_js[i]; - keys.push_back(key_js.As()); - } - } else if (arg.IsNumber()) - keys.push_back(arg.As()); - else - return Napi::TypeError::New(env, "Invalid argument type, expects integral key").ThrowAsJavaScriptException(); - - arg = ctx[1]; - if (arg.IsArray()) { - Napi::Array vectors_js = arg.As(); - for (std::size_t i = 0; i < vectors_js.Length(); i++) { - Napi::Value vector_js = vectors_js[i]; - vectors.push_back(vector_js.As()); - } - } else if (arg.IsTypedArray()) - vectors.push_back(arg.As()); - else - return Napi::TypeError::New(env, "Invalid argument type, expects float vector").ThrowAsJavaScriptException(); - - if (keys.size() != vectors.size()) - return Napi::TypeError::New(env, "The number of keys must match the number of vectors") - .ThrowAsJavaScriptException(); - using key_t = typename index_dense_t::key_t; std::size_t index_dimensions = native_->dimensions(); - if (native_->size() + keys.size() >= native_->capacity()) - native_->reserve(ceil2(native_->size() + keys.size())); - - for (std::size_t i = 0; i < keys.size(); i++) { - - key_t key = keys[i].Uint32Value(); - float const* vector = vectors[i].Data(); - std::size_t dimensions = static_cast(vectors[i].ElementLength()); + auto add = [&](Napi::Number key_js, Napi::Float32Array vector_js) { + key_t key = key_js.Uint32Value(); + float const* vector = vector_js.Data(); + std::size_t dimensions = static_cast(vector_js.ElementLength()); if (dimensions != index_dimensions) return Napi::TypeError::New(env, "Wrong number of dimensions").ThrowAsJavaScriptException(); @@ -241,14 +206,39 @@ void Index::Add(Napi::CallbackInfo const& ctx) { } catch (...) { return Napi::TypeError::New(env, "Insertion failed").ThrowAsJavaScriptException(); } - } + }; + + if (ctx[0].IsArray() && ctx[1].IsArray()) { + Napi::Array keys_js = ctx[0].As(); + Napi::Array vectors_js = ctx[1].As(); + auto length = keys_js.Length(); + + if (length != vectors_js.Length()) + return Napi::TypeError::New(env, "The number of keys must match the number of vectors") + .ThrowAsJavaScriptException(); + + if (native_->size() + length >= native_->capacity()) + native_->reserve(ceil2(native_->size() + length)); + + for (std::size_t i = 0; i < length; i++) { + Napi::Value key_js = keys_js[i]; + Napi::Value vector_js = vectors_js[i]; + add(key_js.As(), vector_js.As()); + } + + } else if (ctx[0].IsNumber() && ctx[1].IsTypedArray()) { + if (native_->size() + 1 >= native_->capacity()) + native_->reserve(ceil2(native_->size() + 1)); + add(ctx[0].As(), ctx[1].As()); + } else + return Napi::TypeError::New(env, "Invalid argument type, expects integral key(s) and float vector(s)") + .ThrowAsJavaScriptException(); } Napi::Value Index::Search(Napi::CallbackInfo const& ctx) { Napi::Env env = ctx.Env(); if (ctx.Length() < 2 || !ctx[0].IsTypedArray() || !ctx[1].IsNumber()) { - Napi::TypeError::New(env, "Expects a float vector and the number of wanted results") - .ThrowAsJavaScriptException(); + Napi::TypeError::New(env, "Expects a and the number of wanted results").ThrowAsJavaScriptException(); return {}; }