Scala で簡易 Web サーバのプロトタイプ書いてみた

Httpd.scala

NIO 経由でのリクエストとレスポンスの処理。

package etc9.aokan

import java.nio.ByteBuffer
import java.io.{File, FileInputStream}
import java.nio.channels._
import etc9.aokan.HttpDef._
import collection.mutable.{ListBuffer}
import Util._
import collection.immutable.Map
import scala.concurrent.ops.spawn
import xml.Elem
import main.scala.etc9.web.Web
import java.net.{URLEncoder, URLDecoder, InetSocketAddress}

object Main { def main(args: Array[String]):Unit = new Server }

class Server(port: Int = 8900) {

  loan(Selector.open) { selector =>
    loan(ServerSocketChannel.open) { channel =>
      channel.socket.setReuseAddress(true)
      channel.configureBlocking(false)
      channel.socket.bind(new InetSocketAddress(port))
      channel.register(selector, SelectionKey.OP_ACCEPT, new AcceptHandler)

      while (selector.keys.size > 0) {
        selector.select
        val it = selector.selectedKeys.iterator
        while(it.hasNext) {
          val key = it.next
          it.remove
          key.attachment.asInstanceOf[Handler].handle(key)
        }
      }
    }
  }
}


abstract class Handler { def handle(key: SelectionKey): Unit }

class AcceptHandler extends Handler {
  def handle(key: SelectionKey): Unit = if(key.isValid)
    if(key.isAcceptable) {
      val sc: SocketChannel = key.channel.asInstanceOf[ServerSocketChannel].accept
      sc.configureBlocking(false)
      sc.register(key.selector, SelectionKey.OP_READ, new IoHandler)
  }
}

class IoHandler extends Handler with IoRead with Process with IoWrite {
  def handle(key: SelectionKey): Unit = {
    if (key.isValid && key.isReadable) ioRead(key) map { request =>
      process(request) { response =>
        reserveWrite(response.writer)
        ioWrite(key)
      }
    }
    if (key.isValid && key.isWritable) ioWrite(key)
  }
}


trait IoRead {
  private val buffer = IoBuffer(2048)
  def ioRead(key: SelectionKey): Option[HttpRequest] = {
    val channel = key.channel.asInstanceOf[SocketChannel]
    buffer read channel

    val header: Option[HttpHeader] = takeHeader(buffer) match {
      case Some(bytes) => withReturn(HttpHeader(bytes)) { buffer.remove(bytes.size) }
      case _ => None
    }

    header map ( h => h.method match {
      case POST if(buffer.length < h.contentLength) => null
      case POST => withReturn(new HttpRequest(h, takeBody(buffer, h.contentLength))) {
        buffer.remove(h.contentLength)
        halfClose(key, channel) }
      case _ => withReturn(new HttpRequest(h, None)) { halfClose(key, channel) }
    })
  }

  private def halfClose(key: SelectionKey, ch: SocketChannel) {
    key.interestOps(key.interestOps() & ~SelectionKey.OP_READ)
    ch.socket.shutdownInput
  }

  private def takeHeader(buf: IoBuffer): Option[Seq[Byte]] = {
    ( 0 to buf.length - 4 ) find { i =>
      buf(i) == CR && buf(i+1) == LF && buf(i+2) == CR && buf(i+3) == LF
    } map { index => buf.take(index + 3) }
  }

  private def takeBody(seq: Seq[Byte], len: Int): Option[Seq[Byte]] =
    if(seq.length >= len) Some(seq.take(len)) else None
}

trait IoWrite {
  private val writers = ListBuffer.empty[Writer]
  def ioWrite(key: SelectionKey): Unit = {
    val sc: SocketChannel = key.channel.asInstanceOf[SocketChannel]
    if (writers.head.write(sc).hasRemaining) {
      key.interestOps(key.interestOps() | SelectionKey.OP_WRITE)
    } else {
      writers.remove(0)
      if (writers.isEmpty) {
        key.interestOps(key.interestOps() & ~SelectionKey.OP_WRITE);
        sc.close
      } else {
        key.interestOps(key.interestOps() | SelectionKey.OP_WRITE)
      }
    }
  }
  def reserveWrite(writer: Writer): Unit = writers += writer
}

