HDU 4812:D Tree(树上点分治+逆元)

题目链接

题意

给一棵树,每个点上有一个权值,问是否存在一条路径(不能是单个点)上的所有点相乘并对1e6+3取模等于k,输出路径的两个端点。如果存在多组答案,输出字典序小的点对。

思路

首先,(a * b) % MOD = k,知道a和k,求b,可以使用逆元来求,于是可以想到用一个类似于map的东西(我这里的Hash数组,记录值为i的时候它的最小下标是多少)存路径长度为b的时候,那个端点是哪个点。

但是我一开始是想着先全部处理好,然后再O(MOD)判一遍,但是发现这种做法的话在有删除的情况下难以解决。

于是要考虑一边更新一边删除。对于当前的根结点的不同子树,分开处理。先判断再更新(因为这样才可以使得路径的两个端点在不同的子树上),最后做完这棵树后要删除。

学习到了逆元打表的写法:

inv[1] = 1LL;
for(LL i = 2; i < MOD; i++)
    inv[i] = inv[MOD % i] * (MOD - MOD / i) % MOD;
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10;
const int MOD = 1e6 + 3;
const int INF = 0x3f3f3f3f;
typedef long long LL;
struct Edge {
    int v, nxt;
} edge[N*2];
LL dis[N], w[N], dep[N], inv[MOD + 11];
int n, k, head[N], tot, root, sum, son[N], vis[N], f[N], ans1, ans2;
int Hash[MOD + 11], id[MOD + 11];

void Add(int u, int v) {
    edge[tot] = (Edge) { v, head[u] }; head[u] = tot++;
    edge[tot] = (Edge) { u, head[v] }; head[v] = tot++;
}

void getroot(int u, int fa) {
    son[u] = 1; f[u] = 0;
    for(int i = head[u]; ~i; i = edge[i].nxt) {
        int v = edge[i].v;
        if(vis[v] || fa == v) continue;
        getroot(v, u);
        son[u] += son[v];
        f[u] = max(f[u], son[v]);
    }
    f[u] = max(f[u], sum - son[u]);
    if(f[u] < f[root]) root = u;
}

void getdeep(int u, int fa) {
    dep[++dep[0]] = dis[u] % MOD; id[dep[0]] = u;
    for(int i = head[u]; ~i; i = edge[i].nxt) {
        int v = edge[i].v;
        if(vis[v] || fa == v) continue;
        dis[v] = dis[u] * w[v] % MOD;
        getdeep(v, u);
    }
}

void update(LL now, int x) {
    now = inv[now] * k % MOD;
    int y = Hash[now];
    if(y == INF) return ;
    if(x > y) swap(x, y);
    if(x < ans1 || (x == ans1 && y < ans2)) ans1 = x, ans2 = y;
}

int cal(int u, int st) {
    if(st) {
        Hash[w[u]] = u;
        for(int i = head[u]; ~i; i = edge[i].nxt) {
            int v = edge[i].v;
            if(vis[v]) continue;
            // 判断是否存在答案
            dep[0] = 0;
            dis[v] = w[v] % MOD;
            getdeep(v, u);
            for(int j = 1; j <= dep[0]; j++)
                update(dep[j], id[j]);
            // 更新Hash表
            dep[0] = 0;
            dis[v] = w[u] * w[v] % MOD;
            getdeep(v, u);
            for(int j = 1; j <= dep[0]; j++)
                Hash[dep[j]] = min(Hash[dep[j]], id[j]);
        }
    } else {
        // 删除操作
        Hash[w[u]] = INF;
        for(int i = head[u]; ~i; i = edge[i].nxt) {
            int v = edge[i].v;
            if(vis[v]) continue;
            dep[0] = 0;
            dis[v] = w[u] * w[v] % MOD;
            getdeep(v, u);
            for(int j = 1; j <= dep[0]; j++)
                Hash[dep[j]] = INF;
        }
    }
}

void work(int u) {
    vis[u] = 1;
    cal(u, 1);
    cal(u, 0);
    for(int i = head[u]; ~i; i = edge[i].nxt) {
        int v = edge[i].v;
        if(vis[v]) continue;
        sum = son[v];
        getroot(v, root = 0);
        work(root);
    }
}

LL f_pow(LL a, int b) {
    LL ans = 1;
    while(b) {
        if(b & 1) ans = ans * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    } return ans % MOD;
}

int main() {
//    for(LL i = 0; i < MOD; i++) inv[i] = f_pow(i, MOD - 2);
    inv[1] = 1LL;
    for(LL i = 2; i < MOD; i++)
        inv[i] = inv[MOD % i] * (MOD - MOD / i) % MOD;

    while(~scanf("%d%d", &n, &k)) {
        memset(head, -1, sizeof(head));
        memset(Hash, INF, sizeof(Hash));
        memset(vis, 0, sizeof(vis));
        tot = 0;
        for(int i = 1; i <= n; i++) scanf("%lld", &w[i]);
        for(int i = 1; i < n; i++) {
            int u, v; scanf("%d%d", &u, &v);
            Add(u, v);
        }
        sum = n, root = 0, f[0] = ans1 = ans2 = INF;
        getroot(1, 0);
        work(root);
        if(ans1 == INF || ans2 == INF) puts("No solution");
        else printf("%d %d
", ans1, ans2);
    }
    return 0;
}

/*
5 60
2 5 2 3 3
1 2
1 3
2 4
2 5
5 2
2 5 2 3 3
1 2
1 3
2 4
2 5
*/

原文地址:https://www.cnblogs.com/fightfordream/p/7607197.html