あるプログラマの日記

プログラマのメモ、出来事、考えたこと、勉強とかの雑記

パターンマッチ (基本のメモ)

  • パターンマッチは条件分岐を記述する switch 文に似た構文
  • 主にデータの比較、分解、抽出の用途で使用する。
  • 実行時に該当するパターンが選択肢にない場合は scala.MatchError が発生する。
  • switch の default はないが、default と同等のパターンとしてワイルカードを指定できる。
  • Scala のパターンマッチには break は不要で上から順番にマッチするまで評価してマッチした後は自動で制御から抜ける。
セレクタ式 match {
  case パターン1 => 処理1
  case パターン2 => 処理2
  case ...
  case _ => 上記バターン以外の処理
}
  • パターンに該当した後は => の後の処理が順番に評価される。
  • パターンに記述できるものが C や java の switch 文よりもはるかに柔軟で豊富
パターンの種類 記述例 内容
ワイルドカード case _ あらゆるオブジェクトにマッチ
リテラル(整数) case 1 整数 1 であればマッチ
リテラル(文字列) case "foo" 文字列 "foo" であれば
定数(空のリスト) *1 case Nil 空リストだけにマッチ
変数 *2 case v すべてにマッチ。結果を変数 v に束縛する
case v: Long Long型であれば
コンストラク case Foo(1, n) *3 case クラスのオブジェクトがマッチ。さらにコンストラクタ引数のマッチ
タプル case (a, b) a と b の変数を持つTuple2であれば
シーケンス case x :: xs 先頭要素と残りリスト要素があれば
パターンガード case (a, b) if (a % 2) == 0 細かい条件分岐を行う場合はガード条件(ifの条件)を付ける

case クラスの定義

  • class 宣言の前に case を付ける。
  • case クラスはフィールド宣言の必要がない。コンストラクタの引数が自動的にフィールドとして扱われる。
  • toString, hashCode, equals, copy メソッドはコンパイラが自動で実装を追加してくれる。
  • Scala はクラス本体が空であれば中括弧({ })を省略可能
scala> case class Foo(a: Int, b: Long)
defined class Foo

scala> val v = Foo(1, 2L)
v: Foo = Foo(1,2)

scala> val v2 = v.copy(a = 3)
v2: Foo = Foo(3,2)

scala> v == v2
res0: Boolean = false

*1:先頭が小文字でないものは定数とみなされる

*2:パターン変数は小文字で始める必要がある

*3:引数リスト内の引数には暗黙に val が付けられるためフィールドとして扱える

Option#getOrElse

前に勉強用に書いたXMLファイルの要素と属性の表示「XMLファイルの要素と属性をベタに表示」はXMLファイルにXML宣言がない場合は処理できなかったので、Opttion#getOrElse でXMLファイルにXML宣言がない場合は、デフォルトのエンコーディングUTF-8 を指定するように変更しました。
encoding メソッドはXMLファイルからエンコーディングを取得して Some を返しますが、XMLファイルからエンコーディングが読めない場合は None を返します。
encoding から返された Option[String] が None かを isEmpty で判断してたのですが、Option#getOrElse を使用して None の場合 "UTF-8" をデフォルトのエンコーディングに指定するように変更しました。

(変更前)