trait Process {
  def process(req: HttpRequest)(reply: HttpResponse => Unit) {
    if (req.header.path.endsWith(".do")) spawn {
      reply(new HttpResponse(OK_RESPONSE, Web.mkContent(req)))
    } else {
      val file = new File("./public" + req.header.path)
      reply(if (file.exists) new HttpResponse(OK_RESPONSE, Some(new FileContent(file)))
            else NotFound)
    }
  }
}

case class HttpHeader (
    val method: Method, val uri: String, val version: String,
    val headers: Map[String, String], val size: Int) {
  val contentLength = headers.getOrElse("Content-Length" ,"0").toInt
  def path: String = uri.split('?').head
  def query = parseQuery(uri.diff(path+'?'))
}

object HttpHeader {
  def apply(seq: Seq[Byte]): Option[HttpHeader] = try {
    val lines = new String(seq.toArray, US_ASCII).lines
    val REQUEST_LINE_RE(method, path, version) = lines.next()
    val headers = for(HEADER_RE(k, v) <- lines) yield (k, v)
    Option(new HttpHeader(Method(method), path, version, headers.toMap, seq.length))
  } catch { case _ => None }
}

class HttpRequest(val header: HttpHeader, val body: Option[Seq[Byte]]) {
  def forms = if(header.method == POST) header.headers.get("Content-Type") match {
    case Some("application/x-www-form-urlencoded") =>
      Some(parseQuery(new String(body.get.toArray, ISO_8859_1)))
    case Some("multipart/form-data") => None // not implemented
    case _ => None
  }
}


class HttpResponse(val status: Status, val content:Option[Content] = None,
    val headers: Map[HeaderField, String] = Map.empty) extends Writable {

  def headerBytes = {
    val default: Map[HeaderField, String] = Map(
      ContentLength -> (content match {case Some(c) => c.length.toString case _ => "0"}),
      Connection -> "close",
      Server -> "Aokan",
      Date -> GMT.dateString(System.currentTimeMillis))

    ("HTTP/1.1 " + status + "\r\n" +
      (default ++ headers).mkString("", "; ", "\r\n") + "\r\n").getBytes(US_ASCII)
  }

  val writer: Writer = new Writer { self =>
    val b = ByteBuffer.wrap(headerBytes)
    def write(sc: SocketChannel): Writer = {
      if (b.hasRemaining) sc.write(b)
      content map { c => c.writer.write(sc) }
      self
    }
    def hasRemaining = b.hasRemaining || content.exists(c => c.writer.hasRemaining==true)
  }
}

object BadRequest       extends HttpResponse(BAD_REQUEST)
object NotFound         extends HttpResponse(NOT_FOUND)
object MethodNotAllowed extends HttpResponse(METHOD_NOT_ALLOWED)
object MovedPermanently extends HttpResponse(MOVED_PERMANENTLY)
object MovedTemporarily extends HttpResponse(MOVED_TEMPORARILY)


trait Writable { val writer: Writer }
trait Writer {
  def write(sc: SocketChannel): Writer
  def hasRemaining: Boolean
}

abstract class Content extends Writable {
  def length: Long
}

class FileContent(val file: File) extends Content {
  def length = file.length
  val writer: Writer = new Writer { self =>
    val fc = new FileInputStream(file).getChannel
    var remaining = fc.size
    def write(sc: SocketChannel): Writer = {
      remaining -= fc.transferTo(fc.size - remaining, remaining, sc)
      if(remaining == 0) fc.close
      self
    }
    def hasRemaining = remaining > 0
  }
}

trait Let extends Content {
  def elem: Elem
  def html = elem.mkString.getBytes(UTF_8)
  def length = html.size
  val writer: Writer = new Writer { self =>
    val b = ByteBuffer.wrap(html)
    def write(sc: SocketChannel): Writer = {
      if (b.hasRemaining) sc.write(b)
      self
    }
    def hasRemaining = b.hasRemaining
  }
}

