diff --git a/src/services/apis/custom-api.mjs b/src/services/apis/custom-api.mjs index 002d1b96..b82e0a15 100644 --- a/src/services/apis/custom-api.mjs +++ b/src/services/apis/custom-api.mjs @@ -10,6 +10,7 @@ import { fetchSSE } from '../../utils/fetch-sse.mjs' import { getConversationPairs } from '../../utils/get-conversation-pairs.mjs' import { isEmpty } from 'lodash-es' import { pushRecord, setAbortController } from './shared.mjs' +import { getChatCompletionsTokenParams } from './openai-token-params.mjs' /** * @param {Browser.Runtime.Port} port @@ -55,7 +56,7 @@ export async function generateAnswersWithCustomApi( messages: prompt, model: modelName, stream: true, - max_tokens: config.maxResponseTokenLength, + ...getChatCompletionsTokenParams('custom', modelName, config.maxResponseTokenLength), temperature: config.temperature, }), onMessage(message) { diff --git a/src/services/apis/openai-api.mjs b/src/services/apis/openai-api.mjs index 2d30c3b1..752a2a21 100644 --- a/src/services/apis/openai-api.mjs +++ b/src/services/apis/openai-api.mjs @@ -6,6 +6,7 @@ import { getConversationPairs } from '../../utils/get-conversation-pairs.mjs' import { isEmpty } from 'lodash-es' import { getCompletionPromptBase, pushRecord, setAbortController } from './shared.mjs' import { getModelValue } from '../../utils/model-name-convert.mjs' +import { getChatCompletionsTokenParams } from './openai-token-params.mjs' /** * @param {Browser.Runtime.Port} port @@ -103,6 +104,8 @@ export async function generateAnswersWithChatgptApi(port, question, session, api question, session, apiKey, + {}, + 'openai', ) } @@ -113,6 +116,7 @@ export async function generateAnswersWithChatgptApiCompat( session, apiKey, extraBody = {}, + provider = 'compat', ) { const { controller, messageListener, disconnectListener } = setAbortController(port) const model = getModelValue(session) @@ -123,6 +127,12 @@ export async function generateAnswersWithChatgptApiCompat( false, ) prompt.push({ role: 'user', content: question }) + const tokenParams = getChatCompletionsTokenParams(provider, model, config.maxResponseTokenLength) + const conflictingTokenParamKey = + 'max_completion_tokens' in tokenParams ? 'max_tokens' : 'max_completion_tokens' + // Avoid sending both token-limit fields when caller passes extraBody. + const safeExtraBody = { ...extraBody } + delete safeExtraBody[conflictingTokenParamKey] let answer = '' let finished = false @@ -143,9 +153,9 @@ export async function generateAnswersWithChatgptApiCompat( messages: prompt, model, stream: true, - max_tokens: config.maxResponseTokenLength, + ...tokenParams, temperature: config.temperature, - ...extraBody, + ...safeExtraBody, }), onMessage(message) { console.debug('sse message', message) diff --git a/src/services/apis/openai-token-params.mjs b/src/services/apis/openai-token-params.mjs new file mode 100644 index 00000000..d5193376 --- /dev/null +++ b/src/services/apis/openai-token-params.mjs @@ -0,0 +1,21 @@ +const GPT5_CHAT_COMPLETIONS_MODEL_PATTERN = /(^|\/)gpt-5([.-]|$)/ + +function shouldUseMaxCompletionTokens(provider, model) { + const normalizedProvider = String(provider || '').toLowerCase() + const normalizedModel = String(model || '').toLowerCase() + + switch (true) { + case normalizedProvider === 'openai' && + GPT5_CHAT_COMPLETIONS_MODEL_PATTERN.test(normalizedModel): + return true + default: + return false + } +} + +export function getChatCompletionsTokenParams(provider, model, maxResponseTokenLength) { + if (shouldUseMaxCompletionTokens(provider, model)) + return { max_completion_tokens: maxResponseTokenLength } + + return { max_tokens: maxResponseTokenLength } +} diff --git a/src/services/apis/openai-token-params.test.mjs b/src/services/apis/openai-token-params.test.mjs new file mode 100644 index 00000000..cc8b948f --- /dev/null +++ b/src/services/apis/openai-token-params.test.mjs @@ -0,0 +1,57 @@ +import test from 'node:test' +import assert from 'node:assert/strict' +import { getChatCompletionsTokenParams } from './openai-token-params.mjs' + +test('uses max_completion_tokens for gpt-5.x chat models', () => { + assert.deepEqual(getChatCompletionsTokenParams('openai', 'gpt-5.2-chat-latest', 1024), { + max_completion_tokens: 1024, + }) +}) + +test('uses max_completion_tokens for provider-prefixed gpt-5.x models', () => { + assert.deepEqual(getChatCompletionsTokenParams('openai', 'openai/gpt-5.2', 2048), { + max_completion_tokens: 2048, + }) +}) + +test('uses max_completion_tokens for gpt-5 baseline model name', () => { + assert.deepEqual(getChatCompletionsTokenParams('openai', 'gpt-5', 1536), { + max_completion_tokens: 1536, + }) +}) + +test('uses max_tokens for non gpt-5 chat models', () => { + assert.deepEqual(getChatCompletionsTokenParams('openai', 'gpt-4o', 512), { + max_tokens: 512, + }) +}) + +test('uses max_tokens for lookalike model names', () => { + assert.deepEqual(getChatCompletionsTokenParams('openai', 'my-gpt-5-clone', 640), { + max_tokens: 640, + }) +}) + +test('uses max_tokens for empty model values', () => { + assert.deepEqual(getChatCompletionsTokenParams('openai', '', 256), { + max_tokens: 256, + }) +}) + +test('uses max_tokens for non OpenAI providers even with gpt-5 models', () => { + assert.deepEqual(getChatCompletionsTokenParams('some-proxy-provider', 'openai/gpt-5.2', 257), { + max_tokens: 257, + }) +}) + +test('uses max_completion_tokens for mixed-case OpenAI provider and model', () => { + assert.deepEqual(getChatCompletionsTokenParams('OpenAI', 'GPT-5.1', 258), { + max_completion_tokens: 258, + }) +}) + +test('uses max_tokens when provider is undefined', () => { + assert.deepEqual(getChatCompletionsTokenParams(undefined, 'gpt-5.1', 259), { + max_tokens: 259, + }) +})