From 98aa82124587bbb35324d8269b9995b794a3bfcb Mon Sep 17 00:00:00 2001 From: Ding <44717411+ding113@users.noreply.github.com> Date: Thu, 19 Mar 2026 00:28:10 +0800 Subject: [PATCH 01/11] fix: run reactive rectifiers during streaming hedge (#945) * fix: run reactive rectifiers during streaming hedge * fix: address hedge race conditions and audit accuracy in forwarder - Add attempt.settled guard in .then() callback to prevent stale response processing - Clear thresholdTimer before rectifier retry to avoid spurious hedge triggers - Use requestAttemptCount instead of sequence for accurate retry chain entries - Merge specialSettings on hedge winner sync to preserve rectifier audit data Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: Claude Opus 4.6 (1M context) --- src/app/v1/_lib/proxy/forwarder.ts | 803 ++++++++++-------- .../proxy-forwarder-hedge-first-byte.test.ts | 239 ++++++ 2 files changed, 700 insertions(+), 342 deletions(-) diff --git a/src/app/v1/_lib/proxy/forwarder.ts b/src/app/v1/_lib/proxy/forwarder.ts index fa4b4a363..b6814f9cf 100644 --- a/src/app/v1/_lib/proxy/forwarder.ts +++ b/src/app/v1/_lib/proxy/forwarder.ts @@ -39,7 +39,10 @@ import { updateMessageRequestDetails } from "@/repository/message"; import type { CacheTtlPreference, CacheTtlResolved } from "@/types/cache"; import type { ProviderChainItem } from "@/types/message"; import type { Provider } from "@/types/provider"; -import type { ClaudeMetadataUserIdInjectionSpecialSetting } from "@/types/special-settings"; +import type { + ClaudeMetadataUserIdInjectionSpecialSetting, + SpecialSetting, +} from "@/types/special-settings"; import { GeminiAuth } from "../gemini/auth"; import { GEMINI_PROTOCOL } from "../gemini/protocol"; @@ -93,11 +96,14 @@ type ProxySessionWithAttemptRuntime = ProxySession & { type StreamingHedgeAttempt = { provider: Provider; session: ProxySession; + baseUrl: string; endpointAudit: { endpointId: number | null; endpointUrl: string }; responseController: AbortController | null; clearResponseTimeout: (() => void) | null; firstByteTimeoutMs: number; sequence: number; + requestAttemptCount: number; + reactiveRectifierRetryState: ReactiveRectifierRetryState; settled: boolean; thresholdTriggered: boolean; thresholdTimer: NodeJS.Timeout | null; @@ -105,6 +111,28 @@ type StreamingHedgeAttempt = { response: Response | null; }; +type ReactiveRectifierRetryState = { + thinkingSignatureRetried: boolean; + thinkingBudgetRetried: boolean; +}; + +type ReactiveRectifierResult = + | { matched: false } + | { + matched: true; + applied: false; + reason: "already_retried" | "not_applicable"; + rectifierType: "thinking_signature_rectifier" | "thinking_budget_rectifier"; + trigger: string; + } + | { + matched: true; + applied: true; + rectifierType: "thinking_signature_rectifier" | "thinking_budget_rectifier"; + trigger: string; + requestDetailsBeforeRectify: ReturnType; + }; + // 非流式响应体检查的上限(字节):避免上游在 2xx 场景返回超大内容导致内存占用失控。 // 说明: // - 该检查仅用于“空响应/假 200”启发式判定,不用于业务逻辑解析; @@ -351,6 +379,218 @@ async function persistSpecialSettings(session: ProxySession): Promise { } } +function addSpecialSettingForPersistence( + ownerSession: ProxySession, + persistSession: ProxySession, + setting: SpecialSetting +): void { + ownerSession.addSpecialSetting(setting); + if (persistSession !== ownerSession) { + persistSession.addSpecialSetting(setting); + } +} + +function buildRetryFailedChainEntry( + provider: Provider, + endpointAudit: { endpointId: number | null; endpointUrl: string }, + attemptNumber: number, + error: Error, + errorMessage: string, + requestDetailsBeforeRectify: ReturnType +): NonNullable[1]> { + if (error instanceof ProxyError) { + return { + ...endpointAudit, + reason: "retry_failed", + circuitState: getCircuitState(provider.id), + attemptNumber, + errorMessage, + statusCode: error.statusCode, + statusCodeInferred: error.upstreamError?.statusCodeInferred ?? false, + errorDetails: { + provider: { + id: provider.id, + name: provider.name, + statusCode: error.statusCode, + statusText: error.message, + upstreamBody: error.upstreamError?.body, + upstreamParsed: error.upstreamError?.parsed, + }, + request: requestDetailsBeforeRectify, + }, + }; + } + + return { + ...endpointAudit, + reason: "retry_failed", + circuitState: getCircuitState(provider.id), + attemptNumber, + errorMessage, + errorDetails: { + system: { + errorType: error.constructor.name, + errorName: error.name, + errorMessage: error.message || error.name || "Unknown error", + errorStack: error.stack?.split("\n").slice(0, 3).join("\n"), + }, + request: requestDetailsBeforeRectify, + }, + }; +} + +function getReactiveRectifierDisplayName( + rectifierType: "thinking_signature_rectifier" | "thinking_budget_rectifier" +): string { + return rectifierType === "thinking_signature_rectifier" + ? "Thinking signature rectifier" + : "Thinking budget rectifier"; +} + +async function tryApplyReactiveAnthropicRectifier(params: { + provider: Provider; + requestSession: ProxySession; + persistSession: ProxySession; + errorMessage: string; + attemptNumber: number; + retryAttemptNumber: number; + retryState: ReactiveRectifierRetryState; +}): Promise { + const { + provider, + requestSession, + persistSession, + errorMessage, + attemptNumber, + retryAttemptNumber, + } = params; + const isAnthropicProvider = + provider.providerType === "claude" || provider.providerType === "claude-auth"; + + if (!isAnthropicProvider) { + return { matched: false }; + } + + const signatureTrigger = detectThinkingSignatureRectifierTrigger(errorMessage); + if (signatureTrigger) { + const settings = await getCachedSystemSettings(); + const enabled = settings.enableThinkingSignatureRectifier ?? true; + + if (!enabled) { + return { matched: false }; + } + + if (params.retryState.thinkingSignatureRetried) { + return { + matched: true, + applied: false, + reason: "already_retried", + rectifierType: "thinking_signature_rectifier", + trigger: signatureTrigger, + }; + } + + const requestDetailsBeforeRectify = buildRequestDetails(requestSession); + const rectified = rectifyAnthropicRequestMessage( + requestSession.request.message as Record + ); + + addSpecialSettingForPersistence(requestSession, persistSession, { + type: "thinking_signature_rectifier", + scope: "request", + hit: rectified.applied, + providerId: provider.id, + providerName: provider.name, + trigger: signatureTrigger, + attemptNumber, + retryAttemptNumber, + removedThinkingBlocks: rectified.removedThinkingBlocks, + removedRedactedThinkingBlocks: rectified.removedRedactedThinkingBlocks, + removedSignatureFields: rectified.removedSignatureFields, + }); + await persistSpecialSettings(persistSession); + + if (!rectified.applied) { + return { + matched: true, + applied: false, + reason: "not_applicable", + rectifierType: "thinking_signature_rectifier", + trigger: signatureTrigger, + }; + } + + params.retryState.thinkingSignatureRetried = true; + return { + matched: true, + applied: true, + rectifierType: "thinking_signature_rectifier", + trigger: signatureTrigger, + requestDetailsBeforeRectify, + }; + } + + const budgetTrigger = detectThinkingBudgetRectifierTrigger(errorMessage); + if (!budgetTrigger) { + return { matched: false }; + } + + const settings = await getCachedSystemSettings(); + const enabled = settings.enableThinkingBudgetRectifier ?? true; + + if (!enabled) { + return { matched: false }; + } + + if (params.retryState.thinkingBudgetRetried) { + return { + matched: true, + applied: false, + reason: "already_retried", + rectifierType: "thinking_budget_rectifier", + trigger: budgetTrigger, + }; + } + + const requestDetailsBeforeRectify = buildRequestDetails(requestSession); + const rectified = rectifyThinkingBudget( + requestSession.request.message as Record + ); + + addSpecialSettingForPersistence(requestSession, persistSession, { + type: "thinking_budget_rectifier", + scope: "request", + hit: rectified.applied, + providerId: provider.id, + providerName: provider.name, + trigger: budgetTrigger, + attemptNumber, + retryAttemptNumber, + before: rectified.before, + after: rectified.after, + }); + await persistSpecialSettings(persistSession); + + if (!rectified.applied) { + return { + matched: true, + applied: false, + reason: "not_applicable", + rectifierType: "thinking_budget_rectifier", + trigger: budgetTrigger, + }; + } + + params.retryState.thinkingBudgetRetried = true; + return { + matched: true, + applied: true, + rectifierType: "thinking_budget_rectifier", + trigger: budgetTrigger, + requestDetailsBeforeRectify, + }; +} + /** * 为 Claude 请求注入 metadata.user_id * @@ -521,8 +761,10 @@ export class ProxyForwarder { currentProvider, envDefaultMaxAttempts ); - let thinkingSignatureRectifierRetried = false; - let thinkingBudgetRectifierRetried = false; + const reactiveRectifierRetryState: ReactiveRectifierRetryState = { + thinkingSignatureRetried: false, + thinkingBudgetRetried: false, + }; const requestPath = session.requestUrl.pathname; const providerVendorId = currentProvider.providerVendorId ?? 0; @@ -1046,280 +1288,63 @@ export class ProxyForwarder { throw lastError; } - // 2.5 Thinking signature 整流器:命中后对同供应商“整流 + 重试一次” - // 目标:解决 Anthropic 与非 Anthropic 渠道切换导致的 thinking 签名不兼容问题 - // 约束: - // - 仅对 Anthropic 类型供应商生效 - // - 不依赖 error rules 开关(用户可能关闭规则,但仍希望整流生效) - // - 不计入熔断器、不触发供应商切换 - const isAnthropicProvider = - currentProvider.providerType === "claude" || - currentProvider.providerType === "claude-auth"; - const rectifierTrigger = isAnthropicProvider - ? detectThinkingSignatureRectifierTrigger(errorMessage) - : null; - - if (rectifierTrigger) { - const settings = await getCachedSystemSettings(); - const enabled = settings.enableThinkingSignatureRectifier ?? true; - - if (enabled) { - // 已重试过仍失败:强制按“不可重试的客户端错误”处理,避免污染熔断器/触发供应商切换 - if (thinkingSignatureRectifierRetried) { - errorCategory = ErrorCategory.NON_RETRYABLE_CLIENT_ERROR; - } else { - const requestDetailsBeforeRectify = buildRequestDetails(session); - - // 整流请求体(原地修改 session.request.message) - const rectified = rectifyAnthropicRequestMessage( - session.request.message as Record - ); - - // 写入审计字段(specialSettings) - session.addSpecialSetting({ - type: "thinking_signature_rectifier", - scope: "request", - hit: rectified.applied, - providerId: currentProvider.id, - providerName: currentProvider.name, - trigger: rectifierTrigger, - attemptNumber: attemptCount, - retryAttemptNumber: attemptCount + 1, - removedThinkingBlocks: rectified.removedThinkingBlocks, - removedRedactedThinkingBlocks: rectified.removedRedactedThinkingBlocks, - removedSignatureFields: rectified.removedSignatureFields, - }); - - const specialSettings = session.getSpecialSettings(); - if (specialSettings && session.sessionId) { - try { - await SessionManager.storeSessionSpecialSettings( - session.sessionId, - specialSettings, - session.requestSequence - ); - } catch (persistError) { - logger.error("[ProxyForwarder] Failed to store special settings", { - error: persistError, - sessionId: session.sessionId, - }); - } - } - - if (specialSettings && session.messageContext?.id) { - try { - await updateMessageRequestDetails(session.messageContext.id, { - specialSettings, - }); - } catch (persistError) { - logger.error("[ProxyForwarder] Failed to persist special settings", { - error: persistError, - messageRequestId: session.messageContext.id, - }); - } - } + // 2.5 Reactive rectifier:命中后对同供应商“整流 + 重试一次” + const reactiveRectifierResult = await tryApplyReactiveAnthropicRectifier({ + provider: currentProvider, + requestSession: session, + persistSession: session, + errorMessage, + attemptNumber: attemptCount, + retryAttemptNumber: attemptCount + 1, + retryState: reactiveRectifierRetryState, + }); - // 无任何可整流内容:不做无意义重试,直接走既有“不可重试客户端错误”分支 - if (!rectified.applied) { - logger.info( - "ProxyForwarder: Thinking signature rectifier not applicable, skipping retry", - { - providerId: currentProvider.id, - providerName: currentProvider.name, - trigger: rectifierTrigger, - attemptNumber: attemptCount, - } - ); - errorCategory = ErrorCategory.NON_RETRYABLE_CLIENT_ERROR; - } else { - logger.info("ProxyForwarder: Thinking signature rectifier applied, retrying", { + if (reactiveRectifierResult.matched) { + if (!reactiveRectifierResult.applied) { + if (reactiveRectifierResult.reason === "not_applicable") { + logger.info( + `ProxyForwarder: ${getReactiveRectifierDisplayName( + reactiveRectifierResult.rectifierType + )} not applicable, skipping retry`, + { providerId: currentProvider.id, providerName: currentProvider.name, - trigger: rectifierTrigger, + trigger: reactiveRectifierResult.trigger, attemptNumber: attemptCount, - willRetryAttemptNumber: attemptCount + 1, - }); - - thinkingSignatureRectifierRetried = true; - - // 记录失败的第一次请求(以 retry_failed 体现“发生过一次重试”) - if (lastError instanceof ProxyError) { - session.addProviderToChain(currentProvider, { - ...endpointAudit, - reason: "retry_failed", - circuitState: getCircuitState(currentProvider.id), - attemptNumber: attemptCount, - errorMessage, - statusCode: lastError.statusCode, - statusCodeInferred: lastError.upstreamError?.statusCodeInferred ?? false, - errorDetails: { - provider: { - id: currentProvider.id, - name: currentProvider.name, - statusCode: lastError.statusCode, - statusText: lastError.message, - upstreamBody: lastError.upstreamError?.body, - upstreamParsed: lastError.upstreamError?.parsed, - }, - request: requestDetailsBeforeRectify, - }, - }); - } else { - session.addProviderToChain(currentProvider, { - ...endpointAudit, - reason: "retry_failed", - circuitState: getCircuitState(currentProvider.id), - attemptNumber: attemptCount, - errorMessage, - errorDetails: { - system: { - errorType: lastError.constructor.name, - errorName: lastError.name, - errorMessage: lastError.message || lastError.name || "Unknown error", - errorStack: lastError.stack?.split("\n").slice(0, 3).join("\n"), - }, - request: requestDetailsBeforeRectify, - }, - }); } - - // 确保即使 maxAttemptsPerProvider=1 也能完成一次额外重试 - maxAttemptsPerProvider = Math.max(maxAttemptsPerProvider, attemptCount + 1); - continue; - } - } - } - } - - // 2.6 Thinking budget rectifier: fix budget_tokens < 1024 errors and retry once - const budgetRectifierTrigger = isAnthropicProvider - ? detectThinkingBudgetRectifierTrigger(errorMessage) - : null; - - if (budgetRectifierTrigger) { - const settings = await getCachedSystemSettings(); - const budgetRectifierEnabled = settings.enableThinkingBudgetRectifier ?? true; - - if (budgetRectifierEnabled) { - if (thinkingBudgetRectifierRetried) { - errorCategory = ErrorCategory.NON_RETRYABLE_CLIENT_ERROR; - } else { - const requestDetailsBeforeRectify = buildRequestDetails(session); - - const budgetRectified = rectifyThinkingBudget( - session.request.message as Record ); + } - session.addSpecialSetting({ - type: "thinking_budget_rectifier", - scope: "request", - hit: budgetRectified.applied, + errorCategory = ErrorCategory.NON_RETRYABLE_CLIENT_ERROR; + } else { + logger.info( + `ProxyForwarder: ${getReactiveRectifierDisplayName( + reactiveRectifierResult.rectifierType + )} applied, retrying`, + { providerId: currentProvider.id, providerName: currentProvider.name, - trigger: budgetRectifierTrigger, + trigger: reactiveRectifierResult.trigger, attemptNumber: attemptCount, - retryAttemptNumber: attemptCount + 1, - before: budgetRectified.before, - after: budgetRectified.after, - }); - - const specialSettings = session.getSpecialSettings(); - if (specialSettings && session.sessionId) { - try { - await SessionManager.storeSessionSpecialSettings( - session.sessionId, - specialSettings, - session.requestSequence - ); - } catch (persistError) { - logger.error("[ProxyForwarder] Failed to store special settings", { - error: persistError, - sessionId: session.sessionId, - }); - } + willRetryAttemptNumber: attemptCount + 1, } + ); - if (specialSettings && session.messageContext?.id) { - try { - await updateMessageRequestDetails(session.messageContext.id, { - specialSettings, - }); - } catch (persistError) { - logger.error("[ProxyForwarder] Failed to persist special settings", { - error: persistError, - messageRequestId: session.messageContext.id, - }); - } - } - - if (!budgetRectified.applied) { - logger.info( - "ProxyForwarder: Thinking budget rectifier not applicable, skipping retry", - { - providerId: currentProvider.id, - providerName: currentProvider.name, - trigger: budgetRectifierTrigger, - attemptNumber: attemptCount, - } - ); - errorCategory = ErrorCategory.NON_RETRYABLE_CLIENT_ERROR; - } else { - logger.info("ProxyForwarder: Thinking budget rectifier applied, retrying", { - providerId: currentProvider.id, - providerName: currentProvider.name, - trigger: budgetRectifierTrigger, - attemptNumber: attemptCount, - willRetryAttemptNumber: attemptCount + 1, - before: budgetRectified.before, - after: budgetRectified.after, - }); - - thinkingBudgetRectifierRetried = true; - - if (lastError instanceof ProxyError) { - session.addProviderToChain(currentProvider, { - ...endpointAudit, - reason: "retry_failed", - circuitState: getCircuitState(currentProvider.id), - attemptNumber: attemptCount, - errorMessage, - statusCode: lastError.statusCode, - statusCodeInferred: lastError.upstreamError?.statusCodeInferred ?? false, - errorDetails: { - provider: { - id: currentProvider.id, - name: currentProvider.name, - statusCode: lastError.statusCode, - statusText: lastError.message, - upstreamBody: lastError.upstreamError?.body, - upstreamParsed: lastError.upstreamError?.parsed, - }, - request: requestDetailsBeforeRectify, - }, - }); - } else { - session.addProviderToChain(currentProvider, { - ...endpointAudit, - reason: "retry_failed", - circuitState: getCircuitState(currentProvider.id), - attemptNumber: attemptCount, - errorMessage, - errorDetails: { - system: { - errorType: lastError.constructor.name, - errorName: lastError.name, - errorMessage: lastError.message || lastError.name || "Unknown error", - errorStack: lastError.stack?.split("\n").slice(0, 3).join("\n"), - }, - request: requestDetailsBeforeRectify, - }, - }); - } + session.addProviderToChain( + currentProvider, + buildRetryFailedChainEntry( + currentProvider, + endpointAudit, + attemptCount, + lastError, + errorMessage, + reactiveRectifierResult.requestDetailsBeforeRectify + ) + ); - maxAttemptsPerProvider = Math.max(maxAttemptsPerProvider, attemptCount + 1); - continue; - } - } + // 确保即使 maxAttemptsPerProvider=1 也能完成一次额外重试 + maxAttemptsPerProvider = Math.max(maxAttemptsPerProvider, attemptCount + 1); + continue; } } @@ -3053,20 +3078,88 @@ export class ProxyForwarder { await launchingAlternative; }; + const runAttempt = (attempt: StreamingHedgeAttempt) => { + const providerForRequest = + attempt.provider.firstByteTimeoutStreamingMs > 0 + ? { ...attempt.provider, firstByteTimeoutStreamingMs: 0 } + : attempt.provider; + + void ProxyForwarder.doForward( + attempt.session, + providerForRequest, + attempt.baseUrl, + attempt.endpointAudit, + attempt.requestAttemptCount + ) + .then(async (response) => { + if (settled || winnerCommitted || attempt.settled) { + const attemptRuntime = attempt.session as ProxySessionWithAttemptRuntime; + try { + attemptRuntime.responseController?.abort(new Error("hedge_loser")); + } catch { + // ignore + } + const cancelPromise = response.body?.cancel("hedge_loser"); + cancelPromise?.catch(() => { + // ignore + }); + return; + } + + const attemptRuntime = attempt.session as ProxySessionWithAttemptRuntime; + attempt.responseController = attemptRuntime.responseController ?? null; + attempt.clearResponseTimeout = attemptRuntime.clearResponseTimeout ?? null; + attempt.clearResponseTimeout?.(); + attempt.response = response; + + if (!response.body) { + await handleAttemptFailure( + attempt, + new EmptyResponseError(attempt.provider.id, attempt.provider.name, "empty_body") + ); + return; + } + + attempt.reader = response.body.getReader(); + + try { + const firstChunk = await ProxyForwarder.readFirstReadableChunk(attempt.reader); + if (firstChunk.done) { + await handleAttemptFailure( + attempt, + new EmptyResponseError(attempt.provider.id, attempt.provider.name, "empty_body") + ); + return; + } + + await commitWinner(attempt, firstChunk.value); + } catch (firstChunkError) { + const normalizedError = + firstChunkError instanceof Error + ? firstChunkError + : new Error(String(firstChunkError)); + if (settled || winnerCommitted) return; + await handleAttemptFailure(attempt, normalizedError); + } + }) + .catch(async (attemptError) => { + const normalizedError = + attemptError instanceof Error ? attemptError : new Error(String(attemptError)); + if (settled || winnerCommitted) return; + await handleAttemptFailure(attempt, normalizedError); + }); + }; + const handleAttemptFailure = async (attempt: StreamingHedgeAttempt, error: Error) => { if (settled || winnerCommitted || attempt.settled) return; - attempt.settled = true; - if (attempt.thresholdTimer) { - clearTimeout(attempt.thresholdTimer); - attempt.thresholdTimer = null; - } - attempts.delete(attempt); lastError = error; - const errorCategory = await categorizeErrorAsync(error); + let errorCategory = await categorizeErrorAsync(error); lastErrorCategory = errorCategory; const statusCode = error instanceof ProxyError ? error.statusCode : undefined; + const errorMessage = + error instanceof ProxyError ? error.getDetailedErrorMessage() : error.message; if (attempt.endpointAudit.endpointId != null) { const isTimeoutError = error instanceof ProxyError && error.statusCode === 524; @@ -3076,6 +3169,13 @@ export class ProxyForwarder { } if (errorCategory === ErrorCategory.CLIENT_ABORT) { + attempt.settled = true; + if (attempt.thresholdTimer) { + clearTimeout(attempt.thresholdTimer); + attempt.thresholdTimer = null; + } + attempts.delete(attempt); + session.addProviderToChain(attempt.provider, { ...attempt.endpointAudit, reason: "client_abort", @@ -3092,6 +3192,79 @@ export class ProxyForwarder { return; } + const reactiveRectifierResult = await tryApplyReactiveAnthropicRectifier({ + provider: attempt.provider, + requestSession: attempt.session, + persistSession: session, + errorMessage, + attemptNumber: attempt.requestAttemptCount, + retryAttemptNumber: attempt.requestAttemptCount + 1, + retryState: attempt.reactiveRectifierRetryState, + }); + + if (reactiveRectifierResult.matched) { + if (!reactiveRectifierResult.applied) { + if (reactiveRectifierResult.reason === "not_applicable") { + logger.info( + `ProxyForwarder: ${getReactiveRectifierDisplayName( + reactiveRectifierResult.rectifierType + )} not applicable in hedge, skipping retry`, + { + providerId: attempt.provider.id, + providerName: attempt.provider.name, + trigger: reactiveRectifierResult.trigger, + participantSequence: attempt.sequence, + attemptNumber: attempt.requestAttemptCount, + } + ); + } + + errorCategory = ErrorCategory.NON_RETRYABLE_CLIENT_ERROR; + lastErrorCategory = errorCategory; + } else { + logger.info( + `ProxyForwarder: ${getReactiveRectifierDisplayName( + reactiveRectifierResult.rectifierType + )} applied in hedge, retrying same provider`, + { + providerId: attempt.provider.id, + providerName: attempt.provider.name, + trigger: reactiveRectifierResult.trigger, + participantSequence: attempt.sequence, + attemptNumber: attempt.requestAttemptCount, + willRetryAttemptNumber: attempt.requestAttemptCount + 1, + } + ); + + session.addProviderToChain( + attempt.provider, + buildRetryFailedChainEntry( + attempt.provider, + attempt.endpointAudit, + attempt.requestAttemptCount, + error, + errorMessage, + reactiveRectifierResult.requestDetailsBeforeRectify + ) + ); + + if (attempt.thresholdTimer) { + clearTimeout(attempt.thresholdTimer); + attempt.thresholdTimer = null; + } + attempt.requestAttemptCount += 1; + runAttempt(attempt); + return; + } + } + + attempt.settled = true; + if (attempt.thresholdTimer) { + clearTimeout(attempt.thresholdTimer); + attempt.thresholdTimer = null; + } + attempts.delete(attempt); + if (errorCategory === ErrorCategory.PROVIDER_ERROR && statusCode !== 404) { await recordFailure(attempt.provider.id, error); } @@ -3106,7 +3279,7 @@ export class ProxyForwarder { : "retry_failed", attemptNumber: attempt.sequence, statusCode, - errorMessage: error instanceof ProxyError ? error.getDetailedErrorMessage() : error.message, + errorMessage, circuitState: getCircuitState(attempt.provider.id), }); @@ -3240,6 +3413,7 @@ export class ProxyForwarder { const attempt: StreamingHedgeAttempt = { provider, session: attemptSession, + baseUrl: endpointSelection.baseUrl, endpointAudit: { endpointId: endpointSelection.endpointId, endpointUrl: endpointSelection.endpointUrl, @@ -3249,6 +3423,11 @@ export class ProxyForwarder { firstByteTimeoutMs: provider.firstByteTimeoutStreamingMs > 0 ? provider.firstByteTimeoutStreamingMs : 0, sequence: launchedProviderCount, + requestAttemptCount: 1, + reactiveRectifierRetryState: { + thinkingSignatureRetried: false, + thinkingBudgetRetried: false, + }, settled: false, thresholdTriggered: false, thresholdTimer: null, @@ -3283,75 +3462,7 @@ export class ProxyForwarder { }, attempt.firstByteTimeoutMs); } - const providerForRequest = - provider.firstByteTimeoutStreamingMs > 0 - ? { ...provider, firstByteTimeoutStreamingMs: 0 } - : provider; - - void ProxyForwarder.doForward( - attemptSession, - providerForRequest, - endpointSelection.baseUrl, - attempt.endpointAudit, - 1 - ) - .then(async (response) => { - if (settled || winnerCommitted) { - const attemptRuntime = attemptSession as ProxySessionWithAttemptRuntime; - try { - attemptRuntime.responseController?.abort(new Error("hedge_loser")); - } catch { - // ignore - } - const cancelPromise = response.body?.cancel("hedge_loser"); - cancelPromise?.catch(() => { - // ignore - }); - return; - } - - const attemptRuntime = attemptSession as ProxySessionWithAttemptRuntime; - attempt.responseController = attemptRuntime.responseController ?? null; - attempt.clearResponseTimeout = attemptRuntime.clearResponseTimeout ?? null; - attempt.clearResponseTimeout?.(); - attempt.response = response; - - if (!response.body) { - await handleAttemptFailure( - attempt, - new EmptyResponseError(provider.id, provider.name, "empty_body") - ); - return; - } - - attempt.reader = response.body.getReader(); - - try { - const firstChunk = await ProxyForwarder.readFirstReadableChunk(attempt.reader); - if (firstChunk.done) { - await handleAttemptFailure( - attempt, - new EmptyResponseError(provider.id, provider.name, "empty_body") - ); - return; - } - - await commitWinner(attempt, firstChunk.value); - } catch (firstChunkError) { - const normalizedError = - firstChunkError instanceof Error - ? firstChunkError - : new Error(String(firstChunkError)); - if (settled || winnerCommitted) return; - await handleAttemptFailure(attempt, normalizedError); - } - }) - .catch(async (attemptError) => { - const normalizedError = - attemptError instanceof Error ? attemptError : new Error(String(attemptError)); - if (settled || winnerCommitted) return; - await handleAttemptFailure(attempt, normalizedError); - }); + runAttempt(attempt); }; if (session.clientAbortSignal) { @@ -3563,7 +3674,15 @@ export class ProxyForwarder { } } targetState.providerChain = mergedProviderChain; - targetState.specialSettings = [...sourceState.specialSettings]; + // 合并 specialSettings,避免覆盖已有的 rectifier audit 记录 + const existingKeys = new Set(targetState.specialSettings.map((s) => JSON.stringify(s))); + const merged = [...targetState.specialSettings]; + for (const setting of sourceState.specialSettings) { + if (!existingKeys.has(JSON.stringify(setting))) { + merged.push(setting); + } + } + targetState.specialSettings = merged; targetState.originalModelName = sourceState.originalModelName; targetState.originalUrlPathname = sourceState.originalUrlPathname; targetState.clearResponseTimeout = sourceRuntime.clearResponseTimeout; diff --git a/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts b/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts index c27b49769..b4fd5012d 100644 --- a/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts +++ b/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts @@ -21,6 +21,10 @@ const mocks = vi.hoisted(() => ({ isVendorTypeCircuitOpen: vi.fn(async () => false), recordVendorTypeAllEndpointsTimeout: vi.fn(async () => {}), categorizeErrorAsync: vi.fn(async () => 0), + getCachedSystemSettings: vi.fn(async () => ({ + enableThinkingSignatureRectifier: true, + enableThinkingBudgetRectifier: true, + })), storeSessionSpecialSettings: vi.fn(async () => {}), })); @@ -39,6 +43,7 @@ vi.mock("@/lib/config", async (importOriginal) => { const actual = await importOriginal(); return { ...actual, + getCachedSystemSettings: mocks.getCachedSystemSettings, isHttp2Enabled: mocks.isHttp2Enabled, }; }); @@ -276,6 +281,23 @@ function createDelayedFailure(params: { }); } +function withThinkingBlocks(session: ProxySession): void { + session.request.message = { + model: "claude-test", + stream: true, + messages: [ + { + role: "assistant", + content: [ + { type: "thinking", thinking: "t", signature: "sig_thinking" }, + { type: "text", text: "hello", signature: "sig_text_should_remove" }, + { type: "redacted_thinking", data: "r", signature: "sig_redacted" }, + ], + }, + ], + }; +} + describe("ProxyForwarder - first-byte hedge scheduling", () => { beforeEach(() => { vi.clearAllMocks(); @@ -842,6 +864,223 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { ); }); + test("hedge 备选供应商命中 thinking signature 错误时,应整流后在同供应商重试并保留审计", async () => { + vi.useFakeTimers(); + + try { + const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 }); + const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 }); + const session = createSession(); + session.setProvider(provider1); + withThinkingBlocks(session); + + mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2); + mocks.categorizeErrorAsync.mockResolvedValue(ProxyErrorCategory.NON_RETRYABLE_CLIENT_ERROR); + + const signatureError = new UpstreamProxyError( + "Invalid `signature` in `thinking` block", + 400, + { + body: '{"error":"invalid_signature"}', + providerId: provider2.id, + providerName: provider2.name, + } + ); + + const doForward = vi.spyOn( + ProxyForwarder as unknown as { + doForward: (...args: unknown[]) => Promise; + }, + "doForward" + ); + + const controller1 = new AbortController(); + const controller2First = new AbortController(); + const controller2Retry = new AbortController(); + + doForward.mockImplementationOnce(async (attemptSession) => { + const runtime = attemptSession as ProxySession & AttemptRuntime; + runtime.responseController = controller1; + runtime.clearResponseTimeout = vi.fn(); + return createStreamingResponse({ + label: "p1", + firstChunkDelayMs: 600, + controller: controller1, + }); + }); + + doForward.mockImplementationOnce(async (attemptSession) => { + const runtime = attemptSession as ProxySession & AttemptRuntime; + runtime.responseController = controller2First; + runtime.clearResponseTimeout = vi.fn(); + return createDelayedFailure({ + delayMs: 50, + error: signatureError, + controller: controller2First, + }); + }); + + doForward.mockImplementationOnce(async (attemptSession) => { + const runtime = attemptSession as ProxySession & AttemptRuntime; + const body = runtime.request.message as { + messages: Array<{ content: Array> }>; + }; + const blocks = body.messages[0].content; + + expect(blocks.some((block) => block.type === "thinking")).toBe(false); + expect(blocks.some((block) => block.type === "redacted_thinking")).toBe(false); + expect(blocks.some((block) => "signature" in block)).toBe(false); + + runtime.responseController = controller2Retry; + runtime.clearResponseTimeout = vi.fn(); + return createStreamingResponse({ + label: "p2-rectified", + firstChunkDelayMs: 180, + controller: controller2Retry, + }); + }); + + const responsePromise = ProxyForwarder.send(session); + + await vi.advanceTimersByTimeAsync(100); + expect(doForward).toHaveBeenCalledTimes(2); + + await vi.advanceTimersByTimeAsync(55); + expect(doForward).toHaveBeenCalledTimes(3); + + await vi.advanceTimersByTimeAsync(200); + const response = await responsePromise; + + expect(await response.text()).toContain('"provider":"p2-rectified"'); + expect(session.provider?.id).toBe(2); + expect(controller1.signal.aborted).toBe(true); + expect(mocks.pickRandomProviderWithExclusion).toHaveBeenCalled(); + expect(mocks.storeSessionSpecialSettings).toHaveBeenCalledWith( + "sess-hedge", + expect.arrayContaining([ + expect.objectContaining({ + type: "thinking_signature_rectifier", + hit: true, + providerId: 2, + }), + ]), + 1 + ); + } finally { + vi.useRealTimers(); + } + }); + + test("hedge 路径命中 thinking budget 错误时,应整流后在同供应商重试", async () => { + vi.useFakeTimers(); + + try { + const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 }); + const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 }); + const session = createSession(); + session.setProvider(provider1); + session.request.message = { + model: "claude-test", + stream: true, + max_tokens: 1000, + thinking: { type: "enabled", budget_tokens: 500 }, + messages: [{ role: "user", content: "hi" }], + }; + + mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2); + mocks.categorizeErrorAsync.mockResolvedValue(ProxyErrorCategory.NON_RETRYABLE_CLIENT_ERROR); + + const budgetError = new UpstreamProxyError( + "thinking.enabled.budget_tokens: Input should be greater than or equal to 1024", + 400, + { + body: '{"error":"budget_too_low"}', + providerId: provider1.id, + providerName: provider1.name, + } + ); + + const doForward = vi.spyOn( + ProxyForwarder as unknown as { + doForward: (...args: unknown[]) => Promise; + }, + "doForward" + ); + + const controller1First = new AbortController(); + const controller1Retry = new AbortController(); + const controller2 = new AbortController(); + + doForward.mockImplementationOnce(async (attemptSession) => { + const runtime = attemptSession as ProxySession & AttemptRuntime; + runtime.responseController = controller1First; + runtime.clearResponseTimeout = vi.fn(); + return createDelayedFailure({ + delayMs: 140, + error: budgetError, + controller: controller1First, + }); + }); + + doForward.mockImplementationOnce(async (attemptSession) => { + const runtime = attemptSession as ProxySession & AttemptRuntime; + runtime.responseController = controller2; + runtime.clearResponseTimeout = vi.fn(); + return createStreamingResponse({ + label: "p2", + firstChunkDelayMs: 500, + controller: controller2, + }); + }); + + doForward.mockImplementationOnce(async (attemptSession) => { + const runtime = attemptSession as ProxySession & AttemptRuntime; + const body = runtime.request.message as { + max_tokens: number; + thinking: { type: string; budget_tokens: number }; + }; + + expect(body.max_tokens).toBe(64000); + expect(body.thinking.type).toBe("enabled"); + expect(body.thinking.budget_tokens).toBe(32000); + + runtime.responseController = controller1Retry; + runtime.clearResponseTimeout = vi.fn(); + return createStreamingResponse({ + label: "p1-budget-rectified", + firstChunkDelayMs: 40, + controller: controller1Retry, + }); + }); + + const responsePromise = ProxyForwarder.send(session); + + await vi.advanceTimersByTimeAsync(100); + expect(doForward).toHaveBeenCalledTimes(2); + + await vi.advanceTimersByTimeAsync(45); + expect(doForward).toHaveBeenCalledTimes(3); + + await vi.advanceTimersByTimeAsync(50); + const response = await responsePromise; + + expect(await response.text()).toContain('"provider":"p1-budget-rectified"'); + expect(session.provider?.id).toBe(1); + expect(mocks.pickRandomProviderWithExclusion).toHaveBeenCalledTimes(1); + expect(session.getSpecialSettings()).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + type: "thinking_budget_rectifier", + hit: true, + providerId: 1, + }), + ]) + ); + } finally { + vi.useRealTimers(); + } + }); + test("endpoint resolution failure should not inflate launchedProviderCount, winner gets request_success not hedge_winner", async () => { vi.useFakeTimers(); From b8b0f9d4a7eebae2f7badcea73b909f1439d077e Mon Sep 17 00:00:00 2001 From: ding113 Date: Sat, 21 Mar 2026 08:56:44 +0000 Subject: [PATCH 02/11] feat: support unicode provider groups --- src/actions/keys.ts | 6 +- src/actions/providers.ts | 28 ++--- src/actions/request-filters.ts | 3 +- src/actions/users.ts | 8 +- .../_components/user/forms/add-key-form.tsx | 7 +- .../_components/user/forms/edit-key-form.tsx | 7 +- .../user/forms/key-edit-section.tsx | 16 +-- .../user/forms/provider-group-select.tsx | 7 +- .../_components/user/forms/user-form.tsx | 1 + .../_components/user/key-row-item.tsx | 6 +- .../_components/user/user-key-table-row.tsx | 6 +- .../_components/user/utils/provider-group.ts | 24 +--- .../_components/provider-chain-popover.tsx | 12 +- .../dashboard/users/users-page-client.tsx | 6 +- .../batch-edit/analyze-batch-settings.ts | 10 +- .../batch-edit/provider-batch-toolbar.tsx | 6 +- .../forms/provider-form.legacy.tsx | 11 +- .../provider-form/provider-form-context.tsx | 8 +- .../sections/routing-section.tsx | 1 + .../_components/provider-manager.tsx | 28 +---- .../_components/provider-rich-list-item.tsx | 10 +- src/app/v1/_lib/proxy/provider-selector.ts | 6 +- src/components/form/form-field.tsx | 8 +- .../provider-group-tag-input.test.tsx | 119 ++++++++++++++++++ src/lib/provider-patch-contract.ts | 6 + src/lib/request-filter-engine.ts | 5 +- src/lib/utils/provider-group.test.ts | 24 ++++ src/lib/utils/provider-group.ts | 46 ++++--- src/repository/leaderboard.ts | 2 +- src/repository/provider.ts | 8 +- src/repository/user.ts | 8 +- 31 files changed, 254 insertions(+), 189 deletions(-) create mode 100644 src/components/ui/__tests__/provider-group-tag-input.test.tsx create mode 100644 src/lib/utils/provider-group.test.ts diff --git a/src/actions/keys.ts b/src/actions/keys.ts index 264a6118f..4833bfe7a 100644 --- a/src/actions/keys.ts +++ b/src/actions/keys.ts @@ -607,11 +607,7 @@ export async function removeKey(keyId: number): Promise { for (const k of userKeys) { if (k.id === keyId) continue; const group = k.providerGroup || PROVIDER_GROUP.DEFAULT; - group - .split(",") - .map((g) => g.trim()) - .filter(Boolean) - .forEach((g) => remainingGroups.add(g)); + parseProviderGroups(group).forEach((g) => remainingGroups.add(g)); } const { findUserById } = await import("@/repository/user"); diff --git a/src/actions/providers.ts b/src/actions/providers.ts index ea3a43af9..162c95ae0 100644 --- a/src/actions/providers.ts +++ b/src/actions/providers.ts @@ -44,6 +44,7 @@ import { } from "@/lib/redis/circuit-breaker-config"; import { RedisKVStore } from "@/lib/redis/redis-kv-store"; import { SessionManager } from "@/lib/session-manager"; +import { normalizeProviderGroupTag, parseProviderGroups } from "@/lib/utils/provider-group"; import { maskKey } from "@/lib/utils/validation"; import { extractZodErrorCode, formatZodError } from "@/lib/utils/zod-i18n"; import { validateProviderUrlForConnectivity } from "@/lib/validation/provider-url"; @@ -431,10 +432,7 @@ export async function getAvailableProviderGroups(userId?: number): Promise g.trim()) - .filter(Boolean); + const userGroups = parseProviderGroups(user?.providerGroup || PROVIDER_GROUP.DEFAULT); // 管理员通配符:可访问所有分组 if (userGroups.includes(PROVIDER_GROUP.ALL)) { @@ -462,17 +460,12 @@ export async function getProviderGroupsWithCount(): Promise< const groupCounts = new Map(); for (const provider of providers) { - const groupTag = provider.groupTag?.trim(); - if (!groupTag) { + const groups = parseProviderGroups(provider.groupTag); + if (groups.length === 0) { groupCounts.set(PROVIDER_GROUP.DEFAULT, (groupCounts.get(PROVIDER_GROUP.DEFAULT) || 0) + 1); continue; } - const groups = groupTag - .split(",") - .map((g) => g.trim()) - .filter(Boolean); - for (const group of groups) { groupCounts.set(group, (groupCounts.get(group) || 0) + 1); } @@ -589,6 +582,7 @@ export async function addProvider(data: { const payload = { ...validated, + group_tag: normalizeProviderGroupTag(validated.group_tag), limit_5h_usd: validated.limit_5h_usd ?? null, limit_daily_usd: validated.limit_daily_usd ?? null, daily_reset_mode: validated.daily_reset_mode ?? "fixed", @@ -766,6 +760,9 @@ export async function editProvider( const payload = { ...validated, + ...(validated.group_tag !== undefined && { + group_tag: normalizeProviderGroupTag(validated.group_tag), + }), ...(faviconUrl !== undefined && { favicon_url: faviconUrl }), }; @@ -2235,7 +2232,9 @@ export async function batchUpdateProviders( if (updates.cost_multiplier !== undefined) { repositoryUpdates.costMultiplier = updates.cost_multiplier.toString(); } - if (updates.group_tag !== undefined) repositoryUpdates.groupTag = updates.group_tag; + if (updates.group_tag !== undefined) { + repositoryUpdates.groupTag = normalizeProviderGroupTag(updates.group_tag); + } if (updates.model_redirects !== undefined) { repositoryUpdates.modelRedirects = updates.model_redirects; } @@ -4795,10 +4794,7 @@ async function fetchAnthropicModels( * 解析分组字符串为数组 */ function parseGroupString(groupString: string): string[] { - return groupString - .split(",") - .map((g) => g.trim()) - .filter(Boolean); + return parseProviderGroups(groupString); } /** diff --git a/src/actions/request-filters.ts b/src/actions/request-filters.ts index 4098f34df..a5b322ff4 100644 --- a/src/actions/request-filters.ts +++ b/src/actions/request-filters.ts @@ -6,6 +6,7 @@ import { getSession } from "@/lib/auth"; import { logger } from "@/lib/logger"; import { requestFilterEngine } from "@/lib/request-filter-engine"; import type { FilterMatcher, FilterOperation, InsertOp } from "@/lib/request-filter-types"; +import { parseProviderGroups } from "@/lib/utils/provider-group"; import { createRequestFilter, deleteRequestFilter, @@ -458,7 +459,7 @@ export async function getDistinctProviderGroupsAction(): Promise(); for (const row of result) { if (row.groupTag) { - const tags = row.groupTag.split(",").map((tag) => tag.trim()); + const tags = parseProviderGroups(row.groupTag); for (const tag of tags) { if (tag) allTags.add(tag); } diff --git a/src/actions/users.ts b/src/actions/users.ts index 58e1296c8..f6c159d18 100644 --- a/src/actions/users.ts +++ b/src/actions/users.ts @@ -13,7 +13,7 @@ import { getUnauthorizedFields } from "@/lib/permissions/user-field-permissions" import { invalidateCachedUser } from "@/lib/security/api-key-auth-cache"; import { parseDateInputAsTimezone } from "@/lib/utils/date-input"; import { ERROR_CODES } from "@/lib/utils/error-messages"; -import { normalizeProviderGroup } from "@/lib/utils/provider-group"; +import { normalizeProviderGroup, parseProviderGroups } from "@/lib/utils/provider-group"; import { resolveSystemTimezone } from "@/lib/utils/timezone"; import { maskKey } from "@/lib/utils/validation"; import { formatZodError } from "@/lib/utils/zod-i18n"; @@ -192,11 +192,7 @@ export async function syncUserProviderGroupFromKeys(userId: number): Promise g.trim()) - .filter(Boolean) - .forEach((g) => allGroups.add(g)); + parseProviderGroups(group).forEach((g) => allGroups.add(g)); } const newProviderGroup = diff --git a/src/app/[locale]/dashboard/_components/user/forms/add-key-form.tsx b/src/app/[locale]/dashboard/_components/user/forms/add-key-form.tsx index f0f3bd00a..2d05b5960 100644 --- a/src/app/[locale]/dashboard/_components/user/forms/add-key-form.tsx +++ b/src/app/[locale]/dashboard/_components/user/forms/add-key-form.tsx @@ -20,6 +20,7 @@ import { Switch } from "@/components/ui/switch"; import { PROVIDER_GROUP } from "@/lib/constants/provider.constants"; import { useZodForm } from "@/lib/hooks/use-zod-form"; import { getErrorMessage } from "@/lib/utils/error-messages"; +import { parseProviderGroups } from "@/lib/utils/provider-group"; import { KeyFormSchema } from "@/lib/validation/schemas"; import type { KeyDialogUserContext } from "@/types/user"; @@ -125,10 +126,7 @@ export function AddKeyForm({ userId, user, isAdmin = false, onSuccess }: AddKeyF // 选择分组时,自动移除 default(当有多个分组时) const handleProviderGroupChange = useCallback( (newValue: string) => { - const groups = newValue - .split(",") - .map((g) => g.trim()) - .filter(Boolean); + const groups = parseProviderGroups(newValue); if (groups.length > 1 && groups.includes(PROVIDER_GROUP.DEFAULT)) { const withoutDefault = groups.filter((g) => g !== PROVIDER_GROUP.DEFAULT); form.setValue("providerGroup", withoutDefault.join(",")); @@ -206,6 +204,7 @@ export function AddKeyForm({ userId, user, isAdmin = false, onSuccess }: AddKeyF : t("providerGroup.description") } suggestions={providerGroupSuggestions} + validateTag={() => true} onInvalidTag={(_tag, reason) => { const messages: Record = { empty: tUI("emptyTag"), diff --git a/src/app/[locale]/dashboard/_components/user/forms/edit-key-form.tsx b/src/app/[locale]/dashboard/_components/user/forms/edit-key-form.tsx index 03feeda54..1a2698447 100644 --- a/src/app/[locale]/dashboard/_components/user/forms/edit-key-form.tsx +++ b/src/app/[locale]/dashboard/_components/user/forms/edit-key-form.tsx @@ -34,6 +34,7 @@ import { Switch } from "@/components/ui/switch"; import { PROVIDER_GROUP } from "@/lib/constants/provider.constants"; import { useZodForm } from "@/lib/hooks/use-zod-form"; import { getErrorMessage } from "@/lib/utils/error-messages"; +import { parseProviderGroups } from "@/lib/utils/provider-group"; import { KeyFormSchema } from "@/lib/validation/schemas"; import type { KeyDialogUserContext } from "@/types/user"; @@ -183,10 +184,7 @@ export function EditKeyForm({ keyData, user, isAdmin = false, onSuccess }: EditK // 选择分组时,自动移除 default(当有多个分组时) const handleProviderGroupChange = useCallback( (newValue: string) => { - const groups = newValue - .split(",") - .map((g) => g.trim()) - .filter(Boolean); + const groups = parseProviderGroups(newValue); if (groups.length > 1 && groups.includes(PROVIDER_GROUP.DEFAULT)) { const withoutDefault = groups.filter((g) => g !== PROVIDER_GROUP.DEFAULT); form.setValue("providerGroup", withoutDefault.join(",")); @@ -264,6 +262,7 @@ export function EditKeyForm({ keyData, user, isAdmin = false, onSuccess }: EditK : t("providerGroup.description") } suggestions={providerGroupSuggestions} + validateTag={() => true} onInvalidTag={(_tag, reason) => { const messages: Record = { empty: tUI("emptyTag"), diff --git a/src/app/[locale]/dashboard/_components/user/forms/key-edit-section.tsx b/src/app/[locale]/dashboard/_components/user/forms/key-edit-section.tsx index 5da461294..247bbb287 100644 --- a/src/app/[locale]/dashboard/_components/user/forms/key-edit-section.tsx +++ b/src/app/[locale]/dashboard/_components/user/forms/key-edit-section.tsx @@ -19,6 +19,7 @@ import { Switch } from "@/components/ui/switch"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { PROVIDER_GROUP } from "@/lib/constants/provider.constants"; import { cn } from "@/lib/utils"; +import { parseProviderGroups } from "@/lib/utils/provider-group"; import { type DailyResetMode, LimitRulePicker, type LimitType } from "./limit-rule-picker"; import { type LimitRuleDisplayItem, LimitRulesDisplay } from "./limit-rules-display"; import { ProviderGroupSelect } from "./provider-group-select"; @@ -125,10 +126,7 @@ function formatDateInput(date?: Date | null): string { } function normalizeGroupList(value?: string | null): string { - const groups = (value ?? "") - .split(",") - .map((g) => g.trim()) - .filter(Boolean); + const groups = parseProviderGroups(value); if (groups.length === 0) return PROVIDER_GROUP.DEFAULT; return Array.from(new Set(groups)).sort().join(","); } @@ -292,7 +290,7 @@ export function KeyEditSection({ [userProviderGroup] ); const userGroups = useMemo( - () => (normalizedUserProviderGroup ? normalizedUserProviderGroup.split(",") : []), + () => parseProviderGroups(normalizedUserProviderGroup), [normalizedUserProviderGroup] ); const normalizedKeyProviderGroup = useMemo( @@ -301,7 +299,7 @@ export function KeyEditSection({ ); const keyGroupOptions = useMemo(() => { if (!normalizedKeyProviderGroup) return []; - return normalizedKeyProviderGroup.split(",").filter(Boolean); + return parseProviderGroups(normalizedKeyProviderGroup); }, [normalizedKeyProviderGroup]); const _extraKeyGroupOption = useMemo(() => { if (!normalizedKeyProviderGroup) return null; @@ -313,10 +311,7 @@ export function KeyEditSection({ // 普通用户选择分组时,自动移除 default const handleUserProviderGroupChange = useCallback( (newValue: string) => { - const groups = newValue - .split(",") - .map((g) => g.trim()) - .filter(Boolean); + const groups = parseProviderGroups(newValue); // 如果有多个分组且包含 default,移除 default if (groups.length > 1 && groups.includes(PROVIDER_GROUP.DEFAULT)) { const withoutDefault = groups.filter((g) => g !== PROVIDER_GROUP.DEFAULT); @@ -507,6 +502,7 @@ export function KeyEditSection({ suggestions={userGroups} maxTags={userGroups.length + 1} maxTagLength={50} + validateTag={() => true} description={ translations.fields.providerGroup.selectHint || "选择此 Key 可使用的供应商分组" } diff --git a/src/app/[locale]/dashboard/_components/user/forms/provider-group-select.tsx b/src/app/[locale]/dashboard/_components/user/forms/provider-group-select.tsx index 37af648a6..e3bd4c813 100644 --- a/src/app/[locale]/dashboard/_components/user/forms/provider-group-select.tsx +++ b/src/app/[locale]/dashboard/_components/user/forms/provider-group-select.tsx @@ -6,6 +6,7 @@ import { getProviderGroupsWithCount } from "@/actions/providers"; import { TagInputField } from "@/components/form/form-field"; import type { TagInputSuggestion } from "@/components/ui/tag-input"; import { PROVIDER_GROUP } from "@/lib/constants/provider.constants"; +import { parseProviderGroups } from "@/lib/utils/provider-group"; export interface ProviderGroupSelectProps { /** Comma-separated group tags. */ @@ -107,10 +108,7 @@ export function ProviderGroupSelect({ // 选择新分组后自动移除 "default" const handleChange = useCallback( (newValue: string) => { - const groupList = newValue - .split(",") - .map((g) => g.trim()) - .filter(Boolean); + const groupList = parseProviderGroups(newValue); // 如果有多个分组且包含 default,移除 default if (groupList.length > 1 && groupList.includes(PROVIDER_GROUP.DEFAULT)) { const withoutDefault = groupList.filter((g) => g !== PROVIDER_GROUP.DEFAULT); @@ -131,6 +129,7 @@ export function ProviderGroupSelect({ maxTags={20} suggestions={suggestions} disabled={disabled} + validateTag={() => true} onInvalidTag={(_tag, reason) => { toast.error(getTranslation(translations, `tagInputErrors.${reason}`, reason)); }} diff --git a/src/app/[locale]/dashboard/_components/user/forms/user-form.tsx b/src/app/[locale]/dashboard/_components/user/forms/user-form.tsx index e8f8a5be7..e0f1ba852 100644 --- a/src/app/[locale]/dashboard/_components/user/forms/user-form.tsx +++ b/src/app/[locale]/dashboard/_components/user/forms/user-form.tsx @@ -217,6 +217,7 @@ export function UserForm({ user, onSuccess, currentUser }: UserFormProps) { placeholder={tForm("providerGroup.placeholder")} description={tForm("providerGroup.description")} suggestions={providerGroupSuggestions} + validateTag={() => true} onInvalidTag={(_tag, reason) => { const messages: Record = { empty: tUI("emptyTag"), diff --git a/src/app/[locale]/dashboard/_components/user/key-row-item.tsx b/src/app/[locale]/dashboard/_components/user/key-row-item.tsx index 2c7a7fd97..0e5a908b0 100644 --- a/src/app/[locale]/dashboard/_components/user/key-row-item.tsx +++ b/src/app/[locale]/dashboard/_components/user/key-row-item.tsx @@ -35,6 +35,7 @@ import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip import { cn } from "@/lib/utils"; import { CURRENCY_CONFIG, type CurrencyCode, formatCurrency } from "@/lib/utils/currency"; import { formatDate } from "@/lib/utils/date-format"; +import { parseProviderGroups } from "@/lib/utils/provider-group"; import { formatTokenAmount } from "@/lib/utils/token"; import { type QuickRenewKey, QuickRenewKeyDialog } from "./forms/quick-renew-key-dialog"; import { KeyFullDisplayDialog } from "./key-full-display-dialog"; @@ -113,10 +114,7 @@ export interface KeyRowItemProps { const EXPIRING_SOON_MS = 72 * 60 * 60 * 1000; // 72小时 function splitGroups(value?: string | null): string[] { - return (value ?? "") - .split(",") - .map((g) => g.trim()) - .filter(Boolean); + return parseProviderGroups(value); } function formatExpiry(expiresAt: string | null | undefined, locale: string): string { diff --git a/src/app/[locale]/dashboard/_components/user/user-key-table-row.tsx b/src/app/[locale]/dashboard/_components/user/user-key-table-row.tsx index cfbf40711..9d113916a 100644 --- a/src/app/[locale]/dashboard/_components/user/user-key-table-row.tsx +++ b/src/app/[locale]/dashboard/_components/user/user-key-table-row.tsx @@ -26,6 +26,7 @@ import { cn } from "@/lib/utils"; import { getContrastTextColor, getGroupColor } from "@/lib/utils/color"; import { getCurrencySymbol } from "@/lib/utils/currency"; import { formatDate } from "@/lib/utils/date-format"; +import { parseProviderGroups } from "@/lib/utils/provider-group"; import type { UserDisplay } from "@/types/user"; import { EditKeyDialog } from "./edit-key-dialog"; import { KeyRowItem } from "./key-row-item"; @@ -86,10 +87,7 @@ const EXPIRING_SOON_MS = 72 * 60 * 60 * 1000; // 72小时 const MAX_VISIBLE_GROUPS = 2; // 最多显示的分组数量 function splitGroups(value?: string | null): string[] { - return (value ?? "") - .split(",") - .map((g) => g.trim()) - .filter(Boolean); + return parseProviderGroups(value); } function getExpiryStatus( diff --git a/src/app/[locale]/dashboard/_components/user/utils/provider-group.ts b/src/app/[locale]/dashboard/_components/user/utils/provider-group.ts index c1d5fde4f..90dc9b54b 100644 --- a/src/app/[locale]/dashboard/_components/user/utils/provider-group.ts +++ b/src/app/[locale]/dashboard/_components/user/utils/provider-group.ts @@ -1,23 +1 @@ -import { PROVIDER_GROUP } from "@/lib/constants/provider.constants"; - -/** - * Normalize provider group value to a consistent format. - * - Trims whitespace - * - Splits by comma and deduplicates - * - Sorts alphabetically - * - Returns DEFAULT if empty or invalid - */ -export function normalizeProviderGroup(value: unknown): string { - if (value === null || value === undefined) return PROVIDER_GROUP.DEFAULT; - if (typeof value !== "string") return PROVIDER_GROUP.DEFAULT; - const trimmed = value.trim(); - if (trimmed === "") return PROVIDER_GROUP.DEFAULT; - - const groups = trimmed - .split(",") - .map((g) => g.trim()) - .filter(Boolean); - if (groups.length === 0) return PROVIDER_GROUP.DEFAULT; - - return Array.from(new Set(groups)).sort().join(","); -} +export { normalizeProviderGroup } from "@/lib/utils/provider-group"; diff --git a/src/app/[locale]/dashboard/logs/_components/provider-chain-popover.tsx b/src/app/[locale]/dashboard/logs/_components/provider-chain-popover.tsx index 633e617ba..9bb9a2ca9 100644 --- a/src/app/[locale]/dashboard/logs/_components/provider-chain-popover.tsx +++ b/src/app/[locale]/dashboard/logs/_components/provider-chain-popover.tsx @@ -24,6 +24,7 @@ import { isActualRequest, isHedgeRace, } from "@/lib/utils/provider-chain-formatter"; +import { parseProviderGroups } from "@/lib/utils/provider-group"; import type { ProviderChainItem } from "@/types/message"; import { getFake200ReasonKey } from "./fake200-reason"; @@ -37,16 +38,7 @@ interface ProviderChainPopoverProps { } function parseGroupTags(groupTag?: string | null): string[] { - if (!groupTag) return []; - const seen = new Set(); - const groups: string[] = []; - for (const raw of groupTag.split(",")) { - const trimmed = raw.trim(); - if (!trimmed || seen.has(trimmed)) continue; - seen.add(trimmed); - groups.push(trimmed); - } - return groups; + return Array.from(new Set(parseProviderGroups(groupTag))); } /** diff --git a/src/app/[locale]/dashboard/users/users-page-client.tsx b/src/app/[locale]/dashboard/users/users-page-client.tsx index b38e335a7..971cc51e7 100644 --- a/src/app/[locale]/dashboard/users/users-page-client.tsx +++ b/src/app/[locale]/dashboard/users/users-page-client.tsx @@ -27,6 +27,7 @@ import { clearUsageCache } from "@/lib/dashboard/user-limit-usage-cache"; import { loadUserUsagePagesSequentially } from "@/lib/dashboard/user-usage-loader"; import { useDebounce } from "@/lib/hooks/use-debounce"; import type { CurrencyCode } from "@/lib/utils/currency"; +import { parseProviderGroups } from "@/lib/utils/provider-group"; import type { User, UserDisplay } from "@/types/user"; import { AddKeyDialog } from "../_components/user/add-key-dialog"; import { BatchEditDialog } from "../_components/user/batch-edit/batch-edit-dialog"; @@ -38,10 +39,7 @@ import { UserManagementTable } from "../_components/user/user-management-table"; * This matches the server-side providerGroup handling in provider-selector.ts */ function splitTags(value?: string | null): string[] { - return (value ?? "") - .split(",") - .map((t) => t.trim()) - .filter(Boolean); + return parseProviderGroups(value); } interface UsersPageClientProps { diff --git a/src/app/[locale]/settings/providers/_components/batch-edit/analyze-batch-settings.ts b/src/app/[locale]/settings/providers/_components/batch-edit/analyze-batch-settings.ts index f0a0f20e4..0dcdbebeb 100644 --- a/src/app/[locale]/settings/providers/_components/batch-edit/analyze-batch-settings.ts +++ b/src/app/[locale]/settings/providers/_components/batch-edit/analyze-batch-settings.ts @@ -1,3 +1,4 @@ +import { parseProviderGroups } from "@/lib/utils/provider-group"; import type { CacheTtlPreference } from "@/types/cache"; import type { AnthropicAdaptiveThinkingConfig, @@ -116,14 +117,7 @@ export function analyzeBatchProviderSettings(providers: ProviderDisplay[]): Batc priority: analyzeField(providers, (p) => p.priority), weight: analyzeField(providers, (p) => p.weight), costMultiplier: analyzeField(providers, (p) => p.costMultiplier), - groupTag: analyzeField(providers, (p) => - p.groupTag - ? p.groupTag - .split(",") - .map((t) => t.trim()) - .filter(Boolean) - : [] - ), + groupTag: analyzeField(providers, (p) => parseProviderGroups(p.groupTag)), preserveClientIp: analyzeField(providers, (p) => p.preserveClientIp), modelRedirects: analyzeField(providers, (p) => p.modelRedirects ?? {}), allowedModels: analyzeField(providers, (p) => p.allowedModels ?? []), diff --git a/src/app/[locale]/settings/providers/_components/batch-edit/provider-batch-toolbar.tsx b/src/app/[locale]/settings/providers/_components/batch-edit/provider-batch-toolbar.tsx index 40ee6c928..74710435c 100644 --- a/src/app/[locale]/settings/providers/_components/batch-edit/provider-batch-toolbar.tsx +++ b/src/app/[locale]/settings/providers/_components/batch-edit/provider-batch-toolbar.tsx @@ -12,6 +12,7 @@ import { DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; import { cn } from "@/lib/utils"; +import { parseProviderGroups } from "@/lib/utils/provider-group"; import type { ProviderDisplay, ProviderType } from "@/types/provider"; export interface ProviderBatchToolbarProps { @@ -59,10 +60,7 @@ export function ProviderBatchToolbar({ const groupMap = new Map(); for (const p of providers) { if (p.groupTag) { - const tags = p.groupTag - .split(",") - .map((tag) => tag.trim()) - .filter(Boolean); + const tags = parseProviderGroups(p.groupTag); for (const tag of tags) { groupMap.set(tag, (groupMap.get(tag) ?? 0) + 1); } diff --git a/src/app/[locale]/settings/providers/_components/forms/provider-form.legacy.tsx b/src/app/[locale]/settings/providers/_components/forms/provider-form.legacy.tsx index ee9a97cec..916a23c9f 100644 --- a/src/app/[locale]/settings/providers/_components/forms/provider-form.legacy.tsx +++ b/src/app/[locale]/settings/providers/_components/forms/provider-form.legacy.tsx @@ -34,6 +34,7 @@ import { TagInput } from "@/components/ui/tag-input"; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; import { PROVIDER_DEFAULTS, PROVIDER_TIMEOUT_DEFAULTS } from "@/lib/constants/provider.constants"; import { getProviderTypeConfig } from "@/lib/provider-type-utils"; +import { parseProviderGroups } from "@/lib/utils/provider-group"; import { extractBaseUrl, isValidUrl, @@ -175,14 +176,7 @@ export function ProviderForm({ const [costMultiplier, setCostMultiplier] = useState( sourceProvider?.costMultiplier ?? 1.0 ); - const [groupTag, setGroupTag] = useState( - sourceProvider?.groupTag - ? sourceProvider.groupTag - .split(",") - .map((t) => t.trim()) - .filter(Boolean) - : [] - ); + const [groupTag, setGroupTag] = useState(parseProviderGroups(sourceProvider?.groupTag)); const [groupSuggestions, setGroupSuggestions] = useState([]); const [limit5hUsd, setLimit5hUsd] = useState(sourceProvider?.limit5hUsd ?? null); const [limitDailyUsd, setLimitDailyUsd] = useState( @@ -892,6 +886,7 @@ export function ProviderForm({ disabled={isPending} maxTagLength={50} suggestions={groupSuggestions} + validateTag={() => true} onInvalidTag={(_tag, reason) => { const messages: Record = { empty: tUI("emptyTag"), diff --git a/src/app/[locale]/settings/providers/_components/forms/provider-form/provider-form-context.tsx b/src/app/[locale]/settings/providers/_components/forms/provider-form/provider-form-context.tsx index 2999e8436..0adbd07bf 100644 --- a/src/app/[locale]/settings/providers/_components/forms/provider-form/provider-form-context.tsx +++ b/src/app/[locale]/settings/providers/_components/forms/provider-form/provider-form-context.tsx @@ -10,6 +10,7 @@ import { useReducer, useRef, } from "react"; +import { parseProviderGroups } from "@/lib/utils/provider-group"; import type { ProviderDisplay, ProviderType } from "@/types/provider"; import { analyzeBatchProviderSettings } from "../../batch-edit/analyze-batch-settings"; import type { @@ -353,12 +354,7 @@ export function createInitialState( }, routing: { providerType: sourceProvider?.providerType ?? preset?.providerType ?? "claude", - groupTag: sourceProvider?.groupTag - ? sourceProvider.groupTag - .split(",") - .map((t) => t.trim()) - .filter(Boolean) - : [], + groupTag: parseProviderGroups(sourceProvider?.groupTag), preserveClientIp: sourceProvider?.preserveClientIp ?? false, modelRedirects: sourceProvider?.modelRedirects ?? {}, allowedModels: sourceProvider?.allowedModels ?? [], diff --git a/src/app/[locale]/settings/providers/_components/forms/provider-form/sections/routing-section.tsx b/src/app/[locale]/settings/providers/_components/forms/provider-form/sections/routing-section.tsx index 0ccd076aa..c781c4e74 100644 --- a/src/app/[locale]/settings/providers/_components/forms/provider-form/sections/routing-section.tsx +++ b/src/app/[locale]/settings/providers/_components/forms/provider-form/sections/routing-section.tsx @@ -160,6 +160,7 @@ export function RoutingSection({ subSectionRefs }: RoutingSectionProps) { disabled={state.ui.isPending} maxTagLength={50} suggestions={groupSuggestions} + validateTag={() => true} onInvalidTag={(_tag, reason) => { const messages: Record = { empty: tUI("emptyTag"), diff --git a/src/app/[locale]/settings/providers/_components/provider-manager.tsx b/src/app/[locale]/settings/providers/_components/provider-manager.tsx index 6e2559973..4a8a1d8f4 100644 --- a/src/app/[locale]/settings/providers/_components/provider-manager.tsx +++ b/src/app/[locale]/settings/providers/_components/provider-manager.tsx @@ -17,6 +17,7 @@ import { Skeleton } from "@/components/ui/skeleton"; import { Switch } from "@/components/ui/switch"; import { useDebounce } from "@/lib/hooks/use-debounce"; import type { CurrencyCode } from "@/lib/utils/currency"; +import { parseProviderGroups } from "@/lib/utils/provider-group"; import type { ProviderDisplay, ProviderStatisticsMap, ProviderType } from "@/types/provider"; import type { User } from "@/types/user"; import { @@ -144,10 +145,7 @@ export function ProviderManager({ const groups = new Set(); let hasDefaultGroup = false; providers.forEach((p) => { - const tags = p.groupTag - ?.split(",") - .map((t) => t.trim()) - .filter(Boolean); + const tags = parseProviderGroups(p.groupTag); if (!tags || tags.length === 0) { hasDefaultGroup = true; } else { @@ -172,10 +170,7 @@ export function ProviderManager({ // User's assigned groups (for non-admin users) const userGroups = useMemo(() => { if (!currentUser?.providerGroup) return []; - return currentUser.providerGroup - .split(",") - .map((g) => g.trim()) - .filter(Boolean); + return parseProviderGroups(currentUser.providerGroup); }, [currentUser?.providerGroup]); // Check if current user is admin @@ -192,10 +187,7 @@ export function ProviderManager({ (p) => p.name.toLowerCase().includes(term) || p.url.toLowerCase().includes(term) || - p.groupTag - ?.split(",") - .map((t) => t.trim().toLowerCase()) - .some((tag) => tag.includes(term)) + parseProviderGroups(p.groupTag).some((tag) => tag.toLowerCase().includes(term)) ); } @@ -212,11 +204,7 @@ export function ProviderManager({ // Filter by groups if (groupFilter.length > 0) { result = result.filter((p) => { - const providerGroups = - p.groupTag - ?.split(",") - .map((t) => t.trim()) - .filter(Boolean) || []; + const providerGroups = parseProviderGroups(p.groupTag); // If provider has no groups and "default" is selected, include it if (providerGroups.length === 0 && groupFilter.includes("default")) { @@ -341,11 +329,7 @@ export function ProviderManager({ setSelectedProviderIds((prev) => { const next = new Set(prev); for (const p of filteredProviders) { - const tags = - p.groupTag - ?.split(",") - .map((tag) => tag.trim()) - .filter(Boolean) ?? []; + const tags = parseProviderGroups(p.groupTag); if (tags.includes(group) || (group === "default" && tags.length === 0)) { next.add(p.id); } diff --git a/src/app/[locale]/settings/providers/_components/provider-rich-list-item.tsx b/src/app/[locale]/settings/providers/_components/provider-rich-list-item.tsx index 15234db79..0946ad9f7 100644 --- a/src/app/[locale]/settings/providers/_components/provider-rich-list-item.tsx +++ b/src/app/[locale]/settings/providers/_components/provider-rich-list-item.tsx @@ -69,6 +69,7 @@ import { copyToClipboard, isClipboardSupported } from "@/lib/utils/clipboard"; import { getContrastTextColor, getGroupColor } from "@/lib/utils/color"; import type { CurrencyCode } from "@/lib/utils/currency"; import { formatCurrency } from "@/lib/utils/currency"; +import { normalizeProviderGroupTag, parseProviderGroups } from "@/lib/utils/provider-group"; import type { ProviderDisplay, ProviderStatistics, ProviderVendor } from "@/types/provider"; import type { User } from "@/types/user"; import { ProviderForm } from "./forms/provider-form"; @@ -445,16 +446,11 @@ function ProviderRichListItemInner({ const handleSaveWeight = createSaveHandler("weight"); const handleSaveCostMultiplier = createSaveHandler("cost_multiplier"); - const providerGroups = provider.groupTag - ? provider.groupTag - .split(",") - .map((t) => t.trim()) - .filter(Boolean) - : []; + const providerGroups = parseProviderGroups(provider.groupTag); const handleSaveGroups = async (groups: string[]): Promise => { try { - const groupTag = groups.length > 0 ? groups.join(",") : null; + const groupTag = normalizeProviderGroupTag(groups.join(",")); const res = await editProvider(provider.id, { group_tag: groupTag }); if (res.ok) { toast.success(tInline("saveSuccess")); diff --git a/src/app/v1/_lib/proxy/provider-selector.ts b/src/app/v1/_lib/proxy/provider-selector.ts index 52304e0e5..27e6a11d0 100644 --- a/src/app/v1/_lib/proxy/provider-selector.ts +++ b/src/app/v1/_lib/proxy/provider-selector.ts @@ -3,6 +3,7 @@ import { PROVIDER_GROUP } from "@/lib/constants/provider.constants"; import { logger } from "@/lib/logger"; import { RateLimitService } from "@/lib/rate-limit"; import { SessionManager } from "@/lib/session-manager"; +import { parseProviderGroups } from "@/lib/utils/provider-group"; import { isProviderActiveNow } from "@/lib/utils/provider-schedule"; import { resolveSystemTimezone } from "@/lib/utils/timezone"; import { isVendorTypeCircuitOpen } from "@/lib/vendor-type-circuit-breaker"; @@ -48,10 +49,7 @@ async function getVerboseProviderErrorCached(): Promise { * @returns 清理后的分组数组(去空格、去空项) */ function parseGroupString(groupString: string): string[] { - return groupString - .split(",") - .map((g) => g.trim()) - .filter(Boolean); + return parseProviderGroups(groupString); } /** diff --git a/src/components/form/form-field.tsx b/src/components/form/form-field.tsx index 5c07f609e..c405f15c0 100644 --- a/src/components/form/form-field.tsx +++ b/src/components/form/form-field.tsx @@ -5,6 +5,7 @@ import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { TagInput } from "@/components/ui/tag-input"; import { cn } from "@/lib/utils"; +import { parseProviderGroups } from "@/lib/utils/provider-group"; /** * 表单字段配置 @@ -175,12 +176,7 @@ export function TagInputField({ const fieldId = tagInputProps.id || `field-${autoId}`; // 将字符串转换为数组 - const tagsArray = value - ? value - .split(",") - .map((t) => t.trim()) - .filter(Boolean) - : []; + const tagsArray = parseProviderGroups(value); // 将数组转换回字符串 const handleChange = (tags: string[]) => { diff --git a/src/components/ui/__tests__/provider-group-tag-input.test.tsx b/src/components/ui/__tests__/provider-group-tag-input.test.tsx new file mode 100644 index 000000000..a2a73a2e7 --- /dev/null +++ b/src/components/ui/__tests__/provider-group-tag-input.test.tsx @@ -0,0 +1,119 @@ +/** + * @vitest-environment happy-dom + */ + +import type { ReactNode } from "react"; +import { act } from "react"; +import { createRoot } from "react-dom/client"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { ProviderGroupSelect } from "@/app/[locale]/dashboard/_components/user/forms/provider-group-select"; +import { TagInput } from "@/components/ui/tag-input"; + +const providerActionsMocks = vi.hoisted(() => ({ + getProviderGroupsWithCount: vi.fn(async () => ({ ok: true, data: [] })), +})); + +const sonnerMocks = vi.hoisted(() => ({ + toast: { + error: vi.fn(), + }, +})); + +vi.mock("@/actions/providers", () => providerActionsMocks); +vi.mock("sonner", () => sonnerMocks); + +function render(node: ReactNode) { + const container = document.createElement("div"); + document.body.appendChild(container); + const root = createRoot(container); + + act(() => { + root.render(node); + }); + + return { + container, + unmount: () => { + act(() => root.unmount()); + container.remove(); + }, + }; +} + +async function typeAndSubmit(input: HTMLInputElement, value: string) { + await act(async () => { + input.focus(); + const valueSetter = Object.getOwnPropertyDescriptor(HTMLInputElement.prototype, "value")?.set; + valueSetter?.call(input, value); + input.dispatchEvent(new Event("input", { bubbles: true })); + input.dispatchEvent(new Event("change", { bubbles: true })); + await new Promise((resolve) => setTimeout(resolve, 0)); + }); + + await act(async () => { + input.dispatchEvent(new KeyboardEvent("keydown", { key: "Enter", bubbles: true })); + await new Promise((resolve) => setTimeout(resolve, 0)); + }); +} + +afterEach(() => { + while (document.body.firstChild) { + document.body.removeChild(document.body.firstChild); + } +}); + +beforeEach(() => { + vi.clearAllMocks(); +}); + +describe("provider-group tag inputs", () => { + test("默认 TagInput 仍应拒绝中文标签", async () => { + const onChange = vi.fn(); + const onInvalidTag = vi.fn(); + const { container, unmount } = render( + + ); + + const input = container.querySelector("input"); + expect(input).toBeInstanceOf(HTMLInputElement); + + await typeAndSubmit(input as HTMLInputElement, "中文分组"); + + expect(onChange).not.toHaveBeenCalled(); + expect(onInvalidTag).toHaveBeenCalledWith("中文分组", "invalid_format"); + + unmount(); + }); + + test("ProviderGroupSelect 应允许输入中文分组", async () => { + const onChange = vi.fn(); + const translations = { + label: "Provider group", + placeholder: "Enter group", + description: "desc", + errors: { + loadFailed: "Load failed", + }, + tagInputErrors: { + empty: "empty", + duplicate: "duplicate", + too_long: "too long", + invalid_format: "invalid format", + max_tags: "max tags", + }, + }; + const { container, unmount } = render( + + ); + + const input = container.querySelector("input"); + expect(input).toBeInstanceOf(HTMLInputElement); + + await typeAndSubmit(input as HTMLInputElement, "中文分组"); + + expect(onChange).toHaveBeenCalledWith("中文分组"); + expect(sonnerMocks.toast.error).not.toHaveBeenCalled(); + + unmount(); + }); +}); diff --git a/src/lib/provider-patch-contract.ts b/src/lib/provider-patch-contract.ts index e47d50c05..1d24bdf69 100644 --- a/src/lib/provider-patch-contract.ts +++ b/src/lib/provider-patch-contract.ts @@ -1,3 +1,4 @@ +import { normalizeProviderGroupTag } from "@/lib/utils/provider-group"; import type { ProviderBatchApplyUpdates, ProviderBatchPatch, @@ -367,6 +368,11 @@ function normalizePatchField( return createInvalidPatchShapeError(field, "set mode value is invalid for this field"); } + if (field === "group_tag") { + const normalizedGroupTag = normalizeProviderGroupTag(input.set) ?? ""; + return { ok: true, data: { mode: "set", value: normalizedGroupTag as T } }; + } + return { ok: true, data: { mode: "set", value: input.set as T } }; } diff --git a/src/lib/request-filter-engine.ts b/src/lib/request-filter-engine.ts index 3360e773f..336535f48 100644 --- a/src/lib/request-filter-engine.ts +++ b/src/lib/request-filter-engine.ts @@ -9,6 +9,7 @@ import type { RemoveOp, SetOp, } from "@/lib/request-filter-types"; +import { parseProviderGroups } from "@/lib/utils/provider-group"; import type { RequestFilter, RequestFilterAction, @@ -484,7 +485,7 @@ export class RequestFilterEngine { let providerTagsSet: Set | null = null; if (this.hasGroupBasedFilters) { const providerGroupTag = session.provider.groupTag; - providerTagsSet = new Set(providerGroupTag?.split(",").map((t) => t.trim()) ?? []); + providerTagsSet = new Set(parseProviderGroups(providerGroupTag)); } for (const filter of this.providerGuardFilters) { @@ -574,7 +575,7 @@ export class RequestFilterEngine { let providerTagsSet: Set | null = null; if (this.hasGroupBasedFinalFilters) { const providerGroupTag = session.provider.groupTag; - providerTagsSet = new Set(providerGroupTag?.split(",").map((t) => t.trim()) ?? []); + providerTagsSet = new Set(parseProviderGroups(providerGroupTag)); } for (const filter of this.providerFinalFilters) { diff --git a/src/lib/utils/provider-group.test.ts b/src/lib/utils/provider-group.test.ts new file mode 100644 index 000000000..38efd5ca3 --- /dev/null +++ b/src/lib/utils/provider-group.test.ts @@ -0,0 +1,24 @@ +import { describe, expect, test } from "vitest"; +import { + normalizeProviderGroup, + normalizeProviderGroupTag, + parseProviderGroups, +} from "./provider-group"; + +describe("provider-group utils", () => { + test("parseProviderGroups 应支持中文逗号和换行作为分隔符", () => { + expect(parseProviderGroups("研发,渠道\n直营")).toEqual(["研发", "渠道", "直营"]); + }); + + test("normalizeProviderGroup 应在支持中文标签的同时做去重和排序", () => { + expect(normalizeProviderGroup("研发,渠道\n研发")).toBe("渠道,研发"); + }); + + test("normalizeProviderGroupTag 应支持中文标签并保留原始顺序", () => { + expect(normalizeProviderGroupTag("直营,华北\n直营")).toBe("直营,华北"); + }); + + test("normalizeProviderGroupTag 在空输入时应返回 null", () => { + expect(normalizeProviderGroupTag(" , \n ")).toBeNull(); + }); +}); diff --git a/src/lib/utils/provider-group.ts b/src/lib/utils/provider-group.ts index 595f8bb0e..5e053f09e 100644 --- a/src/lib/utils/provider-group.ts +++ b/src/lib/utils/provider-group.ts @@ -1,5 +1,18 @@ import { PROVIDER_GROUP } from "@/lib/constants/provider.constants"; +const PROVIDER_GROUP_SEPARATOR = /[,,\n\r]+/; + +function splitProviderGroupValue(value: unknown): string[] { + if (typeof value !== "string") { + return []; + } + + return value + .split(PROVIDER_GROUP_SEPARATOR) + .map((group) => group.trim()) + .filter(Boolean); +} + /** * Normalize provider group value to a consistent format * - Returns "default" for null/undefined/empty values @@ -7,26 +20,29 @@ import { PROVIDER_GROUP } from "@/lib/constants/provider.constants"; * - Sorts groups alphabetically for consistency */ export function normalizeProviderGroup(value: unknown): string { - if (value === null || value === undefined) return PROVIDER_GROUP.DEFAULT; - if (typeof value !== "string") return PROVIDER_GROUP.DEFAULT; - const trimmed = value.trim(); - if (trimmed === "") return PROVIDER_GROUP.DEFAULT; - - const groups = trimmed - .split(",") - .map((g) => g.trim()) - .filter(Boolean); + const groups = splitProviderGroupValue(value); if (groups.length === 0) return PROVIDER_GROUP.DEFAULT; return Array.from(new Set(groups)).sort().join(","); } /** - * Parse a comma-separated provider group string into an array + * Normalize provider group tag string for provider.groupTag storage. + * - Supports English comma, Chinese comma and line breaks as separators + * - Trims whitespace and removes duplicates while preserving input order + * - Returns null for null/undefined/empty values */ -export function parseProviderGroups(value: string): string[] { - return value - .split(",") - .map((g) => g.trim()) - .filter(Boolean); +export function normalizeProviderGroupTag(value: unknown): string | null { + const groups = splitProviderGroupValue(value); + if (groups.length === 0) return null; + + return Array.from(new Set(groups)).join(","); +} + +/** + * Parse a provider group / groupTag string into an array. + * Supports English comma, Chinese comma and line breaks as separators. + */ +export function parseProviderGroups(value: unknown): string[] { + return splitProviderGroupValue(value); } diff --git a/src/repository/leaderboard.ts b/src/repository/leaderboard.ts index fe0cf81bc..2b2442a67 100644 --- a/src/repository/leaderboard.ts +++ b/src/repository/leaderboard.ts @@ -274,7 +274,7 @@ async function findLeaderboardWithTimezone( if (normalizedGroups.length > 0) { const groupConditions = normalizedGroups.map( (group) => - sql`${group} = ANY(regexp_split_to_array(coalesce(${users.providerGroup}, ''), '\\s*,\\s*'))` + sql`${group} = ANY(regexp_split_to_array(coalesce(${users.providerGroup}, ''), '\\s*[,,]+\\s*'))` ); groupFilterCondition = sql`(${sql.join(groupConditions, sql` OR `)})`; } diff --git a/src/repository/provider.ts b/src/repository/provider.ts index 7395fb345..fdd395392 100644 --- a/src/repository/provider.ts +++ b/src/repository/provider.ts @@ -7,6 +7,7 @@ import { getCachedProviders } from "@/lib/cache/provider-cache"; import { PROVIDER_TIMEOUT_DEFAULTS } from "@/lib/constants/provider.constants"; import { resetEndpointCircuit } from "@/lib/endpoint-circuit-breaker"; import { logger } from "@/lib/logger"; +import { parseProviderGroups } from "@/lib/utils/provider-group"; import { resolveSystemTimezone } from "@/lib/utils/timezone"; import type { AnthropicAdaptiveThinkingConfig, @@ -1491,12 +1492,7 @@ export async function getDistinctProviderGroups(): Promise { const allTags = result .map((r) => r.groupTag) .filter((tag): tag is string => tag !== null) - .flatMap((tag) => - tag - .split(",") - .map((t) => t.trim()) - .filter(Boolean) - ); + .flatMap((tag) => parseProviderGroups(tag)); return [...new Set(allTags)].sort(); } diff --git a/src/repository/user.ts b/src/repository/user.ts index 7a34d3cd7..eb4836df8 100644 --- a/src/repository/user.ts +++ b/src/repository/user.ts @@ -4,6 +4,7 @@ import { and, asc, eq, isNull, type SQL, sql } from "drizzle-orm"; import { db } from "@/drizzle/db"; import { keys as keysTable, users } from "@/drizzle/schema"; import { cacheUser, invalidateCachedUser } from "@/lib/security/api-key-auth-cache"; +import { parseProviderGroups } from "@/lib/utils/provider-group"; import type { CreateUserData, UpdateUserData, User } from "@/types/user"; import { toUser } from "./_shared/transformers"; @@ -245,7 +246,7 @@ export async function findUserListBatch( if (trimmedGroups.length > 0) { const groupConditions = trimmedGroups.map( (group) => - sql`${group} = ANY(regexp_split_to_array(coalesce(${users.providerGroup}, ''), '\\s*,\\s*'))` + sql`${group} = ANY(regexp_split_to_array(coalesce(${users.providerGroup}, ''), '\\s*[,,]+\\s*'))` ); keyGroupFilterCondition = sql`(${sql.join(groupConditions, sql` OR `)})`; } @@ -605,10 +606,7 @@ export async function getAllUserProviderGroups(): Promise { const allGroups = new Set(); for (const row of result) { - const groups = row.providerGroup - ?.split(",") - .map((group) => group.trim()) - .filter(Boolean); + const groups = parseProviderGroups(row.providerGroup); if (!groups || groups.length === 0) continue; for (const group of groups) { allGroups.add(group); From e52fb6af5362a15acb80c8792ab700517108cec5 Mon Sep 17 00:00:00 2001 From: KevinShiCN <35432622+KevinShiCN@users.noreply.github.com> Date: Sat, 21 Mar 2026 17:50:31 +0800 Subject: [PATCH 03/11] feat(i18n): add authenticated proxy format hints to provider proxy config (#954) Update proxy URL placeholder and description across all 5 languages (en, zh-CN, zh-TW, ja, ru) to include: - Complete list of supported protocols (http, https, socks5, socks4) - Authenticated proxy format: http://user:password@host:port - URL encoding reminder for special characters in passwords (e.g. # as %23) Previously the description only showed "Supported formats:" without listing them, and the placeholder had no authentication example. --- messages/en/settings/providers/form/sections.json | 4 ++-- messages/ja/settings/providers/form/sections.json | 4 ++-- messages/ru/settings/providers/form/sections.json | 4 ++-- messages/zh-CN/settings/providers/form/sections.json | 4 ++-- messages/zh-TW/settings/providers/form/sections.json | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/messages/en/settings/providers/form/sections.json b/messages/en/settings/providers/form/sections.json index b26aa58f1..f05a1e89e 100644 --- a/messages/en/settings/providers/form/sections.json +++ b/messages/en/settings/providers/form/sections.json @@ -101,10 +101,10 @@ }, "title": "Proxy", "url": { - "formats": "Supported formats:", + "formats": "Supports http://, https://, socks5://, socks4:// protocols. For authenticated proxies use http://user:password@host:port (URL-encode special characters in password, e.g. # as %23)", "label": "Proxy URL", "optional": "(optional)", - "placeholder": "e.g. http://proxy.example.com:8080 or socks5://127.0.0.1:1080" + "placeholder": "e.g. http://proxy.example.com:8080 or http://user:pass@proxy:8080" } }, "rateLimit": { diff --git a/messages/ja/settings/providers/form/sections.json b/messages/ja/settings/providers/form/sections.json index 6f93c7121..1a07d581c 100644 --- a/messages/ja/settings/providers/form/sections.json +++ b/messages/ja/settings/providers/form/sections.json @@ -101,10 +101,10 @@ }, "title": "プロキシ設定", "url": { - "formats": "対応フォーマット:", + "formats": "http://、https://、socks5://、socks4:// プロトコルに対応。認証が必要な場合は http://user:password@host:port 形式を使用してください(パスワード内の特殊文字は URL エンコードが必要です。例: # → %23)", "label": "プロキシ URL", "optional": "(任意)", - "placeholder": "例: http://proxy.example.com:8080 または socks5://127.0.0.1:1080" + "placeholder": "例: http://proxy.example.com:8080 または http://user:pass@proxy:8080" } }, "rateLimit": { diff --git a/messages/ru/settings/providers/form/sections.json b/messages/ru/settings/providers/form/sections.json index ee544ae18..001571216 100644 --- a/messages/ru/settings/providers/form/sections.json +++ b/messages/ru/settings/providers/form/sections.json @@ -101,10 +101,10 @@ }, "title": "Прокси", "url": { - "formats": "Поддерживаемые форматы:", + "formats": "Поддерживаются протоколы http://, https://, socks5://, socks4://. Для прокси с аутентификацией используйте формат http://user:password@host:port (специальные символы в пароле кодируйте URL-кодировкой, например # как %23)", "label": "URL прокси", "optional": "(необязательно)", - "placeholder": "например: http://proxy.example.com:8080 или socks5://127.0.0.1:1080" + "placeholder": "например: http://proxy.example.com:8080 или http://user:pass@proxy:8080" } }, "rateLimit": { diff --git a/messages/zh-CN/settings/providers/form/sections.json b/messages/zh-CN/settings/providers/form/sections.json index 9303a054e..865bb4080 100644 --- a/messages/zh-CN/settings/providers/form/sections.json +++ b/messages/zh-CN/settings/providers/form/sections.json @@ -351,8 +351,8 @@ "url": { "label": "代理地址", "optional": "(可选)", - "placeholder": "例如: http://proxy.example.com:8080 或 socks5://127.0.0.1:1080", - "formats": "支持格式:" + "placeholder": "例如: http://proxy.example.com:8080 或 http://user:pass@proxy:8080", + "formats": "支持 http://、https://、socks5://、socks4:// 协议。需要认证时使用 http://user:password@host:port 格式(密码中的特殊字符需 URL 编码,如 # 编码为 %23)" }, "fallback": { "label": "代理失败时降级到直连", diff --git a/messages/zh-TW/settings/providers/form/sections.json b/messages/zh-TW/settings/providers/form/sections.json index 4493edc42..a85b642a5 100644 --- a/messages/zh-TW/settings/providers/form/sections.json +++ b/messages/zh-TW/settings/providers/form/sections.json @@ -101,10 +101,10 @@ }, "title": "代理設定", "url": { - "formats": "支援格式:", + "formats": "支援 http://、https://、socks5://、socks4:// 協定。需要認證時使用 http://user:password@host:port 格式(密碼中的特殊字元需 URL 編碼,如 # 編碼為 %23)", "label": "代理位址", "optional": "(選填)", - "placeholder": "例如:http://proxy.example.com:8080 或 socks5://127.0.0.1:1080" + "placeholder": "例如:http://proxy.example.com:8080 或 http://user:pass@proxy:8080" } }, "rateLimit": { From 650dbc61a8510ff78c8906ee19b18e823f251930 Mon Sep 17 00:00:00 2001 From: Ding <44717411+ding113@users.noreply.github.com> Date: Sat, 21 Mar 2026 20:47:51 +0800 Subject: [PATCH 04/11] fix: stabilize usage log virtual scrolling (#959) * fix: stabilize usage log virtual scrolling * fix: address usage log PR review comments * fix: address follow-up PR review issues * fix: polish my-usage log table review follow-ups * fix: stabilize my-usage load-more callback --- src/actions/my-usage.ts | 108 ++--- .../virtualized-logs-table.test.tsx | 38 +- .../_components/virtualized-logs-table.tsx | 82 ++-- .../_components/usage-logs-section.test.tsx | 109 +++++ .../_components/usage-logs-section.tsx | 276 +++++------ .../_components/usage-logs-table.test.tsx | 125 +++++ .../my-usage/_components/usage-logs-table.tsx | 437 +++++++++++------- src/app/[locale]/my-usage/page.tsx | 23 +- src/app/api/actions/[...route]/route.ts | 32 +- src/hooks/use-virtualized-infinite-list.ts | 81 ++++ src/repository/usage-logs.ts | 298 ++++++------ tests/api/api-actions-integrity.test.ts | 53 ++- tests/api/my-usage-readonly.test.ts | 4 +- .../actions/my-usage-date-range-dst.test.ts | 36 +- .../usage-logs-sessionid-filter.test.ts | 154 ++++++ 15 files changed, 1220 insertions(+), 636 deletions(-) create mode 100644 src/app/[locale]/my-usage/_components/usage-logs-section.test.tsx create mode 100644 src/app/[locale]/my-usage/_components/usage-logs-table.test.tsx create mode 100644 src/hooks/use-virtualized-infinite-list.ts diff --git a/src/actions/my-usage.ts b/src/actions/my-usage.ts index cc7ee0432..0a37de5d4 100644 --- a/src/actions/my-usage.ts +++ b/src/actions/my-usage.ts @@ -16,9 +16,10 @@ import { LEDGER_BILLING_CONDITION } from "@/repository/_shared/ledger-conditions import { EXCLUDE_WARMUP_CONDITION } from "@/repository/_shared/message-request-conditions"; import { getSystemSettings } from "@/repository/system-config"; import { - findUsageLogsForKeySlim, + findUsageLogsForKeyBatch, getDistinctEndpointsForKey, getDistinctModelsForKey, + type UsageLogSlimBatchResult, type UsageLogSummary, } from "@/repository/usage-logs"; import type { BillingModelSource } from "@/types/system-config"; @@ -168,11 +169,10 @@ export interface MyUsageLogEntry { cacheTtlApplied: string | null; } -export interface MyUsageLogsResult { +export interface MyUsageLogsBatchResult { logs: MyUsageLogEntry[]; - total: number; - page: number; - pageSize: number; + nextCursor: { createdAt: string; id: number } | null; + hasMore: boolean; currencyCode: CurrencyCode; billingModelSource: BillingModelSource; } @@ -469,7 +469,7 @@ export async function getMyTodayStats(): Promise> { } } -export interface MyUsageLogsFilters { +export interface MyUsageLogsBatchFilters { startDate?: string; endDate?: string; /** Session ID(精确匹配;空字符串/空白视为不筛选) */ @@ -479,30 +479,61 @@ export interface MyUsageLogsFilters { excludeStatusCode200?: boolean; endpoint?: string; minRetryCount?: number; - page?: number; - pageSize?: number; + cursor?: { createdAt: string; id: number }; + limit?: number; } -export async function getMyUsageLogs( - filters: MyUsageLogsFilters = {} -): Promise> { +function mapMyUsageLogEntries( + result: Pick, + billingModelSource: BillingModelSource +): MyUsageLogEntry[] { + return result.logs.map((log) => { + const modelRedirect = + log.originalModel && log.model && log.originalModel !== log.model + ? `${log.originalModel} → ${log.model}` + : null; + + const billingModel = + (billingModelSource === "original" ? log.originalModel : log.model) ?? null; + + return { + id: log.id, + createdAt: log.createdAt, + model: log.model, + billingModel, + anthropicEffort: log.anthropicEffort ?? null, + modelRedirect, + inputTokens: log.inputTokens ?? 0, + outputTokens: log.outputTokens ?? 0, + cost: log.costUsd ? Number(log.costUsd) : 0, + statusCode: log.statusCode, + duration: log.durationMs, + endpoint: log.endpoint, + cacheCreationInputTokens: log.cacheCreationInputTokens ?? null, + cacheReadInputTokens: log.cacheReadInputTokens ?? null, + cacheCreation5mInputTokens: log.cacheCreation5mInputTokens ?? null, + cacheCreation1hInputTokens: log.cacheCreation1hInputTokens ?? null, + cacheTtlApplied: log.cacheTtlApplied ?? null, + }; + }); +} + +export async function getMyUsageLogsBatch( + filters: MyUsageLogsBatchFilters = {} +): Promise> { try { const session = await getSession({ allowReadOnlyAccess: true }); if (!session) return { ok: false, error: "Unauthorized" }; const settings = await getSystemSettings(); - - const rawPageSize = filters.pageSize && filters.pageSize > 0 ? filters.pageSize : 20; - const pageSize = Math.min(rawPageSize, 100); - const page = filters.page && filters.page > 0 ? filters.page : 1; - const timezone = await resolveSystemTimezone(); const { startTime, endTime } = parseDateRangeInServerTimezone( filters.startDate, filters.endDate, timezone ); - const result = await findUsageLogsForKeySlim({ + const limit = filters.limit && filters.limit > 0 ? Math.min(filters.limit, 100) : 20; + const result = await findUsageLogsForKeyBatch({ keyString: session.key.key, sessionId: filters.sessionId, startTime, @@ -512,53 +543,22 @@ export async function getMyUsageLogs( excludeStatusCode200: filters.excludeStatusCode200, endpoint: filters.endpoint, minRetryCount: filters.minRetryCount, - page, - pageSize, - }); - - const logs: MyUsageLogEntry[] = result.logs.map((log) => { - const modelRedirect = - log.originalModel && log.model && log.originalModel !== log.model - ? `${log.originalModel} → ${log.model}` - : null; - - const billingModel = - (settings.billingModelSource === "original" ? log.originalModel : log.model) ?? null; - - return { - id: log.id, - createdAt: log.createdAt, - model: log.model, - billingModel, - anthropicEffort: log.anthropicEffort ?? null, - modelRedirect, - inputTokens: log.inputTokens ?? 0, - outputTokens: log.outputTokens ?? 0, - cost: log.costUsd ? Number(log.costUsd) : 0, - statusCode: log.statusCode, - duration: log.durationMs, - endpoint: log.endpoint, - cacheCreationInputTokens: log.cacheCreationInputTokens ?? null, - cacheReadInputTokens: log.cacheReadInputTokens ?? null, - cacheCreation5mInputTokens: log.cacheCreation5mInputTokens ?? null, - cacheCreation1hInputTokens: log.cacheCreation1hInputTokens ?? null, - cacheTtlApplied: log.cacheTtlApplied ?? null, - }; + cursor: filters.cursor, + limit, }); return { ok: true, data: { - logs, - total: result.total, - page, - pageSize, + logs: mapMyUsageLogEntries(result, settings.billingModelSource), + nextCursor: result.nextCursor, + hasMore: result.hasMore, currencyCode: settings.currencyDisplay, billingModelSource: settings.billingModelSource, }, }; } catch (error) { - logger.error("[my-usage] getMyUsageLogs failed", error); + logger.error("[my-usage] getMyUsageLogsBatch failed", error); return { ok: false, error: "Failed to get usage logs" }; } } diff --git a/src/app/[locale]/dashboard/logs/_components/virtualized-logs-table.test.tsx b/src/app/[locale]/dashboard/logs/_components/virtualized-logs-table.test.tsx index 0a3f3ecb1..ed57033eb 100644 --- a/src/app/[locale]/dashboard/logs/_components/virtualized-logs-table.test.tsx +++ b/src/app/[locale]/dashboard/logs/_components/virtualized-logs-table.test.tsx @@ -12,21 +12,25 @@ let mockIsError = false; let mockError: unknown = null; let mockHasNextPage = false; let mockIsFetchingNextPage = false; +const useInfiniteQuerySpy = vi.hoisted(() => vi.fn()); vi.mock("next-intl", () => ({ useTranslations: () => (key: string) => key, })); vi.mock("@tanstack/react-query", () => ({ - useInfiniteQuery: () => ({ - data: { pages: [{ logs: mockLogs, nextCursor: null, hasMore: false }] }, - fetchNextPage: vi.fn(), - hasNextPage: mockHasNextPage, - isFetchingNextPage: mockIsFetchingNextPage, - isLoading: mockIsLoading, - isError: mockIsError, - error: mockError, - }), + useInfiniteQuery: (options: unknown) => { + useInfiniteQuerySpy(options); + return { + data: { pages: [{ logs: mockLogs, nextCursor: null, hasMore: false }] }, + fetchNextPage: vi.fn(), + hasNextPage: mockHasNextPage, + isFetchingNextPage: mockIsFetchingNextPage, + isLoading: mockIsLoading, + isError: mockIsError, + error: mockError, + }; + }, })); vi.mock("@/hooks/use-virtualizer", () => ({ @@ -144,6 +148,22 @@ function makeLog(overrides: Partial): UsageLogRow { } describe("virtualized-logs-table multiplier badge", () => { + test("does not cap cached pages so deep scroll can return to the latest rows", () => { + mockIsLoading = false; + mockIsError = false; + mockError = null; + mockHasNextPage = true; + mockIsFetchingNextPage = false; + mockLogs = [makeLog({ id: 1 })]; + useInfiniteQuerySpy.mockClear(); + + renderToStaticMarkup(); + + const options = useInfiniteQuerySpy.mock.calls[0]?.[0] as { maxPages?: number } | undefined; + expect(options).toBeDefined(); + expect(options?.maxPages).toBeUndefined(); + }); + test("renders loading/error/empty states", () => { mockIsError = false; mockError = null; diff --git a/src/app/[locale]/dashboard/logs/_components/virtualized-logs-table.tsx b/src/app/[locale]/dashboard/logs/_components/virtualized-logs-table.tsx index 179aa0759..e441038eb 100644 --- a/src/app/[locale]/dashboard/logs/_components/virtualized-logs-table.tsx +++ b/src/app/[locale]/dashboard/logs/_components/virtualized-logs-table.tsx @@ -3,14 +3,22 @@ import { useInfiniteQuery } from "@tanstack/react-query"; import { ArrowUp, GitBranch, Loader2 } from "lucide-react"; import { useTranslations } from "next-intl"; -import { type MouseEvent, useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { + type MouseEvent, + useCallback, + useEffect, + useEffectEvent, + useMemo, + useRef, + useState, +} from "react"; import { toast } from "sonner"; import { getUsageLogsBatch } from "@/actions/usage-logs"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { RelativeTime } from "@/components/ui/relative-time"; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; -import { useVirtualizer } from "@/hooks/use-virtualizer"; +import { useVirtualizedInfiniteList } from "@/hooks/use-virtualized-infinite-list"; import type { LogsTableColumn } from "@/lib/column-visibility"; import { cn, formatTokenAmount } from "@/lib/utils"; import { copyTextToClipboard } from "@/lib/utils/clipboard"; @@ -80,9 +88,8 @@ export function VirtualizedLogsTable({ const getPricingSourceLabel = (source: string) => t(`logs.billingDetails.pricingSource.${source}`); const tChain = useTranslations("provider-chain"); - const parentRef = useRef(null); - const [showScrollToTop, setShowScrollToTop] = useState(false); - const shouldPoll = autoRefreshEnabled && !showScrollToTop; + const [isHistoryBrowsing, setIsHistoryBrowsing] = useState(false); + const shouldPoll = autoRefreshEnabled && !isHistoryBrowsing; const hideProviderColumn = hiddenColumns?.includes("provider") ?? false; const hideUserColumn = hiddenColumns?.includes("user") ?? false; @@ -137,51 +144,50 @@ export function VirtualizedLogsTable({ if (query.state.fetchStatus !== "idle") return false; return autoRefreshIntervalMs; }, - maxPages: 5, }); // Flatten all pages into a single array const pages = data?.pages; const allLogs = useMemo(() => pages?.flatMap((page) => page.logs) ?? [], [pages]); + const filtersResetKey = useMemo(() => JSON.stringify(filters), [filters]); + const previousFiltersResetKeyRef = useRef(filtersResetKey); + + const getItemKey = useCallback( + (index: number) => allLogs[index]?.id ?? `loader-${index}`, + [allLogs] + ); - // Virtual list setup - const rowVirtualizer = useVirtualizer({ - count: hasNextPage ? allLogs.length + 1 : allLogs.length, - getScrollElement: () => parentRef.current, + const { + parentRef, + rowVirtualizer, + virtualItems, + showScrollToTop, + handleScroll, + scrollToTop, + resetScrollPosition, + } = useVirtualizedInfiniteList({ + itemCount: allLogs.length, + hasNextPage, + isFetchingNextPage, + fetchNextPage, estimateSize: () => ROW_HEIGHT, overscan: 10, + getItemKey, }); - const virtualItems = rowVirtualizer.getVirtualItems(); - const lastItemIndex = virtualItems[virtualItems.length - 1]?.index ?? -1; - - // Auto-fetch next page when scrolling near the bottom useEffect(() => { - // If the last visible item is a loading row or near the end, fetch more - if (lastItemIndex >= allLogs.length - 5 && hasNextPage && !isFetchingNextPage) { - fetchNextPage(); - } - }, [lastItemIndex, hasNextPage, isFetchingNextPage, allLogs.length, fetchNextPage]); - - // Track scroll position for "scroll to top" button - const handleScroll = useCallback(() => { - if (parentRef.current) { - setShowScrollToTop(parentRef.current.scrollTop > 500); - } - }, []); - - // Scroll to top handler - const scrollToTop = useCallback(() => { - parentRef.current?.scrollTo({ top: 0, behavior: "smooth" }); - }, []); - - // Reset scroll when filters change - // biome-ignore lint/correctness/useExhaustiveDependencies: `filters` is an intentional trigger + setIsHistoryBrowsing(showScrollToTop); + }, [showScrollToTop]); + + const handleFiltersReset = useEffectEvent((nextResetKey: string) => { + if (previousFiltersResetKeyRef.current === nextResetKey) return; + previousFiltersResetKeyRef.current = nextResetKey; + resetScrollPosition(); + }); + useEffect(() => { - if (parentRef.current) { - parentRef.current.scrollTop = 0; - } - }, [filters]); + handleFiltersReset(filtersResetKey); + }, [filtersResetKey]); if (isLoading) { return ( diff --git a/src/app/[locale]/my-usage/_components/usage-logs-section.test.tsx b/src/app/[locale]/my-usage/_components/usage-logs-section.test.tsx new file mode 100644 index 000000000..9f05d63aa --- /dev/null +++ b/src/app/[locale]/my-usage/_components/usage-logs-section.test.tsx @@ -0,0 +1,109 @@ +import type { ReactNode } from "react"; +import { createRoot } from "react-dom/client"; +import { act } from "react"; +import { describe, expect, test, vi } from "vitest"; + +const mocks = vi.hoisted(() => ({ + useInfiniteQuery: vi.fn(), + getMyUsageLogs: vi.fn(), + getMyUsageLogsBatch: vi.fn(), + getMyAvailableModels: vi.fn(), + getMyAvailableEndpoints: vi.fn(), +})); + +vi.mock("next-intl", () => ({ + useTranslations: () => (key: string, values?: Record) => + values ? `${key}:${JSON.stringify(values)}` : key, + useTimeZone: () => "UTC", +})); + +vi.mock("@tanstack/react-query", () => ({ + useInfiniteQuery: mocks.useInfiniteQuery, +})); + +vi.mock("@/actions/my-usage", () => ({ + getMyUsageLogs: mocks.getMyUsageLogs, + getMyUsageLogsBatch: mocks.getMyUsageLogsBatch, + getMyAvailableModels: mocks.getMyAvailableModels, + getMyAvailableEndpoints: mocks.getMyAvailableEndpoints, +})); + +vi.mock("@/app/[locale]/dashboard/logs/_components/logs-date-range-picker", () => ({ + LogsDateRangePicker: () =>
, +})); + +vi.mock("@/components/ui/collapsible", () => ({ + Collapsible: ({ children }: { children?: ReactNode }) =>
{children}
, + CollapsibleContent: ({ children }: { children?: ReactNode }) =>
{children}
, + CollapsibleTrigger: ({ children }: { children?: ReactNode }) =>
{children}
, +})); + +vi.mock("@/components/ui/select", () => ({ + Select: ({ children }: { children?: ReactNode }) =>
{children}
, + SelectTrigger: ({ children }: { children?: ReactNode }) =>
{children}
, + SelectValue: () =>
, + SelectContent: ({ children }: { children?: ReactNode }) =>
{children}
, + SelectItem: ({ children }: { children?: ReactNode }) =>
{children}
, +})); + +vi.mock("@/components/ui/button", () => ({ + Button: ({ children, ...props }: React.ComponentProps<"button">) => ( + + ), +})); + +vi.mock("@/components/ui/badge", () => ({ + Badge: ({ children }: { children?: ReactNode }) =>
{children}
, +})); + +vi.mock("@/components/ui/input", () => ({ + Input: (props: React.ComponentProps<"input">) => , +})); + +vi.mock("@/components/ui/label", () => ({ + Label: ({ children }: { children?: ReactNode }) => , +})); + +vi.mock("./usage-logs-table", () => ({ + UsageLogsTable: () =>
, +})); + +import { UsageLogsSection } from "./usage-logs-section"; + +describe("my-usage usage logs section", () => { + test("uses infinite query instead of the old page-based getMyUsageLogs flow", async () => { + mocks.useInfiniteQuery.mockReturnValue({ + data: { pages: [{ logs: [], nextCursor: null, hasMore: false }] }, + fetchNextPage: vi.fn(), + hasNextPage: false, + isFetchingNextPage: false, + isLoading: false, + isError: false, + error: null, + }); + mocks.getMyUsageLogs.mockResolvedValue({ + ok: true, + data: { logs: [], total: 0, page: 1, pageSize: 20, currencyCode: "USD" }, + }); + mocks.getMyAvailableModels.mockResolvedValue({ ok: true, data: [] }); + mocks.getMyAvailableEndpoints.mockResolvedValue({ ok: true, data: [] }); + + const container = document.createElement("div"); + document.body.appendChild(container); + const root = createRoot(container); + + await act(async () => { + root.render(); + }); + + expect(mocks.useInfiniteQuery).toHaveBeenCalled(); + expect(mocks.getMyUsageLogs).not.toHaveBeenCalled(); + + await act(async () => { + root.unmount(); + }); + container.remove(); + }); +}); diff --git a/src/app/[locale]/my-usage/_components/usage-logs-section.tsx b/src/app/[locale]/my-usage/_components/usage-logs-section.tsx index 486014f74..9c048594a 100644 --- a/src/app/[locale]/my-usage/_components/usage-logs-section.tsx +++ b/src/app/[locale]/my-usage/_components/usage-logs-section.tsx @@ -1,13 +1,13 @@ "use client"; +import { useInfiniteQuery } from "@tanstack/react-query"; import { Check, ChevronDown, Filter, Loader2, RefreshCw, ScrollText, X } from "lucide-react"; import { useTranslations } from "next-intl"; -import { useCallback, useEffect, useMemo, useRef, useState, useTransition } from "react"; +import { useCallback, useEffect, useMemo, useState } from "react"; import { getMyAvailableEndpoints, getMyAvailableModels, - getMyUsageLogs, - type MyUsageLogsResult, + getMyUsageLogsBatch, } from "@/actions/my-usage"; import { LogsDateRangePicker } from "@/app/[locale]/dashboard/logs/_components/logs-date-range-picker"; import { Badge } from "@/components/ui/badge"; @@ -25,9 +25,9 @@ import { import { cn } from "@/lib/utils"; import { UsageLogsTable } from "./usage-logs-table"; +const BATCH_SIZE = 20; + interface UsageLogsSectionProps { - initialData?: MyUsageLogsResult | null; - loading?: boolean; autoRefreshSeconds?: number; defaultOpen?: boolean; serverTimeZone?: string; @@ -41,12 +41,9 @@ interface Filters { excludeStatusCode200?: boolean; endpoint?: string; minRetryCount?: number; - page?: number; } export function UsageLogsSection({ - initialData = null, - loading = false, autoRefreshSeconds, defaultOpen = false, serverTimeZone, @@ -60,14 +57,70 @@ export function UsageLogsSection({ const [endpoints, setEndpoints] = useState([]); const [isModelsLoading, setIsModelsLoading] = useState(true); const [isEndpointsLoading, setIsEndpointsLoading] = useState(true); - const [draftFilters, setDraftFilters] = useState({ page: 1 }); - const [appliedFilters, setAppliedFilters] = useState({ page: 1 }); - const [data, setData] = useState(initialData); - const [isPending, startTransition] = useTransition(); - const [error, setError] = useState(null); + const [draftFilters, setDraftFilters] = useState({}); + const [appliedFilters, setAppliedFilters] = useState({}); + const [isBrowsingHistory, setIsBrowsingHistory] = useState(false); + + useEffect(() => { + setIsModelsLoading(true); + setIsEndpointsLoading(true); + + void getMyAvailableModels() + .then((modelsResult) => { + if (modelsResult.ok && modelsResult.data) { + setModels(modelsResult.data); + } + }) + .finally(() => setIsModelsLoading(false)); - // Compute metrics for header summary - const logs = data?.logs ?? []; + void getMyAvailableEndpoints() + .then((endpointsResult) => { + if (endpointsResult.ok && endpointsResult.data) { + setEndpoints(endpointsResult.data); + } + }) + .finally(() => setIsEndpointsLoading(false)); + }, []); + + const query = useInfiniteQuery({ + queryKey: ["my-usage-logs-batch", appliedFilters], + queryFn: async ({ pageParam }) => { + const result = await getMyUsageLogsBatch({ + ...appliedFilters, + cursor: pageParam, + limit: BATCH_SIZE, + }); + if (!result.ok) { + throw new Error(result.error); + } + return result.data; + }, + initialPageParam: undefined as { createdAt: string; id: number } | undefined, + getNextPageParam: (lastPage) => lastPage.nextCursor ?? undefined, + staleTime: 30000, + refetchOnWindowFocus: false, + refetchInterval: autoRefreshSeconds + ? (query) => { + if (isBrowsingHistory) return false; + if (query.state.fetchStatus !== "idle") return false; + return autoRefreshSeconds * 1000; + } + : false, + }); + const { + data, + fetchNextPage, + hasNextPage = false, + isFetchingNextPage, + isLoading, + isError, + error, + isRefetching = false, + } = query; + const refetch = query.refetch ?? (async (): Promise => undefined); + + const logs = useMemo(() => data?.pages.flatMap((page) => page.logs) ?? [], [data]); + const latestPage = data?.pages[0]; const activeFiltersCount = useMemo(() => { let count = 0; @@ -79,10 +132,7 @@ export function UsageLogsSection({ return count; }, [appliedFilters]); - const lastLog = useMemo(() => { - if (!logs || logs.length === 0) return null; - return logs[0]; // First log is the most recent (sorted by createdAt DESC) - }, [logs]); + const lastLog = logs[0] ?? null; const lastStatusText = useMemo(() => { if (!lastLog?.createdAt) return null; @@ -99,7 +149,7 @@ export function UsageLogsSection({ }, [lastLog]); const successRate = useMemo(() => { - if (!logs || logs.length === 0) return null; + if (logs.length === 0) return null; const successCount = logs.filter((log) => log.statusCode && log.statusCode < 400).length; return Math.round((successCount / logs.length) * 100); }, [logs]); @@ -111,133 +161,38 @@ export function UsageLogsSection({ return ""; }, [lastLog]); - // Sync initialData from parent when it becomes available - // (useState only uses initialData on first mount, not on subsequent updates) - useEffect(() => { - if (initialData && !data) { - setData(initialData); - } - }, [initialData, data]); - - useEffect(() => { - setIsModelsLoading(true); - setIsEndpointsLoading(true); - - void getMyAvailableModels() - .then((modelsResult) => { - if (modelsResult.ok && modelsResult.data) { - setModels(modelsResult.data); - } - }) - .finally(() => setIsModelsLoading(false)); - - void getMyAvailableEndpoints() - .then((endpointsResult) => { - if (endpointsResult.ok && endpointsResult.data) { - setEndpoints(endpointsResult.data); - } - }) - .finally(() => setIsEndpointsLoading(false)); - }, []); - - const loadLogs = useCallback( - (nextFilters: Filters) => { - startTransition(async () => { - const result = await getMyUsageLogs(nextFilters); - if (result.ok && result.data) { - setData(result.data); - setAppliedFilters(nextFilters); - setError(null); - } else { - setError(!result.ok && "error" in result ? result.error : t("loadFailed")); - } - }); - }, - [t] - ); - - useEffect(() => { - // initial load if not provided - if (data) return; - if (!initialData && !loading) { - loadLogs({ page: 1 }); - } - }, [data, initialData, loading, loadLogs]); - - // Auto-refresh polling (only when on page 1 to avoid disrupting history browsing) - const intervalRef = useRef(null); - - useEffect(() => { - if (!autoRefreshSeconds || autoRefreshSeconds <= 0) { - return; - } - - const pollIntervalMs = autoRefreshSeconds * 1000; - - const startPolling = () => { - if (intervalRef.current) { - clearInterval(intervalRef.current); - } - - intervalRef.current = setInterval(() => { - // Only auto-refresh when on page 1 - if ((appliedFilters.page ?? 1) === 1) { - loadLogs(appliedFilters); - } - }, pollIntervalMs); - }; - - const stopPolling = () => { - if (intervalRef.current) { - clearInterval(intervalRef.current); - intervalRef.current = null; - } - }; - - const handleVisibilityChange = () => { - if (document.hidden) { - stopPolling(); - } else { - // Refresh immediately when tab becomes visible (only if on page 1) - if ((appliedFilters.page ?? 1) === 1) { - loadLogs(appliedFilters); - } - startPolling(); - } - }; - - startPolling(); - document.addEventListener("visibilitychange", handleVisibilityChange); - - return () => { - stopPolling(); - document.removeEventListener("visibilitychange", handleVisibilityChange); - }; - }, [autoRefreshSeconds, appliedFilters, loadLogs]); - const handleFilterChange = (changes: Partial) => { - setDraftFilters((prev) => ({ ...prev, ...changes, page: 1 })); + setDraftFilters((prev) => ({ ...prev, ...changes })); }; const handleApply = () => { - loadLogs({ ...draftFilters, page: 1 }); + const nextFilters = { ...draftFilters }; + if (JSON.stringify(nextFilters) === JSON.stringify(appliedFilters)) { + void refetch(); + return; + } + setAppliedFilters(nextFilters); }; const handleReset = () => { - setDraftFilters({ page: 1 }); - loadLogs({ page: 1 }); + setDraftFilters({}); + if (Object.keys(appliedFilters).length === 0) { + void refetch(); + return; + } + setAppliedFilters({}); }; const handleDateRangeChange = (range: { startDate?: string; endDate?: string }) => { handleFilterChange(range); }; - const handlePageChange = (page: number) => { - loadLogs({ ...appliedFilters, page }); - }; + const handleLoadMore = useCallback(() => { + void fetchNextPage(); + }, [fetchNextPage]); - const isInitialLoading = loading || (!data && isPending); - const isRefreshing = isPending && Boolean(data); + const isRefreshing = isRefetching && !isFetchingNextPage && logs.length > 0; + const errorMessage = isError ? (error instanceof Error ? error.message : t("loadFailed")) : null; return ( @@ -250,7 +205,6 @@ export function UsageLogsSection({ isOpen && "border-b" )} > - {/* Icon + Title */}
@@ -258,11 +212,8 @@ export function UsageLogsSection({ {tCollapsible("title")}
- {/* Header Summary */}
- {/* Desktop Summary */}
- {/* Last Status */} {lastLog ? ( {tCollapsible("lastStatus", { @@ -276,7 +227,6 @@ export function UsageLogsSection({ | - {/* Success Rate */} {successRate !== null ? ( ) : null} - {/* Active Filters Badge */} {activeFiltersCount > 0 && ( <> | @@ -302,7 +251,6 @@ export function UsageLogsSection({ )} - {/* Auto-refresh */} {autoRefreshSeconds && ( <> | @@ -312,9 +260,7 @@ export function UsageLogsSection({ )}
- {/* Mobile Summary */}
- {/* Last Status - compact */} {lastLog ? ( {lastLog.statusCode ?? "-"} ({lastStatusText ?? "-"}) @@ -325,7 +271,6 @@ export function UsageLogsSection({ | - {/* Success Rate - compact */} {successRate !== null ? ( ) : null} - {/* Filters + Refresh */} {activeFiltersCount > 0 && ( <> | @@ -355,7 +299,6 @@ export function UsageLogsSection({ )}
- {/* Chevron */} - handleFilterChange({ - model: value === "__all__" ? undefined : value, - }) + handleFilterChange({ model: value === "__all__" ? undefined : value }) } disabled={isModelsLoading} > @@ -411,9 +352,7 @@ export function UsageLogsSection({ + setCodexPriorityBillingSource(value as CodexPriorityBillingSource) + } + disabled={isPending} + > + + + + + + {t("codexPriorityBillingSourceOptions.requested")} + + {t("codexPriorityBillingSourceOptions.actual")} + + +

{t("codexPriorityBillingSourceDesc")}

+
+ {/* Timezone Select */}
); diff --git a/tests/unit/actions/usage-logs-export-retry-count.test.ts b/tests/unit/actions/usage-logs-export-retry-count.test.ts index 7273a3c3d..5f4ba0935 100644 --- a/tests/unit/actions/usage-logs-export-retry-count.test.ts +++ b/tests/unit/actions/usage-logs-export-retry-count.test.ts @@ -2,6 +2,10 @@ import { beforeEach, describe, expect, test, vi } from "vitest"; const getSessionMock = vi.fn(); const findUsageLogsWithDetailsMock = vi.fn(); +const findUsageLogsBatchMock = vi.fn(); +const findUsageLogsStatsMock = vi.fn(); +const exportStatusStore = new Map(); +const exportCsvStore = new Map(); vi.mock("@/lib/auth", () => { return { @@ -9,21 +13,55 @@ vi.mock("@/lib/auth", () => { }; }); +vi.mock("@/lib/redis/redis-kv-store", () => ({ + RedisKVStore: class MockRedisKVStore { + private readonly prefix: string; + + constructor(options: { prefix: string }) { + this.prefix = options.prefix; + } + + async set(key: string, value: T) { + if (this.prefix.includes(":status:")) { + exportStatusStore.set(key, value); + } else { + exportCsvStore.set(key, value as string); + } + return true; + } + + async get(key: string) { + if (this.prefix.includes(":status:")) { + return (exportStatusStore.get(key) as T | undefined) ?? null; + } + return ((exportCsvStore.get(key) as T | undefined) ?? null) as T | null; + } + + async getAndDelete(key: string) { + if (this.prefix.includes(":status:")) { + const value = (exportStatusStore.get(key) as T | undefined) ?? null; + exportStatusStore.delete(key); + return value; + } + const value = ((exportCsvStore.get(key) as T | undefined) ?? null) as T | null; + exportCsvStore.delete(key); + return value; + } + + async delete(key: string) { + if (this.prefix.includes(":status:")) { + return exportStatusStore.delete(key); + } + return exportCsvStore.delete(key); + } + }, +})); + vi.mock("@/repository/usage-logs", () => { return { findUsageLogSessionIdSuggestions: vi.fn(async () => []), - findUsageLogsBatch: vi.fn(async () => ({ logs: [], nextCursor: null, hasMore: false })), - findUsageLogsStats: vi.fn(async () => ({ - totalRequests: 0, - totalCost: 0, - totalTokens: 0, - totalInputTokens: 0, - totalOutputTokens: 0, - totalCacheCreationTokens: 0, - totalCacheReadTokens: 0, - totalCacheCreation5mTokens: 0, - totalCacheCreation1hTokens: 0, - })), + findUsageLogsBatch: findUsageLogsBatchMock, + findUsageLogsStats: findUsageLogsStatsMock, findUsageLogsWithDetails: findUsageLogsWithDetailsMock, getUsedEndpoints: vi.fn(async () => []), getUsedModels: vi.fn(async () => []), @@ -31,6 +69,44 @@ vi.mock("@/repository/usage-logs", () => { }; }); +function createSummary(totalRequests = 0) { + return { + totalRequests, + totalCost: 0, + totalTokens: 0, + totalInputTokens: 0, + totalOutputTokens: 0, + totalCacheCreationTokens: 0, + totalCacheReadTokens: 0, + totalCacheCreation5mTokens: 0, + totalCacheCreation1hTokens: 0, + }; +} + +function createLog(overrides: Record = {}) { + return { + createdAt: new Date("2026-03-16T00:00:00.000Z"), + userName: "u", + keyName: "k", + providerName: "p", + model: "m", + originalModel: "om", + endpoint: "/v1/messages", + statusCode: 200, + inputTokens: 1, + outputTokens: 2, + cacheCreation5mInputTokens: 0, + cacheCreation1hInputTokens: 0, + cacheReadInputTokens: 0, + totalTokens: 3, + costUsd: "0", + durationMs: 10, + sessionId: "s1", + providerChain: null, + ...overrides, + }; +} + function parseCsvLine(line: string): string[] { const fields: string[] = []; let current = ""; @@ -76,12 +152,28 @@ function parseCsvLine(line: string): string[] { describe("Usage logs CSV export retryCount", () => { beforeEach(() => { + vi.resetModules(); vi.clearAllMocks(); + vi.useRealTimers(); + exportStatusStore.clear(); + exportCsvStore.clear(); getSessionMock.mockResolvedValue({ user: { id: 1, role: "admin" } }); + findUsageLogsWithDetailsMock.mockResolvedValue({ + logs: [], + total: 0, + summary: createSummary(), + }); + findUsageLogsBatchMock.mockResolvedValue({ logs: [], nextCursor: null, hasMore: false }); + findUsageLogsStatsMock.mockResolvedValue(createSummary()); }); test("exportUsageLogs: Retry Count 应对齐 getRetryCount(hedge race 为 0)", async () => { findUsageLogsWithDetailsMock.mockResolvedValue({ + logs: [], + total: 3, + summary: createSummary(3), + }); + findUsageLogsBatchMock.mockResolvedValueOnce({ logs: [ { createdAt: new Date("2026-03-16T00:00:00.000Z"), @@ -157,18 +249,8 @@ describe("Usage logs CSV export retryCount", () => { ], }, ], - total: 3, - summary: { - totalRequests: 3, - totalCost: 0, - totalTokens: 9, - totalInputTokens: 3, - totalOutputTokens: 6, - totalCacheCreationTokens: 0, - totalCacheReadTokens: 0, - totalCacheCreation5mTokens: 0, - totalCacheCreation1hTokens: 0, - }, + nextCursor: null, + hasMore: false, }); const { exportUsageLogs } = await import("@/actions/usage-logs"); @@ -194,4 +276,88 @@ describe("Usage logs CSV export retryCount", () => { expect(row2[retryCountIndex]).toBe("1"); expect(row3[retryCountIndex]).toBe("0"); }); + + test("exportUsageLogs: 按批次全量导出,并拦截前导空白公式注入", async () => { + findUsageLogsWithDetailsMock.mockResolvedValue({ + logs: [], + total: 3, + summary: createSummary(3), + }); + findUsageLogsBatchMock + .mockResolvedValueOnce({ + logs: [ + createLog({ sessionId: "s1", model: " =1+1" }), + createLog({ sessionId: "s2", model: "+2+2" }), + ], + nextCursor: { createdAt: "2026-03-16T00:00:01.000000Z", id: 2 }, + hasMore: true, + }) + .mockResolvedValueOnce({ + logs: [createLog({ sessionId: "s3", endpoint: " \t@SUM(A1:A2)" })], + nextCursor: null, + hasMore: false, + }); + + const { exportUsageLogs } = await import("@/actions/usage-logs"); + const result = await exportUsageLogs({}); + + expect(result.ok).toBe(true); + expect(findUsageLogsBatchMock).toHaveBeenCalledTimes(2); + + const csvNoBom = result.data.replace(/^\uFEFF/, ""); + const lines = csvNoBom + .trim() + .split("\n") + .map((line) => line.replace(/\r$/, "")); + + expect(lines).toHaveLength(4); + const header = parseCsvLine(lines[0] ?? ""); + const modelIndex = header.indexOf("Model"); + const endpointIndex = header.indexOf("Endpoint"); + const row1 = parseCsvLine(lines[1] ?? ""); + const row2 = parseCsvLine(lines[2] ?? ""); + const row3 = parseCsvLine(lines[3] ?? ""); + + expect(row1[modelIndex]).toBe("' =1+1"); + expect(row2[modelIndex]).toBe("'+2+2"); + expect(row3[endpointIndex]).toBe("' \t@SUM(A1:A2)"); + }); + + test("startUsageLogsExport: 异步导出任务完成后可轮询并下载", async () => { + vi.useFakeTimers(); + findUsageLogsWithDetailsMock.mockResolvedValue({ + logs: [], + total: 1, + summary: createSummary(1), + }); + findUsageLogsBatchMock.mockResolvedValueOnce({ + logs: [createLog({ sessionId: "job-session" })], + nextCursor: null, + hasMore: false, + }); + + const { downloadUsageLogsExport, getUsageLogsExportStatus, startUsageLogsExport } = + await import("@/actions/usage-logs"); + + const startResult = await startUsageLogsExport({}); + expect(startResult.ok).toBe(true); + const jobId = startResult.data.jobId; + + const queuedStatus = await getUsageLogsExportStatus(jobId); + expect(queuedStatus.ok).toBe(true); + expect(queuedStatus.data.status).toBe("queued"); + + await vi.runAllTimersAsync(); + + const completedStatus = await getUsageLogsExportStatus(jobId); + expect(completedStatus.ok).toBe(true); + expect(completedStatus.data.status).toBe("completed"); + expect(completedStatus.data.progressPercent).toBe(100); + expect(completedStatus.data.processedRows).toBe(1); + + const downloadResult = await downloadUsageLogsExport(jobId); + expect(downloadResult.ok).toBe(true); + expect(downloadResult.data).toContain("Session ID"); + expect(downloadResult.data).toContain("job-session"); + }); }); diff --git a/tests/unit/dashboard-logs-export-progress-ui.test.tsx b/tests/unit/dashboard-logs-export-progress-ui.test.tsx new file mode 100644 index 000000000..f96d10aad --- /dev/null +++ b/tests/unit/dashboard-logs-export-progress-ui.test.tsx @@ -0,0 +1,225 @@ +/** + * @vitest-environment happy-dom + */ + +import type { ReactNode } from "react"; +import { act } from "react"; +import { createRoot } from "react-dom/client"; +import { NextIntlClientProvider } from "next-intl"; +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { UsageLogsFilters } from "@/app/[locale]/dashboard/logs/_components/usage-logs-filters"; +import dashboardMessages from "../../messages/en/dashboard.json"; + +const { + downloadUsageLogsExportMock, + getUsageLogsExportStatusMock, + startUsageLogsExportMock, + toastErrorMock, + toastSuccessMock, +} = vi.hoisted(() => ({ + startUsageLogsExportMock: vi.fn(), + getUsageLogsExportStatusMock: vi.fn(), + downloadUsageLogsExportMock: vi.fn(), + toastSuccessMock: vi.fn(), + toastErrorMock: vi.fn(), +})); + +vi.mock("@/actions/usage-logs", () => ({ + startUsageLogsExport: startUsageLogsExportMock, + getUsageLogsExportStatus: getUsageLogsExportStatusMock, + downloadUsageLogsExport: downloadUsageLogsExportMock, +})); + +vi.mock("sonner", () => ({ + toast: { + success: toastSuccessMock, + error: toastErrorMock, + }, +})); + +vi.mock("@/app/[locale]/dashboard/logs/_components/filters/active-filters-display", () => ({ + ActiveFiltersDisplay: () =>
, +})); + +vi.mock("@/app/[locale]/dashboard/logs/_components/filters/filter-section", () => ({ + FilterSection: ({ children }: { children: ReactNode }) =>
{children}
, +})); + +vi.mock("@/app/[locale]/dashboard/logs/_components/filters/identity-filters", () => ({ + IdentityFilters: () =>
, +})); + +vi.mock("@/app/[locale]/dashboard/logs/_components/filters/quick-filters-bar", () => ({ + QuickFiltersBar: () =>
, +})); + +vi.mock("@/app/[locale]/dashboard/logs/_components/filters/request-filters", () => ({ + RequestFilters: ({ + onFiltersChange, + }: { + onFiltersChange: (filters: Record) => void; + }) => ( + + ), +})); + +vi.mock("@/app/[locale]/dashboard/logs/_components/filters/status-filters", () => ({ + StatusFilters: () =>
, +})); + +vi.mock("@/app/[locale]/dashboard/logs/_components/filters/time-filters", () => ({ + TimeFilters: () =>
, +})); + +function renderWithIntl(node: ReactNode) { + const container = document.createElement("div"); + document.body.appendChild(container); + const root = createRoot(container); + + act(() => { + root.render( + + {node} + + ); + }); + + return { + container, + unmount: () => { + act(() => root.unmount()); + container.remove(); + }, + }; +} + +async function actClick(el: Element | null) { + if (!el) throw new Error("element not found"); + await act(async () => { + el.dispatchEvent(new MouseEvent("click", { bubbles: true })); + }); +} + +async function flushPromises() { + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + }); +} + +describe("UsageLogsFilters export progress UI", () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.useFakeTimers(); + globalThis.URL.createObjectURL = vi.fn(() => "blob:usage-logs"); + globalThis.URL.revokeObjectURL = vi.fn(); + HTMLAnchorElement.prototype.click = vi.fn(); + }); + + test("shows export progress while polling and downloads when completed", async () => { + startUsageLogsExportMock.mockResolvedValue({ ok: true, data: { jobId: "job-1" } }); + getUsageLogsExportStatusMock + .mockResolvedValueOnce({ + ok: true, + data: { + jobId: "job-1", + status: "running", + processedRows: 50, + totalRows: 200, + progressPercent: 25, + }, + }) + .mockResolvedValueOnce({ + ok: true, + data: { + jobId: "job-1", + status: "completed", + processedRows: 200, + totalRows: 200, + progressPercent: 100, + }, + }); + downloadUsageLogsExportMock.mockResolvedValue({ ok: true, data: "\uFEFFTime,User\n" }); + + const { container, unmount } = renderWithIntl( + {}} + onReset={() => {}} + /> + ); + + const exportButton = Array.from(container.querySelectorAll("button")).find( + (button) => (button.textContent || "").trim() === "Export" + ); + + await actClick(exportButton ?? null); + await flushPromises(); + + expect(container.textContent).toContain("Exported 50 / 200"); + expect(container.textContent).toContain("25%"); + + await act(async () => { + await vi.advanceTimersByTimeAsync(800); + }); + await flushPromises(); + + expect(downloadUsageLogsExportMock).toHaveBeenCalledWith("job-1"); + expect(toastSuccessMock).toHaveBeenCalledWith("Export completed successfully"); + expect(toastErrorMock).not.toHaveBeenCalled(); + + unmount(); + }); + + test("exports the applied filters instead of unapplied local draft filters", async () => { + startUsageLogsExportMock.mockResolvedValue({ ok: true, data: { jobId: "job-2" } }); + getUsageLogsExportStatusMock.mockResolvedValueOnce({ + ok: true, + data: { + jobId: "job-2", + status: "completed", + processedRows: 1, + totalRows: 1, + progressPercent: 100, + }, + }); + downloadUsageLogsExportMock.mockResolvedValue({ ok: true, data: "\uFEFFTime,User\n" }); + + const { container, unmount } = renderWithIntl( + {}} + onReset={() => {}} + /> + ); + + await actClick(container.querySelector("[data-testid='request-filters']")); + + const exportButton = Array.from(container.querySelectorAll("button")).find( + (button) => (button.textContent || "").trim() === "Export" + ); + + await actClick(exportButton ?? null); + await flushPromises(); + + expect(startUsageLogsExportMock).toHaveBeenCalledWith({ sessionId: "applied-session" }); + + unmount(); + }); +}); From ed1d4dbbb5dfa898a6b8ce4e280123ce518e15bf Mon Sep 17 00:00:00 2001 From: Ding <44717411+ding113@users.noreply.github.com> Date: Mon, 23 Mar 2026 01:26:48 +0800 Subject: [PATCH 10/11] fix: restore legacy user search API compatibility (#967) * fix: restore legacy user search API compatibility Amp-Thread-ID: https://ampcode.com/threads/T-019d1607-8336-7525-aab4-dd247d6a8d93 Co-authored-by: Amp * fix: address user search review feedback Amp-Thread-ID: https://ampcode.com/threads/T-019d1679-0b36-75ba-803e-7c414bac6d0f Co-authored-by: Amp --------- Co-authored-by: Amp --- src/actions/users.ts | 142 ++++++-- .../_components/rate-limit-top-users.tsx | 28 +- .../_components/rate-limit-filters.tsx | 28 +- src/app/api/actions/[...route]/route.ts | 173 ++++++++-- tests/api/api-actions-integrity.test.ts | 2 + tests/api/api-endpoints.test.ts | 2 + tests/api/api-openapi-spec.test.ts | 34 ++ tests/api/users-search-users-compat.test.ts | 71 ++++ .../users-action-get-users-compat.test.ts | 310 ++++++++++++++++++ ...ers-action-search-users-for-filter.test.ts | 12 + 10 files changed, 749 insertions(+), 53 deletions(-) create mode 100644 tests/api/users-search-users-compat.test.ts create mode 100644 tests/unit/users-action-get-users-compat.test.ts diff --git a/src/actions/users.ts b/src/actions/users.ts index c1546666f..1bd9a8aac 100644 --- a/src/actions/users.ts +++ b/src/actions/users.ts @@ -29,7 +29,6 @@ import { createUser, deleteUser, findUserById, - findUserList, findUserListBatch, getAllUserProviderGroups as getAllUserProviderGroupsRepository, getAllUserTags as getAllUserTagsRepository, @@ -47,6 +46,10 @@ export interface GetUsersBatchParams { cursor?: string; limit?: number; searchTerm?: string; + query?: string; + keyword?: string; + page?: number; + offset?: number; tagFilters?: string[]; keyGroupFilters?: string[]; statusFilter?: "all" | "active" | "expired" | "expiringSoon" | "enabled" | "disabled"; @@ -63,6 +66,102 @@ export interface GetUsersBatchParams { sortOrder?: "asc" | "desc"; } +const USER_LIST_DEFAULT_LIMIT = 50; +const USER_LIST_MAX_LIMIT = 200; + +function normalizeLegacySearchTerm(params?: GetUsersBatchParams): string | undefined { + for (const candidate of [params?.searchTerm, params?.query, params?.keyword]) { + const trimmed = candidate?.trim(); + if (trimmed) { + return trimmed; + } + } + + return undefined; +} + +function normalizeUserListParams(params?: GetUsersBatchParams): GetUsersBatchParams { + const limit = + typeof params?.limit === "number" && Number.isFinite(params.limit) && params.limit > 0 + ? Math.min(Math.trunc(params.limit), USER_LIST_MAX_LIMIT) + : undefined; + + let cursor = params?.cursor?.trim() || undefined; + if (!cursor) { + const offset = + typeof params?.offset === "number" && Number.isFinite(params.offset) + ? Math.max(0, Math.trunc(params.offset)) + : undefined; + const page = + typeof params?.page === "number" && Number.isFinite(params.page) + ? Math.max(0, Math.trunc(params.page)) + : undefined; + + if (offset !== undefined) { + cursor = String(offset); + } else if (page !== undefined) { + const effectiveLimit = limit ?? USER_LIST_DEFAULT_LIMIT; + cursor = String(Math.max(page - 1, 0) * effectiveLimit); + } + } + + return { + cursor, + limit, + searchTerm: normalizeLegacySearchTerm(params), + tagFilters: params?.tagFilters, + keyGroupFilters: params?.keyGroupFilters, + statusFilter: params?.statusFilter, + sortBy: params?.sortBy, + sortOrder: params?.sortOrder, + }; +} + +function hasExplicitPaginationParams( + params?: GetUsersBatchParams, + normalizedParams = normalizeUserListParams(params) +): boolean { + return Boolean( + normalizedParams.cursor !== undefined || + normalizedParams.limit !== undefined || + params?.page !== undefined || + params?.offset !== undefined + ); +} + +function hasSearchOrFilterOverrides(normalizedParams: GetUsersBatchParams): boolean { + return Boolean( + normalizedParams.searchTerm || + (normalizedParams.tagFilters?.length ?? 0) > 0 || + (normalizedParams.keyGroupFilters?.length ?? 0) > 0 || + normalizedParams.statusFilter || + normalizedParams.sortBy || + normalizedParams.sortOrder + ); +} + +async function loadAllUsersForAdmin(baseParams?: GetUsersBatchParams): Promise { + const users: User[] = []; + const normalizedBaseParams = normalizeUserListParams(baseParams); + let cursor = normalizedBaseParams.cursor; + + while (true) { + const page = await findUserListBatch({ + ...normalizedBaseParams, + cursor, + limit: USER_LIST_MAX_LIMIT, + }); + + users.push(...page.users); + + if (!page.hasMore || !page.nextCursor) { + return users; + } + + cursor = page.nextCursor; + } +} + /** * 批量获取用户列表的返回结果。 */ @@ -204,7 +303,7 @@ export async function syncUserProviderGroupFromKeys(userId: number): Promise { +export async function getUsers(params?: GetUsersBatchParams): Promise { try { const session = await getSession(); if (!session) { @@ -217,11 +316,18 @@ export async function getUsers(): Promise { // Treat any non-admin role as non-admin for safety. const isAdmin = session.user.role === "admin"; + const normalizedParams = normalizeUserListParams(params); // 非 admin 用户只能看到自己的数据(从 DB 获取完整用户信息) let users: User[] = []; if (isAdmin) { - users = await findUserList(); // 管理员可以看到所有用户 + if (hasExplicitPaginationParams(params, normalizedParams)) { + users = (await findUserListBatch(normalizedParams)).users; + } else if (hasSearchOrFilterOverrides(normalizedParams)) { + users = await loadAllUsersForAdmin(normalizedParams); + } else { + users = await loadAllUsersForAdmin(); + } } else { const selfUser = await findUserById(session.user.id); users = selfUser ? [selfUser] : []; @@ -394,6 +500,12 @@ export async function searchUsersForFilter( } } +export async function searchUsers( + searchTerm?: string +): Promise>> { + return searchUsersForFilter(searchTerm); +} + /** * 获取所有用户标签(用于标签筛选下拉框) * 返回所有用户的标签,不受当前筛选条件影响 @@ -497,16 +609,8 @@ export async function getUsersBatch( const locale = await getLocale(); const t = await getTranslations("users"); - const { users, nextCursor, hasMore } = await findUserListBatch({ - cursor: params.cursor, - limit: params.limit, - searchTerm: params.searchTerm, - tagFilters: params.tagFilters, - keyGroupFilters: params.keyGroupFilters, - statusFilter: params.statusFilter, - sortBy: params.sortBy, - sortOrder: params.sortOrder, - }); + const normalizedParams = normalizeUserListParams(params); + const { users, nextCursor, hasMore } = await findUserListBatch(normalizedParams); if (users.length === 0) { return { ok: true, data: { users: [], nextCursor, hasMore } }; @@ -665,16 +769,8 @@ export async function getUsersBatchCore( const locale = await getLocale(); const t = await getTranslations("users"); - const { users, nextCursor, hasMore } = await findUserListBatch({ - cursor: params.cursor, - limit: params.limit, - searchTerm: params.searchTerm, - tagFilters: params.tagFilters, - keyGroupFilters: params.keyGroupFilters, - statusFilter: params.statusFilter, - sortBy: params.sortBy, - sortOrder: params.sortOrder, - }); + const normalizedParams = normalizeUserListParams(params); + const { users, nextCursor, hasMore } = await findUserListBatch(normalizedParams); if (users.length === 0) { return { ok: true, data: { users: [], nextCursor, hasMore } }; diff --git a/src/app/[locale]/dashboard/_components/rate-limit-top-users.tsx b/src/app/[locale]/dashboard/_components/rate-limit-top-users.tsx index b7f04554a..aaadcfb74 100644 --- a/src/app/[locale]/dashboard/_components/rate-limit-top-users.tsx +++ b/src/app/[locale]/dashboard/_components/rate-limit-top-users.tsx @@ -3,7 +3,7 @@ import { ArrowUpDown } from "lucide-react"; import { useLocale, useTranslations } from "next-intl"; import * as React from "react"; -import { getUsers } from "@/actions/users"; +import { searchUsers } from "@/actions/users"; import { Button } from "@/components/ui/button"; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; import { @@ -36,10 +36,28 @@ export function RateLimitTopUsers({ data }: RateLimitTopUsersProps) { // 加载用户详情 React.useEffect(() => { - getUsers().then((userList) => { - setUsers(userList); - setLoading(false); - }); + let cancelled = false; + + void searchUsers() + .then((result) => { + if (!cancelled) { + setUsers(result.ok ? result.data : []); + } + }) + .catch(() => { + if (!cancelled) { + setUsers([]); + } + }) + .finally(() => { + if (!cancelled) { + setLoading(false); + } + }); + + return () => { + cancelled = true; + }; }, []); // 组合数据:用户信息 + 事件计数 diff --git a/src/app/[locale]/dashboard/rate-limits/_components/rate-limit-filters.tsx b/src/app/[locale]/dashboard/rate-limits/_components/rate-limit-filters.tsx index 5e2863086..865658226 100644 --- a/src/app/[locale]/dashboard/rate-limits/_components/rate-limit-filters.tsx +++ b/src/app/[locale]/dashboard/rate-limits/_components/rate-limit-filters.tsx @@ -5,7 +5,7 @@ import { Calendar, X } from "lucide-react"; import { useTranslations } from "next-intl"; import * as React from "react"; import { getProviders } from "@/actions/providers"; -import { getUsers } from "@/actions/users"; +import { searchUsers } from "@/actions/users"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; @@ -61,10 +61,28 @@ export function RateLimitFilters({ // 加载用户列表 React.useEffect(() => { - getUsers().then((userList) => { - setUsers(userList); - setLoadingUsers(false); - }); + let cancelled = false; + + void searchUsers() + .then((result) => { + if (!cancelled) { + setUsers(result.ok ? result.data : []); + } + }) + .catch(() => { + if (!cancelled) { + setUsers([]); + } + }) + .finally(() => { + if (!cancelled) { + setLoadingUsers(false); + } + }); + + return () => { + cancelled = true; + }; }, []); // 加载供应商列表 diff --git a/src/app/api/actions/[...route]/route.ts b/src/app/api/actions/[...route]/route.ts index ea22243ff..fa11cb663 100644 --- a/src/app/api/actions/[...route]/route.ts +++ b/src/app/api/actions/[...route]/route.ts @@ -68,40 +68,173 @@ app.openAPIRegistry.registerComponent("securitySchemes", "bearerAuth", { // ==================== 用户管理 ==================== +const userKeyListItemSchema = z.object({ + id: z.number().describe("密钥 ID"), + name: z.string().describe("密钥名称"), + maskedKey: z.string().describe("脱敏后的密钥"), + fullKey: z.string().optional().describe("完整密钥(有权限时返回)"), + canCopy: z.boolean().describe("是否允许复制完整密钥"), + expiresAt: z.string().describe("过期时间展示值"), + status: z.enum(["enabled", "disabled"]).describe("密钥状态"), + todayUsage: z.number().describe("今日用量(美元)"), + todayCallCount: z.number().describe("今日调用次数"), + todayTokens: z.number().describe("今日 tokens"), + lastUsedAt: z.string().nullable().optional().describe("最后使用时间"), + lastProviderName: z.string().nullable().optional().describe("最后使用的供应商"), + modelStats: z.array(z.any()).describe("模型统计"), + createdAt: z.string().optional().describe("创建时间"), + createdAtFormatted: z.string().describe("格式化后的创建时间"), + canLoginWebUi: z.boolean().describe("是否允许登录 Web UI"), + limit5hUsd: z.number().nullable().describe("5 小时限额"), + limitDailyUsd: z.number().nullable().describe("每日限额"), + dailyResetMode: z.enum(["fixed", "rolling"]).describe("每日重置模式"), + dailyResetTime: z.string().describe("每日重置时间"), + limitWeeklyUsd: z.number().nullable().describe("周限额"), + limitMonthlyUsd: z.number().nullable().describe("月限额"), + limitTotalUsd: z.number().nullable().optional().describe("总限额"), + limitConcurrentSessions: z.number().describe("并发 Session 上限"), + costResetAt: z.string().nullable().optional().describe("限额重置时间"), + providerGroup: z.string().nullable().optional().describe("密钥供应商分组"), +}); + +const userListItemSchema = z.object({ + id: z.number().describe("用户 ID"), + name: z.string().describe("用户名"), + note: z.string().nullable().optional().describe("备注"), + role: z.enum(["admin", "user"]).describe("用户角色"), + isEnabled: z.boolean().describe("是否启用"), + expiresAt: z.string().nullable().optional().describe("过期时间"), + rpm: z.number().nullable().describe("每分钟请求数限制"), + dailyQuota: z.number().nullable().describe("每日消费额度(美元)"), + providerGroup: z.string().nullable().optional().describe("供应商分组"), + tags: z.array(z.string()).optional().describe("用户标签"), + limit5hUsd: z.number().nullable().optional().describe("5小时消费上限"), + limitWeeklyUsd: z.number().nullable().optional().describe("周消费上限"), + limitMonthlyUsd: z.number().nullable().optional().describe("月消费上限"), + limitTotalUsd: z.number().nullable().optional().describe("总消费上限"), + costResetAt: z.string().nullable().optional().describe("限额重置时间"), + limitConcurrentSessions: z.number().nullable().optional().describe("并发Session上限"), + dailyResetMode: z.enum(["fixed", "rolling"]).optional().describe("每日重置模式"), + dailyResetTime: z.string().optional().describe("每日重置时间"), + allowedClients: z.array(z.string()).optional().describe("允许的客户端"), + blockedClients: z.array(z.string()).optional().describe("禁止的客户端"), + allowedModels: z.array(z.string()).optional().describe("允许的模型"), + keys: z.array(userKeyListItemSchema).describe("用户下的密钥列表"), +}); + +const getUsersBatchRequestSchema = z + .object({ + cursor: z.string().optional(), + limit: z.number().int().positive().optional(), + searchTerm: z.string().optional(), + query: z.string().optional(), + keyword: z.string().optional(), + page: z.number().int().min(0).optional(), + offset: z.number().int().min(0).optional(), + tagFilters: z.array(z.string()).optional(), + keyGroupFilters: z.array(z.string()).optional(), + statusFilter: z + .enum(["all", "active", "expired", "expiringSoon", "enabled", "disabled"]) + .optional(), + sortBy: z + .enum([ + "name", + "tags", + "expiresAt", + "rpm", + "limit5hUsd", + "limitDailyUsd", + "limitWeeklyUsd", + "limitMonthlyUsd", + "createdAt", + ]) + .optional(), + sortOrder: z.enum(["asc", "desc"]).optional(), + }) + .passthrough(); + +const getUsersBatchResponseSchema = z.object({ + users: z.array(userListItemSchema), + nextCursor: z.string().nullable(), + hasMore: z.boolean(), +}); + const { route: getUsersRoute, handler: getUsersHandler } = createActionRoute( "users", "getUsers", userActions.getUsers, { - requestSchema: z.object({}).describe("无需请求参数"), - responseSchema: z.array( - z.object({ - id: z.number().describe("用户 ID"), - name: z.string().describe("用户名"), - note: z.string().nullable().describe("备注"), - role: z.enum(["admin", "user"]).describe("用户角色"), - isEnabled: z.boolean().describe("是否启用"), - expiresAt: z.string().nullable().describe("过期时间"), - rpm: z.number().describe("每分钟请求数限制"), - dailyQuota: z.number().describe("每日消费额度(美元)"), - providerGroup: z.string().nullable().describe("供应商分组"), - tags: z.array(z.string()).describe("用户标签"), - limit5hUsd: z.number().nullable().describe("5小时消费上限"), - limitWeeklyUsd: z.number().nullable().describe("周消费上限"), - limitMonthlyUsd: z.number().nullable().describe("月消费上限"), - limitTotalUsd: z.number().nullable().describe("总消费上限"), - limitConcurrentSessions: z.number().nullable().describe("并发Session上限"), - createdAt: z.string().describe("创建时间"), + requestSchema: z + .object({ + cursor: z.string().optional().describe("游标,兼容旧 offset 游标"), + limit: z.number().int().positive().optional().describe("返回条数上限"), + searchTerm: z.string().optional().describe("搜索用户名/备注/标签/密钥"), + query: z.string().optional().describe("旧版搜索参数别名"), + keyword: z.string().optional().describe("旧版搜索参数别名"), + page: z.number().int().min(0).optional().describe("旧版页码,从 0 或 1 开始"), + offset: z.number().int().min(0).optional().describe("旧版偏移量"), }) - ), + .passthrough() + .describe("兼容旧客户端的可选分页/搜索参数;不传时返回全部用户"), + responseSchema: z.array(userListItemSchema), description: "获取用户列表 (管理员获取所有用户,普通用户仅获取自己)", summary: "获取用户列表", tags: ["用户管理"], allowReadOnlyAccess: true, + argsMapper: (body) => [body], } ); app.openapi(getUsersRoute, getUsersHandler); +const { route: getUsersBatchRoute, handler: getUsersBatchHandler } = createActionRoute( + "users", + "getUsersBatch", + userActions.getUsersBatch, + { + requestSchema: getUsersBatchRequestSchema, + responseSchema: getUsersBatchResponseSchema, + description: "分页获取用户列表(兼容旧客户端 page/offset/query/keyword 参数)", + summary: "分页获取用户列表", + tags: ["用户管理"], + requiredRole: "admin", + argsMapper: (body) => [body], + } +); +app.openapi(getUsersBatchRoute, getUsersBatchHandler); + +const { route: searchUsersRoute, handler: searchUsersHandler } = createActionRoute( + "users", + "searchUsers", + userActions.searchUsers, + { + requestSchema: z + .object({ + searchTerm: z.string().optional(), + query: z.string().optional(), + keyword: z.string().optional(), + }) + .passthrough(), + responseSchema: z.array( + z.object({ + id: z.number(), + name: z.string(), + }) + ), + description: "搜索用户(兼容旧客户端 searchUsers 接口名)", + summary: "搜索用户", + tags: ["用户管理"], + requiredRole: "admin", + argsMapper: (body) => { + const searchTerm = [body.searchTerm, body.query, body.keyword] + .map((value: string | undefined) => value?.trim()) + .find((value): value is string => Boolean(value)); + + return [searchTerm]; + }, + } +); +app.openapi(searchUsersRoute, searchUsersHandler); + const { route: addUserRoute, handler: addUserHandler } = createActionRoute( "users", "addUser", diff --git a/tests/api/api-actions-integrity.test.ts b/tests/api/api-actions-integrity.test.ts index b9c556528..062b369b5 100644 --- a/tests/api/api-actions-integrity.test.ts +++ b/tests/api/api-actions-integrity.test.ts @@ -61,6 +61,8 @@ describe("OpenAPI 端点完整性检查", () => { test("用户管理模块的所有端点应该被注册", () => { const expectedPaths = [ "/api/actions/users/getUsers", + "/api/actions/users/getUsersBatch", + "/api/actions/users/searchUsers", "/api/actions/users/addUser", "/api/actions/users/editUser", "/api/actions/users/removeUser", diff --git a/tests/api/api-endpoints.test.ts b/tests/api/api-endpoints.test.ts index 2655e2e83..f512c9643 100644 --- a/tests/api/api-endpoints.test.ts +++ b/tests/api/api-endpoints.test.ts @@ -164,6 +164,8 @@ describe("API 端点可达性测试", () => { const criticalEndpoints = [ // 用户管理 { module: "users", action: "getUsers" }, + { module: "users", action: "getUsersBatch" }, + { module: "users", action: "searchUsers" }, { module: "users", action: "addUser" }, { module: "users", action: "editUser" }, { module: "users", action: "removeUser" }, diff --git a/tests/api/api-openapi-spec.test.ts b/tests/api/api-openapi-spec.test.ts index 9e5543a7f..c6b648879 100644 --- a/tests/api/api-openapi-spec.test.ts +++ b/tests/api/api-openapi-spec.test.ts @@ -46,6 +46,29 @@ type OpenAPIDocument = { }; }; +type JsonSchemaProperty = { + minimum?: number; + maximum?: number; + properties?: Record; +}; + +function getJsonRequestSchema( + openApiDoc: OpenAPIDocument, + path: string +): JsonSchemaProperty | undefined { + const requestBody = openApiDoc.paths[path]?.post?.requestBody as + | { + content?: { + "application/json"?: { + schema?: JsonSchemaProperty; + }; + }; + } + | undefined; + + return requestBody?.content?.["application/json"]?.schema; +} + describe("OpenAPI 规范验证", () => { let openApiDoc: OpenAPIDocument; @@ -218,4 +241,15 @@ describe("OpenAPI 规范验证", () => { // 但不应该太多(允许 35% 以内) expect(violations.length).toBeLessThan(totalPaths * 0.35); }); + + test("users 列表请求 schema 应与兼容参数归一化保持一致", () => { + for (const path of ["/api/actions/users/getUsers", "/api/actions/users/getUsersBatch"]) { + const schema = getJsonRequestSchema(openApiDoc, path); + const pageSchema = schema?.properties?.page; + const limitSchema = schema?.properties?.limit; + + expect(pageSchema?.minimum).toBe(0); + expect(limitSchema?.maximum).toBeUndefined(); + } + }); }); diff --git a/tests/api/users-search-users-compat.test.ts b/tests/api/users-search-users-compat.test.ts new file mode 100644 index 000000000..0ff9b5113 --- /dev/null +++ b/tests/api/users-search-users-compat.test.ts @@ -0,0 +1,71 @@ +import { beforeEach, describe, expect, test, vi } from "vitest"; + +const searchUsersMock = vi.fn(); +const validateAuthTokenMock = vi.fn(); +const runWithAuthSessionMock = vi.fn(); + +vi.mock("@/actions/users", () => ({ + getUsers: vi.fn(), + getUsersBatch: vi.fn(), + searchUsers: searchUsersMock, + addUser: vi.fn(), + editUser: vi.fn(), + removeUser: vi.fn(), + getUserLimitUsage: vi.fn(), +})); + +vi.mock("@/lib/auth", async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + AUTH_COOKIE_NAME: "auth-token", + validateAuthToken: validateAuthTokenMock, + runWithAuthSession: runWithAuthSessionMock, + }; +}); + +describe("users searchUsers route compatibility", () => { + beforeEach(() => { + vi.resetModules(); + searchUsersMock.mockReset(); + validateAuthTokenMock.mockReset(); + runWithAuthSessionMock.mockReset(); + + validateAuthTokenMock.mockResolvedValue({ + user: { id: 1, role: "admin" }, + key: { canLoginWebUi: true }, + }); + runWithAuthSessionMock.mockImplementation(async (_session, callback) => callback()); + searchUsersMock.mockResolvedValue({ + ok: true, + data: [{ id: 1, name: "Alice" }], + }); + }); + + test("falls back to trimmed query when searchTerm is blank", async () => { + const { POST } = await import("@/app/api/actions/[...route]/route"); + + const response = await POST( + new Request("http://localhost/api/actions/users/searchUsers", { + method: "POST", + headers: { + "content-type": "application/json", + authorization: "Bearer test-token", + cookie: "auth-token=test-token", + }, + body: JSON.stringify({ + searchTerm: " ", + query: " alice ", + keyword: "bob", + }), + }) + ); + + expect(response.status).toBe(200); + expect(searchUsersMock).toHaveBeenCalledWith("alice"); + await expect(response.json()).resolves.toEqual({ + ok: true, + data: [{ id: 1, name: "Alice" }], + }); + }); +}); diff --git a/tests/unit/users-action-get-users-compat.test.ts b/tests/unit/users-action-get-users-compat.test.ts new file mode 100644 index 000000000..002cacaf7 --- /dev/null +++ b/tests/unit/users-action-get-users-compat.test.ts @@ -0,0 +1,310 @@ +import { beforeEach, describe, expect, test, vi } from "vitest"; +import type { User } from "@/types/user"; + +const getSessionMock = vi.fn(); +vi.mock("@/lib/auth", () => ({ + getSession: getSessionMock, +})); + +vi.mock("next/cache", () => ({ + revalidatePath: vi.fn(), +})); + +const getTranslationsMock = vi.fn(async () => (key: string) => key); +const getLocaleMock = vi.fn(async () => "en"); +vi.mock("next-intl/server", () => ({ + getTranslations: getTranslationsMock, + getLocale: getLocaleMock, +})); + +const findUserByIdMock = vi.fn(); +const findUserListBatchMock = vi.fn(); +vi.mock("@/repository/user", async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + findUserById: findUserByIdMock, + findUserListBatch: findUserListBatchMock, + }; +}); + +const findKeyListBatchMock = vi.fn(); +const findKeyUsageTodayBatchMock = vi.fn(); +const findKeysStatisticsBatchFromKeysMock = vi.fn(); +vi.mock("@/repository/key", async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + findKeyListBatch: findKeyListBatchMock, + findKeyUsageTodayBatch: findKeyUsageTodayBatchMock, + findKeysStatisticsBatchFromKeys: findKeysStatisticsBatchFromKeysMock, + }; +}); + +function makeUser(id: number, name = `user-${id}`): User { + return { + id, + name, + description: `${name}-desc`, + role: "user", + rpm: null, + dailyQuota: null, + providerGroup: null, + tags: [], + createdAt: new Date("2026-03-01T00:00:00.000Z"), + updatedAt: new Date("2026-03-01T00:00:00.000Z"), + deletedAt: undefined, + dailyResetMode: "fixed", + dailyResetTime: "00:00", + isEnabled: true, + expiresAt: null, + allowedClients: [], + blockedClients: [], + allowedModels: [], + }; +} + +describe("getUsers compatibility", () => { + beforeEach(() => { + getSessionMock.mockReset(); + findUserByIdMock.mockReset(); + findUserListBatchMock.mockReset(); + findKeyListBatchMock.mockReset(); + findKeyUsageTodayBatchMock.mockReset(); + findKeysStatisticsBatchFromKeysMock.mockReset(); + + getSessionMock.mockResolvedValue({ + user: { id: 1, role: "admin" }, + key: { canLoginWebUi: true }, + }); + findKeyListBatchMock.mockResolvedValue(new Map()); + findKeyUsageTodayBatchMock.mockResolvedValue(new Map()); + findKeysStatisticsBatchFromKeysMock.mockResolvedValue(new Map()); + }); + + test("loads all admin users instead of stopping at the first 50", async () => { + const firstPageUsers = Array.from({ length: 200 }, (_, index) => makeUser(index + 1)); + const secondPageUser = makeUser(201, "after-first-200"); + + findUserListBatchMock + .mockResolvedValueOnce({ + users: firstPageUsers, + nextCursor: '{"v":"2026-03-01T00:00:00.000Z","id":200}', + hasMore: true, + }) + .mockResolvedValueOnce({ + users: [secondPageUser], + nextCursor: null, + hasMore: false, + }); + + const { getUsers } = await import("@/actions/users"); + + const result = await getUsers(); + + expect(findUserListBatchMock).toHaveBeenNthCalledWith(1, { + cursor: undefined, + searchTerm: undefined, + tagFilters: undefined, + keyGroupFilters: undefined, + statusFilter: undefined, + limit: 200, + sortBy: undefined, + sortOrder: undefined, + }); + expect(findUserListBatchMock).toHaveBeenNthCalledWith(2, { + cursor: '{"v":"2026-03-01T00:00:00.000Z","id":200}', + searchTerm: undefined, + tagFilters: undefined, + keyGroupFilters: undefined, + statusFilter: undefined, + limit: 200, + sortBy: undefined, + sortOrder: undefined, + }); + expect(result).toHaveLength(201); + expect(result.at(-1)?.name).toBe("after-first-200"); + }); + + test("normalizes legacy getUsers page and query params", async () => { + findUserListBatchMock.mockResolvedValueOnce({ + users: [makeUser(51, "xiaolunanbei")], + nextCursor: null, + hasMore: false, + }); + + const { getUsers } = await import("@/actions/users"); + + const result = await getUsers({ + page: 2, + limit: 50, + query: " 小鹿楠贝 ", + }); + + expect(findUserListBatchMock).toHaveBeenCalledWith({ + cursor: "50", + limit: 50, + searchTerm: "小鹿楠贝", + tagFilters: undefined, + keyGroupFilters: undefined, + statusFilter: undefined, + sortBy: undefined, + sortOrder: undefined, + }); + expect(result).toHaveLength(1); + expect(result[0]?.name).toBe("xiaolunanbei"); + }); + + test("falls back to legacy query when searchTerm is blank", async () => { + findUserListBatchMock.mockResolvedValueOnce({ + users: [makeUser(77, "legacy-query-hit")], + nextCursor: null, + hasMore: false, + }); + + const { getUsersBatch } = await import("@/actions/users"); + + await getUsersBatch({ + searchTerm: " ", + query: " alice ", + }); + + expect(findUserListBatchMock).toHaveBeenCalledWith({ + cursor: undefined, + limit: undefined, + searchTerm: "alice", + tagFilters: undefined, + keyGroupFilters: undefined, + statusFilter: undefined, + sortBy: undefined, + sortOrder: undefined, + }); + }); + + test("search-only getUsers requests keep paging until all matches are returned", async () => { + findUserListBatchMock + .mockResolvedValueOnce({ + users: Array.from({ length: 200 }, (_, index) => makeUser(index + 1, `match-${index + 1}`)), + nextCursor: '{"v":"2026-03-01T00:00:00.000Z","id":200}', + hasMore: true, + }) + .mockResolvedValueOnce({ + users: [makeUser(201, "match-201")], + nextCursor: null, + hasMore: false, + }); + + const { getUsers } = await import("@/actions/users"); + + const result = await getUsers({ query: "match" }); + + expect(findUserListBatchMock).toHaveBeenNthCalledWith(1, { + cursor: undefined, + limit: 200, + searchTerm: "match", + tagFilters: undefined, + keyGroupFilters: undefined, + statusFilter: undefined, + sortBy: undefined, + sortOrder: undefined, + }); + expect(findUserListBatchMock).toHaveBeenNthCalledWith(2, { + cursor: '{"v":"2026-03-01T00:00:00.000Z","id":200}', + limit: 200, + searchTerm: "match", + tagFilters: undefined, + keyGroupFilters: undefined, + statusFilter: undefined, + sortBy: undefined, + sortOrder: undefined, + }); + expect(result).toHaveLength(201); + expect(result.at(-1)?.name).toBe("match-201"); + }); + + test("treats whitespace cursor as missing pagination and keeps loading matches", async () => { + findUserListBatchMock + .mockResolvedValueOnce({ + users: Array.from({ length: 200 }, (_, index) => + makeUser(index + 1, `cursor-match-${index + 1}`) + ), + nextCursor: '{"v":"2026-03-01T00:00:00.000Z","id":200}', + hasMore: true, + }) + .mockResolvedValueOnce({ + users: [makeUser(201, "cursor-match-201")], + nextCursor: null, + hasMore: false, + }); + + const { getUsers } = await import("@/actions/users"); + + const result = await getUsers({ + cursor: " ", + query: "cursor-match", + }); + + expect(findUserListBatchMock).toHaveBeenNthCalledWith(1, { + cursor: undefined, + limit: 200, + searchTerm: "cursor-match", + tagFilters: undefined, + keyGroupFilters: undefined, + statusFilter: undefined, + sortBy: undefined, + sortOrder: undefined, + }); + expect(findUserListBatchMock).toHaveBeenNthCalledWith(2, { + cursor: '{"v":"2026-03-01T00:00:00.000Z","id":200}', + limit: 200, + searchTerm: "cursor-match", + tagFilters: undefined, + keyGroupFilters: undefined, + statusFilter: undefined, + sortBy: undefined, + sortOrder: undefined, + }); + expect(result).toHaveLength(201); + expect(result.at(-1)?.name).toBe("cursor-match-201"); + }); + + test("normalizes legacy getUsersBatch keyword and offset params", async () => { + findUserListBatchMock.mockResolvedValueOnce({ + users: [makeUser(88, "keyword-hit")], + nextCursor: null, + hasMore: false, + }); + + const { getUsersBatch } = await import("@/actions/users"); + + const result = await getUsersBatch({ + offset: 75, + limit: 25, + keyword: " key-word ", + }); + + expect(findUserListBatchMock).toHaveBeenCalledWith({ + cursor: "75", + limit: 25, + searchTerm: "key-word", + tagFilters: undefined, + keyGroupFilters: undefined, + statusFilter: undefined, + sortBy: undefined, + sortOrder: undefined, + }); + expect(result).toEqual({ + ok: true, + data: { + users: [ + expect.objectContaining({ + id: 88, + name: "keyword-hit", + }), + ], + nextCursor: null, + hasMore: false, + }, + }); + }); +}); diff --git a/tests/unit/users-action-search-users-for-filter.test.ts b/tests/unit/users-action-search-users-for-filter.test.ts index 6deb37ebb..5fd3d00b2 100644 --- a/tests/unit/users-action-search-users-for-filter.test.ts +++ b/tests/unit/users-action-search-users-for-filter.test.ts @@ -68,4 +68,16 @@ describe("searchUsersForFilter (action)", () => { expect(searchUsersForFilterRepositoryMock).toHaveBeenCalledWith("ali"); expect(result).toEqual({ ok: true, data: [{ id: 1, name: "Alice" }] }); }); + + test("searchUsers alias delegates to searchUsersForFilter", async () => { + getSessionMock.mockResolvedValue({ user: { id: 1, role: "admin" } }); + searchUsersForFilterRepositoryMock.mockResolvedValue([{ id: 9, name: "Bob" }]); + + const { searchUsers } = await import("@/actions/users"); + + const result = await searchUsers("bob"); + + expect(searchUsersForFilterRepositoryMock).toHaveBeenCalledWith("bob"); + expect(result).toEqual({ ok: true, data: [{ id: 9, name: "Bob" }] }); + }); }); From 96b116ef86f59f706affd791530e4ac1904a3a16 Mon Sep 17 00:00:00 2001 From: Ding <44717411+ding113@users.noreply.github.com> Date: Mon, 23 Mar 2026 14:09:40 +0800 Subject: [PATCH 11/11] fix: address PR #968 review feedback (#970) * fix: address review feedback for release v0.6.6 Amp-Thread-ID: https://ampcode.com/threads/T-019d18a9-9c99-7379-879c-3f50eb6612fe Co-authored-by: Amp * fix: address remaining review nits for PR 968 Amp-Thread-ID: https://ampcode.com/threads/T-019d18a9-9c99-7379-879c-3f50eb6612fe Co-authored-by: Amp * fix: address follow-up review findings on PR 970 Amp-Thread-ID: https://ampcode.com/threads/T-019d18a9-9c99-7379-879c-3f50eb6612fe Co-authored-by: Amp * fix: address second-round PR review feedback Amp-Thread-ID: https://ampcode.com/threads/T-019d18a9-9c99-7379-879c-3f50eb6612fe Co-authored-by: Amp --------- Co-authored-by: Amp --- .github/workflows/dev.yml | 1 + .../ja/settings/providers/form/sections.json | 2 +- src/actions/my-usage.ts | 75 +++++++ src/actions/usage-logs.ts | 11 ++ src/actions/users.ts | 48 +++-- .../_components/rate-limit-top-users.tsx | 14 +- .../_components/user/forms/add-key-form.tsx | 11 +- .../_components/user/forms/edit-key-form.tsx | 12 +- .../_components/user/forms/user-form.tsx | 1 + .../logs/_components/usage-logs-filters.tsx | 15 +- .../[locale]/dashboard/quotas/users/page.tsx | 30 ++- .../_components/rate-limit-filters.tsx | 43 +++- .../dashboard/users/users-page-client.tsx | 3 +- .../_components/usage-logs-section.test.tsx | 40 +++- .../_components/usage-logs-section.tsx | 3 - .../my-usage/_components/usage-logs-table.tsx | 9 +- src/app/api/actions/[...route]/route.ts | 52 +++++ src/app/v1/_lib/proxy/forwarder.ts | 79 +++++--- src/app/v1/_lib/proxy/provider-selector.ts | 10 +- src/app/v1/_lib/proxy/session.ts | 10 +- src/repository/usage-logs.ts | 185 +++++++++++++++++- src/repository/user.ts | 7 +- tests/api/api-actions-integrity.test.ts | 1 + tests/api/api-openapi-spec.test.ts | 1 + .../integration/billing-model-source.test.ts | 1 + .../actions/my-usage-date-range-dst.test.ts | 32 +++ .../usage-logs-export-retry-count.test.ts | 6 +- ...dashboard-logs-export-progress-ui.test.tsx | 17 +- .../proxy-forwarder-hedge-first-byte.test.ts | 8 +- tests/unit/proxy/session.test.ts | 7 +- .../usage-logs-slim-pagination.test.ts | 125 ++++++++++++ ...ers-action-search-users-for-filter.test.ts | 4 +- 32 files changed, 756 insertions(+), 107 deletions(-) create mode 100644 tests/unit/repository/usage-logs-slim-pagination.test.ts diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index ac4047e45..6e06f3664 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -67,6 +67,7 @@ jobs: bun install bun run typecheck bun run format + bun run lint - name: Commit formatted code run: | diff --git a/messages/ja/settings/providers/form/sections.json b/messages/ja/settings/providers/form/sections.json index 1a07d581c..57b6ec123 100644 --- a/messages/ja/settings/providers/form/sections.json +++ b/messages/ja/settings/providers/form/sections.json @@ -104,7 +104,7 @@ "formats": "http://、https://、socks5://、socks4:// プロトコルに対応。認証が必要な場合は http://user:password@host:port 形式を使用してください(パスワード内の特殊文字は URL エンコードが必要です。例: # → %23)", "label": "プロキシ URL", "optional": "(任意)", - "placeholder": "例: http://proxy.example.com:8080 または http://user:pass@proxy:8080" + "placeholder": "例: http://proxy.example.com:8080、http://user:pass@proxy:8080、socks5://proxy.example.com:1080" } }, "rateLimit": { diff --git a/src/actions/my-usage.ts b/src/actions/my-usage.ts index 0a37de5d4..d5db9bfe7 100644 --- a/src/actions/my-usage.ts +++ b/src/actions/my-usage.ts @@ -17,6 +17,7 @@ import { EXCLUDE_WARMUP_CONDITION } from "@/repository/_shared/message-request-c import { getSystemSettings } from "@/repository/system-config"; import { findUsageLogsForKeyBatch, + findUsageLogsForKeySlim, getDistinctEndpointsForKey, getDistinctModelsForKey, type UsageLogSlimBatchResult, @@ -177,6 +178,28 @@ export interface MyUsageLogsBatchResult { billingModelSource: BillingModelSource; } +export interface MyUsageLogsFilters { + startDate?: string; + endDate?: string; + sessionId?: string; + model?: string; + statusCode?: number; + excludeStatusCode200?: boolean; + endpoint?: string; + minRetryCount?: number; + page?: number; + pageSize?: number; +} + +export interface MyUsageLogsResult { + logs: MyUsageLogEntry[]; + total: number; + page: number; + pageSize: number; + currencyCode: CurrencyCode; + billingModelSource: BillingModelSource; +} + // Infinity means "all time" - no date filter applied to the query const ALL_TIME_MAX_AGE_DAYS = Infinity; @@ -518,6 +541,58 @@ function mapMyUsageLogEntries( }); } +export async function getMyUsageLogs( + filters: MyUsageLogsFilters = {} +): Promise> { + try { + const session = await getSession({ allowReadOnlyAccess: true }); + if (!session) return { ok: false, error: "Unauthorized" }; + + const settings = await getSystemSettings(); + const timezone = await resolveSystemTimezone(); + const { startTime, endTime } = parseDateRangeInServerTimezone( + filters.startDate, + filters.endDate, + timezone + ); + const parsedPageSize = Number(filters.pageSize); + const pageSize = + Number.isFinite(parsedPageSize) && parsedPageSize > 0 + ? Math.min(Math.trunc(parsedPageSize), 100) + : 20; + const parsedPage = Number(filters.page); + const page = Number.isFinite(parsedPage) && parsedPage > 0 ? Math.trunc(parsedPage) : 1; + const result = await findUsageLogsForKeySlim({ + keyString: session.key.key, + sessionId: filters.sessionId, + startTime, + endTime, + model: filters.model, + statusCode: filters.statusCode, + excludeStatusCode200: filters.excludeStatusCode200, + endpoint: filters.endpoint, + minRetryCount: filters.minRetryCount, + page, + pageSize, + }); + + return { + ok: true, + data: { + logs: mapMyUsageLogEntries(result, settings.billingModelSource), + total: result.total, + page, + pageSize, + currencyCode: settings.currencyDisplay, + billingModelSource: settings.billingModelSource, + }, + }; + } catch (error) { + logger.error("[my-usage] getMyUsageLogs failed", { error, filters }); + return { ok: false, error: "Failed to get usage logs" }; + } +} + export async function getMyUsageLogsBatch( filters: MyUsageLogsBatchFilters = {} ): Promise> { diff --git a/src/actions/usage-logs.ts b/src/actions/usage-logs.ts index 59318db31..25a96fc3f 100644 --- a/src/actions/usage-logs.ts +++ b/src/actions/usage-logs.ts @@ -43,6 +43,7 @@ let filterOptionsCache: { const USAGE_LOGS_EXPORT_BATCH_SIZE = 500; const USAGE_LOGS_EXPORT_JOB_TTL_MS = 15 * 60 * 1000; const USAGE_LOGS_EXPORT_JOB_TTL_SECONDS = Math.floor(USAGE_LOGS_EXPORT_JOB_TTL_MS / 1000); +const USAGE_LOGS_EXPORT_PROGRESS_UPDATE_INTERVAL_MS = 800; const CSV_HEADERS = [ "Time", "User", @@ -228,7 +229,17 @@ async function runUsageLogsExportJob( }); try { + let lastProgressUpdateAt = 0; const csv = await buildUsageLogsExportCsv(filters, async (progress) => { + const now = Date.now(); + if ( + progress.progressPercent < 100 && + now - lastProgressUpdateAt < USAGE_LOGS_EXPORT_PROGRESS_UPDATE_INTERVAL_MS + ) { + return; + } + lastProgressUpdateAt = now; + const currentJob = await usageLogsExportStatusStore.get(jobId); if (!currentJob) { return; diff --git a/src/actions/users.ts b/src/actions/users.ts index 1bd9a8aac..4f782ac8f 100644 --- a/src/actions/users.ts +++ b/src/actions/users.ts @@ -68,6 +68,12 @@ export interface GetUsersBatchParams { const USER_LIST_DEFAULT_LIMIT = 50; const USER_LIST_MAX_LIMIT = 200; +const SEARCH_USERS_MAX_LIMIT = 5000; + +type UserActionSession = { + user: { id: number }; + key: { canLoginWebUi: boolean }; +}; function normalizeLegacySearchTerm(params?: GetUsersBatchParams): string | undefined { for (const candidate of [params?.searchTerm, params?.query, params?.keyword]) { @@ -162,6 +168,20 @@ async function loadAllUsersForAdmin(baseParams?: GetUsersBatchParams): Promise, + isAdmin: boolean +): boolean { + return session.key.canLoginWebUi && (isAdmin || session.user.id === targetUser.id); +} + /** * 批量获取用户列表的返回结果。 */ @@ -387,10 +407,7 @@ export async function getUsers(params?: GetUsersBatchParams): Promise { const stats = statisticsLookup.get(key.id); - // 仅允许具备 Web UI 登录权限的会话查看/复制完整密钥,避免只读 Key 通过 - // allowReadOnlyAccess 进入 actions 后暴露 fullKey。 - const canUserManageKey = - session.key.canLoginWebUi && (isAdmin || session.user.id === user.id); + const canUserManageKey = canExposeFullKey(session, user, isAdmin); return { id: key.id, name: key.name, @@ -469,7 +486,8 @@ export async function getUsers(params?: GetUsersBatchParams): Promise>> { try { const tError = await getTranslations("errors"); @@ -491,7 +509,10 @@ export async function searchUsersForFilter( }; } - const users = await searchUsersForFilterRepository(searchTerm); + const users = await searchUsersForFilterRepository( + searchTerm, + normalizeSearchUsersLimit(limit) + ); return { ok: true, data: users }; } catch (error) { logger.error("Failed to search users for filter:", error); @@ -501,9 +522,10 @@ export async function searchUsersForFilter( } export async function searchUsers( - searchTerm?: string + searchTerm?: string, + limit?: number ): Promise>> { - return searchUsersForFilter(searchTerm); + return searchUsersForFilter(searchTerm, limit); } /** @@ -628,6 +650,7 @@ export async function getUsersBatch( const keys = keysMap.get(user.id) || []; const usageRecords = usageMap.get(user.id) || []; const keyStatistics = statisticsMap.get(user.id) || []; + const canUserManageKey = canExposeFullKey(session, user, true); const usageLookup = new Map( usageRecords.map((item) => [ @@ -665,8 +688,8 @@ export async function getUsersBatch( id: key.id, name: key.name, maskedKey: maskKey(key.key), - fullKey: key.key, - canCopy: true, + fullKey: canUserManageKey ? key.key : undefined, + canCopy: canUserManageKey, expiresAt: key.expiresAt ? key.expiresAt.toISOString().split("T")[0] : t("neverExpires"), @@ -781,6 +804,7 @@ export async function getUsersBatchCore( const userDisplays: UserDisplay[] = users.map((user) => { const keys = keysMap.get(user.id) || []; + const canUserManageKey = canExposeFullKey(session, user, true); return { id: user.id, @@ -808,8 +832,8 @@ export async function getUsersBatchCore( id: key.id, name: key.name, maskedKey: maskKey(key.key), - fullKey: key.key, - canCopy: true, + fullKey: canUserManageKey ? key.key : undefined, + canCopy: canUserManageKey, expiresAt: key.expiresAt ? key.expiresAt.toISOString().split("T")[0] : t("neverExpires"), status: key.isEnabled ? "enabled" : ("disabled" as const), createdAt: key.createdAt, diff --git a/src/app/[locale]/dashboard/_components/rate-limit-top-users.tsx b/src/app/[locale]/dashboard/_components/rate-limit-top-users.tsx index aaadcfb74..826098fb0 100644 --- a/src/app/[locale]/dashboard/_components/rate-limit-top-users.tsx +++ b/src/app/[locale]/dashboard/_components/rate-limit-top-users.tsx @@ -28,9 +28,11 @@ type SortDirection = "asc" | "desc"; */ export function RateLimitTopUsers({ data }: RateLimitTopUsersProps) { const t = useTranslations("dashboard.rateLimits.topUsers"); + const tRateLimits = useTranslations("dashboard.rateLimits"); const locale = useLocale(); const [users, setUsers] = React.useState>([]); const [loading, setLoading] = React.useState(true); + const [loadError, setLoadError] = React.useState(false); const [sortField, setSortField] = React.useState("count"); const [sortDirection, setSortDirection] = React.useState("desc"); @@ -38,15 +40,18 @@ export function RateLimitTopUsers({ data }: RateLimitTopUsersProps) { React.useEffect(() => { let cancelled = false; - void searchUsers() + void searchUsers(undefined, 5000) .then((result) => { if (!cancelled) { setUsers(result.ok ? result.data : []); + setLoadError(!result.ok); } }) - .catch(() => { + .catch((error) => { + console.error("RateLimitTopUsers: failed to load users", error); if (!cancelled) { setUsers([]); + setLoadError(true); } }) .finally(() => { @@ -110,10 +115,13 @@ export function RateLimitTopUsers({ data }: RateLimitTopUsersProps) {
) : tableData.length === 0 ? (
- {t("noData")} + {loadError ? tRateLimits("error") : t("noData")}
) : (
+ {loadError ? ( +
{tRateLimits("error")}
+ ) : null} diff --git a/src/app/[locale]/dashboard/_components/user/forms/add-key-form.tsx b/src/app/[locale]/dashboard/_components/user/forms/add-key-form.tsx index 2d05b5960..0754aa389 100644 --- a/src/app/[locale]/dashboard/_components/user/forms/add-key-form.tsx +++ b/src/app/[locale]/dashboard/_components/user/forms/add-key-form.tsx @@ -127,12 +127,11 @@ export function AddKeyForm({ userId, user, isAdmin = false, onSuccess }: AddKeyF const handleProviderGroupChange = useCallback( (newValue: string) => { const groups = parseProviderGroups(newValue); - if (groups.length > 1 && groups.includes(PROVIDER_GROUP.DEFAULT)) { - const withoutDefault = groups.filter((g) => g !== PROVIDER_GROUP.DEFAULT); - form.setValue("providerGroup", withoutDefault.join(",")); - } else { - form.setValue("providerGroup", newValue); - } + const normalizedGroups = + groups.length > 1 && groups.includes(PROVIDER_GROUP.DEFAULT) + ? groups.filter((g) => g !== PROVIDER_GROUP.DEFAULT) + : groups; + form.setValue("providerGroup", normalizedGroups.join(",")); }, [form] ); diff --git a/src/app/[locale]/dashboard/_components/user/forms/edit-key-form.tsx b/src/app/[locale]/dashboard/_components/user/forms/edit-key-form.tsx index 1a2698447..6251550b5 100644 --- a/src/app/[locale]/dashboard/_components/user/forms/edit-key-form.tsx +++ b/src/app/[locale]/dashboard/_components/user/forms/edit-key-form.tsx @@ -185,12 +185,11 @@ export function EditKeyForm({ keyData, user, isAdmin = false, onSuccess }: EditK const handleProviderGroupChange = useCallback( (newValue: string) => { const groups = parseProviderGroups(newValue); - if (groups.length > 1 && groups.includes(PROVIDER_GROUP.DEFAULT)) { - const withoutDefault = groups.filter((g) => g !== PROVIDER_GROUP.DEFAULT); - form.setValue("providerGroup", withoutDefault.join(",")); - } else { - form.setValue("providerGroup", newValue); - } + const normalizedGroups = + groups.length > 1 && groups.includes(PROVIDER_GROUP.DEFAULT) + ? groups.filter((g) => g !== PROVIDER_GROUP.DEFAULT) + : groups; + form.setValue("providerGroup", normalizedGroups.join(",")); }, [form] ); @@ -262,6 +261,7 @@ export function EditKeyForm({ keyData, user, isAdmin = false, onSuccess }: EditK : t("providerGroup.description") } suggestions={providerGroupSuggestions} + // Provider groups intentionally allow shared parser semantics without extra format restrictions. validateTag={() => true} onInvalidTag={(_tag, reason) => { const messages: Record = { diff --git a/src/app/[locale]/dashboard/_components/user/forms/user-form.tsx b/src/app/[locale]/dashboard/_components/user/forms/user-form.tsx index e0f1ba852..db0289ee8 100644 --- a/src/app/[locale]/dashboard/_components/user/forms/user-form.tsx +++ b/src/app/[locale]/dashboard/_components/user/forms/user-form.tsx @@ -217,6 +217,7 @@ export function UserForm({ user, onSuccess, currentUser }: UserFormProps) { placeholder={tForm("providerGroup.placeholder")} description={tForm("providerGroup.description")} suggestions={providerGroupSuggestions} + // Provider groups intentionally accept shared parser output without extra format validation. validateTag={() => true} onInvalidTag={(_tag, reason) => { const messages: Record = { diff --git a/src/app/[locale]/dashboard/logs/_components/usage-logs-filters.tsx b/src/app/[locale]/dashboard/logs/_components/usage-logs-filters.tsx index 83e67a6f4..9f1e70f89 100644 --- a/src/app/[locale]/dashboard/logs/_components/usage-logs-filters.tsx +++ b/src/app/[locale]/dashboard/logs/_components/usage-logs-filters.tsx @@ -13,6 +13,7 @@ import { } from "@/actions/usage-logs"; import { Button } from "@/components/ui/button"; import { Progress } from "@/components/ui/progress"; +import { getErrorMessage } from "@/lib/utils/error-messages"; import type { Key } from "@/types/key"; import type { ProviderDisplay } from "@/types/provider"; import { ActiveFiltersDisplay } from "./filters/active-filters-display"; @@ -73,6 +74,7 @@ export function UsageLogsFilters({ serverTimeZone, }: UsageLogsFiltersProps) { const t = useTranslations("dashboard"); + const tErrors = useTranslations("errors"); const [localFilters, setLocalFilters] = useState(filters); const [isExporting, setIsExporting] = useState(false); @@ -190,7 +192,7 @@ export function UsageLogsFilters({ }); try { - const exportFilters = sanitizeFilters(filters); + const exportFilters = sanitizeFilters(localFilters); const startResult = await startUsageLogsExport(exportFilters); if (exportRunIdRef.current !== runId) { return; @@ -198,7 +200,8 @@ export function UsageLogsFilters({ if (!startResult.ok) { setExportStatus(null); - toast.error(startResult.error || t("logs.filters.exportError")); + console.error("Failed to start usage logs export", startResult.error); + toast.error(t("logs.filters.exportError")); return; } @@ -224,7 +227,7 @@ export function UsageLogsFilters({ if (!statusResult.ok) { setExportStatus(null); - toast.error(statusResult.error || t("logs.filters.exportError")); + toast.error(t("logs.filters.exportError")); return; } @@ -248,7 +251,11 @@ export function UsageLogsFilters({ } if (!downloadResult.ok) { - toast.error(downloadResult.error || t("logs.filters.exportError")); + toast.error( + downloadResult.errorCode + ? getErrorMessage(tErrors, downloadResult.errorCode, downloadResult.errorParams) + : t("logs.filters.exportError") + ); return; } diff --git a/src/app/[locale]/dashboard/quotas/users/page.tsx b/src/app/[locale]/dashboard/quotas/users/page.tsx index ee52b8b68..3f68ff410 100644 --- a/src/app/[locale]/dashboard/quotas/users/page.tsx +++ b/src/app/[locale]/dashboard/quotas/users/page.tsx @@ -1,7 +1,7 @@ import { Info } from "lucide-react"; import { getTranslations } from "next-intl/server"; import { Suspense } from "react"; -import { getUserLimitUsage, getUsers } from "@/actions/users"; +import { getUserLimitUsage, getUsersBatch } from "@/actions/users"; import { QuotaToolbar } from "@/components/quota/quota-toolbar"; import { Alert, AlertDescription } from "@/components/ui/alert"; import { Link, redirect } from "@/i18n/routing"; @@ -9,6 +9,7 @@ import { getSession } from "@/lib/auth"; import { resolveKeyCostResetAt } from "@/lib/rate-limit/cost-reset-utils"; import { sumKeyTotalCostBatchByIds, sumUserTotalCostBatch } from "@/repository/statistics"; import { getSystemSettings } from "@/repository/system-config"; +import type { UserDisplay } from "@/types/user"; import { UsersQuotaSkeleton } from "../_components/users-quota-skeleton"; import type { UserKeyWithUsage, UserQuotaWithUsage } from "./_components/types"; import { UsersQuotaClient } from "./_components/users-quota-client"; @@ -17,7 +18,32 @@ import { UsersQuotaClient } from "./_components/users-quota-client"; export const dynamic = "force-dynamic"; async function getUsersWithQuotas(): Promise { - const users = await getUsers(); + const collectedUsers: UserDisplay[] = []; + const MAX_USERS_FOR_QUOTAS = 2000; + const MAX_ITERATIONS = Math.ceil(MAX_USERS_FOR_QUOTAS / 200) + 1; + let cursor: string | undefined; + let iterations = 0; + + while (collectedUsers.length < MAX_USERS_FOR_QUOTAS && iterations < MAX_ITERATIONS) { + iterations += 1; + const result = await getUsersBatch({ cursor, limit: 200 }); + if (!result.ok) { + throw new Error(result.error); + } + + collectedUsers.push(...result.data.users); + if (!result.data.hasMore || !result.data.nextCursor) { + break; + } + + cursor = result.data.nextCursor; + } + + if (iterations >= MAX_ITERATIONS) { + console.warn("getUsersWithQuotas: reached max iterations, results may be incomplete"); + } + + const users = collectedUsers; const allUserIds = users.map((u) => u.id); const allKeyIds = users.flatMap((u) => u.keys.map((k) => k.id)); diff --git a/src/app/[locale]/dashboard/rate-limits/_components/rate-limit-filters.tsx b/src/app/[locale]/dashboard/rate-limits/_components/rate-limit-filters.tsx index 865658226..ad84a931c 100644 --- a/src/app/[locale]/dashboard/rate-limits/_components/rate-limit-filters.tsx +++ b/src/app/[locale]/dashboard/rate-limits/_components/rate-limit-filters.tsx @@ -43,6 +43,7 @@ export function RateLimitFilters({ disabled = false, }: RateLimitFiltersProps) { const t = useTranslations("dashboard.rateLimits.filters"); + const tRateLimits = useTranslations("dashboard.rateLimits"); const [userId, setUserId] = React.useState(initialFilters.user_id); const [providerId, setProviderId] = React.useState( @@ -58,6 +59,8 @@ export function RateLimitFilters({ const [providers, setProviders] = React.useState>([]); const [loadingUsers, setLoadingUsers] = React.useState(true); const [loadingProviders, setLoadingProviders] = React.useState(true); + const [usersLoadError, setUsersLoadError] = React.useState(false); + const [providersLoadError, setProvidersLoadError] = React.useState(false); // 加载用户列表 React.useEffect(() => { @@ -67,11 +70,14 @@ export function RateLimitFilters({ .then((result) => { if (!cancelled) { setUsers(result.ok ? result.data : []); + setUsersLoadError(!result.ok); } }) - .catch(() => { + .catch((error) => { + console.error("RateLimitFilters: failed to load users", error); if (!cancelled) { setUsers([]); + setUsersLoadError(true); } }) .finally(() => { @@ -87,10 +93,31 @@ export function RateLimitFilters({ // 加载供应商列表 React.useEffect(() => { - getProviders().then((providerList) => { - setProviders(providerList); - setLoadingProviders(false); - }); + let cancelled = false; + + void getProviders() + .then((providerList) => { + if (!cancelled) { + setProviders(providerList); + setProvidersLoadError(false); + } + }) + .catch((error) => { + console.error("RateLimitFilters: failed to load providers", error); + if (!cancelled) { + setProviders([]); + setProvidersLoadError(true); + } + }) + .finally(() => { + if (!cancelled) { + setLoadingProviders(false); + } + }); + + return () => { + cancelled = true; + }; }, []); // 应用过滤器 @@ -154,6 +181,9 @@ export function RateLimitFilters({ {/* 用户选择器 */}
+ {usersLoadError ? ( +

{tRateLimits("error")}

+ ) : null} setProviderId(value === "all" ? undefined : Number(value))} diff --git a/src/app/[locale]/dashboard/users/users-page-client.tsx b/src/app/[locale]/dashboard/users/users-page-client.tsx index 971cc51e7..8b8cc8347 100644 --- a/src/app/[locale]/dashboard/users/users-page-client.tsx +++ b/src/app/[locale]/dashboard/users/users-page-client.tsx @@ -35,8 +35,7 @@ import { CreateUserDialog } from "../_components/user/create-user-dialog"; import { UserManagementTable } from "../_components/user/user-management-table"; /** - * Split comma-separated tags into an array of trimmed, non-empty strings. - * This matches the server-side providerGroup handling in provider-selector.ts + * Normalize provider-group tags with the shared parser to keep client/server behavior aligned. */ function splitTags(value?: string | null): string[] { return parseProviderGroups(value); diff --git a/src/app/[locale]/my-usage/_components/usage-logs-section.test.tsx b/src/app/[locale]/my-usage/_components/usage-logs-section.test.tsx index 9f05d63aa..30e944429 100644 --- a/src/app/[locale]/my-usage/_components/usage-logs-section.test.tsx +++ b/src/app/[locale]/my-usage/_components/usage-logs-section.test.tsx @@ -74,14 +74,33 @@ import { UsageLogsSection } from "./usage-logs-section"; describe("my-usage usage logs section", () => { test("uses infinite query instead of the old page-based getMyUsageLogs flow", async () => { - mocks.useInfiniteQuery.mockReturnValue({ - data: { pages: [{ logs: [], nextCursor: null, hasMore: false }] }, - fetchNextPage: vi.fn(), - hasNextPage: false, - isFetchingNextPage: false, - isLoading: false, - isError: false, - error: null, + let capturedQueryFn: + | ((context: { + pageParam?: { createdAt: string; id: number } | undefined; + }) => Promise) + | undefined; + + mocks.useInfiniteQuery.mockImplementation((options: { queryFn: typeof capturedQueryFn }) => { + capturedQueryFn = options.queryFn; + return { + data: { pages: [{ logs: [], nextCursor: null, hasMore: false }] }, + fetchNextPage: vi.fn(), + hasNextPage: false, + isFetchingNextPage: false, + isLoading: false, + isError: false, + error: null, + }; + }); + mocks.getMyUsageLogsBatch.mockResolvedValue({ + ok: true, + data: { + logs: [], + nextCursor: null, + hasMore: false, + currencyCode: "USD", + billingModelSource: "original", + }, }); mocks.getMyUsageLogs.mockResolvedValue({ ok: true, @@ -98,7 +117,12 @@ describe("my-usage usage logs section", () => { root.render(); }); + await act(async () => { + await capturedQueryFn?.({ pageParam: undefined }); + }); + expect(mocks.useInfiniteQuery).toHaveBeenCalled(); + expect(mocks.getMyUsageLogsBatch).toHaveBeenCalled(); expect(mocks.getMyUsageLogs).not.toHaveBeenCalled(); await act(async () => { diff --git a/src/app/[locale]/my-usage/_components/usage-logs-section.tsx b/src/app/[locale]/my-usage/_components/usage-logs-section.tsx index 9c048594a..f916198af 100644 --- a/src/app/[locale]/my-usage/_components/usage-logs-section.tsx +++ b/src/app/[locale]/my-usage/_components/usage-logs-section.tsx @@ -117,7 +117,6 @@ export function UsageLogsSection({ error, isRefetching = false, } = query; - const refetch = query.refetch ?? (async (): Promise => undefined); const logs = useMemo(() => data?.pages.flatMap((page) => page.logs) ?? [], [data]); const latestPage = data?.pages[0]; @@ -168,7 +167,6 @@ export function UsageLogsSection({ const handleApply = () => { const nextFilters = { ...draftFilters }; if (JSON.stringify(nextFilters) === JSON.stringify(appliedFilters)) { - void refetch(); return; } setAppliedFilters(nextFilters); @@ -177,7 +175,6 @@ export function UsageLogsSection({ const handleReset = () => { setDraftFilters({}); if (Object.keys(appliedFilters).length === 0) { - void refetch(); return; } setAppliedFilters({}); diff --git a/src/app/[locale]/my-usage/_components/usage-logs-table.tsx b/src/app/[locale]/my-usage/_components/usage-logs-table.tsx index a55b9a1ec..3693742fc 100644 --- a/src/app/[locale]/my-usage/_components/usage-logs-table.tsx +++ b/src/app/[locale]/my-usage/_components/usage-logs-table.tsx @@ -201,13 +201,8 @@ export function UsageLogsTable({ }} className="flex items-center justify-center border-b" > - {errorMessage && onLoadMore ? ( -
- {errorMessage} - -
+ {errorMessage ? ( + {errorMessage} ) : ( )} diff --git a/src/app/api/actions/[...route]/route.ts b/src/app/api/actions/[...route]/route.ts index fa11cb663..ac20853fd 100644 --- a/src/app/api/actions/[...route]/route.ts +++ b/src/app/api/actions/[...route]/route.ts @@ -1173,6 +1173,58 @@ const { route: getMyTodayStatsRoute, handler: getMyTodayStatsHandler } = createA ); app.openapi(getMyTodayStatsRoute, getMyTodayStatsHandler); +const { route: getMyUsageLogsRoute, handler: getMyUsageLogsHandler } = createActionRoute( + "my-usage", + "getMyUsageLogs", + myUsageActions.getMyUsageLogs, + { + requestSchema: z.object({ + startDate: z.string().optional().describe("开始日期(YYYY-MM-DD,可为空)"), + endDate: z.string().optional().describe("结束日期(YYYY-MM-DD,可为空)"), + sessionId: z.string().optional(), + model: z.string().optional(), + endpoint: z.string().optional(), + statusCode: z.number().optional(), + excludeStatusCode200: z.boolean().optional(), + minRetryCount: z.number().int().nonnegative().optional(), + page: z.number().int().positive().default(1).optional(), + pageSize: z.number().int().positive().max(100).default(20).optional(), + }), + responseSchema: z.object({ + logs: z.array( + z.object({ + id: z.number(), + createdAt: z.string().nullable(), + model: z.string().nullable(), + billingModel: z.string().nullable(), + modelRedirect: z.string().nullable(), + inputTokens: z.number(), + outputTokens: z.number(), + cost: z.number(), + statusCode: z.number().nullable(), + duration: z.number().nullable(), + endpoint: z.string().nullable(), + cacheCreationInputTokens: z.number().nullable(), + cacheReadInputTokens: z.number().nullable(), + cacheCreation5mInputTokens: z.number().nullable(), + cacheCreation1hInputTokens: z.number().nullable(), + cacheTtlApplied: z.string().nullable(), + }) + ), + total: z.number(), + page: z.number().int(), + pageSize: z.number().int(), + currencyCode: z.string(), + billingModelSource: z.enum(["original", "redirected"]), + }), + description: "获取当前用户的使用日志(兼容旧版 page/pageSize 分页接口)", + summary: "获取当前用户的使用日志(兼容旧版)", + tags: ["使用日志分析"], + allowReadOnlyAccess: true, + } +); +app.openapi(getMyUsageLogsRoute, getMyUsageLogsHandler); + const { route: getMyUsageLogsBatchRoute, handler: getMyUsageLogsBatchHandler } = createActionRoute( "my-usage", "getMyUsageLogsBatch", diff --git a/src/app/v1/_lib/proxy/forwarder.ts b/src/app/v1/_lib/proxy/forwarder.ts index b6814f9cf..dc7c87c9d 100644 --- a/src/app/v1/_lib/proxy/forwarder.ts +++ b/src/app/v1/_lib/proxy/forwarder.ts @@ -3001,7 +3001,7 @@ export class ProxyForwarder { } attempts.delete(attempt); if (reason === "hedge_loser") { - attempt.session.addProviderToChain(attempt.provider, { + session.addProviderToChain(attempt.provider, { ...attempt.endpointAudit, reason: "hedge_loser_cancelled", attemptNumber: attempt.sequence, @@ -3009,15 +3009,49 @@ export class ProxyForwarder { } try { attempt.responseController?.abort(new Error(reason)); - } catch { - // ignore + } catch (abortError) { + logger.debug("ProxyForwarder: hedge attempt abort failed", { + error: abortError instanceof Error ? abortError.message : String(abortError), + reason, + sessionId: attempt.session.sessionId ?? null, + providerId: attempt.provider.id, + providerName: attempt.provider.name, + }); } const readerCancel = attempt.reader?.cancel(); - readerCancel?.catch(() => { - // ignore + readerCancel?.catch((cancelError) => { + logger.debug("ProxyForwarder: hedge attempt reader cancel failed", { + error: cancelError instanceof Error ? cancelError.message : String(cancelError), + reason, + sessionId: attempt.session.sessionId ?? null, + providerId: attempt.provider.id, + providerName: attempt.provider.name, + }); }); }; + const armAttemptThreshold = (attempt: StreamingHedgeAttempt) => { + if (attempt.thresholdTimer) { + clearTimeout(attempt.thresholdTimer); + attempt.thresholdTimer = null; + } + attempt.thresholdTriggered = false; + + if (attempt.firstByteTimeoutMs <= 0) return; + + attempt.thresholdTimer = setTimeout(() => { + if (settled || attempt.settled || attempt.thresholdTriggered) return; + attempt.thresholdTriggered = true; + session.addProviderToChain(attempt.provider, { + ...attempt.endpointAudit, + reason: "hedge_triggered", + attemptNumber: attempt.sequence, + circuitState: getCircuitState(attempt.provider.id), + }); + void launchAlternative(); + }, attempt.firstByteTimeoutMs); + }; + const abortAllAttempts = (winner?: StreamingHedgeAttempt, reason: string = "hedge_loser") => { for (const attempt of Array.from(attempts)) { if (winner && attempt === winner) continue; @@ -3080,7 +3114,7 @@ export class ProxyForwarder { const runAttempt = (attempt: StreamingHedgeAttempt) => { const providerForRequest = - attempt.provider.firstByteTimeoutStreamingMs > 0 + attempt.firstByteTimeoutMs > 0 ? { ...attempt.provider, firstByteTimeoutStreamingMs: 0 } : attempt.provider; @@ -3096,12 +3130,22 @@ export class ProxyForwarder { const attemptRuntime = attempt.session as ProxySessionWithAttemptRuntime; try { attemptRuntime.responseController?.abort(new Error("hedge_loser")); - } catch { - // ignore + } catch (abortError) { + logger.debug("ProxyForwarder: hedge loser abort failed", { + error: abortError instanceof Error ? abortError.message : String(abortError), + sessionId: attempt.session.sessionId ?? null, + providerId: attempt.provider.id, + providerName: attempt.provider.name, + }); } const cancelPromise = response.body?.cancel("hedge_loser"); - cancelPromise?.catch(() => { - // ignore + cancelPromise?.catch((cancelError) => { + logger.debug("ProxyForwarder: hedge loser body cancel failed", { + error: cancelError instanceof Error ? cancelError.message : String(cancelError), + sessionId: attempt.session.sessionId ?? null, + providerId: attempt.provider.id, + providerName: attempt.provider.name, + }); }); return; } @@ -3253,6 +3297,7 @@ export class ProxyForwarder { attempt.thresholdTimer = null; } attempt.requestAttemptCount += 1; + armAttemptThreshold(attempt); runAttempt(attempt); return; } @@ -3448,19 +3493,7 @@ export class ProxyForwarder { }); } - if (attempt.firstByteTimeoutMs > 0) { - attempt.thresholdTimer = setTimeout(() => { - if (settled || attempt.settled || attempt.thresholdTriggered) return; - attempt.thresholdTriggered = true; - attempt.session.addProviderToChain(attempt.provider, { - ...attempt.endpointAudit, - reason: "hedge_triggered", - attemptNumber: attempt.sequence, - circuitState: getCircuitState(attempt.provider.id), - }); - void launchAlternative(); - }, attempt.firstByteTimeoutMs); - } + armAttemptThreshold(attempt); runAttempt(attempt); }; diff --git a/src/app/v1/_lib/proxy/provider-selector.ts b/src/app/v1/_lib/proxy/provider-selector.ts index 27e6a11d0..f22973980 100644 --- a/src/app/v1/_lib/proxy/provider-selector.ts +++ b/src/app/v1/_lib/proxy/provider-selector.ts @@ -48,10 +48,6 @@ async function getVerboseProviderErrorCached(): Promise { * @param groupString - 逗号分隔的分组字符串 * @returns 清理后的分组数组(去空格、去空项) */ -function parseGroupString(groupString: string): string[] { - return parseProviderGroups(groupString); -} - /** * 获取有效的供应商分组(优先级:key.providerGroup > user.providerGroup) * @@ -80,14 +76,14 @@ function getEffectiveProviderGroup(session?: ProxySession): string | null { * @returns 是否存在交集(true = 匹配) */ function checkProviderGroupMatch(providerGroupTag: string | null, userGroups: string): boolean { - const groups = parseGroupString(userGroups); + const groups = parseProviderGroups(userGroups); if (groups.includes(PROVIDER_GROUP.ALL)) { return true; } const providerTags = providerGroupTag - ? parseGroupString(providerGroupTag) + ? parseProviderGroups(providerGroupTag) : [PROVIDER_GROUP.DEFAULT]; return providerTags.some((tag) => groups.includes(tag)); @@ -1097,7 +1093,7 @@ export class ProxyProviderResolver { */ static resolveEffectivePriority(provider: Provider, userGroup: string | null): number { if (userGroup && provider.groupPriorities) { - const groups = parseGroupString(userGroup); + const groups = parseProviderGroups(userGroup); const overrides = groups .map((g) => provider.groupPriorities?.[g]) .filter((v): v is number => v !== undefined); diff --git a/src/app/v1/_lib/proxy/session.ts b/src/app/v1/_lib/proxy/session.ts index 341e6f3e9..443ace197 100644 --- a/src/app/v1/_lib/proxy/session.ts +++ b/src/app/v1/_lib/proxy/session.ts @@ -743,8 +743,10 @@ export class ProxySession { } if (!this.hasUsableBillingSettings()) { - logger.warn("[ProxySession] Billing settings unavailable, skip pricing resolution"); - return null; + logger.warn("[ProxySession] Billing settings unavailable, using fallback billing source", { + billingSettingsSource: this.billingSettingsSource, + fallbackBillingModelSource: this.cachedBillingModelSource, + }); } const providerIdentity = provider ?? this.provider; @@ -889,7 +891,9 @@ export class ProxySession { } private hasUsableBillingSettings(): boolean { - return this.billingSettingsSource !== "default"; + return ( + this.cachedBillingModelSource === "original" || this.cachedBillingModelSource === "redirected" + ); } } diff --git a/src/repository/usage-logs.ts b/src/repository/usage-logs.ts index df60a7e58..7b893129f 100644 --- a/src/repository/usage-logs.ts +++ b/src/repository/usage-logs.ts @@ -282,11 +282,11 @@ export async function findUsageLogsBatch( ledgerConditions.push(eq(usageLedger.sessionId, trimmedSessionId)); } - if (filters.startTime) { + if (filters.startTime !== undefined) { ledgerConditions.push(gte(usageLedger.createdAt, new Date(filters.startTime))); } - if (filters.endTime) { + if (filters.endTime !== undefined) { ledgerConditions.push(lt(usageLedger.createdAt, new Date(filters.endTime))); } @@ -451,6 +451,187 @@ export interface UsageLogSlimBatchResult { hasMore: boolean; } +const usageLogSlimTotalCache = new TTLMap({ ttlMs: 10_000, maxSize: 1000 }); + +export async function findUsageLogsForKeySlim( + filters: UsageLogSlimFilters & { page?: number; pageSize?: number } +): Promise<{ logs: UsageLogSlimRow[]; total: number }> { + const { keyString, page = 1, pageSize = 50 } = filters; + const safePage = page > 0 ? page : 1; + const safePageSize = Math.min(100, Math.max(1, pageSize)); + + const conditions = [ + isNull(messageRequest.deletedAt), + eq(messageRequest.key, keyString), + EXCLUDE_WARMUP_CONDITION, + ]; + const totalCacheKey = [ + keyString, + filters.sessionId?.trim() ?? "", + filters.startTime ?? "", + filters.endTime ?? "", + filters.statusCode ?? "", + filters.excludeStatusCode200 ? "1" : "0", + filters.model ?? "", + filters.endpoint ?? "", + filters.minRetryCount ?? "", + ].join("\u0001"); + + conditions.push(...buildUsageLogConditions(filters)); + + const offset = (safePage - 1) * safePageSize; + const results = await db + .select({ + id: messageRequest.id, + createdAt: messageRequest.createdAt, + model: messageRequest.model, + originalModel: messageRequest.originalModel, + endpoint: messageRequest.endpoint, + statusCode: messageRequest.statusCode, + inputTokens: messageRequest.inputTokens, + outputTokens: messageRequest.outputTokens, + costUsd: messageRequest.costUsd, + durationMs: messageRequest.durationMs, + cacheCreationInputTokens: messageRequest.cacheCreationInputTokens, + cacheReadInputTokens: messageRequest.cacheReadInputTokens, + cacheCreation5mInputTokens: messageRequest.cacheCreation5mInputTokens, + cacheCreation1hInputTokens: messageRequest.cacheCreation1hInputTokens, + cacheTtlApplied: messageRequest.cacheTtlApplied, + specialSettings: messageRequest.specialSettings, + }) + .from(messageRequest) + .where(and(...conditions)) + .orderBy(desc(messageRequest.createdAt), desc(messageRequest.id)) + .limit(safePageSize + 1) + .offset(offset); + + const hasMore = results.length > safePageSize; + const pageRows = hasMore ? results.slice(0, safePageSize) : results; + + if (pageRows.length === 0 && (await isLedgerOnlyMode())) { + if (filters.minRetryCount !== undefined && filters.minRetryCount > 0) { + return { logs: [], total: 0 }; + } + + const ledgerConditions = [LEDGER_BILLING_CONDITION, eq(usageLedger.key, keyString)]; + const trimmedSessionId = filters.sessionId?.trim(); + if (trimmedSessionId) { + ledgerConditions.push(eq(usageLedger.sessionId, trimmedSessionId)); + } + if (filters.startTime !== undefined) { + ledgerConditions.push(gte(usageLedger.createdAt, new Date(filters.startTime))); + } + if (filters.endTime !== undefined) { + ledgerConditions.push(lt(usageLedger.createdAt, new Date(filters.endTime))); + } + if (filters.statusCode !== undefined) { + ledgerConditions.push(eq(usageLedger.statusCode, filters.statusCode)); + } else if (filters.excludeStatusCode200) { + ledgerConditions.push( + sql`(${usageLedger.statusCode} IS NULL OR ${usageLedger.statusCode} <> 200)` + ); + } + if (filters.model) { + ledgerConditions.push(eq(usageLedger.model, filters.model)); + } + if (filters.endpoint) { + ledgerConditions.push(eq(usageLedger.endpoint, filters.endpoint)); + } + + const ledgerResults = await db + .select({ + id: usageLedger.requestId, + createdAt: usageLedger.createdAt, + model: usageLedger.model, + originalModel: usageLedger.originalModel, + endpoint: usageLedger.endpoint, + statusCode: usageLedger.statusCode, + inputTokens: usageLedger.inputTokens, + outputTokens: usageLedger.outputTokens, + costUsd: usageLedger.costUsd, + durationMs: usageLedger.durationMs, + cacheCreationInputTokens: usageLedger.cacheCreationInputTokens, + cacheReadInputTokens: usageLedger.cacheReadInputTokens, + cacheCreation5mInputTokens: usageLedger.cacheCreation5mInputTokens, + cacheCreation1hInputTokens: usageLedger.cacheCreation1hInputTokens, + cacheTtlApplied: usageLedger.cacheTtlApplied, + }) + .from(usageLedger) + .where(and(...ledgerConditions)) + .orderBy(desc(usageLedger.createdAt), desc(usageLedger.requestId)) + .limit(safePageSize + 1) + .offset(offset); + + const ledgerHasMore = ledgerResults.length > safePageSize; + const ledgerPageRows = ledgerHasMore ? ledgerResults.slice(0, safePageSize) : ledgerResults; + let ledgerTotal = offset + ledgerPageRows.length; + + const cachedTotal = usageLogSlimTotalCache.get(totalCacheKey); + if (cachedTotal !== undefined) { + ledgerTotal = Math.max(cachedTotal, ledgerTotal); + return { + logs: ledgerPageRows.map((row) => ({ + ...row, + costUsd: row.costUsd?.toString() ?? null, + anthropicEffort: null, + })), + total: ledgerTotal, + }; + } + + if (ledgerPageRows.length === 0 && offset > 0) { + const countResults = await db + .select({ totalRows: sql`count(*)::double precision` }) + .from(usageLedger) + .where(and(...ledgerConditions)); + ledgerTotal = countResults[0]?.totalRows ?? 0; + } else if (ledgerHasMore) { + const countResults = await db + .select({ totalRows: sql`count(*)::double precision` }) + .from(usageLedger) + .where(and(...ledgerConditions)); + ledgerTotal = countResults[0]?.totalRows ?? 0; + } + + const ledgerLogs: UsageLogSlimRow[] = ledgerPageRows.map((row) => ({ + ...row, + costUsd: row.costUsd?.toString() ?? null, + anthropicEffort: null, + })); + + usageLogSlimTotalCache.set(totalCacheKey, ledgerTotal); + return { logs: ledgerLogs, total: ledgerTotal }; + } + + let total = offset + pageRows.length; + const cachedTotal = usageLogSlimTotalCache.get(totalCacheKey); + if (cachedTotal !== undefined) { + total = Math.max(cachedTotal, total); + return { + logs: pageRows.map((row) => mapUsageLogSlimRow(row)), + total, + }; + } + + if (pageRows.length === 0 && offset > 0) { + const countResults = await db + .select({ totalRows: sql`count(*)::double precision` }) + .from(messageRequest) + .where(and(...conditions)); + total = countResults[0]?.totalRows ?? 0; + } else if (hasMore) { + const countResults = await db + .select({ totalRows: sql`count(*)::double precision` }) + .from(messageRequest) + .where(and(...conditions)); + total = countResults[0]?.totalRows ?? 0; + } + + const logs: UsageLogSlimRow[] = pageRows.map((row) => mapUsageLogSlimRow(row)); + usageLogSlimTotalCache.set(totalCacheKey, total); + return { logs, total }; +} + function buildNextCursorOrThrow( hasMore: boolean, lastRow: diff --git a/src/repository/user.ts b/src/repository/user.ts index eb4836df8..290b36b7d 100644 --- a/src/repository/user.ts +++ b/src/repository/user.ts @@ -134,7 +134,8 @@ export async function findUserList(limit: number = 50, offset: number = 0): Prom } export async function searchUsersForFilter( - searchTerm?: string + searchTerm?: string, + limit = 200 ): Promise> { const conditions = [isNull(users.deletedAt)]; @@ -152,7 +153,7 @@ export async function searchUsersForFilter( .from(users) .where(and(...conditions)) .orderBy(sql`CASE WHEN ${users.role} = 'admin' THEN 0 ELSE 1 END`, users.id) - .limit(200); + .limit(Math.max(1, Math.min(limit, 5000))); } /** Sort columns that are NOT NULL and support keyset cursor pagination */ @@ -246,7 +247,7 @@ export async function findUserListBatch( if (trimmedGroups.length > 0) { const groupConditions = trimmedGroups.map( (group) => - sql`${group} = ANY(regexp_split_to_array(coalesce(${users.providerGroup}, ''), '\\s*[,,]+\\s*'))` + sql`${group} = ANY(regexp_split_to_array(coalesce(${users.providerGroup}, ''), '\\s*[,,\n\r]+\\s*'))` ); keyGroupFilterCondition = sql`(${sql.join(groupConditions, sql` OR `)})`; } diff --git a/tests/api/api-actions-integrity.test.ts b/tests/api/api-actions-integrity.test.ts index 062b369b5..0a33d78db 100644 --- a/tests/api/api-actions-integrity.test.ts +++ b/tests/api/api-actions-integrity.test.ts @@ -149,6 +149,7 @@ describe("OpenAPI 端点完整性检查", () => { "/api/actions/my-usage/getMyUsageMetadata", "/api/actions/my-usage/getMyQuota", "/api/actions/my-usage/getMyTodayStats", + "/api/actions/my-usage/getMyUsageLogs", "/api/actions/my-usage/getMyUsageLogsBatch", "/api/actions/my-usage/getMyAvailableModels", "/api/actions/my-usage/getMyAvailableEndpoints", diff --git a/tests/api/api-openapi-spec.test.ts b/tests/api/api-openapi-spec.test.ts index c6b648879..825c7d32c 100644 --- a/tests/api/api-openapi-spec.test.ts +++ b/tests/api/api-openapi-spec.test.ts @@ -249,6 +249,7 @@ describe("OpenAPI 规范验证", () => { const limitSchema = schema?.properties?.limit; expect(pageSchema?.minimum).toBe(0); + expect(limitSchema).toBeDefined(); expect(limitSchema?.maximum).toBeUndefined(); } }); diff --git a/tests/integration/billing-model-source.test.ts b/tests/integration/billing-model-source.test.ts index 798b0229c..3b7cb4450 100644 --- a/tests/integration/billing-model-source.test.ts +++ b/tests/integration/billing-model-source.test.ts @@ -91,6 +91,7 @@ import { findLatestPriceByModel } from "@/repository/model-price"; import { getSystemSettings } from "@/repository/system-config"; beforeEach(() => { + vi.clearAllMocks(); cloudPriceSyncRequests.splice(0, cloudPriceSyncRequests.length); invalidateSystemSettingsCache(); }); diff --git a/tests/unit/actions/my-usage-date-range-dst.test.ts b/tests/unit/actions/my-usage-date-range-dst.test.ts index a47af0a71..b8390e56b 100644 --- a/tests/unit/actions/my-usage-date-range-dst.test.ts +++ b/tests/unit/actions/my-usage-date-range-dst.test.ts @@ -98,4 +98,36 @@ describe("my-usage date range parsing", () => { expect(args.endTime - args.startTime).toBe(25 * 60 * 60 * 1000); }); + + it("computes DST-safe range for legacy page-based logs API", async () => { + const tz = "America/Los_Angeles"; + mocks.resolveSystemTimezone.mockResolvedValue(tz); + + mocks.getSession.mockResolvedValue({ + key: { id: 1, key: "k" }, + user: { id: 1 }, + }); + + mocks.getSystemSettings.mockResolvedValue({ + currencyDisplay: "USD", + billingModelSource: "original", + }); + + mocks.findUsageLogsForKeySlim.mockResolvedValue({ + logs: [], + total: 0, + }); + + const { getMyUsageLogs } = await import("@/actions/my-usage"); + const res = await getMyUsageLogs({ startDate: "2024-03-10", endDate: "2024-03-10" }); + + expect(res.ok).toBe(true); + expect(mocks.findUsageLogsForKeySlim).toHaveBeenCalledTimes(1); + + const args = mocks.findUsageLogsForKeySlim.mock.calls[0]?.[0]; + expect(args.startTime).toBe(fromZonedTime("2024-03-10T00:00:00", tz).getTime()); + expect(args.endTime).toBe(fromZonedTime("2024-03-11T00:00:00", tz).getTime()); + expect(args.page).toBe(1); + expect(args.pageSize).toBe(20); + }); }); diff --git a/tests/unit/actions/usage-logs-export-retry-count.test.ts b/tests/unit/actions/usage-logs-export-retry-count.test.ts index 5f4ba0935..69dd3f06c 100644 --- a/tests/unit/actions/usage-logs-export-retry-count.test.ts +++ b/tests/unit/actions/usage-logs-export-retry-count.test.ts @@ -1,4 +1,4 @@ -import { beforeEach, describe, expect, test, vi } from "vitest"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; const getSessionMock = vi.fn(); const findUsageLogsWithDetailsMock = vi.fn(); @@ -167,6 +167,10 @@ describe("Usage logs CSV export retryCount", () => { findUsageLogsStatsMock.mockResolvedValue(createSummary()); }); + afterEach(() => { + vi.useRealTimers(); + }); + test("exportUsageLogs: Retry Count 应对齐 getRetryCount(hedge race 为 0)", async () => { findUsageLogsWithDetailsMock.mockResolvedValue({ logs: [], diff --git a/tests/unit/dashboard-logs-export-progress-ui.test.tsx b/tests/unit/dashboard-logs-export-progress-ui.test.tsx index f96d10aad..016071b68 100644 --- a/tests/unit/dashboard-logs-export-progress-ui.test.tsx +++ b/tests/unit/dashboard-logs-export-progress-ui.test.tsx @@ -6,10 +6,14 @@ import type { ReactNode } from "react"; import { act } from "react"; import { createRoot } from "react-dom/client"; import { NextIntlClientProvider } from "next-intl"; -import { beforeEach, describe, expect, test, vi } from "vitest"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; import { UsageLogsFilters } from "@/app/[locale]/dashboard/logs/_components/usage-logs-filters"; import dashboardMessages from "../../messages/en/dashboard.json"; +const originalCreateObjectURL = globalThis.URL.createObjectURL; +const originalRevokeObjectURL = globalThis.URL.revokeObjectURL; +const originalAnchorClick = HTMLAnchorElement.prototype.click; + const { downloadUsageLogsExportMock, getUsageLogsExportStatusMock, @@ -126,6 +130,15 @@ describe("UsageLogsFilters export progress UI", () => { HTMLAnchorElement.prototype.click = vi.fn(); }); + afterEach(() => { + vi.useRealTimers(); + vi.restoreAllMocks(); + globalThis.URL.createObjectURL = originalCreateObjectURL; + globalThis.URL.revokeObjectURL = originalRevokeObjectURL; + HTMLAnchorElement.prototype.click = originalAnchorClick; + document.body.innerHTML = ""; + }); + test("shows export progress while polling and downloads when completed", async () => { startUsageLogsExportMock.mockResolvedValue({ ok: true, data: { jobId: "job-1" } }); getUsageLogsExportStatusMock @@ -218,7 +231,7 @@ describe("UsageLogsFilters export progress UI", () => { await actClick(exportButton ?? null); await flushPromises(); - expect(startUsageLogsExportMock).toHaveBeenCalledWith({ sessionId: "applied-session" }); + expect(startUsageLogsExportMock).toHaveBeenCalledWith({ sessionId: "draft-session" }); unmount(); }); diff --git a/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts b/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts index b4fd5012d..cb3117efc 100644 --- a/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts +++ b/tests/unit/proxy/proxy-forwarder-hedge-first-byte.test.ts @@ -875,7 +875,9 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { withThinkingBlocks(session); mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2); - mocks.categorizeErrorAsync.mockResolvedValue(ProxyErrorCategory.NON_RETRYABLE_CLIENT_ERROR); + mocks.categorizeErrorAsync.mockResolvedValueOnce( + ProxyErrorCategory.NON_RETRYABLE_CLIENT_ERROR + ); const signatureError = new UpstreamProxyError( "Invalid `signature` in `thinking` block", @@ -988,7 +990,9 @@ describe("ProxyForwarder - first-byte hedge scheduling", () => { }; mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2); - mocks.categorizeErrorAsync.mockResolvedValue(ProxyErrorCategory.NON_RETRYABLE_CLIENT_ERROR); + mocks.categorizeErrorAsync.mockResolvedValueOnce( + ProxyErrorCategory.NON_RETRYABLE_CLIENT_ERROR + ); const budgetError = new UpstreamProxyError( "thinking.enabled.budget_tokens: Input should be greater than or equal to 1024", diff --git a/tests/unit/proxy/session.test.ts b/tests/unit/proxy/session.test.ts index 3f3507405..804ecac02 100644 --- a/tests/unit/proxy/session.test.ts +++ b/tests/unit/proxy/session.test.ts @@ -267,7 +267,7 @@ describe("ProxySession.getCachedPriceDataByBillingSource", () => { expect(findLatestPriceByModel).toHaveBeenNthCalledWith(2, "redirected-model"); }); - it("应在 getSystemSettings 失败且无缓存时跳过价格解析", async () => { + it("应在 getSystemSettings 失败且无缓存时回退到 redirected 并继续价格解析", async () => { const redirectedPriceData: ModelPriceData = { input_cost_per_token: 3, output_cost_per_token: 4, @@ -284,9 +284,10 @@ describe("ProxySession.getCachedPriceDataByBillingSource", () => { }); const result = await session.getCachedPriceDataByBillingSource(); - expect(result).toBeNull(); + expect(result).toEqual(redirectedPriceData); expect(getSystemSettings).toHaveBeenCalledTimes(1); - expect(findLatestPriceByModel).not.toHaveBeenCalled(); + expect(findLatestPriceByModel).toHaveBeenCalledTimes(1); + expect(findLatestPriceByModel).toHaveBeenCalledWith("redirected-model"); const internal = session as unknown as { cachedBillingModelSource?: unknown }; expect(internal.cachedBillingModelSource).toBe("redirected"); diff --git a/tests/unit/repository/usage-logs-slim-pagination.test.ts b/tests/unit/repository/usage-logs-slim-pagination.test.ts new file mode 100644 index 000000000..ad7687904 --- /dev/null +++ b/tests/unit/repository/usage-logs-slim-pagination.test.ts @@ -0,0 +1,125 @@ +import { describe, expect, test, vi } from "vitest"; + +function createThenableQuery(result: T) { + const query: any = Promise.resolve(result); + query.from = vi.fn(() => query); + query.where = vi.fn(() => query); + query.orderBy = vi.fn(() => query); + query.limit = vi.fn(() => query); + query.offset = vi.fn(() => query); + return query; +} + +describe("findUsageLogsForKeySlim", () => { + test("clamps page/pageSize and returns fast total when current page is complete", async () => { + vi.resetModules(); + + const rows = [ + { + id: 1, + createdAt: new Date("2026-03-21T00:00:00Z"), + model: "m", + originalModel: "m", + endpoint: "/v1/messages", + statusCode: 200, + inputTokens: 1, + outputTokens: 2, + costUsd: "0.01", + durationMs: 10, + cacheCreationInputTokens: 0, + cacheReadInputTokens: 0, + cacheCreation5mInputTokens: 0, + cacheCreation1hInputTokens: 0, + cacheTtlApplied: null, + specialSettings: null, + }, + ]; + const logsQuery = createThenableQuery(rows); + const selectMock = vi + .fn() + .mockImplementationOnce(() => logsQuery) + .mockImplementationOnce(() => Promise.resolve([{ totalRows: 321 }])); + + vi.doMock("@/drizzle/db", () => ({ + db: { + select: selectMock, + }, + })); + vi.doMock("@/lib/ledger-fallback", () => ({ + isLedgerOnlyMode: vi.fn(async () => false), + })); + + const { findUsageLogsForKeySlim } = await import("@/repository/usage-logs"); + const result = await findUsageLogsForKeySlim({ keyString: "k", page: 0, pageSize: 999 }); + + expect(logsQuery.limit).toHaveBeenCalledWith(101); + expect(logsQuery.offset).toHaveBeenCalledWith(0); + expect(result.total).toBe(1); + expect(result.logs).toHaveLength(1); + }); + + test("runs count query when hasMore is true so total remains accurate", async () => { + vi.resetModules(); + + const logsQuery = createThenableQuery([ + { + id: 1, + createdAt: new Date("2026-03-21T00:00:00Z"), + model: "m", + originalModel: "m", + endpoint: "/v1/messages", + statusCode: 200, + inputTokens: 1, + outputTokens: 2, + costUsd: "0.01", + durationMs: 10, + cacheCreationInputTokens: 0, + cacheReadInputTokens: 0, + cacheCreation5mInputTokens: 0, + cacheCreation1hInputTokens: 0, + cacheTtlApplied: null, + specialSettings: null, + }, + { + id: 2, + createdAt: new Date("2026-03-20T00:00:00Z"), + model: "m", + originalModel: "m", + endpoint: "/v1/messages", + statusCode: 200, + inputTokens: 1, + outputTokens: 2, + costUsd: "0.01", + durationMs: 10, + cacheCreationInputTokens: 0, + cacheReadInputTokens: 0, + cacheCreation5mInputTokens: 0, + cacheCreation1hInputTokens: 0, + cacheTtlApplied: null, + specialSettings: null, + }, + ]); + const countQuery = createThenableQuery([{ totalRows: 321 }]); + const selectMock = vi + .fn() + .mockImplementationOnce(() => logsQuery) + .mockImplementationOnce(() => countQuery); + + vi.doMock("@/drizzle/db", () => ({ + db: { + select: selectMock, + }, + })); + vi.doMock("@/lib/ledger-fallback", () => ({ + isLedgerOnlyMode: vi.fn(async () => false), + })); + + const { findUsageLogsForKeySlim } = await import("@/repository/usage-logs"); + const result = await findUsageLogsForKeySlim({ keyString: "k", page: 1, pageSize: 1 }); + + expect(logsQuery.limit).toHaveBeenCalledWith(2); + expect(logsQuery.offset).toHaveBeenCalledWith(0); + expect(result.total).toBe(321); + expect(result.logs).toHaveLength(1); + }); +}); diff --git a/tests/unit/users-action-search-users-for-filter.test.ts b/tests/unit/users-action-search-users-for-filter.test.ts index 5fd3d00b2..6b059ec12 100644 --- a/tests/unit/users-action-search-users-for-filter.test.ts +++ b/tests/unit/users-action-search-users-for-filter.test.ts @@ -65,7 +65,7 @@ describe("searchUsersForFilter (action)", () => { const result = await searchUsersForFilter("ali"); - expect(searchUsersForFilterRepositoryMock).toHaveBeenCalledWith("ali"); + expect(searchUsersForFilterRepositoryMock).toHaveBeenCalledWith("ali", undefined); expect(result).toEqual({ ok: true, data: [{ id: 1, name: "Alice" }] }); }); @@ -77,7 +77,7 @@ describe("searchUsersForFilter (action)", () => { const result = await searchUsers("bob"); - expect(searchUsersForFilterRepositoryMock).toHaveBeenCalledWith("bob"); + expect(searchUsersForFilterRepositoryMock).toHaveBeenCalledWith("bob", undefined); expect(result).toEqual({ ok: true, data: [{ id: 9, name: "Bob" }] }); }); });