object Util {
  object GMT {
    val fmt = new java.text.SimpleDateFormat(
      "EEE, dd MMM yyyy HH:mm:ss", java.util.Locale.US)
    val offset = java.util.TimeZone.getDefault.getRawOffset
    def dateString(mills : Long): String = fmt.synchronized {
      fmt.format(new java.util.Date(mills - offset)) + " GMT"
    }
  }

  def parseQuery(str: String): Map[String, String] = 
   (for { q <- str.split('&'); if(q.indexOf('=') > 0)
    kv = q.split('=') } yield (decode(kv.head), decode(kv.tail.mkString))).toMap

  def decode(str: String) = URLDecoder.decode(str , UTF_8.name)
  def encode(str: String) = URLEncoder.encode(str , UTF_8.name)

  def loan[T <: {def close(): Unit}, R](t: T)(f: T => R): R = try{f(t)} finally{t.close}
  def withReturn[R, F](r: => R)(f: => F): R = {val ret = r; f; ret}
}

IoBuffer.scala

ByteBuffer 使いにくすぎるので簡易なラッパ

package etc9.aokan
import java.nio.channels.ReadableByteChannel
import java.nio.ByteBuffer

class IoBuffer(var buf: ByteBuffer) extends Seq[Byte] {
  def apply(idx: Int) = buf.get(idx)
  def update(n: Int, e: Byte) = buf.put(n, e)
  def length = buf.limit
  def remove(idx: Int) { buf.position(idx); buf.compact }

  def iterator: Iterator[Byte] = new Iterator[Byte] {
    val b = buf.asReadOnlyBuffer; b.flip
    def next() = b.get
    def hasNext = b.hasRemaining
  }

  def read[C <: ReadableByteChannel](channel: C): Int = {
    var count = channel.read(buf)
    while(buf.capacity < buf.position + 2) {
      expand
      count += channel.read(buf)
    }
    count
  }

  private def expand {
    val save = (buf.position, buf.limit)
    buf.clear
    buf = ByteBuffer.allocate(nextCapacity()).put(buf)
    buf.position(save._1); buf.limit(save._2)
  }

  private def nextCapacity(current: Int = buf.capacity) = {
    val next = Integer.highestOneBit(current) << 1
    if (next < 0) Integer.MAX_VALUE else next
  }
}

object IoBuffer {
  def wrap(ioBuf: IoBuffer) = new IoBuffer(ioBuf.buf)
  def wrap(buf: ByteBuffer) = new IoBuffer(buf)
  def wrap(array: Array[Byte]) =  new IoBuffer(ByteBuffer.wrap(array))
  def apply(capacity: Int) = new IoBuffer(ByteBuffer.allocate(capacity))
}

HttpDef.scala

HTTPの各種定義系のまとめ

package etc9.aokan

object HttpDef {

  val CR:Int = 13 // carriage return
  val LF:Int = 10 // line feed

  val US_ASCII   = java.nio.charset.Charset.forName("US-ASCII")   // Protocol Charset
  val ISO_8859_1 = java.nio.charset.Charset.forName("ISO-8859-1") // http body Charset
  val UTF_8       = java.nio.charset.Charset.forName("UTF-8")      // content Charset

  val REQUEST_LINE_RE = """^([A-Z]+) +([^ ]+) +HTTP/([0-9\.]+)$""".r
  val HEADER_RE = """^(.*): +(.*)$""".r

  sealed abstract class Method
  case object GET  extends Method
  case object POST extends Method
  case object HEAD extends Method
  object Method {
    def apply(s: String) = s match {
      case "GET" => GET; case "POST" => POST; case "HEAD" => HEAD;
    }
  }

