024号文書

主にプログラミング

ABC116 D - Various Sushi

問題

ABC116 D- Various Sushiです。

解法

方針

まず、おいしさ基礎ポイントを最大化するだけならば、おいしさが大きい方から順にK個の寿司を食べるだけでよいです。 単純においしさ基礎ポイントを最大化する食べ方をベースに、種類ボーナスポイントを利用して、満足ポイントをさらに増やせないかを考えます。

種類ボーナスポイントを利用して満足ポイントを増やせるか

例として、以下のケースを考えます。

7 4
1 1
2 3
3 2
3 4
4 7
4 6
4 5

まず、おいしさ基礎ポイントを最大化することを考えます。各寿司をおいしさ基礎ポイントに関して降順でソートします。

4 7
4 6
4 5
3 4
2 3
3 2
1 1

K=4なので、おいしさ基礎ポイントを最大化するならば、ネタとおいしさが(4, 7), (4, 6), (4, 5), (3, 4) の寿司を食べればよいことがわかります(ソートした寿司を先頭から食べればよい)。 このとき、おいしさ基礎ポイントは7+6+5+4=22, 種類ボーナスポイントは2*2=4となります(ネタ3, 4を食べるため)。 したがって、満足ポイントは22+4=26です。

次に、寿司(2, 3)を食べ、種類ボーナスポイントを増やせないか考えます。 寿司(2, 3)を食べる場合、(4, 7), (4, 6), (4, 5), (3, 4) のいずれかの寿司を諦める必要があります。 (3, 4)を諦めると、ネタ2の代わりにネタ3の寿司がなくなることになるので、種類ボーナスポイントは増えません。 したがって、それ以外の寿司を諦めることになります。 どの寿司を諦めるかですが、これは単純においしさが最小の寿司を諦めればよいでしょう(わざわざ、おいしさの高い寿司を諦める必要はありません)。 この場合、(4, 5)ですね。 以上より、ネタ2も食べるならば、(4, 7), (4, 6), (3, 4), (2, 3)を食べることになります。 このとき、おいしさ基礎ポイントは7+6+4+3=20, 種類ボーナスポイントは3*3=9となります。 したがって、満足ポイントは20+9=29です。最初の食べ方より3ポイント増えました。

次に、寿司(3, 2)ですが、これを食べても種類ボーナスポイントは増えないのでスルーします。

最後に、寿司(1, 1)を食べ、種類ボーナスポイントを増やせないか考えます。 寿司(1, 1)を食べる場合、(4, 7), (4, 6), (3, 4), (2, 3) のいずれかの寿司を諦める必要があります。 (3, 4), (2, 3)を諦めると、ネタ1の代わりにネタ2またはネタ3の寿司がなくなることになるので、種類ボーナスポイントは増えません。 したがって、それ以外の寿司を諦めることになります。先程と同様の議論により、(4, 6)ですね。 以上より、ネタ1も食べるならば、(4, 7), (3, 4), (2, 3), (1, 1)を食べることになります。 このとき、おいしさ基礎ポイントは7+4+3+1=15, 種類ボーナスポイントは4*4=16となります。 したがって、満足ポイントは15+16=31です。(2, 3)を追加する食べ方より2ポイント増えました。

上記例を一般化しましょう。 満足ポイントを最大化する候補は以下に絞られます。

  • 1: おいしさ基礎ポイントが高い寿司から順にK個選ぶ食べ方
  • 2: 1の食べ方から以下を変更した選び方
    • 追加: 1で食べていないネタのうち、おいしさが最大の寿司
    • 削除: 1で2個以上選んでいる寿司のうち、おいしさが最小の寿司
  • 3: 2の食べ方から以下を変更した選び方
    • 追加: 2で食べていないネタのうち、おいしさが最大の寿司
    • 削除: 2で2個以上選んでいる寿司のうち、おいしさが最小の寿司
  • ...

ACしたコード

reduce を使って解を表現したかったのですが、各要素をiterateするたびに集合やリストを生成すると間に合わないので、手続き的なコードにしました。

from operator import itemgetter
from functools import reduce
N, K = map(int, input().split())
t, d = zip(*(map(int, input().split()) for _ in range(N)))
fn = itemgetter(0)
fd = itemgetter(1)
ss = sorted(zip(t, d), key=fd, reverse=True)
netas = set()
duplicates = []
for s in ss[:K]:
  if s[0] in netas:
    duplicates.append(s[1])
  else:
    netas.add(s[0])
p = sum(map(fd, ss[:K])) + len(netas)**2
ans = p
for s in ss[K:]:
  if s[0] not in netas and duplicates:
    netas.add(s[0])
    duplicate = duplicates.pop()
    p += (s[1] - duplicate) + 2 * len(netas) - 1
    ans = max(ans, p)
print(ans)