diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index f3eb5d038..e35a342c6 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -26,6 +26,7 @@ import ( "github.com/docker/docker-agent/pkg/model/provider/oaistream" "github.com/docker/docker-agent/pkg/model/provider/options" "github.com/docker/docker-agent/pkg/modelinfo" + "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/rag/prompts" "github.com/docker/docker-agent/pkg/rag/types" "github.com/docker/docker-agent/pkg/tools" @@ -178,10 +179,20 @@ func (c *Client) Close() { } } -// convertMessages converts chat.Message to openai.ChatCompletionMessageParamUnion -// using the shared oaistream implementation. +// convertMessages converts chat.Message to OpenAI chat-completions message params. +// Custom OpenAI-compatible providers may target local model runners that reject +// consecutive system or user messages, so we normalize those prompts to match +// the DMR provider behavior. +func convertMessages(ctx context.Context, cfg *latest.ModelConfig, id modelsdev.ID, store *modelsdev.Store, messages []chat.Message) []openai.ChatCompletionMessageParamUnion { + openaiMessages := oaistream.ConvertMessages(ctx, messages, id, store) + if isCustomProvider(cfg) { + return oaistream.MergeConsecutiveMessages(openaiMessages) + } + return openaiMessages +} + func (c *Client) convertMessages(ctx context.Context, messages []chat.Message) []openai.ChatCompletionMessageParamUnion { - return oaistream.ConvertMessages(ctx, messages, c.ID(), c.ModelOptions.ModelsDevStore()) + return convertMessages(ctx, &c.ModelConfig, c.ID(), c.ModelOptions.ModelsDevStore(), messages) } // CreateChatCompletionStream creates a streaming chat completion request diff --git a/pkg/model/provider/openai/client_test.go b/pkg/model/provider/openai/client_test.go index 6bc1dd49b..945306ff1 100644 --- a/pkg/model/provider/openai/client_test.go +++ b/pkg/model/provider/openai/client_test.go @@ -7,6 +7,8 @@ import ( "github.com/stretchr/testify/require" "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/tools" ) @@ -82,7 +84,7 @@ func TestConvertMessagesToResponseInput_AssistantTextWithToolCalls(t *testing.T) } func TestConvertMessagesToResponseInput_NoOrphans(t *testing.T) { - // All tool calls have matching results — no placeholder needed. + // All tool calls have matching results - no placeholder needed. messages := []chat.Message{ {Role: chat.MessageRoleUser, Content: "hello"}, { @@ -104,3 +106,37 @@ func TestConvertMessagesToResponseInput_NoOrphans(t *testing.T) { } assert.Equal(t, 1, outputCount, "should not inject extra outputs when all calls have results") } + +func TestConvertMessages_MergesConsecutiveSystemMessagesForCustomProviders(t *testing.T) { + cfg := &latest.ModelConfig{ + ProviderOpts: map[string]any{"api_type": "openai_chatcompletions"}, + } + messages := []chat.Message{ + {Role: chat.MessageRoleSystem, Content: "You are Bob, a coding expert"}, + {Role: chat.MessageRoleSystem, Content: "## Custom Shell Tools\n\n### execute_command"}, + {Role: chat.MessageRoleSystem, Content: "\n what-time-is-it\n"}, + {Role: chat.MessageRoleUser, Content: "what is your favourite colour?"}, + } + + result := convertMessages(t.Context(), cfg, modelsdev.ID{}, nil, messages) + require.Len(t, result, 2) + require.NotNil(t, result[0].OfSystem) + assert.Contains(t, result[0].OfSystem.Content.OfString.Value, "You are Bob, a coding expert") + assert.Contains(t, result[0].OfSystem.Content.OfString.Value, "Custom Shell Tools") + assert.Contains(t, result[0].OfSystem.Content.OfString.Value, "available_skills") + assert.NotNil(t, result[1].OfUser) +} + +func TestConvertMessages_PreservesConsecutiveSystemMessagesForOpenAIProvider(t *testing.T) { + messages := []chat.Message{ + {Role: chat.MessageRoleSystem, Content: "System 1"}, + {Role: chat.MessageRoleSystem, Content: "System 2"}, + {Role: chat.MessageRoleUser, Content: "hello"}, + } + + result := convertMessages(t.Context(), &latest.ModelConfig{}, modelsdev.ID{}, nil, messages) + require.Len(t, result, 3) + assert.NotNil(t, result[0].OfSystem) + assert.NotNil(t, result[1].OfSystem) + assert.NotNil(t, result[2].OfUser) +}