Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 49 additions & 27 deletions Strata/Languages/Laurel/LiftImperativeExpressions.lean
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,42 @@ private def freshCondVar : LiftM Identifier := do
modify fun s => { s with condCounter := n + 1 }
return s!"$c_{n}"

private def addPrepend (stmt : StmtExprMd) : LiftM Unit :=
private def prepend (stmt : StmtExprMd) : LiftM Unit :=
modify fun s => { s with prependedStmts := stmt :: s.prependedStmts }

private def onlyKeepSideEffectStmtsAndLast (stmts : List StmtExprMd) : LiftM (List StmtExprMd) := do
match stmts with
| [] => return []
| _ =>
let last := stmts.getLast!
let nonLast ← stmts.dropLast.flatMapM (fun s =>
match s.val with
| .LocalVariable .. => do
-- This addPrepend is a hack to work around Core not having let expressions
-- Otherwise we could keep them in the block
prepend s
pure []
| .Assert _ => do
-- Hack to work around Core not supporting assert expressions
-- Otherwise we could keep them in the block
prepend s
pure []
| .Assume _ => do
-- Hack to work around Core not supporting assume expressions
-- Otherwise we could keep them in the block
prepend s
pure []

/-
Any other impure StmtExpr, like .Assign, .Exit or .Return,
should already have been processed by translateExpr,
so we can assume this StmtExpr is pure and can be dropped.
TODO: currently .Exit and .Return are not processed by translateExpr, this is a bug
-/
| _ => pure []
)
return nonLast ++ [last]

private def takePrepends : LiftM (List StmtExprMd) := do
let stmts := (← get).prependedStmts
modify fun s => { s with prependedStmts := [] }
Expand Down Expand Up @@ -171,15 +204,15 @@ and updates substitutions. The value should already be transformed by the caller
private def liftAssignExpr (targets : List StmtExprMd) (seqValue : StmtExprMd)
(md : Imperative.MetaData Core.Expression) : LiftM Unit := do
-- Prepend the assignment itself
addPrepend (⟨.Assign targets seqValue, md⟩)
prepend (⟨.Assign targets seqValue, md⟩)
-- Create a before-snapshot for each target and update substitutions
for target in targets do
match target.val with
| .Identifier varName =>
let snapshotName ← freshTempFor varName
let varType ← computeType target
-- Snapshot goes before the assignment (cons pushes to front)
addPrepend (⟨.LocalVariable snapshotName varType (some (⟨.Identifier varName, md⟩)), md⟩)
prepend (⟨.LocalVariable snapshotName varType (some (⟨.Identifier varName, md⟩)), md⟩)
setSubst varName snapshotName
| _ => pure ()

Expand All @@ -200,7 +233,7 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do
| .Hole false (some holeType) =>
-- Nondeterministic typed hole: lift to a fresh variable with no initializer (havoc)
let holeVar ← freshCondVar
addPrepend (bare (.LocalVariable holeVar holeType none))
prepend (bare (.LocalVariable holeVar holeType none))
return bare (.Identifier holeVar)

