[洛谷P5342][TJOI2019]甲苯先生的线段树(数位dp)

Address

Luogu#5342

Solution

  • 对于 \(c=1\),由于路径长度为 \(O(d)\) 级别,只要知道 \(lca(x,y)\) 就是 \(x,y\) 二进制下的 \(lcp\) 就可以做了。
  • 对于 \(c=2\),先和 \(c=1\) 一样求出路径编号和,然后要知道一个显然性质(好像也不显然)和一个神仙性质:
    \((1).\)\(x\) 二进制中 \(1\) 的个数为 \(cnt(x)\),则 \(x\) 到根的路径编号和为 \(2x-cnt(x)\)注意这个性质在根为 \(0\) 时也成立,只要 \(x\) 的儿子还是 \(2x,2x+1\)
    \((2).\) 记路径的编号和为 \(s\),路径的两个端点为 \(x,y\)\(lca(x,y)=z\)\(x→z\) 要经过 \(a\) 条边,\(y→z\) 要经过 \(b\) 条边。那么当 \(a,b,s\) 都为定值时,\(z\) 也必为定值。
  • 性质 \((1)\) 可以归纳证明。
  • 考虑证明性质 \((2):\)
    把路径 \(x→y(x<y)\) 上的每个点编号写成二进制,它们的 \(lcp\) 就是 \(z\) 的二进制。记 \(z\) 的二进制有 \(t\) 位,那么先算出每个点编号前 \(t\) 位的值的贡献。\(z\) 的贡献是 \(z\)\(z\) 的每个儿子的贡献是 \(2z\),每个孙子的贡献是 \(4z\),依此类推,可以得出前 \(t\) 位的总贡献 \(v\) 为:
    \(z+2z+4z+....+2^{a}z+z+2z+4z+...+2^{b}z-z\)
    \(=(2^{a+1}+2^{b+1}-3)z\)
  • \(k\) 为其它贡献,即:\(vz+k=s\)
    讨论 \(k\) 的取值范围(先假设 \(a,b>0\)):
    \(1.\) \(z\) 往左走 \(1\) 步,然后往\(a-1\) 步到 \(x\)\(z\) 往右走 \(1\) 步,然后往\(b-1\) 步到 \(y\)。路径 \(x→z\)\(k\) 的贡献为 \(0\)\(y→z\)\(k\) 的贡献为 \(2^b-1\),此时 \(k\) 取最值。
    \(2.\) \(z\) 往左走 \(1\) 步,然后往\(a-1\) 步到 \(x\)\(z\) 往右走 \(1\) 步,然后往\(b-1\) 步到 \(y\)。路径 \(x→z\)\(k\) 的贡献为 \(2^a-a-1\)\(y→z\)\(k\) 的贡献为 \(2^{b+1}-b-2\),此时 \(k\) 取最值。
    综上所述,\(k∈[2^b-1,2^a+2^{b+1}-a-b-3]\)
  • 发现 \(k < v\),那么 \(z=s/v,k=s\%v\)\(v\) 的值只和 \(a,b\) 有关,所以 \(z\) 为定值,证毕。
  • 那么问题转化为:找到两个数 \(p,q\)(就是 \(x,y\) 砍掉前 \(t\) 位之后的值,砍掉前 \(t\) 位后,\(x\)最高位是 \(0\)\(y\) 最高位是 \(1\)),\(p<2^{a-1},q<2^b,2(p+q)-cnt(p)-cnt(q)=k\),注意 \(q\)\(2^{b-1}\) 位必须为 \(1\)
  • 枚举 \(cnt(p)+cnt(q)\),得到 \(a+b=w\),做数位\(dp\)\(f[i][j][k=0/1]\) 表示从个位开始(\(a,b\) 同时填),填了 \(i\) 位,此时共填 \(j\)\(1\),是否进位到第 \(i+1\) 位,要求 \(a+b\) 的每一位都和 \(w\) 对上即可。
  • 注意细节,注意多开 \(long\) \(long\),注意树高对 \(a,b,z\) 值的限制。
  • 时间复杂度 \(O(d^5)\),常数显然很小。

Code

#include <bits/stdc++.h>

using namespace std;

#define ll long long

template <class t>
inline void read(t & res)
{
    char ch;
    while (ch = getchar(), !isdigit(ch));
    res = ch ^ 48;
    while (ch = getchar(), isdigit(ch))
    res = res * 10 + (ch ^ 48);
}

template <class t>
inline void print(t x)
{
    if (x > 9) print(x / 10);
    putchar(x % 10 + 48);
}

const int e = 205;
ll f[e][e][2], ans;
bool a1[e], b1[e];

inline ll dp(int a, int b, ll s, int c)
{
    int i, j, t, x, y, bit = max(max(a - 1, b), (int)log2(s) + 1) + 1; 
	bool flag = a == 3 && b == 2;
    for (i = 0; i < bit; i++)
    for (j = 0; j <= c; j++)
    f[i][j][0] = f[i][j][1] = 0;
    i = 0;
    for (x = 0; x <= 1; x++)
    for (y = 0; y <= 1; y++)
    {
        int k = x + y >> 1, z = x + y & 1;
        if (i == b - 1 && !y) continue;
        if (z != (s & 1)) continue;
        if (i > b - 1 && y) continue;
        if (i > a - 2 && x) continue;
        f[0][x + y][k]++;
    }
    for (i = 0; i < bit - 1; i++)
    {
        ll d = s & (1ll << i + 1);
        if (d) d = 1;
        for (j = 0; j <= c; j++)
        for (t = 0; t <= 1; t++)
        if (f[i][j][t])
        for (x = 0; x <= 1; x++)
        for (y = 0; y <= 1; y++)
        {
            int now = x + y + t, k = now >> 1, z = now & 1;
            if (i + 1 == b - 1 && !y) continue;
            if (z != d) continue;
            if (i + 1 > b - 1 && y) continue;
            if (i + 1 > a - 2 && x) continue;
            f[i + 1][j + x + y][k] += f[i][j][t];
        }
    }
    return f[bit - 1][c][0];
}

inline ll calc(ll x)
{
    ll res = x << 1;
    while (x) res -= x & 1, x >>= 1;
    return res;
}

int main()
{
    int i, lx, ly, op, tst, h;
    ll x, y, z;
    read(tst);
    while (tst--)
    {
        read(h); read(x); read(y); read(op);
        ll x2 = x, y2 = y;
        lx = ly = z = 0;
        while (x) a1[++lx] = x & 1, x >>= 1;
        while (y) b1[++ly] = y & 1, y >>= 1;
        reverse(a1 + 1, a1 + lx + 1);
        reverse(b1 + 1, b1 + ly + 1);
        for (i = 1; i <= lx && i <= ly; i++)
        if (a1[i] == b1[i]) z = z * 2 + a1[i];
        else break;
        ll s = calc(x2) + calc(y2) - calc(z) - calc(z >> 1);
        if (op == 1)
        {
            print(s);
            putchar('\n');
            continue;
        }
        ans = 0;
        for (int a = 0; a < h; a++)
        for (int b = 0; b < h; b++)
        {
            ll y = ((1ll << a + 1) + (1ll << b + 1) - 3), x = s / y;
            if (!x) continue;
            if ((int)log2(x) + 1 + max(a, b) > h) continue;
            ll k = s - x * y;
            for (int c = 0; c <= a + b; c++) 
            if ((k + c) % 2 == 0) ans += dp(a, b, k + c >> 1, c);
        }
        print(ans - 1);
        putchar('\n');
    }
    return 0;
}
原文地址:https://www.cnblogs.com/cyf32768/p/12196242.html