Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 84 additions & 65 deletions dotnet/src/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable
/// <see cref="CopilotClient"/> that has not been explicitly disposed or removed.
/// </remarks>
internal readonly ConcurrentDictionary<string, CopilotSession> _sessions = new();
private ConcurrentDictionary<string, Func<string, CopilotSession>>? _pendingCloudCreates;

private readonly CopilotClientOptions _options;
private readonly RuntimeConnection _connection;
Expand Down Expand Up @@ -137,6 +138,7 @@ public CopilotClient(CopilotClientOptions? options = null)
{
throw new ArgumentException("ConnectionToken must be a non-empty string or null.", nameof(options));
}

// Auto-generate a connection token when the SDK spawns the runtime over TCP
// so the loopback listener is safe by default.
tcp.ConnectionToken ??= Guid.NewGuid().ToString();
Expand Down Expand Up @@ -164,6 +166,11 @@ public CopilotClient(CopilotClientOptions? options = null)
_onListModels = _options.OnListModels;
}

private ConcurrentDictionary<string, Func<string, CopilotSession>> GetOrCreatePendingCloudCreates() =>
_pendingCloudCreates ??
System.Threading.Interlocked.CompareExchange(ref _pendingCloudCreates, new(), null) ??
_pendingCloudCreates!;

/// <summary>
/// Parses a runtime URL into a URI with host and port.
/// </summary>
Expand Down Expand Up @@ -535,42 +542,22 @@ public async Task<CopilotSession> CreateSessionAsync(SessionConfig config, Cance

var (wireSystemMessage, transformCallbacks) = ExtractTransformCallbacks(config.SystemMessage);

var isCloudCreate = config.Cloud != null;
var sessionId = config.SessionId ?? Guid.NewGuid().ToString();

// Create and register the session before issuing the RPC so that
// events emitted by the CLI (e.g. session.start) are not dropped.
var setupTimestamp = Stopwatch.GetTimestamp();
var session = new CopilotSession(
sessionId,
connection.Rpc,
_logger,
this);
session.RegisterTools(config.Tools ?? []);
session.RegisterPermissionHandler(config.OnPermissionRequest);
session.RegisterCommands(config.Commands);
session.RegisterElicitationHandler(config.OnElicitationRequest);
session.RegisterExitPlanModeHandler(config.OnExitPlanModeRequest);
session.RegisterAutoModeSwitchHandler(config.OnAutoModeSwitchRequest);
if (config.OnUserInputRequest != null)
CopilotSession? session = null;
if (isCloudCreate)
{
session.RegisterUserInputHandler(config.OnUserInputRequest);
}
if (config.Hooks != null)
{
session.RegisterHooks(config.Hooks);
}
if (transformCallbacks != null)
{
session.RegisterTransformCallbacks(transformCallbacks);
GetOrCreatePendingCloudCreates()[sessionId] =
createdSessionId => CreateConfiguredSession(createdSessionId, connection, config, transformCallbacks);
}
if (config.OnEvent != null)
else
{
session.On<SessionEvent>(config.OnEvent);
session = CreateConfiguredSession(sessionId, connection, config, transformCallbacks);
}
ConfigureSessionFsHandlers(session, config.CreateSessionFsProvider);
session.SetCanvasHandler(config.CanvasHandler);
RegisterSession(session);
session.StartProcessingEvents();
LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null,
"CopilotClient.CreateSessionAsync local setup complete. Elapsed={Elapsed}, SessionId={SessionId}, Tools={ToolsCount}, Commands={CommandsCount}, Hooks={HasHooks}",
setupTimestamp,
Expand Down Expand Up @@ -632,29 +619,43 @@ public async Task<CopilotSession> CreateSessionAsync(SessionConfig config, Cance
LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null,
"CopilotClient.CreateSessionAsync session creation request completed successfully. Elapsed={Elapsed}, SessionId={SessionId}",
rpcTimestamp,
sessionId);
response.SessionId);

if (session is null)
{
_sessions.TryGetValue(response.SessionId, out session);
if (session is null)
{
throw new InvalidOperationException($"Session was not registered: {response.SessionId}");
}
}
if (isCloudCreate)
{
_pendingCloudCreates?.TryRemove(sessionId, out _);
}
Comment on lines +624 to +635

session.WorkspacePath = response.WorkspacePath;
session.SetCapabilities(response.Capabilities);
session.SetOpenCanvases(response.OpenCanvases);
}
catch (Exception ex)
{
session.RemoveFromClient();
_pendingCloudCreates?.TryRemove(sessionId, out _);
session?.RemoveFromClient();
if (ex is not OperationCanceledException)
Comment on lines 641 to 645
{
LoggingHelpers.LogTiming(_logger, LogLevel.Warning, ex,
"CopilotClient.CreateSessionAsync failed. Elapsed={Elapsed}, SessionId={SessionId}",
totalTimestamp,
sessionId);
sessionId ?? "<pending-cloud-session>");
}
throw;
}

LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null,
"CopilotClient.CreateSessionAsync complete. Elapsed={Elapsed}, SessionId={SessionId}",
totalTimestamp,
sessionId);
session.SessionId);
return session;
}

