Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ET-1441: add support for configurable maximum input and output token limits #6

Merged
merged 9 commits into from
Nov 28, 2023
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ OCO_EMOJI=<boolean, add GitMoji>
OCO_LANGUAGE=<locale, scroll to the bottom to see options>
OCO_MESSAGE_TEMPLATE_PLACEHOLDER=<message template placeholder, default: '$msg'>
OCO_PROMPT_MODULE=<either conventional-commit or @commitlint, default: conventional-commit>
OCO_OPENAI_MAX_TOKENS=<max response tokens (default: 500)>
OCO_TOKENS_MAX_INPUT=<max response tokens (default: 4096)>
OCO_TOKENS_MAX_OUTPUT=<max response tokens (default: 500)>
OCO_ISSUE_ENABLED=<boolean, issue ID included within commit message - default to true if issue prefix has been set>
OCO_ISSUE_PREFIX=<optional prefix for issue ID, eg. 'ABC-'>
```
Expand Down Expand Up @@ -372,7 +373,8 @@ jobs:
OCO_OPENAI_API_KEY: ${{ secrets.OCO_OPENAI_API_KEY }}

# customization
OCO_OPENAI_MAX_TOKENS: 500
OCO_TOKENS_MAX_INPUT: 4096
OCO_TOKENS_MAX_OUTPUT: 500
mattsalt123 marked this conversation as resolved.
Show resolved Hide resolved
OCO_OPENAI_BASE_PATH: ''
OCO_DESCRIPTION: false
OCO_EMOJI: false
Expand Down
10 changes: 5 additions & 5 deletions src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import { intro, outro } from '@clack/prompts';
import {
CONFIG_MODES,
AI_TYPE,
DEFAULT_MODEL_TOKEN_LIMIT,
getConfig
} from './commands/config';
import { GenerateCommitMessageErrorEnum } from './generateCommitMessageFromGitDiff';
Expand All @@ -21,7 +20,8 @@ import { IDENTITY } from './prompts';

const config = getConfig();

const MAX_TOKENS = config?.OCO_OPENAI_MAX_TOKENS;
const MAX_TOKENS_OUTPUT = config?.OCO_TOKENS_MAX_OUTPUT || 500;
const MAX_TOKENS_INPUT = config?.OCO_TOKENS_MAX_INPUT || 4096;
const BASE_PATH = config?.OCO_OPENAI_BASE_PATH;
const API_KEY = config?.OCO_OPENAI_API_KEY;
const API_TYPE = config?.OCO_OPENAI_API_TYPE || AI_TYPE.OPENAI;
Expand Down Expand Up @@ -83,7 +83,7 @@ class OpenAi {
messages,
temperature: 0,
top_p: 0.1,
max_tokens: MAX_TOKENS || 500
max_tokens: MAX_TOKENS_OUTPUT
};
try {
const completionReponse = await this.openAI.createChatCompletion({
Expand Down Expand Up @@ -148,14 +148,14 @@ class OpenAi {
messages,
temperature: 0,
top_p: 0.1,
max_tokens: MAX_TOKENS || 500
max_tokens: MAX_TOKENS_OUTPUT
};
try {
const REQUEST_TOKENS = messages
.map((msg) => tokenCount(msg.content) + 4)
.reduce((a, b) => a + b, 0);

if (REQUEST_TOKENS > DEFAULT_MODEL_TOKEN_LIMIT - MAX_TOKENS) {
if (REQUEST_TOKENS > MAX_TOKENS_INPUT - MAX_TOKENS_OUTPUT) {
throw new Error(GenerateCommitMessageErrorEnum.tooMuchTokens);
}

Expand Down
37 changes: 29 additions & 8 deletions src/commands/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ dotenv.config();

export enum CONFIG_KEYS {
OCO_OPENAI_API_KEY = 'OCO_OPENAI_API_KEY',
OCO_OPENAI_MAX_TOKENS = 'OCO_OPENAI_MAX_TOKENS',
OCO_TOKENS_MAX_INPUT = 'OCO_TOKENS_MAX_INPUT',
OCO_TOKENS_MAX_OUTPUT = 'OCO_TOKENS_MAX_OUTPUT',
OCO_OPENAI_BASE_PATH = 'OCO_OPENAI_BASE_PATH',
OCO_OPENAI_API_TYPE = 'OCO_OPENAI_API_TYPE',
OCO_DESCRIPTION = 'OCO_DESCRIPTION',
Expand All @@ -35,8 +36,6 @@ export enum AI_TYPE {
AZURE = 'azure'
}

export const DEFAULT_MODEL_TOKEN_LIMIT = 4096;

export enum CONFIG_MODES {
get = 'get',
set = 'set'
Expand Down Expand Up @@ -85,18 +84,37 @@ export const configValidators = {
return value;
},

[CONFIG_KEYS.OCO_OPENAI_MAX_TOKENS](value: any) {
[CONFIG_KEYS.OCO_TOKENS_MAX_INPUT](value: any) {
// If the value is a string, convert it to a number.
if (typeof value === 'string') {
value = parseInt(value);
validateConfig(
CONFIG_KEYS.OCO_TOKENS_MAX_INPUT,
!isNaN(value),
'Must be a number'
);
}
validateConfig(
CONFIG_KEYS.OCO_TOKENS_MAX_INPUT,
value ? typeof value === 'number' : undefined,
'Must be a number'
);

return value;
},

[CONFIG_KEYS.OCO_TOKENS_MAX_OUTPUT](value: any) {
// If the value is a string, convert it to a number.
if (typeof value === 'string') {
value = parseInt(value);
validateConfig(
CONFIG_KEYS.OCO_OPENAI_MAX_TOKENS,
CONFIG_KEYS.OCO_TOKENS_MAX_OUTPUT,
!isNaN(value),
'Must be a number'
);
}
validateConfig(
CONFIG_KEYS.OCO_OPENAI_MAX_TOKENS,
CONFIG_KEYS.OCO_TOKENS_MAX_OUTPUT,
value ? typeof value === 'number' : undefined,
'Must be a number'
);
Expand Down Expand Up @@ -226,8 +244,11 @@ const configPath = pathJoin(homedir(), '.opencommit');
export const getConfig = (): ConfigType | null => {
const configFromEnv = {
OCO_OPENAI_API_KEY: process.env.OCO_OPENAI_API_KEY,
OCO_OPENAI_MAX_TOKENS: process.env.OCO_OPENAI_MAX_TOKENS
? Number(process.env.OCO_OPENAI_MAX_TOKENS)
OCO_TOKENS_MAX_INPUT: process.env.OCO_TOKENS_MAX_INPUT
? Number(process.env.OCO_TOKENS_MAX_INPUT)
: undefined,
OCO_TOKENS_MAX_OUTPUT: process.env.OCO_TOKENS_MAX_OUTPUT
? Number(process.env.OCO_TOKENS_MAX_OUTPUT)
: undefined,
OCO_OPENAI_BASE_PATH: process.env.OCO_OPENAI_BASE_PATH,
OCO_OPENAI_API_TYPE: process.env.OCO_OPENAI_API_TYPE || 'openai',
Expand Down
8 changes: 5 additions & 3 deletions src/generateCommitMessageFromGitDiff.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ import {
} from 'openai';

import { api } from './api';
import { DEFAULT_MODEL_TOKEN_LIMIT, getConfig } from './commands/config';
import { getConfig } from './commands/config';
import { getMainCommitPrompt } from './prompts';
import { mergeDiffs } from './utils/mergeDiffs';
import { tokenCount } from './utils/tokenCount';

const config = getConfig();
const MAX_TOKENS_INPUT = config?.OCO_TOKENS_MAX_INPUT || 4096;
const MAX_TOKENS_OUTPUT = config?.OCO_TOKENS_MAX_INPUT || 500;

const generateCommitMessageChatCompletionPrompt = async (
diff: string,
Expand Down Expand Up @@ -47,10 +49,10 @@ export const generateCommitMessageByDiff = async (
).reduce((a, b) => a + b, 0);

const MAX_REQUEST_TOKENS =
DEFAULT_MODEL_TOKEN_LIMIT -
MAX_TOKENS_INPUT -
ADJUSTMENT_FACTOR -
INIT_MESSAGES_PROMPT_LENGTH -
config?.OCO_OPENAI_MAX_TOKENS;
MAX_TOKENS_OUTPUT;

if (tokenCount(diff) >= MAX_REQUEST_TOKENS) {
const commitMessagePromises = await getCommitMsgsPromisesFromFileDiffs(
Expand Down