Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions src/core/request/error-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,25 @@ export class ErrorHandler {
let isPermanent = false
try {
const errorBody = await response.text()
const errorData = JSON.parse(errorBody)
if (errorData.reason === 'INVALID_MODEL_ID') {
throw new Error(`Invalid model: ${errorData.message}`)
}
if (errorData.reason === 'TEMPORARILY_SUSPENDED') {
errorReason = 'Account Suspended'
isPermanent = true
try {
const errorData = JSON.parse(errorBody)
if (errorData.reason === 'INVALID_MODEL_ID') {
throw new Error(`Invalid model: ${errorData.message}`)
}
if (errorData.reason === 'TEMPORARILY_SUSPENDED') {
errorReason = 'Account Suspended'
isPermanent = true
} else if (errorData.reason || errorData.message) {
const detail = errorData.reason
? `${errorData.reason}${errorData.message ? `: ${errorData.message}` : ''}`
: errorData.message
errorReason = `${errorReason} (${detail})`
}
} catch (parseError) {
if (errorBody) {
const trimmed = errorBody.replace(/\s+/g, ' ').trim().slice(0, 160)
if (trimmed) errorReason = `${errorReason} (${trimmed})`
}
}
} catch (e) {
if (e instanceof Error && e.message.includes('Invalid model')) {
Expand Down
29 changes: 23 additions & 6 deletions src/core/request/request-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ export class RequestHandler {

try {
const res = await fetch(prep.url, prep.init)

if (apiTimestamp) {
this.logResponse(res, prep, apiTimestamp)
}
Expand Down Expand Up @@ -135,7 +134,7 @@ export class RequestHandler {
continue
}

this.logError(prep, res, acc, apiTimestamp)
await this.logError(prep, res, acc, apiTimestamp)
throw new Error(`Kiro Error: ${res.status}`)
} catch (e) {
const networkResult = await this.errorHandler.handleNetworkError(
Expand Down Expand Up @@ -189,11 +188,12 @@ export class RequestHandler {
try {
b = prep.init.body ? JSON.parse(prep.init.body as string) : null
} catch {}
const headers = this.redactHeaders(prep.init.headers)
logger.logApiRequest(
{
url: prep.url,
method: prep.init.method,
headers: prep.init.headers,
headers,
body: b,
conversationId: prep.conversationId,
model: prep.effectiveModel,
Expand All @@ -220,20 +220,28 @@ export class RequestHandler {
)
}

private logError(
private async logError(
prep: PreparedRequest,
res: Response,
acc: ManagedAccount,
apiTimestamp: string | null
): void {
): Promise<void> {
const h: any = {}
res.headers.forEach((v, k) => {
h[k] = v
})
let errorBody: string | undefined
try {
errorBody = await res.text()
if (errorBody) {
errorBody = errorBody.replace(/\s+/g, ' ').trim().slice(0, 1000)
}
} catch {}
const rData = {
status: res.status,
statusText: res.statusText,
headers: h,
body: errorBody,
error: `Kiro Error: ${res.status}`,
conversationId: prep.conversationId,
model: prep.effectiveModel
Expand All @@ -242,12 +250,13 @@ export class RequestHandler {
try {
lastB = prep.init.body ? JSON.parse(prep.init.body as string) : null
} catch {}
const headers = this.redactHeaders(prep.init.headers)
if (!this.config.enable_log_api_request) {
logger.logApiError(
{
url: prep.url,
method: prep.init.method,
headers: prep.init.headers,
headers,
body: lastB,
conversationId: prep.conversationId,
model: prep.effectiveModel,
Expand All @@ -259,6 +268,14 @@ export class RequestHandler {
}
}

private redactHeaders(headers: any): any {
if (!headers || typeof headers !== 'object') return headers
const clone = { ...headers }
if ('Authorization' in clone) clone.Authorization = 'REDACTED'
if ('authorization' in clone) clone.authorization = 'REDACTED'
return clone
}

private allAccountsPermanentlyUnhealthy(): boolean {
const accounts = this.accountManager.getAccounts()
if (accounts.length === 0) {
Expand Down
27 changes: 25 additions & 2 deletions src/plugin/storage/locked-operations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,25 @@ export function mergeAccounts(
const existingAcc = accountMap.get(acc.id)

if (existingAcc) {
const refreshChanged =
typeof acc.refreshToken === 'string' && acc.refreshToken !== existingAcc.refreshToken
const accessChanged =
typeof acc.accessToken === 'string' && acc.accessToken !== existingAcc.accessToken
const clientIdChanged =
typeof acc.clientId === 'string' && acc.clientId !== existingAcc.clientId
const clientSecretChanged =
typeof acc.clientSecret === 'string' && acc.clientSecret !== existingAcc.clientSecret
const incomingIsFresh = (acc.lastSync || 0) >= (existingAcc.lastSync || 0)
const allowRecovery =
refreshChanged ||
accessChanged ||
clientIdChanged ||
clientSecretChanged ||
(acc.isHealthy && incomingIsFresh)

const hasPermanentError =
isPermanentError(existingAcc.unhealthyReason) || isPermanentError(acc.unhealthyReason)
!allowRecovery &&
(isPermanentError(existingAcc.unhealthyReason) || isPermanentError(acc.unhealthyReason))

accountMap.set(acc.id, {
...existingAcc,
Expand All @@ -77,7 +94,13 @@ export function mergeAccounts(
acc.rateLimitResetTime || 0
),
isHealthy: hasPermanentError ? false : existingAcc.isHealthy || acc.isHealthy,
failCount: Math.max(existingAcc.failCount || 0, acc.failCount || 0),
unhealthyReason: hasPermanentError
? existingAcc.unhealthyReason || acc.unhealthyReason
: acc.unhealthyReason,
recoveryTime: hasPermanentError ? existingAcc.recoveryTime : acc.recoveryTime,
failCount: hasPermanentError
? Math.max(existingAcc.failCount || 0, acc.failCount || 0)
: acc.failCount || 0,
lastSync: Math.max(existingAcc.lastSync || 0, acc.lastSync || 0)
})
} else {
Expand Down
27 changes: 19 additions & 8 deletions src/plugin/storage/sqlite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,25 @@ export class KiroDatabase {
is_healthy, unhealthy_reason, recovery_time, fail_count, last_used,
used_count, limit_count, last_sync
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(refresh_token) DO UPDATE SET
id=excluded.id, email=excluded.email, auth_method=excluded.auth_method,
region=excluded.region, client_id=excluded.client_id, client_secret=excluded.client_secret,
profile_arn=excluded.profile_arn, access_token=excluded.access_token, expires_at=excluded.expires_at,
rate_limit_reset=excluded.rate_limit_reset, is_healthy=excluded.is_healthy,
unhealthy_reason=excluded.unhealthy_reason, recovery_time=excluded.recovery_time,
fail_count=excluded.fail_count, last_used=excluded.last_used,
used_count=excluded.used_count, limit_count=excluded.limit_count, last_sync=excluded.last_sync
ON CONFLICT(id) DO UPDATE SET
email=excluded.email,
auth_method=excluded.auth_method,
region=excluded.region,
client_id=excluded.client_id,
client_secret=excluded.client_secret,
profile_arn=excluded.profile_arn,
refresh_token=excluded.refresh_token,
access_token=excluded.access_token,
expires_at=excluded.expires_at,
rate_limit_reset=excluded.rate_limit_reset,
is_healthy=excluded.is_healthy,
unhealthy_reason=excluded.unhealthy_reason,
recovery_time=excluded.recovery_time,
fail_count=excluded.fail_count,
last_used=excluded.last_used,
used_count=excluded.used_count,
limit_count=excluded.limit_count,
last_sync=excluded.last_sync
`
)
.run(
Expand Down
13 changes: 13 additions & 0 deletions src/plugin/sync/idc-region.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
let idcRegion: string | undefined

export function setIdcRegionFromState(region: string | undefined): void {
if (typeof region === 'string' && region.trim()) {
idcRegion = region.trim()
return
}
idcRegion = undefined
}

export function getIdcRegionFromState(): string | undefined {
return idcRegion
}
102 changes: 93 additions & 9 deletions src/plugin/sync/kiro-cli.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import { Database } from 'bun:sqlite'
import { existsSync } from 'node:fs'
import { normalizeRegion } from '../../constants.js'
import { createDeterministicAccountId } from '../accounts'
import * as logger from '../logger'
import { kiroDb } from '../storage/sqlite'
import { fetchUsageLimits } from '../usage'
import { setIdcRegionFromState } from './idc-region'
import {
findClientCredsRecursive,
getCliDbPath,
Expand All @@ -12,6 +14,19 @@ import {
safeJsonParse
} from './kiro-cli-parser'

function extractProfileArnFromAccessToken(accessToken: string | undefined): string | undefined {
if (!accessToken || !accessToken.includes('.')) return undefined
const parts = accessToken.split('.')
if (parts.length < 2 || !parts[1]) return undefined
try {
const payload = Buffer.from(parts[1], 'base64').toString('utf8')
const data = JSON.parse(payload)
return data.profileArn || data.profile_arn || data['profile_arn'] || undefined
} catch {
return undefined
}
}

export async function syncFromKiroCli() {
const dbPath = getCliDbPath()
if (!existsSync(dbPath)) return
Expand All @@ -20,26 +35,81 @@ export async function syncFromKiroCli() {
cliDb.run('PRAGMA busy_timeout = 5000')
const rows = cliDb.prepare('SELECT key, value FROM auth_kv').all() as any[]

const deviceRegRow = rows.find(
let profileArnFromState: string | undefined
try {
const idcRegionRow = cliDb
.prepare('SELECT value FROM state WHERE key = ?')
.get('auth.idc.region') as { value?: string } | undefined
const parsedRegion = safeJsonParse(idcRegionRow?.value)
if (typeof parsedRegion === 'string') {
setIdcRegionFromState(parsedRegion)
}
const profileRow = cliDb
.prepare('SELECT value FROM state WHERE key = ?')
.get('api.codewhisperer.profile') as { value?: string } | undefined
const profile = safeJsonParse(profileRow?.value)
if (profile && typeof profile.arn === 'string') {
profileArnFromState = profile.arn
}
} catch {
setIdcRegionFromState(undefined)
}

const tokenRows = rows.filter((r) => typeof r?.key === 'string' && r.key.includes(':token'))
const parsedTokens = tokenRows
.map((row) => {
const data = safeJsonParse(row.value)
const expiresAt = normalizeExpiresAt(data?.expires_at ?? data?.expiresAt)
return { row, data, expiresAt }
})
.filter((t) => t.data)

const now = Date.now()
const validTokens = parsedTokens.filter((t) => t.expiresAt > now)
const candidates = validTokens.length ? validTokens : parsedTokens

let tokenRowsToImport = tokenRows
if (candidates.length > 0) {
const maxExpiresAt = Math.max(...candidates.map((t) => t.expiresAt || 0))
tokenRowsToImport = candidates.filter((t) => t.expiresAt === maxExpiresAt).map((t) => t.row)
}

const deviceRegRows = rows.filter(
(r) => typeof r?.key === 'string' && r.key.includes('device-registration')
)
const deviceReg = safeJsonParse(deviceRegRow?.value)
const regCreds = deviceReg ? findClientCredsRecursive(deviceReg) : {}
const deviceRegByKey = new Map<string, { clientId?: string; clientSecret?: string }>()
for (const row of deviceRegRows) {
const deviceReg = safeJsonParse(row.value)
const regCreds = deviceReg ? findClientCredsRecursive(deviceReg) : {}
if (regCreds.clientId && regCreds.clientSecret) {
const baseKey = row.key.replace(':device-registration', '')
deviceRegByKey.set(baseKey, regCreds)
}
}

const importedIds = new Set<string>()

for (const row of rows) {
for (const row of tokenRowsToImport) {
if (row.key.includes(':token')) {
const data = safeJsonParse(row.value)
if (!data) continue

const isIdc = row.key.includes('odic')
const isIdc = row.key.includes('odic') || row.key.includes('oidc')
const authMethod = isIdc ? 'idc' : 'desktop'
const region = data.region || 'us-east-1'
const profileArn = data.profile_arn || data.profileArn

const accessToken = data.access_token || data.accessToken || ''
const profileArn = data.profile_arn || data.profileArn || profileArnFromState
const regionFromProfile = profileArn?.split(':')[3]
const region = normalizeRegion(regionFromProfile || data.region)
const refreshToken = data.refresh_token || data.refreshToken
if (!refreshToken) continue

const baseKey = row.key.replace(':token', '')
const regCreds =
deviceRegByKey.get(baseKey) ||
deviceRegByKey.get(baseKey.replace('kirocli', 'codewhisperer')) ||
deviceRegByKey.get(baseKey.replace('codewhisperer', 'kirocli')) ||
{}

const clientId = data.client_id || data.clientId || (isIdc ? regCreds.clientId : undefined)
const clientSecret =
data.client_secret || data.clientSecret || (isIdc ? regCreds.clientSecret : undefined)
Expand Down Expand Up @@ -108,7 +178,8 @@ export async function syncFromKiroCli() {
if (
existingById &&
existingById.is_healthy === 1 &&
existingById.expires_at >= cliExpiresAt
existingById.expires_at >= cliExpiresAt &&
existingById.region === region
)
continue

Expand Down Expand Up @@ -165,6 +236,19 @@ export async function syncFromKiroCli() {
limitCount,
lastSync: Date.now()
})
importedIds.add(id)
}
}

const existing = kiroDb.getAccounts()
for (const acc of existing) {
if (
typeof acc?.email === 'string' &&
acc.email.endsWith('@awsapps.local') &&
acc.auth_method === 'idc' &&
!importedIds.has(acc.id)
) {
await kiroDb.deleteAccount(acc.id)
}
}
cliDb.close()
Expand Down
Loading