From 39cf6e7bdb87671f08894907069a47c92574519a Mon Sep 17 00:00:00 2001 From: Rhys Howell Date: Tue, 29 Oct 2024 16:50:06 -0400 Subject: [PATCH] chore(generative-ai): add url endpoints for compass web usage --- packages/atlas-service/src/atlas-service.ts | 8 +-- .../src/atlas-ai-service.spec.ts | 29 ++++++--- .../src/atlas-ai-service.ts | 64 ++++++++++++++----- .../compass-generative-ai/src/provider.tsx | 64 +++++++++++++------ packages/compass-web/src/entrypoint.tsx | 23 +++++-- .../compass/src/app/components/entrypoint.tsx | 17 ++++- 6 files changed, 149 insertions(+), 56 deletions(-) diff --git a/packages/atlas-service/src/atlas-service.ts b/packages/atlas-service/src/atlas-service.ts index dd80e78bfa3..528e4ef6de5 100644 --- a/packages/atlas-service/src/atlas-service.ts +++ b/packages/atlas-service/src/atlas-service.ts @@ -29,12 +29,8 @@ export class AtlasService { ) { this.config = getAtlasConfig(preferences); } - adminApiEndpoint(path?: string, requestId?: string): string { - const uri = `${this.config.atlasApiBaseUrl}${normalizePath(path)}`; - const query = requestId - ? `?request_id=${encodeURIComponent(requestId)}` - : ''; - return `${uri}${query}`; + adminApiEndpoint(path?: string): string { + return `${this.config.atlasApiBaseUrl}${normalizePath(path)}`; } cloudEndpoint(path?: string): string { return `${this.config.cloudBaseUrl}${normalizePath(path)}`; diff --git a/packages/compass-generative-ai/src/atlas-ai-service.spec.ts b/packages/compass-generative-ai/src/atlas-ai-service.spec.ts index b274e6d0ad7..f825043bfc0 100644 --- a/packages/compass-generative-ai/src/atlas-ai-service.spec.ts +++ b/packages/compass-generative-ai/src/atlas-ai-service.spec.ts @@ -1,6 +1,10 @@ import Sinon from 'sinon'; import { expect } from 'chai'; -import { AtlasAiService } from './atlas-ai-service'; +import { + type AIEndpoint, + aiURLConfig, + AtlasAiService, +} from './atlas-ai-service'; import type { PreferencesAccess } from 'compass-preferences-model'; import { createSandboxFromDefaultPreferences } from 'compass-preferences-model'; import { createNoopLogger } from '@mongodb-js/compass-logging/provider'; @@ -20,13 +24,11 @@ const PREFERENCES_USER = { }; const BASE_URL = 'http://example.com'; +const urlConfig = aiURLConfig['admin-api']; class MockAtlasService { getCurrentUser = () => Promise.resolve(ATLAS_USER); - adminApiEndpoint = (url: string, requestId?: string) => - `${[BASE_URL, url].join('/')}${ - requestId ? `?request_id=${requestId}` : '' - }`; + adminApiEndpoint = (url: string) => `${[BASE_URL, url].join('/')}`; authenticatedFetch = (url: string, init: RequestInit) => { return fetch(url, init); }; @@ -54,11 +56,20 @@ describe('AtlasAiService', function () { preferences = await createSandboxFromDefaultPreferences(); preferences['getPreferencesUser'] = () => PREFERENCES_USER; - atlasAiService = new AtlasAiService( - new MockAtlasService() as any, + const mockAtlasService = new MockAtlasService(); + atlasAiService = new AtlasAiService({ + atlasService: mockAtlasService as any, + getUrlForEndpoint: (urlId: AIEndpoint) => { + const urlPath: string = + urlId === 'user-access' + ? urlConfig[urlId](PREFERENCES_USER.id) + : urlConfig[urlId]; + + return mockAtlasService.adminApiEndpoint(urlPath); + }, preferences, - createNoopLogger() - ); + logger: createNoopLogger(), + }); }); afterEach(function () { diff --git a/packages/compass-generative-ai/src/atlas-ai-service.ts b/packages/compass-generative-ai/src/atlas-ai-service.ts index 5cc49001be7..e6715cddec3 100644 --- a/packages/compass-generative-ai/src/atlas-ai-service.ts +++ b/packages/compass-generative-ai/src/atlas-ai-service.ts @@ -23,9 +23,6 @@ type GenerativeAiInput = { // want to ensure we're not uploading massive documents (some folks have documents > 1mb). const AI_MAX_REQUEST_SIZE = 5120000; const AI_MIN_SAMPLE_DOCUMENTS = 1; -const USER_AI_URI = (userId: string) => `unauth/ai/api/v1/hello/${userId}`; -const AGGREGATION_URI = 'ai/api/v1/mql-aggregation'; -const QUERY_URI = 'ai/api/v1/mql-query'; type AIAggregation = { content: { @@ -192,14 +189,49 @@ export function validateAIAggregationResponse( } } +export const aiURLConfig = { + // There are two different sets of endpoints we use for our requests. + // Down the line we'd like to only use the admin api, however, + // we cannot currently call that from the Atlas UI. Pending CLOUDP-251201 + 'admin-api': { + 'user-access': (userId: string) => `unauth/ai/api/v1/hello/${userId}`, + aggregation: 'ai/api/v1/mql-aggregation', + query: 'ai/api/v1/mql-query', + }, + cloud: { + 'user-access': (groupId: string, userId: string) => + `/ai/v1/groups/${groupId}/hello/${userId}`, + aggregation: (groupId: string) => + `/ai/v1/groups/${groupId}/mql-aggregation`, + query: (groupId: string) => `/ai/v1/groups/${groupId}/mql-query`, + }, +} as const; +export type AIEndpoint = 'user-access' | 'query' | 'aggregation'; + export class AtlasAiService { private initPromise: Promise | null = null; - constructor( - private atlasService: AtlasService, - private preferences: PreferencesAccess, - private logger: Logger - ) { + private atlasService: AtlasService; + private getUrlForEndpoint: (urlId: AIEndpoint) => string; + private preferences: PreferencesAccess; + private logger: Logger; + + constructor({ + atlasService, + getUrlForEndpoint, + preferences, + logger, + }: { + atlasService: AtlasService; + getUrlForEndpoint: (urlId: AIEndpoint) => string; + preferences: PreferencesAccess; + logger: Logger; + }) { + this.atlasService = atlasService; + this.getUrlForEndpoint = getUrlForEndpoint; + this.preferences = preferences; + this.logger = logger; + this.initPromise = this.setupAIAccess(); } @@ -215,8 +247,7 @@ export class AtlasAiService { } private async getAIFeatureEnablement(): Promise { - const userId = this.preferences.getPreferencesUser().id; - const url = this.atlasService.adminApiEndpoint(USER_AI_URI(userId)); + const url = this.getUrlForEndpoint('user-access'); const res = await this.atlasService.fetch(url, { headers: { Accept: 'application/json', @@ -261,7 +292,7 @@ export class AtlasAiService { } private getQueryOrAggregationFromUserInput = async ( - uri: string, + urlId: 'query' | 'aggregation', input: GenerativeAiInput, validationFn: (res: any) => asserts res is T ): Promise => { @@ -271,14 +302,15 @@ export class AtlasAiService { const { signal, requestId, ...rest } = input; const msgBody = buildQueryOrAggregationMessageBody(rest); - const url = this.atlasService.adminApiEndpoint(uri, requestId); + const url = new URL(this.getUrlForEndpoint(urlId)); + url.searchParams.append('request_id', encodeURIComponent(requestId)); this.logger.log.info( this.logger.mongoLogId(1_001_000_308), 'AtlasAIService', 'Running AI query generation request', { - url, + url: url.toString(), userInput: input.userInput, collectionName: input.collectionName, databaseName: input.databaseName, @@ -287,7 +319,7 @@ export class AtlasAiService { } ); - const res = await this.atlasService.authenticatedFetch(url, { + const res = await this.atlasService.authenticatedFetch(url.toString(), { signal, method: 'POST', body: msgBody, @@ -328,7 +360,7 @@ export class AtlasAiService { async getAggregationFromUserInput(input: GenerativeAiInput) { return this.getQueryOrAggregationFromUserInput( - AGGREGATION_URI, + 'aggregation', input, validateAIAggregationResponse ); @@ -336,7 +368,7 @@ export class AtlasAiService { async getQueryFromUserInput(input: GenerativeAiInput) { return this.getQueryOrAggregationFromUserInput( - QUERY_URI, + 'query', input, validateAIQueryResponse ); diff --git a/packages/compass-generative-ai/src/provider.tsx b/packages/compass-generative-ai/src/provider.tsx index 6df896c8439..8af0709c062 100644 --- a/packages/compass-generative-ai/src/provider.tsx +++ b/packages/compass-generative-ai/src/provider.tsx @@ -1,5 +1,5 @@ import React, { createContext, useContext, useMemo } from 'react'; -import { AtlasAiService } from './atlas-ai-service'; +import { type AIEndpoint, AtlasAiService } from './atlas-ai-service'; import { preferencesLocator } from 'compass-preferences-model/provider'; import { useLogger } from '@mongodb-js/compass-logging/provider'; import { atlasServiceLocator } from '@mongodb-js/atlas-service/provider'; @@ -10,23 +10,48 @@ import { const AtlasAiServiceContext = createContext(null); -export const AtlasAiServiceProvider: React.FC = createServiceProvider( - function AtlasAiServiceProvider({ children }) { - const logger = useLogger('ATLAS-AI-SERVICE'); - const preferences = preferencesLocator(); - const atlasService = atlasServiceLocator(); - - const aiService = useMemo(() => { - return new AtlasAiService(atlasService, preferences, logger); - }, [preferences, logger, atlasService]); - - return ( - - {children} - - ); - } -); +export type URLConfig = { + 'user-access': (userId: string) => string; + query: string; + aggregation: string; +}; + +export const AtlasAiServiceProvider: React.FC<{ + apiURLPreset: 'admin-api' | 'cloud'; + urlConfig: URLConfig; +}> = createServiceProvider(function AtlasAiServiceProvider({ + apiURLPreset, + children, + urlConfig, +}) { + const logger = useLogger('ATLAS-AI-SERVICE'); + const preferences = preferencesLocator(); + const atlasService = atlasServiceLocator(); + + const aiService = useMemo(() => { + const userId = preferences.getPreferencesUser().id; + + return new AtlasAiService({ + getUrlForEndpoint: (urlId: AIEndpoint) => { + const urlPath: string = + urlId === 'user-access' ? urlConfig[urlId](userId) : urlConfig[urlId]; + + return apiURLPreset === 'admin-api' + ? atlasService.adminApiEndpoint(urlPath) + : atlasService.cloudEndpoint(urlPath); + }, + atlasService, + preferences, + logger, + }); + }, [apiURLPreset, preferences, logger, atlasService, urlConfig]); + + return ( + + {children} + + ); +}); function useAtlasAiServiceContext(): AtlasAiService { const service = useContext(AtlasAiServiceContext); @@ -40,4 +65,5 @@ export const atlasAiServiceLocator = createServiceLocator( useAtlasAiServiceContext, 'atlasAiServiceLocator' ); -export { AtlasAiService } from './atlas-ai-service'; +export { AtlasAiService, aiURLConfig } from './atlas-ai-service'; +export type { AIEndpoint } from './atlas-ai-service'; diff --git a/packages/compass-web/src/entrypoint.tsx b/packages/compass-web/src/entrypoint.tsx index e2103fe782b..aa511809b17 100644 --- a/packages/compass-web/src/entrypoint.tsx +++ b/packages/compass-web/src/entrypoint.tsx @@ -44,7 +44,10 @@ import { import type { AllPreferences } from 'compass-preferences-model/provider'; import FieldStorePlugin from '@mongodb-js/compass-field-store'; import { AtlasServiceProvider } from '@mongodb-js/atlas-service/provider'; -import { AtlasAiServiceProvider } from '@mongodb-js/compass-generative-ai/provider'; +import { + AtlasAiServiceProvider, + aiURLConfig, +} from '@mongodb-js/compass-generative-ai/provider'; import { LoggerProvider } from '@mongodb-js/compass-logging/provider'; import { TelemetryProvider } from '@mongodb-js/compass-telemetry/provider'; import CompassConnections from '@mongodb-js/compass-connections'; @@ -59,11 +62,23 @@ import { useCompassWebLoggerAndTelemetry } from './logger-and-telemetry'; import { type TelemetryServiceOptions } from '@mongodb-js/compass-telemetry'; import { WorkspaceTab as WelcomeWorkspaceTab } from '@mongodb-js/compass-welcome'; -const WithAtlasProviders: React.FC = ({ children }) => { +const WithAtlasProviders: React.FC<{ + projectId: string; +}> = ({ children, projectId }) => { return ( - {children} + + aiURLConfig.cloud['user-access'](userId, projectId), + query: aiURLConfig.cloud.query(projectId), + aggregation: aiURLConfig.cloud.aggregation(projectId), + }} + > + {children} + ); @@ -315,7 +330,7 @@ const CompassWeb = ({ - + { }, }} > - {children} + + aiURLConfig['admin-api']['user-access'](userId), + query: aiURLConfig['admin-api'].query, + aggregation: aiURLConfig['admin-api'].aggregation, + }} + > + {children} + );