Skip to content

Commit

Permalink
Add ability to set modelParams on getGenerativeModelFromCachedContent…
Browse files Browse the repository at this point in the history
…() (#254)
  • Loading branch information
hsubox76 authored Sep 11, 2024
1 parent ce49f34 commit fc008a1
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 5 deletions.
5 changes: 5 additions & 0 deletions .changeset/tame-lizards-kiss.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@google/generative-ai": minor
---

Add ability to set modelParams (generationConfig, safetySettings) on getGenerativeModelFromCachedContent().
2 changes: 1 addition & 1 deletion common/api-review/generative-ai.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ export class GoogleGenerativeAI {
// (undocumented)
apiKey: string;
getGenerativeModel(modelParams: ModelParams, requestOptions?: RequestOptions): GenerativeModel;
getGenerativeModelFromCachedContent(cachedContent: CachedContent, requestOptions?: RequestOptions): GenerativeModel;
getGenerativeModelFromCachedContent(cachedContent: CachedContent, modelParams?: Partial<ModelParams>, requestOptions?: RequestOptions): GenerativeModel;
}

// @public
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ Creates a [GenerativeModel](./generative-ai.generativemodel.md) instance from pr
**Signature:**

```typescript
getGenerativeModelFromCachedContent(cachedContent: CachedContent, requestOptions?: RequestOptions): GenerativeModel;
getGenerativeModelFromCachedContent(cachedContent: CachedContent, modelParams?: Partial<ModelParams>, requestOptions?: RequestOptions): GenerativeModel;
```

## Parameters

| Parameter | Type | Description |
| --- | --- | --- |
| cachedContent | [CachedContent](./generative-ai.cachedcontent.md) | |
| modelParams | Partial&lt;[ModelParams](./generative-ai.modelparams.md)<!-- -->&gt; | _(Optional)_ |
| requestOptions | [RequestOptions](./generative-ai.requestoptions.md) | _(Optional)_ |

**Returns:**
Expand Down
2 changes: 1 addition & 1 deletion docs/reference/main/generative-ai.googlegenerativeai.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ export declare class GoogleGenerativeAI
| Method | Modifiers | Description |
| --- | --- | --- |
| [getGenerativeModel(modelParams, requestOptions)](./generative-ai.googlegenerativeai.getgenerativemodel.md) | | Gets a [GenerativeModel](./generative-ai.generativemodel.md) instance for the provided model name. |
| [getGenerativeModelFromCachedContent(cachedContent, requestOptions)](./generative-ai.googlegenerativeai.getgenerativemodelfromcachedcontent.md) | | Creates a [GenerativeModel](./generative-ai.generativemodel.md) instance from provided content cache. |
| [getGenerativeModelFromCachedContent(cachedContent, modelParams, requestOptions)](./generative-ai.googlegenerativeai.getgenerativemodelfromcachedcontent.md) | | Creates a [GenerativeModel](./generative-ai.generativemodel.md) instance from provided content cache. |

91 changes: 89 additions & 2 deletions src/gen-ai.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,104 @@ import { ModelParams } from "../types";
import { GenerativeModel, GoogleGenerativeAI } from "./gen-ai";
import { expect } from "chai";

const fakeContents = [{ role: "user", parts: [{ text: "hello" }] }];

const fakeCachedContent = {
model: "my-model",
name: "mycachename",
contents: fakeContents,
};

describe("GoogleGenerativeAI", () => {
it("genGenerativeInstance throws if no model is provided", () => {
it("getGenerativeModel throws if no model is provided", () => {
const genAI = new GoogleGenerativeAI("apikey");
expect(() => genAI.getGenerativeModel({} as ModelParams)).to.throw(
"Must provide a model name",
);
});
it("genGenerativeInstance gets a GenerativeModel", () => {
it("getGenerativeModel gets a GenerativeModel", () => {
const genAI = new GoogleGenerativeAI("apikey");
const genModel = genAI.getGenerativeModel({ model: "my-model" });
expect(genModel).to.be.an.instanceOf(GenerativeModel);
expect(genModel.model).to.equal("models/my-model");
});
it("getGenerativeModelFromCachedContent gets a GenerativeModel", () => {
const genAI = new GoogleGenerativeAI("apikey");
const genModel =
genAI.getGenerativeModelFromCachedContent(fakeCachedContent);
expect(genModel).to.be.an.instanceOf(GenerativeModel);
expect(genModel.model).to.equal("models/my-model");
expect(genModel.cachedContent).to.eql(fakeCachedContent);
});
it("getGenerativeModelFromCachedContent gets a GenerativeModel merged with modelParams", () => {
const genAI = new GoogleGenerativeAI("apikey");
const genModel = genAI.getGenerativeModelFromCachedContent(
fakeCachedContent,
{ generationConfig: { temperature: 0 } },
);
expect(genModel).to.be.an.instanceOf(GenerativeModel);
expect(genModel.model).to.equal("models/my-model");
expect(genModel.generationConfig.temperature).to.equal(0);
expect(genModel.cachedContent).to.eql(fakeCachedContent);
});
it("getGenerativeModelFromCachedContent gets a GenerativeModel merged with modelParams with overlapping keys", () => {
const genAI = new GoogleGenerativeAI("apikey");
const genModel = genAI.getGenerativeModelFromCachedContent(
fakeCachedContent,
{ model: "my-model", generationConfig: { temperature: 0 } },
);
expect(genModel).to.be.an.instanceOf(GenerativeModel);
expect(genModel.model).to.equal("models/my-model");
expect(genModel.generationConfig.temperature).to.equal(0);
expect(genModel.cachedContent).to.eql(fakeCachedContent);
});
it("getGenerativeModelFromCachedContent throws if no name", () => {
const genAI = new GoogleGenerativeAI("apikey");
expect(() =>
genAI.getGenerativeModelFromCachedContent({
model: "my-model",
contents: fakeContents,
}),
).to.throw("Cached content must contain a `name` field.");
});
it("getGenerativeModelFromCachedContent throws if no model", () => {
const genAI = new GoogleGenerativeAI("apikey");
expect(() =>
genAI.getGenerativeModelFromCachedContent({
name: "cachename",
contents: fakeContents,
}),
).to.throw("Cached content must contain a `model` field.");
});
it("getGenerativeModelFromCachedContent throws if mismatched model", () => {
const genAI = new GoogleGenerativeAI("apikey");
expect(() =>
genAI.getGenerativeModelFromCachedContent(
{
name: "cachename",
model: "my-model",
contents: fakeContents,
},
{ model: "your-model" },
),
).to.throw(
`Different value for "model" specified in modelParams (your-model) and cachedContent (my-model)`,
);
});
it("getGenerativeModelFromCachedContent throws if mismatched systemInstruction", () => {
const genAI = new GoogleGenerativeAI("apikey");
expect(() =>
genAI.getGenerativeModelFromCachedContent(
{
name: "cachename",
model: "my-model",
contents: fakeContents,
systemInstruction: "hi",
},
{ model: "models/my-model", systemInstruction: "yo" },
),
).to.throw(
`Different value for "systemInstruction" specified in modelParams (yo) and cachedContent (hi)`,
);
});
});
34 changes: 34 additions & 0 deletions src/gen-ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ export class GoogleGenerativeAI {
*/
getGenerativeModelFromCachedContent(
cachedContent: CachedContent,
modelParams?: Partial<ModelParams>,
requestOptions?: RequestOptions,
): GenerativeModel {
if (!cachedContent.name) {
Expand All @@ -65,7 +66,40 @@ export class GoogleGenerativeAI {
"Cached content must contain a `model` field.",
);
}

/**
* Not checking tools and toolConfig for now as it would require a deep
* equality comparison and isn't likely to be a common case.
*/
const disallowedDuplicates: Array<keyof ModelParams & keyof CachedContent> =
["model", "systemInstruction"];

for (const key of disallowedDuplicates) {
if (
modelParams?.[key] &&
cachedContent[key] &&
modelParams?.[key] !== cachedContent[key]
) {
if (key === "model") {
const modelParamsComp = modelParams.model.startsWith("models/")
? modelParams.model.replace("models/", "")
: modelParams.model;
const cachedContentComp = cachedContent.model.startsWith("models/")
? cachedContent.model.replace("models/", "")
: cachedContent.model;
if (modelParamsComp === cachedContentComp) {
continue;
}
}
throw new GoogleGenerativeAIRequestInputError(
`Different value for "${key}" specified in modelParams` +
` (${modelParams[key]}) and cachedContent (${cachedContent[key]})`,
);
}
}

const modelParamsFromCache: ModelParams = {
...modelParams,
model: cachedContent.model,
tools: cachedContent.tools,
toolConfig: cachedContent.toolConfig,
Expand Down

0 comments on commit fc008a1

Please sign in to comment.