無名関数とカリー化

無名関数

Anonymous Functionsについて順を追って・・by ScalaByExample

2つの整数を引数に取り、その間の数の合計を計算するには再帰にて以下のように書けます。

object Main extends Application {
  def sumInts(a: Int, b: Int): Int =
    if (a > b) 0 else a + sumInts(a + 1, b)
  println(sumInts(1, 3))  // 1+2+3
}

実行結果は以下のようになります。

6


さらに、2つの整数を引数に取り、その間の数の2乗を合計するには同じように以下のように書けます。

object Main extends Application {
  def square(x: Int): Int = x * x
  def sumSquares(a: Int, b: Int): Int =
    if (a > b) 0 else square(a) + sumSquares(a + 1, b)
  println(sumSquares(1, 3))  // 1*1 + 2*2 + 3*3
}

実行結果は以下のようになります。

14


以上2つの sumInts と square という関数の中身は同じような構造をしています。異なる点は、引数で与えられた a をそのまま足しているか、2乗して足しているかになります。
そこで、これらの引数 a に関する処理を外部から関数を与えることで実現させてみます。

object Main extends Application {
  def sum(f: Int => Int, a: Int, b: Int): Int =
    if (a > b) 0 else f(a) + sum(f, a + 1, b)

sun という関数を定義し、Int 型を引数に取り、Int 型を返却する関数を、1つ目の引数として受け取るようにします。sum 関数の内部では、引数として受け取った関数に a を適用しています。

では、この sum という関数を使ってみます。sum に渡す関数を2つ定義し、sum 関数を呼び出します。

object Main extends Application {
  def sum(f: Int => Int, a: Int, b: Int): Int =
    if (a > b) 0 else f(a) + sum(f, a + 1, b)
  
  def func1(x:Int):Int = x
  println(sum(func1, 1, 3))    //sumInts 呼び出しに該当
  
  def func2(x:Int):Int = x * x //sumSquares 呼び出しに該当
  println(sum(func2, 1, 3))

実行結果は以下のようになり、最初に見た結果と同様の結果となっていることが分かります。

6
14


上記の func1 と func2 の定義を含めて以下のようにも書くことができます。

object Main extends Application {
  def sum(f: Int => Int, a: Int, b: Int): Int =
    if (a > b) 0 else f(a) + sum(f, a + 1, b)
  
  println(sum({def f(x:Int):Int = x;     f _}, 1, 3))
  println(sum({def f(x:Int):Int = x * x; f _}, 1, 3))

{}で囲まれたブロック内で関数を定義し、f _ という形で関数を適用しています。_ は、外部より与えられた引数がそのままfという関数の引数となることを意味します。

前述までは関数に名前を付けていましたが、明示的に名前を定義せず以下のように簡素に書く構文糖が用意されています。

object Main extends Application {
  def sum(f: Int => Int, a: Int, b: Int): Int =
    if (a > b) 0 else f(a) + sum(f, a + 1, b)
  
  println(sum(x=>x,   1, 3))
  println(sum(x=>x*x, 1, 3))

カリー化

上記の sum 関数は以下のように書き変えることができます。

object Main extends Application {
  def sum(f: Int => Int): (Int, Int) => Int = {
    def sumF(a: Int, b: Int): Int =
      if (a > b) 0 else f(a) + sumF(a + 1, b)
    sumF
  }  
  println(sum(x=>x)(1, 3))
  println(sum(x=>x*x)(1, 3))

このバージョンのsum関数は以下のように定義されています。

  • 引数「f: Int => Int」:1つの整数を引数に取り、1つの整数を返却する関数
  • 戻り値「(Int, Int) => Int」:2つの整数を引数に取り1つの整数を返却する関数

そして実際、sumFという関数を返却しています。
sum 関数の呼出部分は以下のように書くことができます。

  println((sum(x=>x))(1, 3))
  println((sum(x=>x*x))(1, 3))

sum(x=>x) でsumFとして定義された関数が返却され、そのsumFに(1, 3)という引数を適用していることになります。つまり、以下のように書くこともできるということになります。

object Main extends Application {
  def sum(f: Int => Int): (Int, Int) => Int = {
    def sumF(a: Int, b: Int): Int =
      if (a > b) 0 else f(a) + sumF(a + 1, b)
    sumF
  }  
  def sumInts = sum(x => x)
  def sumSquares = sum(x => x * x)
  println(sumInts(1, 3))
  println(sumSquares(1, 3))


さらにさらに、上記は以下のようにより短く書くことができます。

object Main extends Application {
  def sum(f: Int => Int)(a: Int, b: Int): Int =
    if (a > b) 0 else f(a) + sum(f)(a + 1, b)
  println(sum(x=>x)(1, 3))
  println(sum(x=>x*x)(1, 3))