diff --git a/MIGRATION_0.26_0.27.md b/MIGRATION_0.26_0.27.md index a4dac7b..4cef265 100644 --- a/MIGRATION_0.26_0.27.md +++ b/MIGRATION_0.26_0.27.md @@ -14,6 +14,9 @@ agent or client implementation into a connection: notification params are available as `ctx.params`. Agent handlers use `ctx.client` for outbound calls to the client. Client handlers use `ctx.agent` for outbound calls to the agent. +- Long-lived connection handles also expose the peer context. `agent.connect(...)` + returns an `AgentConnection` with `connection.client`, and `client.connect(...)` + returns a `ClientConnection` with `connection.agent`. `AgentSideConnection` and `ClientSideConnection` still exist as deprecated compatibility wrappers, but new code should use the app API. @@ -24,8 +27,8 @@ compatibility wrappers, but new code should use the app API. | -------------------------------------------------------------- | --------------------------------------------------------------------------- | | `new AgentSideConnection((conn) => new MyAgent(conn), stream)` | `acp.agent({ name }).onRequest(...).onNotification(...).connect(stream)` | | `new ClientSideConnection((_agent) => client, stream)` | `acp.client({ name }).onNotification(...).connectWith(stream, async ...)` | -| Store `AgentSideConnection` on your agent class | Use `ctx.client` in agent handlers | -| Store/use `ClientSideConnection` for outgoing agent calls | Use the `ctx` passed to `connectWith` | +| Store `AgentSideConnection` on your agent class | Use `ctx.client`, `connection.client`, or `agent.onConnect(...)` | +| Store/use `ClientSideConnection` for outgoing agent calls | Use the `ctx` passed to `connectWith`, or `connection.agent` | | Return a response from an `Agent` or `Client` method | Return a response from the app request handler | | Throw from implementation methods for JSON-RPC errors | Throw from an app handler | | Manually create session and prompt requests | Prefer `ctx.buildSession(...).withSession(...)` for common prompt workflows | @@ -34,6 +37,11 @@ Both `connect(...)` and `connectWith(...)` accept either a `Stream` or the app for the other side of the connection. Use streams for production transports and direct app connections for tests or in-process examples. +Use `connectWith(...)` for a scoped workflow where the callback owns the +connection lifetime. Use `connect(...)` when the connection should stay open +independently of one operation; the returned connection can observe closure, +close the transport, and call the peer. + ## Migrating an Agent Previously, an agent usually implemented `acp.Agent`, stored the @@ -135,6 +143,51 @@ acp .connect(stream); ``` +If your agent keeps connection-scoped state or sends notifications from +background work, use the connection handle returned by `connect(...)`: + +```ts +class MyAgent { + private client?: acp.AgentContext; + + bindClient(client?: acp.AgentContext): void { + this.client = client; + } + + async sendBackgroundUpdate(sessionId: acp.SessionId): Promise { + await this.client?.notify(acp.methods.client.session.update, { + sessionId, + update: { + sessionUpdate: "agent_message_chunk", + content: { type: "text", text: "Still working..." }, + }, + }); + } +} + +const implementation = new MyAgent(); +const app = acp.agent({ name: "my-agent" }); + +const connection = app.connect(stream); +implementation.bindClient(connection.client); +``` + +For server-owned connections, such as an app passed to `AcpServer`, register +`onConnect(...)` instead. The hook receives the same connection-scoped client +context. `AcpServer` runs the hook after the client has completed `initialize`, +so messages sent by the hook cannot replace the initialize response: + +```ts +const implementation = new MyAgent(); + +const app = acp.agent({ name: "my-agent" }).onConnect((connection) => { + implementation.bindClient(connection.client); + connection.signal.addEventListener("abort", () => { + implementation.bindClient(undefined); + }); +}); +``` + For JSON-RPC errors, throw from the handler: ```ts @@ -232,7 +285,26 @@ const prompt = await acp `connectWith` owns the connection lifetime for the callback. When the callback finishes or throws, the connection is closed. If you need the connection to stay open independently of one operation, call `connect(stream)` and keep the -returned `AcpConnection`. +returned `ClientConnection`: + +```ts +const connection = acp + .client({ name: "my-client" }) + .onNotification(acp.methods.client.session.update, (ctx) => + client.sessionUpdate(ctx.params), + ) + .connect(stream); + +try { + await connection.agent.request(acp.methods.agent.initialize, { + protocolVersion: acp.PROTOCOL_VERSION, + clientCapabilities: {}, + }); +} finally { + connection.close(); + await connection.closed; +} +``` All protocol paths should be absolute. That includes `cwd`, `additionalDirectories`, file-system request paths, terminal/tool-call @@ -298,6 +370,19 @@ acp.client().onRequest(acp.methods.client.session.requestPermission, (ctx) => { Agent handler contexts include `params` and `client`. Client handler contexts include `params` and `agent`. +Connection handles expose those same peer contexts for connection-scoped work: + +```ts +const agentConnection = acp.agent({ name: "my-agent" }).connect(stream); +await agentConnection.client.notify(acp.methods.client.session.update, update); + +const clientConnection = acp.client({ name: "my-client" }).connect(stream); +await clientConnection.agent.request(acp.methods.agent.session.new, { + cwd: "/workspace/project", + mcpServers: [], +}); +``` + The `connectWith` callback receives a `ClientContext`, usually named `ctx`, with `request(...)` and `notify(...)` for talking to the agent: diff --git a/src/acp.test.ts b/src/acp.test.ts index 1eca7f4..cfeca98 100644 --- a/src/acp.test.ts +++ b/src/acp.test.ts @@ -824,6 +824,150 @@ describe("Connection", () => { ]); }); + it("returns peer contexts from app connection handles", async () => { + const events: string[] = []; + + const appAgent = createAgent({ name: "peer-handle-agent" }) + .onRequest(AGENT_METHODS.initialize, (c) => { + events.push(`initialize:${c.params.protocolVersion}`); + return { + protocolVersion: c.params.protocolVersion, + agentCapabilities: { loadSession: false }, + authMethods: [], + }; + }) + .onNotification( + "vendor/agent/notify", + (params) => params as { message: string }, + (c) => { + events.push(`agent-notify:${c.params.message}`); + }, + ); + + const appClient = createClient({ name: "peer-handle-client" }) + .onRequest(CLIENT_METHODS.fs_read_text_file, (c) => { + events.push(`read:${c.params.path}`); + return { content: "client file" }; + }) + .onNotification(CLIENT_METHODS.session_update, (c) => { + events.push(`update:${c.params.sessionId}`); + }); + + const agentConnection = appAgent.connect(appClient); + try { + const readResponse = await agentConnection.client.request( + CLIENT_METHODS.fs_read_text_file, + { + sessionId: "peer-session", + path: "/peer/file.txt", + }, + ); + await agentConnection.client.notify(CLIENT_METHODS.session_update, { + sessionId: "peer-session", + update: { + sessionUpdate: "agent_message_chunk", + content: { type: "text", text: "from connection" }, + }, + }); + + expect(readResponse.content).toBe("client file"); + await vi.waitFor(() => { + expect(events).toContain("read:/peer/file.txt"); + expect(events).toContain("update:peer-session"); + }); + } finally { + agentConnection.close(); + await agentConnection.closed; + } + + const clientConnection = appClient.connect(appAgent); + try { + const initializeResponse = await clientConnection.agent.request( + AGENT_METHODS.initialize, + { + protocolVersion: PROTOCOL_VERSION, + clientCapabilities: {}, + }, + ); + await clientConnection.agent.notify("vendor/agent/notify", { + message: "from-client-connection", + }); + + expect(initializeResponse.protocolVersion).toBe(PROTOCOL_VERSION); + await vi.waitFor(() => { + expect(events).toContain(`initialize:${PROTOCOL_VERSION}`); + expect(events).toContain("agent-notify:from-client-connection"); + }); + } finally { + clientConnection.close(); + await clientConnection.closed; + } + }); + + it("runs app connection hooks with peer-callable handles", async () => { + const events: string[] = []; + let agentHookConnection: unknown; + let clientHookConnection: unknown; + + const appAgent = createAgent({ name: "hook-agent" }) + .onConnect(async (connection) => { + agentHookConnection = connection; + events.push("agent-connect"); + connection.signal.addEventListener("abort", () => { + events.push("agent-close"); + }); + await connection.client.notify(CLIENT_METHODS.session_update, { + sessionId: "hook-session", + update: { + sessionUpdate: "agent_message_chunk", + content: { type: "text", text: "from agent hook" }, + }, + }); + }) + .onNotification( + "vendor/agent/notify", + (params) => params as { message: string }, + (c) => { + events.push(`agent-notify:${c.params.message}`); + }, + ); + + const appClient = createClient({ name: "hook-client" }) + .onConnect(async (connection) => { + clientHookConnection = connection; + events.push("client-connect"); + connection.signal.addEventListener("abort", () => { + events.push("client-close"); + }); + await connection.agent.notify("vendor/agent/notify", { + message: "from-client-hook", + }); + }) + .onNotification(CLIENT_METHODS.session_update, (c) => { + events.push(`update:${c.params.sessionId}`); + }); + + const connection = appAgent.connect(appClient); + try { + expect(agentHookConnection).toBe(connection); + await vi.waitFor(() => { + expect(clientHookConnection).toBeDefined(); + expect(events).toContain("agent-connect"); + expect(events).toContain("client-connect"); + expect(events).toContain("update:hook-session"); + expect(events).toContain("agent-notify:from-client-hook"); + }); + } finally { + connection.close(); + await connection.closed; + } + + await vi.waitFor(() => { + expect(events).toContain("agent-close"); + expect(events).toContain("client-close"); + }); + }); + it("normalizes app built-in empty-object handler responses before sending", async () => { const appAgent = createAgent({ name: "empty-agent-responses" }) .onRequest(AGENT_METHODS.session_load, () => {}) diff --git a/src/acp.ts b/src/acp.ts index fbf4198..25e188a 100644 --- a/src/acp.ts +++ b/src/acp.ts @@ -58,18 +58,6 @@ function memoryStreamPair(): [Stream, Stream] { ]; } -function connectInProcess( - connectThis: (stream: Stream) => Connection, - connectPeer: (stream: Stream) => Connection, -): Connection { - const [thisStream, peerStream] = memoryStreamPair(); - const peerConnection = connectPeer(peerStream); - const connection = connectThis(thisStream); - void connection.closed.then(() => peerConnection.close()); - void peerConnection.closed.then(() => connection.close()); - return connection; -} - /** * ACP method-name constants. * @@ -164,6 +152,32 @@ export interface AcpConnection { close(error?: unknown): void; } +/** + * Agent-side connection returned by `AgentApp.connect(...)`. + * + * Use `client` to call client-side ACP methods for the lifetime of the + * connection. + */ +export interface AgentConnection extends AcpConnection { + /** + * Context for calling client-side ACP methods. + */ + readonly client: AgentContext; +} + +/** + * Client-side connection returned by `ClientApp.connect(...)`. + * + * Use `agent` to call agent-side ACP methods and session helpers for the + * lifetime of the connection. + */ +export interface ClientConnection extends AcpConnection { + /** + * Context for calling agent-side ACP methods. + */ + readonly agent: ClientContext; +} + class AcpContext { /** @internal */ constructor(private readonly cx: ConnectionContext) {} @@ -361,6 +375,88 @@ export class ClientContext extends AcpContext { } } +class AcpConnectionHandle implements AcpConnection { + constructor(private readonly connection: Connection) {} + + get signal(): AbortSignal { + return this.connection.signal; + } + + get closed(): Promise { + return this.connection.closed; + } + + close(error?: unknown): void { + this.connection.close(error); + } +} + +class AgentConnectionHandle + extends AcpConnectionHandle + implements AgentConnection +{ + readonly client: AgentContext; + private didStartConnectHandlers = false; + + constructor( + connection: Connection, + private readonly connectHandlers: readonly AgentConnectHandler[] = [], + ) { + super(connection); + this.client = AgentContext.create(connection.getContext()); + } + + /** @internal */ + startConnectHandlers(): void { + if (this.didStartConnectHandlers) { + return; + } + + this.didStartConnectHandlers = true; + runConnectHandlers(this, this.connectHandlers); + } +} + +class ClientConnectionHandle + extends AcpConnectionHandle + implements ClientConnection +{ + readonly agent: ClientContext; + private didStartConnectHandlers = false; + + constructor( + connection: Connection, + private readonly connectHandlers: readonly ClientConnectHandler[] = [], + ) { + super(connection); + this.agent = ClientContext.create(connection.getContext()); + } + + /** @internal */ + startConnectHandlers(): void { + if (this.didStartConnectHandlers) { + return; + } + + this.didStartConnectHandlers = true; + runConnectHandlers(this, this.connectHandlers); + } +} + +function agentConnection( + connection: Connection, + connectHandlers: readonly AgentConnectHandler[] = [], +): AgentConnection { + return new AgentConnectionHandle(connection, connectHandlers); +} + +function clientConnection( + connection: Connection, + connectHandlers: readonly ClientConnectHandler[] = [], +): ClientConnection { + return new ClientConnectionHandle(connection, connectHandlers); +} + type AsyncQueueEntry = | { kind: "value"; @@ -826,6 +922,20 @@ export type ClientNotificationHandler = ( context: ClientHandlerContext, ) => MaybePromise; +/** + * Handler called when an `AgentApp` opens a connection. + */ +export type AgentConnectHandler = ( + connection: AgentConnection, +) => MaybePromise; + +/** + * Handler called when a `ClientApp` opens a connection. + */ +export type ClientConnectHandler = ( + connection: ClientConnection, +) => MaybePromise; + function parseParams( parser: ParamsParser | undefined, params: unknown, @@ -1492,7 +1602,42 @@ function sessionUpdateRouter(cx: ConnectionContext): SessionUpdateRouter { return router; } +function runConnectHandlers( + connection: ConnectionHandle, + handlers: ReadonlyArray<(connection: ConnectionHandle) => MaybePromise>, +): void { + for (const handler of handlers) { + let result: MaybePromise; + try { + result = handler(connection); + } catch (error) { + connection.close(error); + throw error; + } + + void Promise.resolve(result).catch((error) => { + connection.close(error); + }); + } +} + const appBuilder = Symbol("appBuilder"); +const runAgentConnectHandlers = Symbol("runAgentConnectHandlers"); +const runClientConnectHandlers = Symbol("runClientConnectHandlers"); + +type AppConnectOptions = { + readonly deferConnectHandlers?: boolean; +}; + +type AgentConnectionState = { + rawConnection: Connection; + connection: AgentConnection; +}; + +type ClientConnectionState = { + rawConnection: Connection; + connection: ClientConnection; +}; /** * Creates an agent-side app. @@ -1514,6 +1659,7 @@ export function agent(options?: AppOptions): AgentApp { */ export class AgentApp { private readonly builder = Connection.builder(); + private readonly connectHandlers: AgentConnectHandler[] = []; constructor(options: AppOptions = {}) { if (options.name) { @@ -1526,19 +1672,29 @@ export class AgentApp { return this.builder; } + /** @internal */ + [runAgentConnectHandlers](connection: AgentConnection): void { + runConnectHandlers(connection, this.connectHandlers); + } + /** * Connects this agent app to a transport stream. */ - connect(stream: Stream): AcpConnection; + connect(stream: Stream): AgentConnection; + /** @internal */ + connect(stream: Stream, options: AppConnectOptions): AgentConnection; /** * Connects this agent app directly to a client app. * * This is useful for tests and in-process examples that do not need a * transport. */ - connect(client: ClientApp): AcpConnection; - connect(target: Stream | ClientApp): AcpConnection { - return this.connectTarget(target); + connect(client: ClientApp): AgentConnection; + connect( + target: Stream | ClientApp, + options: AppConnectOptions = {}, + ): AgentConnection { + return this.connectConnection(target, options).connection; } /** @@ -1562,9 +1718,19 @@ export class AgentApp { target: Stream | ClientApp, op: (context: AgentContext) => MaybePromise, ): Promise { - return this.connectTarget(target).runUntil((cx) => - op(AgentContext.create(cx)), - ); + const { rawConnection, connection } = this.connectConnection(target); + return rawConnection.runUntil(() => op(connection.client)); + } + + /** + * Registers a handler that runs when this agent app opens a connection. + * + * Use this for connection-scoped work that needs to call client-side ACP + * methods outside an inbound request handler. + */ + onConnect(handler: AgentConnectHandler): this { + this.connectHandlers.push(handler); + return this; } /** @@ -1678,15 +1844,41 @@ export class AgentApp { return this; } - private connectTarget(target: Stream | ClientApp): Connection { + private connectConnection( + target: Stream | ClientApp, + options: AppConnectOptions = {}, + ): AgentConnectionState { if (isStream(target)) { - return this.builder.connect(target); + const state = this.openStreamConnection(target); + if (!options.deferConnectHandlers) { + this[runAgentConnectHandlers](state.connection); + } + return state; } - return connectInProcess( - (stream) => this.builder.connect(stream), - (stream) => target[appBuilder]().connect(stream), - ); + const [thisStream, peerStream] = memoryStreamPair(); + const peerRawConnection = target[appBuilder]().connect(peerStream); + const peerConnection = clientConnection(peerRawConnection); + const state = this.openStreamConnection(thisStream); + void state.rawConnection.closed.then(() => peerConnection.close()); + void peerRawConnection.closed.then(() => state.connection.close()); + try { + target[runClientConnectHandlers](peerConnection); + this[runAgentConnectHandlers](state.connection); + } catch (error) { + peerConnection.close(error); + state.connection.close(error); + throw error; + } + return state; + } + + private openStreamConnection(stream: Stream): AgentConnectionState { + const rawConnection = this.builder.connect(stream); + return { + rawConnection, + connection: agentConnection(rawConnection, this.connectHandlers), + }; } } @@ -1710,6 +1902,7 @@ export function client(options?: AppOptions): ClientApp { */ export class ClientApp { private readonly builder = Connection.builder(); + private readonly connectHandlers: ClientConnectHandler[] = []; constructor(options: AppOptions = {}) { if (options.name) { @@ -1727,19 +1920,24 @@ export class ClientApp { return this.builder; } + /** @internal */ + [runClientConnectHandlers](connection: ClientConnection): void { + runConnectHandlers(connection, this.connectHandlers); + } + /** * Connects this client app to a transport stream. */ - connect(stream: Stream): AcpConnection; + connect(stream: Stream): ClientConnection; /** * Connects this client app directly to an agent app. * * This is useful for tests and in-process examples that do not need a * transport. */ - connect(agent: AgentApp): AcpConnection; - connect(target: Stream | AgentApp): AcpConnection { - return this.connectTarget(target); + connect(agent: AgentApp): ClientConnection; + connect(target: Stream | AgentApp): ClientConnection { + return this.connectConnection(target).connection; } /** @@ -1763,9 +1961,19 @@ export class ClientApp { target: Stream | AgentApp, op: (context: ClientContext) => MaybePromise, ): Promise { - return this.connectTarget(target).runUntil((cx) => - op(ClientContext.create(cx)), - ); + const { rawConnection, connection } = this.connectConnection(target); + return rawConnection.runUntil(() => op(connection.agent)); + } + + /** + * Registers a handler that runs when this client app opens a connection. + * + * Use this for connection-scoped work that needs to call agent-side ACP + * methods outside an inbound request handler. + */ + onConnect(handler: ClientConnectHandler): this { + this.connectHandlers.push(handler); + return this; } /** @@ -1879,15 +2087,36 @@ export class ClientApp { return this; } - private connectTarget(target: Stream | AgentApp): Connection { + private connectConnection(target: Stream | AgentApp): ClientConnectionState { if (isStream(target)) { - return this.builder.connect(target); + const state = this.openStreamConnection(target); + this[runClientConnectHandlers](state.connection); + return state; } - return connectInProcess( - (stream) => this.builder.connect(stream), - (stream) => target[appBuilder]().connect(stream), - ); + const [thisStream, peerStream] = memoryStreamPair(); + const peerRawConnection = target[appBuilder]().connect(peerStream); + const peerConnection = agentConnection(peerRawConnection); + const state = this.openStreamConnection(thisStream); + void state.rawConnection.closed.then(() => peerConnection.close()); + void peerRawConnection.closed.then(() => state.connection.close()); + try { + target[runAgentConnectHandlers](peerConnection); + this[runClientConnectHandlers](state.connection); + } catch (error) { + peerConnection.close(error); + state.connection.close(error); + throw error; + } + return state; + } + + private openStreamConnection(stream: Stream): ClientConnectionState { + const rawConnection = this.builder.connect(stream); + return { + rawConnection, + connection: clientConnection(rawConnection, this.connectHandlers), + }; } } diff --git a/src/connection.ts b/src/connection.ts index 7f77e61..34e8cb3 100644 --- a/src/connection.ts +++ b/src/connection.ts @@ -8,8 +8,20 @@ import { import type { AnyMessage, AnyResponse } from "./jsonrpc.js"; import type { Stream } from "./stream.js"; +export interface AgentConnectOptions { + readonly deferConnectHandlers?: boolean; +} + +export interface AgentConnectionLifecycle { + readonly closed?: Promise; + startConnectHandlers?(): void; +} + export interface AgentConnector { - connect(stream: Stream): unknown; + connect( + stream: Stream, + options?: AgentConnectOptions, + ): AgentConnectionLifecycle | unknown; } export type ResponseRoute = "connection" | { readonly session: string }; @@ -93,15 +105,21 @@ export class ConnectionState { readonly sessionStreams = new Map(); readonly pendingRoutes = new Map(); readonly clientResponseRoutes = new Map(); + readonly closed: Promise; + private readonly agentConnection: AgentConnectionLifecycle | unknown; private hasStartedRouter = false; private inboundWriteChain: Promise = Promise.resolve(); private initialReader: ReadableStreamDefaultReader | undefined; private outboundReader: ReadableStreamDefaultReader | undefined; private shutdownPromise: Promise | undefined; + private resolveClosed: () => void = () => {}; constructor(agent: AgentConnector) { this.connectionId = globalThis.crypto.randomUUID(); + this.closed = new Promise((resolve) => { + this.resolveClosed = resolve; + }); const inbound = new TransformStream(); const outbound = new TransformStream(); @@ -113,7 +131,10 @@ export class ConnectionState { writable: outbound.writable, }; - agent.connect(stream); + this.agentConnection = agent.connect(stream, { + deferConnectHandlers: true, + }); + this.observeAgentConnection(); } async recvInitial(initializeId: string | number): Promise { @@ -162,6 +183,17 @@ export class ConnectionState { void this.runRouter(); } + startConnectHandlers(): void { + if ( + typeof this.agentConnection === "object" && + this.agentConnection !== null && + "startConnectHandlers" in this.agentConnection && + typeof this.agentConnection.startConnectHandlers === "function" + ) { + this.agentConnection.startConnectHandlers(); + } + } + ensureSession(sessionId: string): OutboundStream { const existing = this.sessionStreams.get(sessionId); if (existing) { @@ -183,21 +215,40 @@ export class ConnectionState { } private async runShutdown(): Promise { - this.connectionStream.close(); - this.allOutbound.close(); + try { + this.connectionStream.close(); + this.allOutbound.close(); + + for (const stream of this.sessionStreams.values()) { + stream.close(); + } - for (const stream of this.sessionStreams.values()) { - stream.close(); + this.sessionStreams.clear(); + this.pendingRoutes.clear(); + this.clientResponseRoutes.clear(); + + await Promise.allSettled([ + this.inboundTx.close(), + this.cancelOutboundReader(), + ]); + } finally { + this.resolveClosed(); } + } - this.sessionStreams.clear(); - this.pendingRoutes.clear(); - this.clientResponseRoutes.clear(); + private observeAgentConnection(): void { + if ( + typeof this.agentConnection !== "object" || + this.agentConnection === null || + !("closed" in this.agentConnection) || + !this.agentConnection.closed + ) { + return; + } - await Promise.allSettled([ - this.inboundTx.close(), - this.cancelOutboundReader(), - ]); + void Promise.resolve(this.agentConnection.closed).finally(() => { + void this.shutdown(); + }); } private cancelOutboundReader(): Promise { @@ -320,12 +371,14 @@ export class ConnectionRegistry { createConnection(agent: AgentConnector): ConnectionState { const connection = new ConnectionState(agent); this.connections.set(connection.connectionId, connection); + this.trackConnectionClose(connection); return connection; } createPendingConnection(agent: AgentConnector): ConnectionState { const connection = new ConnectionState(agent); this.pendingConnections.set(connection.connectionId, connection); + this.trackConnectionClose(connection); return connection; } @@ -377,6 +430,17 @@ export class ConnectionRegistry { Array.from(connections, (connection) => connection.shutdown()), ); } + + private trackConnectionClose(connection: ConnectionState): void { + void connection.closed.then(() => { + if (this.connections.get(connection.connectionId) === connection) { + this.connections.delete(connection.connectionId); + } + if (this.pendingConnections.get(connection.connectionId) === connection) { + this.pendingConnections.delete(connection.connectionId); + } + }); + } } class OutboundSubscriber { diff --git a/src/jsonrpc.ts b/src/jsonrpc.ts index d2169b4..d0d62df 100644 --- a/src/jsonrpc.ts +++ b/src/jsonrpc.ts @@ -655,6 +655,11 @@ export class Connection { return this.closedPromise; } + /** @internal */ + getContext(): ConnectionContext { + return this.context; + } + /** * Sends a JSON-RPC request. * diff --git a/src/server-websocket-upgrade.test.ts b/src/server-websocket-upgrade.test.ts index 4a18d37..9e43b28 100644 --- a/src/server-websocket-upgrade.test.ts +++ b/src/server-websocket-upgrade.test.ts @@ -1,6 +1,11 @@ import { describe, expect, it } from "vitest"; -import { AgentSideConnection, PROTOCOL_VERSION } from "./acp.js"; +import { + AgentSideConnection, + PROTOCOL_VERSION, + agent as createAgentApp, + methods, +} from "./acp.js"; import { ConnectionRegistry } from "./connection.js"; import { HEADER_CONNECTION_ID, JSON_MIME_TYPE } from "./protocol.js"; import { AcpServer } from "./server.js"; @@ -57,6 +62,88 @@ describe("AcpServer prepared WebSocket upgrades", () => { } }); + it("sends initialize before agent app connect hook messages", async () => { + const server = new AcpServer({ + agent: createAgentApp({ name: "ws-connect-hook-agent" }) + .onConnect((connection) => + connection.client.notify("vendor/connect-ready", { ready: true }), + ) + .onRequest(methods.agent.initialize, (c) => ({ + protocolVersion: c.params.protocolVersion, + agentCapabilities: { + loadSession: false, + }, + authMethods: [], + })), + }); + const socket = new FakeServerSocket(); + + try { + server.prepareWebSocketUpgrade().accept(socket); + socket.receive(JSON.stringify(initializeRequest)); + + await expect(readSentMessage(socket)).resolves.toMatchObject({ + jsonrpc: "2.0", + id: initializeRequest.id, + result: { + protocolVersion: PROTOCOL_VERSION, + }, + }); + await expect(readSentMessage(socket)).resolves.toMatchObject({ + jsonrpc: "2.0", + method: "vendor/connect-ready", + params: { ready: true }, + }); + } finally { + socket.close(); + await server.close(); + } + }); + + it("forwards deferred connect hooks through WebSocket agent factories", async () => { + let connectHookRuns = 0; + const server = new AcpServer({ + createAgent: () => + createAgentApp({ name: "ws-factory-connect-hook-agent" }) + .onConnect((connection) => { + connectHookRuns += 1; + return connection.client.notify("vendor/connect-ready", { + source: "factory", + }); + }) + .onRequest(methods.agent.initialize, (c) => ({ + protocolVersion: c.params.protocolVersion, + agentCapabilities: { + loadSession: false, + }, + authMethods: [], + })), + }); + const socket = new FakeServerSocket(); + + try { + server.prepareWebSocketUpgrade().accept(socket); + socket.receive(JSON.stringify(initializeRequest)); + + await expect(readSentMessage(socket)).resolves.toMatchObject({ + jsonrpc: "2.0", + id: initializeRequest.id, + result: { + protocolVersion: PROTOCOL_VERSION, + }, + }); + expect(connectHookRuns).toBe(1); + await expect(readSentMessage(socket)).resolves.toMatchObject({ + jsonrpc: "2.0", + method: "vendor/connect-ready", + params: { source: "factory" }, + }); + } finally { + socket.close(); + await server.close(); + } + }); + it("accepts a deprecated legacy agent factory for prepared WebSocket upgrades", async () => { const connections: AgentSideConnection[] = []; const server = new AcpServer({ diff --git a/src/server.test.ts b/src/server.test.ts index 83da293..cb41ec8 100644 --- a/src/server.test.ts +++ b/src/server.test.ts @@ -114,6 +114,192 @@ describe("AcpServer", () => { } }); + it("runs agent app connect hooks after direct HTTP initialize", async () => { + const appAgent = createAgentApp({ name: "http-connect-hook-agent" }) + .onConnect((connection) => + connection.client.notify("vendor/connect-ready", { ready: true }), + ) + .onRequest(methods.agent.initialize, (c) => ({ + protocolVersion: c.params.protocolVersion, + agentCapabilities: { + loadSession: false, + }, + authMethods: [], + })); + const server = new AcpServer({ agent: appAgent }); + + try { + const response = await server.handleRequest( + jsonRequest(initializeRequest), + ); + const body = await response.json(); + const connectionId = response.headers.get(HEADER_CONNECTION_ID) ?? ""; + + expect(response.status).toBe(200); + expect(body).toMatchObject({ + jsonrpc: "2.0", + id: initializeRequest.id, + result: { + protocolVersion: PROTOCOL_VERSION, + }, + }); + + const sseResponse = await server.handleRequest( + new Request("http://127.0.0.1/acp", { + method: "GET", + headers: { + Accept: EVENT_STREAM_MIME_TYPE, + [HEADER_CONNECTION_ID]: connectionId, + }, + }), + ); + + expect(sseResponse.status).toBe(200); + await expect(readFirstSseMessage(sseResponse)).resolves.toMatchObject({ + jsonrpc: "2.0", + method: "vendor/connect-ready", + params: { ready: true }, + }); + } finally { + await server.close(); + } + }); + + it("forwards deferred connect hooks through HTTP agent factories", async () => { + let connectHookRuns = 0; + const server = new AcpServer({ + createAgent: () => + createAgentApp({ name: "http-factory-connect-hook-agent" }) + .onConnect((connection) => { + connectHookRuns += 1; + return connection.client.notify("vendor/connect-ready", { + source: "factory", + }); + }) + .onRequest(methods.agent.initialize, (c) => ({ + protocolVersion: c.params.protocolVersion, + agentCapabilities: { + loadSession: false, + }, + authMethods: [], + })), + }); + + try { + const response = await server.handleRequest( + jsonRequest(initializeRequest), + ); + const body = await response.json(); + const connectionId = response.headers.get(HEADER_CONNECTION_ID) ?? ""; + + expect(response.status).toBe(200); + expect(body).toMatchObject({ + jsonrpc: "2.0", + id: initializeRequest.id, + result: { + protocolVersion: PROTOCOL_VERSION, + }, + }); + expect(connectHookRuns).toBe(1); + + const sseResponse = await server.handleRequest( + new Request("http://127.0.0.1/acp", { + method: "GET", + headers: { + Accept: EVENT_STREAM_MIME_TYPE, + [HEADER_CONNECTION_ID]: connectionId, + }, + }), + ); + + expect(sseResponse.status).toBe(200); + await expect(readFirstSseMessage(sseResponse)).resolves.toMatchObject({ + jsonrpc: "2.0", + method: "vendor/connect-ready", + params: { source: "factory" }, + }); + } finally { + await server.close(); + } + }); + + it("removes HTTP connections when connect hooks close the handle", async () => { + const server = new AcpServer({ + agent: createAgentApp({ name: "http-connect-hook-close-agent" }) + .onConnect((connection) => { + connection.close(new Error("connect hook closed")); + }) + .onRequest(methods.agent.initialize, (c) => ({ + protocolVersion: c.params.protocolVersion, + agentCapabilities: { + loadSession: false, + }, + authMethods: [], + })), + }); + + try { + const response = await server.handleRequest( + jsonRequest(initializeRequest), + ); + const body = await response.json(); + const connectionId = response.headers.get(HEADER_CONNECTION_ID) ?? ""; + + expect(response.status).toBe(200); + expect(body).toMatchObject({ + jsonrpc: "2.0", + id: initializeRequest.id, + result: { + protocolVersion: PROTOCOL_VERSION, + }, + }); + + await waitForConnectionNotFound(server, connectionId); + } finally { + await server.close(); + } + }); + + it("removes HTTP connections when async connect hooks reject", async () => { + let connectHookRuns = 0; + const server = new AcpServer({ + agent: createAgentApp({ name: "http-connect-hook-reject-agent" }) + .onConnect(() => { + connectHookRuns += 1; + return Promise.reject(new Error("connect hook rejected")); + }) + .onRequest(methods.agent.initialize, (c) => ({ + protocolVersion: c.params.protocolVersion, + agentCapabilities: { + loadSession: false, + }, + authMethods: [], + })), + }); + + try { + const response = await server.handleRequest( + jsonRequest(initializeRequest), + ); + const body = await response.json(); + const connectionId = response.headers.get(HEADER_CONNECTION_ID) ?? ""; + + expect(response.status).toBe(200); + expect(body).toMatchObject({ + jsonrpc: "2.0", + id: initializeRequest.id, + result: { + protocolVersion: PROTOCOL_VERSION, + }, + }); + expect(connectHookRuns).toBe(1); + + await waitForConnectionNotFound(server, connectionId); + } finally { + await server.close(); + } + }); + it("accepts a deprecated legacy agent factory for direct HTTP initialize requests", async () => { const connections: AgentSideConnection[] = []; const server = new AcpServer({ @@ -1095,6 +1281,36 @@ async function waitFor(callback: () => boolean): Promise { } } +async function waitForConnectionNotFound( + server: AcpServer, + connectionId: string, +): Promise { + const deadline = Date.now() + 1_000; + + for (;;) { + const response = await server.handleRequest( + new Request("http://127.0.0.1/acp", { + method: "GET", + headers: { + Accept: EVENT_STREAM_MIME_TYPE, + [HEADER_CONNECTION_ID]: connectionId, + }, + }), + ); + + if (response.status === 404) { + return; + } + + await response.body?.cancel(); + if (Date.now() > deadline) { + throw new Error("Timed out waiting for connection to be removed"); + } + + await new Promise((resolve) => setTimeout(resolve, 1)); + } +} + function createDeferred(): { readonly promise: Promise; readonly resolve: (value: T | PromiseLike) => void; diff --git a/src/server.ts b/src/server.ts index 41bba99..ff11fee 100644 --- a/src/server.ts +++ b/src/server.ts @@ -316,6 +316,7 @@ export class AcpServer { } connection.startRouter(); + connection.startConnectHandlers(); return jsonResponse(initialResponse, 200, { [HEADER_CONNECTION_ID]: connection.connectionId, @@ -398,7 +399,8 @@ function resolveAgent(options: AgentOptions): AgentConnector { if (options.createAgent) { return { - connect: (stream) => options.createAgent!().connect(stream), + connect: (stream, connectOptions) => + options.createAgent!().connect(stream, connectOptions ?? {}), }; } diff --git a/src/ws-server.ts b/src/ws-server.ts index 7e85609..0321203 100644 --- a/src/ws-server.ts +++ b/src/ws-server.ts @@ -190,6 +190,7 @@ class WebSocketServerSession implements WebSocketServerSessionHandle { this.connection = connection; this.options.registry.register(connection); connection.startRouter(); + connection.startConnectHandlers(); this.send(initialResponse); this.startOutboundPump(connection);