..snip..
  def main(args: Array[String]): Unit = {
    val enc = encoding(args(0))
    if (enc.isEmpty) {
      println("not XML file : " + args(0))
      return
    }
    println("encoding = " + enc.get)

    val f = new java.io.File(args(0))
    val s = Source.fromFile(f, enc.get)
..snip..


(変更後)

..snip..
  def main(args: Array[String]): Unit = {
    val f = new java.io.File(args(0))
    if (!f.exists) {
      println("none file : " + args(0))
      return
    }
    val enc = encoding(args(0)).getOrElse("UTF-8")
    println("encoding = " + enc)

    val s = Source.fromFile(f, enc)
..snip..

Optoin には None の場合に引数に指定したオブジェクトを返してくれるメソッドがあったのですね。
似ているメソッドに orElse があるが Option[T] 型を扱う必要がない場合は、 orElse ではなく getOrElse ですね。
Option#getOrElse は None の場合に指定した v を返しますが Option#orElse は None の場合に指定した Some(v) を返します。

XMLファイルの要素と属性をベタに表示

XMLファイルを読み込むプログラムを勉強がてらつくりました。

ScalaXML 操作が便利。
scala.xml.parsing.ConstructingParser に Source を渡すとXML操作用の Document が取得できるので Node を使って要素のラベルと値、属性、子要素をすべてベタに表示。

import scala.collection.mutable.Seq
import scala.io.Source
import scala.io.BufferedSource
import scala.xml._
import scala.xml.parsing.ConstructingParser

object XmlFlatReader {

  def main(args: Array[String]): Unit = {
    val enc = encoding(args(0))
    if (enc.isEmpty) {
      println("not XML file : " + args(0))
      return
    }
    println("encoding = " + enc.get)

    val f = new java.io.File(args(0))
    val s = Source.fromFile(f, enc.get)
    using[Unit](s) { s =>
      val doc = ConstructingParser.fromSource(s, false).document()
      statement(doc.docElem)
    }
  }

  def statement(node: Node) {
    val elems = node \ "_"
    println(node.label +
        (if (!elems.isEmpty || node.text.isEmpty) "" else " : " + node.text))
    val attrmap = node.attributes.asAttrMap
    attrmap.foreach { pa => println("  attr : " + pa._1 + " = " + pa._2) }
    elems.foreach { statement }
  }

  def encoding(fname: String): Option[String] = {
    val src = Source.fromFile(fname)
    using[Option[String]](src) { src =>
      for (line <- src.getLines) {
        if (!line.contains("<?xml") || !line.contains("?>")) return None
        val items = line.split(" ")
        for (str <- items) {
          if (str.contains("encoding")) {
            val ent = str.split("=")
            return Some(ent(1).replace("\"", "").trim)
          }
        }
      }
      None      
    }
  }

  def using[T](s: Source)(f: Source => T): T = {
    try f(s) finally s.asInstanceOf[BufferedSource].close
  }
}

Node は子要素がある場合、text に子要素の値をすべて持っているので、自己の値のtext だけを表示してます。
ConstructingParser は XML の encoding を自動で識別してくれないのでencoding の取得のためだけに一旦、XMLファイルを読み込んでXML宣言にある encoding を取得。
もうちょっと良い方法があるのかもしれないが、とりあえず encoding が取れて指定できました。
XML宣言がない XMLファイルは読めないので XML宣言がない時は決め打ちで、UTF-8 を指定してもよかったかもしれない。
Source は close がないのでローンパターン *1 で BufferedSource にキャストして close。using に渡すクロージャの戻り値の型を型パラメータで指定してます。
Map で foreach すると キーと値の Pair *2 を渡してくれるのですね。
\ は NodeSeq のメソッドで、指定タグの要素を返してくれて便利。
ここでは使っていないが \\ はさらにネストした要素をサーチしてくれて、指定する名前の先頭に @ を付けると属性をサーチできるようだ。
"_" を指定したときは全ての子要素の集合を返してくれます。これで全要素を取得。
要素の Node パース用の statement メソッドは末尾再帰になっていないですが、末尾再帰関数にする良い方法が他にあるのかもしれません。

*1:リソース処理後の処分(close)の保証

*2:Tuple2

ベンチマーク

Scala で処理時間を計測する Benchmark trait があったので
試しに 1から100万の整数要素を持つリストの各要素を合計する計算処理の時間を、
末尾再帰関数、forループ、whileループ、List の sum で計測してみた。

Benchmark の runBenchmarkメソッドは引数の times で計測の回数を指定し、
計測回数分の計測結果が入った List[Long]をかえしてくれる。計測結果の時間の単位はミリ秒。
Benchmark には multiplier フィールドがあり 1回の計測で計測対象の処理を
何回実行するかを指定できる。デフォルト値は 1 が指定されている。
ここで multiplier はデフォルト値 1 のままで実行。times はデフォルト引数で 10 回を指定。

import scala.annotation.tailrec
import scala.testing.Benchmark

object Test {
  def main(args: Array[String]) = {
    val list = (1 to 1000000).toList
    test("Test1", sum(list.head, list.tail))
    test("Test2", sum2(list))
    test("Test3", sum3(list))
    test("Test4", list.sum)
  }

  final def test(name: String, f: => Any, times: Int = 10) = {
    val result = new Benchmark { def run() = f }.runBenchmark(times)
    println(name + ": " + (result.sum / result.size) + " " + result)
  }

  @tailrec
  final def sum(acc: Int, list: List[Int]): Int = {
    if (list.isEmpty) acc else sum(acc + list.head, list.tail)
  }

  final def sum2(list: List[Int]): Int = {
    var cnt = 0;
    for (i <- list) { cnt += i }
    cnt
  }

  final def sum3(list: List[Int]): Int = {
    var cnt = 0;
    var tmp = list
    while (!tmp.isEmpty) {
      cnt += tmp.head
      tmp = tmp.tail
    }
    cnt
  }
}

[実行結果]
処理10回の計測テストを4回やってみた。
結果は10回の計測結果の平均値と各10回の計測結果のリストの内容を表示。
Test1:末尾再帰 Test2:for Test3:while Test4:list.sum
最適化された末尾再帰関数と while ループは、ほぼ同じくらいで最も早く、
以外にも Scala の List に元からある list.sum が一番遅かった。
for は while よりは遅く list.sum よりは早かった。

$ scala -version
Scala code runner version 2.9.1 -- Copyright 2002-2011, LAMP/EPFL
$ scala Test
Test1: 44 List(33, 51, 26, 44, 46, 44, 45, 64, 44, 44)
Test2: 53 List(57, 53, 53, 53, 53, 53, 53, 53, 52, 53)
Test3: 42 List(31, 47, 25, 44, 45, 44, 45, 48, 45, 48)
Test4: 72 List(59, 101, 50, 73, 73, 74, 73, 73, 76, 73)
$ scala Test
Test1: 41 List(32, 48, 25, 44, 44, 44, 44, 44, 44, 47)
Test2: 56 List(58, 53, 55, 53, 54, 53, 79, 53, 53, 54)
Test3: 45 List(31, 50, 25, 71, 45, 46, 47, 45, 47, 45)
Test4: 71 List(83, 71, 42, 76, 76, 78, 76, 78, 66, 66)
$ scala Test
Test1: 41 List(32, 48, 25, 45, 44, 44, 44, 44, 45, 45)
Test2: 52 List(55, 53, 52, 53, 53, 52, 53, 52, 52, 52)
Test3: 41 List(31, 49, 24, 45, 44, 44, 45, 45, 45, 44)
Test4: 64 List(53, 70, 41, 65, 77, 65, 73, 66, 66, 66)

Benchmark.scala のソースファイル

Benchmark.scala の runBenchmarkメソッドの実装をみてみると

計測の開始時の時刻と終了時の時刻を Platform.currentTime *1 で計測
1回の計測後にPlatform.collectGarbage *2 を呼び出している。

Platform object の各メソッドには @inline アノテーションが指定されていたが、
これは C の inline と同じようなものかな ?

..snip..
  def runBenchmark(noTimes: Int): List[Long] =
    for (i <- List.range(1, noTimes + 1)) yield {
      setUp
      val startTime = Platform.currentTime
      var i = 0; while (i < multiplier) {
        run()
        i += 1
      }
      val stopTime = Platform.currentTime
      tearDown
      Platform.collectGarbage

      stopTime - startTime
    }
..snip..

[参考]
hishidamaさんの Scala Benchmark ページ

*1:Platform.currentTime は java の System.currentTimeMillis()

*2:Platform.collectGarbage は java の System.gc()

再帰

関数型言語では、副作用のないプログラムを行うため繰り返し処理はループではなく再帰を使用。
ただ、Scala で限定的に使用するループで副作用(変数の更新)があっても特に問題なさそうだが、やはりできるだけ副作用のない再帰を使いたい。
再帰のサンプルは,リスト要素の整数値を合算する再帰関数
リストのパターンマッチで Nil(空)の場合は加算する要素がないので、acc をそのまま返す。
パターンマッチでリストのhead(先頭要素) と tail(先頭以外の要素のリスト) の形式になる場合は
acc にリストの先頭要素を足した値とtail部分を引数として自分自身を再帰呼び出ししている。

import scala.annotation.tailrec

object Test1 {
  @tailrec
  def sum(acc: Int, list: List[Int]): Int = {
    print(acc + " ")
    list match {
      case Nil => acc
      case e :: es => sum(acc + e, es)
    }
  }
  
  def main(args: Array[String]) = {
    val list = List(1, 2, 3, 4, 5, 6, 7)
    val result = sum(list.head, list.tail)
    println("result: " + result)
  }
}

sumメソッドの引数 acc は、アキュムレータと呼ばれ、再帰関数の途中の計算結果の値を保持する。
アキュムレータを使用するのは関数の最後で再帰的に自分自身を呼び出して末尾再帰にするため。
末尾再帰コンパイル時に最適化されて命令型コードと同等の形に変換される。
なので、リストの要素数が増えてもスタックオーバーフローは発生しない。
Scala 2.8 から再帰関数に @tailrec アノテーションをつけると末尾再帰でない場合に
コンパイルエラーにしてくれる。

実行結果

$ scala Test1
1 3 6 10 15 21 28 result: 28

次は、アキュムレータを使用していない末尾再帰でない例
sumメソッドの再帰呼び出し部分は、e + sum(es) になっているため
再帰呼び出しが最後ではなく再帰呼出し後に e と加算処理を行う。

import scala.annotation.tailrec

object Test2 {
  @tailrec
  def sum(list: List[Int]): Int = {
    list match {
      case Nil => 0
      case e :: es => e + sum(es)
    }
  }
  
  def main(args: Array[String]) = {
    val list = List(1, 2, 3, 4, 5, 6, 7)
    val result = sum(list)
    println("result: " + result)
  }
}

@tailrec アノテーションを付けているのでコンパイルするとsumメソッドが
末尾再帰でないためコンパイルエラーになる。
@tailrec アノテーションをとっておくとコンパイルは通り
リストの要素数がある程度少なければ動作するが、リストの要素数
増えるとスタックオーバーフロー(java.lang.StackOverflowError)が発生する。

$ scalac test2.scala
test2.scala:8: error: could not optimize @tailrec annotated method sum: it contains a recursive call not in tail position
      case e :: es => e + sum(es)
                        ^

list のチェックは .. match { case .. => .. } 式のパターンマッチを使用。
case Nil は空のリストのケースで、case e :: es *1 は e が list.head で es が list.tail になりリストが先頭要素とそれ以外の要素のリストの形式になる場合をあらわしている。

list match {
  case Nil => acc
  case e :: es => sum(acc + e, es)
}

if (list.isEmpty) acc else sum(acc + list.head, list.tail)

と同じ処理になる。
上の if .. else .. 式に書き換えた Test1 オブジェクト

import scala.annotation.tailrec

object Test1 {
  @tailrec
  def sum(acc: Int, list: List[Int]): Int = {
    print(acc + " ")
    if (list.isEmpty) acc else sum(acc + list.head, list.tail)
  }

  def main(args: Array[String]) = {
    val list = List(1, 2, 3, 4, 5, 6, 7)
    val result = sum(list.head, list.tail)
    println("result: " + result)
  }
}

末尾再帰が最適化される条件

  • 自分自身を呼び出している末尾再帰関数
  • 再帰関数がオーバーライドで変更されないメソッド(関数) *2

*1: ::(e, es) と同じ

*2:メソッドが final か private か、または メソッドのクラスが final

関数型言語を採用する時の壁

副作用のない処理、パターンマッチ、ケースクラス、ファーストクラス関数、高階関数クロージャ、カリー化、関数部分適用、末尾再帰、アクター、並列処理、型推論.. それぞれが関連しあって簡潔なコードが書けて、副作用によるバグも減少できるメリットがあるのですが、現場では、なかなか関数型言語へ移行できない壁があったりします。
ネックになるところは

関数型言語導入のとっかかりとして

というような感じで徐々に現場で認知してもらって、
関数型言語へフェードインできると良いのですが..

そして、関数型言語の勉強をちまちまと継続

暗黙の引数 (implicit parameters)

引数の型に合わせた暗黙の値を指定しておくと、関数(メソッド)呼び出し時の引数を省略できる。
関数定義の引数リストで引数名の前に implicit を付けると暗黙の引数が適用される。
引数を省略しないときは関数呼び出し時の指定値がそのまま渡される。
implicit は引数リスト内の先頭の引数以外には指定できないが、これは個別の引数への指定ではなく
引数リスト全体に適用されている。
デフォルト引数によく似ているが、デフォルト引数は、関数の定義でデフォルト値を指定する
のに対して暗黙の引数は呼び出し元で引き数の暗黙の値を定義する。
例では、一般的な Int と Long にマッチさせる暗黙の引数を定義しているが、通常は個別に型を
作って偶然の型の一致が起こらないようにするのが良さそうである。

scala> def foo(a: Int, implicit b: Long): Long = a * b
<console>:1: error: identifier expected but 'implicit' found.
       def foo(a: Int, implicit b: Int) = a * b
                       ^
scala> def foo(implicit a: Int, b: Long): Long = a * b
foo: (implicit a: Int, implicit b: Long)Long

scala> implicit val defValue = 10
defValue: Int = 10

scala> implicit val defLValue = 20L
defLValue: Long = 20

scala> foo
res1: Long = 200

scala> foo(3, 4)
res2: Long = 12

複数の引数リストをとる関数の最後尾の引数リスト以外に暗黙の引数は適用できない。

scala> def bar(a: Short)(implicit b: Long, c: Int): Long = a + b * c
bar: (a: Short)(implicit b: Long, implicit c: Int)Long

scala> bar(5)
res3: Long = 205

scala> def foo2(implicit a: Int)(b: Int): Int = a * b
<console>:1: error: '=' expected but '(' found.
       def foo2(implicit a: Int)(b: Int): Int = a * b
                                ^

scala> def foo3(a: Int)(implicit b: Long)(c: Short): Long = a * b + c
<console>:1: error: '=' expected but '(' found.
       def foo3(a: Int)(implicit b: Long)(c: Short): Long = a * b + c
                                         ^

scala> def foo4(implicit a: Int)(implicit b: Long): Long = a * b
<console>:1: error: '=' expected but '(' found.
       def foo4(implicit a: Int)(implicit b: Long): Long = a * b
                                ^

呼び出した関数で適用した暗黙の引数の型の implicit val 定義が
スコープ内になければエラーになる。

scala> def bar2(a: Int)(implicit b: Long, c: Short): Long = a + b * c
bar7: (a: Int)(implicit b: Long, implicit c: Short)Long

scala> bar2(7)
<console>:11: error: could not find implicit value for parameter c: Short
       bar2(7)

scala> implicit val defSValue: Short = 5
defSValue: Short = 5

scala> bar2(7)
res8: Long = 107

呼び出した関数で適用した暗黙の引数の型のimplicit val の定義がスコープ内で
2つ以上ある場合もエラーになる。

scala> implicit val defs2: Short = 3
defs2: Short = 3

scala> bar2(7)
<console>:13: error: ambiguous implicit values:
 both value defSValue in object $iw of type => Short
 and value defs2 in object $iw of type => Short
 match expected type Short
       bar2(7)
           ^

あるスコープで、型に定義している暗黙の値を確認するには Predef の implicitly[型] メソッドを使う。

scala> implicitly[Int]
res10: Int = 10

scala> implicitly[Long]
res11: Long = 20

Predef の implicitly の 定義

def implicitly[T](implicit e: T): T = e