Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,20 @@ private static ValDerivationNode foldUnary(ValDerivationNode node) {
*/
private static ValDerivationNode foldIte(ValDerivationNode node) {
Ite iteExp = (Ite) node.getValue();
DerivationNode parent = node.getOrigin();

ValDerivationNode condNode = fold(new ValDerivationNode(iteExp.getCondition(), null));
ValDerivationNode thenNode = fold(new ValDerivationNode(iteExp.getThen(), null));
ValDerivationNode elseNode = fold(new ValDerivationNode(iteExp.getElse(), null));
ValDerivationNode condNode;
ValDerivationNode thenNode;
ValDerivationNode elseNode;
if (parent instanceof IteDerivationNode iteOrigin) {
condNode = fold(iteOrigin.getCondition());
thenNode = fold(iteOrigin.getThenBranch());
elseNode = fold(iteOrigin.getElseBranch());
} else {
condNode = fold(new ValDerivationNode(iteExp.getCondition(), null));
thenNode = fold(new ValDerivationNode(iteExp.getThen(), null));
elseNode = fold(new ValDerivationNode(iteExp.getElse(), null));
}

Expression condition = condNode.getValue();
Expression thenExp = thenNode.getValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import liquidjava.rj_language.ast.Enum;
import liquidjava.rj_language.ast.Expression;
import liquidjava.rj_language.ast.FunctionInvocation;
import liquidjava.rj_language.ast.GroupExpression;
import liquidjava.rj_language.ast.Ite;
import liquidjava.rj_language.ast.UnaryExpression;
import liquidjava.rj_language.ast.Var;
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
Expand Down Expand Up @@ -101,6 +103,28 @@ private static ValDerivationNode propagateRecursive(Expression exp, Map<String,
: new ValDerivationNode(cloned, null);
}

// lift ternary origin
if (exp instanceof Ite ite) {
ValDerivationNode condition = propagateRecursive(ite.getCondition(), subs, varOrigins);
ValDerivationNode thenBranch = propagateRecursive(ite.getThen(), subs, varOrigins);
ValDerivationNode elseBranch = propagateRecursive(ite.getElse(), subs, varOrigins);
Ite cloned = (Ite) ite.clone();
cloned.setChild(0, condition.getValue());
cloned.setChild(1, thenBranch.getValue());
cloned.setChild(2, elseBranch.getValue());

return (condition.getOrigin() != null || thenBranch.getOrigin() != null || elseBranch.getOrigin() != null)
? new ValDerivationNode(cloned, new IteDerivationNode(condition, thenBranch, elseBranch))
: new ValDerivationNode(cloned, null);
}

if (exp instanceof GroupExpression group && group.getChildren().size() == 1) {
ValDerivationNode child = propagateRecursive(group.getExpression(), subs, varOrigins);
GroupExpression cloned = (GroupExpression) group.clone();
cloned.setChild(0, child.getValue());
return new ValDerivationNode(cloned, child.getOrigin());
}

// recursively propagate children
if (exp.hasChildren()) {
Expression propagated = exp.clone();
Expand Down Expand Up @@ -163,4 +187,4 @@ private static void extractVarOrigins(ValDerivationNode node, Map<String, Deriva
extractVarOrigins(valOrigin, varOrigins);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,22 @@ void testIteConditionUsesEqualityFromConjunction() {
"mode == 1 should make the mode == 2 ternary condition false");
}

@Test
void testIteConditionKeepsPropagatedVariableOrigin() {
Expression expr = parse("mode == 1 && (mode == 2 ? explicit(param) : start(param))");
ValDerivationNode result = ExpressionSimplifier.simplify(expr);

assertNotNull(result.getOrigin(), "ITE simplification should record the selected branch");
IteDerivationNode iteOrigin = (IteDerivationNode) result.getOrigin();
ValDerivationNode condition = iteOrigin.getCondition();
BinaryDerivationNode equality = (BinaryDerivationNode) condition.getOrigin();
ValDerivationNode left = equality.getLeft();

assertEquals("1", left.getValue().toString());
assertDerivationEquals(new VarDerivationNode("mode"), left.getOrigin(),
"Propagated condition value should come from the mode parameter");
}

@Test
void testByteAliasExpansion() {
String sut = "Byte(b)";
Expand Down
Loading