30 January 2012

前面大量篇幅介绍了很多理论知识,这次介绍一下 delimited continuation 的一个应用,使用 Scala 的 reset/shift 实现一个单线程的 Echo Server

首先看一下下面这段代码:

object Echo extends App {
  val server = ServerSocketChannel.open
  server.socket.bind(new InetSocketAddress(12345))
  while (true) {
    server.accept match {
      case c: SocketChannel =>
        println("Accept: " + c)
        while (c.isOpen && c.isConnected) {
          val bb = ByteBuffer.allocateDirect(1024)
          c.read(bb) match {
            case count if count > 0 =>
              println("Read: " + c + " count: " + count)
              bb.flip
              while (bb.hasRemaining) {
                c.write(bb) match {
                  case count if count > 0 =>
                    println("Write: " + c + " count: " + count)
                  case count if count == 0 =>
                    println("WriteBlock: " + c)
                  case _ =>
                    println("WriteError: " + c)
                    bb.clear
                    c.close
                }
              }
            case count if count == 0 =>
              println("ReadBlock: " + c)
            case _ =>
              println("ReadError: " + c)
              c.close
          }
        }
      case null =>
        println("AcceptBlock")
    }
  }
}

这看起来像是一个 Echo Server,不过写出这种代码的哥们估计第二天就会被炒鱿鱼了。众所周知,实现单线程的多 TCP 连接程序,肯定要使用 Select/Epoll 这些多路 IO 复用机制,如果不使用这些机制,就会像上面这段程序一样,只能处理一条连接。但是这段程序并不是一无是处,它*逻辑清晰,简单易懂*。几十行代码一气呵成,把需要的逻辑都包含了:接受连接,读取数据,返回数据。后来我想尝试用 Java NIO 写了一下 Echo Server,不过在我动手前,我的脑袋已经被各种回调扭成麻花了,如果没有 mina netty 这种框架,写网络程序无异于自杀。

回到上面那段有问题的程序,将错就错,先使用多线程,让它能跑起来:

object EchoMultiThread extends App {
  val server = ServerSocketChannel.open
  server.socket.bind(new InetSocketAddress(12345))
  while (true) {
    server.accept match {
      case c: SocketChannel =>
        (new Thread(new Runnable {  // <============================= 1
          override def run { // <==================================== 2
            println("Accept: " + c)
            while (c.isOpen && c.isConnected) {
              val bb = ByteBuffer.allocateDirect(1024)
              c.read(bb) match {
                case count if count > 0 =>
                  println("Read: " + c + " count: " + count)
                  bb.flip
                  while (bb.hasRemaining) {
                    c.write(bb) match {
                      case count if count > 0 =>
                        println("Write: " + c + " count: " + count)
                      case count if count == 0 =>
                        println("WriteBlock: " + c)
                      case _ =>
                        println("WriteError: " + c)
                        bb.clear
                        c.close
                    }
                  }
                case count if count == 0 =>
                  println("ReadBlock: " + c)
                case _ =>
                  println("ReadError: " + c)
                  c.close
              }
            }
          } // // <================================================== 3
        })).start // <=============================================== 4
      case null =>
        println("AcceptBlock")
    }
  }
}

上面加了四行代码,把这段程序变成多线程,跑起来了,不过离我们的目标还有一段距离,我们需要使用单线程和 Java NIO 把这段程序跑起来。分析一下这段程序,主线程阻塞在外层 while 循环,所有子线程阻塞在里面的 while 循环,如果我们能够使用 reset/shift 把这两个 while 循环变成可中断的,那么问题不就解决了么?

先把这段程序中,需要使用线程的部分用 reset 包起来

object Echo extends App {
  val server = ServerSocketChannel.open
  server.socket.bind(new InetSocketAddress(12345))
  server.configureBlocking(false)
  reset { // <=================================== here it is the main thread
    while (true) {
      server.accept match {
        case c: SocketChannel =>
          reset { // <=========================== here it is child threads
            println("Accept: " + c)
            c.configureBlocking(false)
            while (c.isOpen && c.isConnected) {
              val bb = ByteBuffer.allocateDirect(1024)
              c.read(bb) match {
                case count if count > 0 =>
                  println("Read: " + c + " count: " + count)
                  bb.flip
                  while (bb.hasRemaining) {
                    c.write(bb) match {
                      case count if count > 0 =>
                        println("Write: " + c + " count: " + count)
                      case count if count == 0 =>
                        println("WriteBlock: " + c)
                      case _ =>
                        println("WriteError: " + c)
                        bb.clear
                        c.close
                    }
                  }
                case count if count == 0 =>
                  println("ReadBlock: " + c)
                case _ =>
                  println("ReadError: " + c)
                  c.close
              }
            }
          }
        case null =>
          println("AcceptBlock")
      }
    }
  }
}