| .Assign targets value =>
Expand Down Expand Up @@ -235,12 +268,13 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do
return seqCall
else
-- Imperative call in expression position: lift it like an assignment
-- Order matters: assign must be prepended first (it's newest-first),
-- so that when reversed the var declaration comes before the call.
let callResultVar ← freshCondVar
let callResultType ← computeType expr
addPrepend (⟨.Assign [bare (.Identifier callResultVar)] seqCall, md⟩)
addPrepend (bare (.LocalVariable callResultVar callResultType none))
let liftedCall := [
⟨ (.LocalVariable callResultVar callResultType none), md ⟩,
⟨.Assign [bare (.Identifier callResultVar)] seqCall, md⟩
]
modify fun s => { s with prependedStmts := s.prependedStmts ++ liftedCall}
return bare (.Identifier callResultVar)

| .IfThenElse cond thenBranch elseBranch =>
Expand Down Expand Up @@ -277,8 +311,8 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do
let condType ← computeType thenBranch
-- IfThenElse added first (cons puts it deeper), then declaration (cons puts it on top)
-- Output order: declaration, then if-then-else
addPrepend (⟨.IfThenElse seqCond thenBlock seqElse, md⟩)
addPrepend (bare (.LocalVariable condVar condType none))
prepend (⟨.IfThenElse seqCond thenBlock seqElse, md⟩)
prepend (bare (.LocalVariable condVar condType none))
return bare (.Identifier condVar)
else
-- No assignments in branches — recurse normally
Expand All @@ -289,21 +323,9 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do
| none => pure none
return ⟨.IfThenElse seqCond seqThen seqElse, md⟩

| .Block stmts metadata =>
-- Block in expression position: lift all but last to prepends
match h_last : stmts.getLast? with
| none => return bare (.Block [] metadata)
| some last => do
have := List.mem_of_getLast? h_last

-- Process all-but-last as statements and prepend them in order
let mut blockStmts : List StmtExprMd := []
for nonLastStatement in stmts.dropLast.attach do
have := List.dropLast_subset stmts nonLastStatement.property
blockStmts := blockStmts ++ (← transformStmt nonLastStatement)
for s in blockStmts.reverse do addPrepend s
-- Last element is the expression value
transformExpr last
| .Block stmts labelOption =>
let newStmts := (← stmts.reverse.mapM transformExpr).reverse
return ⟨ .Block (← onlyKeepSideEffectStmtsAndLast newStmts) labelOption, md ⟩

| .LocalVariable name ty initializer =>
-- If the substitution map has an entry for this variable, it was
Expand All @@ -314,9 +336,9 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do
match initializer with
| some initExpr =>
let seqInit ← transformExpr initExpr
addPrepend (⟨.LocalVariable name ty (some seqInit), expr.md⟩)
prepend (⟨.LocalVariable name ty (some seqInit), expr.md⟩)
| none =>
addPrepend (⟨.LocalVariable name ty none, expr.md⟩)
prepend (⟨.LocalVariable name ty none, expr.md⟩)
return ⟨.Identifier (← getSubst name), expr.md⟩
else
return expr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,36 @@ procedure imperativeCallInConditionalExpression(b: bool) {
assert result == 0
}
};

function add(x: int, y: int): int
{
x + y
};

procedure repeatedBlockExpressions() {
var x: int := 2;
var y: int := { x := 1; x } + { x := x + 10; x };
var z: int := add({ x := 1; x }, { x := x + 10; x });
assert y == 1 + 11;
assert z == 1 + 11
};

procedure addProc(a: int, b: int) returns (r: int)
ensures r == a + b {
return a + b
};

procedure addProcCaller(): int {
var x: int := 0;
var y: int := addProc({x := 1; x}, {x := x + 10; x});
assert y == 11

// The next statement is not translated correctly.
// I think it's a bug in the handling of StaticCall
// Where a reference is substituted when it should not be
// var z: int := addProc({x := 1; x}, {x := x + 10; x}) + (x := 3);
// assert z == 14
};
"

#guard_msgs (error, drop all) in
Expand Down
18 changes: 2 additions & 16 deletions StrataTest/Languages/Laurel/LiftExpressionAssignmentsTest.lean
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,6 @@ open Strata.Elab (parseStrataProgramFromDialect)
namespace Strata.Laurel

def blockStmtLiftingProgram : String := r"
composite Box {
var value: int
}

procedure heapUpdateInBlockExpr(b: Box)
{
var x: int := { b#value := b#value + 1; b#value };
assert x == b#value
};

procedure assertInBlockExpr()
{
var x: int := 0;
Expand All @@ -53,14 +43,10 @@ def parseLaurelAndLift (input : String) : IO Program := do
pure (liftExpressionAssignments model program)

/--
info: procedure heapUpdateInBlockExpr(b: Box) returns ⏎
()
deterministic
{ b#value := b#value + 1; var x: int := b#value; assert x == b#value }
procedure assertInBlockExpr() returns ⏎
info: procedure assertInBlockExpr() returns ⏎
()
deterministic
{ var x: int := 0; assert x == 0; x := 1; var y: int := x; assert y == 1 }
{ var x: int := 0; assert x == 0; var $x_0: int := x; x := 1; var y: int := { x }; assert y == 1 }
-/
#guard_msgs in
#eval! do
Expand Down
Loading