ABC116 D - Various Sushi
問題
解法
方針
まず、おいしさ基礎ポイントを最大化するだけならば、おいしさが大きい方から順にK個の寿司を食べるだけでよいです。 単純においしさ基礎ポイントを最大化する食べ方をベースに、種類ボーナスポイントを利用して、満足ポイントをさらに増やせないかを考えます。
種類ボーナスポイントを利用して満足ポイントを増やせるか
例として、以下のケースを考えます。
7 4 1 1 2 3 3 2 3 4 4 7 4 6 4 5
まず、おいしさ基礎ポイントを最大化することを考えます。各寿司をおいしさ基礎ポイントに関して降順でソートします。
4 7 4 6 4 5 3 4 2 3 3 2 1 1
K=4なので、おいしさ基礎ポイントを最大化するならば、ネタとおいしさが(4, 7), (4, 6), (4, 5), (3, 4) の寿司を食べればよいことがわかります(ソートした寿司を先頭から食べればよい)。 このとき、おいしさ基礎ポイントは7+6+5+4=22, 種類ボーナスポイントは2*2=4となります(ネタ3, 4を食べるため)。 したがって、満足ポイントは22+4=26です。
次に、寿司(2, 3)を食べ、種類ボーナスポイントを増やせないか考えます。 寿司(2, 3)を食べる場合、(4, 7), (4, 6), (4, 5), (3, 4) のいずれかの寿司を諦める必要があります。 (3, 4)を諦めると、ネタ2の代わりにネタ3の寿司がなくなることになるので、種類ボーナスポイントは増えません。 したがって、それ以外の寿司を諦めることになります。 どの寿司を諦めるかですが、これは単純においしさが最小の寿司を諦めればよいでしょう(わざわざ、おいしさの高い寿司を諦める必要はありません)。 この場合、(4, 5)ですね。 以上より、ネタ2も食べるならば、(4, 7), (4, 6), (3, 4), (2, 3)を食べることになります。 このとき、おいしさ基礎ポイントは7+6+4+3=20, 種類ボーナスポイントは3*3=9となります。 したがって、満足ポイントは20+9=29です。最初の食べ方より3ポイント増えました。
次に、寿司(3, 2)ですが、これを食べても種類ボーナスポイントは増えないのでスルーします。
最後に、寿司(1, 1)を食べ、種類ボーナスポイントを増やせないか考えます。 寿司(1, 1)を食べる場合、(4, 7), (4, 6), (3, 4), (2, 3) のいずれかの寿司を諦める必要があります。 (3, 4), (2, 3)を諦めると、ネタ1の代わりにネタ2またはネタ3の寿司がなくなることになるので、種類ボーナスポイントは増えません。 したがって、それ以外の寿司を諦めることになります。先程と同様の議論により、(4, 6)ですね。 以上より、ネタ1も食べるならば、(4, 7), (3, 4), (2, 3), (1, 1)を食べることになります。 このとき、おいしさ基礎ポイントは7+4+3+1=15, 種類ボーナスポイントは4*4=16となります。 したがって、満足ポイントは15+16=31です。(2, 3)を追加する食べ方より2ポイント増えました。
上記例を一般化しましょう。 満足ポイントを最大化する候補は以下に絞られます。
- 1: おいしさ基礎ポイントが高い寿司から順にK個選ぶ食べ方
- 2: 1の食べ方から以下を変更した選び方
- 追加: 1で食べていないネタのうち、おいしさが最大の寿司
- 削除: 1で2個以上選んでいる寿司のうち、おいしさが最小の寿司
- 3: 2の食べ方から以下を変更した選び方
- 追加: 2で食べていないネタのうち、おいしさが最大の寿司
- 削除: 2で2個以上選んでいる寿司のうち、おいしさが最小の寿司
- ...
ACしたコード
reduce
を使って解を表現したかったのですが、各要素をiterateするたびに集合やリストを生成すると間に合わないので、手続き的なコードにしました。
from operator import itemgetter from functools import reduce N, K = map(int, input().split()) t, d = zip(*(map(int, input().split()) for _ in range(N))) fn = itemgetter(0) fd = itemgetter(1) ss = sorted(zip(t, d), key=fd, reverse=True) netas = set() duplicates = [] for s in ss[:K]: if s[0] in netas: duplicates.append(s[1]) else: netas.add(s[0]) p = sum(map(fd, ss[:K])) + len(netas)**2 ans = p for s in ss[K:]: if s[0] not in netas and duplicates: netas.add(s[0]) duplicate = duplicates.pop() p += (s[1] - duplicate) + 2 * len(netas) - 1 ans = max(ans, p) print(ans)