024号文書

主にプログラミング

ABC117 D - XXOR

問題

https://atcoder.jp/contests/abc117/tasks/abc117_d

解法

二進リテラルを0bXXXXのように表すことにします。 例えば、0b101は十進表記で5です。

方針

まず、Aの各要素が1以下、かつ、K=1の場合を考えてみましょう。 例えば、A=[1, 1, 0, 1]の場合です。この場合、Xは0または1です。それぞれの場合について計算してみます。

  • f(0)=(0 XOR 1) + (0 XOR 1) + (0 XOR 0) + (0 XOR 1) = 1 + 1 + 0 + 1 = 3
  • f(1)=(1 XOR 1) + (1 XOR 1) + (1 XOR 0) + (1 XOR 1) = 0 + 0 + 1 + 0 = 1

したがって、fはX=0のとき最大値3をとります。 同様に、Aの各要素が1以下、かつ、K=1の場合の最適解は以下のように求められます。

  • Aの要素について、1が多い場合: X=0とする
  • Aの要素について、0が多い場合: X=1とする

ちなみに、0と1が同数の場合、Xはどちらでもよいです。

次に、Aの各要素が3以下(二進表記で2桁で表現できる場合)、かつ、K=3(二進表記で2桁とも1)の場合を考えてみましょう。 例えば、A=[2, 3, 2, 0]の場合です。 二進表記での桁数が増えましたが、一桁ずつ考えれば大丈夫です(各桁は互いに作用しないため)。 つまり、Aの各要素の1桁目(最上位ビットを1桁目と呼ぶことにします)をみてXの1桁目を、Aの各要素の2桁目をみてXの2桁目を求めればよいです。 例のAを二進表記すると、A=[0b10, 0b11, 0b10, 0b00] です。 1桁目は1の方が多く、2桁目は0の方が多いです。 したがって、X=0b01=1のとき、fは最大値をとります。

さらに一般化し、Aの各要素が(2m) - 1以下、かつ、K=(2m)-1(二進表記で各桁が1)の場合を考えます。 同様に、Aを二進表記した値の各桁の値をみてXの各桁の値を求められます。

ここまで、Kについて二進表記で各桁が1である、という強い仮定をおいていました。 これから、この仮定を外していくのですが、「各桁の0の数と1の数に応じてXの各桁を決める」という方針は一緒です。

Kの一般化

例えば、A=[0, 0, 4], K=4の場合を考えます。 Aを二進表記すると、A=[0b000, 0b000, 0b100]です。 もし、K=7=0b111ならば、X=0b111としたいところです。 しかし、K=4=0b100なので、(ある種、貪欲に)Xの1桁目を1にすると、2桁目以降は0でなければなりません。 f(4)=0b100+0b100+0b000=4+4=8です。 一方、Xの1桁目を我慢して0にすると、2桁目以降は好きな値を割り当てられます。 そこで、2, 3桁目に1を割り当てる、すなわちX=0b011=3とすると、f(3)=0b011+0b011+0b111=3+3+7=13です。 このように、以下のどちらかの戦略により最適化を計る必要があります。

  • 戦略1: Xの1桁目を0にする代わりに、2桁目以降を(無制限に)最適な値とする
  • 戦略2: Xの1桁目を1にする代わりに、以降の一部の桁を0にする(例のようにすべて0になってしまう場合もある)

後者について補足します。 例えば、K=10=0b1010の場合を考えます。 このとき、Xの1桁目を1にすると、2桁目は必ず0です。ただし、3桁目は1にできます。ただし、3桁目も1にすると、4桁目は0です。 このように、2桁目以降に制約が課せられるということです。

戦略1の最適値は「各桁の0の数と1の数に応じてXの各桁を決める」という方針にしたがうだけなので簡単です。 戦略2の最適値の計算が難しいです。 K=10=0b1010をもう一度考えます。Xの1桁目を1にすると、2桁目は0に固定されます。 つまり、X=0b10xxの形になります。 さて、xxを求めればよいのですが、どのように最適値を計算するのでしょうか。 実は、再帰的に同様の問題に直面しています! すなわち、Xの上位2桁は10で固定されているとして、K=0b0010の問題を解けばよいのです。 K=0b0010として、再び(再帰的に)、戦略1と戦略2を比較し、最適解を求めます。 そして、求めた解をxxに当てはまればよいです。

ACしたコード

  • Pythonの整数はメソッド bit_length を持ってます。これを使うと、ビット数を簡単に計算できます
  • あらかじめ、Xの下r桁を(無制限に)決められる場合の最適なXの下r桁を求めています(opt)
from itertools import accumulate
N, K = map(int, input().split())
A = list(map(int, input().split()))
m = K.bit_length()
b = [sum((a >> k) & 1 for a in A) for k in range(m)]
opt = [0] + list(accumulate(1 << k if 2 * b[k] < N else 0 for k in range(m)))


def f(x):
    return sum(x ^ a for a in A)


def g(k, s):
    if k == 0:
        return f(s)
    r = k.bit_length()
    return max(
      f(s + opt[r - 1]),
      g(k - (1 << (r - 1)), s + (1 << (r - 1)))
    )


ans = g(K, 0)
print(ans)