BZOJ 2566 xmastree(树分治+multiset)

题目链接:http://www.lydsy.com:808/JudgeOnline/problem.php?id=2566

题意:一棵有边权的树。结点有颜色。每次修改一个点的颜色。求每次修改后所有同色结点的最近距离。

思路:整体是树分治的方法。其实,分治之后,我们可以理解为重构了这棵树,使得最大深度最小。这棵树的每个结点对于每种颜色保存两个值。一个是该种颜色的所有点到该结点的距离,设这些距离中最小的两个值Min1,Min2,那么Min1+Min2可看做是经过当前结点的这种颜色的两个结点的最近距离,另外,还要保存其他孩子中关于这种颜色的最小值。所有这些最小值的最小值和Min1+Min2再取最小值就是以该节点为根的树的这种颜色的最小值。这些最小值用multiset维护。

另外,结点维护所有同种颜色的距离时也用multiset维护。

对于修改操作,设u节点的颜色由c1变为c2,分两步完成,第一步删除u结点的c1,第二步加入u结点的c2。

两种操作类似,都是从当前结点一直向上更新。细节比较多。

multiset直接删除一个值时是将所有这个值都删掉,只删除一个的话要用指针。

const int M=12005;

struct node
{
    int v,w,next;
};

node edges[M<<1];
int head[M],eNum;

void add(int u,int v,int w)
{
    edges[eNum].v=v;
    edges[eNum].w=w;
    edges[eNum].next=head[u];
    head[u]=eNum++;
}

int n,m;
int color[M];

int fa[M][20];
int d[M],dep[M];

int visit[M];

queue<int> Q;

void init()
{

    Q.push(1);
    visit[1]=1;

    while(!Q.empty())
    {
        int u=Q.front();
        Q.pop();

        for(int i=head[u];i!=-1;i=edges[i].next)
        {
            int v=edges[i].v;
            int w=edges[i].w;
            if(!visit[v])
            {
                fa[v][0]=u;
                d[v]=d[u]+w;
                dep[v]=dep[u]+1;
                Q.push(v);
                visit[v]=1;
            }
        }
    }

    for(int i=1;i<20;i++) for(int j=1;j<=n;j++)
    {
        fa[j][i]=fa[fa[j][i-1]][i-1];
    }
}

int calLca(int u,int v)
{
    if(dep[u]<dep[v]) swap(u,v);
    int x=dep[u]-dep[v];
    for(int i=0;i<20;i++) if(x&(1<<i)) u=fa[u][i];
    if(u==v) return u;
    for(int i=19;i>=0;i--)
    {
        if(fa[u][i]&&fa[v][i]&&fa[u][i]!=fa[v][i])
        {
            u=fa[u][i];
            v=fa[v][i];
        }
    }
    return fa[u][0];
}

int calDis(int u,int v)
{
    int lca=calLca(u,v);
    return d[u]+d[v]-d[lca]*2;
}

int sonNum[M];

int nodeSum;

int arr[M],arrNum;


void dfs(int u,int pre)
{
    arr[++arrNum]=u;

    nodeSum++;
    sonNum[u]=1;
    for(int i=head[u];i!=-1;i=edges[i].next)
    {
        int v=edges[i].v;
        if(v!=pre&&!visit[v])
        {
            dfs(v,u);
            sonNum[u]+=sonNum[v];
        }
    }
}

int calCenter(int u)
{
    nodeSum=0; arrNum=0;
    dfs(u,0);
    int ans=u,Min=INF;
    for(int i=1;i<=arrNum;i++)
    {
        int u=arr[i];
        int tmp=max(sonNum[u],nodeSum-sonNum[u]);
        if(tmp<Min) Min=tmp,ans=u;
    }
    return ans;
}


struct NODE
{
    multiset<int> S,p;

    int getAns()
    {
        if(p.size()<2) return INF;
        int Min1=*p.begin();
        multiset<int>::iterator it=p.begin();
        it++;
        int Min2=*it;
        return Min1+Min2;
    }

    void del(int x)
    {
        p.erase(p.find(x));
    }
};

map<int,NODE> mp[M];
map<int,NODE>::iterator it;


int inq[M],KK;
int dis[M];
int parent[M];


