Skip to content

Commit

Permalink
Task: Add usage to all models using tiktokken
Browse files Browse the repository at this point in the history
  • Loading branch information
zya committed Oct 11, 2023
1 parent c2117b9 commit ae80f83
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 7 deletions.
28 changes: 28 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"typescript": "^5.2.2"
},
"dependencies": {
"js-tiktoken": "^1.0.7",
"openai": "^4.11.1"
},
"engines": {
Expand Down
6 changes: 5 additions & 1 deletion src/handlers/anthropic.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Anthropic from '@anthropic-ai/sdk';

const anthropic = new Anthropic();

import {
Expand All @@ -13,6 +14,7 @@ import {
} from '../types';
import { combinePrompts } from '../utils/combinePrompts';
import { getUnixTimestamp } from '../utils/getUnixTimestamp';
import { toUsage } from '../utils/toUsage';

function toAnthropicPrompt(messages: Message[]): string {
const textsCombined = combinePrompts(messages);
Expand All @@ -29,10 +31,12 @@ function toFinishReson(string: string): FinishReason {

function toResponse(
anthropicResponse: Anthropic.Completion,
prompt: string,
): ResultNotStreaming {
return {
model: anthropicResponse.model,
created: getUnixTimestamp(),
usage: toUsage(prompt, anthropicResponse.completion),
choices: [
{
message: {
Expand Down Expand Up @@ -103,5 +107,5 @@ export async function AnthropicHandler(

const completion = await anthropic.completions.create(anthropicParams);

return toResponse(completion);
return toResponse(completion, prompt);
}
6 changes: 5 additions & 1 deletion src/handlers/cohere.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@ import {
import { cohereResponse, generateResponse } from 'cohere-ai/dist/models';
import { combinePrompts } from '../utils/combinePrompts';
import { getUnixTimestamp } from '../utils/getUnixTimestamp';
import { toUsage } from '../utils/toUsage';

// eslint-disable-next-line @typescript-eslint/require-await
async function* toStream(
response: cohereResponse<generateResponse>,
model: string,
prompt: string,
): AsyncIterable<StreamingChunk> {
yield {
model: model,
created: getUnixTimestamp(),
usage: toUsage(prompt, response.body.generations[0].text),
choices: [
{
delta: {
Expand Down Expand Up @@ -61,12 +64,13 @@ export async function CohereHandler(
const response = await cohere.generate(config);

if (params.stream) {
return toStream(response, params.model);
return toStream(response, params.model, textsCombined);
}

return {
model: params.model,
created: getUnixTimestamp(),
usage: toUsage(textsCombined, response.body.generations[0].text),
choices: [
{
message: {
Expand Down
19 changes: 14 additions & 5 deletions src/handlers/ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
} from '../types';
import { combinePrompts } from '../utils/combinePrompts';
import { getUnixTimestamp } from '../utils/getUnixTimestamp';
import { toUsage } from '../utils/toUsage';

interface OllamaResponseChunk {
model: string;
Expand All @@ -19,10 +20,12 @@ interface OllamaResponseChunk {
function toStreamingChunk(
ollamaResponse: OllamaResponseChunk,
model: string,
prompt: string,
): StreamingChunk {
return {
model: model,
created: getUnixTimestamp(),
usage: toUsage(prompt, ollamaResponse.response),
choices: [
{
delta: { content: ollamaResponse.response, role: 'assistant' },
Expand All @@ -33,10 +36,15 @@ function toStreamingChunk(
};
}

function toResponse(content: string, model: string): ResultNotStreaming {
function toResponse(
content: string,
model: string,
prompt: string,
): ResultNotStreaming {
return {
model: model,
created: getUnixTimestamp(),
usage: toUsage(prompt, content),
choices: [
{
message: { content, role: 'assistant' },
Expand All @@ -50,6 +58,7 @@ function toResponse(content: string, model: string): ResultNotStreaming {
async function* iterateResponse(
response: Response,
model: string,
prompt: string,
): AsyncIterable<StreamingChunk> {
const reader = response.body?.getReader();
let done = false;
Expand All @@ -60,7 +69,7 @@ async function* iterateResponse(
const decoded = new TextDecoder().decode(next.value);
done = next.done;
const ollamaResponse = JSON.parse(decoded) as OllamaResponseChunk;
yield toStreamingChunk(ollamaResponse, model);
yield toStreamingChunk(ollamaResponse, model, prompt);
} else {
done = true;
}
Expand Down Expand Up @@ -107,18 +116,18 @@ export async function OllamaHandler(
const res = await getOllamaResponse(model, prompt, baseUrl);

if (params.stream) {
return iterateResponse(res, model);
return iterateResponse(res, model, prompt);
}

const chunks: StreamingChunk[] = [];

for await (const chunk of iterateResponse(res, model)) {
for await (const chunk of iterateResponse(res, model, prompt)) {
chunks.push(chunk);
}

const message = chunks.reduce((acc: string, chunk: StreamingChunk) => {
return (acc += chunk.choices[0].delta.content);
}, '');

return toResponse(message, model);
return toResponse(message, model, prompt);
}
3 changes: 3 additions & 0 deletions src/utils/encoders.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import { getEncoding } from 'js-tiktoken';

export const encoderCl100K = getEncoding('cl100k_base');
19 changes: 19 additions & 0 deletions src/utils/toUsage.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { ConsistentResponseUsage } from '../types';
import { encoderCl100K } from './encoders';

export function toUsage(
prompt: string,
completion: string | undefined,
): ConsistentResponseUsage | undefined {
if (!completion) {
return undefined;
}

const promptTokens = encoderCl100K.encode(prompt);
const completionTokens = encoderCl100K.encode(completion);
return {
prompt_tokens: promptTokens.length,
completion_tokens: completionTokens.length,
total_tokens: promptTokens.concat(completionTokens).length,
};
}

0 comments on commit ae80f83

Please sign in to comment.