diff --git a/Strata/Languages/Laurel/LiftImperativeExpressions.lean b/Strata/Languages/Laurel/LiftImperativeExpressions.lean index e29618fef..f1e560340 100644 --- a/Strata/Languages/Laurel/LiftImperativeExpressions.lean +++ b/Strata/Languages/Laurel/LiftImperativeExpressions.lean @@ -104,9 +104,42 @@ private def freshCondVar : LiftM Identifier := do modify fun s => { s with condCounter := n + 1 } return s!"$c_{n}" -private def addPrepend (stmt : StmtExprMd) : LiftM Unit := +private def prepend (stmt : StmtExprMd) : LiftM Unit := modify fun s => { s with prependedStmts := stmt :: s.prependedStmts } +private def onlyKeepSideEffectStmtsAndLast (stmts : List StmtExprMd) : LiftM (List StmtExprMd) := do + match stmts with + | [] => return [] + | _ => + let last := stmts.getLast! + let nonLast ← stmts.dropLast.flatMapM (fun s => + match s.val with + | .LocalVariable .. => do + -- This addPrepend is a hack to work around Core not having let expressions + -- Otherwise we could keep them in the block + prepend s + pure [] + | .Assert _ => do + -- Hack to work around Core not supporting assert expressions + -- Otherwise we could keep them in the block + prepend s + pure [] + | .Assume _ => do + -- Hack to work around Core not supporting assume expressions + -- Otherwise we could keep them in the block + prepend s + pure [] + + /- + Any other impure StmtExpr, like .Assign, .Exit or .Return, + should already have been processed by translateExpr, + so we can assume this StmtExpr is pure and can be dropped. + TODO: currently .Exit and .Return are not processed by translateExpr, this is a bug + -/ + | _ => pure [] + ) + return nonLast ++ [last] + private def takePrepends : LiftM (List StmtExprMd) := do let stmts := (← get).prependedStmts modify fun s => { s with prependedStmts := [] } @@ -171,7 +204,7 @@ and updates substitutions. The value should already be transformed by the caller private def liftAssignExpr (targets : List StmtExprMd) (seqValue : StmtExprMd) (md : Imperative.MetaData Core.Expression) : LiftM Unit := do -- Prepend the assignment itself - addPrepend (⟨.Assign targets seqValue, md⟩) + prepend (⟨.Assign targets seqValue, md⟩) -- Create a before-snapshot for each target and update substitutions for target in targets do match target.val with @@ -179,7 +212,7 @@ private def liftAssignExpr (targets : List StmtExprMd) (seqValue : StmtExprMd) let snapshotName ← freshTempFor varName let varType ← computeType target -- Snapshot goes before the assignment (cons pushes to front) - addPrepend (⟨.LocalVariable snapshotName varType (some (⟨.Identifier varName, md⟩)), md⟩) + prepend (⟨.LocalVariable snapshotName varType (some (⟨.Identifier varName, md⟩)), md⟩) setSubst varName snapshotName | _ => pure () @@ -200,7 +233,7 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do | .Hole false (some holeType) => -- Nondeterministic typed hole: lift to a fresh variable with no initializer (havoc) let holeVar ← freshCondVar - addPrepend (bare (.LocalVariable holeVar holeType none)) + prepend (bare (.LocalVariable holeVar holeType none)) return bare (.Identifier holeVar) | .Assign targets value => @@ -235,12 +268,13 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do return seqCall else -- Imperative call in expression position: lift it like an assignment - -- Order matters: assign must be prepended first (it's newest-first), - -- so that when reversed the var declaration comes before the call. let callResultVar ← freshCondVar let callResultType ← computeType expr - addPrepend (⟨.Assign [bare (.Identifier callResultVar)] seqCall, md⟩) - addPrepend (bare (.LocalVariable callResultVar callResultType none)) + let liftedCall := [ + ⟨ (.LocalVariable callResultVar callResultType none), md ⟩, + ⟨.Assign [bare (.Identifier callResultVar)] seqCall, md⟩ + ] + modify fun s => { s with prependedStmts := s.prependedStmts ++ liftedCall} return bare (.Identifier callResultVar) | .IfThenElse cond thenBranch elseBranch => @@ -277,8 +311,8 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do let condType ← computeType thenBranch -- IfThenElse added first (cons puts it deeper), then declaration (cons puts it on top) -- Output order: declaration, then if-then-else - addPrepend (⟨.IfThenElse seqCond thenBlock seqElse, md⟩) - addPrepend (bare (.LocalVariable condVar condType none)) + prepend (⟨.IfThenElse seqCond thenBlock seqElse, md⟩) + prepend (bare (.LocalVariable condVar condType none)) return bare (.Identifier condVar) else -- No assignments in branches — recurse normally @@ -289,21 +323,9 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do | none => pure none return ⟨.IfThenElse seqCond seqThen seqElse, md⟩ - | .Block stmts metadata => - -- Block in expression position: lift all but last to prepends - match h_last : stmts.getLast? with - | none => return bare (.Block [] metadata) - | some last => do - have := List.mem_of_getLast? h_last - - -- Process all-but-last as statements and prepend them in order - let mut blockStmts : List StmtExprMd := [] - for nonLastStatement in stmts.dropLast.attach do - have := List.dropLast_subset stmts nonLastStatement.property - blockStmts := blockStmts ++ (← transformStmt nonLastStatement) - for s in blockStmts.reverse do addPrepend s - -- Last element is the expression value - transformExpr last + | .Block stmts labelOption => + let newStmts := (← stmts.reverse.mapM transformExpr).reverse + return ⟨ .Block (← onlyKeepSideEffectStmtsAndLast newStmts) labelOption, md ⟩ | .LocalVariable name ty initializer => -- If the substitution map has an entry for this variable, it was @@ -314,9 +336,9 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do match initializer with | some initExpr => let seqInit ← transformExpr initExpr - addPrepend (⟨.LocalVariable name ty (some seqInit), expr.md⟩) + prepend (⟨.LocalVariable name ty (some seqInit), expr.md⟩) | none => - addPrepend (⟨.LocalVariable name ty none, expr.md⟩) + prepend (⟨.LocalVariable name ty none, expr.md⟩) return ⟨.Identifier (← getSubst name), expr.md⟩ else return expr diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T2_ImpureExpressions.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T2_ImpureExpressions.lean index ce53e09f0..0ec9bf87f 100644 --- a/StrataTest/Languages/Laurel/Examples/Fundamentals/T2_ImpureExpressions.lean +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T2_ImpureExpressions.lean @@ -97,6 +97,36 @@ procedure imperativeCallInConditionalExpression(b: bool) { assert result == 0 } }; + +function add(x: int, y: int): int +{ + x + y +}; + +procedure repeatedBlockExpressions() { + var x: int := 2; + var y: int := { x := 1; x } + { x := x + 10; x }; + var z: int := add({ x := 1; x }, { x := x + 10; x }); + assert y == 1 + 11; + assert z == 1 + 11 +}; + +procedure addProc(a: int, b: int) returns (r: int) + ensures r == a + b { + return a + b +}; + +procedure addProcCaller(): int { + var x: int := 0; + var y: int := addProc({x := 1; x}, {x := x + 10; x}); + assert y == 11 + + // The next statement is not translated correctly. + // I think it's a bug in the handling of StaticCall + // Where a reference is substituted when it should not be + // var z: int := addProc({x := 1; x}, {x := x + 10; x}) + (x := 3); + // assert z == 14 +}; " #guard_msgs (error, drop all) in diff --git a/StrataTest/Languages/Laurel/LiftExpressionAssignmentsTest.lean b/StrataTest/Languages/Laurel/LiftExpressionAssignmentsTest.lean index 28bc064cc..f326010bf 100644 --- a/StrataTest/Languages/Laurel/LiftExpressionAssignmentsTest.lean +++ b/StrataTest/Languages/Laurel/LiftExpressionAssignmentsTest.lean @@ -22,16 +22,6 @@ open Strata.Elab (parseStrataProgramFromDialect) namespace Strata.Laurel def blockStmtLiftingProgram : String := r" -composite Box { - var value: int -} - -procedure heapUpdateInBlockExpr(b: Box) -{ - var x: int := { b#value := b#value + 1; b#value }; - assert x == b#value -}; - procedure assertInBlockExpr() { var x: int := 0; @@ -53,14 +43,10 @@ def parseLaurelAndLift (input : String) : IO Program := do pure (liftExpressionAssignments model program) /-- -info: procedure heapUpdateInBlockExpr(b: Box) returns ⏎ -() -deterministic -{ b#value := b#value + 1; var x: int := b#value; assert x == b#value } -procedure assertInBlockExpr() returns ⏎ +info: procedure assertInBlockExpr() returns ⏎ () deterministic -{ var x: int := 0; assert x == 0; x := 1; var y: int := x; assert y == 1 } +{ var x: int := 0; assert x == 0; var $x_0: int := x; x := 1; var y: int := { x }; assert y == 1 } -/ #guard_msgs in #eval! do