2016年1月14日木曜日

【POH7】階乗 n! の最下位から続く0を除いた値の下9けたの出力

みなさんこんにちは。

去年12月に Paiza さんの POH7 に参加して、Gigazine さんに捕捉されて「猛者」認定されてしまったエントリを書きましたが、イベントのAmazonギフト券応募期間が終わったので解説を書いてみたいと思います。今回取り上げるのは「水着」の問題
入力された自然数 n (最大1000000)の階乗 n! での、最下位から続く0を除いた値の下9けたを出力せよ
という問題です。「nの階乗」というのは
n! = 1 x 2 x 3 x ... x n
で計算できる数です。プログラミングでは一般的に32ビットや64ビットのメモリでの表現を整数型として計算することが多いのですが、階乗を真面目に計算していくとすぐに表現できなくなります。
21! = 51090942171709440000 > 2^64
複数の整数型をまとめて、より大きな数を多倍長整数として表現し計算することも可能ですが、最大で 1000000!を計算する必要があり、ひたすら大きい数に対してひたすら掛け算しないといけません。桁数が大きくなればなるほど、かけ算の計算量も増えるわけですが、この問題ではそれだけ大きい数を正直に計算する必要はなく「最下位けたから続く0を除き、その数の下位9けた」を出力すればよいとされています。

「最下位けたから続く0を除く」というのは「10で割った余りが0であるかぎり10で割り算する」ということで、「下位9けた」というのは「10^9で割った余り」なので、新しい数をかけ末尾の0を削り10^9で割った余り求めるということを1〜nまで繰り返せば計算できそうです。Python2で書くなら

#! /usr/bin/env python2
n = int(raw_input())
r = 1
for i in xrange(1,n+1):
    r *= i
    while r % 10 == 0:
        r /= 10
    r %= pow(10,9)
print r

といった形で計算できそうです。がしかし、実はこれでは正しく計算できません。例えば n=25 の場合を考えると、 
24! = 620448401733239439360000
この問題での出力: 323943936
25! = 15511210043330985984000000
この問題での出力: 330985984 
ということになるのですが、n=24での出力に25をかけると
323943936 x 25 = 8098598400
と、除外しないといけない最下位の0が出てきてしまいます。n=24のときの結果に対して9桁よりも多くの桁を準備しないと正しい数値が計算できません。
323943936 x 25 = 8098598400 ... NG
3323943936 x 25 = 83098598400 ... NG
73323943936 x 25 = 1833098598400 ... OK
実践的に考えると、n の数に応じて「下位9けたよりも多く」の数字を準備しておくのが、とりあえずは現実的なところかと思います。なんとなく、log10(n)+1桁分を9桁に加えて準備しておけばよさそうですが、証明したわけではありません。

この問題を真面目に解くためには「ルジャンドルの定理」というのを使うことができます。アイディアとしては、階乗の数を素因数分解するというものです。

定理の説明はこちらこちらを見ていただくとして、25! を素因数分解した時に出てくる 2 の個数で考えると、 floor(25/2) = 12, floor(25/4) = 6, floor(25/8) = 3, floor(25/16) = 1 の総計22個となります。事前にnまでの素数列を準備しておけば素因数分解をするのはそれほど大変な計算とはなりません。

与えられた n について n! を素因数分解の形で表現することができたら、「最下位けたから続く0を除き、その数の下位9けた」を求めることになりますが、いうまでもなく 2 x 5 = 10 ですから、素因数分解した結果の2と5の個数を調整(2の個数と5の個数のうち少ない方を調べて、2と5の個数それぞれから引く)すれば最下位けたの連続する0を取り除くことができます。

残ったものを素因数分解の形式から数字に戻せばよいのですが、0が最下位に出る可能性は排除できているので、順次掛け算をして10^9 で割った剰余を計算することでほしい結果を得ることができます。

#! /usr/bin/env python2

def primes(n):
    # n までの素数列を返す関数を実装(記述は省略します)

n = int(raw_input())

ps = primes(n)
idx = []

for p in ps:
    m = p
    e = 0
    while m<=n:
        e+=int(n/m)
        m*=p
    idx.append(e)

if len(idx)<3:
    idx.extend([0,0,0])

ten = min(idx[0],idx[2])
idx[0] -= ten
idx[2] -= ten

mi = map(lambda (a,b):pow(a,b,pow(10,9)),zip(ps,idx))
print reduce(lambda a,b:a*b%pow(10,9),mi)

POH7 での問題ではnの上限(最大1000000)が決まっているので、ここまでやる必要はかならずしもないのですが、こういったアプローチを論理的に展開できるようになりたいものだと思いました。