現代の量子力学

主に精進の様子を記録する日記帳です

Codeforces Round #641 (Div. 2) D. Orac and Medians

久しぶりにこどふぉしました。それもこれもDeepLの翻訳のおかげです。

問題のリンク

Problem - D - Codeforces

問題の概要

・数列a が与えられる。以下の操作を繰り返して数列の全ての要素をk にできるか?
(操作)
数列aから連続した部分列Sを選択し、Sの要素全てを、Sの中央値に置き換える。

コンテスト中の考察

・中央値をとるにあたって、数列の要素はk より小さいか、k より大きいか、k そのものであるかという情報のみが必要なので変換してしまう。以下、それぞれ、小、大、k と表す。

・簡単な場合を考えてみる、長い部分列を考えるのはしんどいので短いものから。
・{k, 大} の場合
→{k, k} に変換できるので、{大, 大, ... , k, 大, 大, ... } みたいなやつは全部k に変換できる。
・{k, k, 小} の場合
→{k, k, k} に変換できる。小の部分が大でも同じことができるので一度{k, k} という部分を作れば全ての要素をk にすることができる。

というわけで、中央値がk になる要素数2以上の連続した部分列を一つでも発見できますか?という問題になる。

中央値がk になる連続した部分列を見つけるには

・中央値がk の部分列とは、{小, 小, k, 大, 大}こんな感じ、ここでk 以上の要素は+1、k 未満の要素は-1に置き換えて部分列の総和をとると、+1になっている。または、{小, k, 大, 大}というパターンで、総和は+2である。
・ある数列の連続した部分列の要素の総和は累積和を用いて表すことができる。累積和を S(r) := a _ 0 + a _ 1 + \dots + a _ {r-1} とすると、半開区間[l, r) における部分列{ a _ l,  \dots a _ {r-1} } の要素の和は S(r) - S(l) と表せる。
・つまり、① S(r) - S(l) = 1 or 2, ② r - l >= 2, ③a[i] = k (iは l <= i < r を満たす i) という(l, r) を見つければ良いということである。あとでやりやすいように少しだけ条件を書き換える。
・条件①について、S(r) - S(l) = 3という状況を考えてみる。a[i] = 1 or -1 なので、S(r-1) - S(l) は3 ± 1 である。このようにr を1ずつ減らしていくとS(x) - S(l) は1ずつ変化していく。ここでS(l+2) - S(l) は最大で2 であることから、S(r) - S(l) = 3ならば、S(i) - S(l) = 2, l+2 <= i <= r を満たすようなi が必ず存在することがわかる。

・つまり、① S(r) > S(l) , ② r - l >= 2, ③ a[i] = k (iは l <= i < r を満たす i) という(l, r) を見つければ良いということになる。

方針

・r を昇順に見ていくことにする。条件②, ③を満たすようなl の範囲でS(l) の最小値を求めれば良い。
・セグ木(区間最小)を使いました。

Wrong Answer on pretest 10

・はい
・k 以外の要素を増やしていくという戦略を思いつきませんでした。くやしい。
・k = 3, a = {3, 1, 1, 1, 4, 5} みたいなのはno だと思ったんですが、1 の区間を全部4 に変更すれば、k と大のみという状態になるのでyes なんですね、くそ〜。

・③の条件を、a[i] = k (i は 0 <= i < n ) に代えればいけます。
・改めて、① S(r) > S(l) , ② r - l >= 2, ③ a[i] = k (iは 0 <= i < n を満たす i) を満たす(l, r) が1つでもあればyes、ないならno

以下実装

#include <iostream>
#include <algorithm>
#include <vector>
#define rep(i,n) for(int i=0;i<n;++i)
#define rep1(i,n) for(int i=1;i<=n;++i)
using namespace std;
template<class T>bool chmax(T &a, const T &b) { if(a < b){ a = b; return 1; } return 0; }
template<class T>bool chmin(T &a, const T &b) { if(a > b){ a = b; return 1; } return 0; }
template <typename F,typename T>
struct SegTree{
  // 二項演算merge,単位元identify
  T identity;
  F merge; 
  int size;
  vector<T> dat;
  
  // 二項演算fと単位元idを渡して宣言する
  SegTree(F f,T id):merge(f),identity(id){}

  // データの要素の数nを渡して初期化、sizeはnより大きい2の冪乗
  void init(int n){
    size = 1;
    while(size<=n) size *= 2;
    dat.resize(size*2-1,identity);
  }

  // 配列を渡して0(n)で初期化
  void build(vector<T> vec){
    rep(i,vec.size()) dat[size-1+i] = vec[i];
    dfs(0);
  }
  
  T dfs(int k){
    if(k>=size-1) return dat[k];
    else return dat[k] = merge(dfs(2*k+1),dfs(2*k+2));
  }

  // index kの要素をaに変更
  void update(int k,T a){
    k += size - 1;
    dat[k] = a;
    while(k > 0){
      k = (k-1)/2;
      dat[k] = merge(dat[2*k+1],dat[2*k+2]);
    }
  }

  // 区間[a,b)に対するクエリに答える。(k,l,r)=(0,0,size)
  T query(int a,int b,int k,int l,int r){
    if(r<=a||b<=l) return identity;
    
    if(a<=l&&r<=b) return dat[k]; 
    else return merge(query(a,b,2*k+1,l,(l+r)/2),query(a,b,2*k+2,(l+r)/2,r));
  }
  
  void show(){
    int index = 0;
    int num = 1;
    while(index<size){
      rep(i,num){
    if(dat[i+index]==identity) cout << "e ";
    else cout << dat[i+index] << " ";
      }
      cout << "\n";
      num *= 2;
      index = index*2+1;
    }
  }
  
};

int main()
{
  int t;cin >> t;
  vector<int> n(t), k(t);
  vector<vector<int>> a(t);
  rep(i,t) {
    cin >> n[i] >> k[i];
    a[i].resize(n[i]);
    rep(j,n[i]) {
      cin >> a[i][j];
    }
  }

  auto f = [](int x,int y){ return min(x,y); }; // RMinQ
  int id = 1e+9;
  

  rep(i,t) {
    if(n[i] == 1 && a[i][0] == k[i]) {
      cout << "yes" << "\n";
      continue;
    }

    vector<int> sum(n[i]+1); // cumulative sum
    sum[0] = 0;
    rep(j,n[i]) {
      if(a[i][j] >= k[i]) sum[j+1] = sum[j] + 1;
      else sum[j+1] = sum[j] - 1;
    }
    SegTree<decltype(f),int> seg(f, id);
    seg.init(n[i]+2);
    seg.build(sum);
    int lim = 0;
    bool flag = false;
    bool ink = false; // is there k ? 
    rep(j,n[i]) {
      if(a[i][j] >= k[i]) lim = j+1;
      if(a[i][j] == k[i]) ink = true;
      if(lim > 0) {
    int mini = (a[i][j] >= k[i] ? seg.query(0, lim-1, 0, 0, seg.size) : seg.query(0, lim, 0, 0, seg.size));
    if(sum[j+1] > mini) flag = true;
      }
    }
    if(flag && ink) cout << "yes" << "\n";
    else cout << "no" << "\n";
  }


  
  return 0;
}