024号文書

主にプログラミング

PythonistaがRustはじめました#007 -- クロージャは便利

極限残業ラッシュにより、10日ぶりの記事になっちゃいました。 来週も厳しい一週間になりそうですが、Streakは切らさないようにしたいです。

CADDi 2018-C

  • 入力: 正の整数N, P
  • 出力: a1 a2 ... aN = P を満たす数列 a について、 a の最大公約数の最大値

例えば、(N, P) = (2, 720) ならば 12を出力すればよいです(a1 = 60, a2 = 12 とすることで P = a1 * a2 であり、 gcd(a1, a2) = 12を満たす。これより最大公約数が大きくなるようなaは存在しない)。

https://atcoder.jp/contests/caddi2018/tasks/caddi2018_a

方針

300点問題だけあって、straight forwardな解法では解けないです。 ポイントは素因数分解です。 2以上sqrt(P)以下の整数でPを割り続けることで、P を 2^P2 * 3^P3 * ... の形に表したときのP2, P3, ... を求めることができます。 そして、各素数について、P2, P3, ... を a1, a2, ... aNになるべく等分割することで、gcd(a1, a2, ..., aN) を最大化できます。

数え上げの際に便利なfold

あとは実装するだけです。 ループをゴリゴリ回して解いてもいいのですが、Rustは高階関数も豊富なので、それを使って解いてみましょう。 今回使うのはIteratorfold メソッドです。 Pythonでいうreduceですね。 例えば、1+2+...+N を以下の式で表現できます。

(1..=N).fold(0, |acc, i| acc + i)

非常に簡潔ですね!!(和なのでsum使えばもっと綺麗なのですが、任意の演算について上記のように書けることがポイントです) Pythonより綺麗だと思います(Pythonだと reduce(lambda acc, i: acc + i, range(1, N + 1), 0) ですね。lambdaがだるい)。 おっと、初めて出てきた記法が二つあるので、それらを解説していきます。

Range

(1..=N) はRangeと呼ばれるオブジェクトです。 Pythonでいう range(1, N + 1) ですね。 ちなみに、(a..b) だと a, a + 1, ..., b - 1 のように半開区間となります。

クロージャ

今日の本題?です。 |acc, i| acc + iクロージャと呼ばれるものです。 クロージャに関する詳しい説明は公式ドキュメント https://doc.rust-jp.rs/book/second-edition/ch13-01-closures.html を読めばよいです。 ざっくりと、クロージャを使う上でのポイントを挙げてみます。

最後がかなりお気に入りです。Pythonだと lambda t, c: t[0] * t[1] + c のように書く必要があり、非常につらい気持ちになります。 foldを適用する際、(解, 補助累積値)のようなタプルを畳み込むケースがかなり多いので、そのときは気持ちよく束縛していきましょう。

コード

素因数分解の指数部を求める関数をfとしています。

fn f(p: i64, i: i64) -> i64 {
    if p % i == 0 {
      1 + f(p / i, i)
    } else {
      0
    }
}

fn main() {
    // 入力
    let (N, P) = get!(i64, i64);

    // 素因数分解した結果を用いて解を求める
    let (ans, _) = if N == 1 {
        (P, 1)
    } else {
        let Q = ((P as f64).sqrt().ceil() as i64) + 1;
        (2..Q).fold(
            (1, P),
            |(acc, p), i| {
                let k = f(p, i);
                (
                    acc * i.pow((k / N) as u32),
                    p / i.pow((k as u32))
                )
            }
        )
    };

    // 出力
    println!("{}", ans);
}