diff --git a/packages/types/src/providers/gemini.ts b/packages/types/src/providers/gemini.ts index 4734606d5d9..ac5adf46c76 100644 --- a/packages/types/src/providers/gemini.ts +++ b/packages/types/src/providers/gemini.ts @@ -283,4 +283,46 @@ export const geminiModels = { supportsReasoningBudget: true, maxThinkingTokens: 24_576, }, + // Gemma 4 models + // https://ai.google.dev/gemma/docs/core + "gemma-4-31b-it": { + maxTokens: 8_192, + contextWindow: 131_072, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemma-4-12b-it": { + maxTokens: 8_192, + contextWindow: 131_072, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemma-4-6b-it": { + maxTokens: 8_192, + contextWindow: 32_768, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemma-4-4b-it": { + maxTokens: 8_192, + contextWindow: 32_768, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, + "gemma-4-1b-it": { + maxTokens: 8_192, + contextWindow: 32_768, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0, + outputPrice: 0, + }, } as const satisfies Record diff --git a/src/api/providers/__tests__/gemini.spec.ts b/src/api/providers/__tests__/gemini.spec.ts index 47ee79dd0d6..3872733a5b2 100644 --- a/src/api/providers/__tests__/gemini.spec.ts +++ b/src/api/providers/__tests__/gemini.spec.ts @@ -165,13 +165,33 @@ describe("GeminiHandler", () => { expect(modelInfo.info).toBeDefined() }) - it("should return default model if invalid model specified", () => { + it("should preserve unknown model ID instead of silently falling back to default", () => { const invalidHandler = new GeminiHandler({ apiModelId: "invalid-model", geminiApiKey: "test-key", }) const modelInfo = invalidHandler.getModel() - expect(modelInfo.id).toBe(geminiDefaultModelId) // Default model + expect(modelInfo.id).toBe("invalid-model") // Preserves user-provided ID + expect(modelInfo.info).toBeDefined() // Falls back to default model info + }) + + it("should use default model when no model ID is provided", () => { + const noModelHandler = new GeminiHandler({ + geminiApiKey: "test-key", + }) + const modelInfo = noModelHandler.getModel() + expect(modelInfo.id).toBe(geminiDefaultModelId) + }) + + it("should recognize Gemma 4 models natively", () => { + const gemmaHandler = new GeminiHandler({ + apiModelId: "gemma-4-31b-it", + geminiApiKey: "test-key", + }) + const modelInfo = gemmaHandler.getModel() + expect(modelInfo.id).toBe("gemma-4-31b-it") + expect(modelInfo.info.contextWindow).toBe(131_072) + expect(modelInfo.info.supportsImages).toBe(true) }) it("should exclude apply_diff and include edit in tool preferences", () => { diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index a49073ea334..0408695bece 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -348,8 +348,8 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl override getModel() { const modelId = this.options.apiModelId - let id = modelId && modelId in geminiModels ? (modelId as GeminiModelId) : geminiDefaultModelId - let info: ModelInfo = geminiModels[id] + const id: string = modelId ?? geminiDefaultModelId + let info: ModelInfo = geminiModels[id as GeminiModelId] ?? geminiModels[geminiDefaultModelId] const params = getModelParams({ format: "gemini", diff --git a/src/api/providers/vertex.ts b/src/api/providers/vertex.ts index fd318d9b19a..3480e5b426b 100644 --- a/src/api/providers/vertex.ts +++ b/src/api/providers/vertex.ts @@ -14,8 +14,8 @@ export class VertexHandler extends GeminiHandler implements SingleCompletionHand override getModel() { const modelId = this.options.apiModelId - let id = modelId && modelId in vertexModels ? (modelId as VertexModelId) : vertexDefaultModelId - let info: ModelInfo = vertexModels[id] + const id: string = modelId ?? vertexDefaultModelId + let info: ModelInfo = vertexModels[id as VertexModelId] ?? vertexModels[vertexDefaultModelId] const params = getModelParams({ format: "gemini", modelId: id,