Contest 高数题 樹的點分治 樹形DP

高数题

HJA最近在刷高数题,他遇到了这样一道高数题。这道高数题里面有一棵N个点的树,树上每个点有点权,每条边有颜色。一条路径的权值是这条路径上所有点的点权和,一条合法的路径需要满足该路径上任意相邻的两条边颜色都不相同。问这棵树上所有合法路径的权值和是多少

输入第一行一个整数N,代表树上有多少个点。
接下来一行N个整数,代表树上每个点的权值。
接下来N-1行,每行三个整数S、E、C,代表S与E之间有一条颜色为C的边。输出一行一个整数,代表所求的值。样例输入

6
6 2 3 7 1 4
1 2 1
1 3 2
1 4 3
2 5 1
2 6 2

样例输出

134

提示

对与30%的数据,1≤N≤1000。 
对于另外20%的数据,可用的颜色数不超过109且随机数据。
对于另外20%的数据,树的形态为一条链。
对于100%的数据,1≤N≤3*105,可用的颜色数不超过109,所有点权的大小不超过105。

這道題簡單的說是一個樹形DP,由下至上分別統計路徑條數,而考場上我想都沒想就開始寫樹的點分治,當天狀態不佳,加之本身的不熟練,沒能按時把點分治寫出來。

唯一需要注意的是n=3*10^5,dfs足以使程序崩掉,以後最好改寫bfs

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<ctime>
#include<cmath>
#include<algorithm>
#include<set>
#include<map>
#include<vector>
#include<string>
#include<queue>
#include<stack>
using namespace std;
#ifdef WIN32
#define LL "%I64d"
#else
#define LL "%lld"
#endif
#define MAXN 410000
#define MAXV MAXN*2
#define MAXE MAXV*2
#define INF 0x3f3f3f3f
#define PROB "gaoshu"
typedef long long qword;
int nextInt()
{
        char ch;
        int x=0;
        while (ch=getchar(),ch < '0' || ch > '9' );
        //        cout<<ch;
        do
                x=x*10+ch-'0';
        while (ch=getchar(),ch<='9' && ch>='0');
        return x;
}
int n;
struct Edge
{
        int np,col;
        int val;
        pair<qword,qword> val2;
        Edge *next,*neg;
        int disable;
}E[MAXE],*V[MAXV];
qword val[MAXN];
int tope=-1;
int bad[MAXN][3];
int size[MAXN];
//int fa[MAXN],depth[MAXN];
//int jump[20][MAXN];

