srupのメモ帳

競プロで解いた問題や勉強したことを記録していくメモ帳

aoj 0529 - Darts

問題

問題概要

探索問題。単純に全探索してしまうと、TLE。

解法

まず、単純に考える。全てのパターンを考えれば良いので、以下のコードのように単純に4つの矢が刺さるマスを全探索すればいい。

int main(void){
    while(1){
        int n, m; cin >> n >> m;
        if(n == 0 && m == 0) return 0;
        vector<int> p(n);
        rep(i, n) cin >> p[i];
        p.push_back(0);//矢を使わないを入れる
        int ans = 0;
        //計算量(n^4)
        rep(i, n + 1)rep(j, n + 1)rep(k, n + 1)rep(l, n + 1){
            int tmp = p[i] + p[j] + p[k] + p[l];
            if(tmp <= m) ans = max(ans, tmp);
        }
        printf("%d\n", ans);
    }
    return 0;
}

上のコードでは計算量が(n4)となってしまい、とても大きくなってしまう。そこで、4本目の矢を2分探索で求めることにする。3本目までの合計をtmpとすると、4本目の矢は(m - tmp)以下のもので、最大のものであれば良い。よってコードが以下のようになる。

int main(void){
    while(1){
        int n, m; cin >> n >> m;
        if(n == 0 && m == 0) return 0;
        vector<int> p(n);
        rep(i, n) cin >> p[i];
        p.push_back(0);//矢を使わないを入れる
        sort(p.begin(), p.end());
        int ans = 0;
        //計算量(n^3*logn)
        rep(i, n + 1)rep(j, n + 1)rep(k, n + 1){
            int tmp = p[i] + p[j] + p[k];
            int aim = m - tmp;
            if(aim < 0) continue;
            tmp += *(upper_bound(p.begin(), p.end(), aim) - 1);
            if(tmp <= m) ans = max(ans, tmp);
        }
        printf("%d\n", ans);
    }
    return 0;
}

上記のコードだと、4つ目の矢を2分探索することで、計算量が(n3*logn)まで落ちているが、n <= 1000の制約ではn3が残っている時点で厳しい。よって、矢2本分で取れる得点をまず全列挙する。2本分の得点のtmpとすると、残り2本でm - tmp以下となる2本分の矢の合計の中で最大のものを探せば良いというんことになる。ここでのポイントは、全部でn通りであったポイントをn2通りにすることで、計算量を落とすことができるということだ。 以下のコードの計算量は(n2 * logn)である。

ミス

2分探索の良問。蟻本25pに類題あり。

コード

#include <iostream>
#include <vector>
#include <cstdio>
#include <algorithm>
using namespace std;
#define rep(i,n) for(int i=0;i<(n);i++)

int main(void){
    while(1){
        int n, m; cin >> n >> m;
        if(n == 0 && m == 0) return 0;
        vector<int> p(n);
        rep(i, n) cin >> p[i];
        p.push_back(0);//矢を使わないを入れる
        vector<int> pp((n + 1) * (n + 1));
        int cnt = 0;
        rep(i, n + 1)rep(j, n + 1){
            pp[cnt] = p[i] + p[j];
            cnt++;
        }
    
        sort(pp.begin(), pp.end());
        int ans = 0;
        //計算量(n^2*logn)
        rep(i, (n + 1) * (n + 1)){
            int tmp = pp[i];
            int aim = m - tmp;
            if(aim < 0) continue;
            tmp += *(upper_bound(pp.begin(), pp.end(), aim) - 1);
            if(tmp <= m) ans = max(ans, tmp);
        }
        printf("%d\n", ans);
    }
    return 0;
}