Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 179 additions & 16 deletions src-tauri/src/drivers/common/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,29 +48,192 @@ pub fn calculate_offset(page: u32, page_size: u32) -> u32 {
(page - 1) * page_size
}

/// Simple SQL tokenizer that respects:
/// - Single-quoted strings ('...')
/// - Double-quoted identifiers ("...")
/// - Backtick-quoted identifiers (`...`)
/// - Parenthesized groups (treated as single tokens)
/// - Whitespace as delimiter
///
/// This prevents keywords like LIMIT or OFFSET from being matched
/// inside string literals, quoted identifiers, or table names such as
/// `tapp_appointment_message_event_limit`.
fn tokenize_sql(sql: &str) -> Vec<String> {
let mut tokens = Vec::new();
let chars: Vec<char> = sql.chars().collect();
let len = chars.len();
let mut i = 0;

while i < len {
if chars[i].is_whitespace() {
i += 1;
continue;
}

if chars[i] == '\'' {
let mut token = String::new();
token.push(chars[i]);
i += 1;
while i < len {
token.push(chars[i]);
if chars[i] == '\'' {
if i + 1 < len && chars[i + 1] == '\'' {
i += 1;
token.push(chars[i]);
} else {
i += 1;
break;
}
}
i += 1;
}
tokens.push(token);
continue;
}

if chars[i] == '"' {
let mut token = String::new();
token.push(chars[i]);
i += 1;
while i < len {
token.push(chars[i]);
if chars[i] == '"' {
if i + 1 < len && chars[i + 1] == '"' {
i += 1;
token.push(chars[i]);
} else {
i += 1;
break;
}
}
i += 1;
}
tokens.push(token);
continue;
}

if chars[i] == '`' {
let mut token = String::new();
token.push(chars[i]);
i += 1;
while i < len {
token.push(chars[i]);
if chars[i] == '`' {
if i + 1 < len && chars[i + 1] == '`' {
i += 1;
token.push(chars[i]);
} else {
i += 1;
break;
}
}
i += 1;
}
tokens.push(token);
continue;
}

if chars[i] == '(' {
let mut token = String::new();
let mut depth = 0;
while i < len {
token.push(chars[i]);
if chars[i] == '(' {
depth += 1;
} else if chars[i] == ')' {
depth -= 1;
if depth == 0 {
i += 1;
break;
}
} else if chars[i] == '\'' {
i += 1;
while i < len {
token.push(chars[i]);
if chars[i] == '\'' {
if i + 1 < len && chars[i + 1] == '\'' {
i += 1;
token.push(chars[i]);
} else {
break;
}
}
i += 1;
}
}
i += 1;
}
tokens.push(token);
continue;
}

let mut token = String::new();
while i < len
&& !chars[i].is_whitespace()
&& chars[i] != '('
&& chars[i] != '\''
&& chars[i] != '"'
&& chars[i] != '`'
{
token.push(chars[i]);
i += 1;
}
if !token.is_empty() {
tokens.push(token);
}
}

tokens
}

