[Updating]点分治学习笔记

Upd

  • (2020/2/15),又补了一题 LuoguP2664 树上游戏

  • (2020/2/14),补了一道例题 LuoguP3085 [USACO13OPEN]阴和阳Yin and Yang

To Do List

  • 动态点分治。这个看心情写吧......是贞德不想写qwq

嘛...上个世纪学的...好像全忘了....来写一下吧

这个应该算树上路径类问题的一类trick吧...

che dan环节

点分治嘛,顾名思义,先抓树上一个点算它对答案贡献,然后把这个点割掉,会变成几棵小一点的树,然后递归算就好了。

那么问题来了,点要怎么选呢?rand一个如果说他是一条链的话从上往下选点就被卡(n^2)默默码起手中的暴力,随便rand又有被针对的风险...

在分治递归的时候,每一层递归的总复杂度我们不想管它,我们要控制的就是每一次选点使得递归的层数变少。

树上有一个名词叫做重心详见CSP-2019 D2T3,重心旁边的子树大小最大是不会超过(n/2)的,所以我们每次点分治的时候先找当前分治到的这一联通块内的重心,然后算重心对答案的贡献在把重心割掉,分治就好了。

这样做递归的层数是不会超过(log n)的,具体的算一点对答案的贡献针对题目来看。

那怎么找重心呢?重心的定义,对于一棵树,其重心的最大子树大小一定是最小的,所以对树(Dfs)一遍,算出每个节点最大子树的大小是多少,取最小的就好了。

int siz[N],all,mx[N],rt,vis[N]; // all 表示当前联通块的大小,vis在下面会说
void getrt(int x,int prev)
{
    siz[x]=1,mx[x]=0;
    for(int i=fst[x];i;i=nxt[i])
    {
        int v=to[i]; if(v==prev) continue;
        getrt(v,x),siz[x]+=siz[v];
        mx[x]=max(mx[x],siz[v]);
    }
    mx[x]=max(mx[x],all-siz[x]); // 无根树嘛...从x父亲哪里跑出去的一坨也是x的子树
    if(mx[x]<mx[rt]) rt=x;
}

所以点分治的code大概长这样

int vis[N],siz[N],all,mx[N],rt; // all 表示当前联通块的大小
void getrt(int x,int prev)
{
    siz[x]=1,mx[x]=0;
    for(int i=fst[x];i;i=nxt[i])
    {
        int v=to[i]; if(v==prev) continue;
        getrt(v,x),siz[x]+=siz[v];
        mx[x]=max(mx[x],siz[v]);
    }
    mx[x]=max(mx[x],all-siz[x]); // 无根树嘛...从x父亲哪里跑出去的一坨也是x的子树
    if(mx[x]<mx[rt]) rt=x;
}
void dfz(int x)
{
    vis[x]=1; // 这里会用到vis
    /*
     * 假装这里是将x的贡献算上
    */
    // 分治
    for(int i=fst[x];i;i=nxt[i])
    {
        int v=to[i]; if(vis[v]) continue;
        mx[rt=0]=siz[v],all=siz[v];
        getrt(v,x),dfz(rt);
    }
}
int main()
{
    // 然后main里面要先求一下整棵树的重心
    mx[rt=0]=n,all=n;
    getrt(1,0),dfz(rt);
    return 0;
}

例题

好,扯了那么多,来看点题目吧。。。

LuoguP4178(POJ1741) Tree

套板子吧。。。每次算(x)的贡献的时候先把从(x)出发到当前分治到的联通块内所有点的路径找出来,然后排序,two-pointer算一下,然后发现样例都没过

[冷静分析.jpg]

按上面直接two-pointer后会出现(x)到同一子树内两个点的路径,也就是说会算上自交的路径。

咋办呢?对于(x)的一个子树(v),容斥掉(x)(v)里面两个点的路径就好了。

