Skip to content

Commit

Permalink
feat: MVI method updates (#959)
Browse files Browse the repository at this point in the history
* feat: MVI method updates

Rename addItemBatch to upsertItemBatch to reflect its new upsert
functionality.

Make metadataFields in SearchOptions either an array of metadata
fields, or a sentinel value representing all metadata.

Add a similarity metric enum to createIndex to let the user specify
which metric to use when creating an index.
  • Loading branch information
nand4011 authored Oct 5, 2023
1 parent 0a644f2 commit a228607
Show file tree
Hide file tree
Showing 18 changed files with 1,540 additions and 511 deletions.
776 changes: 718 additions & 58 deletions packages/client-sdk-nodejs/package-lock.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion packages/client-sdk-nodejs/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"uuid": "8.3.2"
},
"dependencies": {
"@gomomento/generated-types": "0.77.0",
"@gomomento/generated-types": "0.84.0",
"@gomomento/sdk-core": "file:../core",
"@grpc/grpc-js": "1.9.0",
"@types/google-protobuf": "3.15.6",
Expand Down
125 changes: 34 additions & 91 deletions packages/client-sdk-nodejs/src/internal/vector-index-control-client.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import {control} from '@gomomento/generated-types';
import grpcControl = control.control_client;
import {Header, HeaderInterceptorProvider} from './grpc/headers-interceptor';
import {ClientTimeoutInterceptor} from './grpc/client-timeout-interceptor';
import {Status} from '@grpc/grpc-js/build/src/constants';
import {cacheServiceErrorMapper} from '../errors/cache-service-error-mapper';
import {ChannelCredentials, Interceptor} from '@grpc/grpc-js';
import {
ListCaches,
CreateSigningKey,
ListSigningKeys,
RevokeSigningKey,
CredentialProvider,
InvalidArgumentError,
MomentoLogger,
VectorIndexConfiguration,
} from '..';
Expand All @@ -20,16 +16,18 @@ import {GrpcClientWrapper} from './grpc/grpc-client-wrapper';
import {
validateIndexName,
validateNumDimensions,
validateTtlMinutes,
} from '@gomomento/sdk-core/dist/src/internal/utils';
import {normalizeSdkError} from '@gomomento/sdk-core/dist/src/errors';
import {_SigningKey} from '@gomomento/sdk-core/dist/src/messages/responses/grpc-response-types';
import {
CreateVectorIndex,
DeleteVectorIndex,
ListVectorIndexes,
} from '@gomomento/sdk-core';
import {IVectorIndexControlClient} from '@gomomento/sdk-core/dist/src/internal/clients';
import {
IVectorIndexControlClient,
VectorSimilarityMetric,
} from '@gomomento/sdk-core/dist/src/internal/clients';
import grpcControl = control.control_client;

export interface ControlClientProps {
configuration: VectorIndexConfiguration;
Expand Down Expand Up @@ -73,7 +71,8 @@ export class VectorIndexControlClient implements IVectorIndexControlClient {

public async createIndex(
indexName: string,
numDimensions: number
numDimensions: number,
similarityMetric?: VectorSimilarityMetric
): Promise<CreateVectorIndex.Response> {
try {
validateIndexName(indexName);
Expand All @@ -85,6 +84,31 @@ export class VectorIndexControlClient implements IVectorIndexControlClient {
const request = new grpcControl._CreateIndexRequest();
request.index_name = indexName;
request.num_dimensions = numDimensions;

similarityMetric ??= VectorSimilarityMetric.COSINE_SIMILARITY;

switch (similarityMetric) {
case VectorSimilarityMetric.INNER_PRODUCT:
request.inner_product =
new grpcControl._CreateIndexRequest._InnerProduct();
break;
case VectorSimilarityMetric.EUCLIDEAN_SIMILARITY:
request.euclidean_similarity =
new grpcControl._CreateIndexRequest._EuclideanSimilarity();
break;
case VectorSimilarityMetric.COSINE_SIMILARITY:
request.cosine_similarity =
new grpcControl._CreateIndexRequest._CosineSimilarity();
break;
default:
return new CreateVectorIndex.Error(
new InvalidArgumentError(
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`Invalid similarity metric: ${similarityMetric}`
)
);
}

return await new Promise<CreateVectorIndex.Response>(resolve => {
this.clientWrapper.getClient().CreateIndex(
request,
Expand All @@ -107,7 +131,7 @@ export class VectorIndexControlClient implements IVectorIndexControlClient {
});
}

public async listIndexes(): Promise<ListCaches.Response> {
public async listIndexes(): Promise<ListVectorIndexes.Response> {
const request = new grpcControl._ListIndexesRequest();
this.logger.debug("Issuing 'listIndexes' request");
return await new Promise<ListVectorIndexes.Response>(resolve => {
Expand Down Expand Up @@ -160,85 +184,4 @@ export class VectorIndexControlClient implements IVectorIndexControlClient {
);
});
}

public async createSigningKey(
ttlMinutes: number,
endpoint: string
): Promise<CreateSigningKey.Response> {
try {
validateTtlMinutes(ttlMinutes);
} catch (err) {
return new CreateSigningKey.Error(normalizeSdkError(err as Error));
}
this.logger.debug("Issuing 'createSigningKey' request");
const request = new grpcControl._CreateSigningKeyRequest();
request.ttl_minutes = ttlMinutes;
return await new Promise<CreateSigningKey.Response>(resolve => {
this.clientWrapper
.getClient()
.CreateSigningKey(
request,
{interceptors: this.interceptors},
(err, resp) => {
if (err) {
resolve(new CreateSigningKey.Error(cacheServiceErrorMapper(err)));
} else {
const signingKey = new _SigningKey(resp?.key, resp?.expires_at);
resolve(new CreateSigningKey.Success(endpoint, signingKey));
}
}
);
});
}

public async revokeSigningKey(
keyId: string
): Promise<RevokeSigningKey.Response> {
const request = new grpcControl._RevokeSigningKeyRequest();
request.key_id = keyId;
this.logger.debug("Issuing 'revokeSigningKey' request");
return await new Promise<RevokeSigningKey.Response>(resolve => {
this.clientWrapper
.getClient()
.RevokeSigningKey(request, {interceptors: this.interceptors}, err => {
if (err) {
resolve(new RevokeSigningKey.Error(cacheServiceErrorMapper(err)));
} else {
resolve(new RevokeSigningKey.Success());
}
});
});
}

public async listSigningKeys(
endpoint: string
): Promise<ListSigningKeys.Response> {
const request = new grpcControl._ListSigningKeysRequest();
request.next_token = '';
this.logger.debug("Issuing 'listSigningKeys' request");
return await new Promise<ListSigningKeys.Response>(resolve => {
this.clientWrapper
.getClient()
.ListSigningKeys(
request,
{interceptors: this.interceptors},
(err, resp) => {
if (err || !resp) {
resolve(new ListSigningKeys.Error(cacheServiceErrorMapper(err)));
} else {
const signingKeys = resp.signing_key.map(
sk => new _SigningKey(sk.key_id, sk.expires_at)
);
resolve(
new ListSigningKeys.Success(
endpoint,
signingKeys,
resp.next_token
)
);
}
}
);
});
}
}
42 changes: 25 additions & 17 deletions packages/client-sdk-nodejs/src/internal/vector-index-data-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import {
MomentoLogger,
MomentoLoggerFactory,
SearchOptions,
VectorAddItemBatch,
VectorDeleteItemBatch,
VectorSearch,
VectorUpsertItemBatch,
} from '@gomomento/sdk-core';
import {VectorIndexClientProps} from '../vector-index-client-props';
import {VectorIndexConfiguration} from '../config/vector-index-configuration';
Expand All @@ -23,6 +23,7 @@ import {
validateTopK,
} from '@gomomento/sdk-core/dist/src/internal/utils';
import {normalizeSdkError} from '@gomomento/sdk-core/dist/src/errors';
import {ALL_VECTOR_METADATA} from '@gomomento/sdk-core/dist/src/clients/IVectorIndexClient';

export class VectorIndexDataClient implements IVectorIndexDataClient {
private readonly configuration: VectorIndexConfiguration;
Expand Down Expand Up @@ -72,23 +73,23 @@ export class VectorIndexDataClient implements IVectorIndexDataClient {
);
}

public async addItemBatch(
public async upsertItemBatch(
indexName: string,
items: Array<VectorIndexItem>
): Promise<VectorAddItemBatch.Response> {
): Promise<VectorUpsertItemBatch.Response> {
try {
validateIndexName(indexName);
} catch (err) {
return new VectorAddItemBatch.Error(normalizeSdkError(err as Error));
return new VectorUpsertItemBatch.Error(normalizeSdkError(err as Error));
}
return await this.sendAddItemBatch(indexName, items);
return await this.sendUpsertItemBatch(indexName, items);
}

private async sendAddItemBatch(
private async sendUpsertItemBatch(
indexName: string,
items: Array<VectorIndexItem>
): Promise<VectorAddItemBatch.Response> {
const request = new vectorindex._AddItemBatchRequest({
): Promise<VectorUpsertItemBatch.Response> {
const request = new vectorindex._UpsertItemBatchRequest({
index_name: indexName,
items: items.map(item => {
return new vectorindex._Item({
Expand All @@ -108,14 +109,16 @@ export class VectorIndexDataClient implements IVectorIndexDataClient {
}),
});
return await new Promise(resolve => {
this.client.AddItemBatch(
this.client.UpsertItemBatch(
request,
{interceptors: this.interceptors},
(err, resp) => {
if (resp) {
resolve(new VectorAddItemBatch.Success());
resolve(new VectorUpsertItemBatch.Success());
} else {
resolve(new VectorAddItemBatch.Error(cacheServiceErrorMapper(err)));
resolve(
new VectorUpsertItemBatch.Error(cacheServiceErrorMapper(err))
);
}
}
);
Expand Down Expand Up @@ -180,16 +183,21 @@ export class VectorIndexDataClient implements IVectorIndexDataClient {
queryVector: Array<number>,
options?: SearchOptions
): Promise<VectorSearch.Response> {
const metadataRequest = new vectorindex._MetadataRequest();
if (options?.metadataFields === ALL_VECTOR_METADATA) {
metadataRequest.all = new vectorindex._MetadataRequest.All();
} else {
metadataRequest.some = new vectorindex._MetadataRequest.Some({
fields:
options?.metadataFields === undefined ? [] : options.metadataFields,
});
}

const request = new vectorindex._SearchRequest({
index_name: indexName,
query_vector: new vectorindex._Vector({elements: queryVector}),
top_k: options?.topK,
metadata_fields: new vectorindex._MetadataRequest({
some: new vectorindex._MetadataRequest.Some({
fields:
options?.metadataFields === undefined ? [] : options.metadataFields,
}),
}),
metadata_fields: metadataRequest,
});

return await new Promise(resolve => {
Expand Down
14 changes: 7 additions & 7 deletions packages/client-sdk-web/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion packages/client-sdk-web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"xhr2": "^0.2.1"
},
"dependencies": {
"@gomomento/generated-types-webtext": "0.77.0",
"@gomomento/generated-types-webtext": "0.84.0",
"@gomomento/sdk-core": "file:../core",
"google-protobuf": "3.21.2",
"grpc-web": "1.4.2",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import {control} from '@gomomento/generated-types-webtext';
import {CredentialProvider, MomentoLogger, VectorIndexConfiguration} from '..';
import {
CredentialProvider,
InvalidArgumentError,
MomentoLogger,
VectorIndexConfiguration,
} from '..';
import {Request, StatusCode, UnaryResponse} from 'grpc-web';
import {
_CreateIndexRequest,
_ListIndexesRequest,
_DeleteIndexRequest,
} from '@gomomento/generated-types-webtext/dist/controlclient_pb';
import {cacheServiceErrorMapper} from '../errors/cache-service-error-mapper';
import {IVectorIndexControlClient} from '@gomomento/sdk-core/dist/src/internal/clients';
import {
IVectorIndexControlClient,
VectorSimilarityMetric,
} from '@gomomento/sdk-core/dist/src/internal/clients';
import {normalizeSdkError} from '@gomomento/sdk-core/dist/src/errors';
import {
validateIndexName,
Expand Down Expand Up @@ -60,7 +68,8 @@ export class VectorIndexControlClient<

public async createIndex(
indexName: string,
numDimensions: number
numDimensions: number,
similarityMetric?: VectorSimilarityMetric
): Promise<CreateVectorIndex.Response> {
try {
validateIndexName(indexName);
Expand All @@ -71,6 +80,32 @@ export class VectorIndexControlClient<
const request = new _CreateIndexRequest();
request.setIndexName(indexName);
request.setNumDimensions(numDimensions);

similarityMetric ??= VectorSimilarityMetric.COSINE_SIMILARITY;

switch (similarityMetric) {
case VectorSimilarityMetric.INNER_PRODUCT:
request.setInnerProduct(new _CreateIndexRequest._InnerProduct());
break;
case VectorSimilarityMetric.EUCLIDEAN_SIMILARITY:
request.setEuclideanSimilarity(
new _CreateIndexRequest._EuclideanSimilarity()
);
break;
case VectorSimilarityMetric.COSINE_SIMILARITY:
request.setCosineSimilarity(
new _CreateIndexRequest._CosineSimilarity()
);
break;
default:
return new CreateVectorIndex.Error(
new InvalidArgumentError(
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`Invalid similarity metric: ${similarityMetric}`
)
);
}

this.logger.debug("Issuing 'createIndex' request");
return await new Promise<CreateVectorIndex.Response>(resolve => {
this.clientWrapper.createIndex(
Expand Down
Loading

0 comments on commit a228607

Please sign in to comment.