Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 181 additions & 47 deletions DatapathVerification/BitHeap/BitHeap.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ import Std.Data.HashMap
import Std.Data.HashSet
import DatapathVerification.BitHeap.Circuit
import DatapathVerification.BitHeap.Column
import Mathlib.Tactic.SplitIfs
import Std.Data.HashMap.Lemmas
import Std.Tactic.Do

structure BitHeap where
width : Nat
Expand All @@ -26,6 +29,12 @@ Evaluate a bit-heap, to compute the final sum of all the bits in the heap.
def eval (h : BitHeap) (env : BitEnv) : Int :=
(h.columns.fold (init := 0) (fun acc w col => acc + (2 ^ w) * col.eval env))

def eval' (h : BitHeap) (env : BitEnv) : Int := Id.run do
let mut acc := 0
for (col, val) in h.columns do
acc := acc + (2 ^ col) * val.eval env
return acc

/--
Evaluate a bit-heap modulo 2^width, to compute the final sum of all the bits in the heap.
-/
Expand All @@ -40,6 +49,10 @@ structure AdderResult where
def get (h : BitHeap) (column : Nat) : Column :=
h.columns.getD column (Column.empty)

theorem get_eq_getD (h : BitHeap) (column : Nat) :
h.get column = (h.columns[column]?).getD Column.empty := by
simp [get, Std.HashMap.getD_eq_getD_getElem?]

instance : Membership Circuit BitHeap where
mem h c :=
∃ (col : Nat), c ∈ h.get col
Expand All @@ -57,15 +70,35 @@ def highestColumn (h : BitHeap) : Option Nat :=
h.columns.toList.findSome? (fun (idx, col) => if col.height == target then some idx else none)

/--
Add a bit into the bit heap, returning a new bit heap. If the bit already exists in the column, remove it and add it to the next column.
Add a bit into the bit heap, returning a new bit heap.
If the bit already exists in the column, remove it and add it to the next column.
Stops carrying when the column exceeds the width of the bit heap.
-/
partial def addBit (column : Nat) (c : Circuit) (h : BitHeap) : BitHeap :=
let col := h.columns.getD column (Column.empty)
if col.contains c then
let h := h.removeBit column c
addBit (column + 1) c h
else
def addBit (column : Nat) (c : Circuit) (h : BitHeap) : BitHeap :=
if column >= h.width then h else
let col := h.get column
if !col.contains c then
⟨h.width, h.columns.insert column (col.insert c)⟩
else addBit (column + 1) c (h.removeBit column c)
termination_by h.width - column
decreasing_by
have hw : (removeBit column c h).width = h.width := by rfl
rw [hw]
omega

@[simp]
theorem removeBit_width (column : Nat) (c : Circuit) (h : BitHeap) :
(removeBit column c h).width = h.width := by rfl

@[simp]
theorem addBit_width (column : Nat) (c : Circuit) (h : BitHeap) :
(addBit column c h).width = h.width := by
fun_induction addBit with
| case1 => rfl
| case2 => rfl
| case3 _ _ _ _ _ ih =>
rw [removeBit_width] at ih
rw [ih]

def halfAdder (column : Nat) (i j : Circuit) (h : BitHeap) : AdderResult :=
let h := h.removeBit column i
Expand All @@ -87,25 +120,141 @@ def fullAdder (column : Nat) (i j k : Circuit) (h : BitHeap) : AdderResult :=
⟨h, sum, carry⟩

@[simp]
theorem eval_heap_addBit (column : Nat) (c : Circuit) (h : BitHeap) (env : BitEnv) :
(h.addBit column c).eval env = h.eval env + 2^column * (c.eval env).toInt := by
theorem evalMod_heap_removeBit (column : Nat) (c : Circuit) (h : BitHeap) (env : BitEnv) (h1 : c ∈ h.get column) :
(h.removeBit column c).evalMod env = (h.evalMod env - 2^(column) * (c.eval env).toInt) % 2^(h.width) := by
unfold evalMod
rw [removeBit_width]
simp [eval, removeBit]
have : (h.get column |>.erase c).eval env = (h.get column).eval env - 2 ^ column * (c.eval env).toInt := by
sorry
-- have : (h.columns.modify column fun col => col.erase c) = h.columns - 2 ^ column * (c.eval env).toInt := by sorry
repeat rw [Std.HashMap.fold_eq_foldl_toList]
sorry

@[simp]
theorem eval_heap_removeBit (column : Nat) (c : Circuit) (h : BitHeap) (env : BitEnv) (h1 : c ∈ h.get column) :
(h.removeBit column c).eval env = h.eval env - 2^(column) * (c.eval env).toInt := by
simp [BitHeap.eval, BitHeap.removeBit]
theorem by_pow2_of_zero_eval (h : BitHeap) (h1 : col ≥ h.width) :
(2 : Int) ^ h.width ∣ (2 : Int) ^ col := by
sorry
-- exact Nat.pow_dvd_pow_iff_le_right'.mpr h1 -> this works for Nat.

