My simply typed lambda calculus interpreter

22 views
Skip to first unread message

Adrian King

unread,
Apr 8, 2014, 12:43:16 PM4/8/14
to sf-...@googlegroups.com
What's the best way to start writing an interpreter for the simply typed lambda calculus? Well, by writing an interpreter for the untyped lambda calculus. The untyped lambda calculus is a good target to compile to after you've finished typechecking a simply typed program, and simple types would be ignored at execution time anyway.

Below is the untyped interpreter, in Scala, that I used as my target for this exercise. It's a bit different from the Scala interpreter I wrote for the previous exercise on the untyped lambda calculus:
  • To to the basic lambda calculus Var/Lam/App constructs, I've added Con (for constants of types Unit, Boolean, and Int) and If (baking in booleans and conditionals gives them a more familiar look than defining them as functions, which you're forced to do in the unaugmented calculus). Runtime values include not only Unit, Boolean, and Int constants, but also function closures (Clo) and built-in functions (Fun) that manipulate the three supported base types.
  • I've tightened up my syntactic sugar for my internal DSL representation of the lambda calculus. There are implicit conversions from everything you're likely to want to put into a lambda calculus expression, and conversion of a term to a String emits fewer redundant parentheses. The intention is that if you write a lambda calculus DSL expression as clearly as possible in the Scala source, and then call toString on that, you should get back the same Scala source text.
  • For the previous exercise, I wrote interpreters in two different styles—structural reduction by rewriting the lambda calculus term, and a closure-based tail-recursive interpreter in SECD style. Here I have a third style of interpreter, a naive (non-tail-recursive) closure-based interpreter of the type you're likely to write in a beginning course on programming languages. This is probably the most familiar and readable type of interpreter, but the lack of tail recursion means that it's likely to overflow the Scala stack if you try to make it do a nontrivial amount of work.


package lambdas

import scala.language.implicitConversions

object Untyped {

  // Untyped lambda calculus top-level types
  trait Term {
    def ^: (v: Var) = Lam(v,this)
    def apply (arg: Term) = App(this,arg)
  }
  trait Value
  type Env = List[(Symbol,Value)]
  val emptyEnv = List.empty[(Symbol,Value)]

  // Terms
  case class Var (sym: Symbol) extends Term {
    override def toString: String = sym.toString
  }
  implicit def symbolToVar (sym: Symbol): Var = Var(sym)
  case class Lam (v: Var, body: Term) extends Term {
    override def toString: String = s"$v ^: $body"
  }
  case class App (fun: Term, arg: Term) extends Term {
    override def toString: String =
      fun match {
        case _: Lam => s"($fun)($arg)"
        case _ => s"$fun($arg)"
      }
  }
  case class Con (v: Value) extends Term {
    override def toString: String = v.toString
  }
  implicit def unitToCon (u: Unit) = Con(BaseValue(u))
  implicit def booleanToCon (bo: Boolean) = Con(BaseValue(bo))
  implicit def intToCon (i: Int) = Con(BaseValue(i))
  implicit def funToCon (f: Fun) = Con(f)
  case class If (cond: Term, ifTrue: Term, ifFalse: Term) extends Term
  def fix (v: Var, body: Term) = {
    // call-by-value fixpoint combinator:
    val z = 'f ^: ('x ^: 'f('v ^: ('x('x))('v)))('x ^: 'f('v ^: ('x('x))('v)))
    z(Lam(v,body))
  }

  // Values
  case class BaseValue [+Underlying] (v: Underlying) extends Value {
    override def toString: String = v.toString
  }
  implicit def unitToValue (u: Unit) = BaseValue(u)
  implicit def booleanToValue (bo: Boolean) = BaseValue(bo)
  implicit def intToValue (i: Int) = BaseValue(i)
  abstract class Fun (val arity: Int) extends Value {
    assert(arity > 0)
    def apply (term: Term) = App(Con(this),term)
    final def call (args: Value*): Value = {
      assert(args.size == arity)
      eval(args: _*)
    }
    protected def eval (args: Value*): Value
  }
  case class Clo (lam: Lam, env: Env) extends Fun(1) {
    protected def eval (args: Value*) =
      run(lam.body,(lam.v.sym -> args.head) :: env)
    def toString (depth: Int): String = {
      val indent = "  " * depth
      val indent1 = indent + "  "
      val envStr =
        env.map { case (sym,v) =>
          v match {
            case clo: Clo => s"$sym =\n${clo.toString(depth + 2)}"
            case _ => s"$sym = $v"
          }
        }.mkString("\n" + indent1)
      indent +
        (if (envStr.isEmpty) lam.toString
          else s"$lam\n${indent}where:\n$indent1$envStr")
    }
    override def toString: String = toString(0)
  }

