diff --git a/src-tauri/src/drivers/common/query.rs b/src-tauri/src/drivers/common/query.rs index f48befd3..6a412a2b 100644 --- a/src-tauri/src/drivers/common/query.rs +++ b/src-tauri/src/drivers/common/query.rs @@ -48,29 +48,192 @@ pub fn calculate_offset(page: u32, page_size: u32) -> u32 { (page - 1) * page_size } +/// Simple SQL tokenizer that respects: +/// - Single-quoted strings ('...') +/// - Double-quoted identifiers ("...") +/// - Backtick-quoted identifiers (`...`) +/// - Parenthesized groups (treated as single tokens) +/// - Whitespace as delimiter +/// +/// This prevents keywords like LIMIT or OFFSET from being matched +/// inside string literals, quoted identifiers, or table names such as +/// `tapp_appointment_message_event_limit`. +fn tokenize_sql(sql: &str) -> Vec { + let mut tokens = Vec::new(); + let chars: Vec = sql.chars().collect(); + let len = chars.len(); + let mut i = 0; + + while i < len { + if chars[i].is_whitespace() { + i += 1; + continue; + } + + if chars[i] == '\'' { + let mut token = String::new(); + token.push(chars[i]); + i += 1; + while i < len { + token.push(chars[i]); + if chars[i] == '\'' { + if i + 1 < len && chars[i + 1] == '\'' { + i += 1; + token.push(chars[i]); + } else { + i += 1; + break; + } + } + i += 1; + } + tokens.push(token); + continue; + } + + if chars[i] == '"' { + let mut token = String::new(); + token.push(chars[i]); + i += 1; + while i < len { + token.push(chars[i]); + if chars[i] == '"' { + if i + 1 < len && chars[i + 1] == '"' { + i += 1; + token.push(chars[i]); + } else { + i += 1; + break; + } + } + i += 1; + } + tokens.push(token); + continue; + } + + if chars[i] == '`' { + let mut token = String::new(); + token.push(chars[i]); + i += 1; + while i < len { + token.push(chars[i]); + if chars[i] == '`' { + if i + 1 < len && chars[i + 1] == '`' { + i += 1; + token.push(chars[i]); + } else { + i += 1; + break; + } + } + i += 1; + } + tokens.push(token); + continue; + } + + if chars[i] == '(' { + let mut token = String::new(); + let mut depth = 0; + while i < len { + token.push(chars[i]); + if chars[i] == '(' { + depth += 1; + } else if chars[i] == ')' { + depth -= 1; + if depth == 0 { + i += 1; + break; + } + } else if chars[i] == '\'' { + i += 1; + while i < len { + token.push(chars[i]); + if chars[i] == '\'' { + if i + 1 < len && chars[i + 1] == '\'' { + i += 1; + token.push(chars[i]); + } else { + break; + } + } + i += 1; + } + } + i += 1; + } + tokens.push(token); + continue; + } + + let mut token = String::new(); + while i < len + && !chars[i].is_whitespace() + && chars[i] != '(' + && chars[i] != '\'' + && chars[i] != '"' + && chars[i] != '`' + { + token.push(chars[i]); + i += 1; + } + if !token.is_empty() { + tokens.push(token); + } + } + + tokens +} + /// Remove trailing LIMIT and OFFSET clauses from a SQL query. /// -/// Uses `rfind` to locate the last `LIMIT` keyword and strips everything from -/// there onwards (which includes any subsequent OFFSET). Falls back to looking -/// for a standalone `OFFSET` when no LIMIT is present. -pub fn strip_limit_offset(query: &str) -> &str { - let upper = query.to_uppercase(); - if let Some(pos) = upper.rfind("LIMIT") { - query[..pos].trim() - } else if let Some(pos) = upper.rfind("OFFSET") { - query[..pos].trim() - } else { - query.trim() +/// Uses a token-aware scan so that `LIMIT` / `OFFSET` keywords inside +/// string literals, quoted identifiers, parenthesized subqueries, or as +/// part of table names (e.g. `tapp_…_limit`) are never misidentified. +pub fn strip_limit_offset(query: &str) -> String { + let tokens = tokenize_sql(query.trim()); + let mut end = tokens.len(); + + // Scan backwards for OFFSET + if end >= 2 && tokens[end - 2].to_uppercase() == "OFFSET" { + if tokens[end - 1].parse::().is_ok() { + end -= 2; + } + } + + // Scan backwards for LIMIT + if end >= 2 && tokens[end - 2].to_uppercase() == "LIMIT" { + if tokens[end - 1].parse::().is_ok() { + end -= 2; + } } + + tokens[..end].join(" ") } /// Extract the numeric value from a trailing LIMIT clause, if present. +/// +/// Uses a token-aware scan so that `LIMIT` as a substring of a table name +/// (e.g. `tapp_appointment_message_event_limit`) is never misidentified. pub fn extract_user_limit(query: &str) -> Option { - let upper = query.to_uppercase(); - let pos = upper.rfind("LIMIT")?; - let after = query[pos + 5..].trim(); - let num_str: String = after.chars().take_while(|c| c.is_ascii_digit()).collect(); - num_str.parse().ok() + let tokens = tokenize_sql(query.trim()); + let len = tokens.len(); + + // Walk backwards past optional OFFSET + let mut end = len; + if end >= 2 && tokens[end - 2].to_uppercase() == "OFFSET" { + if tokens[end - 1].parse::().is_ok() { + end -= 2; + } + } + + // Check for LIMIT + if end >= 2 && tokens[end - 2].to_uppercase() == "LIMIT" { + return tokens[end - 1].parse().ok(); + } + + None } /// Build a paginated query by stripping any user-supplied LIMIT/OFFSET and diff --git a/src-tauri/src/drivers/common/tests.rs b/src-tauri/src/drivers/common/tests.rs index c01d0d28..745c4aa7 100644 --- a/src-tauri/src/drivers/common/tests.rs +++ b/src-tauri/src/drivers/common/tests.rs @@ -186,6 +186,38 @@ fn test_strip_limit_offset_only_offset() { ); } +#[test] +fn test_strip_limit_offset_table_name_contains_limit() { + assert_eq!( + strip_limit_offset("SELECT * FROM tapp_appointment_message_event_limit ORDER BY id"), + "SELECT * FROM tapp_appointment_message_event_limit ORDER BY id" + ); +} + +#[test] +fn test_strip_limit_offset_table_name_contains_limit_with_real_limit() { + assert_eq!( + strip_limit_offset("SELECT * FROM tapp_appointment_message_event_limit ORDER BY id LIMIT 10"), + "SELECT * FROM tapp_appointment_message_event_limit ORDER BY id" + ); +} + +#[test] +fn test_strip_limit_offset_quoted_identifier() { + assert_eq!( + strip_limit_offset(r#"SELECT * FROM "order_limit_table" WHERE x > 1 LIMIT 5 OFFSET 10"#), + r#"SELECT * FROM "order_limit_table" WHERE x > 1"# + ); +} + +#[test] +fn test_strip_limit_offset_string_literal_with_limit() { + assert_eq!( + strip_limit_offset("SELECT * FROM t WHERE name LIKE '%limit%' LIMIT 10"), + "SELECT * FROM t WHERE name LIKE '%limit%'" + ); +} + #[test] fn test_extract_user_limit_present() { assert_eq!( @@ -210,6 +242,22 @@ fn test_extract_user_limit_absent() { ); } +#[test] +fn test_extract_user_limit_table_name_contains_limit() { + assert_eq!( + super::extract_user_limit("SELECT * FROM tapp_appointment_message_event_limit"), + None + ); +} + +#[test] +fn test_extract_user_limit_table_name_contains_limit_with_real_limit() { + assert_eq!( + super::extract_user_limit("SELECT * FROM tapp_appointment_message_event_limit LIMIT 10"), + Some(10) + ); +} + #[test] fn test_build_paginated_query_no_user_limit() { let q = "SELECT o.id FROM orders o ORDER BY o.created_at DESC"; @@ -244,6 +292,36 @@ fn test_build_paginated_query_user_limit_exhausted() { assert_eq!(result, "SELECT * FROM t LIMIT 0 OFFSET 100"); } +#[test] +fn test_build_paginated_query_table_name_contains_limit() { + let q = "SELECT * FROM tapp_appointment_message_event_limit ORDER BY id"; + let result = build_paginated_query(q, 100, 1); + assert_eq!( + result, + "SELECT * FROM tapp_appointment_message_event_limit ORDER BY id LIMIT 101 OFFSET 0" + ); +} + +#[test] +fn test_build_paginated_query_table_name_contains_limit_with_user_limit() { + let q = "SELECT * FROM tapp_appointment_message_event_limit ORDER BY id LIMIT 10"; + let result = build_paginated_query(q, 100, 1); + assert_eq!( + result, + "SELECT * FROM tapp_appointment_message_event_limit ORDER BY id LIMIT 10 OFFSET 0" + ); +} + +#[test] +fn test_build_paginated_query_subquery_with_limit() { + let q = "SELECT * FROM (SELECT id FROM t ORDER BY id LIMIT 100) sub ORDER BY id LIMIT 5"; + let result = build_paginated_query(q, 100, 1); + assert_eq!( + result, + "SELECT * FROM (SELECT id FROM t ORDER BY id LIMIT 100) sub ORDER BY id LIMIT 5 OFFSET 0" + ); +} + #[test] fn test_encode_blob_full_preserves_all_data() { // 8KB of data — encode_blob would truncate, encode_blob_full must not diff --git a/src-tauri/src/mcp/mod.rs b/src-tauri/src/mcp/mod.rs index 3999a704..763c9851 100644 --- a/src-tauri/src/mcp/mod.rs +++ b/src-tauri/src/mcp/mod.rs @@ -502,12 +502,13 @@ fn handle_list_tools() -> Result { }, Tool { name: "run_query".to_string(), - description: Some("Execute a SQL query on a specific connection".to_string()), + description: Some("Execute a SQL query on a specific connection. If the query already contains a LIMIT clause, it will be respected.".to_string()), input_schema: json!({ "type": "object", "properties": { "connection_id": { "type": "string", "description": "The ID or name of the connection" }, - "query": { "type": "string", "description": "The SQL query to execute" } + "query": { "type": "string", "description": "The SQL query to execute" }, + "limit": { "type": "integer", "description": "Maximum number of rows to return (default: 100). If the query already contains a LIMIT clause smaller than this value, the query's LIMIT takes precedence." } }, "required": ["connection_id", "query"] }), @@ -821,6 +822,11 @@ async fn tool_run_query( data: None, })?; + let max_rows = args + .get("limit") + .and_then(|v| v.as_u64()) + .unwrap_or(100) as u32; + audit.connection_id = Some(conn_id.to_string()); audit.query = Some(query.to_string()); let kind = ai_activity::classify_query_kind(query); @@ -986,11 +992,13 @@ async fn tool_run_query( } let result = match conn.params.driver.as_str() { - "mysql" => mysql::execute_query(&db_params, &effective_query, Some(100), 1, None).await, + "mysql" => { + mysql::execute_query(&db_params, &effective_query, Some(max_rows), 1, None).await + } "postgres" => { - postgres::execute_query(&db_params, &effective_query, Some(100), 1, None).await + postgres::execute_query(&db_params, &effective_query, Some(max_rows), 1, None).await } - "sqlite" => sqlite::execute_query(&db_params, &effective_query, Some(100), 1).await, + "sqlite" => sqlite::execute_query(&db_params, &effective_query, Some(max_rows), 1).await, _ => Err("Unsupported driver".into()), } .map_err(|e| JsonRpcError {