Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support system prompt natively for Gemini 1.5 #22

Merged
merged 2 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/plugin-authoring.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ export const myPlugin = genkitPlugin('my-plugin', async (options: {apiKey?: stri
multiturn: true, // true if your model supports conversations
media: true, // true if your model supports multimodal input
tools: true, // true if your model supports tool/function calling
systemRole: true, // true if your model supports the system role
output: ['text', 'media', 'json'], // types of output your model supports
},
// Zod schema for your model's custom configuration
Expand Down
2 changes: 2 additions & 0 deletions js/ai/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ export const ModelInfoSchema = z.object({
media: z.boolean().optional(),
/** Model can perform tool calls. */
tools: z.boolean().optional(),
/** Model can accept messages with role "system". */
systemRole: z.boolean().optional(),
/** Model can output this type of data. */
output: z.array(OutputFormatSchema).optional(),
})
Expand Down
94 changes: 74 additions & 20 deletions js/plugins/googleai/src/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
MediaPart,
MessageData,
ModelAction,
ModelMiddleware,
modelRef,
ModelReference,
Part,
Expand Down Expand Up @@ -72,6 +73,7 @@ export const geminiPro = modelRef({
multiturn: true,
media: false,
tools: false,
systemRole: true,
},
},
configSchema: GeminiConfigSchema,
Expand All @@ -87,6 +89,7 @@ export const geminiProVision = modelRef({
multiturn: true,
media: true,
tools: false,
systemRole: false,
},
},
configSchema: GeminiConfigSchema,
Expand All @@ -100,6 +103,7 @@ export const gemini15Pro = modelRef({
multiturn: true,
media: true,
tools: true,
systemRole: true,
},
},
configSchema: GeminiConfigSchema,
Expand All @@ -114,39 +118,52 @@ export const geminiUltra = modelRef({
multiturn: true,
media: false,
tools: false,
systemRole: true,
},
},
configSchema: GeminiConfigSchema,
});

export const V1_SUPPORTED_MODELS: Record<
export const SUPPORTED_V1_MODELS: Record<
string,
ModelReference<z.ZodTypeAny>
> = {
'gemini-pro': geminiPro,
'gemini-pro-vision': geminiProVision,
// 'gemini-ultra': geminiUltra,
};

export const V1_BETA_SUPPORTED_MODELS: Record<
export const SUPPORTED_V15_MODELS: Record<
string,
ModelReference<z.ZodTypeAny>
> = {
'gemini-1.5-pro-latest': gemini15Pro,
};

const SUPPORTED_MODELS = {
...V1_SUPPORTED_MODELS,
...V1_BETA_SUPPORTED_MODELS,
...SUPPORTED_V1_MODELS,
...SUPPORTED_V15_MODELS,
};