  // Built-in functions
  val not =
    new Fun(1) {
      protected def eval (args: Value*) = {
        val Seq(bo) = args
        bo match {
          case BaseValue(true) => BaseValue(false)
          case BaseValue(false) => BaseValue(true)
        }
      }
      override def toString: String = "not"
    }
  val neg =
    new Fun(1) {
      protected def eval (args: Value*) = {
        val Seq(a) = args
        a match {
          case BaseValue(a: Int) =>
            BaseValue(- a)
        }
      }
      override def toString: String = "neg"
    }
  def intBinOp (f: (Int,Int) => Int, name: String) =
    new Fun(2) {
      protected def eval (args: Value*) = {
        val Seq(a,b) = args
        a match {
          case BaseValue(a : Int) =>
            b match {
              case BaseValue(b : Int) =>
                BaseValue(f(a,b))
            }
        }
      }
      override def toString: String = name
    }
  val (plus,minus,times,div,mod) =
    (intBinOp(_ + _,"plus"),intBinOp(_ - _,"minus"),intBinOp(_ * _,"times"),
      intBinOp(_ / _,"div"),intBinOp(_ % _,"mod"))
  def intCompareOp (f: (Int,Int) => Boolean, name: String) =
    new Fun(2) {
      protected def eval (args: Value*) = {
        val Seq(a,b) = args
        a match {
          case BaseValue(a : Int) =>
            b match {
              case BaseValue(b : Int) =>
                BaseValue(f(a,b))
            }
        }
      }
      override def toString: String = name
    }
    val (equ,gt,ge,lt,le,ne) =
      (intCompareOp(_ == _,"equ"),intCompareOp(_ > _,"gt"),
        intCompareOp(_ >= _,"ge"),intCompareOp(_ < _,"lt"),
        intCompareOp(_ <= _,"le"),intCompareOp(_ != _,"ne"))

  // Execution. Beware: interpreter is not tail-recursive!
  def run (term: Term, env: Env = emptyEnv): Value = term match {
    case Var(sym) =>
      env. find(_._1 == sym) match {
        case Some((_,v)) => v
        case _ => sys.error(s"Undefined variable $sym")
      }
    case lam: Lam =>
      Clo(lam,env)
    case App(fun,arg) =>
      (run(fun,env),run(arg,env)) match {
        case (Clo(Lam(v,body),env),arg) =>
          run(body,(v.sym -> arg) :: env)
        case (f: Fun,arg) =>
          if (f.arity == 1)
            f.call(arg)
          else
            new Fun(f.arity - 1) {
              protected def eval (args: Value*) = f.call((arg +: args): _*)
              override def toString: String = s"$f($arg)"
            }
        case (badFun,arg) =>
          sys.error(s"$badFun is not a closure; cannot apply to $arg")
      }
    case Con(v) =>
      v
    case If(cond,ifTrue,ifFalse) =>
      run(cond,env) match {
        case BaseValue(true) => run(ifTrue,env)
        case BaseValue(false) => run(ifFalse,env)
        case v => sys.error(s"$v is not a boolean; cannot govern If")
      }
  }

}

Adrian King

unread,
Apr 8, 2014, 1:05:55 PM4/8/14
to sf-...@googlegroups.com
Here is the code that compiles the simply typed lambda calculus for the untyped interpreter of my previous message.

The main function that actually does something here is compile, which performs typechecking while doing a translation to the untyped lambda calculus—the run function just compiles and then invokes the untyped interpreter.

The structure of the simply typed lambda calculus is largely parallel to that of the untyped. I could perhaps have tried to factor out the differences between the corresponding Term/Type/Value/Env/Var/Lam/App/Con/If classes, but there are enough minor differences between the two calculi that I doubt the result would have been more readable than just writing corresponding pairs of not-quite-identical classes.

I've introduced a primitive Fix term to permit recursion, which cannot otherwise be given a correct simple type. The corresponding construct in the untyped implementation is just a function fix that applies a fixpoint combinator, since recursion can be expressed directly in the untyped calculus. Fix (sometimes known as the mu operator) has a structure parallel to Lam, and I've factored out their commonalities into a superclass Parmed.

Compilation conceptually just verifies that the declared simple types match up, and translate the elements of the simply typed calculus to their untyped equivalents. If compilation fails, compile produces a list of String error messages. The place where the App case checks whether funArgT == argT is most characteristic of simple types: function parameter and argument types must match exactly. A more complicated type system would have to more work here—for example, an implementation of parametric polymorphism might need to unify different type variables.



