bzoj 2243 染色

题目大意:

给定一棵有n个节点的无根树和m个操作,操作有2类:

1、将节点a到节点b路径上所有点都染成颜色c

2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段)

如“112221”由3段组成:“11”、“222”和“1”

请你写一个程序依次完成这m个操作

思路:

直接树剖

但是需要注意一下bl[a]和fa[bl[a]]的颜色是否相同,对于这种情况做一下特殊处理

  1 #include<iostream>
  2 #include<cstdio>
  3 #include<cmath>
  4 #include<cstdlib>
  5 #include<cstring>
  6 #include<algorithm>
  7 #include<vector>
  8 #include<queue>
  9 #define inf 2139062143
 10 #define ll long long
 11 #define MAXN 151010
 12 using namespace std;
 13 inline int read()
 14 {
 15     int x=0,f=1;char ch=getchar();
 16     while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
 17     while(isdigit(ch)) {x=x*10+ch-'0';ch=getchar();}
 18     return x*f;
 19 }
 20 int n,T,bl[MAXN],cnt[MAXN],hsh[MAXN],fa[MAXN],dep[MAXN],val[MAXN];
 21 int Cnt,to[MAXN<<1],nxt[MAXN<<1],fst[MAXN];
 22 struct data {int tag,lc,rc,sum,l,r;}tr[MAXN<<2];
 23 void add(int u,int v) {nxt[++Cnt]=fst[u],fst[u]=Cnt,to[Cnt]=v;}
 24 void dfs(int x)
 25 {
 26     cnt[x]=1;
 27     for(int i=fst[x];i;i=nxt[i])
 28     {
 29         if(to[i]==fa[x]) continue;
 30         fa[to[i]]=x,dep[to[i]]=dep[x]+1;
 31         dfs(to[i]);
 32         cnt[x]+=cnt[to[i]];
 33     }
 34 }
 35 void Dfs(int x,int anc)
 36 {
 37     int hvs=0;hsh[x]=++Cnt,bl[x]=anc;
 38     for(int i=fst[x];i;i=nxt[i])
 39         if(to[i]!=fa[x]&&cnt[hvs]<cnt[to[i]]) hvs=to[i];
 40     if(!hvs) return ;
 41     Dfs(hvs,anc);
 42     for(int i=fst[x];i;i=nxt[i])
 43         if(to[i]!=fa[x]&&to[i]!=hvs) Dfs(to[i],to[i]);
 44 }
 45 void build(int k,int l,int r)
 46 {
 47     tr[k].l=l,tr[k].r=r;
 48     if(l==r) {tr[k].sum=1,tr[k].lc=tr[k].rc=tr[k].tag=0;return ;}
 49     int mid=(l+r)>>1;
 50     build(k<<1,l,mid);
 51     build(k<<1|1,mid+1,r);
 52 }
 53 void pushdown(int k)
 54 {
 55     tr[k<<1].tag=tr[k<<1|1].tag=tr[k].tag;
 56     tr[k<<1].sum=tr[k<<1|1].sum=1;
 57     tr[k<<1].lc=tr[k<<1].rc=tr[k<<1|1].lc=tr[k<<1|1].rc=tr[k].tag,tr[k].tag=0;
 58 }
 59 void pushup(int k)
 60 {
 61     tr[k].sum=tr[k<<1].sum+tr[k<<1|1].sum;
 62     if(!(tr[k<<1].rc^tr[k<<1|1].lc)) tr[k].sum--;
 63     tr[k].lc=tr[k<<1].lc,tr[k].rc=tr[k<<1|1].rc;
 64 }
 65 void upd(int k,int a,int b,int x)
 66 {
 67     int l=tr[k].l,r=tr[k].r;
 68     if(l==a&&r==b) {tr[k].sum=1,tr[k].tag=tr[k].lc=tr[k].rc=x;return ;}
 69     int mid=(l+r)>>1;
 70     if(tr[k].tag) pushdown(k);
 71     if(b<=mid) upd(k<<1,a,b,x);
 72     else if(a>mid) upd(k<<1|1,a,b,x);
 73     else {upd(k<<1,a,mid,x);upd(k<<1|1,mid+1,b,x);}
 74     pushup(k);
 75 }
 76 int query(int k,int a,int b)
 77 {
 78     int l=tr[k].l,r=tr[k].r;
 79     if(l==a&&r==b) return tr[k].sum;
 80     int mid=(l+r)>>1;
 81     if(tr[k].tag) pushdown(k);
 82     if(b<=mid) return query(k<<1,a,b);
 83     else if(a>mid) return query(k<<1|1,a,b);
 84     else return query(k<<1,a,mid)+query(k<<1|1,mid+1,b)-(tr[k<<1].rc==tr[k<<1|1].lc);
 85 }
 86 int gc(int k,int x)
 87 {
 88     if(tr[k].tag) pushdown(k);
 89     if(tr[k].l==tr[k].r) return tr[k].lc;
 90     int mid=(tr[k].l+tr[k].r)>>1;
 91     if(x<=mid) return gc(k<<1,x);
 92     if(x>mid) return gc(k<<1|1,x);
 93 }
 94 int main()
 95 {
 96     n=read(),T=read();
 97     int a,b,c,res;
 98     for(int i=1;i<=n;i++) val[i]=read();
 99     for(int i=1;i<n;i++) {a=read(),b=read();add(a,b);add(b,a);}
100     dep[1]=1,fa[1]=1,Cnt=0;
101     dfs(1);Dfs(1,1);
102     build(1,1,n);
103     for(int i=1;i<=n;i++) upd(1,hsh[i],hsh[i],val[i]);
104     char ch[3];
105     while(T--)
106     {
107         scanf("%s",ch);
108         if(ch[0]=='C')
109         {
110             a=read(),b=read(),c=read();
111             while(bl[a]!=bl[b])
112             {
113                 if(dep[bl[a]]<dep[bl[b]]) swap(a,b);
114                 upd(1,hsh[bl[a]],hsh[a],c);
115                 a=fa[bl[a]];
116             }
117             upd(1,min(hsh[a],hsh[b]),max(hsh[a],hsh[b]),c);
118         }
119         else
120         {
121             a=read(),b=read();
122             res=0;
123             while(bl[a]!=bl[b])
124             {
125                 if(dep[bl[a]]<dep[bl[b]]) swap(a,b);
126                 res+=query(1,hsh[bl[a]],hsh[a])-(gc(1,hsh[bl[a]])==gc(1,hsh[fa[bl[a]]]));
127                 a=fa[bl[a]];
128             }
129             res+=query(1,min(hsh[a],hsh[b]),max(hsh[a],hsh[b]));
130             printf("%d
",res);
131         }
132     }
133 }
View Code
原文地址:https://www.cnblogs.com/yyc-jack-0920/p/8184820.html