#include <bits/stdc++.h>
using namespace std;
#define fore(i,x) for(int i=head[x],v=e[i].to;i;i=e[i].nxt,v=e[i].to)
const int N=1e5+10;
int n,K;
struct edge
{
    int to,nxt,w;
}e[N<<1];
int head[N],cnt=0;
inline void ade(int x,int y,int w)
{e[++cnt]=(edge){y,head[x],w},head[x]=cnt;}
inline void addedge(int x,int y,int w){ade(x,y,w),ade(y,x,w);}
int siz[N],mx[N],rt,all,vis[N];
void getrt(int x,int prev)
{
    siz[x]=1,mx[x]=0;
    fore(i,x)if(!vis[v]&&v!=prev)
    {
        getrt(v,x),siz[x]+=siz[v];
        mx[x]=max(mx[x],siz[v]);
    }
    mx[x]=max(mx[x],all-siz[x]);
    if(mx[x]<mx[rt]) rt=x;
}
int dis[N],tot;
void getd(int x,int prev,int d)
{
    dis[++tot]=d;
    fore(i,x) if(v!=prev&&!vis[v]) getd(v,x,d+e[i].w);
}
int calc(int x,int w)
{
    tot=0,getd(x,0,w);
    sort(dis+1,dis+tot+1);
    int nw=tot,ans=0; for(int i=1;i<=tot;i++)
    {
        while(dis[nw]+dis[i]>K) nw--;
        if(nw<=i) break; ans+=nw-i;
    }
    return ans;
}
int ans;
void dfz(int x)
{
    vis[x]=1,ans+=calc(x,0);
    fore(i,x) if(!vis[v]) ans-=calc(v,e[i].w); // 去除不合法的路径(注意参数)
    fore(i,x) if(!vis[v])
    {
        rt=0,all=mx[rt]=siz[v];
        getrt(v,x),dfz(rt);
    }
}
int main()
{
    scanf("%d",&n);
    for(int i=1,x,y,w;i<=n-1;i++)
        scanf("%d%d%d",&x,&y,&w),addedge(x,y,w);
    scanf("%d",&K);
    rt=0,all=mx[rt]=n;
    getrt(1,0),dfz(rt);
    printf("%d
",ans);
    return 0;
}

CF161D Distance in Tree

嘛。。。这个要算的是长度等于(K)的路径数量。

一个可以直接套板子的做法就是用长度(<= K)的路径数量(-)小于(K)的路径数量,然后就是板子了。

上世纪的代码

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN=100010;
int n,K;
struct edge{
    int to,nxt;
}e[MAXN*2];
int head[MAXN],cnt=0;
void adde(int x,int y){
    e[++cnt]=(edge){y,head[x]},head[x]=cnt;
}
void addedge(int x,int y){
    adde(x,y);
    adde(y,x);
}
int size[MAXN],dp[MAXN],vis[MAXN],root,sum;
void getRoot(int x,int prev){
    size[x]=1,dp[x]=0;
    for (int i=head[x];i;i=e[i].nxt){
        int v=e[i].to;
        if (v==prev||vis[v]) continue;
        getRoot(v,x);
        size[x]+=size[v];
        dp[x]=max(dp[x],size[v]);
    }
    if ((dp[x]=max(dp[x],sum-size[x]))<dp[root])
        root=x;
}
int tot=0,dis[MAXN],D[MAXN];
void getDis(int x,int prev){
    D[++tot]=dis[x];
    for (int i=head[x];i;i=e[i].nxt){
        int v=e[i].to;
        if (v==prev||vis[v]) continue;
        dis[v]=dis[x]+1,getDis(v,x);
    }
}
ll calc(int x,int w){
    dis[x]=w,tot=0,getDis(x,0);
    sort(D+1,D+tot+1);
    int l=1,r=tot;
    ll tmp1=0,tmp2=0;
    while (l<r){
        if (D[l]+D[r]<=K) tmp1+=r-l,l++;
        else r--;
    }
    l=1,r=tot;
    while (l<r){
        if (D[l]+D[r]<K) tmp2+=r-l,l++;
        else r--;
    }
    return tmp1-tmp2;
}
ll ans=0;
void solve(int x){
    vis[x]=1,ans+=calc(x,0);
    for (int i=head[x];i;i=e[i].nxt){
        int v=e[i].to;
        if (vis[v]) continue;
        ans-=calc(v,1);
        sum=size[v],dp[root=0]=0x3f3f3f3f;
        getRoot(v,x),solve(root);
    }
}
int main(){
    scanf("%d%d",&n,&K);
    for (int i=1;i<=n-1;i++){
        int x,y;scanf("%d%d",&x,&y);
        addedge(x,y);
    }
    dp[root=0]=0x3f3f3f3f,sum=n;
    getRoot(1,0),solve(root);
    printf("%I64d
",ans);
    return 0;
}

LuoguP2634 [国家集训队]聪聪可可

答案就是 长度为(3)的倍数的路径数量$ / $所有路径数量。

这里因为交换两点算两条路径,端点还可以重合,所以所有路径数量为(n^2)

然后算一个点对答案的贡献的时候可以开个桶(cnt[0..2]),表示从当前重心出发到当前联通快内所有点的路径,长度除以(3)的余数为(0,1,2)的路径数量。

那么这个点对合法路径数量的贡献就是(cnt[0]*cnt[0] + 2*cnt[1]*cnt[2])(注意(1,2)要乘(2),而(0)不用),对于自交的路径同样容斥算一下就好了。

#include <bits/stdc++.h>
using namespace std;
#define rep(i,j,k) for (int i=(int)(j);i<=(int)(k);i++)
#define per(i,j,k) for (int i=(int)(j);i>=(int)(k);i--)
#define mp make_pair
#define pb push_back
#define fi first
#define se second
int gcd(int x,int y){return y==0?x:gcd(y,x%y);}
const int MAXN=20010;
struct edge{int to,w,nxt;}e[MAXN<<1];
int head[MAXN],cur=0;
void addedge(int x,int y,int w){
    e[++cur]=(edge){y,w,head[x]};head[x]=cur;
    e[++cur]=(edge){x,w,head[y]};head[y]=cur;
}
int dp[MAXN],size[MAXN],vis[MAXN],cnt[3],dis[MAXN],root,sum;
int ans=0,n;
void getRoot(int x,int fa){
    size[x]=1,dp[x]=0;
    for (int i=head[x];i;i=e[i].nxt){
        int v=e[i].to;
        if (vis[v]||v==fa)continue;
        getRoot(v,x);
        size[x]+=size[v];
        dp[x]=max(dp[x],size[v]);
    }
    dp[x]=max(dp[x],sum-size[x]);
    if (dp[x]<dp[root])root=x;
}
void getDis(int x,int fa){
    cnt[dis[x]%3]++;
    for (int i=head[x];i;i=e[i].nxt){
        int v=e[i].to;
        if (v==fa||vis[v])continue;
        dis[v]=(dis[x]+e[i].w)%3;
        getDis(v,x);
    }
}
int calc(int x,int w){
    cnt[0]=cnt[1]=cnt[2]=0,dis[x]=w;
    getDis(x,0);
    return cnt[0]*cnt[0]+cnt[1]*cnt[2]*2;
}
void solve(int x){
    vis[x]=1;ans+=calc(x,0);
    for (int i=head[x];i;i=e[i].nxt){
        int v=e[i].to;
        if (vis[v])continue;
        ans-=calc(v,e[i].w);
        sum=size[v],dp[root=0]=n,getRoot(v,0);
        solve(root);
    }
}
int main(){
    scanf("%d",&n);
    rep (i,1,n-1){
        int x,y,w;scanf("%d%d%d",&x,&y,&w);
        addedge(x,y,w);
    }
    sum=dp[root=0]=n,getRoot(1,0);
    solve(root);
    int tmp=gcd(ans,n*n);
    printf("%d/%d
",ans/tmp,n*n/tmp);
    return 0;
}

LuoguP3806 【模板】点分治1

话说怎么到现在才讲模板题...

这个之前好像是数据水了。。。然后导致我(calc)的时候双重循环都能过。。。

考虑到(m)很小,所以离线下来,在点分治的时候一起回答。

这里提供一个最简单粗暴的方法。

在遍历重心(x)的子树的时候枚举询问,然后枚举(x)的当前子树(v)里的所有路径,然后算当前路径与之前遍历过的路径内是否有满足要求的,遍历完一棵子树后把路径都插入到一个multiset,查询的话(set)二分就好了。

这样复杂度应该对了qwq

#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define fore(i,x) for(int i=head[x],v=e[i].to,w=e[i].w;i;i=e[i].nxt,v=e[i].to,w=e[i].w)
const int N=1e5+10;
int n,m;
struct edge
{
    int to,nxt,w;
}e[N<<1];
int head[N],cnt=0;
inline void ade(int x,int y,int w)
{e[++cnt]=(edge){y,head[x],w};head[x]=cnt;}
inline void addedge(int x,int y,int w){ade(x,y,w),ade(y,x,w);}
vector<int>qs;
int ans[1010];
int siz[N],vis[N],mx[N],rt,all;
void getrt(int x,int prev)
{
    siz[x]=1,mx[x]=0;
    fore(i,x) if(!vis[v]&&v!=prev)
    {
        getrt(v,x),siz[x]+=siz[v];
        mx[x]=max(mx[x],siz[v]);
    }
    mx[x]=max(mx[x],all-siz[x]);
    if(mx[x]<mx[rt]) rt=x;
}
int dis[N],tot=0;
inline void getd(int x,int prev,int d)
{
    dis[++tot]=d;
    fore(i,x) if(v!=prev&&!vis[v]) getd(v,x,d+w);
}
#define IT multiset<int>::iterator
multiset<int> s;
void dfz(int x)
{
    vis[x]=1;
    fore(ei,x) if(!vis[v])
    {
        tot=0,getd(v,x,w);
        for(int i=0;i<m;i++)
        {
            if(ans[i]) continue;
            for(int j=1;j<=tot;j++)
            {
                if(dis[j]==qs[i]){ans[i]=1;break;} // 注意特判
                IT it=s.lower_bound(qs[i]-dis[j]);
                if(it==s.end()) continue;
                if(dis[j]+(*it)==qs[i]){ans[i]=1;break;}
            }
        }
        for(int i=1;i<=tot;i++) s.insert(dis[i]);
    }
    s.clear();
    fore(i,x) if(!vis[v])
    {
        all=mx[rt=0]=siz[v];
        getrt(v,x),dfz(rt);
    }
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1,x,y,w;i<n;i++)
        scanf("%d%d%d",&x,&y,&w),addedge(x,y,w);
    for(int i=1,x;i<=m;i++) scanf("%d",&x),qs.pb(x);
    all=mx[rt=0]=n,getrt(1,0),dfz(rt);
    for(int i=0;i<m;i++) puts(ans[i]?"AYE":"NAY");
    return 0;
}

LuoguP3085 [USACO13OPEN]阴和阳Yin and Yang

这个题也挺神的qwq

给一棵树,每条边有黑白两种颜色,下面用((x,y))表示树上点(x)到点(y)的路径

问有多少条路径((x,y)),满足存在路径上一点(z(z ot= x,y)),使得((x,z))这条路径上的黑白边数量相等,且((z,y))这条路径上的黑白边也相等。

终于看到一个要动脑子的题了

还是淀粉质,考虑如何算重心(x)对答案的贡献,下面我们令(d_v)表示当前分治到的块中,点(x)到点(v)路径上黑白边的数量差,算(d)的话可以把(0)的边边权设为(-1)(1)的边边权设为(1)

发现满足条件的一条经过(x)路径((u,v))一定满足(d_u = -d_v)(u,v)也不能在(x)的同一子树中)

((u,v))上是否存在点(z)呢?

分类讨论一下,点(z)要么在((x,u))上,要么在((x,v))上。

((x,u))上时,一定满足(d_z=d_x)(-d_z=d_v),在((v,x))上类似。

具体的把点分成两类,一类点是在该点(v)(x)的路径上存在(z),使得(d_v=d_z)的,二类点是没有的。

这样一类点可以和一,二类点产生贡献,而二类点只会和一类点产生贡献。

这个可以开几个桶算一下,实现看代码吧

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=4e5+10;
int n;
#define fore(i,x) for(int i=fst[x],v=to[i],c=col[i];i;i=nxt[i],v=to[i],c=col[i])
int fst[N],to[N<<1],nxt[N<<1],col[N<<1],es=0;
inline void ade(int x,int y,int w)
{to[++es]=y,col[es]=w,nxt[es]=fst[x],fst[x]=es;}
inline void addedge(int x,int y,int w)
{
    w=w==1?1:-1;
    ade(x,y,w),ade(y,x,w);
}
int siz[N],vis[N],mx[N],rt,all;
void getrt(int x,int prev)
{
    siz[x]=1,mx[x]=0;
    fore(i,x) if(v!=prev&&!vis[v])
    {
        getrt(v,x),siz[x]+=siz[v];
        mx[x]=max(mx[x],siz[v]);
    }
    mx[x]=max(mx[x],all-siz[x]);
    if(mx[x]<mx[rt]) rt=x;
}
int d[N],fa[N];
void dfs(int x,int prev,int w)
{
    d[x]=w,fa[x]=prev;
    fore(i,x) if(!vis[v]&&v!=prev) dfs(v,x,w+c);
}
int tmp[N<<1],cnt[2][N<<1]; 
/*
 cnt[0][i]表示2类点中,d为i的点的个数
 cnt[1][i]表示1类点中,d为i的点的个数
 tmp是用来遍历的时候确定1,2类点的。
*/
void upd(int x,int k1=1)
{
    if(tmp[n+d[x]]) cnt[1][n+d[x]]+=k1;
    else cnt[0][n+d[x]]+=k1;
    tmp[n+d[x]]++;
    fore(i,x) if(!vis[v]&&v!=fa[x]) upd(v,k1);
    tmp[n+d[x]]--;
}
ll ans=0;
void calc(int x)
{
    ans+=cnt[1][n-d[x]]+cnt[0][n-d[x]]*(tmp[n+d[x]]!=0);
    if(d[x]==0) ans+=tmp[n]>1; 
    /* 
     *这里要特判路径一个端点是重心的情况
     *因为z不能在两个端点上,所以tmp要大于1
    */
    tmp[n+d[x]]++;
    fore(i,x) if(!vis[v]&&v!=fa[x]) calc(v);
    tmp[n+d[x]]--;
}
void dfz(int x)
{
    vis[x]=1,dfs(x,0,0);
    tmp[n]=1; // 注意要把x也放到桶里面
    fore(i,x) if(!vis[v]) calc(v),upd(v); // 这里不用容斥算,枚举子树v,算与之前子树的贡献,然后在把v里面的点全部加入到桶里就好了
    fore(i,x) if(!vis[v]) upd(v,-1); // 清除桶
    tmp[n]=0;
    fore(i,x) if(!vis[v])
    {
        all=siz[v],mx[rt=0]=all;
        getrt(v,x),dfz(rt);
    }
}
int main()
{
    scanf("%d",&n);
    for(int i=1,x,y,w;i<=n-1;i++)
        scanf("%d%d%d",&x,&y,&w),addedge(x,y,w);
    mx[0]=n,all=siz[rt=0]=n;
    getrt(1,0),dfz(rt);
    printf("%lld
",ans);
    return 0;
}

LuoguP2664 树上游戏

嘛....这个题好像有(O(n))的做法...把(n log n)点分治吊起来锤......然而我不会

题意:给一颗树,每个点有自己的颜色,定义(s(i,j))表示点(i)到点(j)的路径上的颜色数量,(sum_i=sumlimits_{j=1}^{n}s(i,j)),然后让你求(sum_1...sum_n)

这玩意儿竟然能点分治...涨姿势了qwq

其实也不难理解,类似cdq分治,点分治在确定一个分治点(x)时,会把当前分治到的联通块分成几棵子树,那么我们要做的就是算这些子树两两之间经过点(x)的贡献就好了(当然还要把(x)的答案也更新一遍)

首先可以知道的一点,对于(x)(下面默认当前联通快以(x)为根)的一个儿子(v),子树(v)里面的一点(u),如果(u)的颜色是在路径(x,u)上第一次出现的话,那么对于子树(v)以外的点,都会产生(size_u)的贡献((size_u)表示(u)的子树的大小)。

那么对于到根的路径上第一次出现的颜色,这个节点(u),我们开个桶(tot),让(tot[col_u])(col_u)表示(u)的颜色)加上(size_u)就好了。

(sum = sum tot[c]),那么对(ans_x)的影响就是(sum),把(ans_x)加上(sum)

下面假设我现在要算点(y)的答案,点(y)在子树(v)中((v)(x)的一个儿子),假设(y)(x)的路径上出现了(k)中颜色,那么这些颜色对(ans_y)的贡献就是(k imes (size_x-size_v)),没有出现的颜色的贡献,就是除了子树(v),其他节点记录在(tot)里的(sum)

然后这个维护一个桶,对联通块(Dfs)记下就好了。

#include <bits/stdc++.h>
using namespace std;
#define ll long long 
#define fore(i,x) for(int i=head[x],v=e[i].to;i;i=e[i].nxt,v=e[i].to)
#define N 111111
int n,col[N];
struct edge
{
    int to,nxt;
}e[N<<1];
int head[N],cnt=0;
inline void ade(int x,int y)
{e[++cnt]=(edge){y,head[x]},head[x]=cnt;}
inline void addedge(int x,int y){ade(x,y),ade(y,x);}
int sz[N],vis[N],rt,mx[N],all;
void getrt(int x,int prev)
{
    sz[x]=1,mx[x]=0;
    fore(i,x) if(!vis[v]&&v!=prev)
    {
        getrt(v,x),sz[x]+=sz[v];
        mx[x]=max(mx[x],sz[v]);
    }
    mx[x]=max(mx[x],all-sz[x]);
    if(mx[x]<mx[rt]) rt=x;
}
int siz[N];
void dfs(int x,int prev)
{
    siz[x]=1; fore(i,x) if(!vis[v]&&v!=prev)
        dfs(v,x),siz[x]+=siz[v];
}
ll tot[N],sum; int tmp[N]; // tmp桶来算颜色个数以及是否是第一次出现
void upd(int x,int prev,int k1) // 跟新一棵子树对tot的贡献
{
    int c=col[x]; 
    if(!tmp[c]) tot[c]+=siz[x]*k1,sum+=siz[x]*k1;
    tmp[c]++;
    fore(i,x) if(!vis[v]&&v!=prev)
        upd(v,x,k1);
    tmp[c]--;
}
void clear(int x,int prev)
{
    tot[col[x]]=0;
    fore(i,x) if(!vis[v]&&v!=prev) clear(v,x);
}
ll nw,osiz,ans[N];
/*
 * osiz : size_x - size_v
 * nw : 当前点到x的颜色个数
*/
void getans(int x,int prev)
{
    int c=col[x];
    if(!tmp[c]) nw++,sum-=tot[c];
    ans[x]+=nw*osiz+sum;
    tmp[c]++; fore(i,x) if(!vis[v]&&v!=prev)
        getans(v,x);
    tmp[c]--;
    if(!tmp[c]) nw--,sum+=tot[c];
}
void calc(int x) // 算x对ans的贡献
{
    dfs(x,0),upd(x,0,1);
    ans[x]+=sum; int c=col[x];
    fore(i,x) if(!vis[v])
    {
        sum-=siz[v],tot[c]-=siz[v];
        tmp[c]++,upd(v,x,-1),tmp[c]--; // 注意要把子树v对tot和sum的贡献清除在算答案
        tmp[c]++,osiz=siz[x]-siz[v],sum-=tot[c],nw=1;
        getans(v,x);
        tmp[c]--,nw=0,sum+=tot[c];  // 撤销操作
        sum+=siz[v],tot[c]+=siz[v];
        tmp[c]++,upd(v,x,1),tmp[c]--;
    }
    clear(x,0);
    sum=0,osiz=0,nw=0;
}
void dfz(int x)
{
    calc(x),vis[x]=1;
    fore(i,x) if(!vis[v])
    {
        mx[rt=0]=all=sz[v];
        getrt(v,x),dfz(rt);
    }
}
int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++) scanf("%d",&col[i]);
    for(int i=1,x,y;i<n;i++)
        scanf("%d%d",&x,&y),addedge(x,y);
    mx[rt=0]=all=n,getrt(1,0),dfz(rt);
    for(int i=1;i<=n;i++) printf("%lld
",ans[i]);
    return 0;
}

到这里先咕咕咕吧。。。【咕~】

原文地址:https://www.cnblogs.com/wxq1229/p/12304021.html