Parの実装

注: この記事はScala関数型プログラミング&デザイン7章前半の劣化版まとめです

整数列の足し算について考えてみよう。たたみ込みなどがパッと思い付くだろう。

def sum(ints: Seq[Int]): Int = ints.foldLeft(0)((a, b) => a + b)

// 注: scalaでは整数列のたたみ込み演算についてあらかじめ関数が用意されていて、
// IntelliJを使っている場合は上記のコードに対して以下のsum関数を用いることを勧められるかもしれない
// at scala.TraversableOnce
def sum[B >: A](implicit num: Numeric[B]): B = foldLeft(num.zero)(num.plus)
def sum(ints: Seq[Int]): Int = ints.foldLeft(0)((a, b) => a + b)s

さて上記のたたみ込みだと、列の端から順に足し算を繰り返すことになり、計算を並列化することはできない。 容易に分割できる整数列であるならば、分割したそれぞれのパーツを並列に計算して最後に足し合わせることができそうである。 そこで上のコードを少し改良してみよう。

def sum(ints: IndexedSeq[Int]): Int 
  if (ints.size <= 1) ints.headOption.getOrElse(0)
  else {
    val (l, r) = ints.splitAt(ints.length / 2)
    sum(l) + sum(r)
  }

さて、Scalaでは式は正格評価されるため sum(l) + sum(r) の部分について、 左側の sum(l) の評価が終わってから右側の sum(r) の評価が始まる。

評価を遅らせるためには単純にthunkを作ればよい。 つまり、() => sum(l), () => sum(r) としてやればよい。

def sum(ints: IndexedSeq[Int]): Int =
  if (ints.size <= 1) ints.headOption.getOrElse(0)
  else {
    val (l, r) = ints.splitAt(ints.size / 2)
    val sumL = () => sum(l)
    val sumR = () => sum(r)
    sumL() + sumR()
  }

しかしこのコードも sumL() の評価が終わってから、sumR()の評価が始まるため何も解決されていない。 val sumL = () => sum(l), val sumR = () => sum(r) の部分で別スレッドで評価を始める計算を得て、 sumL() + sumR() の部分で両スレッドの計算を待ち、値を返すようになれば並列化ができそうである。

そこでそういった機能を持つ関数を実装は置いておいて、とりあえずインターフェースだけ定義しておきましょう。

trait Par[A] { 
}

object Par {
  // 未評価なA型の式を受け取り、別スレッドで評価するための計算を返す
  def unit[A](a: => A): Par[A] = ???

  // 並列計算結果を取り出す
  def get[A](a: Par[A]): A = ???
}

これを用いると以下のように整数列の足し算を書き直すことができる。

def sum(ints: IndexedSeq[Int]): Int =
  if (ints.size <= 1) ints.headOption.getOrElse(0)
  else {
    val (l, r) = ints.splitAt(ints.size / 2)
    val sumL: Par[Int] = Par.unit(sum(l))
    val sumR: Par[Int] = Par.unit(sum(r))
    Par.get(sumL) + Par.get(sumR)
  }

さて unit は引数を受け取った瞬間、その評価を別のスレッドで直ちに開始する実装だとした場合、 確かにsum関数は並列化することを達成できる。

しかしながら、Par.get(sumL) + Par.get(sumR)をインライン展開したとすれば、Par.get(Par.unit(sum(l))) + Par.get(Par.unit(sum(r)))となり並列性が失われる。なぜならば、get は Par[Int] の計算が終わるまで待機するからである。つまり、unit は get に対し副作用を持っているということになる。

ということは、sum関数からPar[Int]を直接返してしまえば問題ない。Par.get(sumL) + Par.get(sumR) としている部分は sumL と sumR を合成したPar[Int]値を返せば良い。そしてこの時点でunitは引数の評価を非正格にする必要がなくなる。

object Par {
 
  def unit[A](a: A): Par[A] = ???
  def map2[A, B, C](a: Par[A], b: Par[B])(f: (A, B) => C): Par[C] = ???
 
  def sum(ints: IndexedSeq[Int]): Par[Int] =
    if (ints.size <= 1) Par.unit(ints.headOption.getOrElse(0))
    else {
      val (l, r) = ints.splitAt(ints.size / 2)
      Par.map2(sum(l), sum(r))(_ + _)
    }
 
}

さて、このようにしてできたsum関数に IndexedSeq(1, 2, 3, 4) を渡してみると、 Par.map2(sum(l), sum(r))(_ + ) の部分で左側の引数 sum(l) のほうが先に展開されてしまうという問題がある。したがってmap2関数の引数評価を遅らせる必要があるように思える。ただ Par.map2(Par.unit(1), Par.unit(2))( + _) のような単純な計算については即座に引数を評価したい。そこで、明示的に別スレッドで実行すべきであるという意味をもたせたfork関数を導入しよう。

object Par {
 
  def unit[A](a: A): Par[A] = ???
  def fork[A](a: => Par[A]): Par[A] = ???
  def map2[A, B, C](a: Par[A], b: Par[B])(f: (A, B) => C): Par[C] = ???
 
  def sum(ints: IndexedSeq[Int]): Par[Int] =
    if (ints.size <= 1) Par.unit(ints.headOption.getOrElse(0))
    else {
      val (l, r) = ints.splitAt(ints.size / 2)
      Par.map2(fork(sum(l)), fork(sum(r)))(_ + _)
    }
 
}

こうすることにより forkで包まれた sum 関数は直ちに評価されないため、sum(l) と sum(r) は同時に計算が開始される。 さて、こうして定めたインターフェースに対して java.concurrent.ExecutorService を用いて実装を与えてみよう。以下のようになるだろう。

import java.util.concurrent.{Callable, TimeUnit, Future, ExecutorService}

object Par {

  type Par[A] = ExecutorService => Future[A]

  /*
   * primitives
   */
  def unit[A](a: A): Par[A] = (es: ExecutorService) => UnitFuture(a)

  private case class UnitFuture[A](get: A) extends Future[A] {

    override def isDone: Boolean = true

    override def get(timeout: Long, units: TimeUnit): A = get

    override def isCancelled: Boolean = false

    override def cancel(mayInterruptIfRunning: Boolean): Boolean = false

  }

  def fork[A](a: => Par[A]): Par[A] =
    (es: ExecutorService) => {
      es.submit(new Callable[A] {
        override def call(): A = a(es).get
      })
    }

  def map2[A, B, C](a: Par[A], b: Par[B])(f: (A, B) => C): Par[C] =
    (es: ExecutorService) => {
      val af = a(es)
      val bf = b(es)
      UnitFuture(f(af.get, bf.get))
    }

  def run[A](es: ExecutorService)(a: Par[A]): Future[A] = a(es)

  /*
   * derivative-combinators
   */
  def lazyUnit[A](a: => A): Par[A] = fork(unit(a))

  def asyncF[A, B](f: A => B): A => Par[B] = a => lazyUnit(f(a))

  def map[A, B](par: Par[A])(f: A => B): Par[B] =
    map2(par, unit(()))((a, _) => f(a))

  def sum(ints: IndexedSeq[Int]): Par[Int] =
    if (ints.size <= 1) Par.unit(ints.headOption.getOrElse(0))
    else {
      val (l, r) = ints.splitAt(ints.size / 2)
      Par.map2(fork(sum(l)), fork(sum(r)))(_ + _)
    }

}

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です