Expand Down Expand Up @@ -703,40 +704,8 @@ public async Task<CopilotSession> ResumeSessionAsync(string sessionId, ResumeSes

var (wireSystemMessage, transformCallbacks) = ExtractTransformCallbacks(config.SystemMessage);

// Create and register the session before issuing the RPC so that
// events emitted by the CLI (e.g. session.start) are not dropped.
var setupTimestamp = Stopwatch.GetTimestamp();
var session = new CopilotSession(
sessionId,
connection.Rpc,
_logger,
client: this);
session.RegisterTools(config.Tools ?? []);
session.RegisterPermissionHandler(config.OnPermissionRequest);
session.RegisterCommands(config.Commands);
session.RegisterElicitationHandler(config.OnElicitationRequest);
session.RegisterExitPlanModeHandler(config.OnExitPlanModeRequest);
session.RegisterAutoModeSwitchHandler(config.OnAutoModeSwitchRequest);
if (config.OnUserInputRequest != null)
{
session.RegisterUserInputHandler(config.OnUserInputRequest);
}
if (config.Hooks != null)
{
session.RegisterHooks(config.Hooks);
}
if (transformCallbacks != null)
{
session.RegisterTransformCallbacks(transformCallbacks);
}
if (config.OnEvent != null)
{
session.On<SessionEvent>(config.OnEvent);
}
ConfigureSessionFsHandlers(session, config.CreateSessionFsProvider);
session.SetCanvasHandler(config.CanvasHandler);
RegisterSession(session);
session.StartProcessingEvents();
var session = CreateConfiguredSession(sessionId, connection, config, transformCallbacks);
LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null,
"CopilotClient.ResumeSessionAsync local setup complete. Elapsed={Elapsed}, SessionId={SessionId}, Tools={ToolsCount}, Commands={CommandsCount}, Hooks={HasHooks}",
setupTimestamp,
Expand Down Expand Up @@ -826,6 +795,46 @@ public async Task<CopilotSession> ResumeSessionAsync(string sessionId, ResumeSes
return session;
}

private CopilotSession CreateConfiguredSession(
string sessionId,
Connection connection,
SessionConfigBase config,
Dictionary<string, Func<string, Task<string>>>? transformCallbacks)
{
var session = new CopilotSession(
sessionId,
connection.Rpc,
_logger,
this);
session.RegisterTools(config.Tools ?? []);
session.RegisterPermissionHandler(config.OnPermissionRequest);
session.RegisterCommands(config.Commands);
session.RegisterElicitationHandler(config.OnElicitationRequest);
session.RegisterExitPlanModeHandler(config.OnExitPlanModeRequest);
session.RegisterAutoModeSwitchHandler(config.OnAutoModeSwitchRequest);
if (config.OnUserInputRequest != null)
{
session.RegisterUserInputHandler(config.OnUserInputRequest);
}
if (config.Hooks != null)
{
session.RegisterHooks(config.Hooks);
}
if (transformCallbacks != null)
{
session.RegisterTransformCallbacks(transformCallbacks);
}
if (config.OnEvent != null)
{
session.On<SessionEvent>(config.OnEvent);
}
ConfigureSessionFsHandlers(session, config.CreateSessionFsProvider);
session.SetCanvasHandler(config.CanvasHandler);
RegisterSession(session);
session.StartProcessingEvents();
return session;
}

/// <summary>
/// Validates the health of the connection by sending a ping request.
/// </summary>
Expand Down Expand Up @@ -1736,7 +1745,7 @@ public void OnSessionEvent(string sessionId, JsonElement? @event)
}
}

public void OnSessionLifecycle(string type, string sessionId, JsonElement? metadata)
public void OnSessionLifecycle(string type, string sessionId, JsonElement? metadata, string? clientSessionId = null)
{
SessionLifecycleEvent evt = type switch
{
Expand All @@ -1750,13 +1759,23 @@ public void OnSessionLifecycle(string type, string sessionId, JsonElement? metad

evt.Type = type;
evt.SessionId = sessionId;
evt.ClientSessionId = clientSessionId;
if (metadata is not null)
{
evt.Metadata = JsonSerializer.Deserialize(
metadata.Value.GetRawText(),
TypesJsonContext.Default.SessionLifecycleEventMetadata);
}

if (type == "session.created" &&
!client._sessions.ContainsKey(sessionId) &&
clientSessionId is not null &&
client._pendingCloudCreates is not null &&
client._pendingCloudCreates.TryRemove(clientSessionId, out var initializeFunc))
{
initializeFunc(sessionId);
}

client.DispatchLifecycleEvent(evt);
}

Expand Down
9 changes: 9 additions & 0 deletions dotnet/src/Types.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2230,6 +2230,7 @@ public sealed class CloudSessionRepository
/// <summary>
/// Options for creating a remote session in the cloud.
/// </summary>
[Experimental(Diagnostics.Experimental)]
public sealed class CloudSessionOptions
{
/// <summary>
Expand Down Expand Up @@ -2546,6 +2547,7 @@ private SessionConfig(SessionConfig? other) : base(other)
/// Creates a remote session in the cloud instead of a local session.
/// The optional repository is associated with the cloud session.
/// </summary>
[Experimental(Diagnostics.Experimental)]
public CloudSessionOptions? Cloud { get; set; }

/// <summary>
Expand Down Expand Up @@ -3030,6 +3032,13 @@ public class SessionLifecycleEvent
[JsonPropertyName("sessionId")]
public string SessionId { get; set; } = string.Empty;

/// <summary>
/// Provisional client-generated session ID supplied for cloud session creation.
/// </summary>
[Experimental(Diagnostics.Experimental)]
[JsonPropertyName("clientSessionId")]
public string? ClientSessionId { get; set; }

/// <summary>
/// Metadata associated with the session lifecycle event.
/// </summary>
Expand Down
Loading
Loading