然后找到需要阻塞的地方 shift 出去,shift 的时候把 continuation 函数注册到 Selector 里面:

object Echo extends App {
  val selector = Selector.open
  val server = ServerSocketChannel.open
  server.socket.bind(new InetSocketAddress(12345))
  server.configureBlocking(false)
  reset {
    while (true) {
      server.accept match {
        case c: SocketChannel =>
          reset {
            println("Accept: " + c)
            c.configureBlocking(false)
            while (c.isOpen && c.isConnected) {
              val bb = ByteBuffer.allocateDirect(1024)
              c.read(bb) match {
                case count if count > 0 =>
                  println("Read: " + c + " count: " + count)
                  bb.flip
                  while (bb.hasRemaining) {
                    c.write(bb) match {
                      case count if count > 0 =>
                        println("Write: " + c + " count: " + count)
                      case count if count == 0 =>
                        println("WriteBlock: " + c)
                        // <=============================== May block here
                        shift[Unit, Unit, Unit] { cont =>
                          c.register(selector, SelectionKey.OP_WRITE, cont)
                        }
                      case _ =>
                        println("WriteError: " + c)
                        bb.clear
                        c.close
                    }
                  }
                case count if count == 0 =>
                  println("ReadBlock: " + c)
                  // <===================================== May block here
                  shift[Unit, Unit, Unit] { cont =>
                    c.register(selector, SelectionKey.OP_READ, cont)
                  }
                case _ =>
                  println("ReadError: " + c)
                  c.close
              }
            }
          }
        case null =>
          println("AcceptBlock")
          // <============================================= May block here
          shift[Unit, Unit, Unit] { cont =>
            server.register(selector, SelectionKey.OP_ACCEPT, cont)
          }
      }
    }
  }
}

最后,加上 Selector 的 select 部分,select 成功后调用 shift 时存放在 attachment 中的 continuation 函数。为了编译通过,必须保证所有的 match/case 条件分支语句返回值类型相同,需要加入一些 shiftUnit 语句:

object Echo extends App {
  val selector = Selector.open
  val server = ServerSocketChannel.open
  server.socket.bind(new InetSocketAddress(12345))
  server.configureBlocking(false)
  reset {
    while (true) {
      server.accept match {
        case c: SocketChannel =>
          reset {
            println("Accept: " + c)
            c.configureBlocking(false)
            while (c.isOpen && c.isConnected) {
              val bb = ByteBuffer.allocateDirect(1024)
              c.read(bb) match {
                case count if count > 0 =>
                  println("Read: " + c + " count: " + count)
                  bb.flip
                  while (bb.hasRemaining) {
                    c.write(bb) match {
                      case count if count > 0 =>
                        println("Write: " + c + " count: " + count)
                        shiftUnit[Unit, Unit, Unit]()
                      case count if count == 0 =>
                        println("WriteBlock: " + c)
                        shift[Unit, Unit, Unit] { cont =>
                          c.register(selector, SelectionKey.OP_WRITE, cont)
                        }
                      case _ =>
                        println("WriteError: " + c)
                        bb.clear
                        c.close
                        shiftUnit[Unit, Unit, Unit]()
                    }
                  }
                case count if count == 0 =>
                  println("ReadBlock: " + c)
                  shift[Unit, Unit, Unit] { cont =>
                    c.register(selector, SelectionKey.OP_READ, cont)
                  }
                case _ =>
                  println("ReadError: " + c)
                  c.close
                  shiftUnit[Unit, Unit, Unit]()
              }
            }
          }
          shiftUnit[Unit, Unit, Unit]()
        case null =>
          println("AcceptBlock")
          shift[Unit, Unit, Unit] { cont =>
            server.register(selector, SelectionKey.OP_ACCEPT, cont)
          }
      }
      shiftUnit[Unit, Unit, Unit]()
    }
  }

  val keys = selector.selectedKeys
  while (true) {
    selector.select
    keys foreach { k =>
      k.interestOps(0)
      k.attachment.asInstanceOf[Function1[Unit, Unit]].apply(Unit)
    }
    keys.clear
  }
}

大功告成,怎么样,感觉不错吧,这样一个逻辑非常清晰的 Echo Server 就做完了,所有麻花一样的回调过程都隐藏在 reset/shift 的过程中了。源代码在 这里 ,有兴趣可以研究一下。