diff --git a/libs/@local/hashql/core/src/graph/algorithms/color/mod.rs b/libs/@local/hashql/core/src/graph/algorithms/color/mod.rs new file mode 100644 index 00000000000..3b85ab3bcfb --- /dev/null +++ b/libs/@local/hashql/core/src/graph/algorithms/color/mod.rs @@ -0,0 +1,238 @@ +//! Three-color depth-first search for directed graphs. +//! +//! Implements a DFS where each node transitions through three states: +//! +//! - **White** (unvisited): not yet encountered. +//! - **Gray** (in the `gray` set): discovered but not yet finished; still on the DFS stack. +//! - **Black** (in the `black` set): all successors have been processed. +//! +//! The color of a node when it is re-encountered determines the edge classification: +//! +//! | Re-encounter color | Meaning | +//! |--------------------|------------------| +//! | `None` (white) | Tree edge | +//! | `Some(Gray)` | Back edge (cycle)| +//! | `Some(Black)` | Cross/forward | +//! +//! This is an iterative (stack-based) implementation. The visitor receives callbacks +//! at two points: when a node is first examined ([`node_examined`]) and when all its +//! successors are finished ([`node_finished`]). The `node_finished` callback fires in +//! postorder. +//! +//! [`node_examined`]: TriColorVisitor::node_examined +//! [`node_finished`]: TriColorVisitor::node_finished + +use alloc::alloc::Global; +use core::{alloc::Allocator, ops::Try}; + +use crate::{ + graph::{DirectedGraph, Successors}, + id::bit_vec::DenseBitSet, +}; + +#[cfg(test)] +mod tests; + +/// DFS node state. +/// +/// Passed to [`TriColorVisitor::node_examined`] as the `before` parameter to indicate +/// what state a node was in when it was re-encountered. A value of `None` means the node +/// was white (first discovery). +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum NodeColor { + /// On the current DFS path. Re-encountering a gray node means a back edge (cycle). + Gray, + /// Fully processed. Re-encountering a black node means a cross or forward edge. + Black, +} + +/// Internal event pushed onto the DFS stack. +/// +/// Each node generates two events: `Gray` on discovery (explore successors) +/// and `Black` when all successors are done (finish the node). +struct Event { + node: N, + next: NodeColor, +} + +/// Iterative three-color DFS over a directed graph. +/// +/// Reusable across multiple `run` calls. Each call to [`run`](Self::run) resets all +/// internal state before starting from the given root. +/// +/// The graph's full node domain is used to size the internal bitsets, so node IDs +/// from the graph can be used directly without remapping. +pub struct TriColorDepthFirstSearch<'graph, G: ?Sized, N, A: Allocator = Global> { + graph: &'graph G, + stack: Vec, A>, + + /// Nodes that have been discovered (entered the DFS stack). + gray: DenseBitSet, + /// Nodes whose successors have all been processed. + black: DenseBitSet, +} + +impl<'graph, G: DirectedGraph + ?Sized> TriColorDepthFirstSearch<'graph, G, G::NodeId, Global> { + #[inline] + pub fn new(graph: &'graph G) -> Self { + Self::new_in(graph, Global) + } +} + +impl<'graph, G: DirectedGraph + ?Sized, A: Allocator> + TriColorDepthFirstSearch<'graph, G, G::NodeId, A> +{ + pub fn new_in(graph: &'graph G, alloc: A) -> Self { + Self { + graph, + stack: Vec::new_in(alloc), + gray: DenseBitSet::new_empty(graph.node_count()), + black: DenseBitSet::new_empty(graph.node_count()), + } + } + + /// Clears all traversal state (gray set, black set, stack). + /// + /// Call this before a sequence of [`run_from`](Self::run_from) calls to start + /// with a clean slate. + pub fn reset(&mut self) { + self.stack.clear(); + self.gray.clear(); + self.black.clear(); + } + + /// Run a DFS from `root`, resetting all state first. + /// + /// Equivalent to calling [`reset`](Self::reset) followed by + /// [`run_from`](Self::run_from). Use this when each DFS should be independent. + pub fn run(&mut self, root: G::NodeId, visitor: &mut V) -> V::Result + where + V: TriColorVisitor, + G: Successors, + { + self.reset(); + self.run_from(root, visitor) + } + + /// Run a DFS from `root` without resetting state. + /// + /// Nodes already in the gray or black sets from previous calls are treated as + /// previously visited. This allows running DFS from multiple roots while + /// accumulating state: a node finished (black) by an earlier root is skipped, + /// so each connected component is explored at most once. + /// + /// Stops early if the visitor returns a residual (e.g., `Err` or + /// `ControlFlow::Break`). Edges for which [`TriColorVisitor::ignore_edge`] + /// returns `true` are not followed. + pub fn run_from(&mut self, root: G::NodeId, visitor: &mut V) -> V::Result + where + V: TriColorVisitor, + G: Successors, + { + self.stack.push(Event { + node: root, + next: NodeColor::Gray, + }); + + while let Some(Event { node, next }) = self.stack.pop() { + match next { + NodeColor::Black => { + let not_previously_finished = self.black.insert(node); + debug_assert!( + not_previously_finished, + "a node should be finished exactly once" + ); + + visitor.node_finished(node)?; + } + NodeColor::Gray => { + let newly_discovered = self.gray.insert(node); + let previous_color = if newly_discovered { + None + } else if self.black.contains(node) { + Some(NodeColor::Black) + } else { + Some(NodeColor::Gray) + }; + + visitor.node_examined(node, previous_color)?; + + // Already visited through another path: nothing more to do. + if previous_color.is_some() { + continue; + } + + self.stack.push(Event { + node, + next: NodeColor::Black, + }); + for successor in self.graph.successors(node) { + if !visitor.ignore_edge(node, successor) { + self.stack.push(Event { + node: successor, + next: NodeColor::Gray, + }); + } + } + } + } + } + + Try::from_output(()) + } +} + +/// Callbacks for [`TriColorDepthFirstSearch`]. +/// +/// All methods have default no-op implementations, so visitors only need to +/// override the events they care about. +pub trait TriColorVisitor { + /// The control-flow type returned by each callback. + /// + /// Use `Result<(), E>` or `ControlFlow` to support early termination. + type Result: Try; + + /// Called when a node is encountered during DFS. + /// + /// `before` indicates the node's color at the time of re-encounter: + /// - `None`: first discovery (white to gray transition). + /// - `Some(Gray)`: back edge, indicating a cycle. + /// - `Some(Black)`: cross or forward edge. + #[expect(unused_variables)] + fn node_examined(&mut self, node: G::NodeId, before: Option) -> Self::Result { + Try::from_output(()) + } + + /// Called after all successors of `node` have been fully processed. + /// + /// Fires in postorder: a node finishes only after all its descendants finish. + #[expect(unused_variables)] + fn node_finished(&mut self, node: G::NodeId) -> Self::Result { + Try::from_output(()) + } + + /// Return `true` to skip this edge during traversal. + /// + /// Allows restricting the DFS to a subgraph without constructing a + /// separate graph data structure. + #[expect(unused_variables)] + fn ignore_edge(&mut self, source: G::NodeId, target: G::NodeId) -> bool { + false + } +} + +/// A [`TriColorVisitor`] that detects cycles. +/// +/// Returns `Err(())` as soon as a back edge (re-encounter of a gray node) is found. +pub struct CycleDetector; + +impl TriColorVisitor for CycleDetector { + type Result = Result<(), ()>; + + fn node_examined(&mut self, _: G::NodeId, before: Option) -> Self::Result { + match before { + Some(NodeColor::Gray) => Err(()), + _ => Ok(()), + } + } +} diff --git a/libs/@local/hashql/core/src/graph/algorithms/color/tests.rs b/libs/@local/hashql/core/src/graph/algorithms/color/tests.rs new file mode 100644 index 00000000000..fccbed31f01 --- /dev/null +++ b/libs/@local/hashql/core/src/graph/algorithms/color/tests.rs @@ -0,0 +1,250 @@ +use core::ops::ControlFlow; + +use super::{NodeColor, TriColorDepthFirstSearch, TriColorVisitor}; +use crate::{ + graph::{DirectedGraph as _, NodeId, tests::TestGraph}, + id::Id as _, +}; + +macro_rules! n { + ($id:expr) => { + NodeId::from_usize($id) + }; +} + +struct CycleDetector; + +impl TriColorVisitor for CycleDetector { + type Result = ControlFlow; + + fn node_examined(&mut self, node: NodeId, before: Option) -> Self::Result { + match before { + Some(NodeColor::Gray) => ControlFlow::Break(node), + _ => ControlFlow::Continue(()), + } + } +} + +fn has_cycle(graph: &TestGraph) -> bool { + let mut search = TriColorDepthFirstSearch::new(graph); + (0..graph.node_count()).any(|i| search.run(n!(i), &mut CycleDetector).is_break()) +} + +fn cycle_target(graph: &TestGraph) -> Option { + let mut search = TriColorDepthFirstSearch::new(graph); + for i in 0..graph.node_count() { + if let ControlFlow::Break(target) = search.run(n!(i), &mut CycleDetector) { + return Some(target); + } + } + None +} + +struct PostOrderCollector { + order: Vec, +} + +impl TriColorVisitor for PostOrderCollector { + type Result = ControlFlow<()>; + + fn node_finished(&mut self, node: NodeId) -> Self::Result { + self.order.push(node); + ControlFlow::Continue(()) + } +} + +fn postorder(graph: &TestGraph, root: usize) -> Vec { + let mut search = TriColorDepthFirstSearch::new(graph); + let mut collector = PostOrderCollector { order: Vec::new() }; + let _: ControlFlow<()> = search.run(n!(root), &mut collector); + collector.order +} + +#[test] +fn self_loop_is_cyclic() { + let graph = TestGraph::new(&[(0, 0)]); + assert!(has_cycle(&graph)); + assert_eq!(cycle_target(&graph), Some(n!(0))); +} + +#[test] +fn two_node_cycle() { + let graph = TestGraph::new(&[(0, 1), (1, 0)]); + assert!(has_cycle(&graph)); +} + +#[test] +fn three_node_cycle() { + let graph = TestGraph::new(&[(0, 1), (1, 2), (2, 0)]); + assert!(has_cycle(&graph)); +} + +#[test] +fn linear_chain_no_cycle() { + let graph = TestGraph::new(&[(0, 1), (1, 2), (2, 3)]); + assert!(!has_cycle(&graph)); +} + +#[test] +fn diamond_no_cycle() { + let graph = TestGraph::new(&[(0, 1), (0, 2), (1, 3), (2, 3)]); + assert!(!has_cycle(&graph)); +} + +#[test] +fn diamond_with_back_edge_is_cyclic() { + let graph = TestGraph::new(&[(0, 1), (0, 2), (1, 3), (2, 3), (3, 0)]); + assert!(has_cycle(&graph)); +} + +#[test] +fn disconnected_with_cycle_in_second_component() { + // Component 1: 0 -> 1 (no cycle) + // Component 2: 2 -> 3 -> 2 (cycle) + let graph = TestGraph::new(&[(0, 1), (2, 3), (3, 2)]); + assert!(has_cycle(&graph)); +} + +#[test] +fn disconnected_no_cycle() { + let graph = TestGraph::new(&[(0, 1), (2, 3)]); + assert!(!has_cycle(&graph)); +} + +#[test] +fn isolated_node_no_cycle() { + // Single node, no edges (TestGraph needs at least one edge to set node_count, + // so use two disconnected nodes with one edge). + let graph = TestGraph::new(&[(0, 1)]); + assert!(!has_cycle(&graph)); +} + +#[test] +fn postorder_linear_chain() { + // 0 -> 1 -> 2 + let graph = TestGraph::new(&[(0, 1), (1, 2)]); + let order = postorder(&graph, 0); + assert_eq!(order, [n!(2), n!(1), n!(0)]); +} + +#[test] +fn postorder_diamond() { + // 0 -> 1, 0 -> 2, 1 -> 3, 2 -> 3 + let graph = TestGraph::new(&[(0, 1), (0, 2), (1, 3), (2, 3)]); + let order = postorder(&graph, 0); + + // 3 must come before both 1 and 2; 1 and 2 must come before 0. + assert_eq!(order.len(), 4); + assert_eq!(*order.last().expect("non-empty"), n!(0)); + + let pos = |id: usize| order.iter().position(|&n| n == n!(id)).expect("non-empty"); + assert!(pos(3) < pos(1)); + assert!(pos(3) < pos(2)); + assert!(pos(1) < pos(0)); + assert!(pos(2) < pos(0)); +} + +#[test] +fn postorder_unreachable_node_not_visited() { + // 0 -> 1, node 2 exists but is unreachable from 0 + let graph = TestGraph::new(&[(0, 1), (2, 2)]); + let order = postorder(&graph, 0); + + // Only nodes reachable from root 0 + assert_eq!(order, [n!(1), n!(0)]); +} + +struct FilteredCycleDetector { + ignored: (usize, usize), +} + +impl TriColorVisitor for FilteredCycleDetector { + type Result = ControlFlow<()>; + + fn node_examined(&mut self, _: NodeId, before: Option) -> Self::Result { + match before { + Some(NodeColor::Gray) => ControlFlow::Break(()), + _ => ControlFlow::Continue(()), + } + } + + fn ignore_edge(&mut self, source: NodeId, target: NodeId) -> bool { + source == n!(self.ignored.0) && target == n!(self.ignored.1) + } +} + +#[test] +fn ignore_edge_breaks_cycle() { + // 0 -> 1 -> 2 -> 0 (cycle); ignoring 2 -> 0 removes the cycle + let graph = TestGraph::new(&[(0, 1), (1, 2), (2, 0)]); + + let mut search = TriColorDepthFirstSearch::new(&graph); + let mut visitor = FilteredCycleDetector { ignored: (2, 0) }; + let result = search.run(n!(0), &mut visitor); + assert!(result.is_continue()); +} + +#[test] +fn ignore_edge_wrong_edge_keeps_cycle() { + // 0 -> 1 -> 2 -> 0 (cycle); ignoring 0 -> 1 still leaves 1 -> 2 -> 0 reachable + // from 0? No: if 0 -> 1 is ignored, DFS from 0 has no successors, no cycle found. + // But the cycle B -> C -> A still exists if we start from 1. + let graph = TestGraph::new(&[(0, 1), (1, 2), (2, 0)]); + + let mut search = TriColorDepthFirstSearch::new(&graph); + let mut visitor = FilteredCycleDetector { ignored: (0, 1) }; + + // From node 0: no successors after filtering, no cycle + assert!(search.run(n!(0), &mut visitor).is_continue()); + + // From node 1: 1 -> 2 -> 0 -> (0->1 ignored) -> done, no back edge to gray + // Wait: 0's successor 1 is ignored, so from 0 we go nowhere. But from 1: 1->2->0, + // then 0 has no unignored successors. 0 finishes. No cycle. + assert!(search.run(n!(1), &mut visitor).is_continue()); +} + +#[test] +fn run_resets_between_calls() { + let graph = TestGraph::new(&[(0, 1), (1, 0)]); + let mut search = TriColorDepthFirstSearch::new(&graph); + + // First run: finds cycle + assert!(search.run(n!(0), &mut CycleDetector).is_break()); + + // Second run on same search: state is reset, should find cycle again + assert!(search.run(n!(0), &mut CycleDetector).is_break()); +} + +#[test] +fn run_from_accumulates_state() { + // 0->1->2, 3->1 (node 1 reachable from both roots) + // Without accumulation, run_from(3) would re-explore 1->2 and emit them again. + // With accumulation, nodes 1 and 2 are already black after run_from(0). + let graph = TestGraph::new(&[(0, 1), (1, 2), (3, 1)]); + let mut search = TriColorDepthFirstSearch::new(&graph); + let mut collector = PostOrderCollector { order: Vec::new() }; + + search.reset(); + let _: ControlFlow<()> = search.run_from(n!(0), &mut collector); + let _: ControlFlow<()> = search.run_from(n!(3), &mut collector); + + // Nodes 1 and 2 finished during first run_from; second run_from only finishes 3. + assert_eq!(collector.order, [n!(2), n!(1), n!(0), n!(3)]); +} + +#[test] +fn run_from_skips_already_finished_nodes() { + // 0->1->2, 3->2 (shared sink at 2) + let graph = TestGraph::new(&[(0, 1), (1, 2), (3, 2)]); + let mut search = TriColorDepthFirstSearch::new(&graph); + let mut collector = PostOrderCollector { order: Vec::new() }; + + search.reset(); + let _: ControlFlow<()> = search.run_from(n!(0), &mut collector); + let _: ControlFlow<()> = search.run_from(n!(3), &mut collector); + + // Node 2 should appear exactly once (finished during first run_from), + // not re-emitted when reached from node 3. + assert_eq!(collector.order.iter().filter(|&&n| n == n!(2)).count(), 1); + assert_eq!(collector.order.len(), 4); // 2, 1, 0, 3 +} diff --git a/libs/@local/hashql/core/src/graph/algorithms/mod.rs b/libs/@local/hashql/core/src/graph/algorithms/mod.rs index 5509a2f47aa..bfd5647551d 100644 --- a/libs/@local/hashql/core/src/graph/algorithms/mod.rs +++ b/libs/@local/hashql/core/src/graph/algorithms/mod.rs @@ -25,6 +25,7 @@ //! # assert_eq!(visited, [n1, n2]); //! ``` +pub mod color; pub mod dominators; pub mod tarjan; @@ -32,6 +33,7 @@ use alloc::collections::VecDeque; use core::iter::FusedIterator; pub use self::{ + color::{CycleDetector, TriColorDepthFirstSearch, TriColorVisitor}, dominators::{ DominanceFrontier, DominatorFrontiers, Dominators, IteratedDominanceFrontier, dominance_frontiers, dominators, iterated_dominance_frontier, diff --git a/libs/@local/hashql/core/src/graph/algorithms/tarjan/mod.rs b/libs/@local/hashql/core/src/graph/algorithms/tarjan/mod.rs index edda7708ba7..3a1bd412feb 100644 --- a/libs/@local/hashql/core/src/graph/algorithms/tarjan/mod.rs +++ b/libs/@local/hashql/core/src/graph/algorithms/tarjan/mod.rs @@ -203,6 +203,59 @@ where pub fn of(&self, id: S) -> &[N] { self.as_slice().of(id) } + + #[inline] + pub fn iter(&self) -> impl ExactSizeIterator + DoubleEndedIterator { + self.sccs().map(|scc| (scc, self.of(scc))) + } + + // TODO: miri tests + #[expect(unsafe_code)] + pub fn iter_mut( + &mut self, + ) -> impl ExactSizeIterator + DoubleEndedIterator + '_ { + let ptr = self.nodes.as_mut_ptr(); + let offsets = &self.offsets; + + offsets.ids().take(self.offsets.len() - 1).map(move |scc| { + let start = offsets[scc]; + let end = offsets[scc.plus(1)]; + + // SAFETY: The start and end indices are valid for the nodes slice, and members is + // non-overlapping by construction + (scc, unsafe { + core::slice::from_raw_parts_mut(ptr.add(start), end - start) + }) + }) + } +} + +impl<'this, N, S, A: Allocator> IntoIterator for &'this Members +where + S: Id, +{ + type Item = (S, &'this [N]); + + type IntoIter = impl ExactSizeIterator + DoubleEndedIterator; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'this, N, S, A: Allocator> IntoIterator for &'this mut Members +where + S: Id, +{ + type Item = (S, &'this mut [N]); + + type IntoIter = impl ExactSizeIterator + DoubleEndedIterator; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter_mut() + } } /// Storage for the computed SCCs and their relationships. diff --git a/libs/@local/hashql/eval/src/postgres/continuation.rs b/libs/@local/hashql/eval/src/postgres/continuation.rs index d30b3f2a445..6f9b51a93b2 100644 --- a/libs/@local/hashql/eval/src/postgres/continuation.rs +++ b/libs/@local/hashql/eval/src/postgres/continuation.rs @@ -1,7 +1,7 @@ //! Naming conventions and helpers for the continuation LATERAL subqueries. //! //! Each postgres island in a filter body produces a `CROSS JOIN LATERAL` subquery -//! that evaluates its CASE tree once per row (via `OFFSET 0`) and returns a +//! that evaluates its CASE tree once per row and returns a //! composite `continuation` value. This module provides the identifiers, column //! names, and expression helpers used to construct and reference those subqueries. @@ -80,7 +80,7 @@ impl From for ContinuationColumn { pub(crate) enum ContinuationColumn { /// The composite `continuation` value column in the LATERAL subquery. /// - /// The LATERAL is `(SELECT ::continuation AS c OFFSET 0) AS f0`, + /// The LATERAL is `(SELECT ::continuation AS c) AS f0`, /// so field access is `(f0."c")."filter"`. Entry, /// The filter boolean. `NULL` means passthrough, `true` keeps, `false` rejects. diff --git a/libs/@local/hashql/eval/src/postgres/mod.rs b/libs/@local/hashql/eval/src/postgres/mod.rs index 71743748071..43744c10948 100644 --- a/libs/@local/hashql/eval/src/postgres/mod.rs +++ b/libs/@local/hashql/eval/src/postgres/mod.rs @@ -17,9 +17,6 @@ //! - **`block`** (`int`): next basic block when leaving the island. //! - **`locals`** (`int[]`) and **`values`** (`jsonb[]`): parallel arrays carrying live-out locals. //! -//! Continuation subqueries are forced to materialise once per row using `OFFSET 0` to prevent -//! PostgreSQL from inlining the subquery and duplicating the island's `CASE` tree per field access. -//! //! ## Parameters and projections //! //! Parameters are deduplicated by identity and referenced by index (rendered as `$N` in SQL). @@ -387,7 +384,6 @@ impl<'eval, 'ctx, 'heap, A: Allocator, S: BumpAllocator> expression, alias: Some(ContinuationColumn::Entry.identifier()), }]) - .offset(0) .build(); let subquery = query::FromItem::Subquery { diff --git a/libs/@local/hashql/eval/tests/ui/postgres/comparison-no-cast.stdout b/libs/@local/hashql/eval/tests/ui/postgres/comparison-no-cast.stdout index ff997416812..0cb7132d278 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/comparison-no-cast.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/comparison-no-cast.stdout @@ -2,8 +2,7 @@ SELECT ("continuation_4_0"."row")."block" AS "continuation_4_0_block", ("continuation_4_0"."row")."locals" AS "continuation_4_0_locals", ("continuation_4_0"."row")."values" AS "continuation_4_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((($3::jsonb) > ($4::jsonb))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row" -OFFSET 0) AS "continuation_4_0" +CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((($3::jsonb) > ($4::jsonb))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row") AS "continuation_4_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_4_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/constant-true-filter.stdout b/libs/@local/hashql/eval/tests/ui/postgres/constant-true-filter.stdout index 0df73c4dc88..2b60f69695d 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/constant-true-filter.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/constant-true-filter.stdout @@ -2,8 +2,7 @@ SELECT ("continuation_1_0"."row")."block" AS "continuation_1_0_block", ("continuation_1_0"."row")."locals" AS "continuation_1_0_locals", ("continuation_1_0"."row")."values" AS "continuation_1_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((1)::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row" -OFFSET 0) AS "continuation_1_0" +CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((1)::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row") AS "continuation_1_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_1_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/dict-construction.stdout b/libs/@local/hashql/eval/tests/ui/postgres/dict-construction.stdout index f649dadb5f0..e446d253309 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/dict-construction.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/dict-construction.stdout @@ -2,8 +2,7 @@ SELECT ("continuation_4_0"."row")."block" AS "continuation_4_0_block", ("continuation_4_0"."row")."locals" AS "continuation_4_0_locals", ("continuation_4_0"."row")."values" AS "continuation_4_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb(jsonb_build_object("entity_temporal_metadata_0_0_0"."entity_uuid", "entity_temporal_metadata_0_0_0"."web_id")) = to_jsonb(jsonb_build_object(($3::jsonb), ($4::jsonb))))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row" -OFFSET 0) AS "continuation_4_0" +CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb(jsonb_build_object("entity_temporal_metadata_0_0_0"."entity_uuid", "entity_temporal_metadata_0_0_0"."web_id")) = to_jsonb(jsonb_build_object(($3::jsonb), ($4::jsonb))))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row") AS "continuation_4_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_4_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/entity-archived-check.stdout b/libs/@local/hashql/eval/tests/ui/postgres/entity-archived-check.stdout index 3ca5b2553d6..f3f9ebc2bff 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/entity-archived-check.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/entity-archived-check.stdout @@ -4,8 +4,7 @@ SELECT ("continuation_1_0"."row")."block" AS "continuation_1_0_block", ("continu FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" INNER JOIN "entity_editions" AS "entity_editions_0_0_1" ON "entity_editions_0_0_1"."entity_edition_id" = "entity_temporal_metadata_0_0_0"."entity_edition_id" -CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((NOT("entity_editions_0_0_1"."archived"))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row" -OFFSET 0) AS "continuation_1_0" +CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((NOT("entity_editions_0_0_1"."archived"))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row") AS "continuation_1_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_1_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/entity-type-ids-lateral.stdout b/libs/@local/hashql/eval/tests/ui/postgres/entity-type-ids-lateral.stdout index ecbbf9d0668..886d669d8d6 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/entity-type-ids-lateral.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/entity-type-ids-lateral.stdout @@ -7,8 +7,7 @@ FROM "entity_is_of_type_ids" AS "eit" CROSS JOIN LATERAL UNNEST("eit"."base_urls", ("eit"."versions"::text[])) AS "u"("b", "v") WHERE "eit"."entity_edition_id" = "entity_temporal_metadata_0_0_0"."entity_edition_id") AS "entity_is_of_type_ids_0_0_1" ON TRUE -CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_is_of_type_ids_0_0_1"."entity_type_ids") = to_jsonb(($3::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row" -OFFSET 0) AS "continuation_2_0" +CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_is_of_type_ids_0_0_1"."entity_type_ids") = to_jsonb(($3::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row") AS "continuation_2_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_2_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/entity-uuid-equality.stdout b/libs/@local/hashql/eval/tests/ui/postgres/entity-uuid-equality.stdout index 30bad1b3f91..26e0650762f 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/entity-uuid-equality.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/entity-uuid-equality.stdout @@ -2,8 +2,7 @@ SELECT ("continuation_7_0"."row")."block" AS "continuation_7_0_block", ("continuation_7_0"."row")."locals" AS "continuation_7_0_locals", ("continuation_7_0"."row")."values" AS "continuation_7_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($3::text)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row" -OFFSET 0) AS "continuation_7_0" +CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($3::text)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row") AS "continuation_7_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_7_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/entity-web-id-equality.stdout b/libs/@local/hashql/eval/tests/ui/postgres/entity-web-id-equality.stdout index fe5d03119a3..b45d2ccdd65 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/entity-web-id-equality.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/entity-web-id-equality.stdout @@ -2,8 +2,7 @@ SELECT ("continuation_2_0"."row")."block" AS "continuation_2_0_block", ("continuation_2_0"."row")."locals" AS "continuation_2_0_locals", ("continuation_2_0"."row")."values" AS "continuation_2_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."web_id") = to_jsonb(($3::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row" -OFFSET 0) AS "continuation_2_0" +CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."web_id") = to_jsonb(($3::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row") AS "continuation_2_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_2_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/env-captured-variable.stdout b/libs/@local/hashql/eval/tests/ui/postgres/env-captured-variable.stdout index aced634948f..6b7d9e35e75 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/env-captured-variable.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/env-captured-variable.stdout @@ -2,8 +2,7 @@ SELECT ("continuation_0_0"."row")."block" AS "continuation_0_0_block", ("continuation_0_0"."row")."locals" AS "continuation_0_0_locals", ("continuation_0_0"."row")."values" AS "continuation_0_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($3::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row" -OFFSET 0) AS "continuation_0_0" +CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($3::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row") AS "continuation_0_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_0_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/filter/property_mask.snap b/libs/@local/hashql/eval/tests/ui/postgres/filter/property_mask.snap index fa02cfcd116..62499ece400 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/filter/property_mask.snap +++ b/libs/@local/hashql/eval/tests/ui/postgres/filter/property_mask.snap @@ -141,13 +141,13 @@ LEFT OUTER JOIN "entity_has_right_entity" AS "entity_has_right_entity_0_0_5" = "entity_temporal_metadata_0_0_0"."web_id" AND "entity_has_right_entity_0_0_5"."entity_uuid" = "entity_temporal_metadata_0_0_0"."entity_uuid" -CROSS JOIN LATERAL ( - SELECT - ( - row(NULL, 1, ARRAY[]::int [], ARRAY[]::jsonb [])::continuation - ) AS "row" - OFFSET 0 -) AS "continuation_0_0" +CROSS JOIN LATERAL + ( + SELECT + ( + row(NULL, 1, ARRAY[]::int [], ARRAY[]::jsonb [])::continuation + ) AS "row" + ) AS "continuation_0_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) diff --git a/libs/@local/hashql/eval/tests/ui/postgres/filter/temporal_decision_time_interval.snap b/libs/@local/hashql/eval/tests/ui/postgres/filter/temporal_decision_time_interval.snap index fbd4395d67e..8b48f327b89 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/filter/temporal_decision_time_interval.snap +++ b/libs/@local/hashql/eval/tests/ui/postgres/filter/temporal_decision_time_interval.snap @@ -29,13 +29,13 @@ SELECT ("continuation_0_0"."row")."locals" AS "continuation_0_0_locals", ("continuation_0_0"."row")."values" AS "continuation_0_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -CROSS JOIN LATERAL ( - SELECT - ( - ROW(NULL, 1, ARRAY[]::int [], ARRAY[]::jsonb [])::continuation - ) AS "row" - OFFSET 0 -) AS "continuation_0_0" +CROSS JOIN LATERAL + ( + SELECT + ( + ROW(NULL, 1, ARRAY[]::int [], ARRAY[]::jsonb [])::continuation + ) AS "row" + ) AS "continuation_0_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) diff --git a/libs/@local/hashql/eval/tests/ui/postgres/if-input-branches.stdout b/libs/@local/hashql/eval/tests/ui/postgres/if-input-branches.stdout index 87a155ae195..6ba1aaff3e0 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/if-input-branches.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/if-input-branches.stdout @@ -2,8 +2,7 @@ SELECT ("continuation_2_0"."row")."block" AS "continuation_2_0_block", ("continuation_2_0"."row")."locals" AS "continuation_2_0_locals", ("continuation_2_0"."row")."values" AS "continuation_2_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -CROSS JOIN LATERAL (SELECT CASE WHEN ((($3::jsonb))::int) IS NULL THEN (ROW(COALESCE(((FALSE)::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($3::jsonb))::int) = 0 THEN (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($4::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($3::jsonb))::int) = 1 THEN (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($5::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) END AS "row" -OFFSET 0) AS "continuation_2_0" +CROSS JOIN LATERAL (SELECT CASE WHEN ((($3::jsonb))::int) IS NULL THEN (ROW(COALESCE(((FALSE)::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($3::jsonb))::int) = 0 THEN (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($4::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($3::jsonb))::int) = 1 THEN (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($5::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) END AS "row") AS "continuation_2_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_2_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/input-parameter-exists.stdout b/libs/@local/hashql/eval/tests/ui/postgres/input-parameter-exists.stdout index c486a8c507f..7f05969b7b5 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/input-parameter-exists.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/input-parameter-exists.stdout @@ -2,8 +2,7 @@ SELECT ("continuation_3_0"."row")."block" AS "continuation_3_0_block", ("continuation_3_0"."row")."locals" AS "continuation_3_0_locals", ("continuation_3_0"."row")."values" AS "continuation_3_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -CROSS JOIN LATERAL (SELECT CASE WHEN ((($3::jsonb) IS NOT NULL)::int) IS NULL THEN (ROW(COALESCE(((FALSE)::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($3::jsonb) IS NOT NULL)::int) = 0 THEN (ROW(COALESCE(((1)::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($3::jsonb) IS NOT NULL)::int) = 1 THEN (ROW(COALESCE(((($3::jsonb))::boolean), FALSE), NULL, NULL, NULL)::continuation) END AS "row" -OFFSET 0) AS "continuation_3_0" +CROSS JOIN LATERAL (SELECT CASE WHEN ((($3::jsonb) IS NOT NULL)::int) IS NULL THEN (ROW(COALESCE(((FALSE)::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($3::jsonb) IS NOT NULL)::int) = 0 THEN (ROW(COALESCE(((1)::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($3::jsonb) IS NOT NULL)::int) = 1 THEN (ROW(COALESCE(((($3::jsonb))::boolean), FALSE), NULL, NULL, NULL)::continuation) END AS "row") AS "continuation_3_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_3_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/input-parameter-load.stdout b/libs/@local/hashql/eval/tests/ui/postgres/input-parameter-load.stdout index c4ada3afc65..a717f5af892 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/input-parameter-load.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/input-parameter-load.stdout @@ -2,8 +2,7 @@ SELECT ("continuation_2_0"."row")."block" AS "continuation_2_0_block", ("continuation_2_0"."row")."locals" AS "continuation_2_0_locals", ("continuation_2_0"."row")."values" AS "continuation_2_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($3::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row" -OFFSET 0) AS "continuation_2_0" +CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($3::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row") AS "continuation_2_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_2_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/let-binding-propagation.stdout b/libs/@local/hashql/eval/tests/ui/postgres/let-binding-propagation.stdout index c4ada3afc65..a717f5af892 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/let-binding-propagation.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/let-binding-propagation.stdout @@ -2,8 +2,7 @@ SELECT ("continuation_2_0"."row")."block" AS "continuation_2_0_block", ("continuation_2_0"."row")."locals" AS "continuation_2_0_locals", ("continuation_2_0"."row")."values" AS "continuation_2_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($3::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row" -OFFSET 0) AS "continuation_2_0" +CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($3::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row") AS "continuation_2_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_2_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/list-construction.stdout b/libs/@local/hashql/eval/tests/ui/postgres/list-construction.stdout index 1d748250778..e14d3546843 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/list-construction.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/list-construction.stdout @@ -2,8 +2,7 @@ SELECT ("continuation_3_0"."row")."block" AS "continuation_3_0_block", ("continuation_3_0"."row")."locals" AS "continuation_3_0_locals", ("continuation_3_0"."row")."values" AS "continuation_3_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb(jsonb_build_array("entity_temporal_metadata_0_0_0"."entity_uuid", ($3::jsonb))) = to_jsonb(jsonb_build_array(($4::jsonb), "entity_temporal_metadata_0_0_0"."entity_uuid")))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row" -OFFSET 0) AS "continuation_3_0" +CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb(jsonb_build_array("entity_temporal_metadata_0_0_0"."entity_uuid", ($3::jsonb))) = to_jsonb(jsonb_build_array(($4::jsonb), "entity_temporal_metadata_0_0_0"."entity_uuid")))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row") AS "continuation_3_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_3_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/logical-and-inputs.stdout b/libs/@local/hashql/eval/tests/ui/postgres/logical-and-inputs.stdout index 294cd646a29..40f26fdfacc 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/logical-and-inputs.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/logical-and-inputs.stdout @@ -2,8 +2,7 @@ SELECT ("continuation_3_0"."row")."block" AS "continuation_3_0_block", ("continuation_3_0"."row")."locals" AS "continuation_3_0_locals", ("continuation_3_0"."row")."values" AS "continuation_3_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -CROSS JOIN LATERAL (SELECT CASE WHEN ((($3::jsonb))::int) IS NULL THEN (ROW(COALESCE(((FALSE)::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($3::jsonb))::int) = 0 THEN (ROW(COALESCE(((0)::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($3::jsonb))::int) = 1 THEN (ROW(COALESCE(((($4::jsonb))::boolean), FALSE), NULL, NULL, NULL)::continuation) END AS "row" -OFFSET 0) AS "continuation_3_0" +CROSS JOIN LATERAL (SELECT CASE WHEN ((($3::jsonb))::int) IS NULL THEN (ROW(COALESCE(((FALSE)::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($3::jsonb))::int) = 0 THEN (ROW(COALESCE(((0)::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($3::jsonb))::int) = 1 THEN (ROW(COALESCE(((($4::jsonb))::boolean), FALSE), NULL, NULL, NULL)::continuation) END AS "row") AS "continuation_3_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_3_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/minimal-select-no-extra-joins.stdout b/libs/@local/hashql/eval/tests/ui/postgres/minimal-select-no-extra-joins.stdout index e5c6bb7f053..e4fde5b7414 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/minimal-select-no-extra-joins.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/minimal-select-no-extra-joins.stdout @@ -2,8 +2,7 @@ SELECT ("continuation_2_0"."row")."block" AS "continuation_2_0_block", ("continuation_2_0"."row")."locals" AS "continuation_2_0_locals", ("continuation_2_0"."row")."values" AS "continuation_2_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."web_id") = to_jsonb(($3::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row" -OFFSET 0) AS "continuation_2_0" +CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."web_id") = to_jsonb(($3::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row") AS "continuation_2_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_2_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/mixed-sources-filter.stdout b/libs/@local/hashql/eval/tests/ui/postgres/mixed-sources-filter.stdout index 8a609b8857f..726b00ebc07 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/mixed-sources-filter.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/mixed-sources-filter.stdout @@ -4,10 +4,8 @@ SELECT ("continuation_0_0"."row")."block" AS "continuation_0_0_block", ("continu FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" INNER JOIN "entity_editions" AS "entity_editions_0_0_1" ON "entity_editions_0_0_1"."entity_edition_id" = "entity_temporal_metadata_0_0_0"."entity_edition_id" -CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((NOT("entity_editions_0_0_1"."archived"))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row" -OFFSET 0) AS "continuation_0_0" -CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($3::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row" -OFFSET 0) AS "continuation_1_0" +CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((NOT("entity_editions_0_0_1"."archived"))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row") AS "continuation_0_0" +CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($3::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row") AS "continuation_1_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_0_0"."row")."filter" IS NOT FALSE AND ("continuation_1_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/multiple-filters.stdout b/libs/@local/hashql/eval/tests/ui/postgres/multiple-filters.stdout index 10036eef5ac..de36124593b 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/multiple-filters.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/multiple-filters.stdout @@ -2,10 +2,8 @@ SELECT ("continuation_3_0"."row")."block" AS "continuation_3_0_block", ("continuation_3_0"."row")."locals" AS "continuation_3_0_locals", ("continuation_3_0"."row")."values" AS "continuation_3_0_values", ("continuation_4_0"."row")."block" AS "continuation_4_0_block", ("continuation_4_0"."row")."locals" AS "continuation_4_0_locals", ("continuation_4_0"."row")."values" AS "continuation_4_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($3::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row" -OFFSET 0) AS "continuation_3_0" -CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."web_id") = to_jsonb(($4::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row" -OFFSET 0) AS "continuation_4_0" +CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($3::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row") AS "continuation_3_0" +CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."web_id") = to_jsonb(($4::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row") AS "continuation_4_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_3_0"."row")."filter" IS NOT FALSE AND ("continuation_4_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/nested-if-input-branches.stdout b/libs/@local/hashql/eval/tests/ui/postgres/nested-if-input-branches.stdout index acbe1e351eb..0206650b4f0 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/nested-if-input-branches.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/nested-if-input-branches.stdout @@ -2,8 +2,7 @@ SELECT ("continuation_2_0"."row")."block" AS "continuation_2_0_block", ("continuation_2_0"."row")."locals" AS "continuation_2_0_locals", ("continuation_2_0"."row")."values" AS "continuation_2_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -CROSS JOIN LATERAL (SELECT CASE WHEN ((($3::jsonb))::int) IS NULL THEN (ROW(COALESCE(((FALSE)::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($3::jsonb))::int) = 0 THEN (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($4::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($3::jsonb))::int) = 1 THEN CASE WHEN ((($5::jsonb))::int) IS NULL THEN (ROW(COALESCE(((FALSE)::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($5::jsonb))::int) = 0 THEN (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($6::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($5::jsonb))::int) = 1 THEN (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($7::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) END END AS "row" -OFFSET 0) AS "continuation_2_0" +CROSS JOIN LATERAL (SELECT CASE WHEN ((($3::jsonb))::int) IS NULL THEN (ROW(COALESCE(((FALSE)::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($3::jsonb))::int) = 0 THEN (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($4::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($3::jsonb))::int) = 1 THEN CASE WHEN ((($5::jsonb))::int) IS NULL THEN (ROW(COALESCE(((FALSE)::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($5::jsonb))::int) = 0 THEN (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($6::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($5::jsonb))::int) = 1 THEN (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($7::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) END END AS "row") AS "continuation_2_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_2_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/opaque-passthrough.stdout b/libs/@local/hashql/eval/tests/ui/postgres/opaque-passthrough.stdout index d53b113737a..929a5af3f3a 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/opaque-passthrough.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/opaque-passthrough.stdout @@ -2,8 +2,7 @@ SELECT ("continuation_8_0"."row")."block" AS "continuation_8_0_block", ("continuation_8_0"."row")."locals" AS "continuation_8_0_locals", ("continuation_8_0"."row")."values" AS "continuation_8_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($3::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row" -OFFSET 0) AS "continuation_8_0" +CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb("entity_temporal_metadata_0_0_0"."entity_uuid") = to_jsonb(($3::jsonb)))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row") AS "continuation_8_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_8_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/struct-construction.stdout b/libs/@local/hashql/eval/tests/ui/postgres/struct-construction.stdout index 82cb6f0d73f..39824bc3905 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/struct-construction.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/struct-construction.stdout @@ -2,8 +2,7 @@ SELECT ("continuation_4_0"."row")."block" AS "continuation_4_0_block", ("continuation_4_0"."row")."locals" AS "continuation_4_0_locals", ("continuation_4_0"."row")."values" AS "continuation_4_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb(jsonb_build_object(($3::text), "entity_temporal_metadata_0_0_0"."entity_uuid", ($4::text), "entity_temporal_metadata_0_0_0"."web_id")) = to_jsonb(jsonb_build_object(($3::text), ($5::jsonb), ($4::text), ($6::jsonb))))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row" -OFFSET 0) AS "continuation_4_0" +CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb(jsonb_build_object(($3::text), "entity_temporal_metadata_0_0_0"."entity_uuid", ($4::text), "entity_temporal_metadata_0_0_0"."web_id")) = to_jsonb(jsonb_build_object(($3::text), ($5::jsonb), ($4::text), ($6::jsonb))))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row") AS "continuation_4_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_4_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/tuple-construction.stdout b/libs/@local/hashql/eval/tests/ui/postgres/tuple-construction.stdout index 545cec8b6a8..1e3cfc00abb 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/tuple-construction.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/tuple-construction.stdout @@ -2,8 +2,7 @@ SELECT ("continuation_4_0"."row")."block" AS "continuation_4_0_block", ("continuation_4_0"."row")."locals" AS "continuation_4_0_locals", ("continuation_4_0"."row")."values" AS "continuation_4_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb(jsonb_build_array("entity_temporal_metadata_0_0_0"."entity_uuid", "entity_temporal_metadata_0_0_0"."web_id")) = to_jsonb(jsonb_build_array(($3::jsonb), ($4::jsonb))))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row" -OFFSET 0) AS "continuation_4_0" +CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((to_jsonb(jsonb_build_array("entity_temporal_metadata_0_0_0"."entity_uuid", "entity_temporal_metadata_0_0_0"."web_id")) = to_jsonb(jsonb_build_array(($3::jsonb), ($4::jsonb))))::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row") AS "continuation_4_0" WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_4_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/mir/src/lib.rs b/libs/@local/hashql/mir/src/lib.rs index dcbb5b2831a..ee755063446 100644 --- a/libs/@local/hashql/mir/src/lib.rs +++ b/libs/@local/hashql/mir/src/lib.rs @@ -23,12 +23,13 @@ iter_collect_into, likely_unlikely, maybe_uninit_fill, + maybe_uninit_uninit_array_transpose, option_into_flat_iter, step_trait, temporary_niche_types, try_trait_v2, variant_count, - maybe_uninit_uninit_array_transpose + iterator_try_reduce )] #![cfg_attr(test, feature( // Library Features diff --git a/libs/@local/hashql/mir/src/pass/analysis/callgraph/mod.rs b/libs/@local/hashql/mir/src/pass/analysis/callgraph/mod.rs index 8c26ca917bf..028911c12dc 100644 --- a/libs/@local/hashql/mir/src/pass/analysis/callgraph/mod.rs +++ b/libs/@local/hashql/mir/src/pass/analysis/callgraph/mod.rs @@ -228,6 +228,17 @@ impl CallGraph<'_, A> { Some(DefId::new(edge.source().as_u32())) } + + #[inline] + pub fn callers(&self, def: DefId) -> impl Iterator { + let node = NodeId::from_usize(def.as_usize()); + + self.inner.incoming_edges(node).map(move |edge| CallSite { + caller: DefId::new(edge.source().as_u32()), + kind: edge.data, + target: def, + }) + } } impl fmt::Display for CallGraph<'_, A> { diff --git a/libs/@local/hashql/mir/src/pass/analysis/data_dependency/resolve.rs b/libs/@local/hashql/mir/src/pass/analysis/data_dependency/resolve.rs index 254394474da..f57dc651b3d 100644 --- a/libs/@local/hashql/mir/src/pass/analysis/data_dependency/resolve.rs +++ b/libs/@local/hashql/mir/src/pass/analysis/data_dependency/resolve.rs @@ -181,15 +181,31 @@ fn traverse<'heap, A: Allocator + Clone>( } } -/// Attempts to resolve a block parameter by checking all predecessor edges. +/// Attempts to resolve a block parameter by checking all predecessor values. /// -/// A block parameter may receive values from multiple predecessor blocks. This function -/// traverses all [`Param`] edges and checks whether they resolve to the same source. -/// If all predecessors agree, resolution continues through that common source. +/// A block parameter may receive values from multiple predecessor blocks, either as +/// graph edges (when arguments are places) or constant bindings (when arguments are +/// constants). This function checks whether all non-cyclic predecessors, from both +/// sources, resolve to the same value. /// -/// Handles cycle detection: if we encounter a local already in the `visited` set, -/// we return [`Backtrack`] to unwind. The cycle root (where `visited` was first -/// initialized) catches the backtrack and returns [`Incomplete`]. +/// # Projection-aware consensus +/// +/// When the queried place has a projection suffix (e.g., resolving `x.0` where `x` is a +/// block parameter), consensus is checked on the *fully resolved* result per predecessor, +/// not on the partially resolved predecessor bases. This is necessary because different +/// predecessor locals can still agree on a projected field. +/// +/// For example, if predecessor A passes `(42, u)` and predecessor B passes `(42, v)`, +/// the bases disagree but `A.0 == B.0 == 42`. The algorithm resolves each predecessor +/// through the full projection suffix before comparing, so this case correctly yields +/// `Resolved(42)` rather than `Incomplete(x.0)`. +/// +/// # Cycle handling +/// +/// Cyclic predecessors ([`Backtrack`]) are filtered out before consensus checking. +/// Since [`Param`] edges are identity transfers, the value is fully determined by +/// the non-cyclic init edges. If only cyclic predecessors exist (no external source), +/// the cycle root returns [`Incomplete`] and non-root nodes propagate [`Backtrack`]. /// /// [`Param`]: EdgeKind::Param /// [`Backtrack`]: ResolutionResult::Backtrack @@ -198,54 +214,120 @@ fn resolve_params<'heap, A: Allocator + Clone>( mut state: ResolutionState<'_, '_, 'heap, A>, place: PlaceRef<'_, 'heap>, ) -> ControlFlow, Local> { - let mut edges = state.graph.outgoing_edges(place.local); - let Some(head) = edges.next() else { - unreachable!("caller must guarantee that at least one Param edge exists") - }; + let graph = state.graph; + + // Check whether graph Param edges exist (cycle detection is only relevant for graph edges, + // which are the only source of back-edges). + let has_graph_edges = graph.outgoing_edges(place.local).next().is_some(); // Cycle detection: if we've already visited this local, backtrack. - if let Some(visited) = &mut state.visited + if has_graph_edges + && let Some(visited) = &mut state.visited && !visited.insert(place.local) { return ControlFlow::Break(ResolutionResult::Backtrack); } - // Initialize cycle tracking if this is the first Param traversal. + // Initialize cycle tracking if this is the first Param traversal with graph edges. let mut owned_visited = None; - let visited_ref = state.visited.as_deref_mut().or_else(|| { - let mut set = DenseBitSet::new_empty(state.graph.graph.node_count()); - set.insert(place.local); - - owned_visited = Some(set); - owned_visited.as_mut() - }); + let visited_ref = if has_graph_edges { + state.visited.as_deref_mut().or_else(|| { + let mut set = DenseBitSet::new_empty(graph.graph.node_count()); + set.insert(place.local); + + owned_visited = Some(set); + owned_visited.as_mut() + }) + } else { + state.visited.as_deref_mut() + }; let mut rec_state = ResolutionState { - graph: state.graph, + graph, interner: state.interner, alloc: state.alloc.clone(), visited: visited_ref, }; - let first = traverse(rec_state.cloned(), place, head); + // Resolve all predecessor candidates and check consensus. + // + // When the queried place has projections (e.g., `x.field`), each predecessor is resolved + // through the full projection suffix before consensus comparison. If `traverse` returns + // `Continue(local)` (predecessor base resolved to a bare local), we call `resolve` on + // `local.projections` to complete the resolution. This ensures consensus is checked on + // the final value, not intermediate bases that may differ structurally but agree on the + // projected component. + // + // Cyclic predecessors (Backtrack) are skipped: since Param edges are identity transfers, + // the value is fully determined by the non-cyclic init edges. If only cyclic predecessors + // exist, we cannot resolve (the value has no external source). + let graph_edges = graph.outgoing_edges(place.local).map(|edge| { + let result = traverse(rec_state.cloned(), place, edge); + + match result { + // Predecessor resolved to a bare local, but the query has remaining projections. + // Finish resolving through the projection suffix so consensus compares final values. + ControlFlow::Continue(local) if !place.projections.is_empty() => { + ControlFlow::Break(resolve( + rec_state.cloned(), + PlaceRef { + local, + projections: place.projections, + }, + )) + } + ControlFlow::Continue(_) | ControlFlow::Break(_) => result, + } + }); + let constant_edges = graph + .constant_bindings + .iter_by_kind(place.local, EdgeKind::Param) + .map(|constant| { + ControlFlow::Break(ResolutionResult::Resolved(Operand::Constant(constant))) + }); - // Check consensus: all predecessors must resolve to the same result. - let all_agree = edges.all(|edge| traverse(rec_state.cloned(), place, edge) == first); + // `try_reduce` returns: + // `Some(Some(v))` when all predecessors agree on `v` + // `Some(None)` when the iterator is empty (unreachable: caller guarantees predecessors) + // `None` when the closure short-circuits (predecessors disagree) + let mut backtrack_occurred = false; + let consensus = graph_edges + .chain(constant_edges) + .filter(|candidate| { + if matches!(candidate, ControlFlow::Break(ResolutionResult::Backtrack)) { + backtrack_occurred = true; + return false; + } - if all_agree { - // If we initiated backtracking (owned_visited is Some) and got Backtrack, - // we are the cycle root and should treat this as incomplete. - let is_cycle_root = - first == ControlFlow::Break(ResolutionResult::Backtrack) && owned_visited.is_some(); + true + }) + .try_reduce(|lhs, rhs| (lhs == rhs).then_some(lhs)); - if !is_cycle_root { + match consensus { + // Predecessors agree on a value. + Some(Some(consensus)) => { // Clean up visited state before returning. if let Some(visited) = state.visited { visited.remove(place.local); } - return first; + return consensus; + } + + // All candidates were cyclic (no non-cyclic predecessors to determine the value). + // If we're not the cycle root, propagate Backtrack so the root can handle it. + Some(None) if backtrack_occurred && owned_visited.is_none() => { + if let Some(visited) = &mut state.visited { + visited.remove(place.local); + } + + return ControlFlow::Break(ResolutionResult::Backtrack); } + // Pure cycle at root: fall through to Incomplete. + Some(None) if backtrack_occurred => {} + Some(None) => unreachable!("caller must guarantee at least one Param predecessor exists"), + // Predecessors disagree. + None => {} } // Clean up visited state before returning incomplete. @@ -253,7 +335,7 @@ fn resolve_params<'heap, A: Allocator + Clone>( visited.remove(place.local); } - // Predecessors diverge or a cycle was detected; cannot resolve through this param. + // Non-cyclic predecessors diverge, or pure cycle at root. let mut projections = VecDeque::new_in(state.alloc.clone()); projections.extend(place.projections); @@ -263,53 +345,17 @@ fn resolve_params<'heap, A: Allocator + Clone>( })) } -/// Attempts to resolve a block parameter by checking constant bindings from all predecessors. -/// -/// This handles the case where a block parameter receives constant values from predecessor -/// blocks, but has no graph edges (only constant bindings with [`Param`] kind). The function -/// checks whether all predecessors provide the same constant value. -/// -/// Unlike [`resolve_params`], this function does not need cycle detection because it only -/// examines constant bindings, not graph edges that could form back-edges. -/// -/// # Returns -/// -/// - [`Resolved(Constant)`] if all predecessor constants agree on the same value -/// - [`Resolved(Place)`] if predecessors diverge (the place remains valid but has no constant) -/// -/// [`Param`]: EdgeKind::Param -/// [`Resolved(Constant)`]: ResolutionResult::Resolved -/// [`Resolved(Place)`]: ResolutionResult::Resolved -fn resolve_params_const<'heap, A: Allocator + Clone>( - state: &ResolutionState<'_, '_, 'heap, A>, - place: PlaceRef<'_, 'heap>, -) -> ResolutionResult<'heap, A> { - debug_assert!(place.projections.is_empty()); - let mut constants = state - .graph - .constant_bindings - .iter_by_kind(place.local, EdgeKind::Param); - let Some(head) = constants.next() else { - unreachable!("caller must guarantee that at least one Param edge exists") - }; - - let all_agree = constants.all(|constant| constant == head); - if all_agree { - ResolutionResult::Resolved(Operand::Constant(head)) - } else { - // We have finished (we have terminated on a param, which is divergent, therefore the place - // is still valid, just doesn't have a constant value) - ResolutionResult::Resolved(Operand::Place(Place::local(place.local))) - } -} - /// Resolves a place to its ultimate data source by traversing the dependency graph. /// /// Starting from `place`, this function follows edges in the dependency graph to find where /// the data ultimately originates. The algorithm handles three types of edges: /// /// - **[`Load`]**: Always followed transitively (a load has exactly one source) -/// - **[`Param`]**: Followed only if all predecessors agree on the same source (consensus) +/// - **[`Param`]**: Followed only if all predecessors agree on the same source (consensus). +/// Consensus is checked on fully resolved results: when the queried place has projections, each +/// predecessor is resolved through the complete projection suffix before comparison. This allows +/// resolution through φ-nodes where predecessor bases differ but the projected component agrees +/// (e.g., `(42, a)` and `(42, b)` agree on field `.0`). /// - **[`Index`]/[`Field`]**: Matched against projections to trace through aggregates /// /// Resolution terminates with: @@ -332,13 +378,10 @@ pub(crate) fn resolve<'heap, A: Allocator + Clone>( mut place: PlaceRef<'_, 'heap>, ) -> ResolutionResult<'heap, A> { // Scan outgoing edges to find Load and count Param edges. - let mut edges = 0_usize; let mut params = 0_usize; let mut load_edge = None; for edge in state.graph.outgoing_edges(place.local) { - edges += 1; - match edge.data.kind { EdgeKind::Load => load_edge = Some(edge), EdgeKind::Param => params += 1, @@ -355,26 +398,15 @@ pub(crate) fn resolve<'heap, A: Allocator + Clone>( } // Attempt to resolve through Param edges, if all predecessors agree. - // There are fundamentally two cases: - // - Either all graph edges are Param edges, or - // - all constant bindings are Param edges - if edges == 0 - && state - .graph - .constant_bindings - .find_by_kind(place.local, EdgeKind::Param) - .is_some() - { - return resolve_params_const(&state, place); - } + // Predecessors may arrive as graph edges (place arguments), constant bindings + // (constant arguments), or a mix of both. All sources are checked for consensus. + let has_param_constants = state + .graph + .constant_bindings + .find_by_kind(place.local, EdgeKind::Param) + .is_some(); - if params > 0 - && state - .graph - .constant_bindings - .find_by_kind(place.local, EdgeKind::Param) - .is_none() - { + if params > 0 || has_param_constants { place.local = tri!(resolve_params(state.cloned(), place)); } diff --git a/libs/@local/hashql/mir/src/pass/analysis/data_dependency/tests.rs b/libs/@local/hashql/mir/src/pass/analysis/data_dependency/tests.rs index b840c7435f0..ee724ed3cf6 100644 --- a/libs/@local/hashql/mir/src/pass/analysis/data_dependency/tests.rs +++ b/libs/@local/hashql/mir/src/pass/analysis/data_dependency/tests.rs @@ -300,6 +300,137 @@ fn param_cycle_detection() { ); } +/// Tests that a loop-carried parameter resolves through the non-cyclic init edge +/// when the back-edge just passes the value through unchanged. +/// +/// The init edge provides constant 42, the back-edge creates a cycle (x depends on x). +/// Since cyclic predecessors are identity transfers, the non-cyclic init edge determines +/// the value: x should resolve to 42. +#[test] +fn param_cycle_with_const_init() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; fn@0/0 -> Int { + decl x: Int, cond: Int; + + bb0() { + cond = input.load! "cond"; + goto bb1(42); + }, + bb1(x) { + if cond then bb1(x) else bb2(x); + }, + bb2(x) { + return x; + } + }); + + assert_data_dependency( + "param_cycle_with_const_init", + &body, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + ); +} + +/// Tests that a multi-node cycle with a constant init edge resolves correctly, +/// even when the node with the init edge is not the cycle root. +/// +/// The cycle is x -> y -> x (through bb1 -> bb2 -> bb1). The init edge provides +/// constant 42 to x from bb0. During resolution of y, x is encountered as a non-root +/// participant in the cycle. x must skip the cyclic Backtrack from y and use its +/// non-cyclic constant init edge to resolve to 42, which then propagates through y +/// and out to the consumer (result). +#[test] +fn param_cycle_multi_node_with_const_init() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; fn@0/0 -> Int { + decl x: Int, y: Int, cond: Int, result: Int; + + bb0() { + cond = input.load! "cond"; + goto bb1(42); + }, + bb1(x) { + goto bb2(x); + }, + bb2(y) { + if cond then bb1(y) else bb3(y); + }, + bb3(result) { + return result; + } + }); + + assert_data_dependency( + "param_cycle_multi_node_with_const_init", + &body, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + ); +} + +/// Tests that the visited set is cleaned up when non-cyclic predecessors disagree +/// inside another node's cycle resolution. +/// +/// y has a self-loop (creating a cycle). When resolving y, the cycle root tracks +/// visited locals. x is resolved inside y's resolution and has disagreeing predecessors +/// (constant 42 from bb0, opaque `input` from bb1). x must remove itself from the +/// visited set before returning Incomplete, otherwise later resolutions would see +/// false cycle detections. +#[test] +fn param_cycle_visited_cleanup_on_diverge() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; fn@0/0 -> Int { + decl input: Int, x: Int, y: Int, cond: Int; + + bb0() { + input = input.load! "x"; + cond = input.load! "cond"; + goto bb3(42); + }, + bb1() { + goto bb3(input); + }, + bb3(x) { + goto bb4(x); + }, + bb4(y) { + if cond then bb4(y) else bb5(y); + }, + bb5(y) { + return y; + } + }); + + assert_data_dependency( + "param_cycle_visited_cleanup_on_diverge", + &body, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + ); +} + /// Tests constant propagation through edges. #[test] fn constant_propagation() { @@ -544,3 +675,194 @@ fn projection_prepending_opaque_source() { }, ); } + +/// Tests mixed Param resolution through nested tuple wrapping where predecessors provide +/// a mix of constants and projections that all resolve to the same value. +#[test] +fn load_param_mixed() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; fn@0/0 -> Int { + decl _1: (Int), _3: Int, _4: (Int), _5: Int; + @proj _1_0 = _1.0: Int, _4_0 = _4.0: Int; + + bb0() { + goto bb2(42); + }, + bb1() { + _1 = tuple 42; + goto bb2(_1_0); + }, + bb2(_3) { + _4 = tuple _3; + goto bb4(_4_0); + }, + bb3() { + goto bb4(42); + }, + bb4(_5) { + return _5; + } + }); + + assert_data_dependency( + "load_param_mixed", + &body, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + ); +} + +/// Tests that Param consensus resolves through projections when predecessors are +/// different tuples but the queried field is the same constant. +/// +/// Both paths construct different tuples (`a = (42, u)`, `b = (42, v)`) but the +/// `.0` field is the same constant `42` in both. Current algorithm compares the +/// tuple bases (`a` vs `b`), which disagree, so it returns `Incomplete(x.0)`. +/// Correct behavior: resolve `a.0` and `b.0` individually, find they both yield +/// `42`, and return `Resolved(42)`. +#[test] +fn param_consensus_projected_field_const() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; fn@0/0 -> Int { + decl u: Int, v: Int, a: (Int, Int), b: (Int, Int), cond: Int, x: (Int, Int), result: Int; + @proj x_0 = x.0: Int; + + bb0() { + u = input.load! "u"; + v = input.load! "v"; + cond = input.load! "cond"; + a = tuple 42, u; + b = tuple 42, v; + if cond then bb1() else bb2(); + }, + bb1() { + goto bb3(a); + }, + bb2() { + goto bb3(b); + }, + bb3(x) { + result = load x_0; + return result; + } + }); + + assert_data_dependency( + "param_consensus_projected_field_const", + &body, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + ); +} + +/// Tests that Param consensus resolves through projections when predecessors are +/// different tuples but the queried field is the same place. +/// +/// Both paths construct different tuples (`a = (src, u)`, `b = (src, v)`) but the +/// `.0` field is the same local `src` in both. Current algorithm compares the +/// tuple bases (`a` vs `b`), which disagree, so it returns `Incomplete(x.0)`. +/// Correct behavior: resolve `a.0` and `b.0` individually, find they both yield +/// `src`, and return `Resolved(src)`. +#[test] +fn param_consensus_projected_field_place() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; fn@0/0 -> Int { + decl src: Int, u: Int, v: Int, a: (Int, Int), b: (Int, Int), cond: Int, x: (Int, Int), result: Int; + @proj x_0 = x.0: Int; + + bb0() { + src = input.load! "src"; + u = input.load! "u"; + v = input.load! "v"; + cond = input.load! "cond"; + a = tuple src, u; + b = tuple src, v; + if cond then bb1() else bb2(); + }, + bb1() { + goto bb3(a); + }, + bb2() { + goto bb3(b); + }, + bb3(x) { + result = load x_0; + return result; + } + }); + + assert_data_dependency( + "param_consensus_projected_field_place", + &body, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + ); +} + +/// Tests that a cycle with a loop-invariant projected field resolves correctly. +/// +/// The cycle is `x -> x` via the back-edge in `bb1`. The init edge provides +/// `init = (src, other)`. The back-edge reconstructs `t = (x.0, other)`, +/// preserving `x.0` across iterations. So `x.0` is loop-invariant and should +/// resolve to `src`. Current algorithm compares `init` vs `t` as bases, which +/// disagree, yielding `Incomplete(x.0)`. Correct behavior: resolve the full +/// `init.0 = src` and see that `t.0 = x.0` is a cyclic identity, so the +/// non-cyclic init determines the answer. +#[test] +fn param_cycle_invariant_projected_field() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; fn@0/0 -> Int { + decl src: Int, other: Int, init: (Int, Int), x: (Int, Int), t: (Int, Int), cond: Int, result: Int; + @proj x_0 = x.0: Int; + + bb0() { + src = input.load! "src"; + other = input.load! "other"; + cond = input.load! "cond"; + init = tuple src, other; + goto bb1(init); + }, + bb1(x) { + t = tuple x_0, other; + if cond then bb1(t) else bb2(x_0); + }, + bb2(result) { + return result; + } + }); + + assert_data_dependency( + "param_cycle_invariant_projected_field", + &body, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + ); +} diff --git a/libs/@local/hashql/mir/src/pass/analysis/dataflow/framework.rs b/libs/@local/hashql/mir/src/pass/analysis/dataflow/framework.rs index 481ee5bd027..69440bcdc84 100644 --- a/libs/@local/hashql/mir/src/pass/analysis/dataflow/framework.rs +++ b/libs/@local/hashql/mir/src/pass/analysis/dataflow/framework.rs @@ -63,6 +63,16 @@ pub enum Direction { Backward, } +impl Direction { + #[must_use] + pub const fn reverse(self) -> Self { + match self { + Self::Forward => Self::Backward, + Self::Backward => Self::Forward, + } + } +} + /// The results of a dataflow analysis after reaching a fixed point. /// /// Contains the computed abstract state at both entry and exit of each basic block. diff --git a/libs/@local/hashql/mir/src/pass/execution/placement/solve/csp/mod.rs b/libs/@local/hashql/mir/src/pass/execution/placement/solve/csp/mod.rs index 2480b0c4aec..59fd581633b 100644 --- a/libs/@local/hashql/mir/src/pass/execution/placement/solve/csp/mod.rs +++ b/libs/@local/hashql/mir/src/pass/execution/placement/solve/csp/mod.rs @@ -114,6 +114,21 @@ impl CyclicPlacementRegion<'_> { } } +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub(crate) enum ConstraintSatisfactionMode { + Initial, + Adjustment, +} + +impl ConstraintSatisfactionMode { + const fn config(self) -> CostEstimationConfig { + match self { + Self::Initial => CostEstimationConfig::LOOP, + Self::Adjustment => CostEstimationConfig::TRIVIAL, + } + } +} + /// CSP solver for assigning targets within a cyclic placement region. /// /// Borrows the parent [`PlacementSolver`] for cost estimation and target resolution. @@ -124,6 +139,7 @@ pub(crate) struct ConstraintSatisfaction<'ctx, 'parent, 'alloc, A: Allocator, S: pub region: CyclicPlacementRegion<'alloc>, pub depth: usize, + pub mode: ConstraintSatisfactionMode, // Branch-and-bound state (only used when members.len() <= BNB_CUTOFF) cost_deltas: [ApproxCost; BNB_CUTOFF], @@ -136,6 +152,7 @@ impl<'ctx, 'parent, 'alloc, A: Allocator, S: BumpAllocator> /// Creates a new CSP solver for the given cyclic `region`. pub(crate) const fn new( solver: &'ctx mut PlacementSolver<'parent, 'alloc, A, S>, + mode: ConstraintSatisfactionMode, id: PlacementRegionId, region: CyclicPlacementRegion<'alloc>, ) -> Self { @@ -144,6 +161,7 @@ impl<'ctx, 'parent, 'alloc, A: Allocator, S: BumpAllocator> id, region, depth: 0, + mode, cost_deltas: [ApproxCost::ZERO; BNB_CUTOFF], cost_so_far: ApproxCost::ZERO, } @@ -332,7 +350,7 @@ impl<'ctx, 'parent, 'alloc, A: Allocator, S: BumpAllocator> self.region.blocks.swap(self.depth, self.depth + offset); let mut heap = CostEstimation { - config: CostEstimationConfig::LOOP, + config: self.mode.config(), solver: self.solver, determine_target: |block| { if let Some(member) = self.region.find_block(block) { @@ -373,7 +391,7 @@ impl<'ctx, 'parent, 'alloc, A: Allocator, S: BumpAllocator> target: TargetId, ) -> ApproxCost { let estimator = CostEstimation { - config: CostEstimationConfig::LOOP, + config: self.mode.config(), solver: self.solver, determine_target: |block| { self.region.find_block(block).map_or_else( @@ -520,7 +538,7 @@ impl<'ctx, 'parent, 'alloc, A: Allocator, S: BumpAllocator> self.region.blocks.swap(self.depth, self.depth + offset); let heap = CostEstimation { - config: CostEstimationConfig::LOOP, + config: self.mode.config(), solver: self.solver, determine_target: |block| { if let Some(member) = self.region.find_block(block) { diff --git a/libs/@local/hashql/mir/src/pass/execution/placement/solve/csp/tests.rs b/libs/@local/hashql/mir/src/pass/execution/placement/solve/csp/tests.rs index e993a2a9474..bb7ebd4ec23 100644 --- a/libs/@local/hashql/mir/src/pass/execution/placement/solve/csp/tests.rs +++ b/libs/@local/hashql/mir/src/pass/execution/placement/solve/csp/tests.rs @@ -15,7 +15,7 @@ use crate::{ placement::solve::{ PlacementRegionId, PlacementSolverContext, condensation::PlacementRegionKind, - csp::ConstraintSatisfaction, + csp::{self, ConstraintSatisfaction}, tests::{ all_targets, bb, fix_block, make_block_costs, stmt_costs, target_set, terminators, }, @@ -87,7 +87,12 @@ fn narrow_restricts_successor_domain() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); csp.seed(); fix_block(&mut csp, bb(0), I); @@ -137,7 +142,12 @@ fn narrow_restricts_predecessor_domain() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); csp.seed(); fix_block(&mut csp, bb(0), I); @@ -187,7 +197,12 @@ fn narrow_to_empty_domain() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); csp.seed(); fix_block(&mut csp, bb(0), I); @@ -240,7 +255,12 @@ fn narrow_multiple_edges_intersect() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); csp.seed(); // Assign bb0 = I and narrow @@ -299,7 +319,12 @@ fn replay_narrowing_resets_then_repropagates() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); csp.seed(); // Step 1: assign bb0 = I, narrow @@ -377,7 +402,12 @@ fn lower_bound_min_block_cost_per_block() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); csp.seed(); // Fix bb0 at depth 0 @@ -432,7 +462,12 @@ fn lower_bound_min_transition_cost_per_edge() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); csp.seed(); fix_block(&mut csp, bb(0), I); @@ -486,7 +521,12 @@ fn lower_bound_skips_self_loop_edges() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); csp.seed(); csp.depth = 0; @@ -539,7 +579,12 @@ fn lower_bound_fixed_successor_uses_concrete_target() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); csp.seed(); // Fix bb0 and bb2 (target=P), leaving bb1 unfixed @@ -595,7 +640,12 @@ fn lower_bound_all_fixed_returns_zero() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); csp.seed(); // Fix both blocks @@ -650,7 +700,12 @@ fn mrv_selects_smallest_domain() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); csp.seed(); csp.depth = 0; @@ -700,7 +755,12 @@ fn mrv_tiebreak_by_constraint_degree() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); csp.seed(); csp.depth = 0; @@ -753,7 +813,12 @@ fn mrv_skips_fixed_blocks() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); csp.seed(); // Fix bb0 at position 0 @@ -810,7 +875,12 @@ fn greedy_solves_two_block_loop() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); csp.seed(); assert!(csp.run_greedy(&body)); @@ -864,7 +934,12 @@ fn greedy_rollback_finds_alternative() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); csp.seed(); assert!(csp.run_greedy(&body)); @@ -921,7 +996,12 @@ fn greedy_fails_when_infeasible() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); csp.seed(); assert!(!csp.run_greedy(&body)); @@ -978,7 +1058,12 @@ fn bnb_finds_optimal() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); assert!(csp.solve(&body)); // all-I = stmt(10+1+1) + trans(0) = 12 @@ -1033,7 +1118,12 @@ fn bnb_retains_ranked_solutions() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); assert!(csp.solve(&body)); @@ -1109,7 +1199,12 @@ fn bnb_pruning_preserves_optimal() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); assert!(csp.solve(&body)); // All blocks should get the same target (cost = 4) @@ -1165,7 +1260,12 @@ fn retry_returns_ranked_solutions_in_order() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); assert!(csp.solve(&body)); @@ -1235,7 +1335,12 @@ fn retry_exhausts_then_perturbs() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); assert!(csp.solve(&body)); // Only same-target transitions allowed, so valid assignments are (I,I) and (P,P). @@ -1302,7 +1407,12 @@ fn greedy_rollback_on_empty_heap() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); csp.seed(); assert!(csp.run_greedy(&body)); @@ -1372,7 +1482,12 @@ fn retry_perturbation_after_ranked_exhaustion() { }; let mut solver = data.build_in(&body, &heap); let (region_id, region) = take_cyclic(&mut solver); - let mut csp = ConstraintSatisfaction::new(&mut solver, region_id, region); + let mut csp = ConstraintSatisfaction::new( + &mut solver, + csp::ConstraintSatisfactionMode::Initial, + region_id, + region, + ); // solve() uses BnB (2 blocks ≤ BNB_CUTOFF=12), applies best solution assert!(csp.solve(&body)); diff --git a/libs/@local/hashql/mir/src/pass/execution/placement/solve/mod.rs b/libs/@local/hashql/mir/src/pass/execution/placement/solve/mod.rs index 690b2033352..502ba34ac7e 100644 --- a/libs/@local/hashql/mir/src/pass/execution/placement/solve/mod.rs +++ b/libs/@local/hashql/mir/src/pass/execution/placement/solve/mod.rs @@ -4,12 +4,14 @@ //! single-block regions and cyclic multi-block SCCs. [`CostEstimation`] ranks candidate targets //! for trivial regions, while [`ConstraintSatisfaction`] handles cyclic ones. //! -//! The forward pass processes regions in topological order, the backward pass in reverse for -//! refinement. When assignment fails, [`PlacementSolver::rewind`] walks backward to find a region -//! that can change its assignment. +//! The forward pass processes regions in topological order, assigning targets greedily. +//! Adjustment passes then alternate direction (backward, forward, backward, ...) until no +//! assignment changes, converging to a local minimum in which no single-region change reduces +//! cost. When assignment fails during the forward pass, [`PlacementSolver::rewind`] walks +//! backward to find a region that can change its assignment. //! //! Entry point: [`PlacementSolverContext::build_in`] constructs a [`PlacementSolver`], then -//! [`PlacementSolver::run_in`] executes both passes. +//! [`PlacementSolver::run_in`] executes the forward pass and iterates adjustment passes. use core::{alloc::Allocator, mem}; @@ -27,9 +29,12 @@ use crate::{ basic_block::{BasicBlockId, BasicBlockSlice, BasicBlockVec}, }, context::MirContext, - pass::execution::{ - ApproxCost, cost::BasicBlockCostVec, target::TargetId, - terminator_placement::TerminatorTransitionCostVec, + pass::{ + analysis::dataflow::framework::Direction, + execution::{ + ApproxCost, cost::BasicBlockCostVec, target::TargetId, + terminator_placement::TerminatorTransitionCostVec, + }, }, }; @@ -125,9 +130,9 @@ impl<'ctx, A: Allocator> PlacementSolverContext<'ctx, A> { /// Assigns execution targets to basic blocks by solving over the condensation graph. /// -/// Uses a two-pass approach: the forward pass assigns targets in topological order, the backward -/// pass refines them with full boundary context. Rewind-based backtracking recovers from -/// assignment failures in the forward pass. +/// The forward pass assigns targets in topological order. Adjustment passes then alternate +/// direction until convergence, refining assignments with progressively fuller boundary context. +/// Rewind-based backtracking recovers from assignment failures in the forward pass. // We need two allocators here, because the `BumpAllocator` trait does not carry a lifetime, but we // move `Copy` data into the bump allocator. pub(crate) struct PlacementSolver<'ctx, 'alloc, S1: Allocator, S2: BumpAllocator> { @@ -142,8 +147,8 @@ pub(crate) struct PlacementSolver<'ctx, 'alloc, S1: Allocator, S2: BumpAllocator } impl<'alloc, S1: Allocator, S: BumpAllocator> PlacementSolver<'_, 'alloc, S1, S> { - /// Runs the forward and backward passes, returning the chosen [`TargetId`] for each basic - /// block. + /// Runs the forward pass and iterates adjustment passes until convergence, returning the + /// chosen [`TargetId`] for each basic block. pub(crate) fn run_in<'heap, A: Allocator>( &mut self, context: &mut MirContext<'_, 'heap>, @@ -166,9 +171,16 @@ impl<'alloc, S1: Allocator, S: BumpAllocator> PlacementSolver<'_, 'alloc, S1, S> .diagnostics .push(unsatisfiable_placement(body.span, block_span, &failure)); } else { - // Only run the backward refinement pass if the forward pass succeeded — - // there is nothing to refine when blocks remain unassigned. - self.run_backwards_loop(body, ®ions); + // Iterate adjustment passes in alternating directions until convergence. + // Only entered when the forward pass succeeded — nothing to refine with + // unassigned blocks. + let mut has_changed = true; + let mut direction = Direction::Backward; + + while has_changed { + has_changed = self.run_adjustment(direction, body, ®ions); + direction = direction.reverse(); + } } // Collect the final assignments into the output vec. Unassigned blocks (from a @@ -221,7 +233,12 @@ impl<'alloc, S1: Allocator, S: BumpAllocator> PlacementSolver<'_, 'alloc, S1, S> self.condensation[region_id].kind = kind; } PlacementRegionKind::Cyclic(cyclic) => { - let mut csp = ConstraintSatisfaction::new(self, region_id, cyclic); + let mut csp = ConstraintSatisfaction::new( + self, + csp::ConstraintSatisfactionMode::Initial, + region_id, + cyclic, + ); if csp.retry(body) { // Found a perturbation — flush the new assignments, and resume. @@ -327,7 +344,12 @@ impl<'alloc, S1: Allocator, S: BumpAllocator> PlacementSolver<'_, 'alloc, S1, S> } PlacementRegionKind::Cyclic(cyclic) => { let members = cyclic.members; - let mut csp = ConstraintSatisfaction::new(self, region_id, cyclic); + let mut csp = ConstraintSatisfaction::new( + self, + csp::ConstraintSatisfactionMode::Initial, + region_id, + cyclic, + ); if !csp.solve(body) { let region = PlacementRegionKind::Cyclic(csp.region); @@ -362,27 +384,39 @@ impl<'alloc, S1: Allocator, S: BumpAllocator> PlacementSolver<'_, 'alloc, S1, S> Ok(()) } - /// Re-evaluates assignments in reverse topological order for refinement. - /// - /// Delegates to [`adjust_trivial`](Self::adjust_trivial) and - /// [`adjust_cyclic`](Self::adjust_cyclic). - fn run_backwards_loop(&mut self, body: &Body<'_>, regions: &[PlacementRegionId]) { + /// Re-evaluates assignments in the given `direction`, returning whether any assignment + /// changed. Called in alternating directions until convergence. + fn run_adjustment( + &mut self, + direction: Direction, + body: &Body<'_>, + regions: &[PlacementRegionId], + ) -> bool { debug_assert!(!regions.is_empty(), "at least the start block must exist"); - let mut ptr = regions.len(); - - while ptr > 0 { - ptr -= 1; + let mut iter = regions.iter(); + let mut changed = false; + + loop { + let Some(®ion_id) = (match direction { + Direction::Forward => iter.next(), + Direction::Backward => iter.next_back(), + }) else { + break changed; + }; - let region_id = regions[ptr]; let region = &mut self.condensation[region_id]; let kind = mem::replace(&mut region.kind, PlacementRegionKind::Unassigned); let kind = match kind { kind @ PlacementRegionKind::Trivial(TrivialPlacementRegion { block }) => { - self.adjust_trivial(body, region_id, block); + changed |= self.adjust_trivial(body, region_id, block); + kind + } + PlacementRegionKind::Cyclic(cyclic) => { + let (cyclic_changed, kind) = self.adjust_cyclic(body, region_id, cyclic); + changed |= cyclic_changed; kind } - PlacementRegionKind::Cyclic(cyclic) => self.adjust_cyclic(body, region_id, cyclic), PlacementRegionKind::Unassigned => { unreachable!( "previous iteration has not returned region {region_id:?} into the graph" @@ -402,7 +436,7 @@ impl<'alloc, S1: Allocator, S: BumpAllocator> PlacementSolver<'_, 'alloc, S1, S> body: &Body<'_>, region_id: PlacementRegionId, block: BasicBlockId, - ) { + ) -> bool { let estimator = CostEstimation { config: CostEstimationConfig::TRIVIAL, solver: self, @@ -422,11 +456,14 @@ impl<'alloc, S1: Allocator, S: BumpAllocator> PlacementSolver<'_, 'alloc, S1, S> let Some(elem) = heap.pop() else { // Re-estimation (unlikely) found no viable targets — keep the current assignment - return; + return false; }; if prev > elem.cost { self.targets[block] = Some(elem); + true + } else { + false } } @@ -439,19 +476,24 @@ impl<'alloc, S1: Allocator, S: BumpAllocator> PlacementSolver<'_, 'alloc, S1, S> body: &Body<'_>, region_id: PlacementRegionId, cyclic: CyclicPlacementRegion<'alloc>, - ) -> PlacementRegionKind<'alloc> { + ) -> (bool, PlacementRegionKind<'alloc>) { // Re-run with full boundary context — neighbor assignments may have changed since the // forward pass. - let mut csp = ConstraintSatisfaction::new(self, region_id, cyclic); + let mut csp = ConstraintSatisfaction::new( + self, + csp::ConstraintSatisfactionMode::Adjustment, + region_id, + cyclic, + ); if !csp.solve(body) { // New solve found nothing better — keep the forward-pass assignment - return PlacementRegionKind::Cyclic(csp.region); + return (false, PlacementRegionKind::Cyclic(csp.region)); } let region = csp.region; let prev_estimator = CostEstimation { - config: CostEstimationConfig::LOOP, + config: CostEstimationConfig::TRIVIAL, solver: self, determine_target: |block: BasicBlockId| self.targets[block], }; @@ -473,7 +515,7 @@ impl<'alloc, S1: Allocator, S: BumpAllocator> PlacementSolver<'_, 'alloc, S1, S> .sum(); let next_estimator = CostEstimation { - config: CostEstimationConfig::LOOP, + config: CostEstimationConfig::TRIVIAL, solver: self, determine_target: |block: BasicBlockId| { // Resolve SCC members from the candidate solution, everything else @@ -502,12 +544,14 @@ impl<'alloc, S1: Allocator, S: BumpAllocator> PlacementSolver<'_, 'alloc, S1, S> }) .sum(); + let mut changed = false; if prev_total_cost > next_total_cost { + changed = true; for block in &*region.blocks { self.targets[block.id] = Some(block.target); } } - PlacementRegionKind::Cyclic(region) + (changed, PlacementRegionKind::Cyclic(region)) } } diff --git a/libs/@local/hashql/mir/src/pass/execution/placement/solve/tests.rs b/libs/@local/hashql/mir/src/pass/execution/placement/solve/tests.rs index ea875739fd1..d40b22a5e86 100644 --- a/libs/@local/hashql/mir/src/pass/execution/placement/solve/tests.rs +++ b/libs/@local/hashql/mir/src/pass/execution/placement/solve/tests.rs @@ -1031,7 +1031,7 @@ fn backward_pass_keeps_assignment_when_csp_fails() { let PlacementRegionKind::Cyclic(cyclic) = kind else { panic!("expected cyclic region for bb1"); }; - let result_kind = solver.adjust_cyclic(&body, scc_region_id, cyclic); + let (_, result_kind) = solver.adjust_cyclic(&body, scc_region_id, cyclic); solver.condensation[scc_region_id].kind = result_kind; // Targets must be unchanged — adjust_cyclic kept the existing assignment diff --git a/libs/@local/hashql/mir/src/pass/transform/inline/find.rs b/libs/@local/hashql/mir/src/pass/transform/inline/find.rs index 66efbc1cc12..103a3a97aa0 100644 --- a/libs/@local/hashql/mir/src/pass/transform/inline/find.rs +++ b/libs/@local/hashql/mir/src/pass/transform/inline/find.rs @@ -23,10 +23,8 @@ use crate::{ /// /// A callsite is eligible if: /// - It's a direct call (function is a constant `FnPtr`). -/// - Its target SCC has not already been inlined into this caller. -/// -/// The SCC check prevents cycles: once we've inlined a function (or any function -/// in its SCC) into a filter, we won't inline it again. +/// - It's not a self-call. +/// - Its target is not a loop breaker. pub(crate) struct FindCallsiteVisitor<'ctx, 'state, 'env, 'heap, A: Allocator> { /// The filter function we're finding callsites in. pub caller: DefId, @@ -53,10 +51,10 @@ impl<'heap, A: Allocator> Visitor<'heap> for FindCallsiteVisitor<'_, '_, '_, 'he return Ok(()); }; - let target_component = self.state.components.scc(ptr); - - // Skip if we've already inlined this SCC into this caller. - if self.state.inlined.contains(self.caller, target_component) { + // Skip self-calls and calls to loop breakers. Breakers are the cycle + // cut points: inlining them would reintroduce the recursion that + // breaker selection removed. + if ptr == self.caller || self.state.loop_breakers.contains(ptr) { return Ok(()); } diff --git a/libs/@local/hashql/mir/src/pass/transform/inline/loop_breaker.rs b/libs/@local/hashql/mir/src/pass/transform/inline/loop_breaker.rs new file mode 100644 index 00000000000..4f8dc405935 --- /dev/null +++ b/libs/@local/hashql/mir/src/pass/transform/inline/loop_breaker.rs @@ -0,0 +1,413 @@ +//! Loop-breaker selection for recursive SCCs. +//! +//! When functions form a mutually recursive group (SCC), the inliner cannot inline all +//! calls without diverging. This module selects which functions to mark as loop breakers: +//! calls to a breaker within its SCC are skipped, while calls from a breaker to non-breakers +//! are still inlined. This flattens most of the call chain without infinite expansion. +//! +//! The approach follows GHC's loop-breaker strategy (Peyton Jones & Marlow 2002): select +//! which *nodes* to mark as non-inlineable rather than which *edges* to cut. All edges +//! targeting a loop breaker become non-inlineable. This reduces the problem from feedback +//! arc set (NP-hard) to feedback vertex set, which is tractable for the small SCCs (at most +//! ~12 nodes) that appear in practice. +//! +//! # Algorithm +//! +//! [`LoopBreaker::run_in`] processes every non-trivial SCC (size > 1) in the call graph: +//! +//! 1. **Score** each member by inverse inlining value via [`LoopBreaker::score`]. Higher score = +//! less valuable to inline = better breaker candidate. +//! 2. **Select** breakers greedily: pick the highest-scored member (least valuable to inline), mark +//! it as a breaker, then check if the remaining members still contain a cycle +//! ([`LoopBreaker::has_cycle`]). Repeat until the remaining subgraph is acyclic. This produces a +//! sufficient (not necessarily minimal) feedback vertex set. +//! 3. **Reorder** the SCC members via [`LoopBreaker::order`]: non-breakers appear in DFS postorder +//! (callees before callers), followed by breakers. This ordering ensures that when a function is +//! processed, its non-breaker callees within the same SCC have already been optimized. +//! +//! The members slice is mutated in place so the caller can iterate it directly. +//! +//! # Scoring +//! +//! The breaker score (see [`InlineLoopBreakerConfig`]) combines: +//! +//! - **Body cost** (positive contribution): large functions are expensive to duplicate. +//! - **Caller count** (negative): functions with many call sites lose more inlining opportunities +//! when chosen as breakers. +//! - **Unique callsite** (negative): a single call site means zero duplication on inline. +//! - **Leaf status** (negative): leaves are safe, cheap inlining targets. +//! - **Inline directive**: `Never` maps to `+inf` (ideal breaker), `Always` to `-inf` (avoided +//! unless every other candidate has been exhausted). +//! +//! # Cycle detection +//! +//! After each breaker is selected, the remaining non-breaker subgraph is checked +//! for cycles using three-color DFS ([`TriColorDepthFirstSearch`]). The DFS runs +//! on the full [`CallGraph`] with an [`ignore_edge`] filter that restricts traversal +//! to non-breaker SCC members. State is accumulated across roots via +//! [`run_from`](TriColorDepthFirstSearch::run_from) so disconnected components +//! (which appear when breaker removal splits the subgraph) are all covered. +//! +//! # Postorder computation +//! +//! Once breakers are selected, the non-breaker members form a DAG. Their processing +//! order is computed as DFS postorder over a [`CallSubgraph`] that filters the +//! call graph to non-breaker members. Breaker members are appended after the +//! non-breakers. +//! +//! [`ignore_edge`]: TriColorVisitor::ignore_edge + +use core::{alloc::Allocator, iter, ops::ControlFlow}; + +use hashql_core::{ + graph::{ + DirectedGraph, Successors, + algorithms::{ + DepthFirstForestPostOrder, TriColorDepthFirstSearch, TriColorVisitor, + color::NodeColor, + tarjan::{Members, SccId}, + }, + }, + heap::BumpAllocator, + id::bit_vec::DenseBitSet, +}; + +use super::analysis::{BodyProperties, InlineDirective}; +use crate::{ + def::{DefId, DefIdSlice}, + pass::analysis::{CallGraph, CallKind}, +}; + +/// Configuration for loop-breaker selection within recursive SCCs. +/// +/// Controls the scoring function that determines which SCC members are selected +/// as loop breakers. Higher breaker scores indicate better breaker candidates +/// (less valuable to inline). +/// +/// # Scoring Formula +/// +/// ```text +/// score = cost_weight * body_cost +/// - caller_penalty * apply_caller_count +/// - unique_callsite_penalty (if exactly one callsite targets this function) +/// - leaf_penalty (if function has no outgoing calls) +/// ``` +/// +/// Functions with `InlineDirective::Never` get score `+inf` (ideal breakers). +/// Functions with `InlineDirective::Always` get score `-inf` (avoided unless +/// every other candidate has been exhausted). +#[derive(Debug, Copy, Clone, PartialEq)] +pub struct InlineLoopBreakerConfig { + /// Weight applied to body cost. + /// + /// Large functions are expensive to duplicate at each call site, making them + /// good breaker candidates. + /// + /// Default: `1.0`. + pub cost_weight: f32, + + /// Penalty per apply-callsite caller. + /// + /// Functions called from many sites provide more inlining opportunities. + /// Selecting them as breakers loses those opportunities for every caller. + /// + /// Default: `5.0`. + pub caller_penalty: f32, + + /// Penalty for functions with exactly one callsite. + /// + /// A unique callsite means inlining causes zero code duplication, making + /// the function a poor breaker choice. + /// + /// Default: `15.0`. + pub unique_callsite_penalty: f32, + + /// Penalty for leaf functions. + /// + /// Leaves have no outgoing calls (except intrinsics) and cannot trigger + /// further inlining cascades, making them safe and valuable to inline. + /// + /// Default: `10.0`. + pub leaf_penalty: f32, +} + +impl Default for InlineLoopBreakerConfig { + fn default() -> Self { + Self { + cost_weight: 1.0, + caller_penalty: 5.0, + unique_callsite_penalty: 15.0, + leaf_penalty: 10.0, + } + } +} + +/// A view of the [`CallGraph`] induced on the non-breaker members of a single SCC. +/// +/// Both source and target are filtered: a node outside the non-breaker member set +/// has no successors, and edges targeting nodes outside it are dropped. +/// +/// [`node_count`](DirectedGraph::node_count) returns the full call graph domain so +/// that traversal algorithms size their bitsets correctly for the global `DefId` space. +struct CallSubgraph<'ctx, 'heap, A: Allocator> { + inner: &'ctx CallGraph<'heap, A>, + members: &'ctx [DefId], + breakers: &'ctx DenseBitSet, +} + +impl DirectedGraph for CallSubgraph<'_, '_, A> { + type Edge<'this> + = (DefId, DefId) + where + Self: 'this; + type EdgeId = (DefId, DefId); + type Node<'this> + = DefId + where + Self: 'this; + type NodeId = DefId; + + fn node_count(&self) -> usize { + // Must match the full DefId domain so that DenseBitSet/MixedBitSet + // in traversal algorithms are sized correctly for any DefId index. + self.inner.node_count() + } + + fn edge_count(&self) -> usize { + self.inner.edge_count() + } + + #[expect(unreachable_code)] + fn iter_nodes(&self) -> impl ExactSizeIterator> + DoubleEndedIterator { + unimplemented!(); + iter::empty() + } + + #[expect(unreachable_code)] + fn iter_edges(&self) -> impl ExactSizeIterator> + DoubleEndedIterator { + unimplemented!(); + iter::empty() + } +} + +impl Successors for CallSubgraph<'_, '_, A> { + type SuccIter<'this> + = impl Iterator + where + Self: 'this; + + fn successors(&self, node: Self::NodeId) -> Self::SuccIter<'_> { + let in_subgraph = self.members.contains(&node) && !self.breakers.contains(node); + + self.inner.successors(node).filter(move |&succ| { + in_subgraph && self.members.contains(&succ) && !self.breakers.contains(succ) + }) + } +} + +/// Entry point for loop-breaker selection and SCC reordering. +pub(crate) struct LoopBreaker<'ctx, 'heap, A: Allocator> { + pub config: InlineLoopBreakerConfig, + pub graph: &'ctx CallGraph<'heap, A>, + pub properties: &'ctx DefIdSlice>, + pub search: TriColorDepthFirstSearch<'ctx, CallGraph<'heap, A>, DefId, A>, +} + +impl LoopBreaker<'_, '_, A> { + /// Select loop breakers and reorder members for every non-trivial SCC. + /// + /// After this call, for each non-trivial SCC: + /// - A sufficient set of breakers has been selected to make the remainder acyclic. + /// - The member slice is reordered: non-breaker callees before their callers, breakers last. + /// + /// Returns a bitset of all selected breakers across every SCC. + pub(crate) fn run_in( + &mut self, + members: &mut Members, + scratch: &S, + ) -> DenseBitSet { + let mut breakers = DenseBitSet::new_empty(self.properties.len()); + + for (_, members) in members { + if members.len() < 2 { + continue; + } + + self.select_in(members, &mut breakers, scratch); + + #[expect( + clippy::debug_assert_with_mut_call, + reason = "the call only resets and uses the search state, therefore is safe to be \ + mut" + )] + { + debug_assert!( + !self.has_cycle(members, &breakers), + "select_in must produce an acyclic remainder" + ); + } + + let postorder = self.order(members, &breakers, scratch); + members.copy_from_slice(postorder); + } + + breakers + } + + /// Greedily select breakers for a single non-trivial SCC. + /// + /// Postcondition: the non-breaker remainder of `members` is acyclic. + fn select_in( + &mut self, + members: &[DefId], + breakers: &mut DenseBitSet, + scratch: &B, + ) { + // Sort descending: highest breaker score (least valuable to inline) first. + let scored = scratch + .allocate_slice_uninit(members.len()) + .write_with(|index| (members[index], self.score(members[index]))); + scored.sort_by(|(_, lhs_score), (_, rhs_score)| lhs_score.total_cmp(rhs_score).reverse()); + + // The full SCC is cyclic by definition, so we always need at least one breaker. + for &(candidate, _) in &*scored { + breakers.insert(candidate); + + if !self.has_cycle(members, breakers) { + break; + } + } + } + + /// Returns whether the non-breaker members still contain a cycle. + fn has_cycle(&mut self, members: &[DefId], breakers: &DenseBitSet) -> bool { + struct SubgraphCycleDetector<'ctx> { + members: &'ctx [DefId], + breakers: &'ctx DenseBitSet, + } + + impl TriColorVisitor for SubgraphCycleDetector<'_> + where + G: DirectedGraph, + { + type Result = ControlFlow<()>; + + fn node_examined(&mut self, _: DefId, before: Option) -> Self::Result { + match before { + Some(NodeColor::Gray) => ControlFlow::Break(()), + _ => ControlFlow::Continue(()), + } + } + + fn ignore_edge(&mut self, source: DefId, target: DefId) -> bool { + self.breakers.contains(source) + || self.breakers.contains(target) + || !self.members.contains(&source) + || !self.members.contains(&target) + } + } + + let mut detector = SubgraphCycleDetector { members, breakers }; + + // Accumulate visited state across roots: breaker removal can disconnect + // the subgraph, and a cycle in an unreachable component would be missed + // by a single-root search. + self.search.reset(); + for &member in members { + if breakers.contains(member) { + continue; + } + + if self.search.run_from(member, &mut detector).is_break() { + return true; + } + } + + false + } + + /// Compute the breaker score for a single function. + /// + /// Higher score = better breaker candidate (less valuable to inline). + /// See [`InlineLoopBreakerConfig`] for the formula and weight descriptions. + #[expect(clippy::cast_precision_loss)] + fn score(&self, body: DefId) -> f32 { + let props = &self.properties[body]; + + match props.directive { + InlineDirective::Never => return f32::INFINITY, + InlineDirective::Always => return f32::NEG_INFINITY, + InlineDirective::Heuristic => {} + } + + let caller_count = self + .graph + .callers(body) + .filter(|cs| matches!(cs.kind, CallKind::Apply(_))) + .count(); + + let mut score = self.config.cost_weight * props.cost; + score = self + .config + .caller_penalty + .mul_add(-(caller_count as f32), score); + + if self.graph.unique_caller(body).is_some() { + score -= self.config.unique_callsite_penalty; + } + + if props.is_leaf { + score -= self.config.leaf_penalty; + } + + score + } + + /// Compute the processing order for a non-trivial SCC. + /// + /// Returns non-breaker members ordered so that callees appear before their + /// callers, followed by breaker members. + #[expect(unsafe_code)] + fn order<'alloc, S: BumpAllocator>( + &self, + members: &[DefId], + breakers: &DenseBitSet, + alloc: &'alloc S, + ) -> &'alloc [DefId] { + let subgraph = CallSubgraph { + inner: self.graph, + members, + breakers, + }; + + let mut index = 0; + let order = alloc.allocate_slice_uninit(members.len()); + + // The forest traversal covers the full DefId domain (since node_count + // must match the DefId index space for bitset sizing). Non-member nodes + // have no successors in the induced subgraph and yield as isolated + // nodes, so we filter them out. + for node in DepthFirstForestPostOrder::new(&subgraph) { + if !breakers.contains(node) && members.contains(&node) { + order[index].write(node); + index += 1; + } + } + + // Breakers last, in original order. + for &member in members { + if breakers.contains(member) { + order[index].write(member); + index += 1; + } + } + + debug_assert_eq!(index, members.len()); + + // SAFETY: All `members.len()` elements are initialized: + // - The forest traversal yields every non-breaker member exactly once (reachable in the + // full domain, filtered to SCC members). + // - The final loop writes all breaker members. + unsafe { order.assume_init_mut() } + } +} diff --git a/libs/@local/hashql/mir/src/pass/transform/inline/mod.rs b/libs/@local/hashql/mir/src/pass/transform/inline/mod.rs index ce26d8bc987..fca2f6875f2 100644 --- a/libs/@local/hashql/mir/src/pass/transform/inline/mod.rs +++ b/libs/@local/hashql/mir/src/pass/transform/inline/mod.rs @@ -17,12 +17,15 @@ //! # Normal Phase //! //! For non-filter functions, the normal phase: -//! 1. Processes SCCs in dependency order (callees before callers). -//! 2. For each callsite, computes a score using [`InlineHeuristics::score`]. -//! 3. Selects candidates with positive scores, limited by per-caller budget. -//! 4. Updates caller costs after inlining to prevent cascade explosions. -//! -//! Recursive calls (same SCC) are never inlined to prevent infinite expansion. +//! 1. Selects loop breakers for each non-trivial SCC (see [`loop_breaker`]). +//! 2. Processes SCCs in dependency order (callees before callers). Within non-trivial SCCs, +//! non-breaker members are processed in postorder of the breaker-removed DAG, then breaker +//! members. +//! 3. For each callsite, computes a score using [`InlineHeuristics::score`]. +//! 4. Calls to loop breakers within their SCC are skipped. Calls to non-breakers within the same +//! SCC are eligible for inlining. +//! 5. Selects candidates with positive scores, limited by per-caller budget. +//! 6. Updates caller costs after inlining to prevent cascade explosions. //! //! # Aggressive Phase //! @@ -30,7 +33,7 @@ //! aggressive inlining to fully flatten the filter logic. The aggressive phase: //! 1. Iterates up to `aggressive_inline_cutoff` times per filter. //! 2. On each iteration, inlines all eligible callsites found in the filter. -//! 3. Tracks which SCCs have been inlined to prevent cycles. +//! 3. Calls to loop breakers and self-calls are skipped to prevent cycles. //! 4. Emits a diagnostic if the cutoff is reached. //! //! # Budget System @@ -50,26 +53,24 @@ use alloc::collections::BinaryHeap; use core::{alloc::Allocator, cmp, mem}; use hashql_core::{ - graph::{ - DirectedGraph as _, - algorithms::{ - Tarjan, - tarjan::{SccId, StronglyConnectedComponents}, - }, + graph::algorithms::{ + Tarjan, TriColorDepthFirstSearch, + tarjan::{Members, SccId, StronglyConnectedComponents}, }, heap::{BumpAllocator, Heap}, - id::{ - Id as _, IdSlice, - bit_vec::{DenseBitSet, SparseBitMatrix}, - }, + id::{Id as _, IdSlice, bit_vec::DenseBitSet}, span::SpanId, }; -pub use self::{analysis::InlineCostEstimationConfig, heuristics::InlineHeuristicsConfig}; +pub use self::{ + analysis::InlineCostEstimationConfig, heuristics::InlineHeuristicsConfig, + loop_breaker::InlineLoopBreakerConfig, +}; use self::{ analysis::{BodyAnalysis, BodyProperties, CostEstimationResidual}, find::FindCallsiteVisitor, heuristics::InlineHeuristics, + loop_breaker::LoopBreaker, rename::RenameVisitor, }; use crate::{ @@ -100,6 +101,7 @@ mod find; mod heuristics; mod rename; +mod loop_breaker; #[cfg(test)] mod tests; @@ -141,9 +143,11 @@ pub struct InlineConfig { pub cost: InlineCostEstimationConfig, /// Thresholds and bonuses for scoring callsites. pub heuristics: InlineHeuristicsConfig, + /// Configuration for loop-breaker selection in recursive SCCs. + pub loop_breaker: InlineLoopBreakerConfig, /// Multiplier for computing per-caller budget. /// - /// Budget = `heuristics.max × budget_multiplier`. + /// Budget = `heuristics.max * budget_multiplier`. /// Limits how much code can be inlined into a single function. /// /// Default: `2.0` (budget of 120 with default max of 60). @@ -163,6 +167,7 @@ impl Default for InlineConfig { Self { cost: InlineCostEstimationConfig::default(), heuristics: InlineHeuristicsConfig::default(), + loop_breaker: InlineLoopBreakerConfig::default(), budget_multiplier: 2.0, aggressive_inline_cutoff: 16, } @@ -198,39 +203,38 @@ struct InlineState<'ctx, 'state, 'env, 'heap, A: Allocator> { /// Functions that require aggressive inlining (filter closures). filters: DenseBitSet, - /// Tracks which SCCs have been inlined into each function. + /// Functions selected as loop breakers within their SCC. /// - /// Used to prevent cycles during aggressive inlining: once an SCC - /// has been inlined into a filter, it won't be inlined again. - inlined: SparseBitMatrix, + /// Calls to a breaker within its SCC are skipped during inlining. + /// Calls from a breaker to non-breakers are still inlined. + loop_breakers: DenseBitSet, // cost estimation properties costs: CostEstimationResidual<'heap, A>, /// SCC membership for cycle detection. components: StronglyConnectedComponents, + component_members: Option>, global: &'ctx mut GlobalTransformState<'state>, } impl<'heap, A: Allocator> InlineState<'_, '_, '_, 'heap, A> { - /// Collect all non-recursive callsites for aggressive inlining. + /// Collect all callsites for aggressive inlining. /// /// Used for filter functions which bypass normal heuristics. - /// Records inlined SCCs to prevent cycles in subsequent iterations. - fn collect_all_callsites(&mut self, body: DefId, mem: &mut InlineStateMemory) { + /// Self-calls are excluded to prevent panics in `get_disjoint_mut`. + fn collect_all_callsites(&self, body: DefId, mem: &mut InlineStateMemory) { let component = self.components.scc(body); self.graph .apply_callsites(body) - .filter(|callsite| self.components.scc(callsite.target) != component) + .filter(|callsite| { + callsite.target != body + && (self.components.scc(callsite.target) != component + || !self.loop_breakers.contains(callsite.target)) + }) .collect_into(&mut mem.callsites); - - self.inlined.insert(body, component); - for callsite in &mem.callsites { - self.inlined - .insert(body, self.components.scc(callsite.target)); - } } /// Collect callsites using heuristic scoring and budget. @@ -260,7 +264,13 @@ impl<'heap, A: Allocator> InlineState<'_, '_, '_, 'heap, A> { let candidates = &mut mem.candidates; for callsite in self.graph.apply_callsites(body) { - if self.components.scc(callsite.target) == component { + // Within an SCC, only skip calls to loop breakers (they break the cycle). + // Calls to non-breakers within the SCC are eligible because we're now inside of a DAG. + let same_scc = self.components.scc(callsite.target) == component; + if same_scc && self.loop_breakers.contains(callsite.target) { + continue; + } + if callsite.target == body { continue; } @@ -536,15 +546,25 @@ impl Inline { let tarjan = Tarjan::new_in(&graph, &self.alloc); let components = tarjan.run(); + let mut component_members = components.members_in(&self.alloc); + + let mut loop_breaker = LoopBreaker { + config: self.config.loop_breaker, + graph: &graph, + properties: &costs.properties, + search: TriColorDepthFirstSearch::new_in(&graph, &self.alloc), + }; + let loop_breakers = loop_breaker.run_in(&mut component_members, &self.alloc); InlineState { config: self.config, filters, - inlined: SparseBitMatrix::new_in(components.node_count(), &self.alloc), + loop_breakers, interner, graph, costs, components, + component_members: Some(component_members), global: state, } } @@ -552,18 +572,22 @@ impl Inline { /// Run the normal inlining phase. /// /// Processes SCCs in dependency order (callees before callers) so that - /// cost updates propagate correctly. + /// cost updates propagate correctly. Within non-trivial SCCs, non-breaker + /// members are processed in postorder (callees before callers in the + /// breaker-removed DAG), followed by breaker members. fn normal<'heap, 'alloc>( - &self, state: &mut InlineState<'_, '_, '_, 'heap, &'alloc A>, bodies: &mut IdSlice>, mem: &mut InlineStateMemory<&'alloc A>, ) -> Changed { - let members = state.components.members_in(&self.alloc); - let mut any_changed = Changed::No; - for scc in members.sccs() { - for &id in members.of(scc) { + let component_members = state + .component_members + .take() + .unwrap_or_else(|| panic!("scc component members have been taken twice")); + + for (_, scc_members) in &component_members { + for &id in scc_members { let changed = state.run(bodies, id, mem); any_changed |= changed; state.global.mark(id, changed); @@ -605,9 +629,6 @@ impl Inline { mem.callsites .sort_unstable_by(|lhs, rhs| lhs.kind.cmp(&rhs.kind).reverse()); for callsite in mem.callsites.drain(..) { - let target_component = state.components.scc(callsite.target); - state.inlined.insert(filter, target_component); - state.inline(bodies, callsite); } @@ -637,7 +658,7 @@ impl<'env, 'heap, A: BumpAllocator> GlobalTransformPass<'env, 'heap> for Inline< let mut mem = InlineStateMemory::new(&self.alloc); let mut changed = Changed::No; - changed |= self.normal(&mut state, bodies, &mut mem); + changed |= Self::normal(&mut state, bodies, &mut mem); changed |= self.aggressive(context, &mut state, bodies, &mut mem); changed } diff --git a/libs/@local/hashql/mir/src/pass/transform/inline/tests.rs b/libs/@local/hashql/mir/src/pass/transform/inline/tests.rs index fca219e0be6..632256db6af 100644 --- a/libs/@local/hashql/mir/src/pass/transform/inline/tests.rs +++ b/libs/@local/hashql/mir/src/pass/transform/inline/tests.rs @@ -6,7 +6,9 @@ use std::path::PathBuf; use bstr::ByteVec as _; use hashql_core::{ + graph::algorithms::{Tarjan, TriColorDepthFirstSearch, tarjan::SccId}, heap::Heap, + id::{Id as _, bit_vec::DenseBitSet}, pretty::Formatter, symbol::sym, r#type::{TypeFormatter, TypeFormatterOptions, environment::Environment}, @@ -16,6 +18,7 @@ use insta::{Settings, assert_snapshot}; use super::{ BodyAnalysis, Inline, InlineConfig, InlineCostEstimationConfig, InlineHeuristicsConfig, + InlineLoopBreakerConfig, loop_breaker::LoopBreaker, }; use crate::{ body::{Body, Source, basic_block::BasicBlockId, location::Location}, @@ -1113,3 +1116,616 @@ fn heuristics_no_unique_callsite_bonus_multiple_calls() { // 10 + 5 - 30 * 0.875 = 15.0 - 26.25 = -11.25 assert!((heuristics.score(default_callsite()) - (-11.25)).abs() < f32::EPSILON); } + +/// Two mutually recursive functions A and B, with a caller C. +/// +/// A is large (many statements), B is small (single return). The loop breaker +/// should select A (high cost = good breaker), leaving B as a non-breaker. +/// When the inliner processes the SCC: +/// - B's call to A (the breaker) is skipped. +/// - A's call to B (non-breaker) is inlined into A. +/// - C's call to either is cross-SCC and inlined normally. +#[test] +fn loop_breaker_mutual_recursion() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let a_id = DefId::new(0); + let b_id = DefId::new(1); + let c_id = DefId::new(2); + + // A: large function that calls B + let a = body!(interner, env; fn@a_id/1 -> Int { + decl n: Int, cond: Bool, tmp1: Int, tmp2: Int, tmp3: Int, result: Int; + bb0() { + cond = bin.== n 0; + if cond then bb1() else bb2(); + }, + bb1() { goto bb3(n); }, + bb2() { + tmp1 = bin.+ n 1; + tmp2 = bin.+ tmp1 2; + tmp3 = apply (b_id), tmp2; + goto bb3(tmp3); + }, + bb3(result) { return result; } + }); + + // B: small function that calls A + let b = body!(interner, env; fn@b_id/1 -> Int { + decl x: Int, result: Int; + bb0() { + result = apply (a_id), x; + return result; + } + }); + + // C: external caller + let c = simple_caller(&interner, &env, c_id, b_id); + + let mut bodies = [a, b, c]; + + assert_inline_pass( + "loop_breaker_mutual_recursion", + &mut bodies, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + InlineConfig::default(), + ); +} + +/// Verifies breaker selection picks the highest-cost member. +/// +/// Given SCC {A, B} where A has high cost and B has low cost, +/// A should be selected as the breaker (high cost = good breaker). +#[test] +fn loop_breaker_selects_highest_cost() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let a_id = DefId::new(0); + let b_id = DefId::new(1); + + // A: expensive, calls B + let a = body!(interner, env; fn@a_id/1 -> Int { + decl n: Int, cond: Bool, t1: Int, t2: Int, t3: Int, t4: Int, result: Int; + bb0() { + cond = bin.== n 0; + if cond then bb1() else bb2(); + }, + bb1() { goto bb3(n); }, + bb2() { + t1 = bin.+ n 1; + t2 = bin.+ t1 2; + t3 = bin.- t2 3; + t4 = apply (b_id), t3; + goto bb3(t4); + }, + bb3(result) { return result; } + }); + + // B: cheap, calls A + let b = body!(interner, env; fn@b_id/1 -> Int { + decl x: Int, result: Int; + bb0() { + result = apply (a_id), x; + return result; + } + }); + + let bodies = [a, b]; + let bodies_slice = DefIdSlice::from_raw(&bodies); + + let graph = CallGraph::analyze_in(bodies_slice, &heap); + let mut analysis = BodyAnalysis::new( + &graph, + bodies_slice, + InlineCostEstimationConfig::default(), + &heap, + ); + for body in &bodies { + analysis.run(body); + } + let costs = analysis.finish(); + + let tarjan: Tarjan<_, _, SccId, _, _> = Tarjan::new_in(&graph, &heap); + let components = tarjan.run(); + let mut members = components.members_in(&heap); + + let mut breaker = LoopBreaker { + config: InlineLoopBreakerConfig::default(), + graph: &graph, + properties: &costs.properties, + search: TriColorDepthFirstSearch::new_in(&graph, &heap), + }; + let breakers = breaker.run_in(&mut members, &heap); + + // A has higher cost, so A should be the breaker. + assert!( + breakers.contains(a_id), + "expected A (high cost) to be selected as breaker" + ); + assert!( + !breakers.contains(b_id), + "expected B (low cost) to not be a breaker" + ); +} + +/// Three-way mutual recursion: A -> B -> C -> A. +/// +/// One breaker should be sufficient to break the single cycle. +/// The member with highest cost should be selected. +#[test] +fn loop_breaker_three_way_cycle() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let a_id = DefId::new(0); + let b_id = DefId::new(1); + let c_id = DefId::new(2); + + // A: expensive + let a = body!(interner, env; fn@a_id/1 -> Int { + decl n: Int, cond: Bool, t1: Int, t2: Int, t3: Int, result: Int; + bb0() { + cond = bin.== n 0; + if cond then bb1() else bb2(); + }, + bb1() { goto bb3(n); }, + bb2() { + t1 = bin.+ n 1; + t2 = bin.+ t1 2; + t3 = apply (b_id), t2; + goto bb3(t3); + }, + bb3(result) { return result; } + }); + + // B: medium + let b = body!(interner, env; fn@b_id/1 -> Int { + decl x: Int, t1: Int, result: Int; + bb0() { + t1 = bin.- x 1; + result = apply (c_id), t1; + return result; + } + }); + + // C: cheap, calls A + let c = body!(interner, env; fn@c_id/1 -> Int { + decl x: Int, result: Int; + bb0() { + result = apply (a_id), x; + return result; + } + }); + + let bodies = [a, b, c]; + let bodies_slice = DefIdSlice::from_raw(&bodies); + + let graph = CallGraph::analyze_in(bodies_slice, &heap); + let mut analysis = BodyAnalysis::new( + &graph, + bodies_slice, + InlineCostEstimationConfig::default(), + &heap, + ); + for body in &bodies { + analysis.run(body); + } + let costs = analysis.finish(); + + let tarjan: Tarjan<_, _, SccId, _, _> = Tarjan::new_in(&graph, &heap); + let components = tarjan.run(); + let mut members = components.members_in(&heap); + + let mut breaker = LoopBreaker { + config: InlineLoopBreakerConfig::default(), + graph: &graph, + properties: &costs.properties, + search: TriColorDepthFirstSearch::new_in(&graph, &heap), + }; + let breakers = breaker.run_in(&mut members, &heap); + + // Exactly one breaker should suffice for a single cycle. + assert_eq!( + breakers.count(), + 1, + "expected exactly 1 breaker for a 3-node cycle, got {}", + breakers.count() + ); + + // A has the highest cost. + assert!( + breakers.contains(a_id), + "expected A (highest cost) to be the breaker" + ); +} + +/// Helper: run loop-breaker selection on a set of bodies and return the breaker bitset +/// and the reordered members. +fn run_loop_breaker<'heap>( + bodies: &[Body<'heap>], + heap: &'heap Heap, +) -> (DenseBitSet, Vec>) { + let bodies_slice = DefIdSlice::from_raw(bodies); + + let graph = CallGraph::analyze_in(bodies_slice, heap); + let mut analysis = BodyAnalysis::new( + &graph, + bodies_slice, + InlineCostEstimationConfig::default(), + heap, + ); + for body in bodies { + analysis.run(body); + } + let costs = analysis.finish(); + + let tarjan: Tarjan<_, _, SccId, _, _> = Tarjan::new_in(&graph, heap); + let components = tarjan.run(); + let mut members = components.members_in(heap); + + let mut breaker = LoopBreaker { + config: InlineLoopBreakerConfig::default(), + graph: &graph, + properties: &costs.properties, + search: TriColorDepthFirstSearch::new_in(&graph, heap), + }; + let breakers = breaker.run_in(&mut members, heap); + + let scc_orders: Vec> = members + .iter() + .filter(|(_, m)| m.len() > 1) + .map(|(_, m)| m.to_vec()) + .collect(); + + (breakers, scc_orders) +} + +use core::ops::ControlFlow; + +use hashql_core::graph::algorithms::color::NodeColor; + +struct RemainderCycleDetector<'a> { + members: &'a [DefId], + breakers: &'a DenseBitSet, +} + +impl> + hashql_core::graph::algorithms::TriColorVisitor for RemainderCycleDetector<'_> +{ + type Result = ControlFlow<()>; + + fn node_examined(&mut self, _: DefId, before: Option) -> Self::Result { + match before { + Some(NodeColor::Gray) => ControlFlow::Break(()), + _ => ControlFlow::Continue(()), + } + } + + fn ignore_edge(&mut self, source: DefId, target: DefId) -> bool { + self.breakers.contains(source) + || self.breakers.contains(target) + || !self.members.contains(&source) + || !self.members.contains(&target) + } +} + +/// SCC with two independent 2-cycles joined into one component. +/// Requires at least two breakers. +/// +/// Structure: A <-> B, C <-> D, with B -> C and D -> A connecting them. +#[test] +fn loop_breaker_multi_breaker_scc() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let a_id = DefId::new(0); + let b_id = DefId::new(1); + let c_id = DefId::new(2); + let d_id = DefId::new(3); + + // A: calls B + let a = body!(interner, env; fn@a_id/1 -> Int { + decl n: Int, result: Int; + bb0() { + result = apply (b_id), n; + return result; + } + }); + + // B: calls A and C + let b = body!(interner, env; fn@b_id/1 -> Int { + decl n: Int, cond: Bool, t1: Int, t2: Int, result: Int; + bb0() { + cond = bin.== n 0; + if cond then bb1() else bb2(); + }, + bb1() { + t1 = apply (a_id), n; + goto bb3(t1); + }, + bb2() { + t2 = apply (c_id), n; + goto bb3(t2); + }, + bb3(result) { return result; } + }); + + // C: calls D + let c = body!(interner, env; fn@c_id/1 -> Int { + decl n: Int, result: Int; + bb0() { + result = apply (d_id), n; + return result; + } + }); + + // D: calls C and A (completing both sub-cycles) + let d = body!(interner, env; fn@d_id/1 -> Int { + decl n: Int, cond: Bool, t1: Int, t2: Int, result: Int; + bb0() { + cond = bin.== n 0; + if cond then bb1() else bb2(); + }, + bb1() { + t1 = apply (c_id), n; + goto bb3(t1); + }, + bb2() { + t2 = apply (a_id), n; + goto bb3(t2); + }, + bb3(result) { return result; } + }); + + let bodies = [a, b, c, d]; + let (breakers, _) = run_loop_breaker(&bodies, &heap); + + // Two overlapping sub-cycles (A<->B and C<->D) need exactly 2 breakers: + // no single node participates in both cycles. + assert_eq!( + breakers.count(), + 2, + "expected exactly 2 breakers, got {}", + breakers.count() + ); + + // Verify the remainder is actually acyclic. + let bodies_slice = DefIdSlice::from_raw(&bodies); + let graph = CallGraph::analyze_in(bodies_slice, &heap); + let mut search = TriColorDepthFirstSearch::new_in(&graph, &heap); + let mut cycle_found = false; + + let all_members: Vec = (0..bodies.len()).map(DefId::from_usize).collect(); + let mut detector = RemainderCycleDetector { + members: &all_members, + breakers: &breakers, + }; + + search.reset(); + for &member in &all_members { + if !breakers.contains(member) && search.run_from(member, &mut detector).is_break() { + cycle_found = true; + break; + } + } + + assert!( + !cycle_found, + "remainder after breaker selection must be acyclic" + ); +} + +/// Ordering test with 3 non-breakers forming a chain. +/// +/// SCC: {breaker, X, Y, W} where breaker removal leaves X -> Y -> W. +/// Postorder must satisfy: W before Y, Y before X. +#[test] +fn loop_breaker_ordering_chain() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let breaker_id = DefId::new(0); + let x_id = DefId::new(1); + let y_id = DefId::new(2); + let w_id = DefId::new(3); + + // breaker: expensive, calls X, completes the cycle from W + let breaker_fn = body!(interner, env; fn@breaker_id/1 -> Int { + decl n: Int, cond: Bool, t1: Int, t2: Int, t3: Int, t4: Int, t5: Int, result: Int; + bb0() { + cond = bin.== n 0; + if cond then bb1() else bb2(); + }, + bb1() { goto bb3(n); }, + bb2() { + t1 = bin.+ n 1; + t2 = bin.+ t1 2; + t3 = bin.- t2 3; + t4 = bin.+ t3 4; + t5 = apply (x_id), t4; + goto bb3(t5); + }, + bb3(result) { return result; } + }); + + // X: calls Y + let x = body!(interner, env; fn@x_id/1 -> Int { + decl n: Int, result: Int; + bb0() { + result = apply (y_id), n; + return result; + } + }); + + // Y: calls W + let y = body!(interner, env; fn@y_id/1 -> Int { + decl n: Int, result: Int; + bb0() { + result = apply (w_id), n; + return result; + } + }); + + // W: calls breaker (closing the cycle) + let w = body!(interner, env; fn@w_id/1 -> Int { + decl n: Int, result: Int; + bb0() { + result = apply (breaker_id), n; + return result; + } + }); + + let bodies = [breaker_fn, x, y, w]; + let (breakers, scc_orders) = run_loop_breaker(&bodies, &heap); + + assert_eq!(breakers.count(), 1); + assert!(breakers.contains(breaker_id)); + + // There should be exactly one non-trivial SCC. + assert_eq!(scc_orders.len(), 1); + let order = &scc_orders[0]; + assert_eq!(order.len(), 4); + + // Non-breakers in postorder: W before Y before X. + let pos = |id: DefId| { + order + .iter() + .position(|&node| node == id) + .expect("node exists inside of order") + }; + assert!( + pos(w_id) < pos(y_id), + "W (leaf) must come before Y in postorder" + ); + assert!(pos(y_id) < pos(x_id), "Y must come before X in postorder"); + // Breaker is last. + assert!( + pos(x_id) < pos(breaker_id), + "all non-breakers must come before the breaker" + ); +} + +/// All members have `Always` directive. The algorithm must still select a breaker +/// to break the cycle, even though all candidates score `-inf`. +#[test] +fn loop_breaker_all_always_directive() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let a_id = DefId::new(0); + let b_id = DefId::new(1); + + // Both are constructors (Always directive) + let mut a = body!(interner, env; fn@a_id/1 -> Int { + decl n: Int, result: Int; + bb0() { + result = apply (b_id), n; + return result; + } + }); + a.source = Source::Ctor(hashql_core::symbol::sym::Some); + + let mut b = body!(interner, env; fn@b_id/1 -> Int { + decl n: Int, result: Int; + bb0() { + result = apply (a_id), n; + return result; + } + }); + b.source = Source::Ctor(hashql_core::symbol::sym::None); + + let bodies = [a, b]; + let (breakers, _) = run_loop_breaker(&bodies, &heap); + + // A 2-node cycle needs exactly 1 breaker, even when both score -inf. + assert_eq!( + breakers.count(), + 1, + "expected exactly 1 breaker for a 2-node cycle, got {}", + breakers.count() + ); + // Both are Always with equal cost, so either is a valid choice. + assert!( + breakers.contains(a_id) || breakers.contains(b_id), + "the selected breaker must be one of the SCC members" + ); +} + +/// A filter function that calls into a mutually recursive SCC. +/// +/// The aggressive phase should inline non-breaker B into the filter, but the +/// breaker A (visible after B's inlining) must not be expanded. Without the +/// unconditional breaker check in `FindCallsiteVisitor`, the aggressive phase +/// would re-expand A on each iteration until the cutoff. +#[test] +fn loop_breaker_aggressive_filter_with_recursive_scc() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let a_id = DefId::new(0); + let b_id = DefId::new(1); + let filter_id = DefId::new(2); + + // A: expensive, calls B (will be selected as breaker) + let a = body!(interner, env; fn@a_id/1 -> Int { + decl n: Int, cond: Bool, t1: Int, t2: Int, t3: Int, result: Int; + bb0() { + cond = bin.== n 0; + if cond then bb1() else bb2(); + }, + bb1() { goto bb3(n); }, + bb2() { + t1 = bin.+ n 1; + t2 = bin.+ t1 2; + t3 = apply (b_id), t2; + goto bb3(t3); + }, + bb3(result) { return result; } + }); + + // B: cheap, calls A + let b = body!(interner, env; fn@b_id/1 -> Int { + decl x: Int, result: Int; + bb0() { + result = apply (a_id), x; + return result; + } + }); + + // Filter: calls B + let filter = body!(interner, env; [graph::read::filter]@filter_id/1 -> Int { + decl x: Int, result: Int; + bb0() { + result = apply (b_id), x; + return result; + } + }); + + let mut bodies = [a, b, filter]; + + assert_inline_pass( + "loop_breaker_aggressive_filter", + &mut bodies, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + InlineConfig::default(), + ); +} diff --git a/libs/@local/hashql/mir/src/pass/transform/inst_simplify/mod.rs b/libs/@local/hashql/mir/src/pass/transform/inst_simplify/mod.rs index ade4c6e0626..2fc24b17e1c 100644 --- a/libs/@local/hashql/mir/src/pass/transform/inst_simplify/mod.rs +++ b/libs/@local/hashql/mir/src/pass/transform/inst_simplify/mod.rs @@ -309,6 +309,8 @@ impl<'heap, A: Allocator> InstSimplifyVisitor<'_, 'heap, A> { (BinOp::BitAnd, 0) if is_bool => { Some(RValue::Load(Operand::Constant(Constant::Int(false.into())))) } + // 0 & rhs => 0 (annihilator) + (BinOp::BitAnd, 0) => Some(RValue::Load(Operand::Constant(Constant::Int(0.into())))), (BinOp::BitAnd, _) => None, // 0 | rhs => rhs (identity) (BinOp::BitOr, 0) => Some(RValue::Load(Operand::Place(rhs))), @@ -369,6 +371,8 @@ impl<'heap, A: Allocator> InstSimplifyVisitor<'_, 'heap, A> { (BinOp::BitAnd, 0) if is_bool => { Some(RValue::Load(Operand::Constant(Constant::Int(false.into())))) } + // 0 & lhs => 0 (annihilator) + (BinOp::BitAnd, 0) => Some(RValue::Load(Operand::Constant(Constant::Int(0.into())))), (BinOp::BitAnd, _) => None, // lhs | 0 => lhs (identity) (BinOp::BitOr, 0) => Some(RValue::Load(Operand::Place(lhs))), diff --git a/libs/@local/hashql/mir/src/pass/transform/mod.rs b/libs/@local/hashql/mir/src/pass/transform/mod.rs index c6819a45d9f..ddda2fa2369 100644 --- a/libs/@local/hashql/mir/src/pass/transform/mod.rs +++ b/libs/@local/hashql/mir/src/pass/transform/mod.rs @@ -22,7 +22,10 @@ pub use self::{ dle::DeadLocalElimination, dse::DeadStoreElimination, forward_substitution::ForwardSubstitution, - inline::{Inline, InlineConfig, InlineCostEstimationConfig, InlineHeuristicsConfig}, + inline::{ + Inline, InlineConfig, InlineCostEstimationConfig, InlineHeuristicsConfig, + InlineLoopBreakerConfig, + }, inst_simplify::InstSimplify, post_inline::PostInline, pre_inline::PreInline, diff --git a/libs/@local/hashql/mir/tests/ui/pass/data-dependency/load_param_mixed.snap b/libs/@local/hashql/mir/tests/ui/pass/data-dependency/load_param_mixed.snap new file mode 100644 index 00000000000..d641ecea391 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/data-dependency/load_param_mixed.snap @@ -0,0 +1,21 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/data_dependency/tests.rs +assertion_line: 32 +expression: "format!(\"{graph}\\n\\n=====\\n\\n{transient}\")" +--- +%1 -> %0 [Param, projections: .0] +%2 -> %1 [Index(FieldIndex(0))] +%3 -> %2 [Param, projections: .0] +%0 -> 42 [Index(FieldIndex(0))] +%1 -> 42 [Param] +%3 -> 42 [Param] + + +===== + +%0 -> 42 [Index(FieldIndex(0))] +%1 -> 42 [Param] +%1 -> 42 [Param] +%2 -> 42 [Index(FieldIndex(0))] +%3 -> 42 [Param] +%3 -> 42 [Param] diff --git a/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_consensus_projected_field_const.snap b/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_consensus_projected_field_const.snap new file mode 100644 index 00000000000..7efaeecb79c --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_consensus_projected_field_const.snap @@ -0,0 +1,22 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/data_dependency/tests.rs +expression: "format!(\"{graph}\\n\\n=====\\n\\n{transient}\")" +--- +%2 -> %0 [Index(FieldIndex(1))] +%3 -> %1 [Index(FieldIndex(1))] +%5 -> %2 [Param] +%5 -> %3 [Param] +%6 -> %5 [Load, projections: .0] +%2 -> 42 [Index(FieldIndex(0))] +%3 -> 42 [Index(FieldIndex(0))] + + +===== + +%2 -> %0 [Index(FieldIndex(1))] +%3 -> %1 [Index(FieldIndex(1))] +%5 -> %2 [Param] +%5 -> %3 [Param] +%2 -> 42 [Index(FieldIndex(0))] +%3 -> 42 [Index(FieldIndex(0))] +%6 -> 42 [Load] diff --git a/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_consensus_projected_field_place.snap b/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_consensus_projected_field_place.snap new file mode 100644 index 00000000000..d9c5b6d6de4 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_consensus_projected_field_place.snap @@ -0,0 +1,22 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/data_dependency/tests.rs +expression: "format!(\"{graph}\\n\\n=====\\n\\n{transient}\")" +--- +%3 -> %0 [Index(FieldIndex(0))] +%3 -> %1 [Index(FieldIndex(1))] +%4 -> %0 [Index(FieldIndex(0))] +%4 -> %2 [Index(FieldIndex(1))] +%6 -> %3 [Param] +%6 -> %4 [Param] +%7 -> %6 [Load, projections: .0] + + +===== + +%3 -> %0 [Index(FieldIndex(0))] +%3 -> %1 [Index(FieldIndex(1))] +%4 -> %0 [Index(FieldIndex(0))] +%4 -> %2 [Index(FieldIndex(1))] +%6 -> %3 [Param] +%6 -> %4 [Param] +%7 -> %0 [Load] diff --git a/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_cycle_detection.snap b/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_cycle_detection.snap index 78e219ea059..e174f7581df 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_cycle_detection.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_cycle_detection.snap @@ -1,5 +1,6 @@ --- source: libs/@local/hashql/mir/src/pass/analysis/data_dependency/tests.rs +assertion_line: 32 expression: "format!(\"{graph}\\n\\n=====\\n\\n{transient}\")" --- %1 -> %0 [Param] @@ -10,5 +11,5 @@ expression: "format!(\"{graph}\\n\\n=====\\n\\n{transient}\")" ===== %1 -> %0 [Param] -%3 -> %1 [Param] -%1 -> %1 [Param] +%3 -> %0 [Param] +%1 -> %0 [Param] diff --git a/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_cycle_invariant_projected_field.snap b/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_cycle_invariant_projected_field.snap new file mode 100644 index 00000000000..601e7e15f75 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_cycle_invariant_projected_field.snap @@ -0,0 +1,22 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/data_dependency/tests.rs +expression: "format!(\"{graph}\\n\\n=====\\n\\n{transient}\")" +--- +%2 -> %0 [Index(FieldIndex(0))] +%2 -> %1 [Index(FieldIndex(1))] +%3 -> %2 [Param] +%4 -> %3 [Index(FieldIndex(0)), projections: .0] +%4 -> %1 [Index(FieldIndex(1))] +%6 -> %3 [Param, projections: .0] +%3 -> %4 [Param] + + +===== + +%2 -> %0 [Index(FieldIndex(0))] +%2 -> %1 [Index(FieldIndex(1))] +%3 -> %2 [Param] +%4 -> %0 [Index(FieldIndex(0))] +%4 -> %1 [Index(FieldIndex(1))] +%6 -> %0 [Param] +%3 -> %4 [Param] diff --git a/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_cycle_multi_node_with_const_init.snap b/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_cycle_multi_node_with_const_init.snap new file mode 100644 index 00000000000..5627aa46982 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_cycle_multi_node_with_const_init.snap @@ -0,0 +1,17 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/data_dependency/tests.rs +assertion_line: 32 +expression: "format!(\"{graph}\\n\\n=====\\n\\n{transient}\")" +--- +%1 -> %0 [Param] +%3 -> %1 [Param] +%0 -> %1 [Param] +%0 -> 42 [Param] + + +===== + +%0 -> 42 [Param] +%0 -> 42 [Param] +%1 -> 42 [Param] +%3 -> 42 [Param] diff --git a/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_cycle_visited_cleanup_on_diverge.snap b/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_cycle_visited_cleanup_on_diverge.snap new file mode 100644 index 00000000000..4417a007b7c --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_cycle_visited_cleanup_on_diverge.snap @@ -0,0 +1,19 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/data_dependency/tests.rs +assertion_line: 32 +expression: "format!(\"{graph}\\n\\n=====\\n\\n{transient}\")" +--- +%1 -> %0 [Param] +%2 -> %1 [Param] +%2 -> %2 [Param] +%2 -> %2 [Param] +%1 -> 42 [Param] + + +===== + +%1 -> %0 [Param] +%2 -> %1 [Param] +%2 -> %1 [Param] +%2 -> %1 [Param] +%1 -> 42 [Param] diff --git a/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_cycle_with_const_init.snap b/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_cycle_with_const_init.snap new file mode 100644 index 00000000000..3e000a7998f --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/data-dependency/param_cycle_with_const_init.snap @@ -0,0 +1,15 @@ +--- +source: libs/@local/hashql/mir/src/pass/analysis/data_dependency/tests.rs +assertion_line: 32 +expression: "format!(\"{graph}\\n\\n=====\\n\\n{transient}\")" +--- +%0 -> %0 [Param] +%0 -> %0 [Param] +%0 -> 42 [Param] + + +===== + +%0 -> 42 [Param] +%0 -> 42 [Param] +%0 -> 42 [Param] diff --git a/libs/@local/hashql/mir/tests/ui/pass/inline/loop_breaker_aggressive_filter.snap b/libs/@local/hashql/mir/tests/ui/pass/inline/loop_breaker_aggressive_filter.snap new file mode 100644 index 00000000000..e219ef754d7 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/inline/loop_breaker_aggressive_filter.snap @@ -0,0 +1,130 @@ +--- +source: libs/@local/hashql/mir/src/pass/transform/inline/tests.rs +assertion_line: 136 +expression: output +--- +fn {closure@4294967040}(%0: Integer) -> Integer { + let %1: Boolean + let %2: Integer + let %3: Integer + let %4: Integer + let %5: Integer + + bb0(): { + %1 = %0 == 0 + + switchInt(%1) -> [0: bb2(), 1: bb1()] + } + + bb1(): { + goto -> bb3(%0) + } + + bb2(): { + %2 = %0 + 1 + %3 = %2 + 2 + %4 = apply ({def@1} as FnPtr) %3 + + goto -> bb3(%4) + } + + bb3(%5): { + return %5 + } +} + +fn {closure@4294967040}(%0: Integer) -> Integer { + let %1: Integer + + bb0(): { + %1 = apply ({def@0} as FnPtr) %0 + + return %1 + } +} + +fn {graph::read::filter@4294967040}(%0: Integer) -> Integer { + let %1: Integer + + bb0(): { + %1 = apply ({def@1} as FnPtr) %0 + + return %1 + } +} + +================= After Inlining ================= + +fn {closure@4294967040}(%0: Integer) -> Integer { + let %1: Boolean + let %2: Integer + let %3: Integer + let %4: Integer + let %5: Integer + let %6: Integer + let %7: Integer + + bb0(): { + %1 = %0 == 0 + + switchInt(%1) -> [0: bb2(), 1: bb1()] + } + + bb1(): { + goto -> bb3(%0) + } + + bb2(): { + %2 = %0 + 1 + %3 = %2 + 2 + %6 = %3 + + goto -> bb5() + } + + bb3(%5): { + return %5 + } + + bb4(%4): { + goto -> bb3(%4) + } + + bb5(): { + %7 = apply ({def@0} as FnPtr) %6 + + goto -> bb4(%7) + } +} + +fn {closure@4294967040}(%0: Integer) -> Integer { + let %1: Integer + + bb0(): { + %1 = apply ({def@0} as FnPtr) %0 + + return %1 + } +} + +fn {graph::read::filter@4294967040}(%0: Integer) -> Integer { + let %1: Integer + let %2: Integer + let %3: Integer + + bb0(): { + %2 = %0 + + goto -> bb2() + } + + bb1(%1): { + return %1 + } + + bb2(): { + %3 = apply ({def@0} as FnPtr) %2 + + goto -> bb1(%3) + } +} diff --git a/libs/@local/hashql/mir/tests/ui/pass/inline/loop_breaker_mutual_recursion.snap b/libs/@local/hashql/mir/tests/ui/pass/inline/loop_breaker_mutual_recursion.snap new file mode 100644 index 00000000000..635712c8756 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/inline/loop_breaker_mutual_recursion.snap @@ -0,0 +1,130 @@ +--- +source: libs/@local/hashql/mir/src/pass/transform/inline/tests.rs +assertion_line: 136 +expression: output +--- +fn {closure@4294967040}(%0: Integer) -> Integer { + let %1: Boolean + let %2: Integer + let %3: Integer + let %4: Integer + let %5: Integer + + bb0(): { + %1 = %0 == 0 + + switchInt(%1) -> [0: bb2(), 1: bb1()] + } + + bb1(): { + goto -> bb3(%0) + } + + bb2(): { + %2 = %0 + 1 + %3 = %2 + 2 + %4 = apply ({def@1} as FnPtr) %3 + + goto -> bb3(%4) + } + + bb3(%5): { + return %5 + } +} + +fn {closure@4294967040}(%0: Integer) -> Integer { + let %1: Integer + + bb0(): { + %1 = apply ({def@0} as FnPtr) %0 + + return %1 + } +} + +thunk {thunk@4294967040}() -> Integer { + let %0: Integer + + bb0(): { + %0 = apply ({def@1} as FnPtr) 1 + + return %0 + } +} + +================= After Inlining ================= + +fn {closure@4294967040}(%0: Integer) -> Integer { + let %1: Boolean + let %2: Integer + let %3: Integer + let %4: Integer + let %5: Integer + let %6: Integer + let %7: Integer + + bb0(): { + %1 = %0 == 0 + + switchInt(%1) -> [0: bb2(), 1: bb1()] + } + + bb1(): { + goto -> bb3(%0) + } + + bb2(): { + %2 = %0 + 1 + %3 = %2 + 2 + %6 = %3 + + goto -> bb5() + } + + bb3(%5): { + return %5 + } + + bb4(%4): { + goto -> bb3(%4) + } + + bb5(): { + %7 = apply ({def@0} as FnPtr) %6 + + goto -> bb4(%7) + } +} + +fn {closure@4294967040}(%0: Integer) -> Integer { + let %1: Integer + + bb0(): { + %1 = apply ({def@0} as FnPtr) %0 + + return %1 + } +} + +thunk {thunk@4294967040}() -> Integer { + let %0: Integer + let %1: Integer + let %2: Integer + + bb0(): { + %1 = 1 + + goto -> bb2() + } + + bb1(%0): { + return %0 + } + + bb2(): { + %2 = apply ({def@0} as FnPtr) %1 + + goto -> bb1(%2) + } +}