/--
Relate BitHeap.env to sum of a list. (Nat x Column) comes from Std.HashMap.toList, since it returns (Key x Value) pairs.
-/
theorem foldl_sum (l : List (Nat × Column)) (env : BitEnv) (a : Int) :
l.foldl (fun acc (p : Nat × Column) => acc + 2 ^ p.1 * (p.2.eval env : Int)) a =
a + (l.map (fun p => 2 ^ p.1 * (p.2.eval env : Int))).sum := by
induction l generalizing a with
| nil => simp
| cons p ps ih =>
grind

@[grind => ]
theorem eval_insertColumn_eq_eval_add (h : BitHeap) (k : Nat) (v : Column) (env : BitEnv) :
(⟨h.width, h.columns.insert k v⟩ : BitHeap).eval env
= h.eval env + 2 ^ k * (v.eval env : Int) - 2 ^ k * ((h.get k).eval env : Int) := by
cases h
case mk width cols =>
simp [BitHeap.eval]
rw [Std.HashMap.fold_eq_foldl_toList]
rw [Std.HashMap.fold_eq_foldl_toList]
rw [foldl_sum]
rw [foldl_sum]
sorry

@[grind => ]
theorem eval_eraseColumn_eq_eval_sub (h : BitHeap) (k : Nat) (env : BitEnv) :
(⟨h.width, h.columns.erase k⟩ : BitHeap).eval env
= h.eval env - 2 ^ k * ((h.get k).eval env : Int) := by
simp [BitHeap.eval]
rw [Std.HashMap.fold_eq_foldl_toList]
rw [Std.HashMap.fold_eq_foldl_toList]

rw?
sorry


theorem eval_insertColumn (h : BitHeap) (k : Nat) (v : Column) (env : BitEnv) :
(⟨h.width, h.columns.insert k v⟩ : BitHeap).eval env
= (⟨h.width, h.columns.erase k⟩ : BitHeap).eval env + 2 ^ k * (v.eval env : Int) := by
-- have := eval_insertColumn_eq_eval_add h k v env
-- have := eval_eraseColumn_eq_eval_sub h k env
-- grind only

simp [eval]
repeat rw [Std.HashMap.fold_eq_foldl_toList]
repeat rw [foldl_sum]
simp only [zero_add]

-- Both list are permutations of the same the same list
have hp : (h.columns.insert k v).toList.Perm ((k, v) :: (h.columns.erase k).toList) := by
sorry
-- After mapping, they are still permutations of each other
have hp_mapped : ((h.columns.insert k v).toList.map (fun p => 2 ^ p.fst * (p.snd.eval env : Int))).Perm
(((k, v) :: (h.columns.erase k).toList).map (fun p => 2 ^ p.fst * (p.snd.eval env : Int))) :=
hp.map _
have hp_sum := hp_mapped.sum_eq -- Since they are permutations of each other, their sums are equal
rw [hp_sum]
grind


theorem eval_eraseColumn (h : BitHeap) (k : Nat) (env : BitEnv) :
h.eval env
= (⟨h.width, h.columns.erase k⟩ : BitHeap).eval env + 2 ^ k * ((h.get k).eval env : Int) := by
sorry

@[simp]
theorem evalMod_heap_addBit (column : Nat) (c : Circuit) (h : BitHeap) (env : BitEnv) :
(h.addBit column c).evalMod env = (h.evalMod env + 2^column * (c.eval env).toInt) % 2^(h.width) := by
fun_induction addBit with
| case1 col h h1 =>
simp [evalMod]
have h3 : 2 ^ col * (c.eval env).toInt % 2 ^ h.width = 0 := by
generalize hvi : c.eval env = vi
rcases vi
· simp
· simp only [Bool.toInt_true]
rw [Int.mul_one]
apply Int.emod_eq_zero_of_dvd
exact_mod_cast by_pow2_of_zero_eval h h1
simp [Int.add_emod, h3]
| case2 col h h1 =>
simp only [evalMod, Int.emod_add_emod]
have h3 : (⟨h.width, h.columns.insert col ((h.get col).insert c)⟩ : BitHeap).eval env = h.eval env + 2 ^ col * (c.eval env).toInt := by

rw [eval_insertColumn, eval_eraseColumn h col env]

rw [Column.eval_insert]
· grind
· simp
grind
rw [h3]
| case3 _ _ _ h2 h1 ih =>
simp only [ih, removeBit_width]
rw [evalMod_heap_removeBit]
· simp only [Int.emod_add_emod]
grind
· simp at h1
simp [mem_iff_contains]
grind

@[simp]
theorem get_removeBit_of_ne (column : Nat) (h : BitHeap) (i j : Circuit)
(h1 : i ∈ h.get column) (hne : i ≠ j) :
i ∈ (removeBit column j h).get column := by
sorry
rw [get_eq_getD] at h1
rw [get_eq_getD]
simp only [removeBit]
rcases hcol : h.columns[column]?
· simp_all only
grind
· simp_all only [Option.getD_some, mem_iff_contains, ne_eq, Std.HashMap.getElem?_modify_self,
Option.map_some, Column.erase, Column.contains, Std.HashSet.contains_erase]
grind

