【bzoj2243】[SDOI2011]染色 树链剖分+线段树

题目描述

给定一棵有n个节点的无根树和m个操作,操作有2类:

1、将节点a到节点b路径上所有点都染成颜色c;

2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”3段组成:“11”、“222”和“1”

请你写一个程序依次完成这m个操作。

输入

第一行包含2个整数n和m,分别表示节点数和操作数;

第二行包含n个正整数表示n个节点的初始颜色

下面 行每行包含两个整数x和y,表示xy之间有一条无向边。

下面 行每行描述一个操作:

“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;

“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。

输出

对于每个询问操作,输出一行答案。

样例输入

6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5

样例输出

3
1
2

提示

数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间


题解

裸的树链剖分+线段树。

区间修改非常恶心,很多细节。

多写写应该就能好了吧。。。

#include <stdio.h>
#include <algorithm>
using namespace std;
#define lson l , mid , x << 1
#define rson mid + 1 , r , x << 1 | 1
#define N 100005
int fa[N] , deep[N] , si[N] , val[N] , bl[N] , pos[N] , tot;
int head[N] , to[N << 1] , next[N << 1] , cnt;
int sum[N << 2] , lc[N << 2] , rc[N << 2] , mark[N << 2] , n;
char str[10];
void add(int x , int y)
{
    to[++cnt] = y;
    next[cnt] = head[x];
    head[x] = cnt;
}
void dfs1(int x)
{
    int i , y;
    si[x] = 1;
    for(i = head[x] ; i ; i = next[i])
    {
        y = to[i];
        if(y != fa[x])
        {
            fa[y] = x;
            deep[y] = deep[x] + 1;
            dfs1(y);
            si[x] += si[y];
        }
    }
}
void dfs2(int x , int c)
{
    int k = 0 , i , y;
    bl[x] = c;
    pos[x] = ++tot;
    for(i = head[x] ; i ; i = next[i])
    {
        y = to[i];
        if(fa[x] != y && si[y] > si[k])
            k = y;
    }
    if(k != 0)
    {
        dfs2(k , c);
        for(i = head[x] ; i ; i = next[i])
        {
            y = to[i];
            if(fa[x] != y && y != k)
                dfs2(y , y);
        }
    }
}
void pushup(int x)
{
    lc[x] = lc[x << 1];
    rc[x] = rc[x << 1 | 1];
    sum[x] = sum[x << 1] + sum[x << 1 | 1];
    if(rc[x << 1] == lc[x << 1 | 1])
        sum[x] -- ;
}
void pushdown(int x)
{
    int tmp = mark[x];
    mark[x] = 0;
    if(tmp)
    {
        sum[x << 1] = sum[x << 1 | 1] = 1;
        lc[x << 1] = rc[x << 1] = lc[x << 1 | 1] = rc[x << 1 | 1] = tmp;
        mark[x << 1] = mark[x << 1 | 1] = tmp;
    }
}
void update(int b , int e , int v , int l , int r , int x)
{
    if(b <= l && r <= e)
    {
        sum[x] = 1;
        lc[x] = rc[x] = v;
        mark[x] = v;
        return;
    }
    pushdown(x);
    int mid = (l + r) >> 1;
    if(b <= mid)
        update(b , e , v , lson);
    if(e > mid)
        update(b , e , v , rson);
    pushup(x);
}
void solveupdate(int x , int y , int v)
{
    while(bl[x] != bl[y])
    {
        if(deep[bl[x]] < deep[bl[y]])
        {
            swap(x , y);
        }
        update(pos[bl[x]] , pos[x] , v , 1 , n , 1);
        x = fa[bl[x]];
    }
    if(deep[x] > deep[y])
        swap(x , y);
    update(pos[x] , pos[y] , v , 1 , n , 1);
}
int query(int b , int e , int l , int r , int x)
{
    if(b <= l && r <= e)
    {
        return sum[x];
    }
    pushdown(x);
    int mid = (l + r) >> 1 , ans = 0;
    if(b <= mid)
        ans += query(b , e , lson);
    if(e > mid)
        ans += query(b , e , rson);
    if(b <= mid && e > mid && rc[x << 1] == lc[x << 1 | 1])
        ans -- ;
    return ans;
}
int getcl(int p , int l , int r , int x)
{
    if(l == r)
        return lc[x];
    pushdown(x);
    int mid = (l + r) >> 1;
    if(p <= mid)
        return getcl(p , lson);
    else
        return getcl(p , rson);
}
int solvequery(int x , int y)
{
    int ans = 0;
    while(bl[x] != bl[y])
    {
        if(deep[bl[x]] < deep[bl[y]])
            swap(x , y);
        ans += query(pos[bl[x]] , pos[x] , 1 , n , 1);
        if(getcl(pos[bl[x]] , 1 , n , 1) == getcl(pos[fa[bl[x]]] , 1 , n , 1))
            ans -- ;
        x = fa[bl[x]];
    }
    if(deep[x] > deep[y])
        swap(x , y);
    ans += query(pos[x] , pos[y] , 1 , n , 1);
    return ans;
}
int main()
{
    int i , x , y , z , m;
    scanf("%d%d" , &n , &m);
    for(i = 1 ; i <= n ; i ++ )
        scanf("%d" , &val[i]);
    for(i = 1 ; i < n ; i ++ )
    {
        scanf("%d%d" , &x , &y);
        add(x , y);
        add(y , x);
    }
    dfs1(1);
    dfs2(1 , 1);
    for(i = 1 ; i <= n ; i ++ )
        update(pos[i] , pos[i] , val[i] , 1 , n , 1);
    while(m -- )
    {
        scanf("%s" , str);
        switch(str[0])
        {
            case 'C': scanf("%d%d%d" , &x , &y , &z); solveupdate(x , y , z); break;
            default: scanf("%d%d" , &x , &y); printf("%d
" , solvequery(x , y));
        }
    }
    return 0;
}
原文地址:https://www.cnblogs.com/GXZlegend/p/6184309.html