int DFS(int root)
{
    root=calCenter(root);
    Q.push(root);
    KK++;
    inq[root]=KK;
    dis[root]=0;


    while(!Q.empty())
    {
        int u=Q.front();
        Q.pop();

        int c=color[u];

        if(mp[root].count(c))
        {
            mp[root][c].p.insert(dis[u]);
        }
        else
        {
            NODE tmp;
            tmp.p.insert(dis[u]);
            mp[root][c]=tmp;
        }

        for(int i=head[u];i!=-1;i=edges[i].next)
        {
            int v=edges[i].v;
            int w=edges[i].w;
            if(!visit[v]&&KK!=inq[v])
            {
                inq[v]=KK;
                dis[v]=dis[u]+w;
                Q.push(v);
            }
        }
    }

    for(it=mp[root].begin();it!=mp[root].end();it++)
    {
        NODE tmp=it->second;
        it->second.S.insert(tmp.getAns());
    }
    visit[root]=1;

    for(int i=head[root];i!=-1;i=edges[i].next)
    {
        int v=edges[i].v;
        if(!visit[v])
        {
            v=DFS(v);

            for(it=mp[v].begin();it!=mp[v].end();it++)
            {
                int c=it->first;
                int w=*(it->second.S.begin());
                mp[root][c].S.insert(w);
            }
            parent[v]=root;
        }
    }
    return root;
}

int root;

multiset<int> SS;

void Add(int u,int c,int dis)
{
    if(mp[u].count(c))
    {
        mp[u][c].p.insert(dis);
    }
    else
    {
        NODE tmp;
        tmp.p.insert(dis);
        mp[u][c]=tmp;
    }
}

int st[M],stTop;


void setDel(multiset<int> &S,int x)
{
    multiset<int>::iterator it=S.find(x);
    if(it!=S.end()) S.erase(it);
}

void del(int u,int c)
{
    setDel(SS,*mp[root][c].S.begin());

    int curNode=u;

    stTop=0;
    while(curNode) st[++stTop]=curNode,curNode=parent[curNode];

    for(int i=stTop;i>1;i--)
    {
        int u=st[i];
        int v=st[i-1];
        setDel(mp[u][c].S,*mp[v][c].S.begin());
    }
    curNode=u;
    while(curNode)
    {
        int p=parent[curNode];

        setDel(mp[curNode][c].S,mp[curNode][c].getAns());
        mp[curNode][c].del(calDis(u,curNode));
        mp[curNode][c].S.insert(mp[curNode][c].getAns());

        if(p)
        {
            mp[p][c].S.insert(*(mp[curNode][c].S.begin()));
        }
        curNode=p;
    }
    SS.insert(*mp[root][c].S.begin());
}

void upd(int u,int c)
{
    if(mp[root].count(c))
    {
        setDel(SS,*mp[root][c].S.begin());
    }
    int curNode=u;
    stTop=0;
    while(curNode) st[++stTop]=curNode,curNode=parent[curNode];

    for(int i=stTop;i>1;i--)
    {
        int u=st[i];
        int v=st[i-1];
        if(!mp[v].count(c)) break;
        setDel(mp[u][c].S,*mp[v][c].S.begin());
    }
    for(int i=1;i<=stTop;i++)
    {
        int u=st[i];
        if(mp[u].count(c))
        {
            setDel(mp[u][c].S,mp[u][c].getAns());
        }
    }

    curNode=u;
    while(curNode)
    {
        Add(curNode,c,calDis(u,curNode));
        mp[curNode][c].S.insert(mp[curNode][c].getAns());

        int p=parent[curNode];
        if(p)
        {
            mp[p][c].S.insert(*(mp[curNode][c].S.begin()));
        }
        curNode=p;
    }
    SS.insert(*mp[root][c].S.begin());
}

void change(int u,int c)
{
    if(color[u]==c) return;
    del(u,color[u]);
    upd(u,c);

    color[u]=c;
}

int main()
{
    n=myInt();
    for(int i=1;i<=n;i++) color[i]=myInt();
    clr(head,-1);
    for(int i=1;i<n;i++)
    {
        int u=myInt();
        int v=myInt();
        int w=myInt();
        add(u,v,w);
        add(v,u,w);
    }
    init();
    clr(visit,0);

    root=DFS(1);

    for(it=mp[root].begin();it!=mp[root].end();it++)
    {
        SS.insert(*(it->second.S.begin()));
    }
    int minDis=*SS.begin();

    printf("%d
",minDis==INF?-1:minDis);

    m=myInt();
    while(m--)
    {
        int x=myInt();
        int y=myInt();
        change(x,y);

        minDis=*SS.begin();
        printf("%d
",minDis==INF?-1:minDis);
    }
}

  

原文地址:https://www.cnblogs.com/jianglangcaijin/p/4224470.html