noiac132 B君的第三题 (树形dp)

传送门

本来想用点分治做,结果root又求不对 算的时候还算错了 我好菜啊

结果szr大佬告诉我是树形dp

我好菜啊!!

我们有$lceil frac{x}{k} ceil = frac{x+(k-x)\%k}{k}$

于是可以把这个拆成两部分来求,最后加在一起再除个k

距离和很好求,连接x和fa[x]的边的贡献就是$size[x]*(N-size[x])$

然后考虑到k很小,我们可以直接记x的子树中到x距离%k=y的个数f[x][y],然后拿这个去算

 1 #pragma GCC optimize(3)
 2 #include<bits/stdc++.h>
 3 #define pa pair<ll,ll>
 4 #define CLR(a,x) memset(a,x,sizeof(a))
 5 using namespace std;
 6 typedef long long ll;
 7 const int maxn=1e5+10,maxk=15;
 8 
 9 inline char gc(){
10     return getchar();
11     static const int maxs=1<<16;static char buf[maxs],*p1=buf,*p2=buf;
12     return p1==p2&&(p2=(p1=buf)+fread(buf,1,maxs,stdin),p1==p2)?EOF:*p1++;
13 }
14 inline ll rd(){
15     ll x=0;char c=gc();bool neg=0; 
16     while(c<'0'||c>'9'){if(c=='-') neg=1;c=gc();}
17     while(c>='0'&&c<='9') x=(x<<1)+(x<<3)+c-'0',c=gc();
18     return neg?(~x+1):x;
19 }
20 
21 int eg[maxn*2][3],egh[maxn],ect=1;
22 int N,K,siz[maxn],dcnt[maxk],smsiz,fa[maxn];
23 ll ans;
24 bool flag[maxn];
25 
26 inline void adeg(int a,int b,int c){
27     eg[++ect][0]=b,eg[ect][1]=c,eg[ect][2]=egh[a],egh[a]=ect;
28 }
29 
30 void getroot(int x,int f,int &rt,int &mis){
31     siz[x]=1;
32     int mm=0;
33     for(int i=egh[x];i;i=eg[i][2]){
34         int b=eg[i][0];if(b==f||flag[b]) continue;
35         fa[b]=x;getroot(b,x,rt,mis);
36         siz[x]+=siz[b];mm=max(mm,siz[b]);
37     }
38     mm=max(mm,smsiz-siz[x]);
39     if(mm<mis) rt=x,mis=mm;
40 }
41 
42 void getdis(int x,int f,int d){
43     dcnt[d]++;
44     for(int i=egh[x];i;i=eg[i][2]){
45         int b=eg[i][0];if(b==f||flag[b]) continue;
46         getdis(b,x,(d+eg[i][1])%K);
47     }
48 }
49 
50 inline ll calc(int x,int ini){
51     ll re=0;
52     CLR(dcnt,0);getdis(x,0,ini);
53     for(int i=0;i<K;i++){
54         for(int j=i+1;j<K;j++){
55             re+=1ll*dcnt[i]*dcnt[j]*((K+(K-i-j)%K)%K);
56         }
57     }
58     for(int i=0;i<K;i++) re+=1ll*dcnt[i]*(dcnt[i]-1)/2*((K+(K-i-i)%K)%K);
59     return re;
60 }
61 
62 void solve(int x){
63     flag[x]=1;
64     ans+=calc(x,0);
65     for(int i=egh[x];i;i=eg[i][2]){
66         int b=eg[i][0];if(flag[b]) continue;
67         ans-=calc(b,eg[i][1]%K);
68         int rt=0,mis=1e9;smsiz=siz[b];
69         getroot(b,0,rt,mis);
70         siz[fa[rt]]=smsiz-siz[rt];
71         solve(rt);
72     }
73 }
74 
75 void dfs(int x,int f){
76     siz[x]=1;
77     for(int i=egh[x];i;i=eg[i][2]){
78         int b=eg[i][0];if(b==f) continue;
79         dfs(b,x);siz[x]+=siz[b];
80         ans+=1ll*siz[b]*(N-siz[b])*eg[i][1];
81     }
82 }
83 
84 int main(){
85     // freopen("t3.in","r",stdin);
86     // freopen("t3.out","w",stdout);
87     int i,j,k;
88     N=rd(),K=rd();
89     for(i=1;i<N;i++){
90         int a=rd(),b=rd(),c=rd();
91         adeg(a,b,c);adeg(b,a,c);
92     }
93     smsiz=N;int mis=1e9,rt=0;
94     getroot(1,0,rt,mis);
95     solve(rt);
96     dfs(1,0);
97     printf("%lld
",ans/K);
98     return 0;
99 }
ZZ点分治
 1 #include<bits/stdc++.h>
 2 #define CLR(a,x) memset(a,x,sizeof(a))
 3 using namespace std;
 4 typedef long long ll;
 5 typedef pair<int,int> pa;
 6 const int maxn=1e5+10,maxk=15;
 7 
 8 inline ll rd(){
 9     ll x=0;char c=getchar();int neg=1;
10     while(c<'0'||c>'9'){if(c=='-') neg=-1;c=getchar();}
11     while(c>='0'&&c<='9') x=x*10+c-'0',c=getchar();
12     return x*neg;
13 }
14 
15 int eg[maxn*2][3],egh[maxn],ect;
16 int dp[maxn][maxk],siz[maxn];
17 int N,K;
18 ll ans;
19 
20 inline void adeg(int a,int b,int c){
21     eg[++ect][0]=b,eg[ect][1]=c,eg[ect][2]=egh[a],egh[a]=ect;
22 }
23 
24 inline void dfs(int x,int f){
25     siz[x]=1;dp[x][0]=1;
26     for(int i=egh[x];i;i=eg[i][2]){
27         int b=eg[i][0];if(b==f) continue;
28         dfs(b,x);siz[x]+=siz[b];
29         ans+=1ll*eg[i][1]*(N-siz[b])*siz[b];
30         for(int j=0;j<K;j++)
31             for(int k=0;k<K;k++)
32                 ans+=1ll*((K+(-(j+eg[i][1])%K-k)%K)%K)*dp[b][j]*dp[x][k];
33         for(int j=0;j<K;j++)
34             dp[x][(j+eg[i][1])%K]+=dp[b][j];
35     }
36 }
37 
38 int main(){
39     //freopen("","r",stdin);
40     int i,j,k;
41     N=rd(),K=rd();
42     for(i=1;i<N;i++){
43         int a=rd(),b=rd(),c=rd();
44         adeg(a,b,c);adeg(b,a,c);
45     }
46     dfs(1,0);
47     printf("%lld
",ans/K);
48     return 0;
49 }

 

原文地址:https://www.cnblogs.com/Ressed/p/9934805.html