Here is the code that compiles the simply typed lambda calculus for the untyped interpreter of my previous message.
The main function that actually does something here is
compile, which performs typechecking while doing a translation to the untyped lambda calculus—the
run function just compiles and then invokes the untyped interpreter.
The structure of the simply typed lambda calculus is largely parallel to that of the untyped. I could perhaps have tried to factor out the differences between the corresponding
Term/
Type/
Value/
Env/
Var/
Lam/
App/
Con/
If classes, but there are enough minor differences between the two calculi that I doubt the result would have been more readable than just writing corresponding pairs of not-quite-identical classes.
I've introduced a primitive
Fix term to permit recursion, which cannot otherwise be given a correct simple type. The corresponding construct in the untyped implementation is just a function
fix that applies a fixpoint combinator, since recursion
can be expressed directly in the untyped calculus.
Fix (sometimes known as the mu operator) has a structure parallel to
Lam, and I've factored out their commonalities into a superclass
Parmed.
Compilation conceptually just verifies that the declared simple types match up, and translate the elements of the simply typed calculus to their untyped equivalents. If compilation fails,
compile produces a list of
String error messages. The place where the
App case checks whether
funArgT == argT is most characteristic of simple types: function parameter and argument types must match exactly. A more complicated type system would have to more work here—for example, an implementation of parametric polymorphism might need to unify different type variables.
package lambdas
import scala.language.implicitConversions
object SimplyTyped {
import lambdas.{ Untyped => Un }
// Simply-typed lambda calculus top-level types
trait Term {
def ^: (vt: VarTyping) = Lam(vt.v,vt.t,this)
def ^^: (vt: VarTyping) = Fix(vt.v,vt.t,this)
def apply (arg: Term) = App(this,arg)
}
trait Type {
def :: (v: Var) = VarTyping(v,this)
def =>: (arg: Type) = FunType(arg,this)
}
trait Value {
def t: Type
def unValue: Un.Value
}
type Env = List[(Symbol,Type)]
val emptyEnv = List.empty[(Symbol,Type)]
// Terms
case class Var (sym: Symbol) extends Term {
override def toString: String = sym.toString
}
implicit def symbolToVar (sym: Symbol): Var = Var(sym)
case class VarTyping (v: Var, t: Type) // for syntactic convenience with ::
abstract class Parmed (operator: String) extends Term {
def v: Var
def vType: Type
def body: Term
def overallType (bodyType: Type): Type
def toUntyped (unV: Un.Var, unBody: Un.Term): Un.Term
override def toString: String =
vType match {
case _: BaseType[_] => s"$v :: $vType $operator $body"
case _: FunType => s"$v :: ($vType) $operator $body"
}
}
case class Lam (v: Var, vType: Type, body: Term) extends Parmed("^:") {
def overallType (bodyType: Type) = vType =>: bodyType
def toUntyped (unV: Un.Var, unBody: Un.Term) = Un.Lam(unV,unBody)
}
case class App (fun: Term, arg: Term) extends Term {
override def toString: String =
fun match {
case _: Parmed => s"($fun)($arg)"
case _ => s"$fun($arg)"
}
}
case class Con (v: Value) extends Term {
override def toString: String = v.toString
}
implicit def unitToCon (u: Unit) = Con(BaseValue(u,U))
implicit def booleanToCon (bo: Boolean) = Con(BaseValue(bo,B))
implicit def intToCon (i: Int) = Con(BaseValue(i,I))
implicit def funToCon (f: Fun) = Con(f)
case class If (cond: Term, ifTrue: Term, ifFalse: Term) extends Term
def and (a: Term, b: Term) = If(a,b,false)
def or (a: Term, b: Term) = If(a,true,b)
case class Fix (v: Var, vType: Type, body: Term) extends Parmed("^^:") {
def overallType (bodyType: Type) = bodyType
def toUntyped (unV: Un.Var, unBody: Un.Term) = Un.fix(unV,unBody)
}
// Types
abstract class BaseType [+Underlying] (override val toString: String)
extends Type
case class FunType (arg: Type, result: Type) extends Type {
override def toString: String =
arg match {
case _: BaseType[_] => s"$arg =>: $result"
case _: FunType => s"($arg) =>: $result"
}
}
// Values (compiletime values only, so no function closure here)
case class BaseValue [+Underlying] (v: Underlying, t: BaseType[Underlying])
extends Value {
def unValue = Un.BaseValue(v)
override def toString: String = v.toString
}
case class Fun (t: FunType, unValue: Un.Fun) extends Value {
override def toString: String = unValue.toString
}
// Primitive types
object U extends BaseType[Unit]("U")
object B extends BaseType[Boolean]("B")
object I extends BaseType[Int]("I")
// Built-in functions
val not = Fun(B =>: B,Un.not)
val neg = Fun(I =>: I,Un.neg)
val (plus,minus,times,div,mod) =
(Fun(I =>: I =>: I,Un.plus),Fun(I =>: I =>: I,Un.minus),
Fun(I =>: I =>: I,Un.times),Fun(I =>: I =>: I,Un.div),
Fun(I =>: I =>: I,Un.mod))
val (equ,gt,ge,lt,le,ne) =
(Fun(I =>: I =>: B,Un.equ),Fun(I =>: I =>: B,Un.gt),
Fun(I =>: I =>: B,Un.ge),Fun(I =>: I =>: B,Un.lt),
Fun(I =>: I =>: B,Un.le),Fun(I =>: I =>: B,Un.ne))
// Compilation to untyped lambda calculus
def compile (term: Term, env: Env): Either[List[String],(Un.Term,Type)] = {
def oops (s: String) = Left(List(s))
def compileParmed (
p: Parmed, vSym: Symbol, vType: Type, body: Term, env: Env) =
compile(body,(vSym -> vType) :: env) match {
case Right((unBody,bodyT)) =>
Right((p.toUntyped(Un.Var(vSym),unBody),p.overallType(bodyT)))
case left @ Left(_) =>
left
}
term match {
case Var(sym) =>
env.find(_._1 == sym).map(_._2) match {
case Some(t) => Right((Un.Var(sym),t))
case None => oops(s"Undefined variable $sym")
}
case lam @ Lam(Var(sym),vType,body) =>
compileParmed(lam,sym,vType,body,env)
case App(fun,arg) =>
(compile(fun,env),compile(arg,env)) match {
case (Right((unFun,funT)),Right((unArg,argT))) =>
funT match {
case FunType(funArgT,funResT) =>
if (funArgT == argT) Right((Un.App(unFun,unArg),funResT))
else oops(s"Can't apply function of type $funT to $argT")
case ft =>
oops(s"$funT is not a function type; can't apply to $argT")
}
case (Left(funErrs),Left(argErrs)) =>
Left(funErrs ++ argErrs)
case (left @ Left(_),_) =>
left
case (_,left @ Left(_)) =>
left
}
case Con(v) =>
Right((Un.Con(v.unValue),v.t))
case If(cond,ifTrue,ifFalse) =>
List(compile(cond,env),compile(ifTrue,env),compile(ifFalse,env))
match {
case List(Right((unCond,condT)),Right((unT,tT)),Right((unF,fT))) =>
val condErrs =
if (condT == B) Nil
else List(s"If condition $cond type is $condT, not B")
val branchErrs =
if (tT == fT)
condErrs
else
condErrs :+
(s"If true branch $ifTrue is $tT, " +
s"but false branch $ifFalse is $fT")
if (branchErrs.isEmpty) Right((Un.If(unCond,unT,unF),tT))
else Left(branchErrs)
case errStrs =>
Left(
errStrs.foldLeft(List.empty[String]) { case (ss,lr) =>
lr match {
case Right(_) => ss
case Left(ss1) => ss ++ ss1
}
})
}
case fix @ Fix(Var(sym),vType,body) =>
compileParmed(fix,sym,vType,body,env)
}
}
// Execution
def run (term: Term, env: Env = emptyEnv): Option[Any] =
compile(term,env) match {
case Right((unTerm,t)) =>
println(s"\nCompiled:\n\n$term\n\nto:\n\n$unTerm\n\nof type $t.\n")
println("Running...\n")
val result = Un.run(unTerm)
println(s"Result:\n$result\n")
Some(result)
case Left(errStrs) =>
println(s"\nCompilation of:\n\n$term\n\nfailed because:\n")
println(errStrs.mkString("\n\n"))
println
None
}
def check (term: Term, desiredResult: Option[Any]) = {
assert(run(term) == desiredResult.map(Un.BaseValue(_)))
println
}
val factorial =
'fac :: (I =>: I) ^^:
'n :: I ^:
If(equ('n)(1),1,times('n)('fac(minus('n)(1))))
val odd =
'odd :: (I =>: B) ^^:
'n :: I ^:
('even :: (I =>: B) ^:
If(equ(0)('n),false,'even(minus('n)(1))))(
'n :: I ^: If(equ(0)('n),true,'odd(minus('n)(1))))
val intRangeFoldLeft =
'fold :: ((I =>: I =>: I) =>: I =>: I =>: I =>: I) ^^:
'f :: (I =>: I =>: I) ^:
'acc :: I ^:
'lo :: I ^:
'hi :: I ^:
If(
gt('lo)('hi),
'acc,
'fold('f)('f('acc)('lo))(plus('lo)(1))('hi))
def main (args: Array[String]): Unit = {
check(factorial(5),Some(120))
check(odd(3),Some(true))
check(odd(4),Some(false))
check(intRangeFoldLeft(plus)(0)(1)(10),Some(55))
check(plus(2)(false),None)
check(If(true,2,()),None)
check(If((),1,2),None)
check(If(3,plus(2)(false),not(7)),None)
check(App(2,3),None)
}
}