From 8f4508b6464e65124c79c27edb6f6508c9a13216 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Wed, 27 May 2026 07:32:16 -0400 Subject: [PATCH] Lazily initialize sessions on session.create lifecycle events For cloud session creation, defer CopilotSession construction until the session.created lifecycle event arrives with the server-assigned session ID. This avoids registering a session under a provisional client-generated ID that will be replaced by the server, and ensures the session object is immediately keyed by its real ID. Changes across all SDKs (Node, .NET, Go, Python, Rust): - Extract a configureSession/CreateConfiguredSession helper for session setup - Add pendingCloudSessionCreates map to track deferred session factories - On session.created lifecycle event with a clientSessionId, invoke the pending factory to materialize the session under the real ID - Surface clientSessionId on SessionLifecycleEvent types - Mark cloud session APIs as experimental where applicable Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dotnet/src/Client.cs | 149 +++++++++++++++------------ dotnet/src/Types.cs | 9 ++ go/client.go | 213 +++++++++++++++++++++++++++------------ go/types.go | 14 ++- nodejs/src/client.ts | 139 ++++++++++++++++++------- nodejs/src/types.ts | 10 ++ python/copilot/client.py | 65 ++++++++++-- rust/src/lib.rs | 18 +++- rust/src/router.rs | 35 ++++++- rust/src/session.rs | 197 +++++++++++++++++++++++++++++------- rust/src/types.rs | 13 ++- rust/src/wire.rs | 3 +- 12 files changed, 641 insertions(+), 224 deletions(-) diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 4a65780bd..cc611c895 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -66,6 +66,7 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable /// that has not been explicitly disposed or removed. /// internal readonly ConcurrentDictionary _sessions = new(); + private ConcurrentDictionary>? _pendingCloudCreates; private readonly CopilotClientOptions _options; private readonly RuntimeConnection _connection; @@ -137,6 +138,7 @@ public CopilotClient(CopilotClientOptions? options = null) { throw new ArgumentException("ConnectionToken must be a non-empty string or null.", nameof(options)); } + // Auto-generate a connection token when the SDK spawns the runtime over TCP // so the loopback listener is safe by default. tcp.ConnectionToken ??= Guid.NewGuid().ToString(); @@ -164,6 +166,11 @@ public CopilotClient(CopilotClientOptions? options = null) _onListModels = _options.OnListModels; } + private ConcurrentDictionary> GetOrCreatePendingCloudCreates() => + _pendingCloudCreates ?? + System.Threading.Interlocked.CompareExchange(ref _pendingCloudCreates, new(), null) ?? + _pendingCloudCreates!; + /// /// Parses a runtime URL into a URI with host and port. /// @@ -535,42 +542,22 @@ public async Task CreateSessionAsync(SessionConfig config, Cance var (wireSystemMessage, transformCallbacks) = ExtractTransformCallbacks(config.SystemMessage); + var isCloudCreate = config.Cloud != null; var sessionId = config.SessionId ?? Guid.NewGuid().ToString(); // Create and register the session before issuing the RPC so that // events emitted by the CLI (e.g. session.start) are not dropped. var setupTimestamp = Stopwatch.GetTimestamp(); - var session = new CopilotSession( - sessionId, - connection.Rpc, - _logger, - this); - session.RegisterTools(config.Tools ?? []); - session.RegisterPermissionHandler(config.OnPermissionRequest); - session.RegisterCommands(config.Commands); - session.RegisterElicitationHandler(config.OnElicitationRequest); - session.RegisterExitPlanModeHandler(config.OnExitPlanModeRequest); - session.RegisterAutoModeSwitchHandler(config.OnAutoModeSwitchRequest); - if (config.OnUserInputRequest != null) + CopilotSession? session = null; + if (isCloudCreate) { - session.RegisterUserInputHandler(config.OnUserInputRequest); - } - if (config.Hooks != null) - { - session.RegisterHooks(config.Hooks); - } - if (transformCallbacks != null) - { - session.RegisterTransformCallbacks(transformCallbacks); + GetOrCreatePendingCloudCreates()[sessionId] = + createdSessionId => CreateConfiguredSession(createdSessionId, connection, config, transformCallbacks); } - if (config.OnEvent != null) + else { - session.On(config.OnEvent); + session = CreateConfiguredSession(sessionId, connection, config, transformCallbacks); } - ConfigureSessionFsHandlers(session, config.CreateSessionFsProvider); - session.SetCanvasHandler(config.CanvasHandler); - RegisterSession(session); - session.StartProcessingEvents(); LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, "CopilotClient.CreateSessionAsync local setup complete. Elapsed={Elapsed}, SessionId={SessionId}, Tools={ToolsCount}, Commands={CommandsCount}, Hooks={HasHooks}", setupTimestamp, @@ -632,7 +619,20 @@ public async Task CreateSessionAsync(SessionConfig config, Cance LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, "CopilotClient.CreateSessionAsync session creation request completed successfully. Elapsed={Elapsed}, SessionId={SessionId}", rpcTimestamp, - sessionId); + response.SessionId); + + if (session is null) + { + _sessions.TryGetValue(response.SessionId, out session); + if (session is null) + { + throw new InvalidOperationException($"Session was not registered: {response.SessionId}"); + } + } + if (isCloudCreate) + { + _pendingCloudCreates?.TryRemove(sessionId, out _); + } session.WorkspacePath = response.WorkspacePath; session.SetCapabilities(response.Capabilities); @@ -640,13 +640,14 @@ public async Task CreateSessionAsync(SessionConfig config, Cance } catch (Exception ex) { - session.RemoveFromClient(); + _pendingCloudCreates?.TryRemove(sessionId, out _); + session?.RemoveFromClient(); if (ex is not OperationCanceledException) { LoggingHelpers.LogTiming(_logger, LogLevel.Warning, ex, "CopilotClient.CreateSessionAsync failed. Elapsed={Elapsed}, SessionId={SessionId}", totalTimestamp, - sessionId); + sessionId ?? ""); } throw; } @@ -654,7 +655,7 @@ public async Task CreateSessionAsync(SessionConfig config, Cance LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, "CopilotClient.CreateSessionAsync complete. Elapsed={Elapsed}, SessionId={SessionId}", totalTimestamp, - sessionId); + session.SessionId); return session; } @@ -703,40 +704,8 @@ public async Task ResumeSessionAsync(string sessionId, ResumeSes var (wireSystemMessage, transformCallbacks) = ExtractTransformCallbacks(config.SystemMessage); - // Create and register the session before issuing the RPC so that - // events emitted by the CLI (e.g. session.start) are not dropped. var setupTimestamp = Stopwatch.GetTimestamp(); - var session = new CopilotSession( - sessionId, - connection.Rpc, - _logger, - client: this); - session.RegisterTools(config.Tools ?? []); - session.RegisterPermissionHandler(config.OnPermissionRequest); - session.RegisterCommands(config.Commands); - session.RegisterElicitationHandler(config.OnElicitationRequest); - session.RegisterExitPlanModeHandler(config.OnExitPlanModeRequest); - session.RegisterAutoModeSwitchHandler(config.OnAutoModeSwitchRequest); - if (config.OnUserInputRequest != null) - { - session.RegisterUserInputHandler(config.OnUserInputRequest); - } - if (config.Hooks != null) - { - session.RegisterHooks(config.Hooks); - } - if (transformCallbacks != null) - { - session.RegisterTransformCallbacks(transformCallbacks); - } - if (config.OnEvent != null) - { - session.On(config.OnEvent); - } - ConfigureSessionFsHandlers(session, config.CreateSessionFsProvider); - session.SetCanvasHandler(config.CanvasHandler); - RegisterSession(session); - session.StartProcessingEvents(); + var session = CreateConfiguredSession(sessionId, connection, config, transformCallbacks); LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, "CopilotClient.ResumeSessionAsync local setup complete. Elapsed={Elapsed}, SessionId={SessionId}, Tools={ToolsCount}, Commands={CommandsCount}, Hooks={HasHooks}", setupTimestamp, @@ -826,6 +795,46 @@ public async Task ResumeSessionAsync(string sessionId, ResumeSes return session; } + private CopilotSession CreateConfiguredSession( + string sessionId, + Connection connection, + SessionConfigBase config, + Dictionary>>? transformCallbacks) + { + var session = new CopilotSession( + sessionId, + connection.Rpc, + _logger, + this); + session.RegisterTools(config.Tools ?? []); + session.RegisterPermissionHandler(config.OnPermissionRequest); + session.RegisterCommands(config.Commands); + session.RegisterElicitationHandler(config.OnElicitationRequest); + session.RegisterExitPlanModeHandler(config.OnExitPlanModeRequest); + session.RegisterAutoModeSwitchHandler(config.OnAutoModeSwitchRequest); + if (config.OnUserInputRequest != null) + { + session.RegisterUserInputHandler(config.OnUserInputRequest); + } + if (config.Hooks != null) + { + session.RegisterHooks(config.Hooks); + } + if (transformCallbacks != null) + { + session.RegisterTransformCallbacks(transformCallbacks); + } + if (config.OnEvent != null) + { + session.On(config.OnEvent); + } + ConfigureSessionFsHandlers(session, config.CreateSessionFsProvider); + session.SetCanvasHandler(config.CanvasHandler); + RegisterSession(session); + session.StartProcessingEvents(); + return session; + } + /// /// Validates the health of the connection by sending a ping request. /// @@ -1736,7 +1745,7 @@ public void OnSessionEvent(string sessionId, JsonElement? @event) } } - public void OnSessionLifecycle(string type, string sessionId, JsonElement? metadata) + public void OnSessionLifecycle(string type, string sessionId, JsonElement? metadata, string? clientSessionId = null) { SessionLifecycleEvent evt = type switch { @@ -1750,6 +1759,7 @@ public void OnSessionLifecycle(string type, string sessionId, JsonElement? metad evt.Type = type; evt.SessionId = sessionId; + evt.ClientSessionId = clientSessionId; if (metadata is not null) { evt.Metadata = JsonSerializer.Deserialize( @@ -1757,6 +1767,15 @@ public void OnSessionLifecycle(string type, string sessionId, JsonElement? metad TypesJsonContext.Default.SessionLifecycleEventMetadata); } + if (type == "session.created" && + !client._sessions.ContainsKey(sessionId) && + clientSessionId is not null && + client._pendingCloudCreates is not null && + client._pendingCloudCreates.TryRemove(clientSessionId, out var initializeFunc)) + { + initializeFunc(sessionId); + } + client.DispatchLifecycleEvent(evt); } diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index 54c0f71b6..8237766b6 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -2230,6 +2230,7 @@ public sealed class CloudSessionRepository /// /// Options for creating a remote session in the cloud. /// +[Experimental(Diagnostics.Experimental)] public sealed class CloudSessionOptions { /// @@ -2546,6 +2547,7 @@ private SessionConfig(SessionConfig? other) : base(other) /// Creates a remote session in the cloud instead of a local session. /// The optional repository is associated with the cloud session. /// + [Experimental(Diagnostics.Experimental)] public CloudSessionOptions? Cloud { get; set; } /// @@ -3030,6 +3032,13 @@ public class SessionLifecycleEvent [JsonPropertyName("sessionId")] public string SessionId { get; set; } = string.Empty; + /// + /// Provisional client-generated session ID supplied for cloud session creation. + /// + [Experimental(Diagnostics.Experimental)] + [JsonPropertyName("clientSessionId")] + public string? ClientSessionId { get; set; } + /// /// Metadata associated with the session lifecycle event. /// diff --git a/go/client.go b/go/client.go index ae89128a1..696148d23 100644 --- a/go/client.go +++ b/go/client.go @@ -88,17 +88,18 @@ func validateSessionFsConfig(config *SessionFsConfig) error { // } // defer client.Stop() type Client struct { - options ClientOptions - process *exec.Cmd - client *jsonrpc2.Client - actualPort int - actualHost string - state connectionState - sessions map[string]*Session - sessionsMux sync.Mutex - isExternalServer bool - conn net.Conn // stores net.Conn for external TCP connections - useStdio bool // resolved value from options + options ClientOptions + process *exec.Cmd + client *jsonrpc2.Client + actualPort int + actualHost string + state connectionState + sessions map[string]*Session + pendingCloudCreates map[string]func(string) (*Session, error) + sessionsMux sync.Mutex + isExternalServer bool + conn net.Conn // stores net.Conn for external TCP connections + useStdio bool // resolved value from options // resolved process options for the spawned runtime (zero values for UriConnection) cliPath string cliArgs []string @@ -156,12 +157,13 @@ func NewClient(options *ClientOptions) *Client { opts := ClientOptions{} client := &Client{ - options: opts, - state: stateDisconnected, - sessions: make(map[string]*Session), - actualHost: "localhost", - isExternalServer: false, - useStdio: true, + options: opts, + state: stateDisconnected, + sessions: make(map[string]*Session), + pendingCloudCreates: make(map[string]func(string) (*Session, error)), + actualHost: "localhost", + isExternalServer: false, + useStdio: true, } if options != nil { @@ -676,73 +678,114 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses req.Traceparent = traceparent req.Tracestate = tracestate + isCloudCreate := config.Cloud != nil + sessionID := config.SessionID if sessionID == "" { sessionID = uuid.New().String() } req.SessionID = sessionID - // Create and register the session before issuing the RPC so that - // events emitted by the CLI (e.g. session.start) are not dropped. - session := newSession(sessionID, c.client, "") + var materializedSessionID string + var materializedMux sync.Mutex + materializeSession := func(createdSessionID string) (*Session, error) { + c.sessionsMux.Lock() + if existing := c.sessions[createdSessionID]; existing != nil { + c.sessionsMux.Unlock() + materializedMux.Lock() + materializedSessionID = createdSessionID + materializedMux.Unlock() + return existing, nil + } + c.sessionsMux.Unlock() - session.registerTools(config.Tools) - session.registerPermissionHandler(config.OnPermissionRequest) - if config.OnUserInputRequest != nil { - session.registerUserInputHandler(config.OnUserInputRequest) - } - if config.Hooks != nil { - session.registerHooks(config.Hooks) - } - if transformCallbacks != nil { - session.registerTransformCallbacks(transformCallbacks) - } - if config.OnEvent != nil { - session.On(config.OnEvent) - } - if len(config.Commands) > 0 { - session.registerCommands(config.Commands) - } - if config.OnElicitationRequest != nil { - session.registerElicitationHandler(config.OnElicitationRequest) - } - if config.OnExitPlanModeRequest != nil { - session.registerExitPlanModeHandler(config.OnExitPlanModeRequest) - } - if config.OnAutoModeSwitchRequest != nil { - session.registerAutoModeSwitchHandler(config.OnAutoModeSwitchRequest) - } - if config.CanvasHandler != nil { - session.registerCanvasHandler(config.CanvasHandler) - } + session := newSession(createdSessionID, c.client, "") + session.registerTools(config.Tools) + session.registerPermissionHandler(config.OnPermissionRequest) + if config.OnUserInputRequest != nil { + session.registerUserInputHandler(config.OnUserInputRequest) + } + if config.Hooks != nil { + session.registerHooks(config.Hooks) + } + if transformCallbacks != nil { + session.registerTransformCallbacks(transformCallbacks) + } + if config.OnEvent != nil { + session.On(config.OnEvent) + } + if len(config.Commands) > 0 { + session.registerCommands(config.Commands) + } + if config.OnElicitationRequest != nil { + session.registerElicitationHandler(config.OnElicitationRequest) + } + if config.OnExitPlanModeRequest != nil { + session.registerExitPlanModeHandler(config.OnExitPlanModeRequest) + } + if config.OnAutoModeSwitchRequest != nil { + session.registerAutoModeSwitchHandler(config.OnAutoModeSwitchRequest) + } + if config.CanvasHandler != nil { + session.registerCanvasHandler(config.CanvasHandler) + } - c.sessionsMux.Lock() - c.sessions[sessionID] = session - c.sessionsMux.Unlock() + c.sessionsMux.Lock() + c.sessions[createdSessionID] = session + c.sessionsMux.Unlock() + materializedMux.Lock() + materializedSessionID = createdSessionID + materializedMux.Unlock() - if c.options.SessionFs != nil { - if config.CreateSessionFsProvider == nil { - c.sessionsMux.Lock() - delete(c.sessions, sessionID) - c.sessionsMux.Unlock() - return nil, fmt.Errorf("CreateSessionFsProvider is required in session config when SessionFs is enabled in client options") - } - provider := config.CreateSessionFsProvider(session) - if c.options.SessionFs.Capabilities != nil && c.options.SessionFs.Capabilities.Sqlite { - if _, ok := provider.(SessionFsSqliteProvider); !ok { + if c.options.SessionFs != nil { + if config.CreateSessionFsProvider == nil { c.sessionsMux.Lock() - delete(c.sessions, sessionID) + delete(c.sessions, createdSessionID) c.sessionsMux.Unlock() - return nil, fmt.Errorf("SessionFs capabilities declare SQLite support but the provider does not implement SessionFsSqliteProvider") + return nil, fmt.Errorf("CreateSessionFsProvider is required in session config when SessionFs is enabled in client options") + } + provider := config.CreateSessionFsProvider(session) + if c.options.SessionFs.Capabilities != nil && c.options.SessionFs.Capabilities.Sqlite { + if _, ok := provider.(SessionFsSqliteProvider); !ok { + c.sessionsMux.Lock() + delete(c.sessions, createdSessionID) + c.sessionsMux.Unlock() + return nil, fmt.Errorf("SessionFs capabilities declare SQLite support but the provider does not implement SessionFsSqliteProvider") + } } + session.clientSessionApis.SessionFs = newSessionFsAdapter(provider) + } + + return session, nil + } + + var session *Session + if isCloudCreate { + c.sessionsMux.Lock() + c.pendingCloudCreates[sessionID] = materializeSession + c.sessionsMux.Unlock() + } else { + var err error + session, err = materializeSession(sessionID) + if err != nil { + return nil, err } - session.clientSessionApis.SessionFs = newSessionFsAdapter(provider) } result, err := c.client.Request("session.create", req) if err != nil { c.sessionsMux.Lock() - delete(c.sessions, sessionID) + delete(c.pendingCloudCreates, sessionID) + materializedMux.Lock() + if materializedSessionID != "" { + delete(c.sessions, materializedSessionID) + } + materializedMux.Unlock() + if session != nil { + delete(c.sessions, session.SessionID) + } else if sessionID != "" { + delete(c.sessions, sessionID) + } c.sessionsMux.Unlock() return nil, fmt.Errorf("failed to create session: %w", err) } @@ -750,11 +793,31 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses var response createSessionResponse if err := json.Unmarshal(result, &response); err != nil { c.sessionsMux.Lock() - delete(c.sessions, sessionID) + delete(c.pendingCloudCreates, sessionID) + materializedMux.Lock() + if materializedSessionID != "" { + delete(c.sessions, materializedSessionID) + } + materializedMux.Unlock() + if session != nil { + delete(c.sessions, session.SessionID) + } else if sessionID != "" { + delete(c.sessions, sessionID) + } c.sessionsMux.Unlock() return nil, fmt.Errorf("failed to unmarshal response: %w", err) } + if isCloudCreate { + c.sessionsMux.Lock() + session = c.sessions[response.SessionID] + delete(c.pendingCloudCreates, sessionID) + c.sessionsMux.Unlock() + if session == nil { + return nil, fmt.Errorf("cloud session was not registered: %s", response.SessionID) + } + } + session.workspacePath = response.WorkspacePath session.setCapabilities(response.Capabilities) @@ -1242,6 +1305,22 @@ func (c *Client) OnEventType(eventType SessionLifecycleEventType, handler Sessio // handleLifecycleEvent dispatches a lifecycle event to all registered handlers func (c *Client) handleLifecycleEvent(event SessionLifecycleEvent) { + if event.Type == SessionLifecycleCreated && event.ClientSessionID != "" { + c.sessionsMux.Lock() + _, alreadyRegistered := c.sessions[event.SessionID] + var materialize func(string) (*Session, error) + if !alreadyRegistered { + materialize = c.pendingCloudCreates[event.ClientSessionID] + delete(c.pendingCloudCreates, event.ClientSessionID) + } + c.sessionsMux.Unlock() + if materialize != nil { + if _, err := materialize(event.SessionID); err != nil { + fmt.Printf("Error materializing cloud session %s: %v\n", event.SessionID, err) + } + } + } + c.lifecycleHandlersMux.Lock() // Copy handlers to avoid holding lock during callbacks typedHandlers := make([]SessionLifecycleHandler, 0) diff --git a/go/types.go b/go/types.go index 193c673c5..f5477e80d 100644 --- a/go/types.go +++ b/go/types.go @@ -140,6 +140,8 @@ type CloudSessionRepository struct { } // CloudSessionOptions configures creation of a remote session in the cloud. +// +// Experimental: this API is not stable and may change or be removed. type CloudSessionOptions struct { Repository *CloudSessionRepository `json:"repository,omitempty"` } @@ -976,6 +978,8 @@ type SessionConfig struct { RemoteSession rpc.RemoteSessionMode // Cloud creates a remote session in the cloud instead of a local session. // The optional repository is associated with the cloud session. + // + // Experimental: this API is not stable and may change or be removed. Cloud *CloudSessionOptions // Canvases declares canvases this session provides. Sent over the wire on // `session.create`. CanvasHandler must be set when this is non-empty (the @@ -1427,9 +1431,13 @@ const ( // SessionLifecycleEvent represents a session lifecycle notification type SessionLifecycleEvent struct { - Type SessionLifecycleEventType `json:"type"` - SessionID string `json:"sessionId"` - Metadata *SessionLifecycleEventMetadata `json:"metadata,omitempty"` + Type SessionLifecycleEventType `json:"type"` + SessionID string `json:"sessionId"` + // ClientSessionID is a provisional client-generated session ID supplied for cloud session creation. + // + // Experimental: this API is not stable and may change or be removed. + ClientSessionID string `json:"clientSessionId,omitempty"` + Metadata *SessionLifecycleEventMetadata `json:"metadata,omitempty"` } // SessionLifecycleEventMetadata contains optional metadata for lifecycle events diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index 11e6131cb..3c9724d59 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -251,6 +251,8 @@ export class CopilotClient { private actualHost: string = "localhost"; private state: "disconnected" | "connecting" | "connected" | "error" = "disconnected"; private sessions: Map = new Map(); + private pendingCloudSessionCreates: Map CopilotSession> = + new Map(); private stderrBuffer: string = ""; // Captures CLI stderr for error messages /** Resolved connection mode chosen in the constructor. */ private connectionConfig: InternalRuntimeConnection; @@ -485,6 +487,39 @@ export class CopilotClient { session.clientSessionApis.sessionFs = createSessionFsAdapter(provider); } + private configureSession( + session: CopilotSession, + config: SessionConfig, + transformCallbacks: Map | undefined + ): void { + session.registerTools(config.tools); + session.registerCanvases(config.canvases); + session.registerCommands(config.commands); + session.registerPermissionHandler(config.onPermissionRequest); + if (config.onUserInputRequest) { + session.registerUserInputHandler(config.onUserInputRequest); + } + if (config.onElicitationRequest) { + session.registerElicitationHandler(config.onElicitationRequest); + } + if (config.onExitPlanModeRequest) { + session.registerExitPlanModeHandler(config.onExitPlanModeRequest); + } + if (config.onAutoModeSwitchRequest) { + session.registerAutoModeSwitchHandler(config.onAutoModeSwitchRequest); + } + if (config.hooks) { + session.registerHooks(config.hooks); + } + if (transformCallbacks) { + session.registerTransformCallbacks(transformCallbacks); + } + if (config.onEvent) { + session.on(config.onEvent); + } + this.setupSessionFs(session, config); + } + /** * Starts the CLI server and establishes a connection. * @@ -794,49 +829,49 @@ export class CopilotClient { await this.start(); } - const sessionId = config.sessionId ?? randomUUID(); - - // Create and register the session before issuing the RPC so that - // events emitted by the CLI (e.g. session.start) are not dropped. - const session = new CopilotSession( - sessionId, - this.connection!, - undefined, - this.onGetTraceContext - ); - session.registerTools(config.tools); - session.registerCanvases(config.canvases); - session.registerCommands(config.commands); - session.registerPermissionHandler(config.onPermissionRequest); - if (config.onUserInputRequest) { - session.registerUserInputHandler(config.onUserInputRequest); - } - if (config.onElicitationRequest) { - session.registerElicitationHandler(config.onElicitationRequest); - } - if (config.onExitPlanModeRequest) { - session.registerExitPlanModeHandler(config.onExitPlanModeRequest); - } - if (config.onAutoModeSwitchRequest) { - session.registerAutoModeSwitchHandler(config.onAutoModeSwitchRequest); - } - if (config.hooks) { - session.registerHooks(config.hooks); - } - // Extract transform callbacks from system message config before serialization. const { wirePayload: wireSystemMessage, transformCallbacks } = extractTransformCallbacks( config.systemMessage ); - if (transformCallbacks) { - session.registerTransformCallbacks(transformCallbacks); + + const isCloudCreate = config.cloud !== undefined; + if (isCloudCreate && config.sessionId) { + // In cloud mode this is a provisional client-side ID, not the final server ID. } - if (config.onEvent) { - session.on(config.onEvent); + const sessionId = config.sessionId ?? randomUUID(); + let session: CopilotSession | undefined; + + if (isCloudCreate) { + const materializeCloudSession = (createdSessionId: string) => { + const existing = this.sessions.get(createdSessionId); + if (existing) { + return existing; + } + const createdSession = new CopilotSession( + createdSessionId, + this.connection!, + undefined, + this.onGetTraceContext + ); + this.configureSession(createdSession, config, transformCallbacks); + this.sessions.set(createdSessionId, createdSession); + session = createdSession; + return createdSession; + }; + this.pendingCloudSessionCreates.set(sessionId, materializeCloudSession); + } else { + // Create and register the session before issuing the RPC so that + // events emitted by the CLI (e.g. session.start) are not dropped. + session = new CopilotSession( + sessionId!, + this.connection!, + undefined, + this.onGetTraceContext + ); + this.configureSession(session, config, transformCallbacks); + this.sessions.set(sessionId!, session); } - this.sessions.set(sessionId, session); - this.setupSessionFs(session, config); try { const response = await this.connection!.sendRequest("session.create", { @@ -891,15 +926,33 @@ export class CopilotClient { cloud: config.cloud, }); - const { workspacePath, capabilities } = response as { + const createResponse = response as { sessionId: string; workspacePath?: string; capabilities?: SessionCapabilities; }; + const { workspacePath, capabilities } = createResponse; + if (isCloudCreate) { + session = this.sessions.get(createResponse.sessionId); + if (!session) { + throw new Error( + `Cloud session was not registered: ${createResponse.sessionId}` + ); + } + this.pendingCloudSessionCreates.delete(sessionId); + } + if (!session) { + throw new Error("Session was not registered"); + } session["_workspacePath"] = workspacePath; session.setCapabilities(capabilities); } catch (e) { - this.sessions.delete(sessionId); + this.pendingCloudSessionCreates.delete(sessionId); + if (session) { + this.sessions.delete(session.sessionId); + } else if (sessionId !== undefined) { + this.sessions.delete(sessionId); + } throw e; } @@ -1944,6 +1997,7 @@ export class CopilotClient { const raw = notification as { type: SessionLifecycleEventType; sessionId: string; + clientSessionId?: string; metadata?: { startTime?: string; modifiedTime?: string; summary?: string }; }; @@ -1959,9 +2013,18 @@ export class CopilotClient { const event = { type: raw.type, sessionId: raw.sessionId, + clientSessionId: raw.clientSessionId, metadata, } as SessionLifecycleEvent; + if (raw.type === "session.created" && raw.clientSessionId && !this.sessions.has(raw.sessionId)) { + const materialize = this.pendingCloudSessionCreates.get(raw.clientSessionId); + if (materialize) { + materialize(raw.sessionId); + this.pendingCloudSessionCreates.delete(raw.clientSessionId); + } + } + // Dispatch to typed handlers for this specific event type const typedHandlers = this.typedLifecycleHandlers.get(event.type); if (typedHandlers) { diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index 938a7f2fc..30e52c388 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -324,6 +324,8 @@ export interface CloudSessionRepository { /** * Options for creating a remote session in the cloud. + * + * @experimental This API is not stable and may change or be removed. */ export interface CloudSessionOptions { repository?: CloudSessionRepository; @@ -1768,6 +1770,8 @@ export interface SessionConfig extends SessionConfigBase { /** * Creates a remote session in the cloud instead of a local session. * The optional repository is associated with the cloud session. + * + * @experimental This API is not stable and may change or be removed. */ cloud?: CloudSessionOptions; } @@ -2154,6 +2158,12 @@ export interface SessionLifecycleEventMetadata { interface SessionLifecycleEventBase { /** ID of the session this event relates to. */ sessionId: string; + /** + * Provisional client-generated session ID supplied for cloud session creation. + * + * @experimental This API is not stable and may change or be removed. + */ + clientSessionId?: string; /** Session metadata (not included for `session.deleted`). */ metadata?: SessionLifecycleEventMetadata; } diff --git a/python/copilot/client.py b/python/copilot/client.py index 4386adb08..255452504 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -100,7 +100,10 @@ class CloudSessionRepository: @dataclass class CloudSessionOptions: - """Options for creating a remote session in the cloud.""" + """Options for creating a remote session in the cloud. + + Experimental: this API is not stable and may change or be removed. + """ repository: CloudSessionRepository | None = None @@ -872,6 +875,8 @@ class SessionLifecycleEventBase: session_id: str metadata: SessionLifecycleEventMetadata | None = None + # Experimental: this API is not stable and may change or be removed. + client_session_id: str | None = None @dataclass @@ -924,18 +929,29 @@ def _session_lifecycle_event_from_dict(data: dict) -> SessionLifecycleEvent: if "metadata" in data and data["metadata"]: metadata = SessionLifecycleEventMetadata.from_dict(data["metadata"]) session_id = data.get("sessionId", "") + client_session_id = data.get("clientSessionId") event_type = data.get("type") if event_type == "session.created": - return SessionCreatedEvent(session_id=session_id, metadata=metadata) + return SessionCreatedEvent( + session_id=session_id, metadata=metadata, client_session_id=client_session_id + ) if event_type == "session.deleted": - return SessionDeletedEvent(session_id=session_id, metadata=metadata) + return SessionDeletedEvent( + session_id=session_id, metadata=metadata, client_session_id=client_session_id + ) if event_type == "session.foreground": - return SessionForegroundEvent(session_id=session_id, metadata=metadata) + return SessionForegroundEvent( + session_id=session_id, metadata=metadata, client_session_id=client_session_id + ) if event_type == "session.background": - return SessionBackgroundEvent(session_id=session_id, metadata=metadata) + return SessionBackgroundEvent( + session_id=session_id, metadata=metadata, client_session_id=client_session_id + ) # Default to ``session.updated`` for unknown event types so consumers # keep working across server upgrades. - return SessionUpdatedEvent(session_id=session_id, metadata=metadata) + return SessionUpdatedEvent( + session_id=session_id, metadata=metadata, client_session_id=client_session_id + ) SessionLifecycleHandler = Callable[[SessionLifecycleEvent], None] @@ -1187,6 +1203,7 @@ def __init__( self._state: _ConnectionState = "disconnected" self._sessions: dict[str, CopilotSession] = {} self._sessions_lock = threading.Lock() + self._pending_cloud_creates: dict[str, CopilotSession] = {} self._models_cache: list[ModelInfo] | None = None self._models_cache_lock = asyncio.Lock() self._lifecycle_handlers: list[SessionLifecycleHandler] = [] @@ -1619,7 +1636,8 @@ async def create_session( infinite_sessions: Infinite session configuration. cloud: Creates a remote session in the cloud instead of a local session. Optionally associates repository metadata with the - cloud session. + cloud session. Experimental: this API is not stable and may + change or be removed. on_event: Callback for session events. Returns: @@ -1806,6 +1824,8 @@ async def create_session( raise RuntimeError("Client not connected") total_start = time.perf_counter() + is_cloud_create = cloud is not None + actual_session_id = session_id or str(uuid.uuid4()) payload["sessionId"] = actual_session_id @@ -1855,6 +1875,8 @@ async def create_session( session.on(on_event) with self._sessions_lock: self._sessions[actual_session_id] = session + if is_cloud_create: + self._pending_cloud_creates[actual_session_id] = session log_timing( logger, logging.DEBUG, @@ -1869,19 +1891,33 @@ async def create_session( try: rpc_start = time.perf_counter() response = await self._client.request("session.create", payload) + response_session_id = response.get("sessionId") + if is_cloud_create and response_session_id: + with self._sessions_lock: + materialized = self._sessions.get(response_session_id) + if materialized is None: + self._pending_cloud_creates.pop(actual_session_id, None) + raise RuntimeError( + f"Cloud session was not registered: {response_session_id}" + ) + self._pending_cloud_creates.pop(actual_session_id, None) + session = materialized log_timing( logger, logging.DEBUG, "CopilotClient.create_session session creation request completed successfully", rpc_start, - session_id=actual_session_id, + session_id=session.session_id, ) session._workspace_path = response.get("workspacePath") capabilities = response.get("capabilities") session._set_capabilities(capabilities) except BaseException as exc: with self._sessions_lock: + self._sessions.pop(session.session_id, None) self._sessions.pop(actual_session_id, None) + if is_cloud_create: + self._pending_cloud_creates.pop(actual_session_id, None) if not isinstance(exc, asyncio.CancelledError): log_timing( logger, @@ -1892,7 +1928,6 @@ async def create_session( session_id=actual_session_id, ) raise - log_timing( logger, logging.DEBUG, @@ -2643,6 +2678,18 @@ def unsubscribe_typed() -> None: def _dispatch_lifecycle_event(self, event: SessionLifecycleEvent) -> None: """Dispatch a lifecycle event to all registered handlers.""" + if event.type == "session.created" and event.client_session_id: + with self._sessions_lock: + session = ( + None + if event.session_id in self._sessions + else self._pending_cloud_creates.pop(event.client_session_id, None) + ) + if session is not None: + self._sessions.pop(session.session_id, None) + session.session_id = event.session_id + self._sessions[event.session_id] = session + with self._lifecycle_handlers_lock: # Copy handlers to avoid holding lock during callbacks typed_handlers = list(self._typed_lifecycle_handlers.get(event.type, [])) diff --git a/rust/src/lib.rs b/rust/src/lib.rs index cad6ee629..3492228e3 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -71,6 +71,8 @@ mod sdk_protocol_version; pub use sdk_protocol_version::{SDK_PROTOCOL_VERSION, get_sdk_protocol_version}; pub use subscription::{EventSubscription, Lagged, LifecycleSubscription, RecvError}; +pub(crate) type PendingCloudCreate = Box; + /// Minimum protocol version this SDK can communicate with. const MIN_PROTOCOL_VERSION: u32 = 3; @@ -894,6 +896,8 @@ struct ClientInner { negotiated_protocol_version: OnceLock, state: parking_lot::Mutex, lifecycle_tx: broadcast::Sender, + pending_cloud_creates: + Arc>>, on_list_models: Option>, models_cache: parking_lot::Mutex>>>, session_fs_configured: bool, @@ -1215,6 +1219,9 @@ impl Client { negotiated_protocol_version: OnceLock::new(), state: parking_lot::Mutex::new(ConnectionState::Connected), lifecycle_tx: broadcast::channel(256).0, + pending_cloud_creates: Arc::new(parking_lot::Mutex::new( + std::collections::HashMap::new(), + )), on_list_models, models_cache: parking_lot::Mutex::new(Arc::new(tokio::sync::OnceCell::new())), session_fs_configured, @@ -1580,9 +1587,11 @@ impl Client { &self, session_id: &SessionId, ) -> crate::router::SessionChannels { - self.inner - .router - .ensure_started(&self.inner.notification_tx, &self.inner.request_rx); + self.inner.router.ensure_started( + &self.inner.notification_tx, + &self.inner.request_rx, + &self.inner.pending_cloud_creates, + ); self.inner.router.register(session_id) } @@ -2601,6 +2610,9 @@ mod tests { negotiated_protocol_version: OnceLock::new(), state: parking_lot::Mutex::new(ConnectionState::Connected), lifecycle_tx: broadcast::channel(16).0, + pending_cloud_creates: Arc::new(parking_lot::Mutex::new( + std::collections::HashMap::new(), + )), on_list_models: Some(handler), models_cache: parking_lot::Mutex::new(Arc::new(tokio::sync::OnceCell::new())), session_fs_configured: false, diff --git a/rust/src/router.rs b/rust/src/router.rs index e14630e03..3d46cfd40 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -5,8 +5,11 @@ use parking_lot::Mutex; use tokio::sync::{broadcast, mpsc}; use tracing::warn; +use crate::PendingCloudCreate; use crate::jsonrpc::{JsonRpcNotification, JsonRpcRequest}; -use crate::types::{SessionEventNotification, SessionId}; +use crate::types::{ + SessionEventNotification, SessionId, SessionLifecycleEvent, SessionLifecycleEventType, +}; /// Per-session channels created by the router during session registration. pub(crate) struct SessionChannels { @@ -85,6 +88,7 @@ impl SessionRouter { &self, notification_tx: &broadcast::Sender, request_rx: &Mutex>>, + pending_cloud_creates: &Arc>>, ) { let mut started = self.started.lock(); if *started { @@ -94,11 +98,40 @@ impl SessionRouter { // Notification routing task let sessions = self.sessions.clone(); + let pending_cloud_creates = Arc::clone(pending_cloud_creates); let mut notif_rx = notification_tx.subscribe(); tokio::spawn(async move { loop { match notif_rx.recv().await { Ok(notification) => { + if notification.method == "session.lifecycle" { + let Some(params) = notification.params else { + continue; + }; + let event: SessionLifecycleEvent = match serde_json::from_value(params) + { + Ok(event) => event, + Err(e) => { + warn!(error = %e, "failed to deserialize session.lifecycle notification"); + continue; + } + }; + if event.event_type == SessionLifecycleEventType::Created + && let Some(client_session_id) = event.client_session_id + { + let already_registered = + sessions.lock().contains_key(&event.session_id); + let materialize = if already_registered { + None + } else { + pending_cloud_creates.lock().remove(&client_session_id) + }; + if let Some(materialize) = materialize { + materialize(event.session_id); + } + } + continue; + } if notification.method != "session.event" { continue; } diff --git a/rust/src/session.rs b/rust/src/session.rs index 57181459c..13322b13a 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -109,6 +109,15 @@ impl PendingSessionRegistration { } } +struct SessionRuntimeParts { + session_id: SessionId, + event_loop: JoinHandle<()>, + shutdown: CancellationToken, + idle_waiter: Arc>>, + capabilities: Arc>, + event_tx: tokio::sync::broadcast::Sender, +} + impl Drop for PendingSessionRegistration { fn drop(&mut self) { if !self.disarmed { @@ -787,6 +796,7 @@ impl Client { /// broadcast (and silently skips dispatch if one arrives anyway). pub async fn create_session(&self, mut config: SessionConfig) -> Result { let total_start = Instant::now(); + let is_cloud_create = config.cloud.is_some(); let session_id = config .session_id .clone() @@ -798,7 +808,10 @@ impl Client { if let Some(transforms) = config.system_message_transform.clone() { inject_transform_sections(&mut config, transforms.as_ref()); } - let (wire, mut runtime) = config.into_wire(session_id.clone())?; + let (mut wire, mut runtime) = config.into_wire(session_id.clone())?; + if is_cloud_create { + wire.session_id = None; + } let permission_handler = crate::permission::resolve_handler( runtime.permission_handler.take(), @@ -839,28 +852,58 @@ impl Client { inject_trace_context(&mut params, &trace_ctx); let setup_start = Instant::now(); - let capabilities = Arc::new(parking_lot::RwLock::new(SessionCapabilities::default())); - let channels = self.register_session(&session_id); - let idle_waiter = Arc::new(ParkingLotMutex::new(None)); - let shutdown = CancellationToken::new(); - let (event_tx, _) = tokio::sync::broadcast::channel(512); - let event_loop = spawn_event_loop( - session_id.clone(), - self.clone(), - handlers, - hooks, - transforms, - command_handlers, - canvas_handler, - session_fs_provider, - channels, - idle_waiter.clone(), - capabilities.clone(), - event_tx.clone(), - shutdown.clone(), - ); - let mut registration = - PendingSessionRegistration::new(self.clone(), session_id.clone(), shutdown.clone()); + let materialized_cloud_session: Arc>> = + Arc::new(ParkingLotMutex::new(None)); + let mut local_registration: Option = None; + + if is_cloud_create { + self.inner.router.ensure_started( + &self.inner.notification_tx, + &self.inner.request_rx, + &self.inner.pending_cloud_creates, + ); + let materialized = Arc::clone(&materialized_cloud_session); + let client = self.clone(); + let handlers = handlers.clone(); + let hooks = hooks.clone(); + let transforms = transforms.clone(); + let command_handlers = Arc::clone(&command_handlers); + let canvas_handler = canvas_handler.clone(); + let session_fs_provider = session_fs_provider.clone(); + self.inner.pending_cloud_creates.lock().insert( + session_id.clone(), + Box::new(move |created_session_id| { + let parts = materialize_session_runtime( + client, + created_session_id, + handlers, + hooks, + transforms, + command_handlers, + canvas_handler, + session_fs_provider, + ); + *materialized.lock() = Some(parts); + }), + ); + } else { + let parts = materialize_session_runtime( + self.clone(), + session_id.clone(), + handlers, + hooks, + transforms, + command_handlers, + canvas_handler, + session_fs_provider, + ); + local_registration = Some(PendingSessionRegistration::new( + self.clone(), + parts.session_id.clone(), + parts.shutdown.clone(), + )); + *materialized_cloud_session.lock() = Some(parts); + } tracing::debug!( elapsed_ms = setup_start.elapsed().as_millis(), session_id = %session_id, @@ -874,7 +917,15 @@ impl Client { let result = match self.call("session.create", Some(params)).await { Ok(result) => result, Err(error) => { - registration.cleanup(event_loop).await; + if is_cloud_create { + self.inner.pending_cloud_creates.lock().remove(&session_id); + } + if let Some(parts) = materialized_cloud_session.lock().take() { + cleanup_session_runtime(self, parts).await; + if let Some(registration) = local_registration.as_mut() { + registration.disarm(); + } + } return Err(error); } }; @@ -885,37 +936,63 @@ impl Client { let create_result: CreateSessionResult = match serde_json::from_value(result) { Ok(result) => result, Err(error) => { - registration.cleanup(event_loop).await; + if is_cloud_create { + self.inner.pending_cloud_creates.lock().remove(&session_id); + } + if let Some(parts) = materialized_cloud_session.lock().take() { + cleanup_session_runtime(self, parts).await; + if let Some(registration) = local_registration.as_mut() { + registration.disarm(); + } + } return Err(error.into()); } }; - if create_result.session_id != session_id { - registration.cleanup(event_loop).await; + if !is_cloud_create && create_result.session_id != session_id { + if let Some(parts) = materialized_cloud_session.lock().take() { + cleanup_session_runtime(self, parts).await; + if let Some(registration) = local_registration.as_mut() { + registration.disarm(); + } + } return Err(Error::Session(SessionError::SessionIdMismatch { requested: session_id, returned: create_result.session_id, })); } - *capabilities.write() = create_result.capabilities.unwrap_or_default(); + + if is_cloud_create && materialized_cloud_session.lock().is_none() { + self.inner.pending_cloud_creates.lock().remove(&session_id); + } + + let Some(parts) = materialized_cloud_session.lock().take() else { + self.inner.pending_cloud_creates.lock().remove(&session_id); + return Err(Error::InvalidConfig( + "cloud session was not registered".to_string(), + )); + }; + *parts.capabilities.write() = create_result.capabilities.unwrap_or_default(); tracing::debug!( elapsed_ms = total_start.elapsed().as_millis(), - session_id = %session_id, + session_id = %parts.session_id, "Client::create_session complete" ); - registration.disarm(); + if let Some(registration) = local_registration.as_mut() { + registration.disarm(); + } Ok(Session { - id: session_id, + id: parts.session_id, cwd: self.cwd().clone(), workspace_path: create_result.workspace_path, remote_url: create_result.remote_url, client: self.clone(), - event_loop: ParkingLotMutex::new(Some(event_loop)), - shutdown, - idle_waiter, - capabilities, + event_loop: ParkingLotMutex::new(Some(parts.event_loop)), + shutdown: parts.shutdown, + idle_waiter: parts.idle_waiter, + capabilities: parts.capabilities, open_canvases: Arc::new(parking_lot::RwLock::new(Vec::new())), - event_tx, + event_tx: parts.event_tx, }) } @@ -1107,6 +1184,54 @@ fn build_command_handler_map(commands: Option<&[CommandDefinition]>) -> Arc>, + transforms: Option>, + command_handlers: Arc, + canvas_handler: Option>, + session_fs_provider: Option>, +) -> SessionRuntimeParts { + let capabilities = Arc::new(parking_lot::RwLock::new(SessionCapabilities::default())); + let channels = client.register_session(&session_id); + let idle_waiter = Arc::new(ParkingLotMutex::new(None)); + let shutdown = CancellationToken::new(); + let (event_tx, _) = tokio::sync::broadcast::channel(512); + let event_loop = spawn_event_loop( + session_id.clone(), + client.clone(), + handlers, + hooks, + transforms, + command_handlers, + canvas_handler, + session_fs_provider, + channels, + idle_waiter.clone(), + capabilities.clone(), + event_tx.clone(), + shutdown.clone(), + ); + + SessionRuntimeParts { + session_id, + event_loop, + shutdown, + idle_waiter, + capabilities, + event_tx, + } +} + +async fn cleanup_session_runtime(client: &Client, parts: SessionRuntimeParts) { + parts.shutdown.cancel(); + let _ = parts.event_loop.await; + client.unregister_session(&parts.session_id); +} + #[allow(clippy::too_many_arguments)] fn spawn_event_loop( session_id: SessionId, diff --git a/rust/src/types.rs b/rust/src/types.rs index f454e33ed..f10edc0be 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -91,6 +91,11 @@ pub struct SessionLifecycleEvent { /// Identifier of the session this event refers to. #[serde(rename = "sessionId")] pub session_id: SessionId, + /// Provisional client-generated session ID supplied for cloud session creation. + /// + /// **Experimental.** This API is not stable and may change or be removed. + #[serde(rename = "clientSessionId", skip_serializing_if = "Option::is_none")] + pub client_session_id: Option, /// Optional metadata describing the session at the time of the event. #[serde(skip_serializing_if = "Option::is_none")] pub metadata: Option, @@ -757,6 +762,8 @@ impl CloudSessionRepository { } /// Options for creating a remote session in the cloud. +/// +/// **Experimental.** This API is not stable and may change or be removed. #[derive(Debug, Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] #[non_exhaustive] @@ -1190,6 +1197,8 @@ pub struct SessionConfig { pub remote_session: Option, /// Creates a remote session in the cloud instead of a local session. /// The optional repository is associated with the cloud session. + /// + /// **Experimental.** This API is not stable and may change or be removed. pub cloud: Option, /// Forward sub-agent streaming events to this connection. When false, /// only non-streaming sub-agent events and `subagent.*` lifecycle events @@ -1443,7 +1452,7 @@ impl SessionConfig { let canvas_handler = self.canvas_handler.clone(); let wire = crate::wire::SessionCreateWire { - session_id, + session_id: Some(session_id), model: self.model, client_name: self.client_name, reasoning_effort: self.reasoning_effort, @@ -1833,6 +1842,8 @@ impl SessionConfig { } /// Create a remote session in the cloud instead of a local session. + /// + /// **Experimental.** This API is not stable and may change or be removed. pub fn with_cloud(mut self, cloud: CloudSessionOptions) -> Self { self.cloud = Some(cloud); self diff --git a/rust/src/wire.rs b/rust/src/wire.rs index b97aea261..a1d1ec094 100644 --- a/rust/src/wire.rs +++ b/rust/src/wire.rs @@ -42,7 +42,8 @@ pub(crate) struct CommandWireDefinition { #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] pub(crate) struct SessionCreateWire { - pub session_id: SessionId, + #[serde(skip_serializing_if = "Option::is_none")] + pub session_id: Option, #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, #[serde(skip_serializing_if = "Option::is_none")]