theorem removeBit_decreases_size (col : Nat) (c : Circuit) (h : BitHeap) (h1: c ∈ h.get col) :
((removeBit col c h).get col).height < (h.get col).height := by
sorry
simp only [removeBit, height_eq_size]
simp [erase]
sorry

theorem double_removeBit_decreases (col : Nat) (c₁ c₂ : Circuit) (h : BitHeap)
(h1 : c₁ ∈ h.get col) (h2 : c₂ ∈ h.get col) (hne : c₁ ≠ c₂) :
Expand All @@ -128,59 +277,44 @@ theorem triple_removeBit_decreases (col : Nat) (c₁ c₂ c₃ : Circuit) (h : B
(double_removeBit_decreases col c₁ c₂ (removeBit col c₃ h) h1' h2' hne12)
(removeBit_decreases_size col c₃ h h3)

@[simp]
theorem toNat_and (a b : Bool) :
(a && b).toNat = a.toNat * b.toNat := by
cases a <;> cases b <;> simp

theorem halfAdder_correct (column : Nat) (i j : Circuit) (h : BitHeap) (h1 : i ∈ h.get column) (h2 : j ∈ h.get column) (hne : i ≠ j):
∀ (env : BitEnv), (h.halfAdder column i j).heap.eval env = h.eval env := by
intros env
have h3 := get_removeBit_of_ne column h j i h2 hne.symm
simp [halfAdder, h1, h3]
generalize hvi : i.eval env = vi
generalize hvj : j.eval env = vj
rcases vi <;> rcases vj <;> grind

-- TODO: Seems to me we need the termination proof of addBit first
@[simp]
theorem halfAdder_preserves_width (column : Nat) (i j : Circuit) (h : BitHeap) :
(h.halfAdder column i j).heap.width = h.width := by
simp [halfAdder, removeBit]
sorry

theorem halfAdder_correct_mod (column : Nat) (i j : Circuit) (h : BitHeap)
(h1 : i ∈ h.get column) (h2 : j ∈ h.get column) (hne : i ≠ j) :
∀ (env : BitEnv), (h.halfAdder column i j).heap.evalMod env = h.evalMod env := by
intros env
have h3 := get_removeBit_of_ne column h j i h2 hne.symm
simp [halfAdder, evalMod_heap_addBit, addBit_width, removeBit_width]
simp only [evalMod_heap_removeBit, h1, h3]
simp [evalMod]
rw [halfAdder_correct column i j h h1 h2 hne env]

theorem fullAdder_correct (column : Nat) (i j k : Circuit) (h : BitHeap)
(h1 : i ∈ h.get column) (h2 : j ∈ h.get column) (h3 : k ∈ h.get column) (hne : i ≠ j) (hne2 : i ≠ k) (hne3 : j ≠ k) :
∀ (env : BitEnv), (h.fullAdder column i j k).heap.eval env = h.eval env := by
intros env
have h4 := get_removeBit_of_ne column h j i h2 hne.symm
have h5 := get_removeBit_of_ne column (removeBit column i h) k
have h6 := h5 j (get_removeBit_of_ne column h k i h3 hne2.symm) hne3.symm
simp [fullAdder, h1, h4, h6]
generalize hvi : i.eval env = vi
generalize hvj : j.eval env = vj
generalize hvk : k.eval env = vk
rcases vi <;> rcases vj <;> rcases vk <;> grind
rcases vi <;> rcases vj <;> simp_all
grind

-- TODO: Seems to me we need the termination proof of addBit first
@[simp]
theorem fullAdder_preserves_width (column : Nat) (i j k : Circuit) (h : BitHeap) :
(h.fullAdder column i j k).heap.width = h.width := by
simp [fullAdder, removeBit]
sorry

theorem fullAdder_correct_mod (column : Nat) (i j k : Circuit) (h : BitHeap)
(h1 : i ∈ h.get column) (h2 : j ∈ h.get column) (h3 : k ∈ h.get column) (hne : i ≠ j) (hne2 : i ≠ k) (hne3 : j ≠ k) :
(h1 : i ∈ h.get column) (h2 : j ∈ h.get column) (h3 : k ∈ h.get column)
(hne : i ≠ j) (hne2 : i ≠ k) (hne3 : j ≠ k) :
∀ (env : BitEnv), (h.fullAdder column i j k).heap.evalMod env = h.evalMod env := by
intros env
have h4 := get_removeBit_of_ne column h j i h2 hne.symm
have h5 := get_removeBit_of_ne column (removeBit column i h) k
have h6 := h5 j (get_removeBit_of_ne column h k i h3 hne2.symm) hne3.symm
simp [fullAdder, evalMod_heap_addBit, addBit_width, removeBit_width]
simp only [evalMod_heap_removeBit, h1, h4, h6]
simp [evalMod]
rw [fullAdder_correct column i j k h h1 h2 h3 hne hne2 hne3 env]
generalize hvi : i.eval env = vi
generalize hvj : j.eval env = vj
generalize hvk : k.eval env = vk
rw [Int.add_emod]
rcases vi <;> rcases vj <;> rcases vk <;> simp_all <;> grind

end BitHeap
Loading
Loading