Weak Pair---hud5877大连网选(线段树优化+dfs)

题目链接:http://acm.split.hdu.edu.cn/showproblem.php?pid=5877

 题意:给你一颗树,有n个节点,每个节点都有一个权值v[i];现在求有多少对(u,v)满足u是v的祖先,并且au*av<=k, k是已知的;

思路:从根节点开始dfs遍历整棵树,当遍历到某点u时,已经在栈中的节点都是u的祖先的,所以我们只要找到在栈中的节点有多少个是<=k/a[u]的即可;

由于n的值最大可达到10e5,所以直接查找是会TLE的,我们可以用线段树优化即可;在dfs的时候插入当前节点的权值,在回溯的时候删除节点即可;

#include<stdio.h>
#include<string.h>
#include<algorithm>
#include<iostream>
#include<vector>
#include<queue>
using namespace std;
#define met(a, b) memset(a, b, sizeof(a))
#define N 100005
typedef long long LL;
#define Lson r<<1
#define Rson r<<1|1

vector<int> G[N];
LL v[N], t[2*N], ans, k;
int du[N], m;

struct node
{
    int L, R;
    LL sum;
    int mid(){ return (L+R)/2; }
} a[N*8];

void Build(int r, int L, int R)
{
    a[r].L = L, a[r].R = R, a[r].sum = 0;
    if(L == R) return ;
    Build(Lson, L, a[r].mid());
    Build(Rson, a[r].mid()+1, R);
}

void Update(int r, int pos, LL num)
{
    if(a[r].L == a[r].R && a[r].L == pos)
    {
        a[r].sum += num;
        return ;
    }
    if(pos <= a[r].mid())
        Update(Lson, pos, num);
    else
        Update(Rson, pos, num);

    a[r].sum = a[Lson].sum + a[Rson].sum;
}

LL Query(int r, int L, int R)
{
    if(a[r].L == L && a[r].R == R)
        return a[r].sum;
    if(R <= a[r].mid())
        return Query(Lson, L, R);
    else if(L > a[r].mid())
        return Query(Rson, L, R);
    else
        return Query(Lson, L, a[r].mid()) + Query(Rson, a[r].mid()+1, R);
}

void dfs(int u)
{
    int pos1 = lower_bound(t+1, t+m+1, v[u]) - t;
    int pos2 = lower_bound(t+1, t+m+1, k/v[u]) - t;

    ans += Query(1, 1, pos2);

    Update(1, pos1, 1ll);

    for(int i=0, len=G[u].size(); i<len; i++)
        dfs(G[u][i]);

    Update(1, pos1, -1ll);
}

int main()
{
    int T, n;
    while(scanf("%d", &T) != EOF)
    {
        while(T--)
        {
            scanf("%d %I64d", &n, &k);
            for(int i=1; i<=n; i++)
            {
                scanf("%I64d", &v[i]);
                t[i] = v[i];
                t[i+n] = k/t[i];
                G[i].clear();
                du[i] = 0;
            }
            for(int i=1; i<n; i++)
            {
                int u, v;
                scanf("%d %d", &u, &v);
                G[u].push_back(v);
                du[v] ++;
            }
            sort(t+1, t+n*2+1);
            m = unique(t+1, t+n*2+1)-t-1;
            Build(1, 1, m);
            ans = 0;
            for(int i=1; i<=n; i++)
            {
                if(du[i] == 0)
                {
                    dfs(i);
                    ///break;
                }
            }
            printf("%I64d
", ans);
        }
    }
    return 0;
}
/*
5
5 10
1 2 3 4 5
2 3
2 5
3 4
3 1

*/
View Code
原文地址:https://www.cnblogs.com/zhengguiping--9876/p/5874826.html