/// Remove trailing LIMIT and OFFSET clauses from a SQL query.
///
/// Uses `rfind` to locate the last `LIMIT` keyword and strips everything from
/// there onwards (which includes any subsequent OFFSET). Falls back to looking
/// for a standalone `OFFSET` when no LIMIT is present.
pub fn strip_limit_offset(query: &str) -> &str {
let upper = query.to_uppercase();
if let Some(pos) = upper.rfind("LIMIT") {
query[..pos].trim()
} else if let Some(pos) = upper.rfind("OFFSET") {
query[..pos].trim()
} else {
query.trim()
/// Uses a token-aware scan so that `LIMIT` / `OFFSET` keywords inside
/// string literals, quoted identifiers, parenthesized subqueries, or as
/// part of table names (e.g. `tapp_…_limit`) are never misidentified.
pub fn strip_limit_offset(query: &str) -> String {
let tokens = tokenize_sql(query.trim());
let mut end = tokens.len();

// Scan backwards for OFFSET <n>
if end >= 2 && tokens[end - 2].to_uppercase() == "OFFSET" {
if tokens[end - 1].parse::<u64>().is_ok() {
end -= 2;
}
}

// Scan backwards for LIMIT <n>
if end >= 2 && tokens[end - 2].to_uppercase() == "LIMIT" {
if tokens[end - 1].parse::<u64>().is_ok() {
end -= 2;
}
}

tokens[..end].join(" ")
}

/// Extract the numeric value from a trailing LIMIT clause, if present.
///
/// Uses a token-aware scan so that `LIMIT` as a substring of a table name
/// (e.g. `tapp_appointment_message_event_limit`) is never misidentified.
pub fn extract_user_limit(query: &str) -> Option<u32> {
let upper = query.to_uppercase();
let pos = upper.rfind("LIMIT")?;
let after = query[pos + 5..].trim();
let num_str: String = after.chars().take_while(|c| c.is_ascii_digit()).collect();
num_str.parse().ok()
let tokens = tokenize_sql(query.trim());
let len = tokens.len();

// Walk backwards past optional OFFSET <n>
let mut end = len;
if end >= 2 && tokens[end - 2].to_uppercase() == "OFFSET" {
if tokens[end - 1].parse::<u64>().is_ok() {
end -= 2;
}
}

// Check for LIMIT <n>
if end >= 2 && tokens[end - 2].to_uppercase() == "LIMIT" {
return tokens[end - 1].parse().ok();
}

None
}

/// Build a paginated query by stripping any user-supplied LIMIT/OFFSET and
Expand Down
78 changes: 78 additions & 0 deletions src-tauri/src/drivers/common/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,38 @@ fn test_strip_limit_offset_only_offset() {
);
}

#[test]
fn test_strip_limit_offset_table_name_contains_limit() {
assert_eq!(
strip_limit_offset("SELECT * FROM tapp_appointment_message_event_limit ORDER BY id"),
"SELECT * FROM tapp_appointment_message_event_limit ORDER BY id"
);
}

#[test]
fn test_strip_limit_offset_table_name_contains_limit_with_real_limit() {
assert_eq!(
strip_limit_offset("SELECT * FROM tapp_appointment_message_event_limit ORDER BY id LIMIT 10"),
"SELECT * FROM tapp_appointment_message_event_limit ORDER BY id"
);
}

#[test]
fn test_strip_limit_offset_quoted_identifier() {
assert_eq!(
strip_limit_offset(r#"SELECT * FROM "order_limit_table" WHERE x > 1 LIMIT 5 OFFSET 10"#),
r#"SELECT * FROM "order_limit_table" WHERE x > 1"#
);
}

#[test]
fn test_strip_limit_offset_string_literal_with_limit() {
assert_eq!(
strip_limit_offset("SELECT * FROM t WHERE name LIKE '%limit%' LIMIT 10"),
"SELECT * FROM t WHERE name LIKE '%limit%'"
);
}

#[test]
fn test_extract_user_limit_present() {
assert_eq!(
Expand All @@ -210,6 +242,22 @@ fn test_extract_user_limit_absent() {
);
}

#[test]
fn test_extract_user_limit_table_name_contains_limit() {
assert_eq!(
super::extract_user_limit("SELECT * FROM tapp_appointment_message_event_limit"),
None
);
}

#[test]
fn test_extract_user_limit_table_name_contains_limit_with_real_limit() {
assert_eq!(
super::extract_user_limit("SELECT * FROM tapp_appointment_message_event_limit LIMIT 10"),
Some(10)
);
}

#[test]
fn test_build_paginated_query_no_user_limit() {
let q = "SELECT o.id FROM orders o ORDER BY o.created_at DESC";
Expand Down Expand Up @@ -244,6 +292,36 @@ fn test_build_paginated_query_user_limit_exhausted() {
assert_eq!(result, "SELECT * FROM t LIMIT 0 OFFSET 100");
}

#[test]
fn test_build_paginated_query_table_name_contains_limit() {
let q = "SELECT * FROM tapp_appointment_message_event_limit ORDER BY id";
let result = build_paginated_query(q, 100, 1);
assert_eq!(
result,
"SELECT * FROM tapp_appointment_message_event_limit ORDER BY id LIMIT 101 OFFSET 0"
);
}

#[test]
fn test_build_paginated_query_table_name_contains_limit_with_user_limit() {
let q = "SELECT * FROM tapp_appointment_message_event_limit ORDER BY id LIMIT 10";
let result = build_paginated_query(q, 100, 1);
assert_eq!(
result,
"SELECT * FROM tapp_appointment_message_event_limit ORDER BY id LIMIT 10 OFFSET 0"
);
}

#[test]
fn test_build_paginated_query_subquery_with_limit() {
let q = "SELECT * FROM (SELECT id FROM t ORDER BY id LIMIT 100) sub ORDER BY id LIMIT 5";
let result = build_paginated_query(q, 100, 1);
assert_eq!(
result,
"SELECT * FROM (SELECT id FROM t ORDER BY id LIMIT 100) sub ORDER BY id LIMIT 5 OFFSET 0"
);
}

#[test]
fn test_encode_blob_full_preserves_all_data() {
// 8KB of data — encode_blob would truncate, encode_blob_full must not
Expand Down
18 changes: 13 additions & 5 deletions src-tauri/src/mcp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -502,12 +502,13 @@ fn handle_list_tools() -> Result<Value, JsonRpcError> {
},
Tool {
name: "run_query".to_string(),
description: Some("Execute a SQL query on a specific connection".to_string()),
description: Some("Execute a SQL query on a specific connection. If the query already contains a LIMIT clause, it will be respected.".to_string()),
input_schema: json!({
"type": "object",
"properties": {
"connection_id": { "type": "string", "description": "The ID or name of the connection" },
"query": { "type": "string", "description": "The SQL query to execute" }
"query": { "type": "string", "description": "The SQL query to execute" },
"limit": { "type": "integer", "description": "Maximum number of rows to return (default: 100). If the query already contains a LIMIT clause smaller than this value, the query's LIMIT takes precedence." }
},
"required": ["connection_id", "query"]
}),
Expand Down Expand Up @@ -821,6 +822,11 @@ async fn tool_run_query(
data: None,
})?;

let max_rows = args
.get("limit")
.and_then(|v| v.as_u64())
.unwrap_or(100) as u32;

audit.connection_id = Some(conn_id.to_string());
audit.query = Some(query.to_string());
let kind = ai_activity::classify_query_kind(query);
Expand Down Expand Up @@ -986,11 +992,13 @@ async fn tool_run_query(
}

let result = match conn.params.driver.as_str() {
"mysql" => mysql::execute_query(&db_params, &effective_query, Some(100), 1, None).await,
"mysql" => {
mysql::execute_query(&db_params, &effective_query, Some(max_rows), 1, None).await
}
"postgres" => {
postgres::execute_query(&db_params, &effective_query, Some(100), 1, None).await
postgres::execute_query(&db_params, &effective_query, Some(max_rows), 1, None).await
}
"sqlite" => sqlite::execute_query(&db_params, &effective_query, Some(100), 1).await,
"sqlite" => sqlite::execute_query(&db_params, &effective_query, Some(max_rows), 1).await,
_ => Err("Unsupported driver".into()),
}
.map_err(|e| JsonRpcError {
Expand Down