package lambdas

import scala.language.implicitConversions

object SimplyTyped {

  import lambdas.{ Untyped => Un }
 
  // Simply-typed lambda calculus top-level types
  trait Term {
    def ^: (vt: VarTyping) = Lam(vt.v,vt.t,this)
    def ^^: (vt: VarTyping) = Fix(vt.v,vt.t,this)

    def apply (arg: Term) = App(this,arg)
  }
  trait Type {
    def :: (v: Var) = VarTyping(v,this)
    def =>: (arg: Type) = FunType(arg,this)
  }
  trait Value {
    def t: Type
    def unValue: Un.Value
  }
  type Env = List[(Symbol,Type)]
  val emptyEnv = List.empty[(Symbol,Type)]


  // Terms
  case class Var (sym: Symbol) extends Term {
    override def toString: String = sym.toString
  }
  implicit def symbolToVar (sym: Symbol): Var = Var(sym)
  case class VarTyping (v: Var, t: Type) // for syntactic convenience with ::
  abstract class Parmed (operator: String) extends Term {
    def v: Var
    def vType: Type
    def body: Term
    def overallType (bodyType: Type): Type
    def toUntyped (unV: Un.Var, unBody: Un.Term): Un.Term

    override def toString: String =
      vType match {
        case _: BaseType[_] => s"$v :: $vType $operator $body"
        case _: FunType => s"$v :: ($vType) $operator $body"
      }   
  }
  case class Lam (v: Var, vType: Type, body: Term) extends Parmed("^:") {
    def overallType (bodyType: Type) = vType =>: bodyType
    def toUntyped (unV: Un.Var, unBody: Un.Term) = Un.Lam(unV,unBody)

  }
  case class App (fun: Term, arg: Term) extends Term {
    override def toString: String =
      fun match {
        case _: Parmed =>  s"($fun)($arg)"

        case _ => s"$fun($arg)"
      }
  }
  case class Con (v: Value) extends Term {
    def t: Type = v.t

    override def toString: String = v.toString
  }
  implicit def unitToCon (u: Unit) = Con(BaseValue(u,U))
  implicit def booleanToCon (bo: Boolean) = Con(BaseValue(bo,B))
  implicit def intToCon (i: Int) = Con(BaseValue(i,I))

  implicit def funToCon (f: Fun) = Con(f)
  case class If (cond: Term, ifTrue: Term, ifFalse: Term) extends Term
  def and (a: Term, b: Term) = If(a,b,false)
  def or (a: Term, b: Term) = If(a,true,b)
  case class Fix (v: Var, vType: Type, body: Term) extends Parmed("^^:") {
    def overallType (bodyType: Type) = bodyType
    def toUntyped (unV: Un.Var, unBody: Un.Term) = Un.fix(unV,unBody)
  }

  // Types
  abstract class BaseType [+Underlying] (override val toString: String)
    extends Type
  case class FunType (arg: Type, result: Type) extends Type {

    override def toString: String =
      arg match {
        case _: BaseType[_] => s"$arg =>: $result"
        case _: FunType => s"($arg) =>: $result"
      }
  }

  // Values (compiletime values only, so no function closure here)
  case class BaseValue [+Underlying] (v: Underlying, t: BaseType[Underlying])
      extends Value {
    def unValue = Un.BaseValue(v)

    override def toString: String = v.toString
  }
  case class Fun (t: FunType, unValue: Un.Fun) extends Value {
    override def toString: String = unValue.toString
  }

  // Primitive types
  object U extends BaseType[Unit]("U")
  object B extends BaseType[Boolean]("B")
  object I extends BaseType[Int]("I")

  // Built-in functions
  val not = Fun(B =>: B,Un.not)
  val neg = Fun(I =>: I,Un.neg)
  val (plus,minus,times,div,mod) =
    (Fun(I =>: I =>: I,Un.plus),Fun(I =>: I =>: I,Un.minus),
      Fun(I =>: I =>: I,Un.times),Fun(I =>: I =>: I,Un.div),
      Fun(I =>: I =>: I,Un.mod))
  val (equ,gt,ge,lt,le,ne) =
    (Fun(I =>: I =>: B,Un.equ),Fun(I =>: I =>: B,Un.gt),
      Fun(I =>: I =>: B,Un.ge),Fun(I =>: I =>: B,Un.lt),
      Fun(I =>: I =>: B,Un.le),Fun(I =>: I =>: B,Un.ne))
 
