From 8b889646b6bf442c788740270a06be95f1d12e52 Mon Sep 17 00:00:00 2001 From: Rhys Date: Mon, 4 Nov 2024 10:41:52 -0500 Subject: [PATCH] chore(generative-ai): add url endpoints for compass web usage COMPASS-8376 (#6423) --- package-lock.json | 4 + packages/atlas-service/src/atlas-service.ts | 8 +- .../modules/pipeline-builder/pipeline-ai.ts | 31 +- packages/compass-generative-ai/package.json | 2 + .../src/atlas-ai-service.spec.ts | 557 ++++++++++-------- .../src/atlas-ai-service.ts | 125 +++- .../compass-generative-ai/src/provider.tsx | 38 +- .../src/stores/ai-query-reducer.ts | 31 +- packages/compass-web/src/entrypoint.tsx | 4 +- .../compass/src/app/components/entrypoint.tsx | 4 +- 10 files changed, 488 insertions(+), 316 deletions(-) diff --git a/package-lock.json b/package-lock.json index 8d1ec369fa8..e51763a1c15 100644 --- a/package-lock.json +++ b/package-lock.json @@ -44752,6 +44752,7 @@ "dependencies": { "@mongodb-js/atlas-service": "^0.31.0", "@mongodb-js/compass-components": "^1.31.0", + "@mongodb-js/compass-connections": "^1.45.0", "@mongodb-js/compass-intercom": "^0.13.2", "@mongodb-js/compass-logging": "^1.4.9", "bson": "^6.8.0", @@ -44762,6 +44763,7 @@ "react": "^17.0.2" }, "devDependencies": { + "@mongodb-js/connection-info": "^0.9.2", "@mongodb-js/eslint-config-compass": "^1.1.7", "@mongodb-js/mocha-config-compass": "^1.4.2", "@mongodb-js/prettier-config-compass": "^1.0.2", @@ -55973,8 +55975,10 @@ "requires": { "@mongodb-js/atlas-service": "^0.31.0", "@mongodb-js/compass-components": "^1.31.0", + "@mongodb-js/compass-connections": "^1.45.0", "@mongodb-js/compass-intercom": "^0.13.2", "@mongodb-js/compass-logging": "^1.4.9", + "@mongodb-js/connection-info": "^0.9.2", "@mongodb-js/eslint-config-compass": "^1.1.7", "@mongodb-js/mocha-config-compass": "^1.4.2", "@mongodb-js/prettier-config-compass": "^1.0.2", diff --git a/packages/atlas-service/src/atlas-service.ts b/packages/atlas-service/src/atlas-service.ts index dd80e78bfa3..528e4ef6de5 100644 --- a/packages/atlas-service/src/atlas-service.ts +++ b/packages/atlas-service/src/atlas-service.ts @@ -29,12 +29,8 @@ export class AtlasService { ) { this.config = getAtlasConfig(preferences); } - adminApiEndpoint(path?: string, requestId?: string): string { - const uri = `${this.config.atlasApiBaseUrl}${normalizePath(path)}`; - const query = requestId - ? `?request_id=${encodeURIComponent(requestId)}` - : ''; - return `${uri}${query}`; + adminApiEndpoint(path?: string): string { + return `${this.config.atlasApiBaseUrl}${normalizePath(path)}`; } cloudEndpoint(path?: string): string { return `${this.config.cloudBaseUrl}${normalizePath(path)}`; diff --git a/packages/compass-aggregations/src/modules/pipeline-builder/pipeline-ai.ts b/packages/compass-aggregations/src/modules/pipeline-builder/pipeline-ai.ts index ff5882dde11..5c60b683f8c 100644 --- a/packages/compass-aggregations/src/modules/pipeline-builder/pipeline-ai.ts +++ b/packages/compass-aggregations/src/modules/pipeline-builder/pipeline-ai.ts @@ -281,20 +281,23 @@ export const runAIPipelineGeneration = ( const { collection: collectionName, database: databaseName } = toNS(namespace); - jsonResponse = await atlasAiService.getAggregationFromUserInput({ - signal: abortController.signal, - userInput, - collectionName, - databaseName, - schema, - // Provide sample documents when the user has opted in in their settings. - ...(provideSampleDocuments - ? { - sampleDocuments, - } - : undefined), - requestId, - }); + jsonResponse = await atlasAiService.getAggregationFromUserInput( + { + signal: abortController.signal, + userInput, + collectionName, + databaseName, + schema, + // Provide sample documents when the user has opted in in their settings. + ...(provideSampleDocuments + ? { + sampleDocuments, + } + : undefined), + requestId, + }, + connectionInfo + ); } catch (err: any) { if (signal.aborted) { // If we already aborted so we ignore the error. diff --git a/packages/compass-generative-ai/package.json b/packages/compass-generative-ai/package.json index bbf9b1142dc..3f4ffbf2ba8 100644 --- a/packages/compass-generative-ai/package.json +++ b/packages/compass-generative-ai/package.json @@ -54,6 +54,7 @@ "dependencies": { "@mongodb-js/atlas-service": "^0.31.0", "@mongodb-js/compass-components": "^1.31.0", + "@mongodb-js/compass-connections": "^1.45.0", "@mongodb-js/compass-intercom": "^0.13.2", "@mongodb-js/compass-logging": "^1.4.9", "bson": "^6.8.0", @@ -64,6 +65,7 @@ "react": "^17.0.2" }, "devDependencies": { + "@mongodb-js/connection-info": "^0.9.2", "@mongodb-js/eslint-config-compass": "^1.1.7", "@mongodb-js/mocha-config-compass": "^1.4.2", "@mongodb-js/prettier-config-compass": "^1.0.2", diff --git a/packages/compass-generative-ai/src/atlas-ai-service.spec.ts b/packages/compass-generative-ai/src/atlas-ai-service.spec.ts index b274e6d0ad7..ad904fb4a68 100644 --- a/packages/compass-generative-ai/src/atlas-ai-service.spec.ts +++ b/packages/compass-generative-ai/src/atlas-ai-service.spec.ts @@ -5,6 +5,7 @@ import type { PreferencesAccess } from 'compass-preferences-model'; import { createSandboxFromDefaultPreferences } from 'compass-preferences-model'; import { createNoopLogger } from '@mongodb-js/compass-logging/provider'; import { ObjectId } from 'mongodb'; +import type { ConnectionInfo } from '@mongodb-js/connection-info'; const ATLAS_USER = { firstName: 'John', @@ -21,12 +22,28 @@ const PREFERENCES_USER = { const BASE_URL = 'http://example.com'; +const mockConnectionInfo: ConnectionInfo = { + id: 'TEST', + connectionOptions: { + connectionString: 'mongodb://localhost:27020', + }, + atlasMetadata: { + orgId: 'testOrg', + projectId: 'testProject', + clusterName: 'pineapple', + regionalBaseUrl: 'https://example.com', + metricsId: 'metricsId', + metricsType: 'replicaSet', + instanceSize: 'M10', + clusterType: 'REPLICASET', + clusterUniqueId: 'clusterUniqueId', + }, +}; + class MockAtlasService { getCurrentUser = () => Promise.resolve(ATLAS_USER); - adminApiEndpoint = (url: string, requestId?: string) => - `${[BASE_URL, url].join('/')}${ - requestId ? `?request_id=${requestId}` : '' - }`; + cloudEndpoint = (url: string) => `${['/cloud', url].join('/')}`; + adminApiEndpoint = (url: string) => `${[BASE_URL, url].join('/')}`; authenticatedFetch = (url: string, init: RequestInit) => { return fetch(url, init); }; @@ -45,7 +62,6 @@ function makeResponse(content: any) { describe('AtlasAiService', function () { let sandbox: Sinon.SinonSandbox; - let atlasAiService: AtlasAiService; let preferences: PreferencesAccess; const initialFetch = global.fetch; @@ -53,12 +69,6 @@ describe('AtlasAiService', function () { sandbox = Sinon.createSandbox(); preferences = await createSandboxFromDefaultPreferences(); preferences['getPreferencesUser'] = () => PREFERENCES_USER; - - atlasAiService = new AtlasAiService( - new MockAtlasService() as any, - preferences, - createNoopLogger() - ); }); afterEach(function () { @@ -66,278 +76,333 @@ describe('AtlasAiService', function () { global.fetch = initialFetch; }); - describe('ai api calls', function () { - beforeEach(async function () { - // Enable the AI feature - const fetchStub = sandbox.stub().resolves( - makeResponse({ - features: { - GEN_AI_COMPASS: { - enabled: true, - }, - }, - }) - ); - global.fetch = fetchStub; - await atlasAiService['setupAIAccess'](); - global.fetch = initialFetch; - }); - - const atlasAIServiceTests = [ - { - functionName: 'getQueryFromUserInput', - aiEndpoint: 'mql-query', - responses: { - success: { - content: { query: { filter: "{ test: 'pineapple' }" } }, - }, - invalid: [ - [undefined, 'internal server error'], - [{}, 'unexpected response'], - [{ countent: {} }, 'unexpected response'], - [{ content: { qooery: {} } }, 'unexpected keys'], - [ - { content: { query: { filter: { foo: 1 } } } }, - 'unexpected response', - ], - ], - }, + const endpointBasepathTests = [ + { + apiURLPreset: 'admin-api', + expectedEndpoints: { + 'user-access': 'http://example.com/unauth/ai/api/v1/hello/1234', + 'mql-aggregation': `http://example.com/ai/api/v1/mql-aggregation?request_id=abc`, + 'mql-query': `http://example.com/ai/api/v1/mql-query?request_id=abc`, }, - { - functionName: 'getAggregationFromUserInput', - aiEndpoint: 'mql-aggregation', - responses: { - success: { - content: { aggregation: { pipeline: "[{ test: 'pineapple' }]" } }, - }, - invalid: [ - [undefined, 'internal server error'], - [{}, 'unexpected response'], - [{ content: { aggregation: {} } }, 'unexpected response'], - [{ content: { aggrogation: {} } }, 'unexpected keys'], - [ - { content: { aggregation: { pipeline: true } } }, - 'unexpected response', - ], - ], - }, + }, + { + apiURLPreset: 'cloud', + expectedEndpoints: { + 'user-access': '/cloud/ai/v1/hello/1234', + 'mql-aggregation': + '/cloud/ai/v1/groups/testProject/mql-aggregation?request_id=abc', + 'mql-query': '/cloud/ai/v1/groups/testProject/mql-query?request_id=abc', }, - ] as const; - - for (const { functionName, aiEndpoint, responses } of atlasAIServiceTests) { - describe(functionName, function () { - it('makes a post request with the user input to the endpoint in the environment', async function () { - const fetchStub = sandbox - .stub() - .resolves(makeResponse(responses.success)); - global.fetch = fetchStub; - - const res = await atlasAiService[functionName]({ - userInput: 'test', - signal: new AbortController().signal, - collectionName: 'jam', - databaseName: 'peanut', - schema: { _id: { types: [{ bsonType: 'ObjectId' }] } }, - sampleDocuments: [ - { _id: new ObjectId('642d766b7300158b1f22e972') }, - ], - requestId: 'abc', - }); - - expect(fetchStub).to.have.been.calledOnce; - - const { args } = fetchStub.firstCall; + }, + ] as const; + + for (const { apiURLPreset, expectedEndpoints } of endpointBasepathTests) { + describe(`api URL Preset "${apiURLPreset}"`, function () { + let atlasAiService: AtlasAiService; + + beforeEach(function () { + const mockAtlasService = new MockAtlasService(); + atlasAiService = new AtlasAiService({ + apiURLPreset, + atlasService: mockAtlasService as any, + preferences, + logger: createNoopLogger(), + }); + }); - expect(args[0]).to.eq( - `http://example.com/ai/api/v1/${aiEndpoint}?request_id=abc` - ); - expect(args[1].body).to.eq( - '{"userInput":"test","collectionName":"jam","databaseName":"peanut","schema":{"_id":{"types":[{"bsonType":"ObjectId"}]}},"sampleDocuments":[{"_id":{"$oid":"642d766b7300158b1f22e972"}}]}' + describe('ai api calls', function () { + beforeEach(async function () { + // Enable the AI feature + const fetchStub = sandbox.stub().resolves( + makeResponse({ + features: { + GEN_AI_COMPASS: { + enabled: true, + }, + }, + }) ); - expect(res).to.deep.eq(responses.success); + global.fetch = fetchStub; + await atlasAiService['setupAIAccess'](); + global.fetch = initialFetch; }); - it('should fail when response is not matching expected schema', async function () { - for (const [res, error] of responses.invalid) { - const fetchStub = sandbox.stub().resolves(makeResponse(res)); - global.fetch = fetchStub; - - try { - await atlasAiService[functionName]({ - userInput: 'test', - collectionName: 'test', - databaseName: 'peanut', - requestId: 'abc', - signal: new AbortController().signal, - }); - expect.fail(`Expected ${functionName} to throw`); - } catch (err) { - expect((err as Error).message).to.match(new RegExp(error, 'i')); - } - } - }); + const atlasAIServiceTests = [ + { + functionName: 'getQueryFromUserInput', + aiEndpoint: 'mql-query', + responses: { + success: { + content: { query: { filter: "{ test: 'pineapple' }" } }, + }, + invalid: [ + [undefined, 'internal server error'], + [{}, 'unexpected response'], + [{ countent: {} }, 'unexpected response'], + [{ content: { qooery: {} } }, 'unexpected keys'], + [ + { content: { query: { filter: { foo: 1 } } } }, + 'unexpected response', + ], + ], + }, + }, + { + functionName: 'getAggregationFromUserInput', + aiEndpoint: 'mql-aggregation', + responses: { + success: { + content: { + aggregation: { pipeline: "[{ test: 'pineapple' }]" }, + }, + }, + invalid: [ + [undefined, 'internal server error'], + [{}, 'unexpected response'], + [{ content: { aggregation: {} } }, 'unexpected response'], + [{ content: { aggrogation: {} } }, 'unexpected keys'], + [ + { content: { aggregation: { pipeline: true } } }, + 'unexpected response', + ], + ], + }, + }, + ] as const; + + for (const { + functionName, + aiEndpoint, + responses, + } of atlasAIServiceTests) { + describe(functionName, function () { + it('makes a post request with the user input to the endpoint in the environment', async function () { + const fetchStub = sandbox + .stub() + .resolves(makeResponse(responses.success)); + global.fetch = fetchStub; + + const res = await atlasAiService[functionName]( + { + userInput: 'test', + signal: new AbortController().signal, + collectionName: 'jam', + databaseName: 'peanut', + schema: { _id: { types: [{ bsonType: 'ObjectId' }] } }, + sampleDocuments: [ + { _id: new ObjectId('642d766b7300158b1f22e972') }, + ], + requestId: 'abc', + }, + mockConnectionInfo + ); + + expect(fetchStub).to.have.been.calledOnce; + + const { args } = fetchStub.firstCall; + + expect(args[0]).to.eq(expectedEndpoints[aiEndpoint]); + expect(args[1].body).to.eq( + '{"userInput":"test","collectionName":"jam","databaseName":"peanut","schema":{"_id":{"types":[{"bsonType":"ObjectId"}]}},"sampleDocuments":[{"_id":{"$oid":"642d766b7300158b1f22e972"}}]}' + ); + expect(res).to.deep.eq(responses.success); + }); - it('throws if the request would be too much for the ai', async function () { - try { - await atlasAiService[functionName]({ - userInput: 'test', - collectionName: 'test', - databaseName: 'peanut', - sampleDocuments: [{ test: '4'.repeat(5120001) }], - requestId: 'abc', - signal: new AbortController().signal, + it('should fail when response is not matching expected schema', async function () { + for (const [res, error] of responses.invalid) { + const fetchStub = sandbox.stub().resolves(makeResponse(res)); + global.fetch = fetchStub; + + try { + await atlasAiService[functionName]( + { + userInput: 'test', + collectionName: 'test', + databaseName: 'peanut', + requestId: 'abc', + signal: new AbortController().signal, + }, + mockConnectionInfo + ); + expect.fail(`Expected ${functionName} to throw`); + } catch (err) { + expect((err as Error).message).to.match( + new RegExp(error, 'i') + ); + } + } }); - expect.fail(`Expected ${functionName} to throw`); - } catch (err) { - expect(err).to.have.property( - 'message', - 'Sorry, your request is too large. Please use a smaller prompt or try using this feature on a collection with smaller documents.' - ); - } - }); - it('passes fewer documents if the request would be too much for the ai with all of the documents', async function () { - const fetchStub = sandbox - .stub() - .resolves(makeResponse(responses.success)); - global.fetch = fetchStub; + it('throws if the request would be too much for the ai', async function () { + try { + await atlasAiService[functionName]( + { + userInput: 'test', + collectionName: 'test', + databaseName: 'peanut', + sampleDocuments: [{ test: '4'.repeat(5120001) }], + requestId: 'abc', + signal: new AbortController().signal, + }, + mockConnectionInfo + ); + expect.fail(`Expected ${functionName} to throw`); + } catch (err) { + expect(err).to.have.property( + 'message', + 'Sorry, your request is too large. Please use a smaller prompt or try using this feature on a collection with smaller documents.' + ); + } + }); - await atlasAiService[functionName]({ - userInput: 'test', - collectionName: 'test.test', - databaseName: 'peanut', - sampleDocuments: [ - { a: '1' }, - { a: '2' }, - { a: '3' }, - { a: '4'.repeat(5120001) }, - ], - requestId: 'abc', - signal: new AbortController().signal, + it('passes fewer documents if the request would be too much for the ai with all of the documents', async function () { + const fetchStub = sandbox + .stub() + .resolves(makeResponse(responses.success)); + global.fetch = fetchStub; + + await atlasAiService[functionName]( + { + userInput: 'test', + collectionName: 'test.test', + databaseName: 'peanut', + sampleDocuments: [ + { a: '1' }, + { a: '2' }, + { a: '3' }, + { a: '4'.repeat(5120001) }, + ], + requestId: 'abc', + signal: new AbortController().signal, + }, + mockConnectionInfo + ); + + const { args } = fetchStub.firstCall; + + expect(fetchStub).to.have.been.calledOnce; + expect(args[1].body).to.eq( + '{"userInput":"test","collectionName":"test.test","databaseName":"peanut","sampleDocuments":[{"a":"1"}]}' + ); + }); }); + } + }); - const { args } = fetchStub.firstCall; - - expect(fetchStub).to.have.been.calledOnce; - expect(args[1].body).to.eq( - '{"userInput":"test","collectionName":"test.test","databaseName":"peanut","sampleDocuments":[{"a":"1"}]}' - ); + describe('setupAIAccess', function () { + beforeEach(async function () { + await preferences.savePreferences({ + cloudFeatureRolloutAccess: undefined, + }); }); - }); - } - }); - describe('setupAIAccess', function () { - beforeEach(async function () { - await preferences.savePreferences({ - cloudFeatureRolloutAccess: undefined, - }); - }); + it('should set the cloudFeatureRolloutAccess true when returned true', async function () { + const fetchStub = sandbox.stub().resolves( + makeResponse({ + features: { + GEN_AI_COMPASS: { + enabled: true, + }, + }, + }) + ); + global.fetch = fetchStub; - it('should set the cloudFeatureRolloutAccess true when returned true', async function () { - const fetchStub = sandbox.stub().resolves( - makeResponse({ - features: { - GEN_AI_COMPASS: { - enabled: true, - }, - }, - }) - ); - global.fetch = fetchStub; + let currentCloudFeatureRolloutAccess = + preferences.getPreferences().cloudFeatureRolloutAccess; + expect(currentCloudFeatureRolloutAccess).to.equal(undefined); - let currentCloudFeatureRolloutAccess = - preferences.getPreferences().cloudFeatureRolloutAccess; - expect(currentCloudFeatureRolloutAccess).to.equal(undefined); + await atlasAiService['setupAIAccess'](); - await atlasAiService['setupAIAccess'](); + const { args } = fetchStub.firstCall; - const { args } = fetchStub.firstCall; + expect(fetchStub).to.have.been.calledOnce; - expect(fetchStub).to.have.been.calledOnce; - expect(args[0]).to.contain('ai/api/v1/hello/1234'); + expect(args[0]).to.equal(expectedEndpoints['user-access']); - currentCloudFeatureRolloutAccess = - preferences.getPreferences().cloudFeatureRolloutAccess; - expect(currentCloudFeatureRolloutAccess).to.deep.equal({ - GEN_AI_COMPASS: true, - }); - }); + currentCloudFeatureRolloutAccess = + preferences.getPreferences().cloudFeatureRolloutAccess; + expect(currentCloudFeatureRolloutAccess).to.deep.equal({ + GEN_AI_COMPASS: true, + }); + }); - it('should set the cloudFeatureRolloutAccess false when returned false', async function () { - const fetchStub = sandbox.stub().resolves( - makeResponse({ - features: { - GEN_AI_COMPASS: { - enabled: false, - }, - }, - }) - ); - global.fetch = fetchStub; + it('should set the cloudFeatureRolloutAccess false when returned false', async function () { + const fetchStub = sandbox.stub().resolves( + makeResponse({ + features: { + GEN_AI_COMPASS: { + enabled: false, + }, + }, + }) + ); + global.fetch = fetchStub; - let currentCloudFeatureRolloutAccess = - preferences.getPreferences().cloudFeatureRolloutAccess; - expect(currentCloudFeatureRolloutAccess).to.equal(undefined); + let currentCloudFeatureRolloutAccess = + preferences.getPreferences().cloudFeatureRolloutAccess; + expect(currentCloudFeatureRolloutAccess).to.equal(undefined); - await atlasAiService['setupAIAccess'](); + await atlasAiService['setupAIAccess'](); - const { args } = fetchStub.firstCall; + const { args } = fetchStub.firstCall; - expect(fetchStub).to.have.been.calledOnce; - expect(args[0]).to.contain('ai/api/v1/hello/1234'); + expect(fetchStub).to.have.been.calledOnce; + expect(args[0]).to.equal(expectedEndpoints['user-access']); - currentCloudFeatureRolloutAccess = - preferences.getPreferences().cloudFeatureRolloutAccess; - expect(currentCloudFeatureRolloutAccess).to.deep.equal({ - GEN_AI_COMPASS: false, - }); - }); + currentCloudFeatureRolloutAccess = + preferences.getPreferences().cloudFeatureRolloutAccess; + expect(currentCloudFeatureRolloutAccess).to.deep.equal({ + GEN_AI_COMPASS: false, + }); + }); - it('should set the cloudFeatureRolloutAccess false when returned null', async function () { - const fetchStub = sandbox.stub().resolves( - makeResponse({ - features: null, - }) - ); - global.fetch = fetchStub; + it('should set the cloudFeatureRolloutAccess false when returned null', async function () { + const fetchStub = sandbox.stub().resolves( + makeResponse({ + features: null, + }) + ); + global.fetch = fetchStub; - let currentCloudFeatureRolloutAccess = - preferences.getPreferences().cloudFeatureRolloutAccess; - expect(currentCloudFeatureRolloutAccess).to.equal(undefined); + let currentCloudFeatureRolloutAccess = + preferences.getPreferences().cloudFeatureRolloutAccess; + expect(currentCloudFeatureRolloutAccess).to.equal(undefined); - await atlasAiService['setupAIAccess'](); + await atlasAiService['setupAIAccess'](); - const { args } = fetchStub.firstCall; + const { args } = fetchStub.firstCall; - expect(fetchStub).to.have.been.calledOnce; - expect(args[0]).to.contain('ai/api/v1/hello/1234'); + expect(fetchStub).to.have.been.calledOnce; + expect(args[0]).to.equal(expectedEndpoints['user-access']); - currentCloudFeatureRolloutAccess = - preferences.getPreferences().cloudFeatureRolloutAccess; - expect(currentCloudFeatureRolloutAccess).to.deep.equal({ - GEN_AI_COMPASS: false, - }); - }); + currentCloudFeatureRolloutAccess = + preferences.getPreferences().cloudFeatureRolloutAccess; + expect(currentCloudFeatureRolloutAccess).to.deep.equal({ + GEN_AI_COMPASS: false, + }); + }); - it('should not set the cloudFeatureRolloutAccess false when returned false', async function () { - const fetchStub = sandbox.stub().throws(new Error('error')); - global.fetch = fetchStub; + it('should not set the cloudFeatureRolloutAccess false when returned false', async function () { + const fetchStub = sandbox.stub().throws(new Error('error')); + global.fetch = fetchStub; - let currentCloudFeatureRolloutAccess = - preferences.getPreferences().cloudFeatureRolloutAccess; - expect(currentCloudFeatureRolloutAccess).to.equal(undefined); + let currentCloudFeatureRolloutAccess = + preferences.getPreferences().cloudFeatureRolloutAccess; + expect(currentCloudFeatureRolloutAccess).to.equal(undefined); - await atlasAiService['setupAIAccess'](); + await atlasAiService['setupAIAccess'](); - const { args } = fetchStub.firstCall; + const { args } = fetchStub.firstCall; - expect(fetchStub).to.have.been.calledOnce; - expect(args[0]).to.contain('ai/api/v1/hello/1234'); + expect(fetchStub).to.have.been.calledOnce; + expect(args[0]).to.equal(expectedEndpoints['user-access']); - currentCloudFeatureRolloutAccess = - preferences.getPreferences().cloudFeatureRolloutAccess; - expect(currentCloudFeatureRolloutAccess).to.deep.equal(undefined); + currentCloudFeatureRolloutAccess = + preferences.getPreferences().cloudFeatureRolloutAccess; + expect(currentCloudFeatureRolloutAccess).to.deep.equal(undefined); + }); + }); }); - }); + } }); diff --git a/packages/compass-generative-ai/src/atlas-ai-service.ts b/packages/compass-generative-ai/src/atlas-ai-service.ts index 5cc49001be7..b5038c1faf9 100644 --- a/packages/compass-generative-ai/src/atlas-ai-service.ts +++ b/packages/compass-generative-ai/src/atlas-ai-service.ts @@ -5,6 +5,7 @@ import { } from 'compass-preferences-model/provider'; import type { AtlasService } from '@mongodb-js/atlas-service/provider'; import { AtlasServiceError } from '@mongodb-js/atlas-service/renderer'; +import type { ConnectionInfo } from '@mongodb-js/compass-connections/provider'; import type { Document } from 'mongodb'; import type { Logger } from '@mongodb-js/compass-logging'; import { EJSON } from 'bson'; @@ -23,9 +24,6 @@ type GenerativeAiInput = { // want to ensure we're not uploading massive documents (some folks have documents > 1mb). const AI_MAX_REQUEST_SIZE = 5120000; const AI_MIN_SAMPLE_DOCUMENTS = 1; -const USER_AI_URI = (userId: string) => `unauth/ai/api/v1/hello/${userId}`; -const AGGREGATION_URI = 'ai/api/v1/mql-aggregation'; -const QUERY_URI = 'ai/api/v1/mql-query'; type AIAggregation = { content: { @@ -192,17 +190,83 @@ export function validateAIAggregationResponse( } } +const aiURLConfig = { + // There are two different sets of endpoints we use for our requests. + // Down the line we'd like to only use the admin api, however, + // we cannot currently call that from the Atlas UI. Pending CLOUDP-251201 + 'admin-api': { + 'user-access': (userId: string) => `unauth/ai/api/v1/hello/${userId}`, + aggregation: 'ai/api/v1/mql-aggregation', + query: 'ai/api/v1/mql-query', + }, + cloud: { + 'user-access': (userId: string) => `ai/v1/hello/${userId}`, + aggregation: (groupId: string) => `ai/v1/groups/${groupId}/mql-aggregation`, + query: (groupId: string) => `ai/v1/groups/${groupId}/mql-query`, + }, +} as const; +type AIEndpoint = 'user-access' | 'query' | 'aggregation'; + export class AtlasAiService { private initPromise: Promise | null = null; - constructor( - private atlasService: AtlasService, - private preferences: PreferencesAccess, - private logger: Logger - ) { + private apiURLPreset: 'admin-api' | 'cloud'; + private atlasService: AtlasService; + private preferences: PreferencesAccess; + private logger: Logger; + + constructor({ + apiURLPreset, + atlasService, + preferences, + logger, + }: { + apiURLPreset: 'admin-api' | 'cloud'; + atlasService: AtlasService; + preferences: PreferencesAccess; + logger: Logger; + }) { + this.apiURLPreset = apiURLPreset; + this.atlasService = atlasService; + this.preferences = preferences; + this.logger = logger; + this.initPromise = this.setupAIAccess(); } + private getUrlForEndpoint( + urlId: AIEndpoint, + connectionInfo?: ConnectionInfo + ) { + if (this.apiURLPreset === 'cloud') { + if (urlId === 'user-access') { + return this.atlasService.cloudEndpoint( + aiURLConfig[this.apiURLPreset][urlId]( + this.preferences.getPreferencesUser().id + ) + ); + } + + const atlasMetadata = connectionInfo?.atlasMetadata; + if (!atlasMetadata) { + throw new Error( + "Can't perform generative ai request: atlasMetadata is not available" + ); + } + + return this.atlasService.cloudEndpoint( + aiURLConfig[this.apiURLPreset][urlId](atlasMetadata.projectId) + ); + } + const urlConfig = aiURLConfig[this.apiURLPreset][urlId]; + const urlPath = + typeof urlConfig === 'function' + ? urlConfig(this.preferences.getPreferencesUser().id) + : urlConfig; + + return this.atlasService.adminApiEndpoint(urlPath); + } + private throwIfAINotEnabled() { if (process.env.COMPASS_E2E_SKIP_ATLAS_SIGNIN === 'true') { return; @@ -215,8 +279,8 @@ export class AtlasAiService { } private async getAIFeatureEnablement(): Promise { - const userId = this.preferences.getPreferencesUser().id; - const url = this.atlasService.adminApiEndpoint(USER_AI_URI(userId)); + const url = this.getUrlForEndpoint('user-access'); + const res = await this.atlasService.fetch(url, { headers: { Accept: 'application/json', @@ -261,8 +325,16 @@ export class AtlasAiService { } private getQueryOrAggregationFromUserInput = async ( - uri: string, - input: GenerativeAiInput, + { + urlId, + input, + connectionInfo, + }: { + urlId: 'query' | 'aggregation'; + input: GenerativeAiInput; + + connectionInfo?: ConnectionInfo; + }, validationFn: (res: any) => asserts res is T ): Promise => { await this.initPromise; @@ -271,7 +343,10 @@ export class AtlasAiService { const { signal, requestId, ...rest } = input; const msgBody = buildQueryOrAggregationMessageBody(rest); - const url = this.atlasService.adminApiEndpoint(uri, requestId); + const url = `${this.getUrlForEndpoint( + urlId, + connectionInfo + )}?request_id=${encodeURIComponent(requestId)}`; this.logger.log.info( this.logger.mongoLogId(1_001_000_308), @@ -326,18 +401,30 @@ export class AtlasAiService { return data; }; - async getAggregationFromUserInput(input: GenerativeAiInput) { + async getAggregationFromUserInput( + input: GenerativeAiInput, + connectionInfo: ConnectionInfo + ) { return this.getQueryOrAggregationFromUserInput( - AGGREGATION_URI, - input, + { + connectionInfo, + urlId: 'aggregation', + input, + }, validateAIAggregationResponse ); } - async getQueryFromUserInput(input: GenerativeAiInput) { + async getQueryFromUserInput( + input: GenerativeAiInput, + connectionInfo: ConnectionInfo + ) { return this.getQueryOrAggregationFromUserInput( - QUERY_URI, - input, + { + urlId: 'query', + input, + connectionInfo, + }, validateAIQueryResponse ); } diff --git a/packages/compass-generative-ai/src/provider.tsx b/packages/compass-generative-ai/src/provider.tsx index 6df896c8439..05f3bd58db6 100644 --- a/packages/compass-generative-ai/src/provider.tsx +++ b/packages/compass-generative-ai/src/provider.tsx @@ -10,23 +10,31 @@ import { const AtlasAiServiceContext = createContext(null); -export const AtlasAiServiceProvider: React.FC = createServiceProvider( - function AtlasAiServiceProvider({ children }) { - const logger = useLogger('ATLAS-AI-SERVICE'); - const preferences = preferencesLocator(); - const atlasService = atlasServiceLocator(); +export const AtlasAiServiceProvider: React.FC<{ + apiURLPreset: 'admin-api' | 'cloud'; +}> = createServiceProvider(function AtlasAiServiceProvider({ + apiURLPreset, + children, +}) { + const logger = useLogger('ATLAS-AI-SERVICE'); + const preferences = preferencesLocator(); + const atlasService = atlasServiceLocator(); - const aiService = useMemo(() => { - return new AtlasAiService(atlasService, preferences, logger); - }, [preferences, logger, atlasService]); + const aiService = useMemo(() => { + return new AtlasAiService({ + apiURLPreset, + atlasService, + preferences, + logger, + }); + }, [apiURLPreset, preferences, logger, atlasService]); - return ( - - {children} - - ); - } -); + return ( + + {children} + + ); +}); function useAtlasAiServiceContext(): AtlasAiService { const service = useContext(AtlasAiServiceContext); diff --git a/packages/compass-query-bar/src/stores/ai-query-reducer.ts b/packages/compass-query-bar/src/stores/ai-query-reducer.ts index 87a9609175b..9ed8735b0b2 100644 --- a/packages/compass-query-bar/src/stores/ai-query-reducer.ts +++ b/packages/compass-query-bar/src/stores/ai-query-reducer.ts @@ -221,20 +221,23 @@ export const runAIQuery = ( const { collection: collectionName, database: databaseName } = toNS(namespace); - jsonResponse = await atlasAiService.getQueryFromUserInput({ - signal: abortController.signal, - userInput, - collectionName, - databaseName, - schema, - // Provide sample documents when the user has opted in in their settings. - ...(provideSampleDocuments - ? { - sampleDocuments, - } - : undefined), - requestId, - }); + jsonResponse = await atlasAiService.getQueryFromUserInput( + { + signal: abortController.signal, + userInput, + collectionName, + databaseName, + schema, + // Provide sample documents when the user has opted in in their settings. + ...(provideSampleDocuments + ? { + sampleDocuments, + } + : undefined), + requestId, + }, + connectionInfo + ); } catch (err: any) { if (signal.aborted) { // If we already aborted so we ignore the error. diff --git a/packages/compass-web/src/entrypoint.tsx b/packages/compass-web/src/entrypoint.tsx index e2103fe782b..13225a3c833 100644 --- a/packages/compass-web/src/entrypoint.tsx +++ b/packages/compass-web/src/entrypoint.tsx @@ -63,7 +63,9 @@ const WithAtlasProviders: React.FC = ({ children }) => { return ( - {children} + + {children} + ); diff --git a/packages/compass/src/app/components/entrypoint.tsx b/packages/compass/src/app/components/entrypoint.tsx index 1f4c0af4d65..c952028d18f 100644 --- a/packages/compass/src/app/components/entrypoint.tsx +++ b/packages/compass/src/app/components/entrypoint.tsx @@ -61,7 +61,9 @@ export const WithAtlasProviders: React.FC = ({ children }) => { }, }} > - {children} + + {children} + );