srupのメモ帳

競プロで解いた問題や勉強したことを記録していくメモ帳

Codeforces #361 (Div.2) D. Friends and Subsequences

問題

問題概要

[l, r]の区間を考えた時に, al, .. , ar の最大値と bl, .. , br の最小値が一致する区間の総数を求める.

解法

区間の左端を固定して, 右端を伸ばしていくと, aの最大値は単調増加, bの最小値は単調減少していくので, 区間を伸ばしていくとどこかで一致する(条件を満たすものがあれば). 単調性を利用すると, 左端を決めた時に, 右端となる場所を二分探索でもとめることができる. ひとつの左端に対して, 右端となりうる場所は連続して複数ある場合があるので, lower_bound と upper_bound みたいな感じで, 2つ求める. またa, bのn+1番目の要素として, 番兵のようなものをいれておかないとすべての値が同じ時になぜかばぐったのでいれた. (よくわからない)
あとは区間の最大値最小値をどのように求めるかだが, segtreeでできる. segtreeの場合, 区間の最大値最小値を計算する際に, lognかかるため, 全体でO(nlognlogn)となりだいぶ危ない感じになる. 実際蟻本の実装方法でやると, TLEしたので, 更新処理を高速化したらAC.
また今回のように, 値の更新が途中でない場合, ダブリングのような感じで,
table[i][k] := [i, i + 2k)の最大値/最小値
というような配列を先にO(nlogn)で構築しておくことで, 区間の最大値/最小値を計算するのがO(1)で可能になり, 全体として, O(nlong)でできるようになる.

ミス

なし.

コード

高速化したsegtree

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef vector<int> vint;
typedef pair<int,int> pint;
typedef vector<pint> vpint;
#define rep(i,n) for(int i=0;i<(n);i++)
#define REP(i,n) for(int i=n-1;i>=(0);i--)
#define reps(i,f,n) for(int i=(f);i<(n);i++)
#define each(it,v) for(__typeof((v).begin()) it=(v).begin();it!=(v).end();it++)
#define all(v) (v).begin(),(v).end()
#define eall(v) unique(all(v), v.end())
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define chmax(a, b) a = (((a)<(b)) ? (b) : (a))
#define chmin(a, b) a = (((a)>(b)) ? (b) : (a))
const int MOD = 1e9 + 7;
const int INF = 1e9;
const ll INFF = 1e18;

int n;
int a[200010], b[200010];

template <class T> //T : dat[]の中身の型
class segtree_max{
public:
    int n;
    vector<T> dat;
    segtree_max(int n_): n(n_){ //n_要素数
        n = 1;
        while(n < n_) n *= 2;
        dat.resize(n * 2, -INF); //(1) 初期値を最小に -INFかも
    }
    void update(int k, T val){ // k番目の値(0-indexed)を val に変更
        for (dat[k += n] = val; k > 0; k >>= 1){ // kを含む区間のインデックスを下から順に列挙
            dat[k>>1] = max(dat[k], dat[k ^ 1]); // (2) 区間の最大値で更新
        }
    }
    T query(int l, int r){
        T ret = -INF; //(3) 最大値に関係ない値 -INFかも
        for (l += n, r += n; l < r; l >>= 1, r >>= 1){
            if(l & 1) ret = max(ret, dat[l++]); //(4) 区間の最大値で更新
            if(r & 1) ret = max(ret, dat[--r]); //(4) 区間の最大値で更新
        }
        return ret;
    }
};

template <class T> //T : dat[]の中身の型
class segtree_min{
public:
    int n;
    vector<T> dat;
    segtree_min(int n_): n(n_){ //n_要素数
        n = 1;
        while(n < n_) n *= 2;
        dat.resize(n * 2, INF); //(1) 初期値を最大に
    }
    void update(int k, T val){ // k番目の値(0-indexed)を val に変更
        for (dat[k += n] = val; k > 0; k >>= 1){ // kを含む区間のインデックスを下から順に列挙
            dat[k>>1] = min(dat[k], dat[k ^ 1]); // (2) 区間の最大値で更新
        }
    }
    T query(int l, int r){
        T ret = INF; //(3) 最小値に関係ない値
        for (l += n, r += n; l < r; l >>= 1, r >>= 1){
            if(l & 1) ret = min(ret, dat[l++]); //(4) 区間の最小値で更新
            if(r & 1) ret = min(ret, dat[--r]); //(4) 区間の最小値で更新
        }
        return ret;
    }
};

int main(void){
    scanf("%d", &n);
    rep(i, n)scanf("%d", &a[i]);
    rep(i, n)scanf("%d", &b[i]);
    segtree_max<int> sega(n + 1);
    segtree_min<int> segb(n + 1);

    rep(i, n) sega.update(i, a[i]);
    sega.update(n, INF + 1); //右端を追加
    rep(i, n) segb.update(i, b[i]);
    segb.update(n, -1); //右端を追加

    ll ret = 0;
    rep(i, n){ // 左端
        int l = i, r = n + 1;
        int ansl, ansr;
        while(r - l > 1){ // lower_bound
            int m = (l + r) / 2;
            auto da = sega.query(i, m); //単調増加
            auto db = segb.query(i, m); //単調減少
            if(da < db) l = m;
            else r = m;
        }
        if(sega.query(i, l + 1) == segb.query(i, l + 1))ansl = l;
        else continue;

        l = ansl, r = n + 1;
        while(r - l > 1){ //upper_bound
            int m = (l + r) / 2;
            auto da = sega.query(i, m); //単調増加
            auto db = segb.query(i, m); //単調減少
            if(da < db + 1) l = m;
            else r = m;
        }
        ansr = l;
        //[ansl, anrl], .. , [ansl, ansr - 1] までが答えの区間
        ret += ansr - ansl;
    }
    printf("%lld\n", ret);
    return 0;
}