void addedge(int x,int y,int z)
{
//        cout<<"Add:"<<x<<" "<<y<<endl;
        E[++tope].np=y;
        E[tope].col=z;
        E[tope].next=V[x];
        V[x]=&E[tope];

        E[++tope].np=x;
        E[tope].col=z;
        E[tope].next=V[y];
        V[y]=&E[tope];

        E[tope].neg=&E[tope-1];
        E[tope-1].neg=&E[tope];
}
/*
void dfs1(int now,int d)
{
        size[now]=1;depth[now]=d;
        Edge *ne;
        for (ne=V[now];ne;ne=ne->next)
        {
                if (ne->np==fa[now])continue;
                fa[ne->np]=now;
                dfs1(ne->np,d+1);
                size[now]+=size[ne->np];
        }
}
void init_lca()
{
        int i,j;
        for (i=1;i<=n;i++)
        {
                jump[0][i]=fa[i];
        }
        for (j=1;j<20;j++)
        {
                for (i=1;i<=n;i++)
                {
                        jump[j][i]=jump[j-1][jump[j-1][i]];
                }
        }
}
void swim(int &now,int len)
{
        int i=0;
        while (len)
        {
                if (len&1)now=jump[i][now];
                i++;
        }
}
int lca(int x,int y)
{
        if (depth[x]>depth[y])
        {
                swim(x,depth[x]-depth[y]);
        }else
        {
                swim(y,depth[y]-depth[x]);
        }
        int i;
        if (x==y)return x;
        for (i=19;i>=0;i--)
        {
                if (jump[i][x]!=jump[i][y])
                {
                        x=jump[i][x];
                        y=jump[i][y];
                }
        }
        return fa[x];
}*/
int bcore,vcore;
int size2[MAXN];
int get_core_sizt;
int gc_sizf[MAXN];
int gc_mxsz[MAXN];
int gc_col[MAXN];
int gc_now[MAXN];
Edge *nel[MAXN];
int get_core(int dep=1)
{
        gc_sizf[dep]=get_core_sizt-1;
        gc_mxsz[dep]=0;
        size2[gc_now[dep]]=1;
        for (nel[dep]=V[gc_now[dep]];nel[dep];nel[dep]=nel[dep]->next)
        {
                if (nel[dep]->np==gc_now[dep-1] || nel[dep]->disable)continue;
                gc_col[dep+1]=nel[dep]->col;
                gc_now[dep+1]=nel[dep]->np;
                get_core(dep+1);
                size2[gc_now[dep]]+=size2[nel[dep]->np];
                gc_sizf[dep]-=size2[nel[dep]->np];
                gc_mxsz[dep]=max(gc_mxsz[dep],size2[nel[dep]->np]);
        }
        gc_mxsz[dep]=max(gc_mxsz[dep],gc_sizf[dep]);
        if (gc_mxsz[dep]<vcore)
        {
                vcore=gc_mxsz[dep];
                bcore=gc_now[dep];
        }
}
pair<qword,qword> dfs2_ret[MAXN],dfs2_tt[MAXN];
int dfs2_now[MAXN],dfs2_col[MAXN];
pair<qword,qword> dfs2(int dep=1)
{
//        pair<qword,qword> ret,tt;
        dfs2_ret[dep]=make_pair(val[dfs2_now[dep]],1);
        for (nel[dep]=V[dfs2_now[dep]];nel[dep];nel[dep]=nel[dep]->next)
        {
                if (nel[dep]->np==dfs2_now[dep-1] || nel[dep]->disable)continue;
                dfs2_now[dep+1]=nel[dep]->np;
                dfs2_col[dep+1]=nel[dep]->col;
                dfs2_tt[dep]=dfs2(dep+1);
                dfs2_ret[dep].second+=dfs2_tt[dep].second*(nel[dep]->col!=dfs2_col[dep]);
                dfs2_ret[dep].first+=(dfs2_tt[dep].first+dfs2_tt[dep].second*val[dfs2_now[dep]]) *(nel[dep]->col!=dfs2_col[dep]);
        }
        return dfs2_ret[dep];
}
qword solve(int root,int siz)
{
        if (siz==1)return 0;
        vcore=INF;
        gc_now[0]=root;
        gc_col[1]=INF;
        gc_now[1]=root;
        get_core_sizt=siz;
        get_core(1);
        int core=bcore;
        gc_now[0]=core;
        gc_now[1]=core;
        get_core(1);
        qword ans=0;
        Edge *ne;
        for (ne=V[core];ne;ne=ne->next)
        {
                if (ne->disable)continue;
                ne->disable=core;
                ne->neg->disable=core;
                ne->val=size[ne->np];
        }
        for (ne=V[core];ne;ne=ne->next)
        {
                if (ne->disable!=core)continue;
                ans+=solve(ne->np,size2[ne->np]);
        }
        for (ne=V[core];ne;ne=ne->next)
        {
                if (ne->disable!=core)continue;
                dfs2_now[0]=dfs2_now[1]=ne->np;
                dfs2_col[1]=ne->col;
                ne->val2=dfs2(1);
        }
        Edge *ne2;
        int t=0;
        map<int,qword> mp;
        pair<qword,qword> sum=make_pair(0,0);
        for (ne=V[core];ne;ne=ne->next)
        {
                if (ne->disable!=core)continue;    
                sum.first+=ne->val2.first;
                sum.second+=ne->val2.second;
                mp[ne->col]+=ne->val2.second;
        }
        qword ans2=0;
        for (ne=V[core];ne;ne=ne->next)
        {
                if (ne->disable!=core)continue;
                ans+=ne->val2.first*(sum.second-mp[ne->col]);//分居兩邊
                ans2+=val[core]*ne->val2.second*(sum.second-mp[ne->col]);//分局兩邊,中心貢獻
        }
        ans+=ans2/2;
        for (ne=V[core];ne;ne=ne->next)
        {
                if (ne->disable!=core)continue;
                t+=ne->val2.second;//中心出發條數
                ans+=ne->val2.first;//中心出發,外點貢獻
        }
        ans+=val[core]*t;
        for (ne=V[core];ne;ne=ne->next)
        {
                if (ne->disable==core)
                        ne->disable=ne->neg->disable=0;
        }
        return ans;
}
int main()
{
        //freopen("input.txt","r",stdin);
        //freopen("output.txt","w",stdout);
        freopen(PROB".in","r",stdin);
        freopen(PROB".out","w",stdout);
        int i,j,k;
        int x,y,z;
        //scanf("%d",&n);
        n=nextInt();
        for (i=1;i<=n;i++)
                val[i]=nextInt();//scanf("%d",&val[i]);
        for (i=1;i<n;i++)
        {
                //scanf("%d%d%d",&x,&y,&z);
                x=nextInt();
                y=nextInt();
                z=nextInt();
                addedge(x,y,z);
        }
//        fa[1]=1;
//        dfs1(1,0);
//        init_lca();
        qword ans=solve(1,n);
        printf(LL "
",ans);
        return 0;

}
by mhy12345(http://www.cnblogs.com/mhy12345/) 未经允许请勿转载

本博客已停用,新博客地址:http://mhy12345.xyz

原文地址:https://www.cnblogs.com/mhy12345/p/4001131.html