这个题能改变成加法。
第三个while始终觉得区间变长了,看了一下午还是很难理解。
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 50;
typedef long long ll;
int a[maxn];
struct node
{
int l, r;
int block;
int id;
};
node q[maxn];
bool cmp(node A, node B)
{
if(A.block == B.block) return A.r < B.r;
else return A.l < B.l;
}
ll out[maxn];
int sum[maxn];
ll cnt[1 << 20];
int main()
{
int n, m, k; scanf("%d %d %d", &n, &m, &k);
for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
for(int i = 1; i <= n; i++)
{
sum[i] = sum[i - 1] ^ a[i];
}
int block = sqrt(n);
for(int i = 1; i <= m; i++)
{
scanf("%d %d", &q[i].l, &q[i].r);
q[i].block = q[i].l / block;
q[i].id = i;
}
sort(q + 1, q + m + 1, cmp);
memset(cnt, 0, sizeof(cnt));
int l = 1, r = 0;
ll ans = 0;
cnt[0]++;
for(int i = 1; i <= m; i++)
{
while(r < q[i].r)
{
r++;
ans += cnt[sum[r] ^ k];
// printf("%d %d %d
", sum[r], sum[r] ^ k, cnt[sum[r] ^ k]);
cnt[sum[r]]++;
}
//printf("%d
", ans);
while(r > q[i].r)
{
cnt[sum[r]]--; ///先把我自己去掉,以免产生影响
ans -= cnt[sum[r] ^ k];
r--;
}
while(l < q[i].l)
{
cnt[sum[l - 1]]--; ///也是先把我自己去掉
ans -= cnt[sum[l - 1] ^ k];
l++;
}
while(l > q[i].l)
{
l--;
ans += cnt[sum[l - 1] ^ k];
cnt[sum[l - 1]]++;
}
out[q[i].id] = (ans >= 0LL ? ans : 0LL);
}
for(int i = 1; i <= m; i++)
{
printf("%lld
", out[i]);
}
return 0;
}
/*
6 1 3
1 2 1 1 0 3
3 4
*/
今天各种保研事情,没做上,明天继续补。
---------------------------------------------------
首先对于这颗树,如果我们去掉的是$u->v$这条边,那么我们考虑最优解是:
$sum=(u的子树内任意两点的和)+(v的子树内任意两点的和)+(u的子树内,所有点到某个点距离和的最小值 imes v的子树个数)+(v的子树内,所有点到某个点距离和的最小值 imes u的子树个数)+(u的子树个数 imes v的子树个数 imes 这条路径长度)$.
一个$O(n^3)$的做法:枚举删掉的边,还有枚举补上的边,再跑一遍。
在找新边的时候,前两项是定值,我们只需要计算后三项的最小值,要统计每个点到它子树的距离,然后$O(n^2)$枚举即可。
那么现在我们要求的就是,每个点它子树内的任意两点和,每个点到它子树的距离, 每个点的子树个数。
每个点到它子树内的任意两点和,枚举边的贡献即可,顺便计算点总数。
每个点的子树内,所有点到某个点距离和的最小值,先求以根的和,dfs处理的时候再转换一下。
(看网上都说 求某颗树内所有点到某个点距离和的最小值是经典套路……)
#include <bits/stdc++.h> using namespace std; typedef long long ll; const int maxn = 5e3 + 5; struct node { int v, w; }; vector<node> g[maxn]; int u[maxn], v[maxn], w[maxn]; int son[maxn]; ll d[maxn]; ///树内所有点到某个点的距离和 void dfs1(int u, int fa) { son[u] = 0; son[u]++; for(int i = 0; i < (int)g[u].size(); i++) { int v = g[u][i].v; if(v == fa) continue; dfs1(v, u); son[u] += son[v]; } } int tot = 0; void dfs2(int u, int fa, ll & sum1) { d[u] = 0; for(int i = 0; i < (int)g[u].size(); i++) { int v = g[u][i].v; if(v == fa) continue; dfs2(v, u, sum1); sum1 += (ll)son[v] * (tot - son[v]) * g[u][i].w; ///任意两点之间的距离和 d[u] += d[v] + (ll)son[v] * g[u][i].w; } // printf("%d %I64d ", u, d[u]); } ll minv = 0; void dfs3(int u, int fa) { minv = min(minv, d[u]); for(int i = 0; i < (int)g[u].size(); i++) { int v = g[u][i].v; if(v == fa) continue; d[v] = d[v] + (d[u] - d[v] - (ll)son[v] * g[u][i].w) + (ll)(son[u] - son[v]) * g[u][i].w; son[v] = son[u]; dfs3(v, u); } } int main() { int n; scanf("%d", &n); for(int i = 1; i <= n - 1; i++) { scanf("%d %d %d", &u[i], &v[i], &w[i]); g[u[i]].push_back({v[i], w[i]}); g[v[i]].push_back({u[i], w[i]}); } ll ans = 1e18; for(int i = 1; i <= n - 1; i++) { ll tmp = 0; dfs1(u[i], v[i]); ll sum1 = 0; tot = son[u[i]]; dfs2(u[i], v[i], sum1); tmp += sum1; minv = 1e18; dfs3(u[i], v[i]); int cntu = son[u[i]]; dfs1(v[i], u[i]); sum1 = 0; tot = son[v[i]]; dfs2(v[i], u[i], sum1); tmp += sum1; tmp += minv * son[v[i]]; minv = 1e18; dfs3(v[i], u[i]); tmp += minv * cntu; tmp += (ll)cntu * son[v[i]] * w[i]; //printf("%I64d ", tmp); ans = min(ans, tmp); } printf("%I64d ", ans); return 0; }
写的有点乱,再整理一下。