diff --git a/Cslib/Foundations/Control/Monad/Free.lean b/Cslib/Foundations/Control/Monad/Free.lean index b0a828c1..43790d78 100644 --- a/Cslib/Foundations/Control/Monad/Free.lean +++ b/Cslib/Foundations/Control/Monad/Free.lean @@ -51,8 +51,6 @@ The `FreeM` monad is defined using an inductive type with constructors `.pure` a We implement `Functor` and `Monad` instances, and prove the corresponding `LawfulFunctor` and `LawfulMonad` instances. -For now we choose to make the constructors the simp-normal form, as opposed to the standard -monad notation. The file `Free/Effects.lean` demonstrates practical applications by implementing State, Writer, and Continuations monads using `FreeM` with appropriate effect signatures. @@ -71,6 +69,9 @@ Free monad, state monad namespace Cslib +-- Disable generation of unneeded lemmas which the simpNF linter would complain about. +set_option genInjectivity false in +set_option genSizeOfSpec false in /-- The Free monad over a type constructor `F`. A `FreeM F a` is a tree of operations from the type constructor `F`, with leaves of type `a`. @@ -94,83 +95,98 @@ universe u v w w' w'' namespace FreeM variable {F : Type u → Type v} {ι : Type u} {α : Type w} {β : Type w'} {γ : Type w''} +section notations + instance : Pure (FreeM F) where pure := .pure @[simp] -theorem pure_eq_pure : (pure : α → FreeM F α) = FreeM.pure := rfl +theorem pure_eq_pure : FreeM.pure = (pure : α → FreeM F α) := rfl + +/-- Bind operation for the `FreeM` monad. -/-- Bind operation for the `FreeM` monad. -/ +The builtin `>>=` notation should be preferred when `α` and `β` are in the same universe. -/ protected def bind (x : FreeM F α) (f : α → FreeM F β) : FreeM F β := match x with | .pure a => f a | .liftBind op cont => .liftBind op fun z => FreeM.bind (cont z) f -protected theorem bind_assoc (x : FreeM F α) (f : α → FreeM F β) (g : β → FreeM F γ) : - (x.bind f).bind g = x.bind (fun x => (f x).bind g) := by - induction x with - | pure a => rfl - | liftBind op cont ih => - simp [FreeM.bind] at * - simp [ih] - instance : Bind (FreeM F) where bind := .bind +/-- Note that this lemma does not always apply, as it is universe-constrained by `Bind.bind`. -/ @[simp] -theorem bind_eq_bind {α β : Type w} : Bind.bind = (FreeM.bind : FreeM F α → _ → FreeM F β) := rfl +theorem bind_eq_bind {α β : Type w} : (FreeM.bind : FreeM F α → _ → FreeM F β) = Bind.bind := rfl -/-- Map a function over a `FreeM` monad. -/ -@[simp] +/-- Map a function over a `FreeM` monad. + +The builtin `<$>` notation should be preferred when `α` and `β` are in the same universe. -/ def map (f : α → β) : FreeM F α → FreeM F β | .pure a => .pure (f a) | .liftBind op cont => .liftBind op fun z => FreeM.map f (cont z) -@[simp] -theorem id_map : ∀ x : FreeM F α, map id x = x - | .pure a => rfl - | .liftBind op cont => by simp_all [map, id_map] - -theorem comp_map (h : β → γ) (g : α → β) : ∀ x : FreeM F α, map (h ∘ g) x = map h (map g x) - | .pure a => rfl - | .liftBind op cont => by simp_all [map, comp_map] - instance : Functor (FreeM F) where map := .map +/-- Note that this lemma does not always apply, as it is universe-constrained by `Functor.map`. -/ @[simp] -theorem map_eq_map {α β : Type w} : Functor.map = FreeM.map (F := F) (α := α) (β := β) := rfl +theorem map_eq_map {α β : Type w} : FreeM.map (F := F) (α := α) (β := β) = Functor.map := rfl /-- Lift an operation from the effect signature `F` into the `FreeM F` monad. -/ def lift (op : F ι) : FreeM F ι := - .liftBind op .pure + .liftBind op pure -/-- Rewrite `lift` to the constructor form so that simplification stays in constructor normal -form. -/ @[simp] -lemma lift_def (op : F ι) : - (lift op : FreeM F ι) = liftBind op .pure := rfl +lemma liftBind_eq (op : F ι) : + liftBind op cont = (lift op : FreeM F ι).bind cont := + rfl -@[simp] -lemma map_lift (f : ι → α) (op : F ι) : - map f (lift op : FreeM F ι) = liftBind op (fun z => (.pure (f z) : FreeM F α)) := rfl +set_option linter.unusedVariables false in +/-- An override for the default induction principle that is in simp-normal form. -/ +@[induction_eliminator] +protected theorem induction {motive : FreeM F α → Prop} + (pure : ∀ a, motive (pure a)) + (lift_bind : ∀ {ι} (op : F ι) (cont : ι → FreeM F α) (ih : ∀ i, motive (cont i)), + motive ((lift op).bind cont)) : ∀ x, motive x + | .pure a => pure a + | liftBind _ _ => lift_bind _ _ fun _ => FreeM.induction pure lift_bind _ + +end notations + +protected theorem bind_assoc (x : FreeM F α) (f : α → FreeM F β) (g : β → FreeM F γ) : + (x.bind f).bind g = x.bind (fun x => (f x).bind g) := by + induction x with + | pure a => rfl + | lift_bind op cont ih => simp [← liftBind_eq, FreeM.bind, ih] at * /-- `.pure a` followed by `bind` collapses immediately. -/ @[simp] -lemma pure_bind (a : α) (f : α → FreeM F β) : (.pure a : FreeM F α).bind f = f a := rfl +lemma pure_bind (a : α) (f : α → FreeM F β) : (pure a : FreeM F α).bind f = f a := rfl @[simp] -lemma bind_pure : ∀ x : FreeM F α, x.bind (.pure) = x +lemma bind_pure : ∀ x : FreeM F α, x.bind pure = x | .pure a => rfl - | liftBind op k => by simp [FreeM.bind, bind_pure] + | liftBind op k => by simp [FreeM.bind, bind_pure, -bind_eq_bind] @[simp] -lemma bind_pure_comp (f : α → β) : ∀ x : FreeM F α, x.bind (.pure ∘ f) = map f x +lemma bind_pure_comp (f : α → β) : ∀ x : FreeM F α, x.bind (pure ∘ f) = map f x | .pure a => rfl | liftBind op k => by simp only [FreeM.bind, map, bind_pure_comp] -/-- Collapse a `.bind` that follows a `liftBind` into a single `liftBind` -/ @[simp] -lemma liftBind_bind (op : F ι) (cont : ι → FreeM F α) (f : α → FreeM F β) : - (liftBind op cont).bind f = liftBind op fun x => (cont x).bind f := rfl +theorem map_pure (f : α → β) (x : α) : map f (pure x : FreeM F α) = pure (f x) := rfl + +@[simp] +theorem map_bind (f : β → γ) (x : FreeM F α) (c : α → FreeM F β) : + map f (x.bind c) = x.bind fun a => (c a).map f := by + simp_rw [← bind_pure_comp, FreeM.bind_assoc] + +@[simp] +theorem id_map : ∀ x : FreeM F α, map id x = x + | .pure a => rfl + | .liftBind op cont => by simp_all [map, id_map] + +theorem comp_map (h : β → γ) (g : α → β) : ∀ x : FreeM F α, map (h ∘ g) x = map h (map g x) + | .pure a => rfl + | .liftBind op cont => by simp_all [map, comp_map] instance : LawfulFunctor (FreeM F) where map_const := rfl @@ -202,16 +218,16 @@ protected def liftM (interp : {ι : Type u} → F ι → m ι) : FreeM F α → @[simp] lemma liftM_pure (interp : {ι : Type u} → F ι → m ι) (a : α) : - (.pure a : FreeM F α).liftM interp = pure a := rfl + (pure a : FreeM F α).liftM interp = pure a := rfl @[simp] -lemma liftM_liftBind (interp : {ι : Type u} → F ι → m ι) (op : F β) (cont : β → FreeM F α) : - (liftBind op cont).liftM interp = (do let b ← interp op; (cont b).liftM interp) := by +lemma liftM_lift_bind (interp : {ι : Type u} → F ι → m ι) (op : F β) (cont : β → FreeM F α) : + ((lift op) >>= cont).liftM interp = (do let b ← interp op; (cont b).liftM interp) := by rfl lemma liftM_lift [LawfulMonad m] (interp : {ι : Type u} → F ι → m ι) (op : F β) : (lift op).liftM interp = interp op := by - simp_rw [lift_def, liftM_liftBind, liftM_pure, _root_.bind_pure] + simp_rw [lift, FreeM.liftM, _root_.bind_pure] @[simp] lemma liftM_bind [LawfulMonad m] @@ -219,9 +235,7 @@ lemma liftM_bind [LawfulMonad m] (x.bind f : FreeM F β).liftM interp = (do let a ← x.liftM interp; (f a).liftM interp) := by induction x generalizing f with | pure a => simp only [pure_bind, liftM_pure, LawfulMonad.pure_bind] - | liftBind op cont ih => - rw [FreeM.bind, liftM_liftBind, liftM_liftBind, bind_assoc] - simp_rw [ih] + | lift_bind op cont ih => simp [← ih] /-- A predicate stating that `interp : FreeM F α → m α` is an interpreter for the effect @@ -237,8 +251,8 @@ Formally, `interp` satisfies the two equations: -/ structure Interprets (handler : {ι : Type u} → F ι → m ι) (interp : FreeM F α → m α) : Prop where apply_pure (a : α) : interp (.pure a) = pure a - apply_liftBind {ι : Type u} (op : F ι) (cont : ι → FreeM F α) : - interp (liftBind op cont) = handler op >>= fun x => interp (cont x) + apply_lift_bind {ι : Type u} (op : F ι) (cont : ι → FreeM F α) : + interp (lift op >>= cont) = handler op >>= fun x => interp (cont x) theorem Interprets.eq {handler : {ι : Type u} → F ι → m ι} {interp : FreeM F α → m α} (h : Interprets handler interp) : @@ -246,14 +260,13 @@ theorem Interprets.eq {handler : {ι : Type u} → F ι → m ι} {interp : Free ext x induction x with | pure a => exact h.apply_pure a - | liftBind op cont ih => - rw [liftM_liftBind, h.apply_liftBind] - simp [ih] + | lift_bind op cont ih => + simp [h.apply_lift_bind, ih] theorem Interprets.liftM (handler : {ι : Type u} → F ι → m ι) : Interprets handler (·.liftM handler : FreeM F α → _) where apply_pure _ := rfl - apply_liftBind _ _ := rfl + apply_lift_bind _ _ := rfl /-- The universal property of the free monad `FreeM`. diff --git a/Cslib/Foundations/Control/Monad/Free/Effects.lean b/Cslib/Foundations/Control/Monad/Free/Effects.lean index e1675c01..a08aeb39 100644 --- a/Cslib/Foundations/Control/Monad/Free/Effects.lean +++ b/Cslib/Foundations/Control/Monad/Free/Effects.lean @@ -110,15 +110,27 @@ theorem run_toStateM {α : Type u} (comp : FreeState σ α) : @[simp] lemma run_pure (a : α) (s₀ : σ) : - run (.pure a : FreeState σ α) s₀ = (a, s₀) := rfl + run (pure a : FreeState σ α) s₀ = (a, s₀) := rfl @[simp] -lemma run_get (k : σ → FreeState σ α) (s₀ : σ) : - run (liftBind .get k) s₀ = run (k s₀) s₀ := rfl +lemma run_get (s₀ : σ) : + run (lift .get) s₀ = (s₀, s₀) := rfl @[simp] -lemma run_set (s' : σ) (k : PUnit → FreeState σ α) (s₀ : σ) : - run (liftBind (.set s') k) s₀ = run (k .unit) s' := rfl +lemma run_set (s' : σ) (s₀ : σ) : + run (lift (.set s')) s₀ = (.unit, s') := rfl + +lemma run_lift_bind (f : ι → FreeState σ β) (s₀ : σ) : + run ((lift op).bind f) s₀ = let p := run (lift op) s₀; (f p.1).run p.2 := by + cases op <;> simp [← liftBind_eq, run] + +@[simp] +lemma run_bind (x : FreeState σ α) (f : α → FreeState σ β) (s₀ : σ) : + run (x.bind f) s₀ = let p := x.run s₀; (f p.1).run p.2 := by + induction x using FreeM.induction generalizing f s₀ with + | pure => simp + | lift_bind op cont ih => + simp_rw [FreeM.bind_assoc, run_lift_bind, ih] /-- Run a state computation, returning only the result. -/ def run' (c : FreeState σ α) (s₀ : σ) : α := (run c s₀).1 @@ -132,15 +144,15 @@ theorem run'_toStateM {α : Type u} (comp : FreeState σ α) : @[simp] lemma run'_pure (a : α) (s₀ : σ) : - run' (.pure a : FreeState σ α) s₀ = a := rfl + run' (pure a : FreeState σ α) s₀ = a := rfl @[simp] -lemma run'_get (k : σ → FreeState σ α) (s₀ : σ) : - run' (liftBind .get k) s₀ = run' (k s₀) s₀ := rfl +lemma run'_get (s₀ : σ) : + run' (lift .get) s₀ = s₀ := rfl @[simp] -lemma run'_set (s' : σ) (k : PUnit → FreeState σ α) (s₀ : σ) : - run' (liftBind (.set s') k) s₀ = run' (k .unit) s' := rfl +lemma run'_set (s' : σ) (s₀ : σ) : + run' (lift (.set s')) s₀ = .unit := rfl end FreeState @@ -164,6 +176,7 @@ open WriterF variable {ω : Type u} {α : Type u} /-- Interpret `WriterF` operations into `WriterT`. -/ +@[simp] def writerInterp {α : Type u} : WriterF ω α → WriterT ω Id α | .tell w => MonadWriter.tell w @@ -198,24 +211,49 @@ def run [Monoid ω] : FreeWriter ω α → α × ω @[simp] lemma run_pure [Monoid ω] (a : α) : - run (.pure a : FreeWriter ω α) = (a, 1) := rfl + run (pure a : FreeWriter ω α) = (a, 1) := rfl + +@[simp] +lemma run_lift_tell [Monoid ω] (w : ω) : + run (lift (.tell w)) = (.unit, w) := Prod.ext rfl <| mul_one _ @[simp] -lemma run_liftBind_tell [Monoid ω] (w : ω) (k : PUnit → FreeWriter ω α) : - run (liftBind (.tell w) k) = (let (a, w') := run (k .unit); (a, w * w')) := rfl +lemma run_lift_bind [Monoid ω] (op) (f : ι → FreeWriter ω β) : + run (lift op >>= f) = let p := run (lift op); ((f p.1).run.1, p.2 * (f p.1).run.2) := by + cases op; simp [← bind_eq_bind, ← liftBind_eq, run] + +-- https://github.com/leanprover-community/mathlib4/pull/36497 +section missing_from_mathlib + +@[simp] +theorem _root_.WriterT.run_pure [Monoid ω] [Monad M] (a : α) : + WriterT.run (pure a : WriterT ω M α) = pure (a, 1) := rfl + +@[simp] +theorem _root_.WriterT.run_bind [Monoid ω] [Monad M] (x : WriterT ω M α) (f : α → WriterT ω M β) : + WriterT.run (x >>= f) = x.run >>= fun (a, w₁) => (fun (b, w₂) => (b, w₁ * w₂)) <$> (f a).run := + rfl + +@[simp] +theorem _root_.WriterT.run_tell [Monad M] (w : ω) : + WriterT.run (MonadWriter.tell w : WriterT ω M PUnit) = pure (.unit, w) := rfl + +end missing_from_mathlib /-- The canonical interpreter `toWriterT` derived from `liftM` agrees with the hand-written recursive interpreter `run` for `FreeWriter`. -/ @[simp] -theorem run_toWriterT {α : Type u} [Monoid ω] : - ∀ comp : FreeWriter ω α, (toWriterT comp).run = run comp - | .pure _ => by simp only [toWriterT, liftM_pure, run_pure, pure, WriterT.run] - | liftBind (.tell w) cont => by - simp only [toWriterT, liftM_liftBind, run_liftBind_tell] at * - rw [← run_toWriterT] - congr +theorem run_toWriterT {α : Type u} [Monoid ω] (comp : FreeWriter ω α) : + (toWriterT comp).run = pure (run comp) := by + induction comp using FreeM.induction with + | pure _ => simp [toWriterT] + | lift_bind op cont ih => + ext : 1 + cases op + simp only [toWriterT] at * + simp [ih] /-- `listen` captures the log produced by a subcomputation incrementally. It traverses the computation, @@ -230,13 +268,13 @@ def listen [Monoid ω] : FreeWriter ω α → FreeWriter ω (α × ω) @[simp] lemma listen_pure [Monoid ω] (a : α) : - listen (.pure a : FreeWriter ω α) = .pure (a, 1) := rfl + listen (pure a : FreeWriter ω α) = .pure (a, 1) := rfl @[simp] -lemma listen_liftBind_tell [Monoid ω] (w : ω) +lemma listen_lift_tell_bind [Monoid ω] (w : ω) (k : PUnit → FreeWriter ω α) : - listen (liftBind (.tell w) k) = - liftBind (.tell w) (fun _ => + listen (lift (.tell w) >>= k) = + lift (.tell w) >>= (fun _ => listen (k .unit) >>= fun (a, w') => pure (a, w * w')) := by rfl @@ -315,23 +353,21 @@ theorem run_toContT {α : Type u} (comp : FreeCont r α) : @[simp] lemma run_pure (a : α) (k : α → r) : - run (.pure a : FreeCont r α) k = k a := rfl + run (pure a : FreeCont r α) k = k a := rfl @[simp] -lemma run_liftBind_callCC (g : (α → r) → r) +lemma run_lift_callCC_bind (g : (α → r) → r) (cont : α → FreeCont r β) (k : β → r) : - run (liftBind (.callCC g) cont) k = g (fun a => run (cont a) k) := rfl + run (lift (.callCC g) |>.bind cont) k = g (fun a => run (cont a) k) := rfl + +@[simp] +lemma run_lift_callCC (g : (α → r) → r) (k : α → r) : + run (lift (.callCC g)) k = g k := rfl /-- Call with current continuation for the Free continuation monad. -/ def callCC (f : MonadCont.Label α (FreeCont r) β → FreeCont r α) : FreeCont r α := - liftBind (.callCC fun k => run (f ⟨fun x => liftBind (.callCC fun _ => k x) pure⟩) k) pure - -@[simp] -lemma callCC_def (f : MonadCont.Label α (FreeCont r) β → FreeCont r α) : - callCC f = - liftBind (.callCC fun k => run (f ⟨fun x => liftBind (.callCC fun _ => k x) pure⟩) k) pure := - rfl + lift (.callCC fun k => run (f ⟨fun x => lift (.callCC fun _ => k x)⟩) k) instance : MonadCont (FreeCont r) where callCC := .callCC @@ -339,8 +375,8 @@ instance : MonadCont (FreeCont r) where /-- `run` of a `callCC` node simplifies to running the handler with the current continuation. -/ @[simp] lemma run_callCC (f : MonadCont.Label α (FreeCont r) β → FreeCont r α) (k : α → r) : - run (callCC f) k = run (f ⟨fun x => liftBind (.callCC fun _ => k x) pure⟩) k := by - simp [callCC, run_liftBind_callCC] + run (callCC f) k = run (f ⟨fun x => lift (.callCC fun _ => k x)⟩) k := by + simp [callCC] end FreeCont @@ -395,11 +431,11 @@ theorem run_toReaderM {α : Type u} (comp : FreeReader σ α) (s : σ) : @[simp] lemma run_pure (a : α) (s₀ : σ) : - run (.pure a : FreeReader σ α) s₀ = a := rfl + run (pure a : FreeReader σ α) s₀ = a := rfl @[simp] lemma run_read (k : σ → FreeReader σ α) (s₀ : σ) : - run (liftBind .read k) s₀ = run (k s₀) s₀ := rfl + run (lift .read >>= k) s₀ = run (k s₀) s₀ := rfl instance instMonadWithReaderOf : MonadWithReaderOf σ (FreeReader σ) where withReader {α} f m := diff --git a/Cslib/Foundations/Control/Monad/Free/Fold.lean b/Cslib/Foundations/Control/Monad/Free/Fold.lean index 99201914..6fd8f589 100644 --- a/Cslib/Foundations/Control/Monad/Free/Fold.lean +++ b/Cslib/Foundations/Control/Monad/Free/Fold.lean @@ -59,20 +59,36 @@ def foldFreeM theorem foldFreeM_pure (onValue : α → β) (onEffect : {ι : Type u} → F ι → (ι → β) → β) - (a : α) : foldFreeM onValue onEffect (.pure a) = onValue a := rfl + (a : α) : foldFreeM onValue onEffect (pure a) = onValue a := rfl @[simp] -theorem foldFreeM_liftBind +theorem foldFreeM_lift_bind (onValue : α → β) (onEffect : {ι : Type u} → F ι → (ι → β) → β) (op : F ι) (k : ι → FreeM F α) : - foldFreeM onValue onEffect (.liftBind op k) + foldFreeM onValue onEffect ((lift op).bind k) = onEffect op (fun x => foldFreeM onValue onEffect (k x)) := rfl +@[simp] +theorem foldFreeM_lift_bind' {F : Type w → Type v} {ι : Type w} + (onValue : α → β) + (onEffect : {ι : Type w} → F ι → (ι → β) → β) + (op : F ι) (k : ι → FreeM F α) : + foldFreeM onValue onEffect ((lift op) >>= k) + = onEffect op (fun x => foldFreeM onValue onEffect (k x)) := rfl + +@[simp] +theorem foldFreeM_lift + (onValue : ι → β) + (onEffect : {ι : Type u} → F ι → (ι → β) → β) + (op : F ι) : + foldFreeM onValue onEffect (lift op) = onEffect op onValue := + rfl + /-- **Universal Property**: If `h : FreeM F α → β` satisfies: -* `h (.pure a) = onValue a` -* `h (.liftBind op k) = onEffect op (fun x => h (k x))` +* `h (pure a) = onValue a` +* `h ((lift op).bind k) = onEffect op (fun x => h (k x))` then `h` is equal to `foldFreeM onValue onEffect`. -/ @@ -80,16 +96,16 @@ theorem foldFreeM_unique (onValue : α → β) (onEffect : {ι : Type u} → F ι → (ι → β) → β) (h : FreeM F α → β) - (h_pure : ∀ a, h (.pure a) = onValue a) + (h_pure : ∀ a, h (pure a) = onValue a) (h_liftBind : ∀ {ι} (op : F ι) (k : ι → FreeM F α), - h (.liftBind op k) = onEffect op (fun x => h (k x))) : + h ((lift op).bind k) = onEffect op (fun x => h (k x))) : h = foldFreeM onValue onEffect := by funext x induction x with | pure a => rw [foldFreeM_pure, h_pure] - | liftBind op k ih => - rw [foldFreeM_liftBind, h_liftBind] + | lift_bind op k ih => + rw [foldFreeM_lift_bind, h_liftBind] grind end FreeM