diff --git a/packages/app/e2e/session/session-restart.spec.ts b/packages/app/e2e/session/session-restart.spec.ts new file mode 100644 index 00000000000..d9935f2ccc6 --- /dev/null +++ b/packages/app/e2e/session/session-restart.spec.ts @@ -0,0 +1,87 @@ +import type { Page } from "@playwright/test" +import { test, expect } from "../fixtures" +import { runPromptSlash, withSession } from "../actions" +import { createSdk } from "../utils" +import { promptSelector } from "../selectors" + +async function seedConversation(input: { + page: Page + sdk: ReturnType + sessionID: string + token: string +}) { + const messages = async () => + await input.sdk.session.messages({ sessionID: input.sessionID, limit: 100 }).then((r) => r.data ?? []) + const seeded = await messages() + const userIDs = new Set(seeded.filter((m) => m.info.role === "user").map((m) => m.info.id)) + + await input.sdk.session.promptAsync({ + sessionID: input.sessionID, + noReply: true, + parts: [{ type: "text", text: input.token }], + }) + + let userMessageID: string | undefined + await expect + .poll( + async () => { + const users = (await messages()).filter( + (m) => + !userIDs.has(m.info.id) && + m.info.role === "user" && + m.parts.filter((p) => p.type === "text").some((p) => p.text.includes(input.token)), + ) + if (users.length === 0) return false + + const user = users[users.length - 1] + if (!user) return false + userMessageID = user.info.id + return true + }, + { timeout: 90_000, intervals: [250, 500, 1_000] }, + ) + .toBe(true) + + if (!userMessageID) throw new Error("Expected a user message id") + return userMessageID +} + +test("slash restart opens a new session draft with the initial user prompt", async ({ page, withProject }) => { + test.setTimeout(120_000) + + const firstToken = `restart_first_${Date.now()}` + const secondToken = `restart_second_${Date.now()}` + + await withProject(async (project) => { + const sdk = createSdk(project.directory) + + await withSession(sdk, `e2e restart ${Date.now()}`, async (session) => { + await project.gotoSession(session.id) + + const first = await seedConversation({ + page, + sdk, + sessionID: session.id, + token: firstToken, + }) + const second = await seedConversation({ + page, + sdk, + sessionID: session.id, + token: secondToken, + }) + + expect(first).not.toBe(second) + + const prompt = page.locator(promptSelector) + await expect(prompt).toBeVisible() + + await runPromptSlash(page, { id: "session.restart", text: "/restart", prompt }) + + await expect(page).toHaveURL(new RegExp(`/${project.slug}/session(?:[?#]|$)`), { timeout: 30_000 }) + + await expect(prompt).toContainText(firstToken) + await expect(prompt).not.toContainText(secondToken) + }) + }) +}) diff --git a/packages/app/src/components/dialog-fork.tsx b/packages/app/src/components/dialog-fork.tsx index 9e1b896fa8f..ff19b41accf 100644 --- a/packages/app/src/components/dialog-fork.tsx +++ b/packages/app/src/components/dialog-fork.tsx @@ -2,15 +2,15 @@ import { Component, createMemo } from "solid-js" import { useNavigate, useParams } from "@solidjs/router" import { useSync } from "@/context/sync" import { useSDK } from "@/context/sdk" -import { usePrompt } from "@/context/prompt" +import { usePrompt, type Prompt } from "@/context/prompt" import { useDialog } from "@opencode-ai/ui/context/dialog" import { Dialog } from "@opencode-ai/ui/dialog" import { List } from "@opencode-ai/ui/list" import { showToast } from "@opencode-ai/ui/toast" -import { extractPromptFromParts } from "@/utils/prompt" import type { TextPart as SDKTextPart } from "@opencode-ai/sdk/v2/client" import { base64Encode } from "@opencode-ai/util/encode" import { useLanguage } from "@/context/language" +import { extractPromptFromParts } from "@/utils/prompt" interface ForkableMessage { id: string @@ -22,6 +22,37 @@ function formatTime(date: Date): string { return date.toLocaleTimeString(undefined, { timeStyle: "short" }) } +async function fork(opts: { + fork: (input: { sessionID: string; messageID: string }) => Promise<{ data?: { id: string } }> + sessionID: string + messageID: string + prompt: Prompt + directory: string + fail: (message?: string) => void + navigate: (href: string) => void + set: (prompt: Prompt, next: { dir: string; id: string }) => void + done?: () => void +}) { + const dir = base64Encode(opts.directory) + + await opts + .fork({ sessionID: opts.sessionID, messageID: opts.messageID }) + .then((res) => { + const id = res.data?.id + if (!id) { + opts.fail() + return + } + opts.done?.() + opts.set(opts.prompt, { dir, id }) + opts.navigate(`/${dir}/session/${id}`) + }) + .catch((err: unknown) => { + const message = err instanceof Error ? err.message : String(err) + opts.fail(message) + }) +} + export const DialogFork: Component = () => { const params = useParams() const navigate = useNavigate() @@ -60,29 +91,22 @@ export const DialogFork: Component = () => { const sessionID = params.id if (!sessionID) return - - const parts = sync.data.part[item.id] ?? [] - const restored = extractPromptFromParts(parts, { + const value = extractPromptFromParts(sync.data.part[item.id] ?? [], { directory: sdk.directory, attachmentName: language.t("common.attachment"), }) - const dir = base64Encode(sdk.directory) - sdk.client.session - .fork({ sessionID, messageID: item.id }) - .then((forked) => { - if (!forked.data) { - showToast({ title: language.t("common.requestFailed") }) - return - } - dialog.close() - prompt.set(restored, undefined, { dir, id: forked.data.id }) - navigate(`/${dir}/session/${forked.data.id}`) - }) - .catch((err: unknown) => { - const message = err instanceof Error ? err.message : String(err) - showToast({ title: language.t("common.requestFailed"), description: message }) - }) + void fork({ + fork: sdk.client.session.fork, + sessionID, + messageID: item.id, + prompt: value, + directory: sdk.directory, + fail: (message) => showToast({ title: language.t("common.requestFailed"), description: message }), + navigate, + set: (value, next) => prompt.set(value, undefined, next), + done: dialog.close, + }) } return ( diff --git a/packages/app/src/i18n/en.ts b/packages/app/src/i18n/en.ts index 72caed40ad9..a59fe27c1c8 100644 --- a/packages/app/src/i18n/en.ts +++ b/packages/app/src/i18n/en.ts @@ -81,6 +81,8 @@ export const dict = { "command.session.redo.description": "Redo the last undone message", "command.session.compact": "Compact session", "command.session.compact.description": "Summarize the session to reduce context size", + "command.session.restart": "Restart from first prompt", + "command.session.restart.description": "Fork a new session from the user's initial query", "command.session.fork": "Fork from message", "command.session.fork.description": "Create a new session from a previous message", "command.session.share": "Share session", diff --git a/packages/app/src/pages/session/use-session-commands.tsx b/packages/app/src/pages/session/use-session-commands.tsx index 1a2e777f522..bb8c619d63e 100644 --- a/packages/app/src/pages/session/use-session-commands.tsx +++ b/packages/app/src/pages/session/use-session-commands.tsx @@ -15,6 +15,7 @@ import { DialogSelectFile } from "@/components/dialog-select-file" import { DialogSelectModel } from "@/components/dialog-select-model" import { DialogSelectMcp } from "@/components/dialog-select-mcp" import { DialogFork } from "@/components/dialog-fork" +import { promptLength } from "@/components/prompt-input/history" import { showToast } from "@opencode-ai/ui/toast" import { findLast } from "@opencode-ai/util/array" import { createSessionTabs } from "@/pages/session/helpers" @@ -94,6 +95,20 @@ export const useSessionCommands = (actions: SessionCommandContext) => { layout.fileTree.setTab("all") } + const restart = async () => { + const dir = params.dir + if (!dir) return + const msg = userMessages()[0] + if (!msg) return + const value = extractPromptFromParts(sync.data.part[msg.id] ?? [], { + directory: sdk.directory, + attachmentName: language.t("common.attachment"), + }) + + prompt.set(value, promptLength(value), { dir }) + navigate(`/${dir}/session`) + } + const selectionPreview = (path: string, selection: FileSelection) => { const content = file.get(path)?.content?.content if (!content) return undefined @@ -481,6 +496,14 @@ export const useSessionCommands = (actions: SessionCommandContext) => { }) }, }), + sessionCommand({ + id: "session.restart", + title: language.t("command.session.restart"), + description: language.t("command.session.restart.description"), + slash: "restart", + disabled: !params.id || userMessages().length === 0, + onSelect: restart, + }), sessionCommand({ id: "session.fork", title: language.t("command.session.fork"),