  // Compilation to untyped lambda calculus
  def compile (term: Term, env: Env): Either[List[String],(Un.Term,Type)] = {
    def oops (s: String) = Left(List(s))
    def compileParmed (
        p: Parmed, vSym: Symbol, vType: Type, body: Term, env: Env) =
      compile(body,(vSym -> vType) :: env) match {
        case Right((unBody,bodyT)) =>
          Right((p.toUntyped(Un.Var(vSym),unBody),p.overallType(bodyT)))
        case left @ Left(_) =>
          left

      }
    term match {
      case Var(sym) =>
        env.find(_._1 == sym).map(_._2) match {
          case Some(t) => Right((Un.Var(sym),t))
          case None => oops(s"Undefined variable $sym")
        }
      case lam @ Lam(Var(sym),vType,body) =>
        compileParmed(lam,sym,vType,body,env)
      case App(fun,arg) =>
        (compile(fun,env),compile(arg,env)) match {
          case (Right((unFun,funT)),Right((unArg,argT))) =>
            funT match {
              case FunType(funArgT,funResT) =>
                if (funArgT == argT) Right((Un.App(unFun,unArg),funResT))
                else oops(s"Can't apply function of type $funT to $argT")
              case ft =>
                oops(s"$funT is not a function type; can't apply to $argT")
            }
          case (Left(funErrs),Left(argErrs)) =>
            Left(funErrs ++ argErrs)
          case (left @ Left(_),_) =>
            left
          case (_,left @ Left(_)) =>
            left
        }
      case Con(v) =>
        Right((Un.Con(v.unValue),v.t))
      case If(cond,ifTrue,ifFalse) =>
        List(compile(cond,env),compile(ifTrue,env),compile(ifFalse,env))
            match {
          case List(Right((unCond,condT)),Right((unT,tT)),Right((unF,fT))) =>
            val condErrs =
              if (condT == B) Nil
              else List(s"If condition $cond type is $condT, not B")
            val branchErrs =
              if (tT == fT)
                condErrs
              else
                condErrs :+
                  (s"If true branch $ifTrue is $tT, " +
                    s"but false branch $ifFalse is $fT")
            if (branchErrs.isEmpty) Right((Un.If(unCond,unT,unF),tT))
            else Left(branchErrs)
          case errStrs =>
            Left(
              errStrs.foldLeft(List.empty[String]) { case (ss,lr) =>
                lr match {
                  case Right(_) => ss
                  case Left(ss1) => ss ++ ss1
                }
              })
        }
      case fix @ Fix(Var(sym),vType,body) =>
        compileParmed(fix,sym,vType,body,env)
    }
  }

  // Execution
  def run (term: Term, env: Env = emptyEnv): Option[Any] =
    compile(term,env) match {
      case Right((unTerm,t)) =>
        println(s"\nCompiled:\n\n$term\n\nto:\n\n$unTerm\n\nof type $t.\n")
        println("Running...\n")
        val result = Un.run(unTerm)
        println(s"Result:\n$result\n")
        Some(result)
      case Left(errStrs) =>
        println(s"\nCompilation of:\n\n$term\n\nfailed because:\n")
        println(errStrs.mkString("\n\n"))
        println
        None
    }

  def check (term: Term, desiredResult: Option[Any]) = {
    assert(run(term) == desiredResult.map(Un.BaseValue(_)))
    println
  }
  val factorial =
    'fac :: (I =>: I) ^^:
      'n :: I ^:
        If(equ('n)(1),1,times('n)('fac(minus('n)(1))))
  val odd =
    'odd :: (I =>: B) ^^:
      'n :: I ^:
        ('even :: (I =>: B) ^:
            If(equ(0)('n),false,'even(minus('n)(1))))(
          'n :: I ^: If(equ(0)('n),true,'odd(minus('n)(1))))
  val intRangeFoldLeft =
    'fold :: ((I =>: I =>: I) =>: I =>: I =>: I =>: I) ^^:
      'f :: (I =>: I =>: I) ^:
        'acc :: I ^:
          'lo :: I ^:
            'hi :: I ^:
              If(
                gt('lo)('hi),
                'acc,
                'fold('f)('f('acc)('lo))(plus('lo)(1))('hi))
               

  def main (args: Array[String]): Unit = {
    check(factorial(5),Some(120))
    check(odd(3),Some(true))
    check(odd(4),Some(false))
    check(intRangeFoldLeft(plus)(0)(1)(10),Some(55))
    check(plus(2)(false),None)
    check(If(true,2,()),None)
    check(If((),1,2),None)
    check(If(3,plus(2)(false),not(7)),None)
    check(App(2,3),None)
  }

}
Reply all
Reply to author
Forward
0 new messages