  sealed abstract class Status(code: Int, description: String) {
    override def toString = code + " " + description
  }
  case object OK_RESPONSE        extends Status(code=200, description="OK")
  case object BAD_REQUEST        extends Status(code=400, description="Bad Request")
  case object NOT_FOUND          extends Status(code=404, description="Not Found")
  case object METHOD_NOT_ALLOWED extends Status(code=405, description="Method Not Allowed")
  case object MOVED_PERMANENTLY  extends Status(code=301, description="Moved Permanently")
  case object MOVED_TEMPORARILY  extends Status(code=302, description="Moved Temporarily")


  abstract class HeaderField(val name: String) { override def toString = name }
  case object Accept            extends HeaderField("Accept")
  case object AcceptCharset     extends HeaderField("Accept-Charset")
  case object AcceptLanguage    extends HeaderField("Accept-Language")
  case object AcceptRange       extends HeaderField("Accept-Range")
  case object Age               extends HeaderField("Age")
  case object Allow             extends HeaderField("Allow")
  case object Authorization     extends HeaderField("Authorization")
  case object CacheControl      extends HeaderField("Cache-Control")
  case object Connection        extends HeaderField("Connection")
  case object ContentLanguage   extends HeaderField("Content-Language")
  case object ContentLength     extends HeaderField("Content-Length")
  case object ContentLocation   extends HeaderField("Content-Location")
  case object ContentMD5        extends HeaderField("Content-MD5")
  case object ContentRange      extends HeaderField("Content-Range")
  case object ContentType       extends HeaderField("Content-Type")
  case object Date              extends HeaderField("Date")
  case object Expect            extends HeaderField("Expect")
  case object Expires           extends HeaderField("Expires")
  case object From              extends HeaderField("From")
  case object Host              extends HeaderField("Host")
  case object IfModifiedSince   extends HeaderField("If-Modified-Since")
  case object IfRange           extends HeaderField("If-Range")
  case object IfUnmodifiedSince extends HeaderField("If-Unmodified-Since")
  case object LastModified      extends HeaderField("Last-Modified")
  case object Location          extends HeaderField("Location")
  case object MaxForwards       extends HeaderField("Max-Forwards")
  case object Pragma            extends HeaderField("Pragma")
  case object Range             extends HeaderField("Range")
  case object Referer           extends HeaderField("Referer")
  case object RetryAfter        extends HeaderField("Retry-After")
  case object Server            extends HeaderField("Server")
  case object Upgrade           extends HeaderField("Upgrade")
  case object UserAgent         extends HeaderField("User-Agent")
  case object Vary              extends HeaderField("Vary")
  case object Via               extends HeaderField("Via")
  case object WWWAuthenticate   extends HeaderField("WWW-Authenticate")
}

Web.scala

動的なものは簡単なサーブレットっぽく。

package main.scala.etc9.web
import _root_.etc9.aokan.{Let, Content, HttpRequest}

object Web {
  def mkContent(req: HttpRequest): Option[Content] = route(req)
  private def route[T >: Content](req: HttpRequest): Option[T] = req.header.path match {
    case "/index.do" => Some(new IndexPage(req))
    case _ => None
  }
}

class IndexPage(req: HttpRequest) extends Let {
  def date = new java.util.Date
  def elem =
<html>
  <head>
    <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
    <title>Index</title>
  </head>
  <body>
    <h1>Hello</h1>
    <p>current time : {date.toString}</p>
    <p>querys : {req.header.query}</p>
    <p>forms : {req.forms}</p>
  </body>
</html>
}

テスト用にHTML

public フォルダに静的コンテンツ置く。

<html>
  <head>
    <meta http-equiv="Content-Type" content="text/html; charset=utf-8">
    <title>Aokan</title>
  </head>
  <body>
    <h1>Aokan Httpd</h1>
    <form action="index.do" method="post">
      <p>name:<input type="text" name="name" size="40"></p>
      <p>sex:<input type="radio" name="sex" value="male">male
      <input type="radio" name="sex" value="female">female</p>
      <input type="submit"><input type="reset">
    </form>
  </body>
</html>

まとめ

リクエストをパースして適当にレスポンス返すだけで、ヘッダとかもまじめに見てません。単なる えいや のプロト実装でしかありません。結論としては、やっぱり NIO 扱いずらい。。