18 January 2012

In programming languages, a delimited continuation, composable continuation or partial continuation, is a "slice" of a continuation frame that has been reified into a function.

— Wikipedia

这次使用 Scala 作为示范语言介绍 Delimited continuation, 因为 Scala 是我知道的除了括号语言(Lisp, Scheme 及其方言)之外, 唯一一个支持 Delimited continuation 的语言.

按照惯例, 还是从 Fibonacci 函数开始, 这里定义一个 Ruby yield 风格的 Fibonacci 函数:

def fibonacci(f: BigDecimal => Unit) {
  var (i, j) = (BigDecimal(1), BigDecimal(1))
  while (true) {
    f(j)
    val t = i + j; i = j; j = t
  }
}

fibonacci { println _ }

这个函数是一个高阶函数
[Higher-order function: 参数或者返回值为函数的函数]
, 其参数 f 是一个能处理 BigDecimal 的函数. 虽然这个函数工作得非常好, 但是有几个致命缺点:

  • 无法终止: 一旦开始遍历, 就没法停下来

  • 难以控制: 比如打印[100,10000)区间的 fibonacci 数, 或者打印第10~15个 fibonacci 数, 这种需求很难实现

如果我们能够实现一个 fibonacci 数的 Iterator, 通过 Scala 的 Iterator 库, 很容易就可以实现这些需求:

// print all fibonacci number inside [100,10000)
fibonacciIteritor
  .dropWhile(_ < 100)
  .takeWhile(_ < 10000)
  .foreach(println _)

// print the 10th ~ 15th fibonacci number
fibonacciIteritor
  .slice(10, 15)
  .foreach(println _)

现在就变个魔术, 把上面的 Ruby yield 风格的 fibonacci 函数, 变成一个 Iterator:

object fibonacciSeq {
  import scala.util.continuations._
  private var continuation: Unit => BigDecimal = reset {
    shift(identity[Unit => BigDecimal])
    var (i, j) = (BigDecimal(1), BigDecimal(1))
    while (true) {
      shift { k: (Unit => BigDecimal) => continuation = k; j } // f(j)
      val t = i + j; i = j; j = t
    }
    null // unreachable code, trick the compiler's type system
  }

  def next = continuation()
}

val fibonacciIteritor = Iterator.continually(fibonacciSeq.next)

这段代码最关键的东西, 就是 reset, shift 这两个函数. 在 Scala 的语法规则中, 这两个函数仅仅是普通的高阶函数
[hof]
, 其参数就是后面的 { … } 部分. reset 函数接受一段代码作为其参数, shift 函数接受一个函数作为其参数, 下面是这两个函数的声明:

def reset[A,C](ctx: =>(A @cpsParam[A,C])): C
def shift[A,B,C](fun: (A => B) => C): A @cpsParam[B,C]

reset 函数定义了一个闭锁空间(其参数), 在这个空间里面, 会发生一些违背物理常识的事情. shift 就是这个闭锁空间中的凉宫春日, 可以自由自在地操作这个空间, 具体方式就是操作时间流动(程序流程). 当程序运行到 shift 函数的时候, 会中断闭锁空间内部程序的执行, 生成一个当前闭锁空间的 snapshot, 包含当前的调用堆栈等信息, 然后调用 shift 函数的参数, 当这个函数执行完之后, 程序*不会*从 shift 函数后面继续执行, 而是跳出闭锁空间, 从当初进入闭锁空间的地方继续执行. 比如下面这段代码:

scala> reset {
     |   shift { k: (Unit => Unit) =>
     |     println("get here")
     |   }
     |   println("never reach here")
     | }
get here

shift 函数的参数执行完之后, 跳出闭锁空间, 回到当初进入的地方, 所以 shift 参数的返回值, 就是 reset 的返回值:

scala> val result = reset {
     |   shift { k: (Unit => Int) => 1 }
     |   2
     | }
result: Int = 1

根据上面 shift 函数的定义, shift 函数的参数也是一个函数, 这个函数接受一个参数, 这个参数就是当前闭锁空间的 snapshot. 现在来关注一下这个 snapshot, 首先这个 snapshot 代表了一段程序(从当前 shift 函数返回开始到闭锁空间内剩下的程序流程), 所以这个 snapshot 一定是一个函数闭包, 既然是函数, 那么就必须有参数和返回值, snapshot 对象是在进入 shift 函数时生成的, 那么其入口也就是从当前 shift 返回开始, snapshot 函数的参数就是 shift 的返回值:

scala> reset {
     |   val x = shift { snapshot: (Int => Unit) =>
     |     snapshot(1)
     |     snapshot(5)
     |     snapshot(25)
     |   }
     |   println("shifted: " + x)
     | }
shifted: 1
shifted: 5
shifted: 25

这里 snapshot 函数就捕捉了从 shift 返回开始, 到 reset 结束为止的程序流程, 即相当于:

val snapshot = { argument: Int =>
  val x = argument
  println("shifted: " + x)
}

下面考察一下 snapshot 函数的返回值. reset 作为一个程序块, 必然有返回值, 既然 snapshot 捕捉了到 reset 为止的所有程序流程, 那么 snapshot 的返回值就应该是 reset 程序块的返回值:

scala> reset {
     |   val x = shift { snapshot: (Unit => String) =>
     |     println(snapshot())
     |   }
     |   "I'm the return value"
     | }
I'm the return value

现在, 讲解了魔术的原理, 再回头看看上面那个 fibonacciSeq 实现, 应该知道其中变化了. 下次会介绍一些实际的应用场景