再帰
関数型言語では、副作用のないプログラムを行うため繰り返し処理はループではなく再帰を使用。
ただ、Scala で限定的に使用するループで副作用(変数の更新)があっても特に問題なさそうだが、やはりできるだけ副作用のない再帰を使いたい。
再帰のサンプルは,リスト要素の整数値を合算する再帰関数
リストのパターンマッチで Nil(空)の場合は加算する要素がないので、acc をそのまま返す。
パターンマッチでリストのhead(先頭要素) と tail(先頭以外の要素のリスト) の形式になる場合は
acc にリストの先頭要素を足した値とtail部分を引数として自分自身を再帰呼び出ししている。
import scala.annotation.tailrec object Test1 { @tailrec def sum(acc: Int, list: List[Int]): Int = { print(acc + " ") list match { case Nil => acc case e :: es => sum(acc + e, es) } } def main(args: Array[String]) = { val list = List(1, 2, 3, 4, 5, 6, 7) val result = sum(list.head, list.tail) println("result: " + result) } }
sumメソッドの引数 acc は、アキュムレータと呼ばれ、再帰関数の途中の計算結果の値を保持する。
アキュムレータを使用するのは関数の最後で再帰的に自分自身を呼び出して末尾再帰にするため。
末尾再帰はコンパイル時に最適化されて命令型コードと同等の形に変換される。
なので、リストの要素数が増えてもスタックオーバーフローは発生しない。
Scala 2.8 から再帰関数に @tailrec アノテーションをつけると末尾再帰でない場合に
コンパイルエラーにしてくれる。
実行結果
$ scala Test1 1 3 6 10 15 21 28 result: 28
次は、アキュムレータを使用していない末尾再帰でない例
sumメソッドの再帰呼び出し部分は、e + sum(es) になっているため
再帰呼び出しが最後ではなく再帰呼出し後に e と加算処理を行う。
import scala.annotation.tailrec object Test2 { @tailrec def sum(list: List[Int]): Int = { list match { case Nil => 0 case e :: es => e + sum(es) } } def main(args: Array[String]) = { val list = List(1, 2, 3, 4, 5, 6, 7) val result = sum(list) println("result: " + result) } }
@tailrec アノテーションを付けているのでコンパイルするとsumメソッドが
末尾再帰でないためコンパイルエラーになる。
@tailrec アノテーションをとっておくとコンパイルは通り
リストの要素数がある程度少なければ動作するが、リストの要素数が
増えるとスタックオーバーフロー(java.lang.StackOverflowError)が発生する。
$ scalac test2.scala test2.scala:8: error: could not optimize @tailrec annotated method sum: it contains a recursive call not in tail position case e :: es => e + sum(es) ^
list のチェックは .. match { case .. => .. } 式のパターンマッチを使用。
case Nil は空のリストのケースで、case e :: es *1 は e が list.head で es が list.tail になりリストが先頭要素とそれ以外の要素のリストの形式になる場合をあらわしている。
list match { case Nil => acc case e :: es => sum(acc + e, es) }
は
if (list.isEmpty) acc else sum(acc + list.head, list.tail)
と同じ処理になる。
上の if .. else .. 式に書き換えた Test1 オブジェクト
import scala.annotation.tailrec object Test1 { @tailrec def sum(acc: Int, list: List[Int]): Int = { print(acc + " ") if (list.isEmpty) acc else sum(acc + list.head, list.tail) } def main(args: Array[String]) = { val list = List(1, 2, 3, 4, 5, 6, 7) val result = sum(list.head, list.tail) println("result: " + result) } }