diff --git a/Strata/DDM/Elab/Core.lean b/Strata/DDM/Elab/Core.lean index b27608045..568ef47e7 100644 --- a/Strata/DDM/Elab/Core.lean +++ b/Strata/DDM/Elab/Core.lean @@ -250,7 +250,6 @@ def translateQualifiedIdent (t : Tree) : MaybeQualifiedIdent := | q`Init.qualifiedIdentImplicit, 1 => Id.run do let .ident _ name := args[0] | return panic! "Expected ident" - let name := name.dropPrefix "«" |>.dropSuffix "»" |>.toString match name.splitOn "." with | [dialect, rest] => .qid { dialect, name := rest } | _ => .name name @@ -1448,7 +1447,7 @@ partial def catElaborator (c : SyntaxCat) : TypingContext → Syntax → ElabM T fun tctx stx => do let some loc := mkSourceRange? stx | panic! "ident missing source location" - let info : IdentInfo := { inputCtx := tctx, loc := loc, val := stx.getId.toString } + let info : IdentInfo := { inputCtx := tctx, loc := loc, val := stx.getId.toString (escape := false) } pure <| .node (.ofIdentInfo info) #[] | q`Init.Num => fun tctx stx => do diff --git a/Strata/DDM/Format.lean b/Strata/DDM/Format.lean index e7dc7d2e0..6cb5fda88 100644 --- a/Strata/DDM/Format.lean +++ b/Strata/DDM/Format.lean @@ -61,11 +61,6 @@ Strips Lean's «» notation if present. Follows SMT-LIB 2.6 specification for quoted symbols. -/ private def formatIdent (s : String) : Format := - -- Strip Lean's «» notation if present - let s := if s.startsWith "«" && s.endsWith "»" then - s.drop 1 |>.dropEnd 1 |>.toString - else - s if needsPipeDelimiters s then Format.text ("|" ++ escapePipeIdent s ++ "|") else diff --git a/Strata/Languages/Core/Verifier.lean b/Strata/Languages/Core/Verifier.lean index 7f1e4e392..9b0c7efae 100644 --- a/Strata/Languages/Core/Verifier.lean +++ b/Strata/Languages/Core/Verifier.lean @@ -522,7 +522,8 @@ structure Diagnostic where deriving Repr, BEq def DiagnosticModel.toDiagnostic (files: Map Strata.Uri Lean.FileMap) (dm: DiagnosticModel): Diagnostic := - let fileMap := (files.find? dm.fileRange.file).getD (panic s!"Could not find {repr dm.fileRange.file} in {repr files.keys} when converting model '{dm}' to a diagnostic") + let fileMap := (files.find? dm.fileRange.file).getD + (dbg_trace s!"Could not find {repr dm.fileRange.file} in {repr files.keys} when converting model '{dm}' to a diagnostic"; default) let startPos := fileMap.toPosition dm.fileRange.range.start let endPos := fileMap.toPosition dm.fileRange.range.stop { diff --git a/Strata/Languages/Laurel/CoreDefinitionsForLaurel.lean b/Strata/Languages/Laurel/CoreDefinitionsForLaurel.lean new file mode 100644 index 000000000..43baec7b8 --- /dev/null +++ b/Strata/Languages/Laurel/CoreDefinitionsForLaurel.lean @@ -0,0 +1,46 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +import Strata.DDM.Elab +import Strata.DDM.AST +import Strata.Languages.Laurel.Grammar.LaurelGrammar +import Strata.Languages.Laurel.Grammar.ConcreteToAbstractTreeTranslator + +namespace Strata.Laurel + +/-- +Core map operations (`select`, `update`, `const`) expressed in Laurel syntax. +These are polymorphic map primitives used by the Laurel-to-Core translator. +Since Laurel doesn't have polymorphic types, `int` is used as a placeholder type +for all parameters — the actual types are inferred during Core translation. +-/ +def coreDefinitionsForLaurelDDM := +#strata +program Laurel; + +// The types for these Map functions are incorrect. +// We'll fix them when Laurel supports polymorphism +function select(map: int, key: int) : int + external + +function update(map: int, key: int, value: int) : int + external + +function const(value: int) : int + external + +#end + +/-- +The core map operation definitions as a `Laurel.Program`, parsed at compile time. +-/ +def coreDefinitionsForLaurel : Program := + let uri := Strata.Uri.file "Strata/Languages/Laurel/CoreDefinitionsForLaurel.lean" + match TransM.run uri (parseProgram coreDefinitionsForLaurelDDM) with + | .ok program => program + | .error e => panic! s!"CoreDefinitionsForLaurel parse error: {e}" + +end Strata.Laurel diff --git a/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean b/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean index a30c6b8a8..fd676dd23 100644 --- a/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean +++ b/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean @@ -19,12 +19,12 @@ open Lean.Parser (InputContext) open Imperative (MetaData) structure TransState where - uri : Uri + uri : Option Uri errors : Array String abbrev TransM := StateT TransState (Except String) -def TransM.run (uri : Uri) (m : TransM α) : Except String α := +def TransM.run (uri : Option Uri) (m : TransM α) : Except String α := match StateT.run m { uri := uri, errors := #[] } with | .ok (v, _) => .ok v | .error e => .error e @@ -36,8 +36,10 @@ def SourceRange.toMetaData (uri : Uri) (sr : SourceRange) : Imperative.MetaData let fileRangeElt := ⟨ Imperative.MetaDataElem.Field.label "fileRange", .fileRange ⟨ uri, sr.start, sr.stop ⟩ ⟩ #[fileRangeElt] -def getArgMetaData (arg : Arg) : TransM (Imperative.MetaData Core.Expression) := - return SourceRange.toMetaData (← get).uri arg.ann +def getArgMetaData (arg : Arg) : TransM (Imperative.MetaData Core.Expression) := do + return match (← get).uri with + | some uri => SourceRange.toMetaData uri arg.ann + | none => default def checkOp (op : Strata.Operation) (name : QualifiedIdent) (argc : Nat) : TransM Unit := do @@ -55,7 +57,7 @@ def checkOp (op : Strata.Operation) (name : QualifiedIdent) (argc : Nat) : def translateIdent (arg : Arg) : TransM Identifier := do let .ident _ id := arg | TransM.error s!"translateIdent expects ident" - return id + return { text := id } def translateBool (arg : Arg) : TransM Bool := do match arg with @@ -72,7 +74,7 @@ def translateBool (arg : Arg) : TransM Bool := do | x => TransM.error s!"translateBool expects expression or operation, got {repr x}" instance : Inhabited Parameter where - default := { name := "", type := ⟨.TVoid, #[]⟩ } + default := { name := "" , type := ⟨.TVoid, #[]⟩ } def mkHighTypeMd (t : HighType) (md : MetaData Core.Expression) : HighTypeMd := ⟨t, md⟩ def mkStmtExprMd (e : StmtExpr) (md : MetaData Core.Expression) : StmtExprMd := ⟨e, md⟩ @@ -85,7 +87,12 @@ partial def translateHighType (arg : Arg) : TransM HighTypeMd := do match op.name, op.args with | q`Laurel.intType, _ => return mkHighTypeMd .TInt md | q`Laurel.boolType, _ => return mkHighTypeMd .TBool md + | q`Laurel.float64Type, _ => return mkHighTypeMd .TFloat64 md | q`Laurel.stringType, _ => return mkHighTypeMd .TString md + | q`Laurel.mapType, #[keyArg, valArg] => + let keyType ← translateHighType keyArg + let valType ← translateHighType valArg + return mkHighTypeMd (.TMap keyType valType) md | q`Laurel.compositeType, #[nameArg] => let name ← translateIdent nameArg return mkHighTypeMd (.UserDefined name) md @@ -257,12 +264,12 @@ partial def translateStmtExpr (arg : Arg) : TransM StmtExprMd := do let name ← translateIdent nameArg let ty ← translateHighType tyArg let body ← translateStmtExpr bodyArg - return mkStmtExprMd (.Forall name ty body) md + return mkStmtExprMd (.Forall { name := name, type := ty } body) md | q`Laurel.existsExpr, #[nameArg, tyArg, bodyArg] => let name ← translateIdent nameArg let ty ← translateHighType tyArg let body ← translateStmtExpr bodyArg - return mkStmtExprMd (.Exists name ty body) md + return mkStmtExprMd (.Exists { name := name, type := ty } body) md | _, #[arg0] => match getUnaryOp? op.name with | some primOp => let inner ← translateStmtExpr arg0 @@ -370,14 +377,22 @@ def parseProcedure (arg : Arg) : TransM Procedure := do -- Parse modifies clauses (zero or more) let modifies ← translateModifiesClauses modifiesArg -- Parse optional body + let isExternal ← match bodyArg with + | .option _ (some (.op bodyOp)) => match bodyOp.name, bodyOp.args with + | q`Laurel.externalBody, #[] => pure true + | _, _ => pure false + | _ => pure false let body ← match bodyArg with | .option _ (some (.op bodyOp)) => match bodyOp.name, bodyOp.args with | q`Laurel.optionalBody, #[exprArg] => translateCommand exprArg >>= (pure ∘ some) - | _, _ => TransM.error s!"Expected optionalBody operation, got {repr bodyOp.name}" + | q`Laurel.externalBody, #[] => pure none + | _, _ => TransM.error s!"Expected optionalBody or externalBody operation, got {repr bodyOp.name}" | .option _ none => pure none | _ => TransM.error s!"Expected optionalBody, got {repr bodyArg}" -- Determine procedure body kind - let procBody := match postconditions, body with + let procBody := + if isExternal then Body.External + else match postconditions, body with | _ :: _, bodyOpt => Body.Opaque postconditions bodyOpt modifies | [], some b => Body.Transparent b | [], none => Body.Opaque [] none modifies @@ -435,6 +450,51 @@ def parseComposite (arg : Arg) : TransM TypeDefinition := do | _, _ => TransM.error s!"parseComposite expects composite, got {repr op.name}" +def parseDatatypeConstructorArg (arg : Arg) : TransM Parameter := do + let .op op := arg + | TransM.error s!"parseDatatypeConstructorArg expects operation" + match op.name, op.args with + | q`Laurel.datatypeConstructorArg, #[nameArg, typeArg] => + let name ← translateIdent nameArg + let argType ← translateHighType typeArg + return { name := name, type := argType } + | _, _ => + TransM.error s!"parseDatatypeConstructorArg expects datatypeConstructorArg, got {repr op.name}" + +def parseDatatypeConstructor (arg : Arg) : TransM DatatypeConstructor := do + let .op op := arg + | TransM.error s!"parseDatatypeConstructor expects operation" + match op.name, op.args with + | q`Laurel.datatypeConstructor, #[nameArg, argsSeq] => + let name ← translateIdent nameArg + let args ← match argsSeq with + | .seq _ .comma args => args.toList.mapM parseDatatypeConstructorArg + | _ => pure [] + return { name := name, args := args } + | q`Laurel.datatypeConstructorNoArgs, #[nameArg] => + let name ← translateIdent nameArg + return { name := name, args := [] } + | _, _ => + TransM.error s!"parseDatatypeConstructor expects datatypeConstructor, got {repr op.name}" + +def parseDatatype (arg : Arg) : TransM TypeDefinition := do + let .op op := arg + | TransM.error s!"parseDatatype expects operation" + match op.name, op.args with + | q`Laurel.datatype, #[nameArg, constructorsArg] => + let name ← translateIdent nameArg + let constructors ← match constructorsArg with + | .op listOp => match listOp.name, listOp.args with + | q`Laurel.datatypeConstructorList, #[csArg] => + match csArg with + | .seq _ .comma args => args.toList.mapM parseDatatypeConstructor + | singleArg => do let c ← parseDatatypeConstructor singleArg; pure [c] + | _, _ => TransM.error s!"Expected datatypeConstructorList, got {repr listOp.name}" + | _ => TransM.error s!"Expected datatypeConstructorList operation" + return .Datatype { name := name, typeArgs := [], constructors := constructors } + | _, _ => + TransM.error s!"parseDatatype expects datatype, got {repr op.name}" + def parseTopLevel (arg : Arg) : TransM (Option Procedure × Option TypeDefinition) := do let .op op := arg | TransM.error s!"parseTopLevel expects operation" @@ -446,8 +506,11 @@ def parseTopLevel (arg : Arg) : TransM (Option Procedure × Option TypeDefinitio | q`Laurel.topLevelComposite, #[compositeArg] => let typeDef ← parseComposite compositeArg return (none, some typeDef) + | q`Laurel.topLevelDatatype, #[datatypeArg] => + let typeDef ← parseDatatype datatypeArg + return (none, some typeDef) | _, _ => - TransM.error s!"parseTopLevel expects topLevelProcedure or topLevelComposite, got {repr op.name}" + TransM.error s!"parseTopLevel expects topLevelProcedure, topLevelComposite, or topLevelDatatype, got {repr op.name}" /-- Translate concrete Laurel syntax into abstract Laurel syntax diff --git a/Strata/Languages/Laurel/Grammar/LaurelGrammar.lean b/Strata/Languages/Laurel/Grammar/LaurelGrammar.lean index 632e0a69b..0511e158b 100644 --- a/Strata/Languages/Laurel/Grammar/LaurelGrammar.lean +++ b/Strata/Languages/Laurel/Grammar/LaurelGrammar.lean @@ -7,6 +7,7 @@ -- Laurel dialect definition, loaded from LaurelGrammar.st -- NOTE: Changes to LaurelGrammar.st are not automatically tracked by the build system. -- Update this file (e.g. this comment) to trigger a recompile after modifying LaurelGrammar.st. +-- Last grammar change: added externalBody op for external procedure/function bodies. import Strata.DDM.Integration.Lean namespace Strata.Laurel diff --git a/Strata/Languages/Laurel/Grammar/LaurelGrammar.st b/Strata/Languages/Laurel/Grammar/LaurelGrammar.st index b1432aef1..2c7850c54 100644 --- a/Strata/Languages/Laurel/Grammar/LaurelGrammar.st +++ b/Strata/Languages/Laurel/Grammar/LaurelGrammar.st @@ -6,7 +6,9 @@ dialect Laurel; category LaurelType; op intType : LaurelType => "int"; op boolType : LaurelType => "bool"; +op float64Type : LaurelType => "float64"; op stringType : LaurelType => "string"; +op mapType (keyType: LaurelType, valueType: LaurelType): LaurelType => "Map " keyType " " valueType; op compositeType (name: Ident): LaurelType => name; category StmtExpr; @@ -106,6 +108,20 @@ op optionalExtends(parents: CommaSepBy Ident): OptionalExtends => "extends " par category Composite; op composite (name: Ident, extending: Option OptionalExtends, fields: Seq Field): Composite => "composite " name extending "{" fields "}"; +// Datatype definitions +category DatatypeConstructorArg; +op datatypeConstructorArg (name: Ident, argType: LaurelType): DatatypeConstructorArg => name ":" argType; + +category DatatypeConstructor; +op datatypeConstructor (name: Ident, args: CommaSepBy DatatypeConstructorArg): DatatypeConstructor => name "(" args ")"; +op datatypeConstructorNoArgs (name: Ident): DatatypeConstructor => name; + +category DatatypeConstructorList; +op datatypeConstructorList (constructors: CommaSepBy DatatypeConstructor): DatatypeConstructorList => constructors; + +category Datatype; +op datatype (name: Ident, constructors: DatatypeConstructorList): Datatype => "datatype " name "{" constructors "}"; + // Procedures category OptionalReturnType; op optionalReturnType(returnType: LaurelType): OptionalReturnType => ":" returnType; @@ -124,6 +140,7 @@ op returnParameters(parameters: CommaSepBy Parameter): ReturnParameters => "retu category OptionalBody; op optionalBody(body: StmtExpr): OptionalBody => body:0; +op externalBody: OptionalBody => "external"; category Procedure; op procedure (name : Ident, parameters: CommaSepBy Parameter, @@ -147,5 +164,6 @@ op function (name : Ident, parameters: CommaSepBy Parameter, category TopLevel; op topLevelComposite(composite: Composite): TopLevel => composite; op topLevelProcedure(procedure: Procedure): TopLevel => procedure; +op topLevelDatatype(datatype: Datatype): TopLevel => datatype; op program (items: Seq TopLevel): Command => items; \ No newline at end of file diff --git a/Strata/Languages/Laurel/HeapParameterization.lean b/Strata/Languages/Laurel/HeapParameterization.lean index b1ff3e6fd..5f298009e 100644 --- a/Strata/Languages/Laurel/HeapParameterization.lean +++ b/Strata/Languages/Laurel/HeapParameterization.lean @@ -7,6 +7,7 @@ import Strata.Languages.Laurel.Laurel import Strata.Languages.Laurel.LaurelFormat import Strata.Languages.Laurel.LaurelTypes +import Strata.Languages.Laurel.HeapParameterizationConstants import Strata.Util.Tactics /- @@ -77,8 +78,8 @@ def collectExpr (expr : StmtExpr) : StateM AnalysisResult Unit := do | .ReferenceEquals l r => collectExprMd l; collectExprMd r | .AsType t _ => collectExprMd t | .IsType t _ => collectExprMd t - | .Forall _ _ b => collectExprMd b - | .Exists _ _ b => collectExprMd b + | .Forall _ b => collectExprMd b + | .Exists _ b => collectExprMd b | .Assigned n => collectExprMd n | .Old v => collectExprMd v | .Fresh v => collectExprMd v @@ -112,6 +113,7 @@ def analyzeProc (proc : Procedure) : AnalysisResult := writesHeapDirectly := r1.writesHeapDirectly || r2.writesHeapDirectly, callees := r1.callees ++ r2.callees } | .Abstract postconds => (postconds.forM collectExprMd).run {} |>.2 + | .External => {} -- Also analyze preconditions let precondResult := (proc.preconditions.forM collectExprMd).run {} |>.2 { readsHeapDirectly := bodyResult.readsHeapDirectly || precondResult.readsHeapDirectly, @@ -147,22 +149,12 @@ def computeWritesHeap (procs : List Procedure) : List Identifier := fixpoint procs.length direct structure TransformState where - fieldConstants : List Constant := [] heapReaders : List Identifier heapWriters : List Identifier - fieldTypes : List (Identifier × HighTypeMd) := [] -- Maps "TypeName.fieldName" to their value types - types : List TypeDefinition := [] -- Type definitions for resolving field owners freshCounter : Nat := 0 -- Counter for generating fresh variable names abbrev TransformM := StateM TransformState -def addFieldConstant (name : Identifier) (valueType : HighTypeMd) : TransformM Unit := - modify fun s => if s.fieldConstants.any (·.name == name) then s - else { s with fieldConstants := { name := name, type := ⟨.TTypedField valueType, #[] ⟩ } :: s.fieldConstants } - -def lookupFieldType (name : Identifier) : TransformM (Option HighTypeMd) := do - return (← get).fieldTypes.find? (·.1 == name) |>.map (·.2) - /-- Get the Box destructor name for a given Laurel HighType -/ def boxDestructorName (ty : HighType) : Identifier := match ty with @@ -195,35 +187,14 @@ def freshVarName : TransformM Identifier := do /-- Helper to wrap a StmtExpr into StmtExprMd with empty metadata -/ private def mkMd (e : StmtExpr) : StmtExprMd := ⟨e, #[]⟩ -/-- -Find the composite type that actually declares a given field, walking up the inheritance chain. -Returns the declaring type's name, or falls back to the given type name. --/ -def findFieldOwner (types : List TypeDefinition) (typeName : Identifier) (fieldName : Identifier) : Identifier := - let rec go (fuel : Nat) (current : Identifier) : Option Identifier := - match fuel with - | 0 => none - | fuel' + 1 => - types.findSome? fun td => - match td with - | .Composite ct => - if ct.name == current then - if ct.fields.any (·.name == fieldName) then some ct.name - else ct.extending.findSome? (go fuel') - else none - | _ => none - (go types.length typeName).getD (panic "type inheritance forms a cycle") - /-- Resolve the owning composite type name for a field access by computing the target expression's type. Returns the qualified field name "DeclaringType.fieldName". -/ -def resolveQualifiedFieldName (env : TypeEnv) (types : List TypeDefinition) (target : StmtExprMd) (fieldName : Identifier) : Identifier := - match (computeExprType env types target).val with - | .UserDefined typeName => - let owner := findFieldOwner types typeName fieldName - owner ++ "." ++ fieldName - | _ => panic "assigning to a target that's not a composite type" +def resolveQualifiedFieldName (model: SemanticModel) (fieldName : Identifier) : String := + match model.get fieldName with + | .field owner _ => owner.text ++ "." ++ fieldName.text + | _ => panic! s!"resolveQualifiedFieldName {fieldName} did not resolve to a field" /-- Transform an expression, adding heap parameters where needed. @@ -231,23 +202,20 @@ Transform an expression, adding heap parameters where needed. - `env`: the type environment for resolving field owners - `valueUsed`: whether the result value of this expression is used (affects optimization of heap-writing calls) -/ -def heapTransformExpr (heapVar : Identifier) (env : TypeEnv) (expr : StmtExprMd) (valueUsed : Bool := true) : TransformM StmtExprMd := - recurse env expr valueUsed +def heapTransformExpr (heapVar : Identifier) (model: SemanticModel) (expr : StmtExprMd) (valueUsed : Bool := true) : TransformM StmtExprMd := + recurse expr valueUsed where - recurse (env : TypeEnv) (expr : StmtExprMd) (valueUsed : Bool := true) : TransformM StmtExprMd := do + recurse (expr : StmtExprMd) (valueUsed : Bool := true) : TransformM StmtExprMd := do let md := expr.md - let types := (← get).types match _h : expr.val with | .FieldSelect selectTarget fieldName => - let qualifiedName := resolveQualifiedFieldName env types selectTarget fieldName - let fieldType ← lookupFieldType qualifiedName - let valTy := fieldType.getD (panic s!"could not find field type for {qualifiedName}") - addFieldConstant qualifiedName valTy - let readExpr := ⟨ .StaticCall "readField" [mkMd (.Identifier heapVar), selectTarget, mkMd (.Identifier qualifiedName)], md ⟩ + let qualifiedName : Identifier := resolveQualifiedFieldName model fieldName + let valTy := (model.get fieldName).getType.getD (panic! "heapTransformExpr1") + let readExpr := ⟨ .StaticCall "readField" [mkMd (.Identifier heapVar), selectTarget, mkMd (.StaticCall qualifiedName [])], md ⟩ -- Unwrap Box: apply the appropriate destructor return mkMd <| .StaticCall (boxDestructorName valTy.val) [readExpr] | .StaticCall callee args => - let args' ← args.mapM (recurse env ·) + let args' ← args.mapM (recurse ·) let calleeReadsHeap ← readsHeap callee let calleeWritesHeap ← writesHeap callee if calleeWritesHeap then @@ -265,73 +233,67 @@ where else return ⟨ .StaticCall callee args', md ⟩ | .InstanceCall callTarget callee args => - let t ← recurse env callTarget - let args' ← args.mapM (recurse env ·) + let t ← recurse callTarget + let args' ← args.mapM (recurse ·) return ⟨ .InstanceCall t callee args', md ⟩ | .IfThenElse c t e => - let e' ← match e with | some x => some <$> recurse env x valueUsed | none => pure none - return ⟨ .IfThenElse (← recurse env c) (← recurse env t valueUsed) e', md ⟩ + let e' ← match e with | some x => some <$> recurse x valueUsed | none => pure none + return ⟨ .IfThenElse (← recurse c) (← recurse t valueUsed) e', md ⟩ | .Block stmts label => let n := stmts.length - let rec processStmts (env : TypeEnv) (idx : Nat) (remaining : List StmtExprMd) : TransformM (List StmtExprMd) := do + let rec processStmts (idx : Nat) (remaining : List StmtExprMd) : TransformM (List StmtExprMd) := do match remaining with | [] => pure [] | s :: rest => let isLast := idx == n - 1 - -- Extend env for LocalVariable declarations - let env' := match s.val with - | .LocalVariable name ty _ => (name, ty) :: env - | _ => env - let s' ← recurse env s (isLast && valueUsed) - let rest' ← processStmts env' (idx + 1) rest + let s' ← recurse s (isLast && valueUsed) + let rest' ← processStmts (idx + 1) rest pure (s' :: rest') termination_by sizeOf remaining - let stmts' ← processStmts env 0 stmts + let stmts' ← processStmts 0 stmts return ⟨ .Block stmts' label, md ⟩ | .LocalVariable n ty i => - let i' ← match i with | some x => some <$> recurse env x | none => pure none + let i' ← match i with | some x => some <$> recurse x | none => pure none return ⟨ .LocalVariable n ty i', md ⟩ | .While c invs d b => - let invs' ← invs.mapM (recurse env ·) - return ⟨ .While (← recurse env c) invs' d (← recurse env b false), md ⟩ + let invs' ← invs.mapM (recurse ·) + return ⟨ .While (← recurse c) invs' d (← recurse b false), md ⟩ | .Return v => - let v' ← match v with | some x => some <$> recurse env x | none => pure none + let v' ← match v with | some x => some <$> recurse x | none => pure none return ⟨ .Return v', md ⟩ | .Assign targets v => match targets with | [fieldSelectMd] => match _h2 : fieldSelectMd.val with | .FieldSelect target fieldName => - let qualifiedName := resolveQualifiedFieldName env types target fieldName - let fieldType ← lookupFieldType qualifiedName - let valTy := fieldType.getD (panic s!"could not find field type for {qualifiedName}") - addFieldConstant qualifiedName valTy - let target' ← recurse env target - let v' ← recurse env v + let qualifiedName : Identifier := resolveQualifiedFieldName model fieldName + let valTy := (model.get fieldName).getType.getD (panic! "heapTransformExpr2") + let target' ← recurse target + let v' ← recurse v -- Wrap value in Box constructor let boxedVal := mkMd <| .StaticCall (boxConstructorName valTy.val) [v'] let heapAssign := ⟨ .Assign [mkMd (.Identifier heapVar)] - (mkMd (.StaticCall "updateField" [mkMd (.Identifier heapVar), target', mkMd (.Identifier qualifiedName), boxedVal])), md ⟩ + (mkMd (.StaticCall "updateField" [mkMd (.Identifier heapVar), target', mkMd (.StaticCall qualifiedName []), boxedVal])), md ⟩ if valueUsed then return ⟨ .Block [heapAssign, v'] none, md ⟩ else return heapAssign | _ => - let tgt' ← recurse env fieldSelectMd - return ⟨ .Assign [tgt'] (← recurse env v), md ⟩ + let tgt' ← recurse fieldSelectMd + return ⟨ .Assign [tgt'] (← recurse v), md ⟩ | [] => - return ⟨ .Assign [] (← recurse env v), md ⟩ + return ⟨ .Assign [] (← recurse v), md ⟩ | tgt :: rest => - let tgt' ← recurse env tgt - let targets' ← rest.mapM (recurse env ·) - return ⟨ .Assign (tgt' :: targets') (← recurse env v), md ⟩ - | .PureFieldUpdate t f v => return ⟨ .PureFieldUpdate (← recurse env t) f (← recurse env v), md ⟩ + let tgt' ← recurse tgt + let targets' ← rest.mapM (recurse ·) + return ⟨ .Assign (tgt' :: targets') (← recurse v), md ⟩ + | .PureFieldUpdate t f v => return ⟨ .PureFieldUpdate (← recurse t) f (← recurse v), md ⟩ | .PrimitiveOp op args => - let args' ← args.mapM (recurse env ·) + let args' ← args.mapM (recurse ·) -- For == and != on Composite types, compare refs instead match op, args with | .Eq, [e1, _e2] => - let ty := (computeExprType env types e1).val + let ty := (computeExprType model e1).val match ty with | .UserDefined _ => let ref1 := mkMd (.StaticCall "Composite..ref!" [args'[0]!]) @@ -339,7 +301,7 @@ where return ⟨ .PrimitiveOp .Eq [ref1, ref2], md ⟩ | _ => return ⟨ .PrimitiveOp op args', md ⟩ | .Neq, [e1, _e2] => - let ty := (computeExprType env types e1).val + let ty := (computeExprType model e1).val match ty with | .UserDefined _ => let ref1 := mkMd (.StaticCall "Composite..ref!" [args'[0]!]) @@ -348,22 +310,22 @@ where | _ => return ⟨ .PrimitiveOp op args', md ⟩ | _, _ => return ⟨ .PrimitiveOp op args', md ⟩ | .New _ => return expr - | .ReferenceEquals l r => return ⟨ .ReferenceEquals (← recurse env l) (← recurse env r), md ⟩ + | .ReferenceEquals l r => return ⟨ .ReferenceEquals (← recurse l) (← recurse r), md ⟩ | .AsType t ty => - let t' ← recurse env t valueUsed + let t' ← recurse t valueUsed let isCheck := ⟨ .IsType t' ty, md ⟩ let assertStmt := ⟨ .Assert isCheck, md ⟩ return ⟨ .Block [assertStmt, t'] none, md ⟩ - | .IsType t ty => return ⟨ .IsType (← recurse env t) ty, md ⟩ - | .Forall n ty b => return ⟨ .Forall n ty (← recurse env b), md ⟩ - | .Exists n ty b => return ⟨ .Exists n ty (← recurse env b), md ⟩ - | .Assigned n => return ⟨ .Assigned (← recurse env n), md ⟩ - | .Old v => return ⟨ .Old (← recurse env v), md ⟩ - | .Fresh v => return ⟨ .Fresh (← recurse env v), md ⟩ - | .Assert c => return ⟨ .Assert (← recurse env c), md ⟩ - | .Assume c => return ⟨ .Assume (← recurse env c), md ⟩ - | .ProveBy v p => return ⟨ .ProveBy (← recurse env v) (← recurse env p), md ⟩ - | .ContractOf ty f => return ⟨ .ContractOf ty (← recurse env f), md ⟩ + | .IsType t ty => return ⟨ .IsType (← recurse t) ty, md ⟩ + | .Forall p b => return ⟨ .Forall p (← recurse b), md ⟩ + | .Exists p b => return ⟨ .Exists p (← recurse b), md ⟩ + | .Assigned n => return ⟨ .Assigned (← recurse n), md ⟩ + | .Old v => return ⟨ .Old (← recurse v), md ⟩ + | .Fresh v => return ⟨ .Fresh (← recurse v), md ⟩ + | .Assert c => return ⟨ .Assert (← recurse c), md ⟩ + | .Assume c => return ⟨ .Assume (← recurse c), md ⟩ + | .ProveBy v p => return ⟨ .ProveBy (← recurse v) (← recurse p), md ⟩ + | .ContractOf ty f => return ⟨ .ContractOf ty (← recurse f), md ⟩ | _ => return expr termination_by sizeOf expr decreasing_by @@ -376,16 +338,12 @@ where | -- For the FieldSelect-inside-Assign case: target < fieldSelectMd < expr (have hfs := WithMetadata.sizeOf_val_lt fieldSelectMd; term_by_mem) -def heapTransformProcedure (proc : Procedure) : TransformM Procedure := do - let heapName := "$heap" - let heapInName := "$heap_in" +def heapTransformProcedure (model: SemanticModel) (proc : Procedure) : TransformM Procedure := do + let heapName : Identifier := "$heap" + let heapInName : Identifier := "$heap_in" let readsHeap := (← get).heapReaders.contains proc.name let writesHeap := (← get).heapWriters.contains proc.name - -- Build the type environment from procedure parameters and constants - let initEnv : TypeEnv := proc.inputs.map (fun p => (p.name, p.type)) ++ - proc.outputs.map (fun p => (p.name, p.type)) - if writesHeap then -- This procedure writes the heap - add $heap_in as input and $heap as output -- At the start, assign $heap_in to $heap, then use $heap throughout @@ -396,29 +354,30 @@ def heapTransformProcedure (proc : Procedure) : TransformM Procedure := do let outputs' := heapOutParam :: proc.outputs -- Preconditions use $heap_in (the input state) - let preconditions' ← proc.preconditions.mapM (heapTransformExpr heapInName initEnv) + let preconditions' ← proc.preconditions.mapM (heapTransformExpr heapInName model) let bodyValueIsUsed := !proc.outputs.isEmpty let body' ← match proc.body with | .Transparent bodyExpr => -- First assign $heap_in to $heap, then transform body using $heap let assignHeap := mkMd (.Assign [mkMd (.Identifier heapName)] (mkMd (.Identifier heapInName))) - let bodyExpr' ← heapTransformExpr heapName initEnv bodyExpr bodyValueIsUsed + let bodyExpr' ← heapTransformExpr heapName model bodyExpr bodyValueIsUsed pure (.Transparent (mkMd (.Block [assignHeap, bodyExpr'] none))) | .Opaque postconds impl modif => -- Postconditions use $heap (the output state) - let postconds' ← postconds.mapM (heapTransformExpr heapName initEnv ·) + let postconds' ← postconds.mapM (heapTransformExpr heapName model ·) let impl' ← match impl with | some implExpr => let assignHeap := mkMd (.Assign [mkMd (.Identifier heapName)] (mkMd (.Identifier heapInName))) - let implExpr' ← heapTransformExpr heapName initEnv implExpr bodyValueIsUsed + let implExpr' ← heapTransformExpr heapName model implExpr bodyValueIsUsed pure (some (mkMd (.Block [assignHeap, implExpr'] none))) | none => pure none - let modif' ← modif.mapM (heapTransformExpr heapName initEnv ·) + let modif' ← modif.mapM (heapTransformExpr heapName model ·) pure (.Opaque postconds' impl' modif') | .Abstract postconds => - let postconds' ← postconds.mapM (heapTransformExpr heapName initEnv ·) + let postconds' ← postconds.mapM (heapTransformExpr heapName model ·) pure (.Abstract postconds') + | .External => pure .External return { proc with inputs := inputs', @@ -431,20 +390,21 @@ def heapTransformProcedure (proc : Procedure) : TransformM Procedure := do let heapParam : Parameter := { name := heapName, type := ⟨.THeap, #[]⟩ } let inputs' := heapParam :: proc.inputs - let preconditions' ← proc.preconditions.mapM (heapTransformExpr heapName initEnv) + let preconditions' ← proc.preconditions.mapM (heapTransformExpr heapName model) let body' ← match proc.body with | .Transparent bodyExpr => - let bodyExpr' ← heapTransformExpr heapName initEnv bodyExpr + let bodyExpr' ← heapTransformExpr heapName model bodyExpr pure (.Transparent bodyExpr') | .Opaque postconds impl modif => - let postconds' ← postconds.mapM (heapTransformExpr heapName initEnv ·) - let impl' ← impl.mapM (heapTransformExpr heapName initEnv ·) - let modif' ← modif.mapM (heapTransformExpr heapName initEnv ·) + let postconds' ← postconds.mapM (heapTransformExpr heapName model ·) + let impl' ← impl.mapM (heapTransformExpr heapName model ·) + let modif' ← modif.mapM (heapTransformExpr heapName model ·) pure (.Opaque postconds' impl' modif') | .Abstract postconds => - let postconds' ← postconds.mapM (heapTransformExpr heapName initEnv ·) + let postconds' ← postconds.mapM (heapTransformExpr heapName model ·) pure (.Abstract postconds') + | .External => pure .External return { proc with inputs := inputs', @@ -455,26 +415,28 @@ def heapTransformProcedure (proc : Procedure) : TransformM Procedure := do -- This procedure doesn't read or write the heap - no changes needed return proc -def heapParameterization (program : Program) : Program := +def heapParameterization (model: SemanticModel) (program : Program) : Program := + let program := { program with + types := program.types + staticProcedures := program.staticProcedures } let heapReaders := computeReadsHeap program.staticProcedures let heapWriters := computeWritesHeap program.staticProcedures - -- Extract field types from composite type definitions, qualified with composite type name - let fieldTypes := program.types.foldl (fun acc typeDef => - match typeDef with - | .Composite ct => acc ++ ct.fields.map (fun f => (ct.name ++ "." ++ f.name, f.type)) - | .Constrained _ => acc - | .Datatype _ => acc) [] - let (procs', _) := (program.staticProcedures.mapM heapTransformProcedure).run - { heapReaders, heapWriters, fieldTypes, types := program.types } + let (procs', _) := (program.staticProcedures.mapM (heapTransformProcedure model)).run + { heapReaders, heapWriters } -- Collect all qualified field names and generate a Field datatype let fieldNames := program.types.foldl (fun acc td => match td with - | .Composite ct => acc ++ ct.fields.map (fun f => ct.name ++ "." ++ f.name) + | .Composite ct => acc ++ ct.fields.map (fun f => (mkId $ ct.name.text ++ "." ++ f.name.text)) | _ => acc) ([] : List Identifier) let fieldDatatype : TypeDefinition := .Datatype { name := "Field", typeArgs := [], constructors := fieldNames.map fun n => { name := n, args := [] } } + -- Remove fields from composite types since they are now stored in the heap + let types' := program.types.map fun td => + match td with + | .Composite ct => .Composite { ct with fields := [] } + | other => other { program with - staticProcedures := procs', - types := program.types ++ [fieldDatatype] } + staticProcedures := heapConstants.staticProcedures ++ procs', + types := fieldDatatype :: heapConstants.types ++ types' } end Strata.Laurel diff --git a/Strata/Languages/Laurel/HeapParameterizationConstants.lean b/Strata/Languages/Laurel/HeapParameterizationConstants.lean new file mode 100644 index 000000000..5fde9755e --- /dev/null +++ b/Strata/Languages/Laurel/HeapParameterizationConstants.lean @@ -0,0 +1,81 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +import Strata.DDM.Elab +import Strata.DDM.AST +import Strata.Languages.Laurel.Grammar.LaurelGrammar +import Strata.Languages.Laurel.Grammar.ConcreteToAbstractTreeTranslator + +namespace Strata.Laurel + +/-- +The Laurel Core prelude defines the heap model types and operations +used by the Laurel-to-Core translator. These declarations are expressed +in Laurel syntax via the `#strata program Laurel` macro and parsed into +a `Laurel.Program` at compile time. + +The heap model uses: +- `Composite` - datatype with a reference (int) and a runtime type tag +- `Field` - abstract type for field names (zero-constructor datatype) +- `TypeTag` - abstract type for type tags (zero-constructor datatype) +- `Box` - tagged union for field values (int, bool, float64, Composite) +- `Heap` - datatype with a `data` map and a `nextReference` for allocation +- `readField` / `updateField` / `increment` - heap access functions +-/ + +private def laurelPreludeDDM := +#strata +program Laurel; + +// Composite: datatype with a reference (int) +datatype Composite { MkComposite(ref: int) } + +// Box: tagged union for field values +datatype Box { + BoxInt(intVal: int), + BoxBool(boolVal: bool), + BoxFloat64(float64Val: float64), + BoxComposite(compositeVal: Composite) +} + +// Heap: contains the data map and a nextReference for allocation +datatype Heap { + MkHeap(data: Map Composite Map Field Box, nextReference: int) +} + +// Read a field from the heap: readField(heap, obj, field) = Heap..data!(heap)[obj][field] +function readField(heap: Heap, obj: Composite, field: Field): Box { + select(select(Heap..data!(heap), obj), field) +} + +// Update a field in the heap +function updateField(heap: Heap, obj: Composite, field: Field, val: Box): Heap { + MkHeap( + update(Heap..data!(heap), obj, + update(select(Heap..data!(heap), obj), field, val)), + Heap..nextReference!(heap)) +} + +// Increment the heap allocation nextReference, returning a new heap +function increment(heap: Heap): Heap { + MkHeap(Heap..data!(heap), Heap..nextReference!(heap) + 1) +} + +// The #strata macro does not identify the end macro correctly, +// because Laurel's grammar also parses # signs +// Having this datatype here brings the parser in a state where it won't consume the # +// A fix would be to require ';' after the body of functions/procedures +datatype Workaround {} +#end + +/-- The Laurel Core prelude as a Laurel Program. -/ +def heapConstants : Program := + let uri := Strata.Uri.file "Strata/Languages/Laurel/HeapParameterizationConstants.lean" + match Laurel.TransM.run uri (Laurel.parseProgram laurelPreludeDDM) with + | .ok program => program + | .error e => panic! s!"Laurel heap prelude parse error: {e}" + +end Strata.Laurel diff --git a/Strata/Languages/Laurel/Laurel.lean b/Strata/Languages/Laurel/Laurel.lean index 80d66d0a3..46519de20 100644 --- a/Strata/Languages/Laurel/Laurel.lean +++ b/Strata/Languages/Laurel/Laurel.lean @@ -15,7 +15,30 @@ Documentation for Laurel can be found in docs/verso/LaurelDoc.lean namespace Strata namespace Laurel -abbrev Identifier := String /- Potentially this could be an Int to save resources. -/ + +/-- A name-introduction site (variable declaration, procedure, field, type, etc.). + Carries a mandatory unique ID assigned by the resolution pass. -/ +structure Identifier where + /-- The declared name. -/ + text : String + /-- Unique ID assigned by the resolution pass. -/ + uniqueId : Option Nat := none + deriving Repr + +-- Temporary hack because the Python through Laurel pipeline doesn't resolve +instance : BEq Identifier where + beq a b := a.text == b.text + +instance : Inhabited Identifier where + default := { text := "defaultIdentifier" } + +instance : ToString Identifier where + toString id := id.text + +instance : Coe String Identifier where + coe s := Identifier.mk s none + +def mkId (name: String): Identifier := Identifier.mk name none /-- Primitive operations available in Laurel expressions. @@ -112,7 +135,7 @@ inductive HighType : Type where | TSet (elementType : WithMetadata HighType) /-- Map type. -/ | TMap (keyType : WithMetadata HighType) (valueType : WithMetadata HighType) - /-- A reference to a user-defined composite or constrained type by name. -/ + /-- A Identifier to a user-defined composite or constrained type by name. -/ | UserDefined (name : Identifier) /-- A generic type application, e.g. `List`. -/ | Applied (base : WithMetadata HighType) (typeArguments : List (WithMetadata HighType)) @@ -146,7 +169,7 @@ structure Procedure : Type where /-- Optional termination measure for recursive procedures. -/ decreases : Option (WithMetadata StmtExpr) -- optionally prove termination /-- If true, the body may only have functional constructs, so no destructive assignments or loops. -/ - isFunctional : Bool := false + isFunctional : Bool /-- The procedure body: transparent, opaque, or abstract. -/ body : Body /-- Source-level metadata. -/ @@ -188,6 +211,8 @@ inductive Body where (modifies : List (WithMetadata StmtExpr)) /-- An abstract body that must be overridden in extending types. A type containing any members with abstract bodies cannot be instantiated. -/ | Abstract (postconditions : List (WithMetadata StmtExpr)) + /-- An external body for procedures that are not translated to Core (e.g., built-in primitives). -/ + | External /-- The unified statement-expression type for Laurel programs. @@ -201,7 +226,7 @@ inductive StmtExpr : Type where /-- Conditional with a then-branch and optional else-branch. -/ | IfThenElse (cond : WithMetadata StmtExpr) (thenBranch : WithMetadata StmtExpr) (elseBranch : Option (WithMetadata StmtExpr)) /-- A sequence of statements with an optional label for `Exit`. -/ - | Block (statements : List (WithMetadata StmtExpr)) (label : Option Identifier) + | Block (statements : List (WithMetadata StmtExpr)) (label : Option String) /-- A local variable declaration with a type and optional initializer. The initializer must be set if this `StmtExpr` is pure. -/ | LocalVariable (name : Identifier) (type : WithMetadata HighType) (initializer : Option (WithMetadata StmtExpr)) /-- A while loop with a condition, invariants, optional termination measure, and body. Only allowed in impure contexts. -/ @@ -209,7 +234,7 @@ inductive StmtExpr : Type where (decreases : Option (WithMetadata StmtExpr)) (body : WithMetadata StmtExpr) /-- Exit a labelled block. Models `break` and `continue` statements. -/ - | Exit (target : Identifier) + | Exit (target : String) /-- Return from the enclosing procedure with an optional value. -/ | Return (value : Option (WithMetadata StmtExpr)) /-- An integer literal. -/ @@ -231,8 +256,8 @@ inductive StmtExpr : Type where /-- Apply a primitive operation to the given arguments. -/ | PrimitiveOp (operator : Operation) (arguments : List (WithMetadata StmtExpr)) /-- Create new object (`new`). -/ - | New (name: Identifier) - /-- Reference to the current object (`this`/`self`). -/ + | New (ref : Identifier) + /-- Identifier to the current object (`this`/`self`). -/ | This /-- Reference equality test between two expressions. -/ | ReferenceEquals (lhs : WithMetadata StmtExpr) (rhs : WithMetadata StmtExpr) @@ -242,10 +267,10 @@ inductive StmtExpr : Type where | IsType (target : WithMetadata StmtExpr) (type : WithMetadata HighType) /-- Call an instance method on a target object. -/ | InstanceCall (target : WithMetadata StmtExpr) (callee : Identifier) (arguments : List (WithMetadata StmtExpr)) - /-- Universal quantification over a typed variable. -/ - | Forall (name : Identifier) (type : WithMetadata HighType) (body : WithMetadata StmtExpr) - /-- Existential quantification over a typed variable. -/ - | Exists (name : Identifier) (type : WithMetadata HighType) (body : WithMetadata StmtExpr) + /-- Universal quantification over a typed parameter. -/ + | Forall (param : Parameter) (body : WithMetadata StmtExpr) + /-- Existential quantification over a typed parameter. -/ + | Exists (param : Parameter) (body : WithMetadata StmtExpr) /-- Check whether a variable has been assigned. -/ | Assigned (name : WithMetadata StmtExpr) /-- Refer to the pre-state value of an expression in a postcondition. -/ @@ -299,7 +324,7 @@ def highEq (a : HighTypeMd) (b : HighTypeMd) : Bool := match _a: a.val, _b: b.va | HighType.TTypedField t1, HighType.TTypedField t2 => highEq t1 t2 | HighType.TSet t1, HighType.TSet t2 => highEq t1 t2 | HighType.TMap k1 v1, HighType.TMap k2 v2 => highEq k1 k2 && highEq v1 v2 - | HighType.UserDefined n1, HighType.UserDefined n2 => n1 == n2 + | HighType.UserDefined r1, HighType.UserDefined r2 => r1.text == r2.text | HighType.Applied b1 args1, HighType.Applied b2 args2 => highEq b1 b2 && args1.length == args2.length && (args1.attach.zip args2 |>.all (fun (a1, a2) => highEq a1.1 a2)) | HighType.Pure b1, HighType.Pure b2 => highEq b1 b2 @@ -320,6 +345,10 @@ def HighType.isBool : HighType → Bool | TBool => true | _ => false +def Body.isExternal : Body → Bool + | .External => true + | _ => false + def HighTypeMd.isBool (t : HighTypeMd) : Bool := t.val.isBool /-- @@ -372,7 +401,7 @@ structure ConstrainedType where /-- A constructor of a Laurel datatype, with a name and typed arguments. -/ structure DatatypeConstructor where name : Identifier - args : List (Identifier × HighTypeMd) + args : List Parameter /-- A Laurel datatype definition with optional type parameters. Zero constructors produces an opaque (abstract) type in Core. diff --git a/Strata/Languages/Laurel/LaurelFormat.lean b/Strata/Languages/Laurel/LaurelFormat.lean index 263749946..a241f3e84 100644 --- a/Strata/Languages/Laurel/LaurelFormat.lean +++ b/Strata/Languages/Laurel/LaurelFormat.lean @@ -9,7 +9,7 @@ import Strata.Languages.Laurel.Laurel namespace Strata namespace Laurel -open Std (Format) +open Std (Format format) open Std.Format def formatOperation : Operation → Format @@ -49,7 +49,7 @@ def formatHighTypeVal : HighType → Format | .TTypedField valueType => "Field[" ++ formatHighType valueType ++ "]" | .TSet elementType => "Set[" ++ formatHighType elementType ++ "]" | .TMap keyType valueType => "Map[" ++ formatHighType keyType ++ ", " ++ formatHighType valueType ++ "]" - | .UserDefined name => Format.text name + | .UserDefined ref => format ref | .Applied base args => Format.text "(" ++ formatHighType base ++ " " ++ Format.joinSep (args.map formatHighType) " " ++ ")" @@ -78,8 +78,8 @@ def formatStmtExprVal (s : StmtExpr) : Format := | some e => " else " ++ formatStmtExpr e | .Block stmts _ => group $ "{" ++ nestD (line ++ joinSep (stmts.map formatStmtExpr) (";" ++ line)) ++ line ++ "}" - | .LocalVariable name ty init => - "var " ++ Format.text name ++ ": " ++ formatHighType ty ++ + | .LocalVariable defn ty init => + "var " ++ format defn ++ ": " ++ formatHighType ty ++ match init with | none => "" | some e => " := " ++ formatStmtExpr e @@ -96,24 +96,24 @@ def formatStmtExprVal (s : StmtExpr) : Format := | .LiteralInt n => Format.text (toString n) | .LiteralBool b => if b then "true" else "false" | .LiteralString s => "\"" ++ Format.text s ++ "\"" - | .Identifier name => Format.text name + | .Identifier ref => format ref | .Assign [single] value => formatStmtExpr single ++ " := " ++ formatStmtExpr value | .Assign targets value => "(" ++ Format.joinSep (targets.map formatStmtExpr) ", " ++ ")" ++ " := " ++ formatStmtExpr value | .FieldSelect target field => - formatStmtExpr target ++ "#" ++ Format.text field + formatStmtExpr target ++ "#" ++ format field | .PureFieldUpdate target field value => - formatStmtExpr target ++ " with { " ++ Format.text field ++ " := " ++ formatStmtExpr value ++ " }" + formatStmtExpr target ++ " with { " ++ format field ++ " := " ++ formatStmtExpr value ++ " }" | .StaticCall name args => - Format.text name ++ "(" ++ Format.joinSep (args.map formatStmtExpr) ", " ++ ")" + format name ++ "(" ++ Format.joinSep (args.map formatStmtExpr) ", " ++ ")" | .PrimitiveOp op [a] => formatOperation op ++ formatStmtExpr a | .PrimitiveOp op [a, b] => formatStmtExpr a ++ " " ++ formatOperation op ++ " " ++ formatStmtExpr b | .PrimitiveOp op args => formatOperation op ++ "(" ++ Format.joinSep (args.map formatStmtExpr) ", " ++ ")" - | .New name => "new " ++ Format.text name + | .New name => "new " ++ format name | .This => "this" | .ReferenceEquals lhs rhs => formatStmtExpr lhs ++ " === " ++ formatStmtExpr rhs @@ -122,12 +122,12 @@ def formatStmtExprVal (s : StmtExpr) : Format := | .IsType target ty => formatStmtExpr target ++ " is " ++ formatHighType ty | .InstanceCall target name args => - formatStmtExpr target ++ "." ++ Format.text name ++ "(" ++ + formatStmtExpr target ++ "." ++ format name ++ "(" ++ Format.joinSep (args.map formatStmtExpr) ", " ++ ")" - | .Forall name ty body => - "forall " ++ Format.text name ++ ": " ++ formatHighType ty ++ " => " ++ formatStmtExpr body - | .Exists name ty body => - "exists " ++ Format.text name ++ ": " ++ formatHighType ty ++ " => " ++ formatStmtExpr body + | .Forall param body => + "forall " ++ format param.name ++ ": " ++ formatHighType param.type ++ " => " ++ formatStmtExpr body + | .Exists param body => + "exists " ++ format param.name ++ ": " ++ formatHighType param.type ++ " => " ++ formatStmtExpr body | .Assigned name => "assigned(" ++ formatStmtExpr name ++ ")" | .Old value => "old(" ++ formatStmtExpr value ++ ")" | .Fresh value => "fresh(" ++ formatStmtExpr value ++ ")" @@ -144,7 +144,7 @@ def formatStmtExprVal (s : StmtExpr) : Format := end def formatParameter (p : Parameter) : Format := - Format.text p.name ++ ": " ++ formatHighType p.type + format p.name ++ ": " ++ formatHighType p.type def formatBody : Body → Format | .Transparent body => formatStmtExpr body @@ -156,6 +156,7 @@ def formatBody : Body → Format | none => Format.nil | some e => " := " ++ formatStmtExpr e | .Abstract posts => "abstract" ++ Format.join (posts.map (fun p => " ensures " ++ formatStmtExpr p)) + | .External => "external" def formatDeterminism : Determinism → Format | .deterministic none => "deterministic" @@ -166,7 +167,7 @@ instance : Std.ToFormat Determinism where format := formatDeterminism def formatProcedure (proc : Procedure) : Format := - (if proc.isFunctional then "function " else "procedure ") ++ Format.text proc.name ++ + (if proc.isFunctional then "function " else "procedure ") ++ format proc.name ++ "(" ++ Format.joinSep (proc.inputs.map formatParameter) ", " ++ ") returns " ++ Format.line ++ "(" ++ Format.joinSep (proc.outputs.map formatParameter) ", " ++ ")" ++ Format.line ++ Format.join (proc.preconditions.map (fun p => "requires " ++ formatStmtExpr p ++ Format.line)) ++ @@ -175,28 +176,28 @@ def formatProcedure (proc : Procedure) : Format := def formatField (f : Field) : Format := (if f.isMutable then "var " else "val ") ++ - Format.text f.name ++ ": " ++ formatHighType f.type + format f.name ++ ": " ++ formatHighType f.type def formatCompositeType (ct : CompositeType) : Format := - "composite " ++ Format.text ct.name ++ + "composite " ++ format ct.name ++ (if ct.extending.isEmpty then Format.nil else " extends " ++ - Format.joinSep (ct.extending.map Format.text) ", ") ++ + Format.joinSep (ct.extending.map format) ", ") ++ " { " ++ Format.joinSep (ct.fields.map formatField) "; " ++ " }" def formatConstrainedType (ct : ConstrainedType) : Format := - "constrained " ++ Format.text ct.name ++ - " = " ++ Format.text ct.valueName ++ ": " ++ formatHighType ct.base ++ + "constrained " ++ format ct.name ++ + " = " ++ format ct.valueName ++ ": " ++ formatHighType ct.base ++ " | " ++ formatStmtExpr ct.constraint def formatDatatypeConstructor (c : DatatypeConstructor) : Format := - Format.text c.name ++ + format c.name ++ if c.args.isEmpty then Format.nil - else "(" ++ Format.joinSep (c.args.map fun (n, ty) => Format.text n ++ ": " ++ formatHighType ty) ", " ++ ")" + else "(" ++ Format.joinSep (c.args.map fun p => formatParameter p) ", " ++ ")" def formatDatatypeDefinition (dt : DatatypeDefinition) : Format := - "datatype " ++ Format.text dt.name ++ + "datatype " ++ format dt.name ++ (if dt.typeArgs.isEmpty then Format.nil - else "(" ++ Format.joinSep (dt.typeArgs.map Format.text) ", " ++ ")") ++ + else "(" ++ Format.joinSep (dt.typeArgs.map format) ", " ++ ")") ++ " { " ++ Format.joinSep (dt.constructors.map formatDatatypeConstructor) ", " ++ " }" def formatTypeDefinition : TypeDefinition → Format @@ -204,8 +205,18 @@ def formatTypeDefinition : TypeDefinition → Format | .Constrained ty => formatConstrainedType ty | .Datatype ty => formatDatatypeDefinition ty +def formatConstant (c : Constant) : Format := + "const " ++ format c.name ++ ": " ++ formatHighType c.type ++ + match c.initializer with + | none => "" + | some e => " := " ++ formatStmtExpr e + def formatProgram (prog : Program) : Format := - Format.joinSep (prog.staticProcedures.map formatProcedure) "\n\n" + let types := prog.types.map formatTypeDefinition + let constants := prog.constants.map formatConstant + let fields := prog.staticFields.map formatField + let procs := prog.staticProcedures.map formatProcedure + Format.joinSep (types ++ constants ++ fields ++ procs) "\n\n" instance : Std.ToFormat Operation where format := formatOperation @@ -246,10 +257,28 @@ instance : Std.ToFormat DatatypeConstructor where instance : Std.ToFormat DatatypeDefinition where format := formatDatatypeDefinition +instance : Std.ToFormat Constant where + format := formatConstant + instance : Std.ToFormat TypeDefinition where format := formatTypeDefinition instance : Std.ToFormat Program where format := formatProgram +instance : Repr StmtExpr where + reprPrec r _ := s!"{Std.format r}" + +instance : Repr HighType where + reprPrec r _ := s!"{Std.format r}" + +deriving instance Repr for Strata.Laurel.Parameter +deriving instance Repr for Strata.Laurel.Procedure +deriving instance Repr for Strata.Laurel.Field +deriving instance Repr for Strata.Laurel.CompositeType +deriving instance Repr for Strata.Laurel.ConstrainedType +deriving instance Repr for Strata.Laurel.DatatypeConstructor +deriving instance Repr for Strata.Laurel.DatatypeDefinition +deriving instance Repr for Strata.Laurel.Constant + end Laurel diff --git a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean index 3aefa2cf9..ffd7cb444 100644 --- a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean +++ b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean @@ -15,7 +15,7 @@ import Strata.Languages.Laurel.HeapParameterization import Strata.Languages.Laurel.TypeHierarchy import Strata.Languages.Laurel.LaurelTypes import Strata.Languages.Laurel.ModifiesClauses -import Strata.Languages.Laurel.CorePrelude +import Strata.Languages.Laurel.CoreDefinitionsForLaurel import Strata.DL.Imperative.Stmt import Strata.DL.Imperative.MetaData import Strata.DL.Lambda.LExpr @@ -34,7 +34,7 @@ open Lambda (LMonoTy LTy LExpr) /- Translate Laurel HighType to Core Type -/ -def translateType (ty : HighTypeMd) : LMonoTy := +def translateType (model : SemanticModel) (ty : HighTypeMd) : LMonoTy := match _h : ty.val with | .TInt => LMonoTy.int | .TBool => LMonoTy.bool @@ -42,18 +42,24 @@ def translateType (ty : HighTypeMd) : LMonoTy := | .TVoid => LMonoTy.bool -- Using bool as placeholder for void | .THeap => .tcons "Heap" [] | .TTypedField _ => .tcons "Field" [] - | .TSet elementType => Core.mapTy (translateType elementType) LMonoTy.bool - | .TMap keyType valueType => Core.mapTy (translateType keyType) (translateType valueType) - | .UserDefined _ => .tcons "Composite" [] + | .TSet elementType => Core.mapTy (translateType model elementType) LMonoTy.bool + | .TMap keyType valueType => Core.mapTy (translateType model keyType) (translateType model valueType) + | .UserDefined name => + -- Composite types map to "Composite"; datatypes map to their own name + match name.uniqueId.bind model.refToDef.get? with + | some (.compositeType _) => .tcons "Composite" [] + | some (.datatypeDefinition dt) => .tcons dt.name.text [] + | _ => .tcons "Composite" [] -- fallback for unresolved refs | .TCore s => .tcons s [] - | _ => panic s!"unsupported type {ToFormat.format ty}" + | .TFloat64 => LMonoTy.real -- Incorrect? + | _ => panic s!"translateType: unsupported type {ToFormat.format ty}" termination_by ty.val decreasing_by all_goals (first | (cases elementType; term_by_mem) | (cases keyType; term_by_mem) | (cases valueType; term_by_mem)) -def lookupType (env : TypeEnv) (name : Identifier) : LMonoTy := - match env.find? (fun (n, _) => n == name) with - | some (_, ty) => translateType ty - | none => panic s!"could not find variable {name} in environment '{Std.format env}'" +def lookupType (model : SemanticModel) (name : Identifier) : LMonoTy := + match (model.get name).getType with + | .some ty => translateType model ty + | none => panic s!"no type for {name}" def isFieldName (fieldNames : List Identifier) (name : Identifier) : Bool := fieldNames.contains name @@ -61,27 +67,14 @@ def isFieldName (fieldNames : List Identifier) (name : Identifier) : Bool := /-- Set of names that are translated to Core functions (not procedures) -/ abbrev FunctionNames := List Identifier -def isCoreFunction (funcNames : FunctionNames) (name : Identifier) : Bool := - -- readField, updateField, and Box constructors/destructors are always functions - name == "readField" || name == "updateField" || name == "increment" || - name == "MkHeap" || name == "Heap..data!" || name == "Heap..nextReference!" || - name == "MkComposite" || name == "Composite..ref!" || name == "Composite..typeTag!" || - name == "BoxInt" || name == "BoxBool" || name == "BoxFloat64" || name == "BoxComposite" || - name == "Box..intVal!" || name == "Box..boolVal!" || name == "Box..float64Val!" || name == "Box..compositeVal!" || - -- Map operations (select/update are built-in map functions, const creates a constant map) - name == "select" || name == "update" || name == "const" || - -- Type hierarchy functions generated by typeHierarchyTransform - name == "ancestorsPerType" || name.startsWith "ancestorsFor" || name.endsWith "_TypeTag" || - funcNames.contains name - /-- State threaded through expression and statement translation -/ structure TranslateState where /-- Diagnostics accumulated during translation -/ diagnostics : List DiagnosticModel := [] + /-- Next fresh ID to allocate. -/ + nextId : Nat := 1 /-- Constants known to the program (field constants, etc.) -/ - fieldNames : List String := [] - /-- Names of procedures that are translated as Core functions -/ - funcNames : FunctionNames := [] + model : SemanticModel /-- The translation monad: state over Id -/ abbrev TranslateM := StateT TranslateState Id @@ -94,6 +87,13 @@ def emitDiagnostic (d : DiagnosticModel) : TranslateM Unit := def runTranslateM (s : TranslateState) (m : TranslateM α) : α × TranslateState := m s +/-- Allocate a fresh unique ID. -/ +private def freshId : TranslateM Nat := do + let s ← get + let id := s.nextId + set { s with nextId := id + 1 } + return id + /-- Translate Laurel StmtExpr to Core Expression using the `TranslateM` monad. Diagnostics for disallowed constructs are emitted into the monad state. @@ -107,14 +107,13 @@ because `liftImperativeExpressions` should have already removed them. When an Identifier matches a bound name at index `i`, it becomes `bvar i` (de Bruijn index) instead of `fvar`. -/ -def translateExpr (env : TypeEnv) (expr : StmtExprMd) +def translateExpr (expr : StmtExprMd) (boundVars : List Identifier := []) (isPureContext : Bool := false) : TranslateM Core.Expression.Expr := do let s ← get - let fieldNames := s.fieldNames - let funcNames := s.funcNames + let model := s.model -- Dummy expression used as placeholder when an error is emitted in pure context - let dummy := .fvar () (⟨s!"DUMMY_VAR_{env.length}", ()⟩) none + let dummy := .fvar () (⟨s!"DUMMY_VAR_{← freshId}", ()⟩) none -- Emit an error in pure context; panic in impure context (lifting invariant violated) let disallowed (e : StmtExprMd) (msg : String) : TranslateM Core.Expression.Expr := do if isPureContext then @@ -133,23 +132,23 @@ def translateExpr (env : TypeEnv) (expr : StmtExprMd) -- Bound variable: use de Bruijn index return .bvar () idx | none => - -- Check if this is a field name (datatype constructor) or local variable - if isFieldName fieldNames name then - return .op () ⟨name, ()⟩ none - else - return .fvar () ⟨name, ()⟩ (some (lookupType env name)) + match model.get name with + | .field _ f => + return .op () ⟨f.name.text, ()⟩ none + | astNode => + return .fvar () ⟨name.text, ()⟩ (some (translateType model $ astNode.getType.getD (panic! "LaurelToCore.translateExpr"))) | .PrimitiveOp op [e] => match op with | .Not => - let re ← translateExpr env e boundVars isPureContext + let re ← translateExpr e boundVars isPureContext return .app () boolNotOp re | .Neg => - let re ← translateExpr env e boundVars isPureContext + let re ← translateExpr e boundVars isPureContext return .app () intNegOp re | _ => panic! s!"translateExpr: Invalid unary op: {repr op}" | .PrimitiveOp op [e1, e2] => - let re1 ← translateExpr env e1 boundVars isPureContext - let re2 ← translateExpr env e2 boundVars isPureContext + let re1 ← translateExpr e1 boundVars isPureContext + let re2 ← translateExpr e2 boundVars isPureContext let binOp (bop : Core.Expression.Expr) : Core.Expression.Expr := LExpr.mkApp () bop [re1, re2] match op with @@ -174,38 +173,38 @@ def translateExpr (env : TypeEnv) (expr : StmtExprMd) | .PrimitiveOp op args => panic! s!"translateExpr: PrimitiveOp {repr op} with {args.length} args" | .IfThenElse cond thenBranch elseBranch => - let bcond ← translateExpr env cond boundVars isPureContext - let bthen ← translateExpr env thenBranch boundVars isPureContext + let bcond ← translateExpr cond boundVars isPureContext + let bthen ← translateExpr thenBranch boundVars isPureContext let belse ← match elseBranch with | none => panic "if-then without else expression not yet implemented" | some e => have : sizeOf e < sizeOf expr := by have := WithMetadata.sizeOf_val_lt expr cases expr; simp_all; omega - translateExpr env e boundVars isPureContext + translateExpr e boundVars isPureContext return .ite () bcond bthen belse - | .StaticCall name args => + | .StaticCall callee args => -- In a pure context, only Core functions (not procedures) are allowed - if isPureContext && !isCoreFunction funcNames name then + if isPureContext && !model.isFunction callee then disallowed expr "calls to procedures are not supported in functions or contracts" else - let fnOp : Core.Expression.Expr := .op () ⟨name, ()⟩ none + let fnOp : Core.Expression.Expr := .op () ⟨callee.text, ()⟩ none args.attach.foldlM (fun acc ⟨arg, _⟩ => do - let re ← translateExpr env arg boundVars isPureContext + let re ← translateExpr arg boundVars isPureContext return .app () acc re) fnOp - | .Block [single] _ => translateExpr env single boundVars isPureContext - | .Forall name ty body => - let coreTy := translateType ty - let coreBody ← translateExpr env body (name :: boundVars) isPureContext - return LExpr.all () name (some coreTy) coreBody - | .Exists name ty body => - let coreTy := translateType ty - let coreBody ← translateExpr env body (name :: boundVars) isPureContext - return LExpr.exist () name (some coreTy) coreBody + | .Block [single] _ => translateExpr single boundVars isPureContext + | .Forall ⟨ name, ty ⟩ body => + let coreTy := translateType model ty + let coreBody ← translateExpr body (name :: boundVars) isPureContext + return LExpr.all () name.text (some coreTy) coreBody + | .Exists ⟨ name, ty ⟩ body => + let coreTy := translateType model ty + let coreBody ← translateExpr body (name :: boundVars) isPureContext + return LExpr.exist () name.text (some coreTy) coreBody | .Hole => return dummy | .ReferenceEquals e1 e2 => - let re1 ← translateExpr env e1 boundVars isPureContext - let re2 ← translateExpr env e2 boundVars isPureContext + let re1 ← translateExpr e1 boundVars isPureContext + let re2 ← translateExpr e2 boundVars isPureContext return .eq () re1 re2 | .Assign _ _ => disallowed expr "destructive assignments are not supported in functions or contracts" @@ -215,10 +214,11 @@ def translateExpr (env : TypeEnv) (expr : StmtExprMd) | .IsType _ _ => panic "IsType should have been lowered" | .New _ => panic! s!"New should have been eliminated by typeHierarchyTransform" - | .FieldSelect target fieldName => + | .FieldSelect target fieldId => -- Field selects should have been eliminated by heap parameterization -- If we see one here, it's an error in the pipeline - panic! s!"FieldSelect should have been eliminated by heap parameterization: {Std.ToFormat.format target}#{fieldName}" + panic! s!"FieldSelect should have been eliminated by heap parameterization: {Std.ToFormat.format target}#{fieldId.text}" + | .Block _ _ => panic "block expression not yet implemented (should be lowered in a separate pass)" | .LocalVariable _ _ _ => panic "local variable expression not yet implemented (should be lowered in a separate pass)" | .Return _ => disallowed expr "return expression not yet implemented (should be lowered in a separate pass)" @@ -244,7 +244,7 @@ def getNameFromMd (md : Imperative.MetaData Core.Expression): String := let fileRange := (Imperative.getFileRange md).getD (panic "getNameFromMd bug") s!"({fileRange.range.start})" -def defaultExprForType (ty : HighTypeMd) : Core.Expression.Expr := +def defaultExprForType (model : SemanticModel) (ty : HighTypeMd) : Core.Expression.Expr := match ty.val with | .TInt => .const () (.intConst 0) | .TBool => .const () (.boolConst false) @@ -253,120 +253,115 @@ def defaultExprForType (ty : HighTypeMd) : Core.Expression.Expr := -- For types without a natural default (arrays, composites, etc.), -- use a fresh free variable. This is only used when the value is -- immediately overwritten by a procedure call. - let coreTy := translateType ty + let coreTy := translateType model ty .fvar () (⟨"$default", ()⟩) (some coreTy) /-- Translate Laurel StmtExpr to Core Statements using the `TranslateM` monad. Diagnostics are emitted into the monad state. -/ -def translateStmt (env : TypeEnv) (outputParams : List Parameter) (stmt : StmtExprMd) - : TranslateM (TypeEnv × List Core.Statement) := do +def translateStmt (outputParams : List Parameter) (stmt : StmtExprMd) + : TranslateM (List Core.Statement) := do let s ← get - let functionNames := s.funcNames + let model := s.model let md := stmt.md match _h : stmt.val with | @StmtExpr.Assert cond => -- Assert/assume bodies must be pure expressions (no assignments, loops, or procedure calls) - let coreExpr ← translateExpr env cond [] (isPureContext := true) - return (env, [Core.Statement.assert ("assert" ++ getNameFromMd md) coreExpr md]) + let coreExpr ← translateExpr cond [] (isPureContext := true) + return [Core.Statement.assert ("assert" ++ getNameFromMd md) coreExpr md] | @StmtExpr.Assume cond => - let coreExpr ← translateExpr env cond [] (isPureContext := true) - return (env, [Core.Statement.assume ("assume" ++ getNameFromMd md) coreExpr md]) - | .Block stmts _ => - let (env', stmtsList) ← stmts.attach.foldlM (fun (e, acc) ⟨s, _hs⟩ => do - let (e', ss) ← translateStmt e outputParams s - return (e', acc ++ ss)) (env, []) - return (env', stmtsList) - | .LocalVariable name ty initializer => - let env' := (name, ty) :: env - let boogieMonoType := translateType ty - let boogieType := LTy.forAll [] boogieMonoType - let ident := ⟨name, ()⟩ + let coreExpr ← translateExpr cond [] (isPureContext := true) + return [Core.Statement.assume ("assume" ++ getNameFromMd md) coreExpr md] + | .Block stmts _ => stmts.flatMapM (fun s => translateStmt outputParams s) + | .LocalVariable id ty initializer => + let coreMonoType := translateType model ty + let coreType := LTy.forAll [] coreMonoType + let ident := ⟨id.text, ()⟩ match initializer with | some (⟨ .StaticCall callee args, callMd⟩) => -- Check if this is a function or a procedure call - if isCoreFunction functionNames callee then + if model.isFunction callee then -- Translate as expression (function application) - let boogieExpr ← translateExpr env (⟨ .StaticCall callee args, callMd ⟩) - return (env', [Core.Statement.init ident boogieType (some boogieExpr) md]) + let coreExpr ← translateExpr (⟨ .StaticCall callee args, callMd ⟩) + return [Core.Statement.init ident coreType (some coreExpr) md] else -- Translate as: var name; call name := callee(args) - let coreArgs ← args.mapM (fun a => translateExpr env a) - let defaultExpr := defaultExprForType ty - let initStmt := Core.Statement.init ident boogieType (some defaultExpr) md - let callStmt := Core.Statement.call [ident] callee coreArgs md - return (env', [initStmt, callStmt]) + let coreArgs ← args.mapM (fun a => translateExpr a) + let defaultExpr := defaultExprForType model ty + let initStmt := Core.Statement.init ident coreType (some defaultExpr) md + let callStmt := Core.Statement.call [ident] callee.text coreArgs md + return [initStmt, callStmt] | some initExpr => - let coreExpr ← translateExpr env initExpr - return (env', [Core.Statement.init ident boogieType (some coreExpr) md]) + let coreExpr ← translateExpr initExpr + return [Core.Statement.init ident coreType (some coreExpr) md] | none => - let defaultExpr := defaultExprForType ty - return (env', [Core.Statement.init ident boogieType (some defaultExpr) md]) + let defaultExpr := defaultExprForType model ty + return [Core.Statement.init ident coreType (some defaultExpr) md] | .Assign targets value => match targets with - | [⟨ .Identifier name, _ ⟩] => - let ident := ⟨name, ()⟩ + | [⟨ .Identifier targetId, _ ⟩] => + let ident := ⟨targetId.text, ()⟩ -- Check if RHS is a procedure call (not a function) match value.val with | .StaticCall callee args => - if isCoreFunction functionNames callee then + if model.isFunction callee then -- Functions are translated as expressions - let boogieExpr ← translateExpr env value - return (env, [Core.Statement.set ident boogieExpr md]) + let coreExpr ← translateExpr value + return [Core.Statement.set ident coreExpr md] else -- Procedure calls need to be translated as call statements - let coreArgs ← args.mapM (fun a => translateExpr env a) - return (env, [Core.Statement.call [ident] callee coreArgs md]) + let coreArgs ← args.mapM (fun a => translateExpr a) + return [Core.Statement.call [ident] callee.text coreArgs md] | _ => - let boogieExpr ← translateExpr env value - return (env, [Core.Statement.set ident boogieExpr md]) + let coreExpr ← translateExpr value + return [Core.Statement.set ident coreExpr md] | _ => -- Parallel assignment: (var1, var2, ...) := expr -- Example use is heap-modifying procedure calls: (result, heap) := f(heap, args) match value.val with | .StaticCall callee args => - let coreArgs ← args.mapM (fun a => translateExpr env a) + let coreArgs ← args.mapM (fun a => translateExpr a) let lhsIdents := targets.filterMap fun t => match t.val with - | .Identifier name => some (⟨name, ()⟩) + | .Identifier name => some (⟨name.text, ()⟩) | _ => none - return (env, [Core.Statement.call lhsIdents callee coreArgs value.md]) + return [Core.Statement.call lhsIdents callee.text coreArgs value.md] | _ => panic "Assignments with multiple target but without a RHS call should not be constructed" | .IfThenElse cond thenBranch elseBranch => - let bcond ← translateExpr env cond - let (_, bthen) ← translateStmt env outputParams thenBranch + let bcond ← translateExpr cond + let bthen ← translateStmt outputParams thenBranch let belse ← match elseBranch with - | some e => (·.2) <$> translateStmt env outputParams e + | some e => translateStmt outputParams e | none => pure [] - return (env, [Imperative.Stmt.ite bcond bthen belse .empty]) - | .StaticCall name args => + return [Imperative.Stmt.ite bcond bthen belse .empty] + | .StaticCall callee args => -- Check if this is a function or procedure - if isCoreFunction functionNames name then + if model.isFunction callee then -- Functions as statements have no effect (shouldn't happen in well-formed programs) - return (env, []) + return [] else - let coreArgs ← args.mapM (fun a => translateExpr env a) - return (env, [Core.Statement.call [] name coreArgs md]) + let coreArgs ← args.mapM (fun a => translateExpr a) + return [Core.Statement.call [] callee.text coreArgs md] | .Return valueOpt => match valueOpt, outputParams.head? with | some value, some outParam => - let ident := ⟨outParam.name, ()⟩ - let coreExpr ← translateExpr env value + let ident := ⟨outParam.name.text, ()⟩ + let coreExpr ← translateExpr value let assignStmt := Core.Statement.set ident coreExpr md - return (env, [assignStmt, .exit (some "$body") md]) + return [assignStmt, .exit (some "$body") md] | none, _ => - return (env, [.exit (some "$body") md]) + return [.exit (some "$body") md] | some _, none => panic! "Return statement with value but procedure has no output parameters" | .While cond invariants decreasesExpr body => - let condExpr ← translateExpr env cond - let invExprs ← invariants.mapM (translateExpr env) - let decreasingExprCore ← decreasesExpr.mapM (translateExpr env) - let (_, bodyStmts) ← translateStmt env outputParams body - return (env, [Imperative.Stmt.loop condExpr decreasingExprCore invExprs bodyStmts md]) - | _ => return (env, []) + let condExpr ← translateExpr cond + let invExprs ← invariants.mapM (translateExpr) + let decreasingExprCore ← decreasesExpr.mapM (translateExpr) + let bodyStmts ← translateStmt outputParams body + return [Imperative.Stmt.loop condExpr decreasingExprCore invExprs bodyStmts md] + | _ => return [] termination_by sizeOf stmt decreasing_by all_goals @@ -377,20 +372,20 @@ def translateStmt (env : TypeEnv) (outputParams : List Parameter) (stmt : StmtEx Translate a list of checks (preconditions or postconditions) to Core checks. Each check gets a label like `"requires"` or `"requires_0"`, `"requires_1"`, etc. -/ -private def translateChecks (env : TypeEnv) (checks : List StmtExprMd) (labelBase : String) +private def translateChecks (checks : List StmtExprMd) (labelBase : String) : TranslateM (ListMap Core.CoreLabel Core.Procedure.Check) := checks.mapIdxM (fun i check => do let label := if checks.length == 1 then labelBase else s!"{labelBase}_{i}" - let checkExpr ← translateExpr env check [] (isPureContext := true) + let checkExpr ← translateExpr check [] (isPureContext := true) let c : Core.Procedure.Check := { expr := checkExpr, md := check.md } return (label, c)) /-- Translate Laurel Parameter to Core Signature entry -/ -def translateParameterToCore (param : Parameter) : (Core.CoreIdent × LMonoTy) := - let ident := ⟨param.name, ()⟩ - let ty := translateType param.type +def translateParameterToCore (model : SemanticModel) (param : Parameter) : (Core.CoreIdent × LMonoTy) := + let ident := ⟨param.name.text, ()⟩ + let ty := translateType model param.type (ident, ty) /-- @@ -399,31 +394,29 @@ Diagnostics from disallowed constructs in preconditions, postconditions, and bod are emitted into the monad state. -/ def translateProcedure (proc : Procedure) : TranslateM Core.Procedure := do - let inputPairs := proc.inputs.map translateParameterToCore + let inputPairs := proc.inputs.map (translateParameterToCore (← get).model) let inputs := inputPairs - let outputs := proc.outputs.map translateParameterToCore + let outputs := proc.outputs.map (translateParameterToCore (← get).model) let header : Core.Procedure.Header := { - name := proc.name + name := proc.name.text typeArgs := [] inputs := inputs outputs := outputs } - let initEnv : TypeEnv := proc.inputs.map (fun p => (p.name, p.type)) ++ - proc.outputs.map (fun p => (p.name, p.type)) -- Translate preconditions - let preconditions ← translateChecks initEnv proc.preconditions "requires" + let preconditions ← translateChecks proc.preconditions "requires" -- Translate postconditions for Opaque bodies let postconditions : ListMap Core.CoreLabel Core.Procedure.Check ← match proc.body with | .Opaque postconds _ _ => - translateChecks initEnv postconds "postcondition" + translateChecks postconds "postcondition" | _ => pure [] let modifies : List Core.Expression.Ident := [] let bodyStmts : List Core.Statement ← match proc.body with - | .Transparent bodyExpr => (·.2) <$> translateStmt initEnv proc.outputs bodyExpr - | .Opaque _postconds (some impl) _ => (·.2) <$> translateStmt initEnv proc.outputs impl + | .Transparent bodyExpr => translateStmt proc.outputs bodyExpr + | .Opaque _postconds (some impl) _ => translateStmt proc.outputs impl | _ => pure [Core.Statement.assume "no_body" (.const () (.boolConst false)) .empty] -- Wrap body in a labeled block so early returns (exit) work correctly. let body : List Core.Statement := [.block "$body" bodyStmts .empty] @@ -489,6 +482,7 @@ private def canBeCoreFunctionBody (proc : Procedure) : Bool := | .Opaque _ bodyExprOption _ => (bodyExprOption.map isPureExpr).getD true && proc.outputs.length == 1 + | .External => false | _ => false /-- @@ -496,24 +490,24 @@ Translate a Laurel Procedure to a Core Function (when applicable) using `Transla Diagnostics for disallowed constructs in the function body are emitted into the monad state. -/ def translateProcedureToFunction (proc : Procedure) : TranslateM Core.Decl := do - let inputs := proc.inputs.map translateParameterToCore + let model := (← get).model + let inputs := proc.inputs.map (translateParameterToCore model) let outputTy := match proc.outputs.head? with - | some p => translateType p.type + | some p => translateType model p.type | none => LMonoTy.int - let initEnv : TypeEnv := proc.inputs.map (fun p => (p.name, p.type)) -- Translate precondition to FuncPrecondition (skip trivial `true`) let preconditions ← proc.preconditions.mapM (fun precondition => do - let checkExpr ← translateExpr initEnv precondition [] true + let checkExpr ← translateExpr precondition [] true return { expr := checkExpr, md := () }) let body ← match proc.body with - | .Transparent bodyExpr => some <$> translateExpr initEnv bodyExpr [] (isPureContext := true) + | .Transparent bodyExpr => some <$> translateExpr bodyExpr [] (isPureContext := true) | .Opaque _ (some bodyExpr) _ => emitDiagnostic (proc.md.toDiagnostic "functions with postconditions are not yet supported") - some <$> translateExpr initEnv bodyExpr [] (isPureContext := true) + some <$> translateExpr bodyExpr [] (isPureContext := true) | _ => pure none return .func { - name := ⟨proc.name, ()⟩ + name := ⟨proc.name.text, ()⟩ typeArgs := [] inputs := inputs output := outputTy @@ -525,18 +519,18 @@ def translateProcedureToFunction (proc : Procedure) : TranslateM Core.Decl := do Translate a Laurel DatatypeDefinition to a Core type declaration. Zero constructors produces an opaque (abstract) type; otherwise a Core datatype. -/ -def translateDatatypeDefinition (dt : DatatypeDefinition) : Core.Decl := +def translateDatatypeDefinition (model : SemanticModel) (dt : DatatypeDefinition) : Core.Decl := match h : dt.constructors with | [] => -- Zero constructors: opaque type - Core.Decl.type (.con { name := dt.name, numargs := dt.typeArgs.length }) + Core.Decl.type (.con { name := dt.name.text, numargs := dt.typeArgs.length }) | first :: rest => let constrs : List (Lambda.LConstr Unit) := (first :: rest).map fun c => - { name := ⟨c.name, ()⟩ - args := c.args.map fun (n, ty) => (⟨n, ()⟩, translateType ty) } + { name := ⟨c.name.text, ()⟩ + args := c.args.map fun ⟨ n, ty ⟩ => (⟨n.text, ()⟩, translateType model ty) } let ldt : Lambda.LDatatype Unit := { - name := dt.name - typeArgs := dt.typeArgs + name := dt.name.text + typeArgs := dt.typeArgs.map (fun id => id.text) constrs := constrs constrs_ne := by simp [constrs] } @@ -559,24 +553,41 @@ def tryTranslatePureToFunction (proc : Procedure) (initState : TranslateState) Translate Laurel Program to Core Program -/ def translate (program : Program) : Except (Array DiagnosticModel) (Core.Program × Array DiagnosticModel) := do - let program := heapParameterization program - let program := typeHierarchyTransform program - let (program, modifiesDiags) := modifiesClausesTransform program - dbg_trace "=== Program after heapParameterization + modifiesClausesTransform ===" - dbg_trace (toString (Std.Format.pretty (Std.ToFormat.format program))) - dbg_trace "=================================" - let program := liftImperativeExpressions program - -- Collect field names from the Field datatype (generated by heapParameterization) - let fieldNames : List Identifier := program.types.foldl (fun acc td => - match td with - | .Datatype dt => if dt.name == "Field" then acc ++ dt.constructors.map (·.name) else acc - | _ => acc) [] + let program := { program with + staticProcedures := coreDefinitionsForLaurel.staticProcedures ++ program.staticProcedures + } + + let result := resolve program + let (program, model) := (result.program, result.model) + let mut _resolutionDiags := result.errors + let diamondErrors := validateDiamondFieldAccesses model program + + let program := heapParameterization model program + let result := resolve program (some model) + let (program, model) := (result.program, result.model) + _resolutionDiags := _resolutionDiags ++ result.errors + + let program := typeHierarchyTransform model program + let result := resolve program (some model) + let (program, model) := (result.program, result.model) + _resolutionDiags := _resolutionDiags ++ result.errors + let (program, modifiesDiags) := modifiesClausesTransform model program + let result := resolve program (some model) + let (program, model) := (result.program, result.model) + _resolutionDiags := _resolutionDiags ++ result.errors + -- dbg_trace "=== Program after heapParameterization + modifiesClausesTransform ===" + -- dbg_trace (toString (Std.Format.pretty (Std.ToFormat.format program))) + -- dbg_trace "=================================" + let program := liftExpressionAssignments model program + let result := resolve program (some model) + let (program, model) := (result.program, result.model) + _resolutionDiags := _resolutionDiags ++ result.errors -- Procedures marked isFunctional are translated to Core functions; all others become Core procedures. - let (markedPure, procProcs) := program.staticProcedures.partition (·.isFunctional) - -- Build the shared initial state with constants and function names - let funcNames : FunctionNames := markedPure.map (·.name) - let initState : TranslateState := { fieldNames := fieldNames, funcNames } + -- External procedures are completely ignored (not translated to Core). + let nonExternal := program.staticProcedures.filter (fun p => !p.body.isExternal) + let (markedPure, procProcs) := nonExternal.partition (·.isFunctional) + let initState : TranslateState := { model := model } -- Try to translate each isFunctional procedure to a Core function, collecting errors for failures let (pureErrors, pureFuncDecls) := markedPure.foldl (fun (errs, decls) p => match tryTranslatePureToFunction p initState with @@ -589,31 +600,38 @@ def translate (program : Program) : Except (Array DiagnosticModel) (Core.Program -- Translate Laurel constants to Core function declarations (0-ary functions) let (constantDecls, constantsState) := runTranslateM initState $ program.constants.mapM fun c => do - let coreTy := translateType c.type - let body ← c.initializer.mapM (translateExpr [] ·) + let coreTy := translateType model c.type + let body ← c.initializer.mapM (translateExpr ·) return Core.Decl.func { - name := ⟨c.name, ()⟩ + name := ⟨c.name.text, ()⟩ typeArgs := [] inputs := [] output := coreTy body := body } - -- Collect ALL errors from both functions and procedures before deciding whether to fail - let allErrors := pureErrors ++ procDiags ++ constantsState.diagnostics + -- Collect ALL errors from both functions, procedures, and resolution before deciding whether to fail + let allErrors := + -- Not including resolution diagnostics yet because the Python through Laurel pipeline + -- does not resolve yet. + -- resolutionDiags.toList ++ + pureErrors ++ procDiags ++ constantsState.diagnostics if !allErrors.isEmpty then .error allErrors.toArray let procDecls := procedures.map (fun p => Core.Decl.proc p .empty) - -- Filter out the Field and TypeTag opaque types. These are only in the prelude to satisfy the DDM type checker. - let preludeDecls := corePrelude.decls.filter fun d => - d.name.name != "Field" && d.name.name != "TypeTag" - -- Translate Laurel datatype definitions to Core datatype declarations let laurelDatatypeDecls := program.types.filterMap fun td => match td with - | .Datatype dt => some (translateDatatypeDefinition dt) + | .Datatype dt => some (translateDatatypeDefinition model dt) | _ => none - pure ({ decls := laurelDatatypeDecls ++ preludeDecls ++ constantDecls ++ pureFuncDecls.toList ++ procDecls }, modifiesDiags) + let program := { + decls := laurelDatatypeDecls ++ constantDecls ++ pureFuncDecls.toList ++ procDecls + } + + -- dbg_trace "=== Generated Strata Core Program ===" + -- dbg_trace (toString (Std.Format.pretty (Strata.Core.formatProgram program) 100)) + -- dbg_trace "=================================" + pure (program, diamondErrors ++ modifiesDiags) /-- Verify a Laurel program using an SMT solver @@ -628,9 +646,6 @@ def verifyToVcResults (program : Program) -- Enable removeIrrelevantAxioms to avoid polluting simple assertions with heap axioms let options := { options with removeIrrelevantAxioms := true } -- Debug: Print the generated Strata Core program - dbg_trace "=== Generated Strata Core Program ===" - dbg_trace (toString (Std.Format.pretty (Strata.Core.formatProgram strataCoreProgram) 100)) - dbg_trace "=================================" let runner tempDir := EIO.toIO (fun f => IO.Error.userError (toString f)) (Core.verify strataCoreProgram tempDir .none options) @@ -645,11 +660,6 @@ def verifyToVcResults (program : Program) def verifyToDiagnostics (files: Map Strata.Uri Lean.FileMap) (program : Program) (options : VerifyOptions := .default): IO (Array Diagnostic) := do - -- Validate for diamond-inherited field accesses before translation - let uri := files.keys.head! - let diamondErrors := validateDiamondFieldAccesses uri program - if !diamondErrors.isEmpty then - return diamondErrors.map (fun dm => dm.toDiagnostic files) let results <- verifyToVcResults program options match results with | .error errors => return errors.map (fun dm => dm.toDiagnostic files) diff --git a/Strata/Languages/Laurel/LaurelTypes.lean b/Strata/Languages/Laurel/LaurelTypes.lean index feec05c08..6af4c786e 100644 --- a/Strata/Languages/Laurel/LaurelTypes.lean +++ b/Strata/Languages/Laurel/LaurelTypes.lean @@ -6,6 +6,7 @@ import Strata.Languages.Laurel.Laurel import Strata.Languages.Laurel.LaurelFormat +import Strata.Languages.Laurel.Resolution import Strata.Util.Tactics /- @@ -17,34 +18,12 @@ no inference is performed. namespace Strata.Laurel -abbrev TypeEnv := List (Identifier × HighTypeMd) - -/-- -Look up a field's type in a composite type by name, walking the inheritance chain. --/ -def lookupFieldInTypes (types : List TypeDefinition) (typeName : Identifier) (fieldName : Identifier) : Option HighTypeMd := - let rec go (fuel : Nat) (current : Identifier) : Option HighTypeMd := - match fuel with - | 0 => none - | fuel' + 1 => - types.findSome? fun td => - match td with - | .Composite ct => - if ct.name == current then - match ct.fields.findSome? fun f => if f.name == fieldName then some f.type else none with - | some ty => some ty - | none => ct.extending.findSome? (go fuel') - else none - | _ => none - go types.length typeName - /-- Compute the HighType of a StmtExpr given a type environment, type definitions, and procedure list. No inference is performed — all types are determined by annotations on parameters and variable declarations. -/ -def computeExprType (env : TypeEnv) (types : List TypeDefinition) (expr : StmtExprMd) - (procedures : List Procedure := []) : HighTypeMd := +def computeExprType (model : SemanticModel) (expr : StmtExprMd) : HighTypeMd := match expr with | WithMetadata.mk val md => match val with @@ -53,25 +32,20 @@ def computeExprType (env : TypeEnv) (types : List TypeDefinition) (expr : StmtEx | .LiteralBool _ => ⟨ .TBool, md ⟩ | .LiteralString _ => ⟨ .TString, md ⟩ -- Variables - | .Identifier name => - match env.find? (fun (n, _) => n == name) with - | some (_, ty) => ty - | none => panic s!"Could not find variable {name} in environment '{Std.format env}'" + | .Identifier id => (model.get id).getType.getD (panic "computeExprType1") -- Field access - | .FieldSelect target fieldName => - match computeExprType env types target procedures with - | WithMetadata.mk (.UserDefined typeName) _ => - match lookupFieldInTypes types typeName fieldName with - | some ty => ty - | none => panic s!"Could not find field {fieldName} in type {typeName} or its ancestors" - | _ => panic s!"Selecting from a type that's not a composite" + | .FieldSelect _ fieldName => (model.get fieldName).getType.getD (panic "computeExprType2") -- Pure field update returns the same type as the target - | .PureFieldUpdate target _ _ => computeExprType env types target procedures - -- Calls — look up return type from first output of matching procedure - | .StaticCall name _ => - match procedures.find? (·.name == name) with - | some proc => proc.outputs.head?.map (·.type) |>.getD ⟨ .TVoid, md ⟩ - | none => ⟨ .TVoid, md ⟩ + | .PureFieldUpdate target _ _ => computeExprType model target + -- Calls — we don't track return types here, so fall back to TVoid + | .StaticCall callee _ => match model.get callee with + | .staticProcedure proc => match proc.outputs with + | [singleOutput] => singleOutput.type + | _ => { val := .TVoid, md := default } + | .unresolved => + -- The Python through Laurel pipeline does not resolve yet + { val := .TVoid, md := default } + | astNode => panic! s!"static call to {callee} not to a procedure but to a {repr astNode}" | .InstanceCall _ _ _ => panic "Not supported InstanceCall" -- Operators | .PrimitiveOp op _ => @@ -80,18 +54,18 @@ def computeExprType (env : TypeEnv) (types : List TypeDefinition) (expr : StmtEx | .Neg | .Add | .Sub | .Mul | .Div | .Mod | .DivT | .ModT => ⟨ .TInt, md ⟩ | .StrConcat => ⟨ .TString, md ⟩ -- Control flow - | .IfThenElse _ thenBranch _ => computeExprType env types thenBranch procedures + | .IfThenElse _ thenBranch _ => computeExprType model thenBranch | .Block stmts _ => match _blockGetLastResult: stmts.getLast? with | some last => have := List.mem_of_getLast? _blockGetLastResult - computeExprType env types last procedures + computeExprType model last | none => ⟨ .TVoid, md ⟩ - -- Statements (void-typed) + -- Statements | .LocalVariable _ _ _ => ⟨ .TVoid, md ⟩ | .While _ _ _ _ => ⟨ .TVoid, md ⟩ | .Exit _ => ⟨ .TVoid, md ⟩ | .Return _ => ⟨ .TVoid, md ⟩ - | .Assign _ _ => ⟨ .TVoid, md ⟩ + | .Assign _ value => computeExprType model value | .Assert _ => ⟨ .TVoid, md ⟩ | .Assume _ => ⟨ .TVoid, md ⟩ -- Instance related @@ -101,13 +75,13 @@ def computeExprType (env : TypeEnv) (types : List TypeDefinition) (expr : StmtEx | .AsType _ ty => ty | .IsType _ _ => ⟨ .TBool, md ⟩ -- Verification specific - | .Forall _ _ _ => ⟨ .TBool, md ⟩ - | .Exists _ _ _ => ⟨ .TBool, md ⟩ + | .Forall _ _ => ⟨ .TBool, md ⟩ + | .Exists _ _ => ⟨ .TBool, md ⟩ | .Assigned _ => ⟨ .TBool, md ⟩ - | .Old v => computeExprType env types v procedures + | .Old v => computeExprType model v | .Fresh _ => ⟨ .TBool, md ⟩ -- Proof related - | .ProveBy v _ => computeExprType env types v procedures + | .ProveBy v _ => computeExprType model v | .ContractOf _ _ => panic "Not supported" -- Special | .Abstract => panic "Not supported" diff --git a/Strata/Languages/Laurel/LiftImperativeExpressions.lean b/Strata/Languages/Laurel/LiftImperativeExpressions.lean index 027eb3380..6e5520a40 100644 --- a/Strata/Languages/Laurel/LiftImperativeExpressions.lean +++ b/Strata/Languages/Laurel/LiftImperativeExpressions.lean @@ -73,13 +73,9 @@ structure LiftState where /-- Substitution map: variable name → name to use -/ subst : SubstMap := [] /-- Type environment -/ - env : TypeEnv := [] - /-- Type definitions from the program -/ - types : List TypeDefinition := [] + model : SemanticModel /-- Global counter for fresh conditional variables -/ condCounter : Nat := 0 - /-- Names of imperative procedures whose calls must be lifted from expression position -/ - imperativeNames : List Identifier := [] /-- All procedures in the program, used to look up return types of imperative calls -/ procedures : List Procedure := [] @@ -97,7 +93,7 @@ private def freshTempFor (varName : Identifier) : LiftM Identifier := do let counters := (← get).varCounters let counter := counters.find? (·.1 == varName) |>.map (·.2) |>.getD 0 modify fun s => { s with varCounters := (varName, counter + 1) :: s.varCounters.filter (·.1 != varName) } - return s!"${varName}_{counter}" + return s!"${varName.text}_{counter}" private def freshCondVar : LiftM Identifier := do let n := (← get).condCounter @@ -112,42 +108,35 @@ private def takePrepends : LiftM (List StmtExprMd) := do modify fun s => { s with prependedStmts := [] } return stmts -private def getVarType (varName : Identifier) : LiftM HighTypeMd := do - let env := (← get).env - match env.find? (fun (n, _) => n == varName) with - | some (_, ty) => return ty - | none => panic s!"Could not find {varName} in environment." - -private def addToEnv (varName : Identifier) (ty : HighTypeMd) : LiftM Unit := - modify fun s => { s with env := (varName, ty) :: s.env } - private def getSubst (varName : Identifier) : LiftM Identifier := do - match (← get).subst.find? varName with + match (← get).subst.lookup varName with | some mapped => return mapped | none => return varName private def setSubst (varName : Identifier) (value : Identifier) : LiftM Unit := - modify fun s => { s with subst := s.subst.insert varName value } + modify fun s => { s with subst := ⟨ varName, value ⟩ :: s.subst } private def computeType (expr : StmtExprMd) : LiftM HighTypeMd := do let s ← get - return computeExprType s.env s.types expr s.procedures + return computeExprType s.model expr /-- Check if an expression contains any assignments or imperative calls (recursively). -/ -private def containsAssignmentOrImperativeCall (imperativeNames : List Identifier) (expr : StmtExprMd) : Bool := +private def containsAssignmentOrImperativeCall (model: SemanticModel) (expr : StmtExprMd) : Bool := match expr with | WithMetadata.mk val _ => match val with | .Assign .. => true | .StaticCall name args1 => - imperativeNames.contains name || - args1.attach.any (fun x => containsAssignmentOrImperativeCall imperativeNames x.val) - | .PrimitiveOp _ args2 => args2.attach.any (fun x => containsAssignmentOrImperativeCall imperativeNames x.val) - | .Block stmts _ => stmts.attach.any (fun x => containsAssignmentOrImperativeCall imperativeNames x.val) + (match model.get name with + | .staticProcedure proc => !proc.isFunctional + | _ => false) || + args1.attach.any (fun x => containsAssignmentOrImperativeCall model x.val) + | .PrimitiveOp _ args2 => args2.attach.any (fun x => containsAssignmentOrImperativeCall model x.val) + | .Block stmts _ => stmts.attach.any (fun x => containsAssignmentOrImperativeCall model x.val) | .IfThenElse cond th el => - containsAssignmentOrImperativeCall imperativeNames cond || - containsAssignmentOrImperativeCall imperativeNames th || - match el with | some e => containsAssignmentOrImperativeCall imperativeNames e | none => false + containsAssignmentOrImperativeCall model cond || + containsAssignmentOrImperativeCall model th || + match el with | some e => containsAssignmentOrImperativeCall model e | none => false | _ => false termination_by expr decreasing_by @@ -167,7 +156,7 @@ private def liftAssignExpr (targets : List StmtExprMd) (seqValue : StmtExprMd) match target.val with | .Identifier varName => let snapshotName ← freshTempFor varName - let varType ← getVarType varName + let varType ← computeType target -- Snapshot goes before the assignment (cons pushes to front) addPrepend (⟨.LocalVariable snapshotName varType (some (⟨.Identifier varName, md⟩)), md⟩) setSubst varName snapshotName @@ -206,28 +195,27 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do let seqArgs ← args.reverse.mapM transformExpr return ⟨.PrimitiveOp op seqArgs.reverse, md⟩ - | .StaticCall name args => - let imperative := (← get).imperativeNames - let seqArgs ← args.reverse.mapM transformExpr - let seqCall := ⟨.StaticCall name seqArgs.reverse, md⟩ - if imperative.contains name then - -- Imperative call in expression position: lift it like an assignment - -- Order matters: assign must be prepended first (it's newest-first), - -- so that when reversed the var declaration comes before the call. - let callResultVar ← freshCondVar - let callResultType ← computeType expr - addToEnv callResultVar callResultType - addPrepend (⟨.Assign [bare (.Identifier callResultVar)] seqCall, md⟩) - addPrepend (bare (.LocalVariable callResultVar callResultType none)) - return bare (.Identifier callResultVar) - else - return seqCall + | .StaticCall callee args => + let model := (← get).model + let seqArgs ← args.reverse.mapM transformExpr + let seqCall := ⟨.StaticCall callee seqArgs.reverse, md⟩ + if model.isFunction callee then + return seqCall + else + -- Imperative call in expression position: lift it like an assignment + -- Order matters: assign must be prepended first (it's newest-first), + -- so that when reversed the var declaration comes before the call. + let callResultVar ← freshCondVar + let callResultType ← computeType expr + addPrepend (⟨.Assign [bare (.Identifier callResultVar)] seqCall, md⟩) + addPrepend (bare (.LocalVariable callResultVar callResultType none)) + return bare (.Identifier callResultVar) | .IfThenElse cond thenBranch elseBranch => - let imperative := (← get).imperativeNames - let thenHasAssign := containsAssignmentOrImperativeCall imperative thenBranch + let model := (← get).model + let thenHasAssign := containsAssignmentOrImperativeCall model thenBranch let elseHasAssign := match elseBranch with - | some e => containsAssignmentOrImperativeCall imperative e + | some e => containsAssignmentOrImperativeCall model e | none => false if thenHasAssign || elseHasAssign then -- Lift the entire if-then-else. Introduce a fresh variable for the result. @@ -251,8 +239,10 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do | none => pure none -- Restore outer state modify fun s => { s with subst := savedSubst, prependedStmts := savedPrepends } - -- Infer type from the then-branch result - let condType ← computeType seqThen + -- Infer type from the ORIGINAL then-branch (not the transformed one), + -- because the transformed expression may reference freshly generated + -- variables (e.g. $c_2) that don't exist in the SemanticModel yet. + let condType ← computeType thenBranch -- IfThenElse added first (cons puts it deeper), then declaration (cons puts it on top) -- Output order: declaration, then if-then-else addPrepend (⟨.IfThenElse seqCond thenBlock seqElse, md⟩) @@ -274,14 +264,6 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do | some last => do have := List.mem_of_getLast? h_last - -- Pre-populate the environment with all LocalVariable declarations - -- so that getVarType works when creating snapshots - for s in stmts do - match s with - | WithMetadata.mk val _ => - match val with - | .LocalVariable name ty _ => addToEnv name ty - | _ => pure () -- Process all-but-last as statements and prepend them in order let mut blockStmts : List StmtExprMd := [] for nonLastStatement in stmts.dropLast.attach do @@ -292,12 +274,10 @@ def transformExpr (expr : StmtExprMd) : LiftM StmtExprMd := do transformExpr last | .LocalVariable name ty initializer => - -- Add the variable to the environment - addToEnv name ty -- If the substitution map has an entry for this variable, it was -- assigned to the right and we need to lift this declaration so it -- appears before the snapshot that references it. - let hasSubst := (← get).subst.find? name |>.isSome + let hasSubst := (← get).subst.lookup name |>.isSome if hasSubst then match initializer with | some initExpr => @@ -336,55 +316,54 @@ def transformStmt (stmt : StmtExprMd) : LiftM (List StmtExprMd) := do return [bare (.Block seqStmts.flatten metadata)] | .LocalVariable name ty initializer => - addToEnv name ty - match _: initializer with - | some initExpr => + match _ : initializer with + | some initExprMd => -- If the initializer is a direct imperative StaticCall, don't lift it — -- translateStmt handles LocalVariable + StaticCall directly as a call statement. + match _: initExprMd with + | WithMetadata.mk initExpr md => match _: initExpr with - | WithMetadata.mk initExprVal md => - match _: initExprVal with | .StaticCall callee args => - let imperative := (← get).imperativeNames - if imperative.contains callee then + let model := (← get).model + if model.isFunction callee then + let seqInit ← transformExpr initExprMd + let prepends ← takePrepends + modify fun s => { s with subst := [] } + return prepends ++ [⟨.LocalVariable name ty (some seqInit), md⟩] + else -- Pass through as-is; translateStmt will emit init + call let seqArgs ← args.mapM transformExpr let argPrepends ← takePrepends modify fun s => { s with subst := [] } - return argPrepends ++ [⟨.LocalVariable name ty (some ⟨.StaticCall callee seqArgs, initExpr.md⟩), md⟩] - else - let seqInit ← transformExpr initExpr - let prepends ← takePrepends - modify fun s => { s with subst := [] } - return prepends ++ [⟨.LocalVariable name ty (some seqInit), md⟩] + return argPrepends ++ [⟨.LocalVariable name ty (some ⟨.StaticCall callee seqArgs, initExprMd.md⟩), md⟩] | _ => - let seqInit ← transformExpr initExpr + let seqInit ← transformExpr initExprMd let prepends ← takePrepends modify fun s => { s with subst := [] } return prepends ++ [⟨.LocalVariable name ty (some seqInit), md⟩] | none => return [stmt] - | .Assign targets value => + | .Assign targets valueMd => -- If the RHS is a direct imperative StaticCall, don't lift it — -- translateStmt handles Assign + StaticCall directly as a call statement. - match _:value with - | WithMetadata.mk val valueMd => - match _:val with + match _: valueMd with + | WithMetadata.mk value _ => + match _: value with | .StaticCall callee args => - let imperative := (← get).imperativeNames - if imperative.contains callee then - let seqArgs ← args.mapM transformExpr - let argPrepends ← takePrepends - modify fun s => { s with subst := [] } - return argPrepends ++ [⟨.Assign targets ⟨.StaticCall callee seqArgs, valueMd⟩, md⟩] - else - let seqValue ← transformExpr value + let model := (← get).model + if model.isFunction callee then + let seqValue ← transformExpr valueMd let prepends ← takePrepends modify fun s => { s with subst := [] } return prepends ++ [⟨.Assign targets seqValue, md⟩] + else + let seqArgs ← args.mapM transformExpr + let argPrepends ← takePrepends + modify fun s => { s with subst := [] } + return argPrepends ++ [⟨.Assign targets ⟨.StaticCall callee seqArgs, md⟩, md⟩] | _ => - let seqValue ← transformExpr value + let seqValue ← transformExpr valueMd let prepends ← takePrepends modify fun s => { s with subst := [] } return prepends ++ [⟨.Assign targets seqValue, md⟩] @@ -419,7 +398,8 @@ def transformStmt (stmt : StmtExprMd) : LiftM (List StmtExprMd) := do return [stmt] termination_by (sizeOf stmt, 0) decreasing_by - all_goals (term_by_mem) + all_goals (try term_by_mem) + all_goals (apply Prod.Lex.left; try term_by_mem) end def transformProcedureBody (body : StmtExprMd) : LiftM StmtExprMd := do @@ -429,10 +409,7 @@ def transformProcedureBody (body : StmtExprMd) : LiftM StmtExprMd := do | multiple => pure (bare (.Block multiple none)) def transformProcedure (proc : Procedure) : LiftM Procedure := do - let initEnv : TypeEnv := - proc.inputs.map (fun p => (p.name, p.type)) ++ - proc.outputs.map (fun p => (p.name, p.type)) - modify fun s => { s with subst := [], prependedStmts := [], varCounters := [], env := initEnv } + modify fun s => { s with subst := [], prependedStmts := [], varCounters := [] } match proc.body with | .Transparent bodyExpr => let seqBody ← transformProcedureBody bodyExpr @@ -442,13 +419,14 @@ def transformProcedure (proc : Procedure) : LiftM Procedure := do pure { proc with body := .Opaque postconds impl' modif } | .Abstract _ => pure proc + | .External => + pure proc /-- Transform a program to lift all assignments that occur in an expression context. -/ -def liftImperativeExpressions (program : Program) : Program := - let imperativeNames := program.staticProcedures.filter (fun p => !p.isFunctional) |>.map (·.name) - let initState : LiftState := { types := program.types, imperativeNames := imperativeNames, procedures := program.staticProcedures } +def liftExpressionAssignments (model: SemanticModel) (program : Program) : Program := + let initState : LiftState := { model := model } let (seqProcedures, _) := (program.staticProcedures.mapM transformProcedure).run initState { program with staticProcedures := seqProcedures } diff --git a/Strata/Languages/Laurel/ModifiesClauses.lean b/Strata/Languages/Laurel/ModifiesClauses.lean index 398c5d01a..444b68f56 100644 --- a/Strata/Languages/Laurel/ModifiesClauses.lean +++ b/Strata/Languages/Laurel/ModifiesClauses.lean @@ -7,6 +7,7 @@ import Strata.Languages.Laurel.Laurel import Strata.Languages.Laurel.LaurelTypes import Strata.Languages.Core.Verifier +import Strata.Languages.Laurel.Resolution /- Modifies clause transformation (Laurel → Laurel). @@ -47,10 +48,10 @@ inductive ModifiesEntry where Extract modifies entries from the list of modifies StmtExprs, using the type environment and type definitions to distinguish Composite from Set Composite. -/ -def extractModifiesEntries (env : TypeEnv) (types : List TypeDefinition) +def extractModifiesEntries (model: SemanticModel) (modifiesExprs : List StmtExprMd) : List ModifiesEntry := modifiesExprs.map fun expr => - match (computeExprType env types expr).val with + match (computeExprType model expr).val with | .TSet _ => .set expr | _ => .single expr /-- @@ -83,12 +84,11 @@ Generates a single quantified formula: Returns `none` if there are no entries. -/ -def buildModifiesEnsures (proc: Procedure) (env : TypeEnv) - (types : List TypeDefinition) (modifiesExprs : List StmtExprMd) +def buildModifiesEnsures (proc: Procedure) (model: SemanticModel) (modifiesExprs : List StmtExprMd) (heapInName heapOutName : Identifier) : Option StmtExprMd := - let entries := extractModifiesEntries env types modifiesExprs - let objName := "$modifies_obj" - let fldName := "$modifies_fld" + let entries := extractModifiesEntries model modifiesExprs + let objName : Identifier := "$modifies_obj" + let fldName : Identifier := "$modifies_fld" let obj := mkMd <| .Identifier objName let fld := mkMd <| .Identifier fldName let heapIn := mkMd <| .Identifier heapInName @@ -111,8 +111,8 @@ def buildModifiesEnsures (proc: Procedure) (env : TypeEnv) -- Build: antecedent ==> heapUnchanged let implBody := mkMd <| .PrimitiveOp .Implies [antecedent, heapUnchanged] -- Build: forall $obj: Composite, $fld: Field => ... - let innerForall := mkMd <| .Forall fldName (⟨ .TTypedField ⟨.TInt, .empty⟩, .empty ⟩) implBody - let outerForall := ⟨ .Forall objName (⟨ .UserDefined "Composite", .empty ⟩) innerForall, proc.md ⟩ + let innerForall := mkMd <| .Forall ⟨ fldName, (⟨ .TTypedField ⟨.TInt, .empty⟩, .empty ⟩) ⟩ implBody + let outerForall := ⟨ .Forall ⟨ objName, (⟨ .UserDefined "Composite", .empty ⟩) ⟩ innerForall, proc.md ⟩ some outerForall /-- @@ -120,7 +120,7 @@ Check whether a procedure has a `$heap` output parameter, indicating it mutates the heap. -/ def hasHeapOut (proc : Procedure) : Bool := - proc.outputs.any (fun p => p.name == "$heap") + proc.outputs.any (fun p => p.name.text == "$heap") /-- Transform a single procedure: if it has modifies clauses, generate the frame @@ -130,17 +130,15 @@ If the procedure has a `$heap` but no modifies clause, adds a postcondition that all allocated objects are preserved between heaps: `forall $obj: Composite, $fld: Field => $obj < $heap_in.nextReference ==> readField($heap_in, $obj, $fld) == readField($heap, $obj, $fld)` -/ -def transformModifiesClauses (constants : List Constant) (types : List TypeDefinition) +def transformModifiesClauses (model: SemanticModel) (proc : Procedure) : Except (Array DiagnosticModel) Procedure := match proc.body with + | .External => .ok proc | .Opaque postconds impl modifiesExprs => if hasHeapOut proc then - let env : TypeEnv := proc.inputs.map (fun p => (p.name, p.type)) ++ - proc.outputs.map (fun p => (p.name, p.type)) ++ - constants.map (fun c => (c.name, c.type)) - let heapInName := "$heap_in" - let heapName := "$heap" - let frameCondition := buildModifiesEnsures proc env types modifiesExprs heapInName heapName + let heapInName : Identifier := "$heap_in" + let heapName : Identifier := "$heap" + let frameCondition := buildModifiesEnsures proc model modifiesExprs heapInName heapName let postconds' := match frameCondition with | some frame => postconds ++ [frame] | none => postconds @@ -156,9 +154,9 @@ This is a Laurel → Laurel pass that should run after heap parameterization. Always returns the (best-effort) transformed program together with any diagnostics, so that later passes can continue and report additional errors. -/ -def modifiesClausesTransform (program : Program) : Program × Array DiagnosticModel := +def modifiesClausesTransform (model: SemanticModel) (program : Program) : Program × Array DiagnosticModel := let (procs', errors) := program.staticProcedures.foldl (fun (acc, errs) proc => - match transformModifiesClauses program.constants program.types proc with + match transformModifiesClauses model proc with | .ok proc' => (acc ++ [proc'], errs) | .error newErrs => (acc ++ [proc], errs ++ newErrs.toList) ) ([], []) diff --git a/Strata/Languages/Laurel/Resolution.lean b/Strata/Languages/Laurel/Resolution.lean new file mode 100644 index 000000000..ca3891dd6 --- /dev/null +++ b/Strata/Languages/Laurel/Resolution.lean @@ -0,0 +1,731 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +import Strata.Languages.Laurel.Laurel +import Strata.Languages.Laurel.LaurelFormat +import Strata.Util.Tactics + +/-! +# Name Resolution Pass + +Assigns a unique numeric ID to every definition and reference node in a +Laurel program, then resolves references to their definitions. + +## Design + +The resolution pass operates in two phases: + +### Phase 1: ID Assignment and Reference Resolution +Walks the AST, assigning fresh unique IDs to all definition nodes and +resolving references by looking up names in the current lexical scope. +After this phase, every definition and reference node has its `id` field +filled in. + +### Phase 2: Build refToDef Map +Walks the *resolved* AST (where all definitions already have their UUIDs) +and builds a map from each definition's ID to its `AstNode`. Because this +happens after Phase 1, the `AstNode` values in the map contain the fully +resolved sub-trees (e.g. a procedure's parameters already have their IDs). + +### Definition nodes (introduce a name into scope) +- `StmtExpr.LocalVariable` — local variable declaration +- `StmtExpr.Forall` / `StmtExpr.Exists` — quantifier-bound variable +- `Parameter` — procedure parameter +- `Procedure` — procedure definition +- `Field` — field on a composite type +- `CompositeType` / `ConstrainedType` / `DatatypeDefinition` — type definitions +- `DatatypeConstructor` — datatype constructor +- `Constant` — named constant + +### Reference nodes (use a name) +- `StmtExpr.Identifier` — variable reference +- `StmtExpr.StaticCall` — static procedure call +- `StmtExpr.InstanceCall` — instance method call +- `StmtExpr.FieldSelect` — field access +- `StmtExpr.New` — object creation (references a type) +- `StmtExpr.Exit` — exit a labelled block +- `HighType.UserDefined` — type reference + +Each of these nodes carries an `id : Nat` field (defaulting to `0`). +The ID assignment pass fills in unique values. The resolution pass then +builds a map from reference IDs to `AstNode` values describing the +definition each reference resolves to. +-/ + +namespace Strata.Laurel + +/-! ## AstNode — the target of a resolved reference -/ + +/-- A definition-site AST node that a reference can resolve to. -/ +inductive AstNode where + /-- A local variable declaration. -/ + | var (name : Identifier) (type : HighTypeMd) + /-- A procedure parameter. -/ + | parameter (param : Parameter) + /-- A static procedure. -/ + | staticProcedure (proc : Procedure) + /-- An instance procedure (method) on a composite type. -/ + | instanceProcedure (typeName : Identifier) (proc : Procedure) + /-- A field on a composite type. -/ + | field (typeName : Identifier) (fld : Field) + /-- A composite type definition. -/ + | compositeType (ty : CompositeType) + /-- A constrained type definition. -/ + | constrainedType (ty : ConstrainedType) + /-- A datatype definition. -/ + | datatypeDefinition (ty : DatatypeDefinition) + /-- A datatype constructor. -/ + | datatypeConstructor (typeName : Identifier) (ctor : DatatypeConstructor) + /-- A constant. -/ + | constant (c : Constant) + /-- A quantifier-bound variable. -/ + | quantifierVar (name : Identifier) (type : HighTypeMd) + | unresolved + deriving Repr + +instance : Inhabited AstNode where + default := AstNode.unresolved + +def AstNode.getType (node: AstNode): Option HighTypeMd := match node with + | .var _ type => type + | .parameter p => p.type + | .field _ f => f.type + | .datatypeConstructor type _ => some ⟨ .UserDefined type, default ⟩ + | .constant c => c.type + | .unresolved => + -- The Python through Laurel pipeline does not resolve yet + some ⟨ .UserDefined "dummyName", default ⟩ + | _ => panic! s!"getType called on {repr node}" + +/-! ## Resolution result -/ + +structure SemanticModel where + nextId: Nat + compositeCount: Nat + refToDef: Std.HashMap Nat AstNode + deriving Repr + +def SemanticModel.get (model: SemanticModel) (iden: Identifier): AstNode := + match iden.uniqueId with + | some key => (model.refToDef.get? key).getD (panic! s!"could not find key {key}") + | none => default -- panic! s!"model.get called on identifier {iden.text} without number" + +def SemanticModel.isFunction (model: SemanticModel) (id: Identifier): Bool := + if id.uniqueId == none then + -- The Python pipeline generates constructor/discriminator calls that may not + -- be resolved at the Laurel level. Treating them as functions keeps them as + -- expressions; any real errors will be caught during Core type checking. + -- Make an exception for 'test_helper_procedure' since it's a procedure + -- We will remove this hack when we enable the Python through Laurel pipeline to correctly resolve + id.text != "test_helper_procedure" + else + match model.get id with + | .staticProcedure proc => proc.isFunctional + | .parameter _ => true + | .datatypeConstructor _ _ => true + | .constant _ => true + | node => panic! s!"id: {repr id}, is not a procedure, node {repr node}" + +/-- The output of the resolution pass. -/ +structure ResolutionResult where + /-- The program with unique IDs on all definition and reference nodes. -/ + program : Program + /-- Map from reference node ID to the definition it resolves to. -/ + model : SemanticModel + /-- Diagnostics collected during resolution (e.g. unresolved references). -/ + errors : Array DiagnosticModel := #[] + +/-! ## Phase 1: ID assignment and reference resolution -/ + +/-- A scope entry stores the definition-site ID and the AstNode for type lookups. -/ +abbrev ScopeEntry := Nat × AstNode + +/-- Scope maps a name to its definition-site ID and optional AstNode. -/ +abbrev Scope := Std.HashMap String ScopeEntry + +/-- Per-composite-type scope mapping field names to their scope entries. -/ +abbrev TypeScopes := Std.HashMap String Scope + +/-- State threaded through the resolution pass. -/ +structure ResolveState where + /-- Next fresh ID to allocate. -/ + nextId : Nat := 1 + /-- Current lexical scope (name → definition ID). -/ + scope : Scope := {} + /-- Per-composite-type field scopes (type name → field name → scope entry). -/ + typeScopes : TypeScopes := {} + /-- Diagnostics collected during resolution. -/ + errors : Array DiagnosticModel := #[] + +abbrev ResolveM := StateM ResolveState + +/-- Allocate a fresh unique ID. -/ +private def freshId : ResolveM Nat := do + let s ← get + let id := s.nextId + set { s with nextId := id + 1 } + return id + +/-- Register a definition: assign a fresh ID to the identifier and record it in scope with its AstNode. -/ +def defineName (iden : Identifier) (node : AstNode) (overrideResolutionName: Option String := none) : ResolveM Identifier := do + let resolutionName := overrideResolutionName.getD iden.text + let name' ← if iden.uniqueId == none then + let id ← freshId + pure { iden with uniqueId := some (id) } + else + pure iden + + modify fun s => { s with scope := s.scope.insert resolutionName (name'.uniqueId.getD (panic "key was just inserted"), node) } + return name' + +/-- Resolve a reference: look up the name in scope and assign the definition's ID. + Returns the identifier with its ID filled in. -/ +def resolveRef (name : Identifier) (md : Imperative.MetaData Core.Expression := .empty) : ResolveM Identifier := do + let s ← get + match s.scope.get? name.text with + | some (defId, _) => + let name' := { name with uniqueId := some defId } + return name' + | none => + let diag := md.toDiagnostic s!"Resolution failed: '{name}' is not defined" + modify fun s => { s with errors := s.errors.push diag } + return { name with uniqueId := none } + +/-- Extract the UserDefined type name from a resolved target expression by looking up its scope entry. -/ +private def targetTypeName (target : StmtExprMd) : ResolveM (Option String) := do + let s ← get + match target.val with + | .Identifier ref => + match s.scope.get? ref.text with + | some (_, node) => + match node.getType with + | some ty => + match ty.val with + | .UserDefined typRef => pure (some typRef.text) + | _ => pure none + | none => pure none + | none => pure none + | _ => pure none + +/-- Resolve a field reference using the target's type to build a qualified lookup key. + Falls back to unqualified lookup if the target type cannot be determined. -/ +def resolveFieldRef (target : StmtExprMd) (fieldName : Identifier) + (md : Imperative.MetaData Core.Expression) : ResolveM Identifier := do + let typeName? ← targetTypeName target + match typeName? with + | some typeName => + let s ← get + match s.typeScopes.get? typeName with + | some typeScope => + match typeScope.get? fieldName.text with + | some (defId, _) => return { fieldName with uniqueId := some defId } + | none => resolveRef fieldName md + | none => resolveRef fieldName md + | none => resolveRef fieldName md + +/-- Save and restore scope around a block (for lexical scoping). -/ +def withScope (action : ResolveM α) : ResolveM α := do + let savedScope := (← get).scope + let result ← action + modify fun s => { s with scope := savedScope } + return result + +/-! ## AST traversal (Phase 1) -/ + + +def resolveHighType (ty : HighTypeMd) : ResolveM HighTypeMd := do + match ty with + | WithMetadata.mk val _ => + let val' ← match val with + | .UserDefined ref => + let ref' ← resolveRef ref ty.md + pure (.UserDefined ref') + | .TTypedField vt => + let vt' ← resolveHighType vt + pure (.TTypedField vt') + | .TSet et => + let et' ← resolveHighType et + pure (.TSet et') + | .TMap kt vt => + let kt' ← resolveHighType kt + let vt' ← resolveHighType vt + pure (.TMap kt' vt') + | .Applied base args => + let base' ← resolveHighType base + let args' ← args.mapM resolveHighType + pure (.Applied base' args') + | .Pure base => + let base' ← resolveHighType base + pure (.Pure base') + | .Intersection tys => + let tys' ← tys.mapM resolveHighType + pure (.Intersection tys') + | other => pure other + return ⟨val', ty.md⟩ + +def resolveStmtExpr (exprMd : StmtExprMd) : ResolveM StmtExprMd := do + match _: exprMd with + | WithMetadata.mk expr md => + let val' ← match _: expr with + | .IfThenElse cond thenBr elseBr => + let cond' ← resolveStmtExpr cond + let thenBr' ← resolveStmtExpr thenBr + let elseBr' ← elseBr.attach.mapM (fun a => have := a.property; resolveStmtExpr a.val) + pure (.IfThenElse cond' thenBr' elseBr') + | .Block stmts label => + withScope do + let stmts' ← stmts.mapM resolveStmtExpr + pure (.Block stmts' label) + | .LocalVariable name ty init => + let ty' ← resolveHighType ty + let init' ← init.attach.mapM (fun a => have := a.property; resolveStmtExpr a.val) + let name' ← defineName name (.var name ty') + pure (.LocalVariable name' ty' init') + | .While cond invs dec body => + let cond' ← resolveStmtExpr cond + let invs' ← invs.attach.mapM (fun a => have := a.property; resolveStmtExpr a.val) + let dec' ← dec.attach.mapM (fun a => have := a.property; resolveStmtExpr a.val) + let body' ← resolveStmtExpr body + pure (.While cond' invs' dec' body') + | .Exit target => pure (.Exit target) + | .Return val => do + let val' ← val.attach.mapM (fun a => have := a.property; resolveStmtExpr a.val) + pure (.Return val') + | .LiteralInt v => pure (.LiteralInt v) + | .LiteralBool v => pure (.LiteralBool v) + | .LiteralString v => pure (.LiteralString v) + | .Identifier ref => + let ref' ← resolveRef ref md + pure (.Identifier ref') + | .Assign targets value => + let targets' ← targets.mapM resolveStmtExpr + let value' ← resolveStmtExpr value + pure (.Assign targets' value') + | .FieldSelect target fieldName => + let target' ← resolveStmtExpr target + let fieldName' ← resolveFieldRef target' fieldName md + pure (.FieldSelect target' fieldName') + | .PureFieldUpdate target fieldName newVal => + let target' ← resolveStmtExpr target + let fieldName' ← resolveFieldRef target' fieldName md + let newVal' ← resolveStmtExpr newVal + pure (.PureFieldUpdate target' fieldName' newVal') + | .StaticCall callee args => + let callee' ← resolveRef callee md + let args' ← args.mapM resolveStmtExpr + pure (.StaticCall callee' args') + | .PrimitiveOp op args => + let args' ← args.mapM resolveStmtExpr + pure (.PrimitiveOp op args') + | .New ref => + let ref' ← resolveRef ref md + pure (.New ref') + | .This => pure .This + | .ReferenceEquals lhs rhs => + let lhs' ← resolveStmtExpr lhs + let rhs' ← resolveStmtExpr rhs + pure (.ReferenceEquals lhs' rhs') + | .AsType target ty => + let target' ← resolveStmtExpr target + let ty' ← resolveHighType ty + pure (.AsType target' ty') + | .IsType target ty => + let target' ← resolveStmtExpr target + let ty' ← resolveHighType ty + pure (.IsType target' ty') + | .InstanceCall target callee args => + let target' ← resolveStmtExpr target + let callee' ← resolveRef callee md + let args' ← args.mapM resolveStmtExpr + pure (.InstanceCall target' callee' args') + | .Forall param body => + withScope do + let paramTy' ← resolveHighType param.type + let paramName' ← defineName param.name (.quantifierVar param.name paramTy') + let body' ← resolveStmtExpr body + pure (.Forall ⟨paramName', paramTy'⟩ body') + | .Exists param body => + withScope do + let paramTy' ← resolveHighType param.type + let paramName' ← defineName param.name (.quantifierVar param.name paramTy') + let body' ← resolveStmtExpr body + pure (.Exists ⟨paramName', paramTy'⟩ body') + | .Assigned name => + let name' ← resolveStmtExpr name + pure (.Assigned name') + | .Old val => + let val' ← resolveStmtExpr val + pure (.Old val') + | .Fresh val => + let val' ← resolveStmtExpr val + pure (.Fresh val') + | .Assert cond => + let cond' ← resolveStmtExpr cond + pure (.Assert cond') + | .Assume cond => + let cond' ← resolveStmtExpr cond + pure (.Assume cond') + | .ProveBy val proof => + let val' ← resolveStmtExpr val + let proof' ← resolveStmtExpr proof + pure (.ProveBy val' proof') + | .ContractOf ty fn => + let fn' ← resolveStmtExpr fn + pure (.ContractOf ty fn') + | .Abstract => pure .Abstract + | .All => pure .All + | .Hole => pure .Hole + return ⟨val', md⟩ + termination_by exprMd + decreasing_by all_goals term_by_mem + +/-- Resolve a parameter: assign a fresh ID and add to scope. -/ +def resolveParameter (param : Parameter) : ResolveM Parameter := do + let ty' ← resolveHighType param.type + let name' ← defineName param.name (.parameter ⟨param.name, ty'⟩) + return ⟨name', ty'⟩ + +/-- Resolve a procedure body. -/ +def resolveBody (body : Body) : ResolveM Body := do + match body with + | .Transparent b => + let b' ← resolveStmtExpr b + return .Transparent b' + | .Opaque posts impl mods => + let posts' ← posts.mapM resolveStmtExpr + let impl' ← impl.mapM resolveStmtExpr + let mods' ← mods.mapM resolveStmtExpr + return .Opaque posts' impl' mods' + | .Abstract posts => + let posts' ← posts.mapM resolveStmtExpr + return .Abstract posts' + | .External => return .External + +/-- Resolve a determinism clause. -/ +def resolveDeterminism (d : Determinism) : ResolveM Determinism := do + match d with + | .deterministic reads => + let reads' ← reads.mapM resolveStmtExpr + return .deterministic reads' + | .nondeterministic => return .nondeterministic + +/-- Resolve a procedure: define its name, then resolve params, contracts, and body in a new scope. -/ +def resolveProcedure (proc : Procedure) : ResolveM Procedure := do + let procName' ← defineName proc.name (.staticProcedure proc) + withScope do + let inputs' ← proc.inputs.mapM resolveParameter + let outputs' ← proc.outputs.mapM resolveParameter + let pres' ← proc.preconditions.mapM resolveStmtExpr + let det' ← resolveDeterminism proc.determinism + let dec' ← proc.decreases.mapM resolveStmtExpr + let body' ← resolveBody proc.body + return { name := procName', inputs := inputs', outputs := outputs', + isFunctional := proc.isFunctional, + preconditions := pres', determinism := det', decreases := dec', + body := body', md := proc.md } + +/-- Resolve a field: define its name under the qualified key (OwnerType.fieldName) and resolve its type. -/ +def resolveField (ownerName : Identifier) (field : Field) : ResolveM Field := do + let ty' ← resolveHighType field.type + let qualifiedName := ownerName.text ++ "." ++ field.name.text + let name' ← defineName field.name (.field ownerName { field with type := ty' }) (some qualifiedName) + return { name := name', isMutable := field.isMutable, type := ty' } + +/-- Resolve an instance procedure on a composite type. -/ +def resolveInstanceProcedure (typeName : Identifier) (proc : Procedure) : ResolveM Procedure := do + let procName' ← defineName proc.name (.instanceProcedure typeName proc) + withScope do + let inputs' ← proc.inputs.mapM resolveParameter + let outputs' ← proc.outputs.mapM resolveParameter + let pres' ← proc.preconditions.mapM resolveStmtExpr + let det' ← resolveDeterminism proc.determinism + let dec' ← proc.decreases.mapM resolveStmtExpr + let body' ← resolveBody proc.body + return { name := procName', inputs := inputs', outputs := outputs', + isFunctional := proc.isFunctional, + preconditions := pres', determinism := det', decreases := dec', + body := body', md := proc.md } + +/-- Resolve a type definition. -/ +def resolveTypeDefinition (td : TypeDefinition) : ResolveM TypeDefinition := do + match td with + | .Composite ct => + let ctName' ← defineName ct.name (.compositeType ct) + let extending' ← ct.extending.mapM (resolveRef · .empty) + let fields' ← ct.fields.mapM (resolveField ctName') + let instProcs' ← ct.instanceProcedures.mapM (resolveInstanceProcedure ctName') + -- Build per-type scope: start with inherited fields from parents, then add own fields + let s ← get + let mut typeScope : Scope := {} + for parent in extending' do + match s.typeScopes.get? parent.text with + | some parentScope => + for (k, v) in parentScope do + typeScope := typeScope.insert k v + | none => pure () + -- Add own fields (these override inherited ones with the same name) + for field in fields' do + let qualifiedKey := ctName'.text ++ "." ++ field.name.text + match s.scope.get? qualifiedKey with + | some entry => typeScope := typeScope.insert field.name.text entry + | none => pure () + modify fun s => { s with typeScopes := s.typeScopes.insert ctName'.text typeScope } + return .Composite { name := ctName', extending := extending', + fields := fields', instanceProcedures := instProcs' } + | .Constrained ct => + let ctName' ← defineName ct.name (.constrainedType ct) + let base' ← resolveHighType ct.base + let constraint' ← resolveStmtExpr ct.constraint + let witness' ← resolveStmtExpr ct.witness + return .Constrained { name := ctName', base := base', valueName := ct.valueName, + constraint := constraint', witness := witness' } + | .Datatype dt => + let dtName' ← defineName dt.name (.datatypeDefinition dt) + let ctors' ← dt.constructors.mapM fun ctor => do + let ctorName' ← defineName ctor.name (.datatypeConstructor dt.name ctor) + let args' ← ctor.args.mapM fun (p: Parameter) => do + let ty' ← resolveHighType p.type + let destructorId ← defineName p.name (.parameter p) (some $ dt.name.text ++ ".." ++ p.name.text) + return ⟨ destructorId, ty' ⟩ + return { name := ctorName', args := args' : DatatypeConstructor } + return .Datatype { name := dtName', typeArgs := dt.typeArgs, constructors := ctors' } + +/-- Resolve a constant definition. -/ +def resolveConstant (c : Constant) : ResolveM Constant := do + let ty' ← resolveHighType c.type + let init' ← c.initializer.mapM resolveStmtExpr + let name' ← defineName c.name (.constant c) + return { name := name', type := ty', initializer := init' } + +/-! ## Phase 2: Build refToDef map from the resolved program -/ + +/-- Insert a definition into the refToDef map using the ID already on the identifier. -/ +private def register (map : Std.HashMap Nat AstNode) (iden : Identifier) (node : AstNode) + : Std.HashMap Nat AstNode := + match iden.uniqueId with + | some uuid => map.insert uuid node + | none => map -- shouldn't happen after Phase 1 + +private def collectHighType (map : Std.HashMap Nat AstNode) (ty : HighTypeMd) + : Std.HashMap Nat AstNode := + match ty with + | WithMetadata.mk val _ => + match val with + | .TTypedField vt => collectHighType map vt + | .TSet et => collectHighType map et + | .TMap kt vt => + let map := collectHighType map kt + collectHighType map vt + | .Applied base args => + let map := collectHighType map base + args.foldl collectHighType map + | .Pure base => collectHighType map base + | .Intersection tys => tys.foldl collectHighType map + | _ => map + +private def collectStmtExpr (map : Std.HashMap Nat AstNode) (expr : StmtExprMd) + : Std.HashMap Nat AstNode := + match expr with + | WithMetadata.mk val _ => + match val with + | .IfThenElse cond thenBr elseBr => + let map := collectStmtExpr map cond + let map := collectStmtExpr map thenBr + match elseBr with + | some e => collectStmtExpr map e + | none => map + | .Block stmts _ => stmts.foldl collectStmtExpr map + | .LocalVariable name ty init => + let map := register map name (.var name ty) + let map := collectHighType map ty + match init with + | some i => collectStmtExpr map i + | none => map + | .While cond invs dec body => + let map := collectStmtExpr map cond + let map := invs.foldl collectStmtExpr map + let map := match dec with | some d => collectStmtExpr map d | none => map + collectStmtExpr map body + | .Return val => match val with | some v => collectStmtExpr map v | none => map + | .Identifier _ => map + | .Assign targets value => + let map := targets.foldl collectStmtExpr map + collectStmtExpr map value + | .FieldSelect target _ => collectStmtExpr map target + | .PureFieldUpdate target _ newVal => + let map := collectStmtExpr map target + collectStmtExpr map newVal + | .StaticCall _ args => args.foldl collectStmtExpr map + | .PrimitiveOp _ args => args.foldl collectStmtExpr map + | .ReferenceEquals lhs rhs => + let map := collectStmtExpr map lhs + collectStmtExpr map rhs + | .AsType target ty => + let map := collectStmtExpr map target + collectHighType map ty + | .IsType target ty => + let map := collectStmtExpr map target + collectHighType map ty + | .InstanceCall target _ args => + let map := collectStmtExpr map target + args.foldl collectStmtExpr map + | .Forall param body => + let map := register map param.name (.quantifierVar param.name param.type) + let map := collectHighType map param.type + collectStmtExpr map body + | .Exists param body => + let map := register map param.name (.quantifierVar param.name param.type) + let map := collectHighType map param.type + collectStmtExpr map body + | .Assigned name => collectStmtExpr map name + | .Old val => collectStmtExpr map val + | .Fresh val => collectStmtExpr map val + | .Assert cond => collectStmtExpr map cond + | .Assume cond => collectStmtExpr map cond + | .ProveBy val proof => + let map := collectStmtExpr map val + collectStmtExpr map proof + | .ContractOf _ fn => collectStmtExpr map fn + | .New _ | .This | .Exit _ | .LiteralInt _ | .LiteralBool _ | .LiteralString _ + | .Abstract | .All | .Hole => map + +private def collectBody (map : Std.HashMap Nat AstNode) (body : Body) + : Std.HashMap Nat AstNode := + match body with + | .Transparent b => collectStmtExpr map b + | .Opaque posts impl mods => + let map := posts.foldl collectStmtExpr map + let map := match impl with | some i => collectStmtExpr map i | none => map + mods.foldl collectStmtExpr map + | .Abstract posts => posts.foldl collectStmtExpr map + | .External => map + +private def collectDeterminism (map : Std.HashMap Nat AstNode) (d : Determinism) + : Std.HashMap Nat AstNode := + match d with + | .deterministic (some reads) => collectStmtExpr map reads + | _ => map + +private def collectParameter (map : Std.HashMap Nat AstNode) (param : Parameter) + : Std.HashMap Nat AstNode := + let map := register map param.name (.parameter param) + collectHighType map param.type + +private def collectProcedure (map : Std.HashMap Nat AstNode) (proc : Procedure) + (mkNode : Procedure → AstNode) : Std.HashMap Nat AstNode := + let map := register map proc.name (mkNode proc) + let map := proc.inputs.foldl collectParameter map + let map := proc.outputs.foldl collectParameter map + let map := proc.preconditions.foldl collectStmtExpr map + let map := collectDeterminism map proc.determinism + let map := match proc.decreases with | some d => collectStmtExpr map d | none => map + collectBody map proc.body + +private def collectField (map : Std.HashMap Nat AstNode) (ownerName : Identifier) (field : Field) + : Std.HashMap Nat AstNode := + let map := register map field.name (.field ownerName field) + collectHighType map field.type + +private def collectTypeDefinition (map : Std.HashMap Nat AstNode) (td : TypeDefinition) + : Std.HashMap Nat AstNode := + match td with + | .Composite ct => + let map := register map ct.name (.compositeType ct) + let map := ct.fields.foldl (collectField · ct.name ·) map + ct.instanceProcedures.foldl (collectProcedure · · (.instanceProcedure ct.name ·)) map + | .Constrained ct => + let map := register map ct.name (.constrainedType ct) + let map := collectHighType map ct.base + let map := collectStmtExpr map ct.constraint + collectStmtExpr map ct.witness + | .Datatype dt => + let map := register map dt.name (.datatypeDefinition dt) + dt.constructors.foldl (fun map ctor => + let map := register map ctor.name (.datatypeConstructor dt.name ctor) + ctor.args.foldl (fun map p => + let map := register map p.name (.parameter p) + collectHighType map p.type + ) map + ) map + +private def collectConstant (map : Std.HashMap Nat AstNode) (c : Constant) + : Std.HashMap Nat AstNode := + let map := register map c.name (.constant c) + let map := collectHighType map c.type + match c.initializer with + | some init => collectStmtExpr map init + | none => map + +/-- Build the refToDef map by walking the fully-resolved program (Phase 2). -/ +def buildRefToDef (program : Program) : Std.HashMap Nat AstNode := + let map : Std.HashMap Nat AstNode := {} + let map := program.types.foldl collectTypeDefinition map + let map := program.constants.foldl collectConstant map + let map := program.staticFields.foldl (collectField · "$static" ·) map + program.staticProcedures.foldl (collectProcedure · · .staticProcedure) map + +/-! ## Pre-registration: populate scope with all top-level names before resolving bodies -/ + +/-- A default AstNode used as a placeholder during pre-registration. + It will be overwritten with the real node when the definition is fully resolved. -/ +private def placeholderNode : AstNode := .var "$placeholder" ⟨.TVoid, #[]⟩ + +/-- Pre-register all top-level names into scope so that declaration order doesn't matter. + This assigns fresh IDs and adds placeholder scope entries for: + - Type names (composite, constrained, datatype) and their constructors/destructors/fields + - Constant names + - Static procedure names -/ +private def preRegisterTopLevel (program : Program) : ResolveM Unit := do + -- Pre-register type definitions + for td in program.types do + match td with + | .Composite ct => + let _ ← defineName ct.name (.compositeType ct) + for field in ct.fields do + let qualifiedName := ct.name.text ++ "." ++ field.name.text + let _ ← defineName field.name placeholderNode (some qualifiedName) + for proc in ct.instanceProcedures do + let _ ← defineName proc.name placeholderNode + | .Constrained ct => + let _ ← defineName ct.name (.constrainedType ct) + | .Datatype dt => + let _ ← defineName dt.name (.datatypeDefinition dt) + for ctor in dt.constructors do + let _ ← defineName ctor.name (.datatypeConstructor dt.name ctor) + for p in ctor.args do + let _ ← defineName p.name placeholderNode (some $ dt.name.text ++ ".." ++ p.name.text) + -- Pre-register constants + for c in program.constants do + let _ ← defineName c.name (.constant c) + -- Pre-register static procedures + for proc in program.staticProcedures do + let _ ← defineName proc.name (.staticProcedure proc) + +/-! ## Entry point -/ + +/-- Run the full resolution pass on a Laurel program. -/ +def resolve (program : Program) (existingModel: Option SemanticModel := none) : ResolutionResult := + -- Phase 1: pre-register all top-level names, then assign IDs and resolve references + let phase1 : ResolveM Program := do + preRegisterTopLevel program + let types' ← program.types.mapM resolveTypeDefinition + let constants' ← program.constants.mapM resolveConstant + let staticFields' ← program.staticFields.mapM (resolveField "$static") + let staticProcs' ← program.staticProcedures.mapM resolveProcedure + return { staticProcedures := staticProcs', staticFields := staticFields', + types := types', constants := constants' } + let nextId := existingModel.elim 1 (fun m => m.nextId) + let (program', finalState) := phase1.run { nextId := nextId } + -- Phase 2: build refToDef from the resolved program (all definitions now have UUIDs) + let refToDef := buildRefToDef program' + { program := program', + model := { + compositeCount := program.types.length, + refToDef := refToDef, + nextId := finalState.nextId + }, + errors := finalState.errors + } diff --git a/Strata/Languages/Laurel/TypeHierarchy.lean b/Strata/Languages/Laurel/TypeHierarchy.lean index a78f918b8..8871d98ee 100644 --- a/Strata/Languages/Laurel/TypeHierarchy.lean +++ b/Strata/Languages/Laurel/TypeHierarchy.lean @@ -17,26 +17,20 @@ open Strata Compute the flattened set of ancestors for a composite type, including itself. Traverses the `extending` list transitively. -/ -def computeAncestors (types : List TypeDefinition) (name : Identifier) : List CompositeType := +def computeAncestors (model: SemanticModel) (name : Identifier) : List CompositeType := let rec go (fuel : Nat) (current : Identifier) : List CompositeType := match fuel with | 0 => - types.filterMap fun td => match td with - | .Composite ct => if ct.name == current then some ct else none - | _ => none + match model.get current with + | .compositeType (ty : CompositeType) => [ty] + | _ => [] | fuel' + 1 => - let self := types.filterMap fun td => match td with - | .Composite ct => if ct.name == current then some ct else none - | _ => none - self ++ (types.foldl (fun acc td => - match td with - | .Composite ct => - if ct.name == current then - ct.extending.foldl (fun acc2 parent => acc2 ++ go fuel' parent) acc - else acc - | _ => acc) []) + match model.get current with + | .compositeType (ty : CompositeType) => + [ty] ++ ty.extending.flatMap (fun parent => go fuel' parent) + | _ => [] let seen : List Identifier := [] - (go types.length name).foldl (fun (acc, seen) ct => + (go model.compositeCount name).foldl (fun (acc, seen) ct => if seen.contains ct.name then (acc, seen) else (acc ++ [ct], seen ++ [ct.name])) ([], seen) |>.1 @@ -52,8 +46,8 @@ We use this to translate ` is `. The runtime type of `` is used for the outer Map lookup while `` for the inner one. -/ -def generateTypeHierarchyDecls (types : List TypeDefinition) : List Constant := - let composites := types.filterMap fun td => match td with +def generateTypeHierarchyDecls (model : SemanticModel) (program: Program) : List Constant := + let composites := program.types.filterMap fun td => match td with | .Composite ct => some ct | _ => none if composites.isEmpty then [] else @@ -64,18 +58,20 @@ def generateTypeHierarchyDecls (types : List TypeDefinition) : List Constant := -- Helper: build an inner map (Map TypeTag bool) for a given composite type -- Start with const(false), then update each composite type's entry let mkInnerMap (ct : CompositeType) : StmtExprMd := - let ancestors := computeAncestors types ct.name + let ancestors := computeAncestors model ct.name let falseConst := mkMd (.LiteralBool false) let emptyInner := mkMd (.StaticCall "const" [falseConst]) composites.foldl (fun acc otherCt => - let otherConst := mkMd (.StaticCall (otherCt.name ++ "_TypeTag") []) let isAncestor := ancestors.any (·.name == otherCt.name) - let boolVal := mkMd (.LiteralBool isAncestor) - mkMd (.StaticCall "update" [acc, otherConst, boolVal]) + if isAncestor then + let otherConst := mkMd (.StaticCall (mkId $ otherCt.name.text ++ "_TypeTag") []) + let boolVal := mkMd (.LiteralBool true) + mkMd (.StaticCall "update" [acc, otherConst, boolVal]) + else acc ) emptyInner -- Generate a separate constant `ancestorsFor` for each composite type let ancestorsForDecls := composites.map fun ct => - { name := s!"ancestorsFor{ct.name}" + { name := s!"ancestorsFor{ct.name.text}" type := innerMapTy initializer := some (mkInnerMap ct) : Constant } -- Build ancestorsPerType by referencing the individual ancestorsFor constants @@ -83,8 +79,8 @@ def generateTypeHierarchyDecls (types : List TypeDefinition) : List Constant := let emptyInner := mkMd (.StaticCall "const" [falseConst]) let emptyOuter := mkMd (.StaticCall "const" [emptyInner]) let outerMapExpr := composites.foldl (fun acc ct => - let typeConst := mkMd (.StaticCall (ct.name ++ "_TypeTag") []) - let innerMapRef := mkMd (.StaticCall s!"ancestorsFor{ct.name}" []) + let typeConst := mkMd (.StaticCall (mkId $ ct.name.text ++ "_TypeTag") []) + let innerMapRef := mkMd (.StaticCall s!"ancestorsFor{ct.name.text}" []) mkMd (.StaticCall "update" [acc, typeConst, innerMapRef]) ) emptyOuter let ancestorsDecl : Constant := @@ -97,83 +93,66 @@ def generateTypeHierarchyDecls (types : List TypeDefinition) : List Constant := Check if a field can be reached through a given type (directly declared or inherited). Returns true if the type or any of its ancestors declares the field. -/ -def canReachField (types : List TypeDefinition) (typeName : Identifier) (fieldName : Identifier) : Bool := - let rec go (fuel : Nat) (current : Identifier) : Bool := - match fuel with - | 0 => false - | fuel' + 1 => - types.any fun td => - match td with - | .Composite ct => - ct.name == current && - (ct.fields.any (·.name == fieldName) || - ct.extending.any (go fuel')) - | _ => false - go types.length typeName +def canReachField (model : SemanticModel) (typeName : Identifier) (fieldName : Identifier) : Bool := + match model.get fieldName with + | .field owner _ => ((computeAncestors model typeName).find? (fun t => t.name == owner)).isSome + | _ => panic! s!"{fieldName} did not resolve to a field" /-- Check if a field is inherited through multiple parent paths (diamond inheritance). Returns true if more than one direct parent of the given type can reach the field. -/ -def isDiamondInheritedField (types : List TypeDefinition) (typeName : Identifier) (fieldName : Identifier) : Bool := - let findComposite := types.findSome? fun td => - match td with - | .Composite ct => if ct.name == typeName then some ct else none - | _ => none - match findComposite with - | none => false - | some ct => +def isDiamondInheritedField (model : SemanticModel) (typeName : Identifier) (fieldName : Identifier) : Bool := + match model.get typeName with + | .compositeType ct => -- If the field is directly declared on this type, it's not a diamond if ct.fields.any (·.name == fieldName) then false else -- Count how many direct parents can reach this field - let parentsWithField := ct.extending.filter (canReachField types · fieldName) + let parentsWithField := ct.extending.filter (canReachField model · fieldName) parentsWithField.length > 1 + | _ => false /-- Walk a StmtExpr AST and collect DiagnosticModel errors for diamond-inherited field accesses. -/ -def validateDiamondFieldAccessesForStmtExpr (uri : Uri) (types : List TypeDefinition) (env : TypeEnv) +def validateDiamondFieldAccessesForStmtExpr (model : SemanticModel) (expr : StmtExprMd) : List DiagnosticModel := match _h : expr.val with | .FieldSelect target fieldName => - let targetErrors := validateDiamondFieldAccessesForStmtExpr uri types env target - let fieldError := match (computeExprType env types target).val with + let targetErrors := validateDiamondFieldAccessesForStmtExpr model target + let fieldError := match (computeExprType model target).val with | .UserDefined typeName => - if isDiamondInheritedField types typeName fieldName then + if isDiamondInheritedField model typeName fieldName then let fileRange := (Imperative.getFileRange expr.md).getD FileRange.unknown [DiagnosticModel.withRange fileRange s!"fields that are inherited multiple times can not be accessed."] else [] | _ => [] targetErrors ++ fieldError | .Block stmts _ => - (stmts.attach.foldl (fun (acc, env') ⟨s, _⟩ => - let env'' := match s.val with - | .LocalVariable name ty _ => (name, ty) :: env' - | _ => env' - (acc ++ validateDiamondFieldAccessesForStmtExpr uri types env' s, env'')) ([], env)).1 + stmts.flatMap (fun s => validateDiamondFieldAccessesForStmtExpr model s) | .Assign targets value => - let targetErrors := targets.attach.foldl (fun acc ⟨t, _⟩ => acc ++ validateDiamondFieldAccessesForStmtExpr uri types env t) [] - targetErrors ++ validateDiamondFieldAccessesForStmtExpr uri types env value + let targetErrors := targets.attach.foldl (fun acc ⟨t, _⟩ => acc ++ validateDiamondFieldAccessesForStmtExpr model t) [] + targetErrors ++ validateDiamondFieldAccessesForStmtExpr model value | .IfThenElse c t e => - let errs := validateDiamondFieldAccessesForStmtExpr uri types env c ++ - validateDiamondFieldAccessesForStmtExpr uri types env t + let errs := validateDiamondFieldAccessesForStmtExpr model c ++ + validateDiamondFieldAccessesForStmtExpr model t match e with - | some eb => errs ++ validateDiamondFieldAccessesForStmtExpr uri types env eb + | some eb => errs ++ validateDiamondFieldAccessesForStmtExpr model eb | none => errs | .LocalVariable _ _ (some init) => - validateDiamondFieldAccessesForStmtExpr uri types env init + validateDiamondFieldAccessesForStmtExpr model init | .While c invs _ b => - let errs := validateDiamondFieldAccessesForStmtExpr uri types env c ++ - validateDiamondFieldAccessesForStmtExpr uri types env b - invs.attach.foldl (fun acc ⟨inv, _⟩ => acc ++ validateDiamondFieldAccessesForStmtExpr uri types env inv) errs - | .Assert cond => validateDiamondFieldAccessesForStmtExpr uri types env cond - | .Assume cond => validateDiamondFieldAccessesForStmtExpr uri types env cond + let errs := validateDiamondFieldAccessesForStmtExpr model c ++ + validateDiamondFieldAccessesForStmtExpr model b + invs.attach.foldl (fun acc ⟨inv, _⟩ => acc ++ validateDiamondFieldAccessesForStmtExpr model inv) errs + | .Assert cond => validateDiamondFieldAccessesForStmtExpr model cond + | .Assume cond => validateDiamondFieldAccessesForStmtExpr model cond | .PrimitiveOp _ args => - args.attach.foldl (fun acc ⟨a, _⟩ => acc ++ validateDiamondFieldAccessesForStmtExpr uri types env a) [] + args.attach.foldl (fun acc ⟨a, _⟩ => acc ++ validateDiamondFieldAccessesForStmtExpr model a) [] | .StaticCall _ args => - args.attach.foldl (fun acc ⟨a, _⟩ => acc ++ validateDiamondFieldAccessesForStmtExpr uri types env a) [] - | .Return (some v) => validateDiamondFieldAccessesForStmtExpr uri types env v + args.attach.foldl (fun acc ⟨a, _⟩ => acc ++ validateDiamondFieldAccessesForStmtExpr model a) [] + | .Return (some v) => validateDiamondFieldAccessesForStmtExpr model v | _ => [] termination_by sizeOf expr decreasing_by all_goals (have := WithMetadata.sizeOf_val_lt expr; term_by_mem) @@ -182,19 +161,18 @@ def validateDiamondFieldAccessesForStmtExpr (uri : Uri) (types : List TypeDefini Validate a Laurel program for diamond-inherited field accesses. Returns an array of DiagnosticModel errors. -/ -def validateDiamondFieldAccesses (uri : Uri) (program : Program) : Array DiagnosticModel := +def validateDiamondFieldAccesses (model: SemanticModel) (program : Program) : Array DiagnosticModel := let errors := program.staticProcedures.foldl (fun acc proc => - let env : TypeEnv := proc.inputs.map (fun p => (p.name, p.type)) ++ - proc.outputs.map (fun p => (p.name, p.type)) let bodyErrors := match proc.body with - | .Transparent bodyExpr => validateDiamondFieldAccessesForStmtExpr uri program.types env bodyExpr + | .Transparent bodyExpr => validateDiamondFieldAccessesForStmtExpr model bodyExpr | .Opaque postconds impl _ => - let postErrors := postconds.foldl (fun acc2 pc => acc2 ++ validateDiamondFieldAccessesForStmtExpr uri program.types env pc) [] + let postErrors := postconds.foldl (fun acc2 pc => acc2 ++ validateDiamondFieldAccessesForStmtExpr model pc) [] let implErrors := match impl with - | some implExpr => validateDiamondFieldAccessesForStmtExpr uri program.types env implExpr + | some implExpr => validateDiamondFieldAccessesForStmtExpr model implExpr | none => [] postErrors ++ implErrors - | .Abstract postconds => postconds.foldl (fun acc p => acc ++ validateDiamondFieldAccessesForStmtExpr uri program.types env p) [] + | .Abstract postconds => postconds.foldl (fun acc p => acc ++ validateDiamondFieldAccessesForStmtExpr model p) [] + | .External => [] acc ++ bodyErrors) [] errors.toArray @@ -204,12 +182,12 @@ Lower `IsType target ty` to Laurel-level map lookups: -/ def lowerIsType (target : StmtExprMd) (ty : HighTypeMd) (md : Imperative.MetaData Core.Expression) : StmtExprMd := let typeName := match ty.val with - | .UserDefined name => name + | .UserDefined name => name.text | _ => panic! s!"IsType: expected UserDefined type" let typeTag := mkMd (.StaticCall "Composite..typeTag!" [target]) let ancestorsPerType := mkMd (.StaticCall "ancestorsPerType" []) let innerMap := mkMd (.StaticCall "select" [ancestorsPerType, typeTag]) - let typeConst := mkMd (.StaticCall (typeName ++ "_TypeTag") []) + let typeConst := mkMd (.StaticCall (mkId $ typeName ++ "_TypeTag") []) ⟨.StaticCall "select" [innerMap, typeConst], md⟩ /-- State for the type hierarchy rewrite monad -/ @@ -230,13 +208,13 @@ Lower `New name` to a block that: 3. Constructs a `MkComposite(counter, name_TypeTag())` value -/ def lowerNew (name : Identifier) (md : Imperative.MetaData Core.Expression) : THM StmtExprMd := do - let heapVar := "$heap" + let heapVar : Identifier := "$heap" let freshVar ← freshVarName let getCounter := mkMd (.StaticCall "Heap..nextReference!" [mkMd (.Identifier heapVar)]) let saveCounter := mkMd (.LocalVariable freshVar ⟨.TInt, #[]⟩ (some getCounter)) let newHeap := mkMd (.StaticCall "increment" [mkMd (.Identifier heapVar)]) let updateHeap := mkMd (.Assign [mkMd (.Identifier heapVar)] newHeap) - let compositeResult := mkMd (.StaticCall "MkComposite" [mkMd (.Identifier freshVar), mkMd (.StaticCall (name ++ "_TypeTag") [])]) + let compositeResult := mkMd (.StaticCall "MkComposite" [mkMd (.Identifier freshVar), mkMd (.StaticCall (name.text ++ "_TypeTag") [])]) return ⟨ .Block [saveCounter, updateHeap, compositeResult] none, md ⟩ /-- @@ -282,8 +260,8 @@ def rewriteTypeHierarchyExpr (exprMd : StmtExprMd) : THM StmtExprMd := | .InstanceCall t callee args => do let args' ← args.attach.mapM fun ⟨a, _⟩ => rewriteTypeHierarchyExpr a return ⟨.InstanceCall (← rewriteTypeHierarchyExpr t) callee args', md⟩ - | .Forall n ty b => do return ⟨.Forall n ty (← rewriteTypeHierarchyExpr b), md⟩ - | .Exists n ty b => do return ⟨.Exists n ty (← rewriteTypeHierarchyExpr b), md⟩ + | .Forall p b => do return ⟨.Forall p (← rewriteTypeHierarchyExpr b), md⟩ + | .Exists p b => do return ⟨.Exists p (← rewriteTypeHierarchyExpr b), md⟩ | .Assigned n => do return ⟨.Assigned (← rewriteTypeHierarchyExpr n), md⟩ | .Old v => do return ⟨.Old (← rewriteTypeHierarchyExpr v), md⟩ | .Fresh v => do return ⟨.Fresh (← rewriteTypeHierarchyExpr v), md⟩ @@ -307,6 +285,7 @@ def rewriteTypeHierarchyProcedure (proc : Procedure) : THM Procedure := do let modif' ← modif.mapM rewriteTypeHierarchyExpr pure (.Opaque postconds' impl' modif') | .Abstract postconds => pure (.Abstract (← postconds.mapM rewriteTypeHierarchyExpr)) + | .External => pure .External return { proc with preconditions := preconditions', body := body' } /-- @@ -317,18 +296,30 @@ Type hierarchy transformation pass (Laurel → Laurel). 3. Generates the `TypeTag` datatype with one constructor per composite type 4. Generates type hierarchy constants (`ancestorsFor`, `ancestorsPerType`) -/ -def typeHierarchyTransform (program : Program) : Program := +def typeHierarchyTransform (model: SemanticModel) (program : Program) : Program := let compositeNames := program.types.filterMap fun td => match td with - | .Composite ct => some ct.name + | .Composite ct => some ct.name.text | _ => none let typeTagDatatype : TypeDefinition := - .Datatype { name := "TypeTag", typeArgs := [], constructors := compositeNames.map fun n => { name := n ++ "_TypeTag", args := [] } } - let typeHierarchyConstants := generateTypeHierarchyDecls program.types + .Datatype { name := "TypeTag", typeArgs := [], constructors := compositeNames.map fun n => { name := (mkId $ n ++ "_TypeTag"), args := [] } } + let typeHierarchyConstants := generateTypeHierarchyDecls model program let (procs', _) := (program.staticProcedures.mapM rewriteTypeHierarchyProcedure).run {} + -- Update the Composite datatype to include the typeTag field (introduced in this phase) + let typeTagTy : HighTypeMd := ⟨.UserDefined "TypeTag", #[]⟩ + let remainingTypes := program.types.map fun td => + match td with + | .Datatype dt => + if dt.name.text == "Composite" then + .Datatype { dt with constructors := dt.constructors.map fun c => + if c.name.text == "MkComposite" then + { c with args := c.args ++ [{ name := ("typeTag" : Identifier), type := typeTagTy }] } + else c } + else td + | _ => td { program with staticProcedures := procs', - types := program.types ++ [typeTagDatatype], + types := [typeTagDatatype] ++ remainingTypes, constants := program.constants ++ typeHierarchyConstants } end Laurel diff --git a/Strata/Languages/Python/PythonToLaurel.lean b/Strata/Languages/Python/PythonToLaurel.lean index 818ab412d..2fb8b2437 100644 --- a/Strata/Languages/Python/PythonToLaurel.lean +++ b/Strata/Languages/Python/PythonToLaurel.lean @@ -207,6 +207,17 @@ def resolveDispatch (ctx : TranslationContext) /-! ## Expression Translation -/ +/-- Check if a Python expression has string type, using the Python AST and variable types. + Used to disambiguate `+` between arithmetic Add and string StrConcat. -/ +def isPyExprStringTyped (ctx : TranslationContext) (e : Python.expr SourceRange) : Bool := + match e with + | .Constant _ (.ConString ..) _ => true + | .Name _ name _ => + match ctx.variableTypes.find? (·.1 == name.val) with + | some (_, ty) => highEq ty (mkHighTypeMd .TString) + | none => false + | _ => false + /-- Check if a function has a model (is in prelude or user-defined) -/ def hasModel (ctx : TranslationContext) (funcName : String) : Bool := ctx.preludeProcedures.any (·.1 == funcName) || ctx.userFunctions.contains funcName @@ -267,10 +278,8 @@ partial def translateExpr (ctx : TranslationContext) (e : Python.expr SourceRang let laurelOp ← match op with -- Arithmetic | .Add _ => - let typeEnv : Laurel.TypeEnv := ctx.variableTypes - let leftType := Laurel.computeExprType typeEnv [] leftExpr - -- Dispatch on left operand - if highEq leftType { val := .TString, md := leftExpr.md } then + -- Dispatch on left operand type (determined from Python AST) + if isPyExprStringTyped ctx left then .ok Operation.StrConcat else .ok Operation.Add @@ -478,7 +487,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang let targetExpr := mkStmtExprMd (StmtExpr.Identifier target) let assignStmt := mkStmtExprMdWithLoc (StmtExpr.Assign [targetExpr] valueExpr) md return (ctx, assignStmt) - | .Attribute _ obj attr _ => + | .Attribute _ _ _ _ => -- Field assignment: obj.field = expr or self.field = expr let valueExpr ← translateExpr ctx value let targetExpr ← translateExpr ctx targets.val[0]! -- This will handle self.field via translateExpr @@ -898,6 +907,7 @@ def translateMethod (ctx : TranslationContext) (className : String) outputs := outputs preconditions := [mkStmtExprMd (StmtExpr.LiteralBool true)] determinism := .nondeterministic + isFunctional := false decreases := none body := .Transparent bodyBlock md := default diff --git a/Strata/Languages/Python/Specs/ToLaurel.lean b/Strata/Languages/Python/Specs/ToLaurel.lean index 38fb2d09c..072d63ecd 100644 --- a/Strata/Languages/Python/Specs/ToLaurel.lean +++ b/Strata/Languages/Python/Specs/ToLaurel.lean @@ -240,7 +240,7 @@ def specTypeToLaurelType (ty : SpecType) : ToLaurelM HighTypeMd := do if args.size > 0 then reportError default s!"Generic class '{name}' with type args unsupported" - return mkTy (.UserDefined name) + return mkTy (.UserDefined { text := name }) | .intLiteral _ => return mkTy .TInt | .stringLiteral _ => return mkTy .TString | .typedDict _ _ _ => return mkCore "DictStrAny" @@ -271,6 +271,7 @@ def funcDeclToLaurel (procName : String) (func : FunctionDecl) preconditions := [] determinism := .nondeterministic decreases := none + isFunctional := false body := .Opaque [] none [] md := .empty } @@ -282,7 +283,7 @@ def classDefToLaurel (cls : ClassDef) : ToLaurelM Unit := do pure { name := f.name, isMutable := true, type := ty : Laurel.Field } pushType (.Composite { name := cls.name - extending := cls.bases.toList.map toString + extending := cls.bases.toList.map (fun cd => mkId $ toString cd) fields := laurelFields instanceProcedures := [] }) diff --git a/StrataMain.lean b/StrataMain.lean index 852cf4e57..434314257 100644 --- a/StrataMain.lean +++ b/StrataMain.lean @@ -328,12 +328,11 @@ structure PySpecPrelude where are appended to the base prelude (with duplicates filtered out). Also accumulates overload dispatch tables. -/ def buildPySpecPrelude (pyspecPaths : Array String) : IO PySpecPrelude := do - -- Laurel.translate prepends corePrelude.decls to every output. - -- Add them once here and strip the prefix from each translated result. - -- Accumulate into an Array for efficient appending; build Core.Program at the end. - let laurelPreludeSize := Strata.Laurel.corePrelude.decls.length + -- The Laurel prelude is now included during HeapParameterization at the Laurel level. + -- We no longer need to strip it from translate output. + let laurelPreludeSize := 0 let mut preludeDecls : Array Core.Decl := - Strata.Python.Core.prelude.decls.toArray ++ Strata.Laurel.corePrelude.decls.toArray + Strata.Python.Core.prelude.decls.toArray let mut existingNames : Std.HashSet String := preludeDecls.foldl (init := {}) fun s d => (Core.Decl.names d).foldl (init := s) fun s n => s.insert n.name @@ -362,7 +361,8 @@ def buildPySpecPrelude (pyspecPaths : Array String) : IO PySpecPrelude := do | .error diagnostics => exitFailure s!"PySpec Laurel to Core translation failed for {ionPath}: {diagnostics}" | .ok (coreSpec, _modifiesDiags) => - -- Strip the Laurel corePrelude prefix (always emitted by Laurel.translate) + -- The Laurel prelude is now included at the Laurel level during HeapParameterization, + -- so translate output already contains the prelude declarations as normal decls. let pyspecDecls := coreSpec.decls.drop laurelPreludeSize -- Register new names, failing on collisions for d in pyspecDecls do @@ -444,27 +444,23 @@ def pyAnalyzeLaurelCommand : Command where IO.println "\n==== Core Program ====" IO.print (coreProgramDecls, modifiesDiags) - -- Strip the Laurel corePrelude prefix (always emitted by - -- Laurel.translate); already present in pyPrelude. - -- We don't want to strip types defined by the user program - -- (e.g., Class declarations), so we add those back. - let laurelPreludeSize := Strata.Laurel.corePrelude.decls.length - let droppedPrefix := coreProgramDecls.decls.take laurelPreludeSize - let programDecls := coreProgramDecls.decls.drop laurelPreludeSize - let pyPreludeDecls := pyPrelude.decls.map fun d => - match droppedPrefix.find? (fun pd => pd.name.name == d.name.name) with - | some replacement => replacement - | none => d + -- The Laurel prelude is now included at the Laurel level during + -- HeapParameterization, so translate output contains prelude decls as normal decls. + -- No stripping needed. + let programDecls := coreProgramDecls.decls -- Check for name collisions between program and prelude let preludeNames : Std.HashSet String := - pyPreludeDecls.flatMap Core.Decl.names + pyPrelude.decls.flatMap Core.Decl.names |>.foldl (init := {}) fun s n => s.insert n.name let collisions := programDecls.flatMap fun d => d.names.filter fun n => preludeNames.contains n.name if !collisions.isEmpty then let names := ", ".intercalate (collisions.map (·.name)) exitFailure s!"Core name collision between program and prelude: {names}" - let coreProgram := {decls := pyPreludeDecls ++ programDecls } + let coreProgram := {decls := pyPrelude.decls ++ programDecls } + -- dbg_trace "=== Generated Strata Core Program ===" + -- dbg_trace (toString (Std.Format.pretty (Strata.Core.formatProgram coreProgram) 100)) + -- dbg_trace "=================================" -- Verify using Core verifier let vcResults ← IO.FS.withTempDir (fun tempDir => diff --git a/StrataTest/DDM/PipeIdent.lean b/StrataTest/DDM/PipeIdent.lean index 66f1880dc..b1c5136ff 100644 --- a/StrataTest/DDM/PipeIdent.lean +++ b/StrataTest/DDM/PipeIdent.lean @@ -127,10 +127,10 @@ def getRHSIdent (op : Operation) : String := | _ => "" -- Verify: \| is unescaped to | in AST (stored with Lean's «» notation) -#guard (getRHSIdent testEscapeAST.commands[0]!) == "«name|with|pipes»" +#guard (getRHSIdent testEscapeAST.commands[0]!) == "name|with|pipes" -- Verify: \\ is unescaped to single \ in AST (stored with Lean's «» notation) -#guard (getRHSIdent testEscapeAST.commands[1]!) == "«path\\to\\file»" +#guard (getRHSIdent testEscapeAST.commands[1]!) == "path\\to\\file" -- Verify dots are preserved in AST def testDotIdent := #strata @@ -143,11 +143,11 @@ v := trailing..end; #end -- Verify: dots are preserved in identifier names in AST (stored with Lean's «» notation) -#guard (getRHSIdent testDotIdent.commands[0]!) == "«qualified.name»" -#guard (getRHSIdent testDotIdent.commands[1]!) == "«another.dotted.identifier»" -#guard (getRHSIdent testDotIdent.commands[2]!) == "«a..b»" -#guard (getRHSIdent testDotIdent.commands[3]!) == "«x...y»" -#guard (getRHSIdent testDotIdent.commands[4]!) == "«trailing..end»" +#guard (getRHSIdent testDotIdent.commands[0]!) == "qualified.name" +#guard (getRHSIdent testDotIdent.commands[1]!) == "another.dotted.identifier" +#guard (getRHSIdent testDotIdent.commands[2]!) == "a..b" +#guard (getRHSIdent testDotIdent.commands[3]!) == "x...y" +#guard (getRHSIdent testDotIdent.commands[4]!) == "trailing..end" -- Test dialect with | operator that has NO spaces in syntax definition #dialect diff --git a/StrataTest/Languages/B3/DDMFormatProgramsTests.lean b/StrataTest/Languages/B3/DDMFormatProgramsTests.lean index 7bf2f21fc..7cda00af1 100644 --- a/StrataTest/Languages/B3/DDMFormatProgramsTests.lean +++ b/StrataTest/Languages/B3/DDMFormatProgramsTests.lean @@ -38,12 +38,12 @@ section ProgramRoundtripTests /-- info: CST→AST Errors: - Unresolved identifier '«myFileSystemName: string»' - Unresolved identifier '«BlockPublicAcls: string»' - Unresolved identifier '«bucket: string»' - Unresolved identifier '«is-blocked: string»' - Unresolved identifier '«bucket: string»' - Unresolved identifier '«is-not-blocked: string»' + Unresolved identifier 'myFileSystemName: string' + Unresolved identifier 'BlockPublicAcls: string' + Unresolved identifier 'bucket: string' + Unresolved identifier 'is-blocked: string' + Unresolved identifier 'bucket: string' + Unresolved identifier 'is-not-blocked: string' AST→CST Errors: Variable index @2 is out of bounds (context has 2 variables) Variable index @12 is out of bounds (context has 12 variables) diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T3_ControlFlow.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T3_ControlFlow.lean index cb434dda2..295239f56 100644 --- a/StrataTest/Languages/Laurel/Examples/Fundamentals/T3_ControlFlow.lean +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T3_ControlFlow.lean @@ -51,40 +51,3 @@ procedure dag(a: int) returns (r: int) #guard_msgs (error, drop all) in #eval! testInputWithOffset "ControlFlow" program 14 processLaurelFile - -/- -Translation towards expression form: - -function guards(a: int): int { - var b = a + 2; - if (b > 2) { - var c = b + 3; - if (c > 3) { - c + 4; - } else { - var d = c + 5; - d + 6; - } - } else { - var e = b + 1; - e - } -} - -To translate towards SMT we only need to apply something like WP calculus. - Here's an example of what that looks like: - -function dag(a: int): int { - ( - assume a > 0; - assume b == 1; - b; - ) - OR - ( - assume a <= 0; - assume b == 2; - b; - ) -} --/ diff --git a/StrataTest/Languages/Laurel/LiftExpressionAssignmentsTest.lean b/StrataTest/Languages/Laurel/LiftExpressionAssignmentsTest.lean index dc486f7de..92a19f0b9 100644 --- a/StrataTest/Languages/Laurel/LiftExpressionAssignmentsTest.lean +++ b/StrataTest/Languages/Laurel/LiftExpressionAssignmentsTest.lean @@ -47,7 +47,10 @@ def parseLaurelAndLift (input : String) : IO Program := do let uri := Strata.Uri.file "test" match Laurel.TransM.run uri (Laurel.parseProgram strataProgram) with | .error e => throw (IO.userError s!"Translation errors: {e}") - | .ok program => pure (liftImperativeExpressions program) + | .ok program => + let result := resolve program + let (program, model) := (result.program, result.model) + pure (liftExpressionAssignments model program) /-- info: procedure heapUpdateInBlockExpr(b: Box) returns ⏎ diff --git a/docs/verso/LaurelDoc.lean b/docs/verso/LaurelDoc.lean index b1abaa274..9d13497c7 100644 --- a/docs/verso/LaurelDoc.lean +++ b/docs/verso/LaurelDoc.lean @@ -12,7 +12,6 @@ import Strata.Languages.Laurel.LaurelToCoreTranslator import Strata.Languages.Laurel.HeapParameterization import Strata.Languages.Laurel.LiftImperativeExpressions import Strata.Languages.Laurel.ModifiesClauses -import Strata.Languages.Laurel.CorePrelude open Strata.Laurel @@ -184,10 +183,3 @@ The final translation converts Laurel types, expressions, statements, and proced their Strata Core equivalents. Procedures with bodies that only have constructs supported by Core expressions are translated to a Core function, while other procedures become Core procedures. - -## Core Prelude - -The Laurel translator prepends a Core prelude that defines the heap model types and -operations: `Composite`, `Field`, `Box`, `readField`, and `updateField`. - -{docstring Strata.Laurel.corePreludeDDM}