点分治练习

点分治

P3806 【模板】点分治1

题目链接

解题思路:

点分治,我对于每次询问都直接计算...感觉复杂度挺大的..点分治处理出每个节点到根的距离,
排序后左右指针移动计算长度为k的路径数量

#include <bits/stdc++.h>
using namespace std;
/*    freopen("k.in", "r", stdin);
    freopen("k.out", "w", stdout); */
// clock_t c1 = clock();
// std::cerr << "Time:" << clock() - c1 <<"ms" << std::endl;
//#pragma comment(linker, "/STACK:1024000000,1024000000")
#define de(a) cout << #a << " = " << a << endl
#define rep(i, a, n) for (int i = a; i <= n; i++)
#define per(i, a, n) for (int i = n; i >= a; i--)
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
typedef pair<double, double> PDD;
typedef vector<int, int> VII;
#define inf 0x3f3f3f3f
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll MAXN = 1e5 + 7;
const ll MAXM = 1e6 + 7;
const ll MOD = 1e9 + 7;
const double eps = 1e-6;
const double pi = acos(-1.0);
struct Edge
{
    int u, v, val, net;
} e[MAXN << 1];
int cnt = -1;
int head[MAXN];
int n, m, sum;
int k;
void add(int u, int v, int val)
{
    e[++cnt].u = u;
    e[cnt].v = v;
    e[cnt].val = val;
    e[cnt].net = head[u];
    head[u] = cnt;
}
int mx[MAXN], vis[MAXN], sz[MAXN];
int rt;
int ans;
void dfs(int now, int fa)
{
    sz[now] = 1, mx[now] = 0;
    for (int i = head[now]; ~i; i = e[i].net)
    {
        int v = e[i].v;
        if (vis[v] || v == fa)
            continue;
        dfs(v, now);
        sz[now] += sz[v];
        mx[now] = max(mx[now], sz[v]);
    }
    mx[now] = max(mx[now], sum - sz[now]);
    if (mx[now] < mx[rt])
        rt = now;
}
int tot;
int dis[MAXN];
void getdis(int now, int fa, int len)
{
    for (int i = head[now]; ~i; i = e[i].net)
    {
        int v = e[i].v;
        if (vis[v] || v == fa)
            continue;
        dis[++tot] = len + e[i].val;
        getdis(v, now, dis[tot]);
    }
}
int solve(int now, int len)
{
    int ret = 0;
    dis[tot = 1] = len;
    getdis(now, -1, len);
    int l = 1, r = tot;
    sort(dis + 1, dis + 1 + tot);
    while (l < r)
    {
        if (dis[l] + dis[r] == k)
        {
            l++, r--;
            ret++;
        }
        else if (dis[l] + dis[r] < k)
            l++;
        else
            r--;
    }
    // cout << ret << endl;
    return ret;
}
void divide(int now)
{
    vis[now] = 1;
    ans += solve(now, 0);
    for (int i = head[now]; ~i; i = e[i].net)
    {
        int v = e[i].v;
        if (vis[v])
            continue;
        ans -= solve(v, e[i].val);
        rt = 0;
        sum = sz[v];
        dfs(v, now);
        divide(rt);
    }
}
void init()
{
    memset(head, -1, sizeof(head));
    cnt = -1;
    rt = 0;
    mx[0] = inf;
}
int main()
{
    scanf("%d%d", &n, &m);
    init();
    for (int i = 0; i < n - 1; i++)
    {
        int u, v, val;
        scanf("%d%d%d", &u, &v, &val);
        add(u, v, val);
        add(v, u, val);
    }
    for (int i = 0; i < m; i++)
    {
        scanf("%d", &k);
        memset(vis, 0, sizeof(vis));
        ans = 0;
        rt = 0;
        sum = n;
        dfs(1, -1);
        divide(rt);
        printf("%s
", ans ? "AYE" : "NAY");
    }
    return 0;
}

P2634 [国家集训队]聪聪可可

聪聪可可

解题思路:

在上题代码上改solve函数就差不多了,模3后到根结点长度为0,1,2的路径对于答案的贡献是(dis[0]*dis[0]+2*dis[1]*dis[2])

#include <bits/stdc++.h>
using namespace std;
/*    freopen("k.in", "r", stdin);
    freopen("k.out", "w", stdout); */
// clock_t c1 = clock();
// std::cerr << "Time:" << clock() - c1 <<"ms" << std::endl;
//#pragma comment(linker, "/STACK:1024000000,1024000000")
#define de(a) cout << #a << " = " << a << endl
#define rep(i, a, n) for (int i = a; i <= n; i++)
#define per(i, a, n) for (int i = n; i >= a; i--)
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
typedef pair<double, double> PDD;
typedef vector<int, int> VII;
#define inf 0x3f3f3f3f
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll MAXN = 2e5 + 7;
const ll MAXM = 1e6 + 7;
const ll MOD = 1e9 + 7;
const double eps = 1e-6;
const double pi = acos(-1.0);
struct Edge
{
    int u, v, val, net;
} e[MAXN << 1];
int cnt = -1;
int head[MAXN];
int n;
int sum;
void add(int u, int v, int val)
{
    e[++cnt].u = u;
    e[cnt].v = v;
    e[cnt].val = val;
    e[cnt].net = head[u];
    head[u] = cnt;
}
int mx[MAXN], vis[MAXN], sz[MAXN];
int rt;
ll ans;
void dfs(int now, int fa)
{
    sz[now] = 1, mx[now] = 0;
    for (int i = head[now]; ~i; i = e[i].net)
    {
        int v = e[i].v;
        if (vis[v] || v == fa)
            continue;
        dfs(v, now);
        sz[now] += sz[v];
        mx[now] = max(mx[now], sz[v]);
    }
    mx[now] = max(mx[now], sum - sz[now]);
    if (mx[now] < mx[rt])
        rt = now;
}
ll dis[4];
void getdis(int now, int fa, int len)
{
    for (int i = head[now]; ~i; i = e[i].net)
    {
        int v = e[i].v;
        if (vis[v] || v == fa)
            continue;
        dis[(len + e[i].val) % 3]++;
        getdis(v, now, len + e[i].val);
    }
}
ll solve(int now, int len)
{
    dis[0] = dis[1] = dis[2] = 0;
    dis[len % 3]++;
    getdis(now, -1, len);
    return dis[0] * dis[0] + 2 * dis[1] * dis[2];
}
void divide(int now)
{
    vis[now] = 1;
    ans += solve(now, 0);
    for (int i = head[now]; ~i; i = e[i].net)
    {
        int v = e[i].v;
        if (vis[v])
            continue;
        ans -= solve(v, e[i].val);
        rt = 0;
        sum = sz[v];
        dfs(v, now);
        divide(rt);
    }
}
void init()
{
    memset(vis, 0, sizeof(vis));
    memset(head, -1, sizeof(head));
    cnt = -1;
    rt = 0;
    mx[0] = inf;
}
ll GCD(ll a, ll b) { return b == 0 ? a : GCD(b, a % b); }
int main()
{
    scanf("%d", &n);
    init();
    for (int i = 0; i < n - 1; i++)
    {
        int u, v, val;
        scanf("%d%d%d", &u, &v, &val);
        val %= 3;
        add(u, v, val);
        add(v, u, val);
    }
    sum = n;
    dfs(1, -1);
    divide(rt);
    ll div = GCD(ans, 1LL * n * n);
    printf("%lld/%lld
", ans / div, 1LL * n * n / div);
    // system("pause");
    return 0;
}

PKU-1741 Tree

Tree

解题思路:

点分治+二分,复杂度(nlog^2n)

#include <algorithm>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <iostream>
#include <cstdlib>
#include <set>
#include <vector>
#include <cctype>
#include <iomanip>
#include <sstream>
#include <climits>
#include <queue>
#include <stack>
using namespace std;
/*    freopen("k.in", "r", stdin);
    freopen("k.out", "w", stdout); */
// clock_t c1 = clock();
// std::cerr << "Time:" << clock() - c1 <<"ms" << std::endl;
//#pragma comment(linker, "/STACK:1024000000,1024000000")
#define de(a) cout << #a << " = " << a << endl
#define rep(i, a, n) for (int i = a; i <= n; i++)
#define per(i, a, n) for (int i = n; i >= a; i--)
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
typedef pair<double, double> PDD;
typedef vector<int, int> VII;
#define inf 0x3f3f3f3f
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll MAXN = 1e5 + 7;
const ll MAXM = 1e6 + 7;
const ll MOD = 1e9 + 7;
const double eps = 1e-6;
const double pi = acos(-1.0);
struct Edge
{
    int u, v, val, net;
    Edge(int _u = 0, int _v = 0, int _val = 0, int _net = 0) { u = _u, v = _v, val = _val, net = _net; }
} e[MAXN << 1];
int cnt = -1;
int head[MAXN];
void add(int u, int v, int val)
{
    e[++cnt] = Edge(u, v, val, head[u]);
    head[u] = cnt;
}
int vis[MAXN], sz[MAXN], mx[MAXN];
int n, k, sum, rt;
void dfs(int now, int fa)
{
    sz[now] = 1, mx[now] = 0;
    for (int i = head[now]; ~i; i = e[i].net)
    {
        int v = e[i].v;
        if (vis[v] || v == fa)
            continue;
        dfs(v, now);
        sz[now] += sz[v];
        mx[now] = max(mx[now], sz[v]);
    }
    mx[now] = max(mx[now], sum - sz[now]);
    if (mx[now] < mx[rt])
        rt = now;
}
int tot;
int ans;
int dis[MAXN];
//路径长度小于等于K点对数
void getdis(int now, int fa, int len)
{
    for (int i = head[now]; ~i; i = e[i].net)
    {
        int v = e[i].v;
        if (vis[v] || v == fa)
            continue;
        dis[++tot] = len + e[i].val;
        getdis(v, now, dis[tot]);
    }
}
int solve(int now, int len)
{
    int ret = 0;
    dis[tot = 1] = len;
    getdis(now, -1, len);
    sort(dis + 1, dis + 1 + tot);
    for (int i = 1; i <= tot; i++)
    {
        int temp = upper_bound(dis + 1, dis + 1 + tot, k - dis[i]) - dis - 1;
        if (temp >= i)
            temp--;
        ret += temp;
    }
    return ret/2;
}
void divide(int now)
{
    vis[now] = 1;
    ans += solve(now, 0);
    for (int i = head[now]; ~i; i = e[i].net)
    {
        int v = e[i].v;
        if (vis[v])
            continue;
        ans -= solve(v, e[i].val);
        rt = 0;
        sum = sz[v];
        dfs(v, now);
        divide(rt);
    }
}
void init()
{
    memset(head, -1, sizeof(head));
    memset(vis, 0, sizeof(vis));
    cnt = -1;
    rt = 0;
    sum = n;
    ans = 0;
    mx[0] = inf;
}
int main()
{
    while (~scanf("%d%d", &n, &k) && n + k)
    {
        init();
        for (int i = 0; i < n - 1; i++)
        {
            int u, v, val;
            scanf("%d%d%d", &u, &v, &val);
            add(u, v, val);
            add(v, u, val);
        }
        dfs(1, -1);
        divide(rt);
        printf("%d
", ans);
    }
    return 0;
}
/* 
5 4
1 2 3
1 3 1
1 4 2
3 5 1 

5 4
1 2 1
2 3 2
3 4 3
4 5 2
*/
原文地址:https://www.cnblogs.com/graytido/p/11872481.html