2243: [SDOI2011]染色
Time Limit: 20 Sec Memory Limit: 512 MBSubmit: 4270 Solved: 1608
[Submit][Status][Discuss]
Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面
下面
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
HINT
N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
Source
树链剖分,注意一下询问的细节就可以了。
代码:
#include<cstdio> #include<algorithm> #define maxn 1000000 using namespace std; int n,m,u,v,tot,opt,a,b,c,edgenum; int f[maxn],d[maxn],num[maxn],son[maxn],tid[maxn],pre[maxn],top[maxn],col[maxn],vet[maxn],next[maxn],head[maxn]; char s[5]; struct node{int l,r,sum,tag;}tr[maxn]; void add(int u,int v) { edgenum++; vet[edgenum]=v; next[edgenum]=head[u]; head[u]=edgenum; } void dfs(int u,int dep,int fa) { f[u]=fa; d[u]=dep; num[u]=1; int maxnum=0; for (int e=head[u];e;e=next[e]) { int v=vet[e]; if (v!=fa) { dfs(v,dep+1,u); num[u]+=num[v]; if (num[v]>maxnum){maxnum=num[v];son[u]=v;} } } } void dfs(int u,int number) { top[u]=number; tid[u]=++tot; pre[tot]=u; if (!son[u]) return; dfs(son[u],number); for (int e=head[u];e;e=next[e]) { int v=vet[e]; if (v!=f[u]&&v!=son[u]) dfs(v,v); } } void build(int l,int r,int p) { if (l==r){tr[p].l=tr[p].r=col[pre[l]];tr[p].sum=1;tr[p].tag=-1;return;} int mid=(l+r)>>1; build(l,mid,p*2);build(mid+1,r,p*2+1); tr[p].l=tr[p*2].l; tr[p].r=tr[p*2+1].r; tr[p].tag=-1; tr[p].sum=tr[p*2].sum+tr[p*2+1].sum-(tr[p*2].r==tr[p*2+1].l); } void change(int l,int r,int x,int y,int p,int c) { if (l==x&&r==y){tr[p].l=tr[p].r=c;tr[p].sum=1;tr[p].tag=c;return;} if (tr[p].tag!=-1) { tr[p*2].tag=tr[p].tag; tr[p*2].l=tr[p*2].r=tr[p].tag; tr[p*2+1].tag=tr[p].tag; tr[p*2+1].l=tr[p*2+1].r=tr[p].tag; tr[p*2].sum=tr[p*2+1].sum=1; tr[p].tag=-1; } int mid=(l+r)>>1; if (y<=mid) change(l,mid,x,y,p*2,c);else if (x>mid) change(mid+1,r,x,y,p*2+1,c);else { change(l,mid,x,mid,p*2,c); change(mid+1,r,mid+1,y,p*2+1,c); } tr[p].l=tr[p*2].l; tr[p].r=tr[p*2+1].r; tr[p].sum=tr[p*2].sum+tr[p*2+1].sum-(tr[p*2].r==tr[p*2+1].l); } void update(int u,int v,int c) { while (top[u]!=top[v]) { change(1,n,tid[top[u]],tid[u],1,c); u=f[top[u]]; } change(1,n,tid[v],tid[u],1,c); } int ask(int l,int r,int x,int y,int p) { if (l==x&&r==y) return tr[p].sum; if (tr[p].tag!=-1) { tr[p*2].tag=tr[p].tag; tr[p*2].l=tr[p*2].r=tr[p].tag; tr[p*2+1].tag=tr[p].tag; tr[p*2+1].l=tr[p*2+1].r=tr[p].tag; tr[p*2].sum=tr[p*2+1].sum=1; tr[p].tag=-1; } int mid=(l+r)>>1; int ans; if (y<=mid) ans=ask(l,mid,x,y,p*2);else if (x>mid) ans=ask(mid+1,r,x,y,p*2+1);else { int tmp=1; if (tr[p*2].r!=tr[p*2+1].l) tmp=0; ans=ask(l,mid,x,mid,p*2)+ask(mid+1,r,mid+1,y,p*2+1)-tmp; } tr[p].l=tr[p*2].l; tr[p].r=tr[p*2+1].r; tr[p].sum=tr[p*2].sum+tr[p*2+1].sum-(tr[p*2].r==tr[p*2+1].l); return ans; } int getc(int l,int r,int p,int x) { if (l==r) return tr[p].l; if (tr[p].tag!=-1) { tr[p*2].tag=tr[p].tag; tr[p*2].l=tr[p*2].r=tr[p].tag; tr[p*2+1].tag=tr[p].tag; tr[p*2+1].l=tr[p*2+1].r=tr[p].tag; tr[p*2].sum=tr[p*2+1].sum=1; tr[p].tag=-1; } int mid=(l+r)>>1; int color; if (x<=mid) color=getc(l,mid,p*2,x);else color=getc(mid+1,r,p*2+1,x); tr[p].l=tr[p*2].l; tr[p].r=tr[p*2+1].r; tr[p].sum=tr[p*2].sum+tr[p*2+1].sum-(tr[p*2].r==tr[p*2+1].l); return color; } int solve(int u,int v) { int ans=0; while (top[u]!=top[v]) { ans+=ask(1,n,tid[top[u]],tid[u],1); if (getc(1,n,1,tid[top[u]])==getc(1,n,1,tid[f[top[u]]])) ans--; u=f[top[u]]; } ans+=ask(1,n,tid[v],tid[u],1); return ans; } int lca(int u,int v) { while (top[u]!=top[v]) { if (d[top[u]]<d[top[v]])swap(u,v); u=f[top[u]]; } if (d[u]<d[v]) return u;else return v; } int main() { scanf("%d%d",&n,&m); for (int i=1;i<=n;i++) scanf("%d",&col[i]); for (int i=1;i<n;i++) { scanf("%d%d",&u,&v); add(u,v);add(v,u); } dfs(1,1,0); dfs(1,1); build(1,n,1); while (m--) { scanf("%s",s); if (s[0]=='C') { scanf("%d%d%d",&a,&b,&c); u=lca(a,b); update(a,u,c);update(b,u,c); }else { scanf("%d%d",&a,&b); u=lca(a,b); printf("%d\n",solve(a,u)+solve(b,u)-1); } } }