Haskellでリストのソート(その弐)

今日は早く起きれました・・・。っても9時ですが。
ケータイのアラームを目覚まし時計代わりにしているのですが、これが効くんですよ。なんか絶望的な気持ちにさせる音で唸るんです。

さて、今日もHaskellでリストのソートを行うプログラムを書いてみましょう。
昨日はクイックソートだったので、実用性重視繋がりということで、今日はマージソートを書いてみましょう。


マージソートは全ての要素を分割した後、それらをマージしてソートするというアルゴリズムです。



merge a@(x:xs) b@(y:ys)
    | x < y = x:merge xs b
    | otherwise = y:merge a ys
merge (x:xs) = x:xs
merge
(y:ys) = y:ys
merge _ _ =

まずはマージ処理です。リストを二つ受け取って一つのリストにマージします。
やはりHaskellだと簡単に書けちゃいますね。アズパターンも便利。

次に、リストを半分に分割する処理です。


half xs = let (ls, rs, _) = (half' xs 0) in (ls, rs) where
    half' n = (, , div n 2)
    half' (x:xs) n = case (half' xs (n+1)) of
        (ys, zs, 0) -> (x:ys, zs, 0)
        c@(_, _, half) -> if n == half then (, x:xs, 0) else c

一旦リストの末尾までいって長さを確かめてから、半分のところでリストを分割します。

で、ここまで来たら後は楽チン。


msort =
msort [x] = [x]
msort xs = merge (msort ls) (msort rs) where
    (ls, rs) = half xs

よく出来ました。これで完成です。手続き型言語で書かれたものをまんまHaskellに直したような感じですね。
!!! でも待ってください、これ、滅茶苦茶遅そうじゃないですか?
だって見てください、分割処理の為に毎回リストの末尾まで行かなくちゃならないのはアホらしすぎません?

これはいけない、ということで改良しましょう。


import List

msort_plus xs = msort' 1 xs where
    len = length xs
    msort' n xs
        | n >= len = xs
        | otherwise = msort' (n*2) (msort'' n xs) where
            msort'' n =
            msort'' n xs = (merge ls ms) ++ (msort'' n rs) where
                (ls, ls') = splitAt n xs
                (ms, rs ) = splitAt n ls'

今度は何度もリストの長さを意識しないで済むようになりました。
しかしコレにもまだ無駄がありそうです・・・。よく見てください、これ、
msort' が呼ばれる毎に、splitAtで分割・再結合を繰り返してます。
splitAtはリストを手繰るので、分割する長さが増えれば増えるほど遅くなります。
この場合、msort' 毎にsplitAtでリスト全体を手繰ることになるので、nlog2 n(多分)程の計算量を食うことになります。


default(Int)

msort_plusPlus xs = msort' xs 1 where
    msort' (x:xs) ys n = msort' xs (f [x] ys n) (n+1) where
        f x (y:ys) n
            | even n = f (merge x y) ys (div n 2)
            | otherwise = x:y:ys
        f x _ _ = [x]
    msort' _ (y:ys) _ = foldl (\ a b -> merge a b) y ys

これならいいでしょう!分割・マージしたリストはリストに(つまりリストのリストにする)保持しています。
そして、数値演算でマージ相手となるリストを調べるようにしました。

では早速ベンチマークです。
昨日は色々最適化しましたがさっぱり早くなりませんでしたが・・・今日はどうでしょう。
最後のやつは結構自信あるんですがね!

以下がそのコードです。
(ソート処理の定義と、ベンチマーク用のコードは紙面節約のため省略。)

main =
   do {
         print "Hello, I am a pen.";
         (seed, _) <- getClockTimePrim;

         {- リストの長さが100の時 -}
         start (msort) 50 100 seed;
         start (msort_plus) 50 100 seed;
         start (msort_plusPlus) 50 100 seed;

         {- リストの長さが1000の時 -}
         start (msort) 50 1000 seed;
         start (msort_plus) 50 1000 seed;
         start (msort_plusPlus) 50 1000 seed;
       }

実行結果:
73042
72024
77633

949972
913993
752711

・・・まぁ、こんなもんでしょう。思ったより早くなってませんね・・・。悲しい。
実は他にmsort_plusPlusを改良したものも作ったのですが、複雑になりすぎて余計遅くなってしまいました・・・。
マージソートはここで終わり。

(すみません、ひょっとしたらはてな記法の為にソースコードの一部が壊れているかもしれません。)


# at 2006/3/24 14:23

平均の計算が間違っていたので修正して再計算。その結果↓

リストの長さ100
81520 - msort
82320 - msort_plus
69500 - msort_plusPlus


リストの長さ1000
969200 - msort
935140 - msort_plus
764700 - msort_plusPlus