function toGeminiRole(role: MessageData['role']): string {
function toGeminiRole(
role: MessageData['role'],
model?: ModelReference<z.ZodTypeAny>
): string {
switch (role) {
case 'user':
return 'user';
case 'model':
return 'model';
case 'system':
throw new Error('system role is not supported');
if (model && SUPPORTED_V15_MODELS[model.name]) {
// We should have already pulled out the supported system messages,
// anything remaining is unsupported; throw an error.
throw new Error(
'system role is only supported for a single message in the first position'
);
} else {
throw new Error('system role is not supported');
}
case 'tool':
return 'function';
default:
Expand Down Expand Up @@ -195,9 +212,19 @@ function fromGeminiPart(part: GeminiPart): Part {
throw new Error('Only support text for the moment.');
}

function toGeminiMessage(message: MessageData): GeminiMessage {
function toGeminiMessage(
message: MessageData,
model?: ModelReference<z.ZodTypeAny>
): GeminiMessage {
return {
role: toGeminiRole(message.role, model),
parts: message.content.map(toGeminiPart),
};
}

function toGeminiSystemInstruction(message: MessageData): GeminiMessage {
return {
role: toGeminiRole(message.role),
role: 'user',
parts: message.content.map(toGeminiPart),
};
}
Expand Down Expand Up @@ -245,25 +272,34 @@ export function googleAIModel(
baseUrl?: string
): ModelAction {
const modelName = `googleai/${name}`;
if (!apiKey)

if (!apiKey) {
apiKey = process.env.GOOGLE_GENAI_API_KEY || process.env.GOOGLE_API_KEY;
if (!apiKey)
}
if (!apiKey) {
throw new Error(
'please pass in the API key or set the GOOGLE_GENAI_API_KEY or GOOGLE_API_KEY environment variable'
);
}

const model: ModelReference<z.ZodTypeAny> = SUPPORTED_MODELS[name];
if (!model) throw new Error(`Unsupported model: ${name}`);

const middleware: ModelMiddleware[] = [];
if (SUPPORTED_V1_MODELS[name]) {
middleware.push(simulateSystemPrompt());
}
if (model?.info?.supports?.media) {
// the gemini api doesn't support downloading media from http(s)
middleware.push(downloadRequestMedia({ maxBytes: 1024 * 1024 * 10 }));
}

return defineModel(
{
name: modelName,
...model.info,
configSchema: model.configSchema,
use: [
// simulate a system prompt since no native one is supported
simulateSystemPrompt(),
// since gemini api doesn't support downloading media from http(s)
downloadRequestMedia({ maxBytes: 1024 * 1024 * 10 }),
],
use: middleware,
},
async (request, streamingCallback) => {
const options: RequestOptions = { apiClient: GENKIT_CLIENT_HEADER };
Expand All @@ -279,10 +315,27 @@ export function googleAIModel(
},
options
);
const messages = request.messages.map(toGeminiMessage);

// make a copy so that modifying the request will not produce side-effects
const messages = [...request.messages];
if (messages.length === 0) throw new Error('No messages provided.');

// Gemini does not support messages with role system and instead expects
// systemInstructions to be provided as a separate input. The first
// message detected with role=system will be used for systemInstructions.
// Any additional system messages may be considered to be "exceptional".
let systemInstruction: GeminiMessage | undefined = undefined;
const systemMessage = messages.find((m) => m.role === 'system');
if (systemMessage) {
messages.splice(messages.indexOf(systemMessage), 1);
systemInstruction = toGeminiSystemInstruction(systemMessage);
}

const chatRequest = {
history: messages.slice(0, messages.length - 1),
systemInstruction,
history: messages
.slice(0, -1)
.map((message) => toGeminiMessage(message, model)),
generationConfig: {
candidateCount: request.candidates || undefined,
temperature: request.config?.temperature,
Expand All @@ -293,10 +346,11 @@ export function googleAIModel(
},
safetySettings: request.config?.safetySettings,
} as StartChatParams;
const msg = toGeminiMessage(messages[messages.length - 1], model);
if (streamingCallback) {
const result = await client
.startChat(chatRequest)
.sendMessageStream(messages[messages.length - 1].parts);
.sendMessageStream(msg.parts);
for await (const item of result.stream) {
(item as GenerateContentResponse).candidates?.forEach((candidate) => {
const c = fromGeminiCandidate(candidate);
Expand All @@ -317,7 +371,7 @@ export function googleAIModel(
} else {
const result = await client
.startChat(chatRequest)
.sendMessage(messages[messages.length - 1].parts);
.sendMessage(msg.parts);
if (!result.response.candidates?.length)
throw new Error('No valid candidates returned.');
const responseCandidates =
Expand Down
8 changes: 4 additions & 4 deletions js/plugins/googleai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ import {
geminiPro,
geminiProVision,
googleAIModel,
V1_BETA_SUPPORTED_MODELS,
V1_SUPPORTED_MODELS,
SUPPORTED_V15_MODELS,
SUPPORTED_V1_MODELS,
} from './gemini.js';
export { gemini15Pro, geminiPro, geminiProVision };

Expand All @@ -52,14 +52,14 @@ export const googleAI: Plugin<[PluginOptions] | []> = genkitPlugin(
if (apiVersions.includes('v1beta')) {
(embedders = []),
(models = [
...Object.keys(V1_BETA_SUPPORTED_MODELS).map((name) =>
...Object.keys(SUPPORTED_V15_MODELS).map((name) =>
googleAIModel(name, options?.apiKey, 'v1beta', options?.baseUrl)
),
]);
}
if (apiVersions.includes('v1')) {
models = [
...Object.keys(V1_SUPPORTED_MODELS).map((name) =>
...Object.keys(SUPPORTED_V1_MODELS).map((name) =>
googleAIModel(name, options?.apiKey, undefined, options?.baseUrl)
),
];
Expand Down
3 changes: 3 additions & 0 deletions js/plugins/vertexai/src/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ export const claude3Sonnet = modelRef({
multiturn: true,
media: true,
tools: false,
systemRole: true,
output: ['text'],
},
},
Expand All @@ -59,6 +60,7 @@ export const claude3Haiku = modelRef({
multiturn: true,
media: true,
tools: false,
systemRole: true,
output: ['text'],
},
},
Expand All @@ -74,6 +76,7 @@ export const claude3Opus = modelRef({
multiturn: true,
media: true,
tools: false,
systemRole: true,
output: ['text'],
},
},
Expand Down
Loading
Loading