diff --git a/include/yardstick_ffi.h b/include/yardstick_ffi.h index 52f6a38..2eba9cf 100644 --- a/include/yardstick_ffi.h +++ b/include/yardstick_ffi.h @@ -177,6 +177,13 @@ void yardstick_free_aggregate_list(YardstickAggregateCallList* list); */ YardstickSelectInfo* yardstick_parse_select(const char* sql); +/** + * Inline SELECT aliases inside complex ORDER BY expressions when they reference + * a SELECT alias whose expression contains a subquery. Returns NULL when no + * rewrite is needed or parsing fails. + */ +char* yardstick_inline_order_by_subquery_aliases(const char* sql); + /** * Free a select info structure. */ diff --git a/src/yardstick_extension.cpp b/src/yardstick_extension.cpp index bcd0bcc..fc879a0 100644 --- a/src/yardstick_extension.cpp +++ b/src/yardstick_extension.cpp @@ -63,6 +63,7 @@ extern "C" { char* (*replace_range)(const char*, uint32_t, uint32_t, const char*), char* (*apply_replacements)(const char*, const YardstickReplacement*, size_t), char* (*qualify_expression)(const char*, const char*), + char* (*inline_order_by_subquery_aliases)(const char*), void (*free_string)(char*), char* (*expand_aggregate_call)(const char*, const char*, const YardstickAtModifier*, size_t, const char*, const char*, const char*, const char* const*, size_t) ); @@ -635,6 +636,7 @@ static void LoadInternal(ExtensionLoader &loader) { yardstick_replace_range, yardstick_apply_replacements, yardstick_qualify_expression, + yardstick_inline_order_by_subquery_aliases, yardstick_free_string, yardstick_expand_aggregate_call ); diff --git a/src/yardstick_parser_ffi.cpp b/src/yardstick_parser_ffi.cpp index 079ec9e..80f5d3e 100644 --- a/src/yardstick_parser_ffi.cpp +++ b/src/yardstick_parser_ffi.cpp @@ -26,6 +26,7 @@ #include "duckdb/parser/expression/window_expression.hpp" #include "duckdb/parser/expression/between_expression.hpp" #include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" #include "duckdb/parser/tableref.hpp" #include "duckdb/parser/tableref/basetableref.hpp" #include "duckdb/parser/tableref/joinref.hpp" @@ -37,6 +38,8 @@ #include #include #include +#include +#include using namespace duckdb; @@ -57,6 +60,31 @@ static bool IsBoundaryChar(char c) { return !std::isalnum(static_cast(c)) && c != '_'; } +static std::string NormalizeAliasName(const std::string &alias) { + return StringUtil::Lower(alias); +} + +using TableQualifierSet = std::unordered_set; + +static bool IsPotentialOrderAliasRef( + const ColumnRefExpression &colref, + const TableQualifierSet &from_qualifiers, + std::string &alias_name +) { + if (!colref.IsQualified()) { + alias_name = colref.GetColumnName(); + return true; + } + if (colref.column_names.size() == 2 && StringUtil::CIEquals(colref.GetTableName(), "alias")) { + if (from_qualifiers.find(NormalizeAliasName(colref.GetTableName())) != from_qualifiers.end()) { + return false; + } + alias_name = colref.GetColumnName(); + return true; + } + return false; +} + static size_t SkipWhitespaceAndComments(const std::string& sql, size_t idx) { while (idx < sql.size()) { if (std::isspace(static_cast(sql[idx]))) { @@ -498,7 +526,7 @@ static void FindAggregateCalls(ParsedExpression* expr, std::vectorexpression_class) { + switch (expr->GetExpressionClass()) { case ExpressionClass::FUNCTION: { auto* func = static_cast(expr); std::string lower_name = StringUtil::Lower(func->function_name); @@ -510,7 +538,7 @@ static void FindAggregateCalls(ParsedExpression* expr, std::vectorchildren.empty()) { // First argument should be measure name (column ref or string) auto* first_arg = func->children[0].get(); - if (first_arg->expression_class == ExpressionClass::COLUMN_REF) { + if (first_arg->GetExpressionClass() == ExpressionClass::COLUMN_REF) { auto* col = static_cast(first_arg); info.measure_name = col->GetColumnName(); } else { @@ -519,8 +547,9 @@ static void FindAggregateCalls(ParsedExpression* expr, std::vectorquery_location.IsValid()) { - info.start_pos = static_cast(expr->query_location.GetIndex()); + auto query_location = expr->GetQueryLocation(); + if (query_location.IsValid()) { + info.start_pos = static_cast(query_location.GetIndex()); } else { info.start_pos = 0; } @@ -534,7 +563,7 @@ static void FindAggregateCalls(ParsedExpression* expr, std::vectorchildren[i].get(); // Check if this is an AT modifier call - if (arg->expression_class == ExpressionClass::FUNCTION) { + if (arg->GetExpressionClass() == ExpressionClass::FUNCTION) { auto* at_func = static_cast(arg); std::string at_name = StringUtil::Lower(at_func->function_name); @@ -728,7 +757,7 @@ static void CollectTablesFromTableRef(TableRef* ref, std::vectorexpression_class) { + switch (expr->GetExpressionClass()) { case ExpressionClass::FUNCTION: { auto* func = static_cast(expr); if (IsStandardAggregate(func->function_name)) { @@ -800,7 +829,7 @@ static bool ExpressionContainsAggregate(ParsedExpression* expr) { static bool ExpressionContainsMeasureRef(ParsedExpression* expr) { if (!expr) return false; - switch (expr->expression_class) { + switch (expr->GetExpressionClass()) { case ExpressionClass::FUNCTION: { auto* func = static_cast(expr); if (StringUtil::Lower(func->function_name) == "aggregate") { @@ -864,7 +893,7 @@ static bool ExpressionContainsMeasureRef(ParsedExpression* expr) { static void QualifyColumnRefs(ParsedExpression* expr, const std::string& qualifier) { if (!expr) return; - switch (expr->expression_class) { + switch (expr->GetExpressionClass()) { case ExpressionClass::COLUMN_REF: { auto* col = static_cast(expr); if (col->column_names.size() == 1) { @@ -1117,10 +1146,11 @@ extern "C" YardstickSelectInfo* yardstick_parse_select(const char* sql) { for (auto& expr : select_node->select_list) { YardstickSelectItem item; item.expression_sql = safe_strdup(expr->ToString()); - item.alias = expr->alias.empty() ? nullptr : safe_strdup(expr->alias); + item.alias = expr->HasAlias() ? safe_strdup(expr->GetAlias()) : nullptr; - if (expr->query_location.IsValid()) { - item.start_pos = static_cast(expr->query_location.GetIndex()); + auto query_location = expr->GetQueryLocation(); + if (query_location.IsValid()) { + item.start_pos = static_cast(query_location.GetIndex()); } else { item.start_pos = 0; } @@ -1128,7 +1158,7 @@ extern "C" YardstickSelectInfo* yardstick_parse_select(const char* sql) { item.end_pos = static_cast(end_pos); item.is_aggregate = ExpressionContainsAggregate(expr.get()); - item.is_star = expr->expression_class == ExpressionClass::STAR; + item.is_star = expr->GetExpressionClass() == ExpressionClass::STAR; item.is_measure_ref = ExpressionContainsMeasureRef(expr.get()); items.push_back(item); @@ -1219,6 +1249,226 @@ extern "C" void yardstick_free_select_info(YardstickSelectInfo* info) { delete info; } +struct SelectAliasEntry { + ParsedExpression* expression; + bool has_subquery; +}; + +using SelectAliasMap = std::unordered_map; + +static void AddTableQualifier(TableQualifierSet &qualifiers, const std::string &name) { + if (!name.empty()) { + qualifiers.insert(NormalizeAliasName(name)); + } +} + +static void CollectTableQualifiers(const TableRef *ref, TableQualifierSet &qualifiers) { + if (!ref) { + return; + } + + AddTableQualifier(qualifiers, ref->alias); + + switch (ref->type) { + case TableReferenceType::BASE_TABLE: { + auto *base = static_cast(ref); + if (ref->alias.empty()) { + AddTableQualifier(qualifiers, base->table_name); + } + break; + } + + case TableReferenceType::JOIN: { + auto *join = static_cast(ref); + CollectTableQualifiers(join->left.get(), qualifiers); + CollectTableQualifiers(join->right.get(), qualifiers); + break; + } + + default: + break; + } +} + +static bool FindSelectAliasRef(const ParsedExpression &expr, const SelectAliasMap &aliases, + const TableQualifierSet &from_qualifiers, + SelectAliasMap::const_iterator &alias_entry) { + if (expr.GetExpressionClass() != ExpressionClass::COLUMN_REF) { + return false; + } + + std::string alias_name; + auto &colref = expr.Cast(); + if (!IsPotentialOrderAliasRef(colref, from_qualifiers, alias_name)) { + return false; + } + + alias_entry = aliases.find(NormalizeAliasName(alias_name)); + return alias_entry != aliases.end(); +} + +static bool IsSimpleSelectAliasOrder( + const ParsedExpression &expr, + const SelectAliasMap &aliases, + const TableQualifierSet &from_qualifiers +) { + SelectAliasMap::const_iterator alias_entry; + return FindSelectAliasRef(expr, aliases, from_qualifiers, alias_entry); +} + +static void EnumerateOrderAliasScopeChildren( + const ParsedExpression &expr, + const std::function &callback +) { + if (expr.GetExpressionClass() == ExpressionClass::SUBQUERY) { + auto &subquery_expr = expr.Cast(); + if (subquery_expr.child) { + callback(*subquery_expr.child); + } + return; + } + + ParsedExpressionIterator::EnumerateChildren(expr, callback); +} + +static void EnumerateOrderAliasScopeChildren( + ParsedExpression &expr, + const std::function &child)> &callback +) { + if (expr.GetExpressionClass() == ExpressionClass::SUBQUERY) { + auto &subquery_expr = expr.Cast(); + if (subquery_expr.child) { + callback(subquery_expr.child); + } + return; + } + + ParsedExpressionIterator::EnumerateChildren(expr, callback); +} + +static bool ReferencesSubqueryAlias( + const ParsedExpression &expr, + const SelectAliasMap &aliases, + const TableQualifierSet &from_qualifiers +) { + SelectAliasMap::const_iterator alias_entry; + if (FindSelectAliasRef(expr, aliases, from_qualifiers, alias_entry) && alias_entry->second.has_subquery) { + return true; + } + + bool found = false; + EnumerateOrderAliasScopeChildren(expr, [&](const ParsedExpression &child) { + if (!found && ReferencesSubqueryAlias(child, aliases, from_qualifiers)) { + found = true; + } + }); + return found; +} + +static bool InlineSelectAliases( + unique_ptr &expr, + const SelectAliasMap &aliases, + const TableQualifierSet &from_qualifiers +) { + if (!expr) { + return false; + } + + SelectAliasMap::const_iterator alias_entry; + if (FindSelectAliasRef(*expr, aliases, from_qualifiers, alias_entry)) { + if (!alias_entry->second.has_subquery) { + return false; + } + auto replacement = alias_entry->second.expression->Copy(); + replacement->ClearAlias(); + expr = std::move(replacement); + return true; + } + + bool changed = false; + EnumerateOrderAliasScopeChildren(*expr, [&](unique_ptr &child) { + if (InlineSelectAliases(child, aliases, from_qualifiers)) { + changed = true; + } + }); + return changed; +} + +//============================================================================= +// FFI Implementation: yardstick_inline_order_by_subquery_aliases +//============================================================================= + +extern "C" char* yardstick_inline_order_by_subquery_aliases(const char* sql) { + if (!sql) { + return nullptr; + } + + try { + Parser parser; + parser.ParseQuery(sql); + if (parser.statements.empty()) { + return nullptr; + } + + auto& stmt = parser.statements[0]; + if (stmt->type != StatementType::SELECT_STATEMENT) { + return nullptr; + } + + auto* select_stmt = static_cast(stmt.get()); + if (!select_stmt->node || select_stmt->node->type != QueryNodeType::SELECT_NODE) { + return nullptr; + } + + auto* select_node = static_cast(select_stmt->node.get()); + TableQualifierSet from_qualifiers; + CollectTableQualifiers(select_node->from_table.get(), from_qualifiers); + + SelectAliasMap aliases; + bool has_subquery_alias = false; + for (auto& expr : select_node->select_list) { + if (!expr->HasAlias()) { + continue; + } + bool has_subquery = expr->HasSubquery(); + aliases[NormalizeAliasName(expr->GetAlias())] = SelectAliasEntry { expr.get(), has_subquery }; + has_subquery_alias = has_subquery_alias || has_subquery; + } + + if (!has_subquery_alias || aliases.empty()) { + return nullptr; + } + + bool changed = false; + for (auto& modifier : select_node->modifiers) { + if (modifier->type != ResultModifierType::ORDER_MODIFIER) { + continue; + } + + auto& order_modifier = modifier->Cast(); + for (auto& order : order_modifier.orders) { + if (!order.expression) { + continue; + } + if (IsSimpleSelectAliasOrder(*order.expression, aliases, from_qualifiers)) { + continue; + } + if (!ReferencesSubqueryAlias(*order.expression, aliases, from_qualifiers)) { + continue; + } + changed = InlineSelectAliases(order.expression, aliases, from_qualifiers) || changed; + } + } + + if (!changed) { + return nullptr; + } + return safe_strdup(stmt->ToString()); + } catch (...) { + return nullptr; + } +} + //============================================================================= // FFI Implementation: yardstick_parse_expression //============================================================================= @@ -1248,11 +1498,11 @@ extern "C" YardstickExpressionInfo* yardstick_parse_expression(const char* expr_ auto& expr = expressions[0]; result->sql = safe_strdup(expr->ToString()); - result->is_identifier = expr->expression_class == ExpressionClass::COLUMN_REF; + result->is_identifier = expr->GetExpressionClass() == ExpressionClass::COLUMN_REF; result->is_aggregate = ExpressionContainsAggregate(expr.get()); // If it's a simple aggregate function, extract the function name and inner expr - if (expr->expression_class == ExpressionClass::FUNCTION) { + if (expr->GetExpressionClass() == ExpressionClass::FUNCTION) { auto* func = static_cast(expr.get()); if (IsStandardAggregate(func->function_name)) { result->aggregate_func = safe_strdup(StringUtil::Upper(func->function_name)); diff --git a/test/sql/measures.test b/test/sql/measures.test index fd5df95..3a38d47 100644 --- a/test/sql/measures.test +++ b/test/sql/measures.test @@ -114,6 +114,165 @@ FROM sales_v; 2023 EU 225.0 2023 US 225.0 +# ORDER BY expression referencing a named aggregate that expands to a subquery (#28) +query IIRR +SEMANTIC SELECT + year, + region, + AGGREGATE(revenue) AS revenue, + AGGREGATE(revenue) AT (ALL region) AS year_total +FROM sales_v +ORDER BY revenue/year_total, year, region; +---- +2022 EU 50.0 150.0 +2023 EU 75.0 225.0 +2022 US 100.0 150.0 +2023 US 150.0 225.0 + +# Qualified ORDER BY tie-breakers should remain in the original query scope (#28) +query IIRR +SEMANTIC SELECT + o.year, + o.region, + AGGREGATE(revenue) AS revenue, + AGGREGATE(revenue) AT (ALL region) AS year_total +FROM sales_v o +ORDER BY revenue/year_total, o.region, o.year; +---- +2022 EU 50.0 150.0 +2023 EU 75.0 225.0 +2022 US 100.0 150.0 +2023 US 150.0 225.0 + +# Simple subquery aliases in ORDER BY should not force query wrapping (#28) +query IIR +SEMANTIC SELECT + o.year, + o.region, + AGGREGATE(revenue) AT (ALL region) AS year_total +FROM sales_v o +ORDER BY year_total, o.region, o.year; +---- +2022 EU 150.0 +2022 US 150.0 +2023 EU 225.0 +2023 US 225.0 + +# DuckDB's alias namespace should still work when no table qualifier shadows it (#28) +query IIR +SEMANTIC SELECT + o.year, + o.region, + AGGREGATE(revenue) AT (ALL region) AS year_total +FROM sales_v o +ORDER BY alias.year_total + 1, o.region, o.year; +---- +2022 EU 150.0 +2022 US 150.0 +2023 EU 225.0 +2023 US 225.0 + +# A real table qualifier named alias should not be treated as the alias namespace (#28) +statement ok +CREATE TABLE alias_collision_sales(bucket INT, year_total INT, amount DOUBLE); + +statement ok +INSERT INTO alias_collision_sales VALUES (1, 2, 10), (2, 1, 20); + +statement ok +CREATE VIEW alias_collision_v AS +SELECT bucket, year_total, SUM(amount) AS MEASURE revenue +FROM alias_collision_sales; + +query IR +SEMANTIC SELECT + alias.bucket, + AGGREGATE(revenue) AT (ALL bucket) AS year_total +FROM alias_collision_v alias +GROUP BY alias.bucket, alias.year_total +ORDER BY ANY_VALUE(alias.year_total) + 1; +---- +2 20.0 +1 10.0 + +# Non-subquery aliases in mixed ORDER BY expressions should stay alias refs (#28) +statement ok +CREATE TABLE order_alias_rows AS +SELECT i::INT AS id, 1.0 AS amount +FROM range(32) t(i); + +statement ok +CREATE VIEW order_alias_v AS +SELECT id, SUM(amount) AS MEASURE revenue +FROM order_alias_rows; + +query I +SELECT setseed(0.42); +---- +NULL + +query I +SEMANTIC WITH q AS ( + SELECT + random() AS r, + id, + AGGREGATE(revenue) AT (ALL id) AS all_revenue + FROM order_alias_v + ORDER BY r / all_revenue +) +SELECT COUNT(*)::INTEGER +FROM ( + SELECT r, LAG(r) OVER () AS prev_r + FROM q +) +WHERE prev_r IS NOT NULL AND r < prev_r; +---- +0 + +# ORDER BY scalar subquery scopes should not be searched for outer aliases (#28) +statement ok +CREATE TABLE order_subquery_totals(year INT, total DOUBLE); + +statement ok +INSERT INTO order_subquery_totals VALUES (2022, 2), (2023, 1); + +query IIR +SEMANTIC SELECT + o.year, + o.region, + AGGREGATE(revenue) AT (ALL region) AS total +FROM sales_v o +ORDER BY ( + SELECT total + FROM order_subquery_totals st + WHERE st.year = o.year +), o.region; +---- +2023 EU 225.0 +2023 US 225.0 +2022 EU 150.0 +2022 US 150.0 + +# The outer side of IN/ANY still belongs to the ORDER BY alias scope (#28) +statement ok +CREATE TABLE order_allowed_totals(total DOUBLE); + +statement ok +INSERT INTO order_allowed_totals VALUES (225); + +query IIR +SEMANTIC SELECT + o.year, + o.region, + AGGREGATE(revenue) AT (ALL region) AS total +FROM sales_v o +ORDER BY total IN (SELECT total FROM order_allowed_totals), o.region, o.year; +---- +2022 EU 150.0 +2022 US 150.0 +2023 EU 225.0 +2023 US 225.0 + # Lowercase from with line break query IIR rowsort SEMANTIC SELECT year, region, AGGREGATE(revenue) AT (ALL region) AS year_total diff --git a/yardstick-rs/src/parser_ffi.rs b/yardstick-rs/src/parser_ffi.rs index b2b1fb8..ca52b5b 100644 --- a/yardstick-rs/src/parser_ffi.rs +++ b/yardstick-rs/src/parser_ffi.rs @@ -169,6 +169,7 @@ type FnFreeCreateViewInfo = unsafe extern "C" fn(*mut YardstickCreateViewInfo); type FnReplaceRange = unsafe extern "C" fn(*const c_char, u32, u32, *const c_char) -> *mut c_char; type FnApplyReplacements = unsafe extern "C" fn(*const c_char, *const YardstickReplacement, usize) -> *mut c_char; type FnQualifyExpression = unsafe extern "C" fn(*const c_char, *const c_char) -> *mut c_char; +type FnInlineOrderBySubqueryAliases = unsafe extern "C" fn(*const c_char) -> *mut c_char; type FnFreeString = unsafe extern "C" fn(*mut c_char); type FnExpandAggregateCall = unsafe extern "C" fn( *const c_char, *const c_char, *const YardstickAtModifier, usize, @@ -187,6 +188,7 @@ static FN_FREE_CREATE_VIEW_INFO: AtomicPtr<()> = AtomicPtr::new(ptr::null_mut()) static FN_REPLACE_RANGE: AtomicPtr<()> = AtomicPtr::new(ptr::null_mut()); static FN_APPLY_REPLACEMENTS: AtomicPtr<()> = AtomicPtr::new(ptr::null_mut()); static FN_QUALIFY_EXPRESSION: AtomicPtr<()> = AtomicPtr::new(ptr::null_mut()); +static FN_INLINE_ORDER_BY_SUBQUERY_ALIASES: AtomicPtr<()> = AtomicPtr::new(ptr::null_mut()); static FN_FREE_STRING: AtomicPtr<()> = AtomicPtr::new(ptr::null_mut()); static FN_EXPAND_AGGREGATE_CALL: AtomicPtr<()> = AtomicPtr::new(ptr::null_mut()); @@ -204,6 +206,7 @@ pub extern "C" fn yardstick_init_parser_ffi( replace_range: FnReplaceRange, apply_replacements: FnApplyReplacements, qualify_expression: FnQualifyExpression, + inline_order_by_subquery_aliases: FnInlineOrderBySubqueryAliases, free_string: FnFreeString, expand_aggregate_call: FnExpandAggregateCall, ) { @@ -218,6 +221,7 @@ pub extern "C" fn yardstick_init_parser_ffi( FN_REPLACE_RANGE.store(replace_range as *mut (), Ordering::SeqCst); FN_APPLY_REPLACEMENTS.store(apply_replacements as *mut (), Ordering::SeqCst); FN_QUALIFY_EXPRESSION.store(qualify_expression as *mut (), Ordering::SeqCst); + FN_INLINE_ORDER_BY_SUBQUERY_ALIASES.store(inline_order_by_subquery_aliases as *mut (), Ordering::SeqCst); FN_FREE_STRING.store(free_string as *mut (), Ordering::SeqCst); FN_EXPAND_AGGREGATE_CALL.store(expand_aggregate_call as *mut (), Ordering::SeqCst); } @@ -278,6 +282,10 @@ unsafe fn yardstick_free_string(ptr: *mut c_char) { call_ffi!(FN_FREE_STRING, FnFreeString, ptr) } +unsafe fn yardstick_inline_order_by_subquery_aliases(sql: *const c_char) -> *mut c_char { + call_ffi!(FN_INLINE_ORDER_BY_SUBQUERY_ALIASES, FnInlineOrderBySubqueryAliases, sql) +} + unsafe fn yardstick_expand_aggregate_call( measure_name: *const c_char, agg_func: *const c_char, @@ -815,6 +823,25 @@ pub fn qualify_expression(expr: &str, qualifier: &str) -> Result } } +pub fn inline_order_by_subquery_aliases(sql: &str) -> Option { + let fn_ptr = FN_INLINE_ORDER_BY_SUBQUERY_ALIASES.load(Ordering::SeqCst); + if fn_ptr.is_null() { + return None; + } + + let c_sql = CString::new(sql).ok()?; + unsafe { + let result_ptr = yardstick_inline_order_by_subquery_aliases(c_sql.as_ptr()); + if result_ptr.is_null() { + return None; + } + + let result = c_str_to_string(result_ptr).unwrap_or_default(); + yardstick_free_string(result_ptr); + Some(result) + } +} + /// Expand a single AGGREGATE() call to SQL. /// /// Generates a correlated subquery for the measure based on the aggregation function diff --git a/yardstick-rs/src/sql/measures.rs b/yardstick-rs/src/sql/measures.rs index 0cd035e..f4153b1 100644 --- a/yardstick-rs/src/sql/measures.rs +++ b/yardstick-rs/src/sql/measures.rs @@ -5942,6 +5942,14 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { ); } + if let Some(rewritten_sql) = + std::panic::catch_unwind(|| parser_ffi::inline_order_by_subquery_aliases(&result_sql)) + .ok() + .flatten() + { + result_sql = rewritten_sql; + } + AggregateExpandResult { had_aggregate, expanded_sql: result_sql,