diff --git a/rust/README.md b/rust/README.md index 00e26dbaa..007a5a7eb 100644 --- a/rust/README.md +++ b/rust/README.md @@ -84,6 +84,22 @@ With the default `CliProgram::Resolve`, `Client::start()` resolves the CLI in th Created via `Client::create_session` or `Client::resume_session`. Owns an internal event loop that dispatches CLI callbacks to the focused handler traits you install on `SessionConfig`, and broadcasts session events through `subscribe()`. +#### Cloud sessions + +`Client::create_session` creates a Mission Control–backed cloud session when the config is built with `SessionConfig::with_cloud(...)`. The runtime owns the session ID: do **not** set `session_id` or `provider` on the config (the SDK rejects both with `Error::InvalidConfig`). + +```rust,ignore +use github_copilot_sdk::types::{CloudSessionOptions, CloudSessionRepository, SessionConfig}; + +let cloud = CloudSessionOptions::with_repository( + CloudSessionRepository::new("github", "copilot-sdk").with_branch("main"), +); +let session = client + .create_session(SessionConfig::default().with_cloud(cloud)) + .await?; +println!("cloud session id: {}", session.id()); +``` + ```rust,ignore use github_copilot_sdk::MessageOptions; diff --git a/rust/src/jsonrpc.rs b/rust/src/jsonrpc.rs index 88a9670cd..9ef64d176 100644 --- a/rust/src/jsonrpc.rs +++ b/rust/src/jsonrpc.rs @@ -63,7 +63,6 @@ pub mod error_codes { /// Invalid method parameters (-32602). pub const INVALID_PARAMS: i32 = -32602; /// Internal server error (-32603). - #[allow(dead_code, reason = "standard JSON-RPC code, reserved for future use")] pub const INTERNAL_ERROR: i32 = -32603; } diff --git a/rust/src/lib.rs b/rust/src/lib.rs index cad6ee629..c6b49b262 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1224,6 +1224,10 @@ impl Client { }), }; client.spawn_lifecycle_dispatcher(); + client + .inner + .router + .start(&client.inner.notification_tx, &client.inner.request_rx); debug!( elapsed_ms = setup_start.elapsed().as_millis(), pid = ?pid, @@ -1580,9 +1584,6 @@ impl Client { &self, session_id: &SessionId, ) -> crate::router::SessionChannels { - self.inner - .router - .ensure_started(&self.inner.notification_tx, &self.inner.request_rx); self.inner.router.register(session_id) } diff --git a/rust/src/router.rs b/rust/src/router.rs index e14630e03..abe11cada 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -21,19 +21,55 @@ struct SessionSenders { requests: mpsc::UnboundedSender, } +#[derive(Default)] +struct SessionRouterState { + sessions: HashMap, +} + +impl SessionRouterState { + fn register(&mut self, session_id: &SessionId, senders: SessionSenders) { + self.sessions.insert(session_id.clone(), senders); + } + + fn route_notification(&mut self, session_id: &str, notification: SessionEventNotification) { + if let Some(sender) = self.sessions.get(session_id) { + let _ = sender.notifications.send(notification); + } + } + + fn route_request(&mut self, request: JsonRpcRequest) { + let Some(session_id) = request + .params + .as_ref() + .and_then(|p| p.get("sessionId")) + .and_then(|v| v.as_str()) + else { + warn!(method = %request.method, "request missing sessionId"); + return; + }; + if let Some(sender) = self.sessions.get(session_id) { + let _ = sender.requests.send(request); + return; + } + warn!( + session_id = session_id, + method = %request.method, + "request for unregistered session" + ); + } +} + /// Routes notifications and requests by sessionId to per-session channels. /// /// Internal to the SDK — consumers interact via `Client::register_session()`. pub(crate) struct SessionRouter { - sessions: Arc>>, - started: Mutex, + state: Arc>, } impl SessionRouter { pub(crate) fn new() -> Self { Self { - sessions: Arc::new(Mutex::new(HashMap::new())), - started: Mutex::new(false), + state: Arc::new(Mutex::new(SessionRouterState::default())), } } @@ -41,8 +77,8 @@ impl SessionRouter { pub(crate) fn register(&self, session_id: &SessionId) -> SessionChannels { let (notif_tx, notif_rx) = mpsc::unbounded_channel(); let (req_tx, req_rx) = mpsc::unbounded_channel(); - self.sessions.lock().insert( - session_id.clone(), + self.state.lock().register( + session_id, SessionSenders { notifications: notif_tx, requests: req_tx, @@ -56,7 +92,7 @@ impl SessionRouter { /// Unregister a session, dropping its channels. pub(crate) fn unregister(&self, session_id: &SessionId) { - self.sessions.lock().remove(session_id.as_str()); + self.state.lock().sessions.remove(session_id.as_str()); } /// Snapshot every currently-registered session ID. @@ -65,7 +101,7 @@ impl SessionRouter { /// sessions for cooperative shutdown without holding the router lock /// across `.await`. pub(crate) fn session_ids(&self) -> Vec { - self.sessions.lock().keys().cloned().collect() + self.state.lock().sessions.keys().cloned().collect() } /// Drop all registered session channels. @@ -73,27 +109,22 @@ impl SessionRouter { /// Used by [`Client::force_stop`](crate::Client::force_stop) to release /// per-session state without waiting for graceful unregistration. pub(crate) fn clear(&self) { - self.sessions.lock().clear(); + self.state.lock().sessions.clear(); } - /// Start the router tasks if not already running. + /// Spawn the notification and request routing tasks. /// - /// Takes the notification broadcast and request channel from the Client. - /// If `request_rx` is `None` (already taken by `take_request_rx()`), - /// only notification routing is available. - pub(crate) fn ensure_started( + /// Called exactly once during [`Client::from_streams`]. Takes the + /// notification broadcast and request channel from the Client. If + /// `request_rx` is `None` (already taken by `take_request_rx()`), only + /// notification routing is available. + pub(crate) fn start( &self, notification_tx: &broadcast::Sender, request_rx: &Mutex>>, ) { - let mut started = self.started.lock(); - if *started { - return; - } - *started = true; - // Notification routing task - let sessions = self.sessions.clone(); + let state = self.state.clone(); let mut notif_rx = notification_tx.subscribe(); tokio::spawn(async move { loop { @@ -110,27 +141,20 @@ impl SessionRouter { continue; }; - let sender = { - let guard = sessions.lock(); - guard.get(session_id).map(|s| s.notifications.clone()) - }; - if let Some(sender) = sender { - match serde_json::from_value::(params.clone()) - { - Ok(event_notification) => { - let _ = sender.send(event_notification); - } - Err(e) => { - warn!( - error = %e, - session_id = session_id, - "failed to deserialize session event notification" - ); - } + match serde_json::from_value::(params.clone()) { + Ok(event_notification) => { + state + .lock() + .route_notification(session_id, event_notification); + } + Err(e) => { + warn!( + error = %e, + session_id = session_id, + "failed to deserialize session event notification" + ); } } - // Unknown session IDs are silently dropped — the session - // may have been unregistered between dispatch and delivery. } Err(broadcast::error::RecvError::Lagged(n)) => { warn!(missed = n, "notification router lagged"); @@ -142,37 +166,85 @@ impl SessionRouter { // Request routing task (if request_rx is available) if let Some(mut rx) = request_rx.lock().take() { - let sessions = self.sessions.clone(); + let state = self.state.clone(); tokio::spawn(async move { while let Some(request) = rx.recv().await { - let session_id = request - .params - .as_ref() - .and_then(|p| p.get("sessionId")) - .and_then(|v| v.as_str()); - - if let Some(sid) = session_id { - let sender = { - let guard = sessions.lock(); - guard.get(sid).map(|s| s.requests.clone()) - }; - if let Some(sender) = sender { - let _ = sender.send(request); - } else { - warn!( - session_id = sid, - method = %request.method, - "request for unregistered session" - ); - } - } else { - warn!( - method = %request.method, - "request missing sessionId" - ); - } + state.lock().route_request(request); } }); } } } + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::*; + use crate::jsonrpc::JsonRpcRequest; + + fn make_notification(session_id: &str, kind: &str) -> SessionEventNotification { + let value = json!({ + "sessionId": session_id, + "event": { + "id": "evt-id", + "timestamp": "1970-01-01T00:00:00Z", + "parentId": null, + "type": kind, + "data": {}, + }, + }); + serde_json::from_value(value).expect("valid session event notification") + } + + fn make_request(id: u64, session_id: &str, method: &str) -> JsonRpcRequest { + JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id, + method: method.to_string(), + params: Some(json!({ "sessionId": session_id })), + } + } + + #[test] + fn drops_unknown_session_notifications() { + let router = SessionRouter::new(); + router + .state + .lock() + .route_notification("ghost", make_notification("ghost", "session.start")); + + let channels = router.register(&SessionId::from("ghost")); + assert!(channels.notifications.is_empty()); + } + + #[test] + fn drops_unknown_session_requests() { + let router = SessionRouter::new(); + router + .state + .lock() + .route_request(make_request(1, "ghost", "userInput.request")); + + let channels = router.register(&SessionId::from("ghost")); + assert!(channels.requests.is_empty()); + } + + #[test] + fn routes_registered_session_messages() { + let router = SessionRouter::new(); + let sid = SessionId::from("remote"); + let mut channels = router.register(&sid); + + { + let mut state = router.state.lock(); + state.route_notification("remote", make_notification("remote", "evt")); + state.route_request(make_request(1, "remote", "userInput.request")); + } + + let notification = channels.notifications.try_recv().expect("notification"); + assert_eq!(notification.event.event_type, "evt"); + let request = channels.requests.try_recv().expect("request"); + assert_eq!(request.id, 1); + } +} diff --git a/rust/src/session.rs b/rust/src/session.rs index 57181459c..39c4031e2 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -28,12 +28,14 @@ use crate::types::{ CommandContext, CommandDefinition, CommandHandler, CreateSessionResult, ElicitationRequest, ElicitationResult, ExitPlanModeData, GetMessagesResponse, MessageOptions, PermissionRequestData, RequestId, ResumeSessionConfig, ResumeSessionResult, SectionOverride, - SessionCapabilities, SessionConfig, SessionEvent, SessionId, SetModelOptions, - SystemMessageConfig, ToolInvocation, ToolResult, ToolResultExpanded, TraceContext, - UiInputOptions, ensure_attachment_display_names, + SessionCapabilities, SessionConfig, SessionConfigRuntime, SessionEvent, SessionId, + SetModelOptions, SystemMessageConfig, ToolInvocation, ToolResult, ToolResultExpanded, + TraceContext, UiInputOptions, ensure_attachment_display_names, }; use crate::{Client, Error, JsonRpcResponse, SessionError, SessionEventNotification, error_codes}; +type CommandHandlerMap = HashMap>; + /// Bundle of the per-session callbacks the SDK dispatches to. Built from a /// [`SessionConfig`] / [`ResumeSessionConfig`] at /// [`Client::create_session`] / [`Client::resume_session`] time. Each @@ -51,6 +53,49 @@ pub(crate) struct SessionHandlers { pub tools: Arc>>, } +/// Bundle of everything `create_session` / `resume_session` need to spawn +/// the per-session event loop, extracted from a `SessionConfigRuntime`. +/// Built by [`prepare_session_runtime`]. +struct PreparedSessionRuntime { + handlers: SessionHandlers, + hooks: Option>, + transforms: Option>, + command_handlers: Arc, + canvas_handler: Option>, + session_fs_provider: Option>, + commands_count: usize, + has_hooks: bool, +} + +enum CreateSessionKind { + Local { session_id: SessionId }, + Cloud, +} + +struct PreparedCreateSession { + kind: CreateSessionKind, + params: serde_json::Value, + runtime: PreparedSessionRuntime, + tools_count: usize, + commands_count: usize, + has_hooks: bool, +} + +struct CreateSessionStats { + tools_count: usize, + commands_count: usize, + has_hooks: bool, + total_start: Instant, +} + +struct RunningSession { + event_loop: JoinHandle<()>, + shutdown: CancellationToken, + idle_waiter: Arc>>, + capabilities: Arc>, + event_tx: tokio::sync::broadcast::Sender, +} + /// Shared state between a [`Session`] and its event loop, used by [`Session::send_and_wait`]. struct IdleWaiter { tx: oneshot::Sender, Error>>, @@ -785,138 +830,34 @@ impl Client { /// Each per-event handler is independently optional. If a handler is /// not installed, the SDK signals the runtime not to emit the matching /// broadcast (and silently skips dispatch if one arrives anyway). - pub async fn create_session(&self, mut config: SessionConfig) -> Result { + /// + /// If [`SessionConfig::with_cloud`] was used, this creates a cloud + /// (Mission Control) session instead of a local session. The runtime + /// assigns the session ID for cloud sessions, so callers must not set + /// [`SessionConfig::session_id`] or [`SessionConfig::provider`]. + pub async fn create_session(&self, config: SessionConfig) -> Result { let total_start = Instant::now(); - let session_id = config - .session_id - .clone() - .unwrap_or_else(|| SessionId::from(uuid::Uuid::new_v4().to_string())); - config.session_id = Some(session_id.clone()); - if config.hooks_handler.is_some() && config.hooks.is_none() { - config.hooks = Some(true); - } - if let Some(transforms) = config.system_message_transform.clone() { - inject_transform_sections(&mut config, transforms.as_ref()); - } - let (wire, mut runtime) = config.into_wire(session_id.clone())?; - - let permission_handler = crate::permission::resolve_handler( - runtime.permission_handler.take(), - runtime.permission_policy.take(), - ); - let handlers = SessionHandlers { - permission: permission_handler, - elicitation: runtime.elicitation_handler.take(), - user_input: runtime.user_input_handler.take(), - exit_plan_mode: runtime.exit_plan_mode_handler.take(), - auto_mode_switch: runtime.auto_mode_switch_handler.take(), - tools: Arc::new(std::mem::take(&mut runtime.tool_handlers)), - }; - let hooks = runtime.hooks_handler.take(); - let transforms = runtime.system_message_transform.take(); - let tools_count = wire.tools.as_ref().map_or(0, Vec::len); - let commands_count = runtime.commands.as_ref().map_or(0, Vec::len); - let has_hooks = hooks.is_some(); - let command_handlers = build_command_handler_map(runtime.commands.as_deref()); - let canvas_handler = runtime.canvas_handler.take(); - let session_fs_provider = runtime.session_fs_provider.take(); - if self.inner.session_fs_configured && session_fs_provider.is_none() { - return Err(Error::Session(SessionError::SessionFsProviderRequired)); - } - if self.inner.session_fs_sqlite_declared - && let Some(ref provider) = session_fs_provider - && provider.sqlite().is_none() - { - return Err(Error::InvalidConfig( - "SessionFs capabilities declare SQLite support but the provider \ - does not implement SessionFsSqliteProvider" - .to_string(), - )); - } - - let mut params = serde_json::to_value(&wire)?; - let trace_ctx = self.resolve_trace_context().await; - inject_trace_context(&mut params, &trace_ctx); - - let setup_start = Instant::now(); - let capabilities = Arc::new(parking_lot::RwLock::new(SessionCapabilities::default())); - let channels = self.register_session(&session_id); - let idle_waiter = Arc::new(ParkingLotMutex::new(None)); - let shutdown = CancellationToken::new(); - let (event_tx, _) = tokio::sync::broadcast::channel(512); - let event_loop = spawn_event_loop( - session_id.clone(), - self.clone(), - handlers, - hooks, - transforms, - command_handlers, - canvas_handler, - session_fs_provider, - channels, - idle_waiter.clone(), - capabilities.clone(), - event_tx.clone(), - shutdown.clone(), - ); - let mut registration = - PendingSessionRegistration::new(self.clone(), session_id.clone(), shutdown.clone()); - tracing::debug!( - elapsed_ms = setup_start.elapsed().as_millis(), - session_id = %session_id, + let PreparedCreateSession { + kind, + params, + runtime, tools_count, commands_count, has_hooks, - "Client::create_session local setup complete" - ); - - let rpc_start = Instant::now(); - let result = match self.call("session.create", Some(params)).await { - Ok(result) => result, - Err(error) => { - registration.cleanup(event_loop).await; - return Err(error); - } + } = prepare_create_session_request(self, config).await?; + let stats = CreateSessionStats { + tools_count, + commands_count, + has_hooks, + total_start, }; - tracing::debug!( - elapsed_ms = rpc_start.elapsed().as_millis(), - "Client::create_session session creation request completed successfully" - ); - let create_result: CreateSessionResult = match serde_json::from_value(result) { - Ok(result) => result, - Err(error) => { - registration.cleanup(event_loop).await; - return Err(error.into()); + + match kind { + CreateSessionKind::Local { session_id } => { + create_local_session(self, session_id, params, runtime, stats).await } - }; - if create_result.session_id != session_id { - registration.cleanup(event_loop).await; - return Err(Error::Session(SessionError::SessionIdMismatch { - requested: session_id, - returned: create_result.session_id, - })); + CreateSessionKind::Cloud => create_cloud_session(self, params, runtime, stats).await, } - *capabilities.write() = create_result.capabilities.unwrap_or_default(); - - tracing::debug!( - elapsed_ms = total_start.elapsed().as_millis(), - session_id = %session_id, - "Client::create_session complete" - ); - registration.disarm(); - Ok(Session { - id: session_id, - cwd: self.cwd().clone(), - workspace_path: create_result.workspace_path, - remote_url: create_result.remote_url, - client: self.clone(), - event_loop: ParkingLotMutex::new(Some(event_loop)), - shutdown, - idle_waiter, - capabilities, - open_canvases: Arc::new(parking_lot::RwLock::new(Vec::new())), - event_tx, - }) } /// Resume an existing session on the CLI. @@ -938,41 +879,19 @@ impl Client { if let Some(transforms) = config.system_message_transform.clone() { inject_transform_sections_resume(&mut config, transforms.as_ref()); } - let (wire, mut runtime) = config.into_wire()?; - - let permission_handler = crate::permission::resolve_handler( - runtime.permission_handler.take(), - runtime.permission_policy.take(), - ); - let handlers = SessionHandlers { - permission: permission_handler, - elicitation: runtime.elicitation_handler.take(), - user_input: runtime.user_input_handler.take(), - exit_plan_mode: runtime.exit_plan_mode_handler.take(), - auto_mode_switch: runtime.auto_mode_switch_handler.take(), - tools: Arc::new(std::mem::take(&mut runtime.tool_handlers)), - }; - let hooks = runtime.hooks_handler.take(); - let transforms = runtime.system_message_transform.take(); + let (wire, runtime) = config.into_wire()?; let tools_count = wire.tools.as_ref().map_or(0, Vec::len); - let commands_count = runtime.commands.as_ref().map_or(0, Vec::len); - let has_hooks = hooks.is_some(); - let command_handlers = build_command_handler_map(runtime.commands.as_deref()); - let canvas_handler = runtime.canvas_handler.take(); - let session_fs_provider = runtime.session_fs_provider.take(); - if self.inner.session_fs_configured && session_fs_provider.is_none() { - return Err(Error::Session(SessionError::SessionFsProviderRequired)); - } - if self.inner.session_fs_sqlite_declared - && let Some(ref provider) = session_fs_provider - && provider.sqlite().is_none() - { - return Err(Error::InvalidConfig( - "SessionFs capabilities declare SQLite support but the provider \ - does not implement SessionFsSqliteProvider" - .to_string(), - )); - } + + let PreparedSessionRuntime { + handlers, + hooks, + transforms, + command_handlers, + canvas_handler, + session_fs_provider, + commands_count, + has_hooks, + } = prepare_session_runtime(self, runtime)?; let mut params = serde_json::to_value(&wire)?; let trace_ctx = self.resolve_trace_context().await; @@ -1093,7 +1012,278 @@ impl Client { } } -type CommandHandlerMap = HashMap>; +async fn prepare_create_session_request( + client: &Client, + mut config: SessionConfig, +) -> Result { + let kind = if config.cloud.is_some() { + if config.session_id.is_some() { + return Err(Error::InvalidConfig( + "cloud session creation does not accept a caller-provided \ + session_id; the runtime assigns the session id" + .to_string(), + )); + } + if config.provider.is_some() { + return Err(Error::InvalidConfig( + "cloud session creation does not accept a caller-provided \ + provider; the runtime selects the provider" + .to_string(), + )); + } + CreateSessionKind::Cloud + } else { + let session_id = config + .session_id + .clone() + .unwrap_or_else(|| SessionId::from(uuid::Uuid::new_v4().to_string())); + config.session_id = Some(session_id.clone()); + CreateSessionKind::Local { session_id } + }; + + if config.hooks_handler.is_some() && config.hooks.is_none() { + config.hooks = Some(true); + } + if let Some(transforms) = config.system_message_transform.clone() { + inject_transform_sections(&mut config, transforms.as_ref()); + } + + let (wire, runtime) = match &kind { + CreateSessionKind::Local { session_id } => config.into_wire(session_id.clone())?, + CreateSessionKind::Cloud => config.into_cloud_wire()?, + }; + let tools_count = wire.tools.as_ref().map_or(0, Vec::len); + let runtime = prepare_session_runtime(client, runtime)?; + let commands_count = runtime.commands_count; + let has_hooks = runtime.has_hooks; + + let mut params = serde_json::to_value(&wire)?; + let trace_ctx = client.resolve_trace_context().await; + inject_trace_context(&mut params, &trace_ctx); + + Ok(PreparedCreateSession { + kind, + params, + runtime, + tools_count, + commands_count, + has_hooks, + }) +} + +async fn create_local_session( + client: &Client, + session_id: SessionId, + params: serde_json::Value, + runtime: PreparedSessionRuntime, + stats: CreateSessionStats, +) -> Result { + let setup_start = Instant::now(); + let running = start_registered_session(client, &session_id, runtime); + let mut registration = PendingSessionRegistration::new( + client.clone(), + session_id.clone(), + running.shutdown.clone(), + ); + tracing::debug!( + elapsed_ms = setup_start.elapsed().as_millis(), + session_id = %session_id, + tools_count = stats.tools_count, + commands_count = stats.commands_count, + has_hooks = stats.has_hooks, + "Client::create_session local setup complete" + ); + + let rpc_start = Instant::now(); + let result = match client.call("session.create", Some(params)).await { + Ok(result) => result, + Err(error) => { + registration.cleanup(running.event_loop).await; + return Err(error); + } + }; + tracing::debug!( + elapsed_ms = rpc_start.elapsed().as_millis(), + "Client::create_session session creation request completed successfully" + ); + let create_result: CreateSessionResult = match serde_json::from_value(result) { + Ok(result) => result, + Err(error) => { + registration.cleanup(running.event_loop).await; + return Err(error.into()); + } + }; + if create_result.session_id != session_id { + registration.cleanup(running.event_loop).await; + return Err(Error::Session(SessionError::SessionIdMismatch { + requested: session_id, + returned: create_result.session_id, + })); + } + + tracing::debug!( + elapsed_ms = stats.total_start.elapsed().as_millis(), + session_id = %session_id, + "Client::create_session complete" + ); + registration.disarm(); + Ok(session_from_create_result( + client, + session_id, + create_result, + running, + )) +} + +async fn create_cloud_session( + client: &Client, + params: serde_json::Value, + runtime: PreparedSessionRuntime, + stats: CreateSessionStats, +) -> Result { + let setup_start = Instant::now(); + tracing::debug!( + elapsed_ms = setup_start.elapsed().as_millis(), + tools_count = stats.tools_count, + commands_count = stats.commands_count, + has_hooks = stats.has_hooks, + "Client::create_session cloud setup complete" + ); + + let rpc_start = Instant::now(); + let result = client.call("session.create", Some(params)).await?; + tracing::debug!( + elapsed_ms = rpc_start.elapsed().as_millis(), + "Client::create_session cloud creation request completed successfully" + ); + let create_result = decode_cloud_create_result(client, result).await?; + let session_id = create_result.session_id.clone(); + let running = start_registered_session(client, &session_id, runtime); + + tracing::debug!( + elapsed_ms = stats.total_start.elapsed().as_millis(), + session_id = %session_id, + "Client::create_session cloud complete" + ); + Ok(session_from_create_result( + client, + session_id, + create_result, + running, + )) +} + +async fn decode_cloud_create_result( + client: &Client, + result: serde_json::Value, +) -> Result { + let recovered_session_id = result + .get("sessionId") + .and_then(|value| value.as_str()) + .map(SessionId::from); + match serde_json::from_value(result) { + Ok(result) => Ok(result), + Err(error) => { + if let Some(recovered_id) = recovered_session_id { + if let Err(destroy_err) = client + .call( + "session.destroy", + Some(serde_json::json!({ "sessionId": recovered_id })), + ) + .await + { + tracing::warn!( + session_id = %recovered_id, + error = %destroy_err, + "failed to destroy cloud session after create response decode failed" + ); + } + } else { + tracing::warn!( + "Client::create_session cloud decode failure with no recoverable session id; \ + skipping session.destroy (runtime session may leak)" + ); + } + Err(error.into()) + } + } +} + +fn start_registered_session( + client: &Client, + session_id: &SessionId, + runtime: PreparedSessionRuntime, +) -> RunningSession { + let PreparedSessionRuntime { + handlers, + hooks, + transforms, + command_handlers, + canvas_handler, + session_fs_provider, + commands_count: _, + has_hooks: _, + } = runtime; + + let capabilities = Arc::new(parking_lot::RwLock::new(SessionCapabilities::default())); + let channels = client.register_session(session_id); + let idle_waiter = Arc::new(ParkingLotMutex::new(None)); + let shutdown = CancellationToken::new(); + let (event_tx, _) = tokio::sync::broadcast::channel(512); + let event_loop = spawn_event_loop( + session_id.clone(), + client.clone(), + handlers, + hooks, + transforms, + command_handlers, + canvas_handler, + session_fs_provider, + channels, + idle_waiter.clone(), + capabilities.clone(), + event_tx.clone(), + shutdown.clone(), + ); + + RunningSession { + event_loop, + shutdown, + idle_waiter, + capabilities, + event_tx, + } +} + +fn session_from_create_result( + client: &Client, + session_id: SessionId, + create_result: CreateSessionResult, + running: RunningSession, +) -> Session { + let RunningSession { + event_loop, + shutdown, + idle_waiter, + capabilities, + event_tx, + } = running; + *capabilities.write() = create_result.capabilities.unwrap_or_default(); + + Session { + id: session_id, + cwd: client.cwd().clone(), + workspace_path: create_result.workspace_path, + remote_url: create_result.remote_url, + client: client.clone(), + event_loop: ParkingLotMutex::new(Some(event_loop)), + shutdown, + idle_waiter, + capabilities, + open_canvases: Arc::new(parking_lot::RwLock::new(Vec::new())), + event_tx, + } +} fn build_command_handler_map(commands: Option<&[CommandDefinition]>) -> Arc { let map = match commands { @@ -1107,6 +1297,62 @@ fn build_command_handler_map(commands: Option<&[CommandDefinition]>) -> Arc Result { + let SessionConfigRuntime { + permission_handler, + permission_policy, + elicitation_handler, + user_input_handler, + exit_plan_mode_handler, + auto_mode_switch_handler, + hooks_handler, + system_message_transform, + tool_handlers, + canvas_handler, + session_fs_provider, + commands, + } = runtime; + let handlers = SessionHandlers { + permission: crate::permission::resolve_handler(permission_handler, permission_policy), + elicitation: elicitation_handler, + user_input: user_input_handler, + exit_plan_mode: exit_plan_mode_handler, + auto_mode_switch: auto_mode_switch_handler, + tools: Arc::new(tool_handlers), + }; + let commands_count = commands.as_ref().map_or(0, Vec::len); + let has_hooks = hooks_handler.is_some(); + let command_handlers = build_command_handler_map(commands.as_deref()); + + if client.inner.session_fs_configured && session_fs_provider.is_none() { + return Err(Error::Session(SessionError::SessionFsProviderRequired)); + } + if client.inner.session_fs_sqlite_declared + && let Some(ref provider) = session_fs_provider + && provider.sqlite().is_none() + { + return Err(Error::InvalidConfig( + "SessionFs capabilities declare SQLite support but the provider \ + does not implement SessionFsSqliteProvider" + .to_string(), + )); + } + + Ok(PreparedSessionRuntime { + handlers, + hooks: hooks_handler, + transforms: system_message_transform, + command_handlers, + canvas_handler, + session_fs_provider, + commands_count, + has_hooks, + }) +} + #[allow(clippy::too_many_arguments)] fn spawn_event_loop( session_id: SessionId, diff --git a/rust/src/types.rs b/rust/src/types.rs index f454e33ed..16a8eeae4 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -1188,8 +1188,12 @@ pub struct SessionConfig { /// enabling remote steering /// - `On` — export to GitHub AND enable remote steering pub remote_session: Option, - /// Creates a remote session in the cloud instead of a local session. - /// The optional repository is associated with the cloud session. + /// Creates a cloud-backed session through + /// [`Client::create_session`](crate::Client::create_session). + /// + /// The runtime assigns the session ID for cloud sessions, so do not set + /// [`session_id`](Self::session_id) or [`provider`](Self::provider) when + /// this field is set. pub cloud: Option, /// Forward sub-agent streaming events to this connection. When false, /// only non-streaming sub-agent events and `subagent.*` lifecycle events @@ -1406,8 +1410,25 @@ impl SessionConfig { /// /// [`SessionCreateWire`]: crate::wire::SessionCreateWire pub(crate) fn into_wire( - mut self, + self, session_id: SessionId, + ) -> Result<(crate::wire::SessionCreateWire, SessionConfigRuntime), crate::Error> { + self.into_create_wire(Some(session_id)) + } + + /// Consume this config to produce the [`SessionCreateWire`] payload for + /// cloud `session.create`. Cloud create follows the runtime contract: + /// the caller does not provide a `sessionId`; the runtime returns the + /// Mission Control task/session ID. + pub(crate) fn into_cloud_wire( + self, + ) -> Result<(crate::wire::SessionCreateWire, SessionConfigRuntime), crate::Error> { + self.into_create_wire(None) + } + + fn into_create_wire( + mut self, + session_id: Option, ) -> Result<(crate::wire::SessionCreateWire, SessionConfigRuntime), crate::Error> { let permission_active = self.permission_handler.is_some() || self.permission_policy.is_some(); @@ -1832,7 +1853,8 @@ impl SessionConfig { self } - /// Create a remote session in the cloud instead of a local session. + /// Create a cloud-backed session through + /// [`Client::create_session`](crate::Client::create_session). pub fn with_cloud(mut self, cloud: CloudSessionOptions) -> Self { self.cloud = Some(cloud); self diff --git a/rust/src/wire.rs b/rust/src/wire.rs index b97aea261..a1d1ec094 100644 --- a/rust/src/wire.rs +++ b/rust/src/wire.rs @@ -42,7 +42,8 @@ pub(crate) struct CommandWireDefinition { #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] pub(crate) struct SessionCreateWire { - pub session_id: SessionId, + #[serde(skip_serializing_if = "Option::is_none")] + pub session_id: Option, #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, #[serde(skip_serializing_if = "Option::is_none")] diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index bb4e602e0..c6b5a0b8c 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -16,9 +16,10 @@ use github_copilot_sdk::handler::{ ExitPlanModeHandler, ExitPlanModeResult, UserInputHandler, UserInputResponse, }; use github_copilot_sdk::types::{ - CommandContext, CommandDefinition, CommandHandler, DeliveryMode, ElicitationRequest, - ElicitationResult, ExitPlanModeData, ExtensionInfo, MessageOptions, RequestId, SessionConfig, - SessionId, Tool, ToolInvocation, ToolResult, + CloudSessionOptions, CloudSessionRepository, CommandContext, CommandDefinition, CommandHandler, + DeliveryMode, ElicitationRequest, ElicitationResult, ExitPlanModeData, ExtensionInfo, + MessageOptions, ProviderConfig, RequestId, SessionConfig, SessionId, Tool, ToolInvocation, + ToolResult, }; use github_copilot_sdk::{Client, tool}; use serde_json::Value; @@ -225,6 +226,22 @@ fn requested_session_id(request: &Value) -> &str { .expect("session request should include sessionId") } +fn cloud_session_config() -> SessionConfig { + SessionConfig::default().with_cloud(CloudSessionOptions::with_repository( + CloudSessionRepository::new("github", "copilot-sdk").with_branch("main"), + )) +} + +fn expect_sdk_error( + result: Result, + message: &str, +) -> github_copilot_sdk::Error { + match result { + Ok(_) => panic!("{message}"), + Err(error) => error, + } +} + #[tokio::test] async fn session_subscribe_yields_events_observe_only() { let (session, mut server) = create_session_pair().await; @@ -372,6 +389,172 @@ async fn create_session_sends_canvas_wire_fields() { timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); } +#[tokio::test] +async fn create_session_sends_cloud_create_without_session_id() { + let (client, mut server_read, mut server_write) = make_client(); + + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { client.create_session(cloud_session_config()).await.unwrap() } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.create"); + assert!(request["params"].get("sessionId").is_none()); + assert_eq!(request["params"]["cloud"]["repository"]["owner"], "github"); + assert_eq!( + request["params"]["cloud"]["repository"]["name"], + "copilot-sdk" + ); + assert_eq!(request["params"]["cloud"]["repository"]["branch"], "main"); + assert!(request["params"].get("provider").is_none()); + + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { + "sessionId": "remote-cloud-session", + "remoteUrl": "https://copilot.example.test/agents/remote-cloud-session", + "capabilities": { "ui": { "elicitation": true } } + }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + assert_eq!(session.id(), "remote-cloud-session"); + assert_eq!( + session.remote_url(), + Some("https://copilot.example.test/agents/remote-cloud-session") + ); + assert_eq!( + session.capabilities().ui.and_then(|ui| ui.elicitation), + Some(true) + ); +} + +#[tokio::test] +async fn create_session_rejects_cloud_session_id_and_provider() { + let (client, _server_read, _server_write) = make_client(); + + let error = expect_sdk_error( + client + .create_session(cloud_session_config().with_session_id("caller-id")) + .await, + "cloud create should reject caller session id", + ); + assert!( + matches!(error, github_copilot_sdk::Error::InvalidConfig(ref message) if message.contains("session_id")), + "unexpected error: {error:?}", + ); + + let mut config = cloud_session_config(); + config.provider = Some(ProviderConfig::new("https://api.example.test/v1")); + let error = expect_sdk_error( + client.create_session(config).await, + "cloud create should reject provider", + ); + assert!( + matches!(error, github_copilot_sdk::Error::InvalidConfig(ref message) if message.contains("provider")), + "unexpected error: {error:?}", + ); +} + +#[tokio::test] +async fn create_session_cloud_request_flags_follow_handlers() { + struct InputHandler; + #[async_trait] + impl UserInputHandler for InputHandler { + async fn handle( + &self, + _session_id: SessionId, + _question: String, + _choices: Option>, + _allow_freeform: Option, + ) -> Option { + None + } + } + + struct ExitHandler; + #[async_trait] + impl ExitPlanModeHandler for ExitHandler { + async fn handle( + &self, + _session_id: SessionId, + _data: ExitPlanModeData, + ) -> ExitPlanModeResult { + ExitPlanModeResult::default() + } + } + + struct AutoHandler; + #[async_trait] + impl AutoModeSwitchHandler for AutoHandler { + async fn handle( + &self, + _session_id: SessionId, + _error_code: Option, + _retry_after_seconds: Option, + ) -> AutoModeSwitchResponse { + AutoModeSwitchResponse::No + } + } + + struct ElicitHandler; + #[async_trait] + impl ElicitationHandler for ElicitHandler { + async fn handle( + &self, + _session_id: SessionId, + _request_id: RequestId, + _request: ElicitationRequest, + ) -> ElicitationResult { + ElicitationResult { + action: "cancel".to_string(), + content: None, + } + } + } + + let (client, mut server_read, mut server_write) = make_client(); + + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session( + cloud_session_config() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_user_input_handler(Arc::new(InputHandler)) + .with_exit_plan_mode_handler(Arc::new(ExitHandler)) + .with_auto_mode_switch_handler(Arc::new(AutoHandler)) + .with_elicitation_handler(Arc::new(ElicitHandler)), + ) + .await + .unwrap() + } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.create"); + assert_eq!(request["params"]["requestPermission"], true); + assert_eq!(request["params"]["requestUserInput"], true); + assert_eq!(request["params"]["requestExitPlanMode"], true); + assert_eq!(request["params"]["requestAutoModeSwitch"], true); + assert_eq!(request["params"]["requestElicitation"], true); + + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "sessionId": "remote-cloud-session" }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); +} + #[tokio::test] async fn provider_canvas_dispatch_routes_direct_canvas_action_requests() { let (session, mut server) = create_session_pair_with_config(|cfg| {