Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src-tauri/src/agents/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,8 @@ async fn run_scheduled_agent_with_fallback(
// synthetic `scheduled:{instance_id}:{uuid}` string, which meant
// `assistant_cancel_run(run.id)` couldn't find the token and
// scheduled runs were effectively un-cancellable from the UI.
let cancel_token = runtime::register_run(&run.id);
let run_registration = runtime::register_run(&run.id);
let cancel_token = run_registration.token();
let input = RunTurnInput {
session_id: session.id.clone(),
run_id: Some(run.id.clone()),
Expand All @@ -720,7 +721,7 @@ async fn run_scheduled_agent_with_fallback(
trigger_message_id: None,
};
let result = engine::run_session_turn(&deps, input).await;
runtime::unregister_run(&run.id);
drop(run_registration);

match result {
Ok(()) => {
Expand Down
17 changes: 17 additions & 0 deletions src-tauri/src/assistant/compaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,23 @@ pub async fn compact_session_history(
}))
}

pub async fn compact_for_context_limit_recovery(
pool: &DbPool,
session: &AssistantSession,
connection: &ProviderConnection,
run_id: &str,
) -> Result<Option<CompactionOutcome>, String> {
compact_session_history(
pool,
session,
connection,
CompactionTrigger::ErrorRecovery,
Some(run_id),
true,
)
.await
}

fn provider_history_messages_with_compaction(
messages: &[AssistantMessage],
latest: Option<&AssistantCompaction>,
Expand Down
8 changes: 2 additions & 6 deletions src-tauri/src/assistant/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use crate::assistant::providers;
use crate::assistant::providers::types::ProviderError;
use crate::assistant::repository;
use crate::assistant::repository::{CreateMessageParams, CreateRunParams, CreateToolCallParams};
use crate::assistant::runtime;
use crate::assistant::tools::{self, ToolExecutionContext};
use crate::assistant::types::{
AssistantMessage, CompactionTrigger, CompletionRequest, ContentPart, MessageRole,
Expand Down Expand Up @@ -319,13 +318,11 @@ pub async fn run_session_turn(
&& compaction::is_context_limit_error(&e.to_string())
{
retried_after_context_compaction = true;
match compaction::compact_session_history(
match compaction::compact_for_context_limit_recovery(
&deps.pool,
&session,
&connection,
CompactionTrigger::ErrorRecovery,
Some(&run_id),
true,
&run_id,
)
.await
{
Expand Down Expand Up @@ -3153,6 +3150,5 @@ async fn cancel_run(
Some(run_id),
AssistantUiEvent::RunCancelled { run },
);
runtime::unregister_run(run_id);
Ok(())
}
6 changes: 2 additions & 4 deletions src-tauri/src/assistant/local_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,13 +425,11 @@ pub async fn run_session_turn(
"{} reported a context limit; compacting local history and restarting with a fresh session",
provider_runtime.display_name()
);
match compaction::compact_session_history(
match compaction::compact_for_context_limit_recovery(
&deps.pool,
&session,
&connection,
CompactionTrigger::ErrorRecovery,
Some(&run_id),
true,
&run_id,
)
.await
{
Expand Down
17 changes: 17 additions & 0 deletions src-tauri/src/assistant/repository.rs
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,23 @@ pub async fn session_has_active_run(pool: &DbPool, session_id: &str) -> Result<b
Ok(count > 0)
}

/// Whether the workspace database contains any non-terminal run. Workspace
/// deletion uses this as a coarse guard before removing the root directory.
pub async fn workspace_has_active_run(pool: &DbPool) -> Result<bool, String> {
let (count,): (i64,) = sqlx::query_as(
r#"
SELECT COUNT(*)
FROM assistant_runs
WHERE status IN ('"queued"', '"running"', '"waiting_for_tool"')
"#,
)
.fetch_one(pool)
.await
.map_err(|e| format!("Failed to check workspace for active runs: {}", e))?;

Ok(count > 0)
}

pub async fn get_active_run(
pool: &DbPool,
session_id: &str,
Expand Down
61 changes: 61 additions & 0 deletions src-tauri/src/assistant/repository_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,67 @@ async fn test_get_active_run_ignores_terminal_runs() {
);
}

#[tokio::test]
async fn test_workspace_has_active_run_tracks_any_session() {
let pool = setup_test_pool().await;

let first = create_session(
&pool,
CreateSessionParams {
kind: SessionKind::Interactive,
title: Some("first".to_string()),
context: sample_context(),
},
)
.await
.unwrap();
let second = create_session(
&pool,
CreateSessionParams {
kind: SessionKind::Interactive,
title: Some("second".to_string()),
context: sample_context(),
},
)
.await
.unwrap();

assert!(!workspace_has_active_run(&pool).await.unwrap());

create_run(
&pool,
CreateRunParams {
session_id: first.id.clone(),
status: RunStatus::Completed,
trigger: RunTrigger::UserMessage,
connection_id: "conn-1".to_string(),
provider_id: "openai".to_string(),
model_id: "gpt-4".to_string(),
usage: None,
error: None,
},
)
.await
.unwrap();
create_run(
&pool,
CreateRunParams {
session_id: second.id.clone(),
status: RunStatus::WaitingForTool,
trigger: RunTrigger::UserMessage,
connection_id: "conn-1".to_string(),
provider_id: "openai".to_string(),
model_id: "gpt-4".to_string(),
usage: None,
error: None,
},
)
.await
.unwrap();

assert!(workspace_has_active_run(&pool).await.unwrap());
}

// ---------------------------------------------------------------------------
// Run CRUD
// ---------------------------------------------------------------------------
Expand Down
89 changes: 69 additions & 20 deletions src-tauri/src/assistant/runtime.rs
Original file line number Diff line number Diff line change
@@ -1,38 +1,79 @@
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Mutex, OnceLock};

use tokio_util::sync::CancellationToken;

type ActiveRuns = HashMap<String, CancellationToken>;
#[derive(Clone)]
struct ActiveRunEntry {
token: CancellationToken,
generation: u64,
}

type ActiveRuns = HashMap<String, ActiveRunEntry>;

static ACTIVE_RUNS: OnceLock<Mutex<ActiveRuns>> = OnceLock::new();
static NEXT_GENERATION: AtomicU64 = AtomicU64::new(1);

pub struct RunRegistration {
run_id: String,
token: CancellationToken,
generation: u64,
}

impl RunRegistration {
pub fn token(&self) -> CancellationToken {
self.token.clone()
}
}

impl Drop for RunRegistration {
fn drop(&mut self) {
let mut active = active_runs().lock().unwrap();
if active
.get(&self.run_id)
.is_some_and(|entry| entry.generation == self.generation)
{
active.remove(&self.run_id);
}
}
}

fn active_runs() -> &'static Mutex<ActiveRuns> {
ACTIVE_RUNS.get_or_init(|| Mutex::new(HashMap::new()))
}

pub fn register_run(run_id: &str) -> CancellationToken {
pub fn register_run(run_id: &str) -> RunRegistration {
let token = CancellationToken::new();
active_runs()
.lock()
.unwrap()
.insert(run_id.to_string(), token.clone());
token
let generation = NEXT_GENERATION.fetch_add(1, Ordering::Relaxed);
active_runs().lock().unwrap().insert(
run_id.to_string(),
ActiveRunEntry {
token: token.clone(),
generation,
},
);
RunRegistration {
run_id: run_id.to_string(),
token,
generation,
}
}

pub fn cancel_run(run_id: &str) -> bool {
if let Some(token) = active_runs().lock().unwrap().get(run_id).cloned() {
if let Some(token) = active_runs()
.lock()
.unwrap()
.get(run_id)
.map(|entry| entry.token.clone())
{
token.cancel();
return true;
}

false
}

pub fn unregister_run(run_id: &str) {
active_runs().lock().unwrap().remove(run_id);
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -45,7 +86,8 @@ mod tests {
#[test]
fn register_then_cancel_propagates_to_token() {
let id = "runtime-test-register-then-cancel";
let token = register_run(id);
let registration = register_run(id);
let token = registration.token();
assert!(!token.is_cancelled(), "fresh token must start uncancelled");

let was_found = cancel_run(id);
Expand All @@ -54,8 +96,6 @@ mod tests {
token.is_cancelled(),
"the original token handle must observe the cancel"
);

unregister_run(id);
}

#[test]
Expand All @@ -71,8 +111,8 @@ mod tests {
#[test]
fn unregister_removes_from_active_set() {
let id = "runtime-test-unregister-removes";
let _token = register_run(id);
unregister_run(id);
let registration = register_run(id);
drop(registration);

let was_found = cancel_run(id);
assert!(!was_found, "unregistered ids must no longer be cancellable");
Expand All @@ -86,8 +126,10 @@ mod tests {
// registration wins. Pin the behavior so a future refactor
// doesn't accidentally start de-duping or asserting.
let id = "runtime-test-double-register";
let first = register_run(id);
let second = register_run(id);
let first_registration = register_run(id);
let first = first_registration.token();
let second_registration = register_run(id);
let second = second_registration.token();

// Cancelling now should signal the second (current) token.
assert!(cancel_run(id));
Expand All @@ -96,6 +138,13 @@ mod tests {
// cancel_run never reaches it.
assert!(!first.is_cancelled());

unregister_run(id);
drop(first_registration);
assert!(
cancel_run(id),
"dropping an older guard must not unregister the newer token"
);

drop(second_registration);
assert!(!cancel_run(id));
}
}
Loading
Loading