Skip to content

Commit

Permalink
chore(generative-ai): add url endpoints for compass web usage
Browse files Browse the repository at this point in the history
  • Loading branch information
Anemy committed Oct 29, 2024
1 parent 39d7222 commit 39cf6e7
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 56 deletions.
8 changes: 2 additions & 6 deletions packages/atlas-service/src/atlas-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,8 @@ export class AtlasService {
) {
this.config = getAtlasConfig(preferences);
}
adminApiEndpoint(path?: string, requestId?: string): string {
const uri = `${this.config.atlasApiBaseUrl}${normalizePath(path)}`;
const query = requestId
? `?request_id=${encodeURIComponent(requestId)}`
: '';
return `${uri}${query}`;
adminApiEndpoint(path?: string): string {
return `${this.config.atlasApiBaseUrl}${normalizePath(path)}`;
}
cloudEndpoint(path?: string): string {
return `${this.config.cloudBaseUrl}${normalizePath(path)}`;
Expand Down
29 changes: 20 additions & 9 deletions packages/compass-generative-ai/src/atlas-ai-service.spec.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import Sinon from 'sinon';
import { expect } from 'chai';
import { AtlasAiService } from './atlas-ai-service';
import {
type AIEndpoint,
aiURLConfig,
AtlasAiService,
} from './atlas-ai-service';
import type { PreferencesAccess } from 'compass-preferences-model';
import { createSandboxFromDefaultPreferences } from 'compass-preferences-model';
import { createNoopLogger } from '@mongodb-js/compass-logging/provider';
Expand All @@ -20,13 +24,11 @@ const PREFERENCES_USER = {
};

const BASE_URL = 'http://example.com';
const urlConfig = aiURLConfig['admin-api'];

class MockAtlasService {
getCurrentUser = () => Promise.resolve(ATLAS_USER);
adminApiEndpoint = (url: string, requestId?: string) =>
`${[BASE_URL, url].join('/')}${
requestId ? `?request_id=${requestId}` : ''
}`;
adminApiEndpoint = (url: string) => `${[BASE_URL, url].join('/')}`;
authenticatedFetch = (url: string, init: RequestInit) => {
return fetch(url, init);
};
Expand Down Expand Up @@ -54,11 +56,20 @@ describe('AtlasAiService', function () {
preferences = await createSandboxFromDefaultPreferences();
preferences['getPreferencesUser'] = () => PREFERENCES_USER;

atlasAiService = new AtlasAiService(
new MockAtlasService() as any,
const mockAtlasService = new MockAtlasService();
atlasAiService = new AtlasAiService({
atlasService: mockAtlasService as any,
getUrlForEndpoint: (urlId: AIEndpoint) => {
const urlPath: string =
urlId === 'user-access'
? urlConfig[urlId](PREFERENCES_USER.id)
: urlConfig[urlId];

return mockAtlasService.adminApiEndpoint(urlPath);
},
preferences,
createNoopLogger()
);
logger: createNoopLogger(),
});
});

afterEach(function () {
Expand Down
64 changes: 48 additions & 16 deletions packages/compass-generative-ai/src/atlas-ai-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ type GenerativeAiInput = {
// want to ensure we're not uploading massive documents (some folks have documents > 1mb).
const AI_MAX_REQUEST_SIZE = 5120000;
const AI_MIN_SAMPLE_DOCUMENTS = 1;
const USER_AI_URI = (userId: string) => `unauth/ai/api/v1/hello/${userId}`;
const AGGREGATION_URI = 'ai/api/v1/mql-aggregation';
const QUERY_URI = 'ai/api/v1/mql-query';

type AIAggregation = {
content: {
Expand Down Expand Up @@ -192,14 +189,49 @@ export function validateAIAggregationResponse(
}
}

export const aiURLConfig = {
// There are two different sets of endpoints we use for our requests.
// Down the line we'd like to only use the admin api, however,
// we cannot currently call that from the Atlas UI. Pending CLOUDP-251201
'admin-api': {
'user-access': (userId: string) => `unauth/ai/api/v1/hello/${userId}`,
aggregation: 'ai/api/v1/mql-aggregation',
query: 'ai/api/v1/mql-query',
},
cloud: {
'user-access': (groupId: string, userId: string) =>
`/ai/v1/groups/${groupId}/hello/${userId}`,
aggregation: (groupId: string) =>
`/ai/v1/groups/${groupId}/mql-aggregation`,
query: (groupId: string) => `/ai/v1/groups/${groupId}/mql-query`,
},
} as const;
export type AIEndpoint = 'user-access' | 'query' | 'aggregation';

export class AtlasAiService {
private initPromise: Promise<void> | null = null;

constructor(
private atlasService: AtlasService,
private preferences: PreferencesAccess,
private logger: Logger
) {
private atlasService: AtlasService;
private getUrlForEndpoint: (urlId: AIEndpoint) => string;
private preferences: PreferencesAccess;
private logger: Logger;

constructor({
atlasService,
getUrlForEndpoint,
preferences,
logger,
}: {
atlasService: AtlasService;
getUrlForEndpoint: (urlId: AIEndpoint) => string;
preferences: PreferencesAccess;
logger: Logger;
}) {
this.atlasService = atlasService;
this.getUrlForEndpoint = getUrlForEndpoint;
this.preferences = preferences;
this.logger = logger;

this.initPromise = this.setupAIAccess();
}

Expand All @@ -215,8 +247,7 @@ export class AtlasAiService {
}

private async getAIFeatureEnablement(): Promise<AIFeatureEnablement> {
const userId = this.preferences.getPreferencesUser().id;
const url = this.atlasService.adminApiEndpoint(USER_AI_URI(userId));
const url = this.getUrlForEndpoint('user-access');
const res = await this.atlasService.fetch(url, {
headers: {
Accept: 'application/json',
Expand Down Expand Up @@ -261,7 +292,7 @@ export class AtlasAiService {
}

private getQueryOrAggregationFromUserInput = async <T>(
uri: string,
urlId: 'query' | 'aggregation',
input: GenerativeAiInput,
validationFn: (res: any) => asserts res is T
): Promise<T> => {
Expand All @@ -271,14 +302,15 @@ export class AtlasAiService {
const { signal, requestId, ...rest } = input;
const msgBody = buildQueryOrAggregationMessageBody(rest);

const url = this.atlasService.adminApiEndpoint(uri, requestId);
const url = new URL(this.getUrlForEndpoint(urlId));
url.searchParams.append('request_id', encodeURIComponent(requestId));

this.logger.log.info(
this.logger.mongoLogId(1_001_000_308),
'AtlasAIService',
'Running AI query generation request',
{
url,
url: url.toString(),
userInput: input.userInput,
collectionName: input.collectionName,
databaseName: input.databaseName,
Expand All @@ -287,7 +319,7 @@ export class AtlasAiService {
}
);

const res = await this.atlasService.authenticatedFetch(url, {
const res = await this.atlasService.authenticatedFetch(url.toString(), {
signal,
method: 'POST',
body: msgBody,
Expand Down Expand Up @@ -328,15 +360,15 @@ export class AtlasAiService {

async getAggregationFromUserInput(input: GenerativeAiInput) {
return this.getQueryOrAggregationFromUserInput(
AGGREGATION_URI,
'aggregation',
input,
validateAIAggregationResponse
);
}

async getQueryFromUserInput(input: GenerativeAiInput) {
return this.getQueryOrAggregationFromUserInput(
QUERY_URI,
'query',
input,
validateAIQueryResponse
);
Expand Down
64 changes: 45 additions & 19 deletions packages/compass-generative-ai/src/provider.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import React, { createContext, useContext, useMemo } from 'react';
import { AtlasAiService } from './atlas-ai-service';
import { type AIEndpoint, AtlasAiService } from './atlas-ai-service';
import { preferencesLocator } from 'compass-preferences-model/provider';
import { useLogger } from '@mongodb-js/compass-logging/provider';
import { atlasServiceLocator } from '@mongodb-js/atlas-service/provider';
Expand All @@ -10,23 +10,48 @@ import {

const AtlasAiServiceContext = createContext<AtlasAiService | null>(null);

export const AtlasAiServiceProvider: React.FC = createServiceProvider(
function AtlasAiServiceProvider({ children }) {
const logger = useLogger('ATLAS-AI-SERVICE');
const preferences = preferencesLocator();
const atlasService = atlasServiceLocator();

const aiService = useMemo(() => {
return new AtlasAiService(atlasService, preferences, logger);
}, [preferences, logger, atlasService]);

return (
<AtlasAiServiceContext.Provider value={aiService}>
{children}
</AtlasAiServiceContext.Provider>
);
}
);
export type URLConfig = {
'user-access': (userId: string) => string;
query: string;
aggregation: string;
};

export const AtlasAiServiceProvider: React.FC<{
apiURLPreset: 'admin-api' | 'cloud';
urlConfig: URLConfig;
}> = createServiceProvider(function AtlasAiServiceProvider({
apiURLPreset,
children,
urlConfig,
}) {
const logger = useLogger('ATLAS-AI-SERVICE');
const preferences = preferencesLocator();
const atlasService = atlasServiceLocator();

const aiService = useMemo(() => {
const userId = preferences.getPreferencesUser().id;

return new AtlasAiService({
getUrlForEndpoint: (urlId: AIEndpoint) => {
const urlPath: string =
urlId === 'user-access' ? urlConfig[urlId](userId) : urlConfig[urlId];

return apiURLPreset === 'admin-api'
? atlasService.adminApiEndpoint(urlPath)
: atlasService.cloudEndpoint(urlPath);
},
atlasService,
preferences,
logger,
});
}, [apiURLPreset, preferences, logger, atlasService, urlConfig]);

return (
<AtlasAiServiceContext.Provider value={aiService}>
{children}
</AtlasAiServiceContext.Provider>
);
});

function useAtlasAiServiceContext(): AtlasAiService {
const service = useContext(AtlasAiServiceContext);
Expand All @@ -40,4 +65,5 @@ export const atlasAiServiceLocator = createServiceLocator(
useAtlasAiServiceContext,
'atlasAiServiceLocator'
);
export { AtlasAiService } from './atlas-ai-service';
export { AtlasAiService, aiURLConfig } from './atlas-ai-service';
export type { AIEndpoint } from './atlas-ai-service';
23 changes: 19 additions & 4 deletions packages/compass-web/src/entrypoint.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ import {
import type { AllPreferences } from 'compass-preferences-model/provider';
import FieldStorePlugin from '@mongodb-js/compass-field-store';
import { AtlasServiceProvider } from '@mongodb-js/atlas-service/provider';
import { AtlasAiServiceProvider } from '@mongodb-js/compass-generative-ai/provider';
import {
AtlasAiServiceProvider,
aiURLConfig,
} from '@mongodb-js/compass-generative-ai/provider';
import { LoggerProvider } from '@mongodb-js/compass-logging/provider';
import { TelemetryProvider } from '@mongodb-js/compass-telemetry/provider';
import CompassConnections from '@mongodb-js/compass-connections';
Expand All @@ -59,11 +62,23 @@ import { useCompassWebLoggerAndTelemetry } from './logger-and-telemetry';
import { type TelemetryServiceOptions } from '@mongodb-js/compass-telemetry';
import { WorkspaceTab as WelcomeWorkspaceTab } from '@mongodb-js/compass-welcome';

const WithAtlasProviders: React.FC = ({ children }) => {
const WithAtlasProviders: React.FC<{
projectId: string;
}> = ({ children, projectId }) => {
return (
<AtlasCloudAuthServiceProvider>
<AtlasServiceProvider>
<AtlasAiServiceProvider>{children}</AtlasAiServiceProvider>
<AtlasAiServiceProvider
apiURLPreset="cloud"
urlConfig={{
'user-access': (userId: string) =>
aiURLConfig.cloud['user-access'](userId, projectId),
query: aiURLConfig.cloud.query(projectId),
aggregation: aiURLConfig.cloud.aggregation(projectId),
}}
>
{children}
</AtlasAiServiceProvider>
</AtlasServiceProvider>
</AtlasCloudAuthServiceProvider>
);
Expand Down Expand Up @@ -315,7 +330,7 @@ const CompassWeb = ({
<PreferencesProvider value={preferencesAccess.current}>
<LoggerProvider value={logger}>
<TelemetryProvider options={telemetryOptions.current}>
<WithAtlasProviders>
<WithAtlasProviders projectId={projectId}>
<AtlasCloudConnectionStorageProvider
orgId={orgId}
projectId={projectId}
Expand Down
17 changes: 15 additions & 2 deletions packages/compass/src/app/components/entrypoint.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ import {
AtlasAuthServiceProvider,
AtlasServiceProvider,
} from '@mongodb-js/atlas-service/provider';
import { AtlasAiServiceProvider } from '@mongodb-js/compass-generative-ai/provider';
import {
AtlasAiServiceProvider,
aiURLConfig,
} from '@mongodb-js/compass-generative-ai/provider';
import {
CompassFavoriteQueryStorage,
CompassPipelineStorage,
Expand Down Expand Up @@ -61,7 +64,17 @@ export const WithAtlasProviders: React.FC = ({ children }) => {
},
}}
>
<AtlasAiServiceProvider>{children}</AtlasAiServiceProvider>
<AtlasAiServiceProvider
apiURLPreset="admin-api"
urlConfig={{
'user-access': (userId: string) =>
aiURLConfig['admin-api']['user-access'](userId),
query: aiURLConfig['admin-api'].query,
aggregation: aiURLConfig['admin-api'].aggregation,
}}
>
{children}
</AtlasAiServiceProvider>
</AtlasServiceProvider>
</AtlasAuthServiceProvider>
);
Expand Down

0 comments on commit 39cf6e7

Please sign in to comment.