SparseTable

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef vector<int> vint;
typedef pair<int,int> pint;
typedef vector<pint> vpint;
#define rep(i,n) for(int i=0;i<(n);i++)
#define REP(i,n) for(int i=n-1;i>=(0);i--)
#define reps(i,f,n) for(int i=(f);i<(n);i++)
#define each(it,v) for(__typeof((v).begin()) it=(v).begin();it!=(v).end();it++)
#define all(v) (v).begin(),(v).end()
#define eall(v) unique(all(v), v.end())
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define chmax(a, b) a = (((a)<(b)) ? (b) : (a))
#define chmin(a, b) a = (((a)>(b)) ? (b) : (a))
const int MOD = 1e9 + 7;
const int INF = 1e9;
const ll INFF = 1e18;

int n;
int a[200010], b[200010];

template <class T> //T : table[][]の中身の型
class SparseTable_max{
public:
    int N, M; //table[N][M]
    // table[i][k] := [i, i + 2^k)の最大値
    vector<vector<T>> table;
    template<class S> SparseTable_max(int n, S &val): N(n){ // O(nlogn)
        M = 32 - __builtin_clz(N); // M - 1 <= logN < M
        table.resize(N, vector<T>(M));
        for (int i = 0; i < N; ++i){ // [i, i + 1)までの区間の最大値
            table[i][0] = val[i];
        }
        for (int k = 0; k < M - 1; ++k){ // [i, i + 2^(k+1))の区間を計算
            for (int i = 0; i + (1<< k) < N; ++i){
                // iから2^(k+1)の長さの区間の最小値を2^kの長さの区間の最大値を利用して求める
                table[i][k + 1] = max(table[i][k], table[i + (1 << k)][k]); // (1)最大値
            }
        }
    }
    T query(int l, int r){ // O(1) [l, r) の間の最大値
        int k = 31 - __builtin_clz(r - l); //区間の長さの半分以上の値 (k<= r - l < k + 1)
        return max(table[l][k], table[r - (1 << k)][k]); // (2) 最大値
    }
};

template <class T> //T : table[][]の中身の型
class SparseTable_min{
public:
    int N, M; //table[N][M]
    // table[i][k] := [i, i + 2^k)の最小値
    vector<vector<T>> table;
    template<class S> SparseTable_min(int n, S &val): N(n){ // O(nlogn)
        M = 32 - __builtin_clz(N); // M - 1 <= logN < M
        table.resize(N, vector<T>(M));
        for (int i = 0; i < N; ++i){ // [i, i + 1)までの区間の最小値
            table[i][0] = val[i];
        }
        for (int k = 0; k < M - 1; ++k){ // [i, i + 2^(k+1))の区間を計算
            for (int i = 0; i + (1<< k) < N; ++i){
                // iから2^(k+1)の長さの区間の最小値を2^kの長さの区間の最小値を利用して求める
                table[i][k + 1] = min(table[i][k], table[i + (1 << k)][k]); // (1)最小値
            }
        }
    }
    T query(int l, int r){ // O(1) [l, r) の間の最小値
        int k = 31 - __builtin_clz(r - l); //区間の長さの半分以上の値 (k<= r - l < k + 1)
        return min(table[l][k], table[r - (1 << k)][k]); // (2) 最小値
    }
};

int main(void){
    scanf("%d", &n);
    rep(i, n)scanf("%d", &a[i]);
    a[n] = INF + 1;
    rep(i, n)scanf("%d", &b[i]);
    b[n] = -1;
    SparseTable_max<int> sega(n + 1, a);
    SparseTable_min<int> segb(n + 1, b);

    ll ret = 0;
    rep(i, n){ // 左端
        int l = i, r = n + 1;
        int ansl, ansr;
        while(r - l > 1){ // lower_bound
            int m = (l + r) / 2;
            auto da = sega.query(i, m); //単調増加
            auto db = segb.query(i, m); //単調減少
            if(da < db) l = m;
            else r = m;
        }
        if(sega.query(i, l + 1) == segb.query(i, l + 1))ansl = l;
        else continue;

        l = ansl, r = n + 1;
        while(r - l > 1){ //upper_bound
            int m = (l + r) / 2;
            auto da = sega.query(i, m); //単調増加
            auto db = segb.query(i, m); //単調減少
            if(da < db + 1) l = m;
            else r = m;
        }
        ansr = l;
        //[ansl, anrl], .. , [ansl, ansr - 1] までが答えの区間
        ret += ansr - ansl;
    }
    printf("%lld\n", ret);
    return 0;
}