题解:
(被无良的出题人拿来骗钱的题)o( ̄ヘ ̄o)
此题比较水,首先可以发现一个性质:一个位置改过一次之后,再改这个位置对答案没有影响。之后又发现了一个玄妙的东西:每个没有被改过的位置,它在第一次被重排的时候对答案的影响不会变化。
于是就变成了求每次修改时候,被影响的没被影响过的位置,然后就是一个线段树#include <cstdio>
#include <algorithm> #include <cstring> #include <cstdlib> #include <memory.h> #include <stack> #include <vector> #define lowbit(x) ((x) & (-(x))) #ifdef YJQ_LOCAL #define LL "%I64d" #else #define LL "%lld" #endif using namespace std; int n,m,pcnt; const int MAXN = 500010; struct OPT { int v,pos; } opt[MAXN]; vector<int> g[MAXN]; int siz[MAXN]; bool cmp(OPT o1,OPT o2) { return o1.v < o2.v; } int lab[MAXN],a[MAXN]; void Uni() { for (int i=1;i<=n;i++) { opt[i].v = a[i]; opt[i].pos = i; } sort(opt+1,opt+n+1,cmp); for (int i=1;i<=n;i++) { if (opt[i].v == opt[i-1].v) lab[opt[i].pos] = pcnt; else lab[opt[i].pos] = ++pcnt; } } int low[MAXN],bit[MAXN],pr[MAXN]; long long ans = 0; bool vis[MAXN]; void modify_bit(int p,int v) { for (;p<=pcnt;p+=lowbit(p)) bit[p] += v; } int query_bit(int p) { int ret= 0; for (; p; p-=lowbit(p)) ret += bit[p]; return ret; } struct NODE { int l,r,maxi; } seg[MAXN << 2]; void build(int x,int l,int r) { seg[x].l = l; seg[x].r= r; if (l == r) { seg[x].maxi = pr[l]; return; } int mid = (l + r) >> 1; build(x << 1,l,mid); build((x << 1) ^ 1,mid+1,r); seg[x].maxi = max(seg[x << 1].maxi,seg[(x << 1) + 1].maxi); } void modify(int x,int p,int v) { if (seg[x].l == seg[x].r) { seg[x].maxi = v; return; } int mid = (seg[x].l + seg[x].r) >> 1; if (p <= mid) modify(x << 1,p,v); else modify((x << 1) + 1,p,v); seg[x].maxi= max(seg[x << 1].maxi,seg[(x << 1) ^ 1].maxi); } void solve(int x,int p) { while (siz[x]) { int cur = g[x][siz[x]-1]; if (cur < p) { modify(1,x,cur); return; } ans -= low[cur]; vis[cur] = 1; siz[x] --; } modify(1,x,0); } void find(int x,int p) { if (seg[x].l == seg[x].r) { solve(seg[x].l,p); return; } if (seg[x << 1].maxi >= p) find(x << 1,p); if (seg[(x << 1) + 1].maxi >= p) find((x << 1) + 1,p); } void Find(int x,int l,int r,int p) { if (seg[x].l ==l && seg[x].r == r) { find(x,p); return; } int mid = (seg[x].l + seg[x].r) >> 1; if (r <= mid) Find(x << 1,l,r,p); else if (l > mid) Find((x << 1) ^ 1,l,r,p); else { Find(x << 1,l,mid,p); Find((x << 1) ^ 1,mid+1,r,p); } return; } int main(){ scanf("%d%d",&n,&m); for (int i=1;i<=n;i++) { scanf("%d",&a[i]); } Uni(); for (int i=n;i;i--) { int cur = lab[i]; // printf("%d %d\n",i,cur); low[i] = query_bit(cur-1); modify_bit(cur,1); ans = ans + low[i]; if (!pr[cur]) pr[cur] = i; } printf(LL "\n",ans); build(1,1,pcnt); for (int i=1;i<=n;i++) { g[lab[i]].push_back(i); siz[lab[i]] ++; } while (m--) { int p; scanf("%d",&p); if (vis[p]) { printf(LL "\n",ans); continue; } Find(1,1,lab[p],p); printf(LL "\n",ans); } }