「CSP-S 2020」函数调用(拓扑排序+DP)

Address

LOJ3381
LuoguP7077

Solution

因为加是单点加,乘是全体乘,所以考虑计算后面的乘对前面的加的影响。

也就是说,对于某次执行 (T_j=1) 的操作 (a_p+=v),设在它之后执行的 (T_j=2) 的操作的 (prod V_j=x)。那么计算最终答案的时候,只要把 (a_p+=v imes x) 即可。

对于 (T_j=3) 的操作,题目说保证不会出现递归(即不会直接或间接地调用本身)。因此建一张 DAG,如果函数 (u) 直接调用了函数 (v),那么连一条 (u→v) 的边。方便起见,再建一个点 (m+1),向 (Q)(f_i) 都连一条边。

接下来开始暴力,我们记一个 (prod),表示当前访问过的 (T_j=2) 节点的 (prod V_j)。(重复访问就重复计算)

(m+1) 开始 DFS(注意出边的顺序要反过来,因为是后面的乘对前面的加的影响)。DFS 到 (u) 的时候,如果 (T_j=2)(prod imes=V_j),如果 (T_j=1)(a_{P_j}+=V_j imes prod),如果 (T_j=3),就什么都不做。

怎么优化这个暴力?

考虑对于一个点 (u),它连向的点分别为 (v_1,v_2,...,v_k)。那么 DFS 到 (u) 之后,设当前的 (prod)(s),接下来肯定是 DFS (v_1),那么执行完 DFS (v_1),准备 DFS (v_2) 的时候,(prod) 是多少?

预处理出 (dp_u) 表示从 (u) 开始 DFS,经过的所有 (T_j=2) 节点的 (prod V_j),按拓扑序倒序转移即可。

那么上述的 (prod) 就是 (dp_{v_1} imes s),以此类推,准备 DFS (v_i) 的时候,(prod) 就是 (prod_{j=1}^{i-1}dp_{v_j} imes s)

我们可以这样描述这个 DFS:从 (u) 开始,带着大小为 (prod) 的标记走下去,接下来,对于每个 (v_i),带着大小为 (prod_{j=1}^{i-1}dp_{v_j} imes s) 的标记走下去。也就是说,我们不用把 (v_1sim v_{i-1}) 都 DFS 一遍,就可以知道 (v_i) 的标记大小,它仅仅取决于所有连向它的 (u)

我们记 (tag_u) 表示点 (u) 的标记大小。这个 (tag) 有什么用呢?我们发现所有的 (T_jin{1,2})(j) 都是底层节点,没有出边,所以如果 (T_j=1),我们求出 (tag_j) 之后,直接让 (a_{P_j}+=V_j imes tag_j) 即可。

根据上述分析,对于一个点 (v),只要知道所有连向它的 (u)(tag_u),即可用形如 (tag_v=sum_{u→v}tag_u imes prod_{xin pre(u,v)}dp_x) 的式子求出 (tag_v)。按照拓扑序转移即可。

注意把所有 (m+1) 走不到的点和边删掉。

时间复杂度 (O(n+m+Q+sum C_j))

Code

#include <bits/stdc++.h>

using namespace std;

#define ll long long

template <class t>
inline void read(t & res)
{
	char ch;
	while (ch = getchar(), !isdigit(ch));
	res = ch ^ 48;
	while (ch = getchar(), isdigit(ch))
		res = res * 10 + (ch ^ 48);
}

template <class t>
inline void print(t x)
{
	if (x > 9) print(x / 10);
	putchar(x % 10 + 48);
}

const int N = 1e5 + 15, M = 2e6 + 15, mod = 998244353;

int adj[N], nxt[M], go[M], val[N], pos[N], typ[N], n, m, q, tag[N];
int f[N], deg[N], seq[N], cnt, a[N], num;
bool vis[N];

inline void add(int &x, int y)
{
	(x += y) >= mod && (x -= mod);
}

inline void link(int x, int y)
{
	nxt[++num] = adj[x];
	adj[x] = num;
	go[num] = y;
	deg[y]++;
}

inline void dfs(int u)
{
	if (vis[u]) return;
	vis[u] = 1;
	for (int i = adj[u]; i; i = nxt[i]) dfs(go[i]);
}

inline void topsort()
{
	queue<int>q;
	int i, j;
	q.push(m + 1);
	seq[cnt = 1] = m + 1;
	while (!q.empty())
	{
		int u = q.front();
		q.pop();
		for (i = adj[u]; i; i = nxt[i])
		{
			int v = go[i];
			if (!vis[v]) continue;
			deg[v]--;
			if (!deg[v]) q.push(v), seq[++cnt] = v;
		}
	}
	for (i = cnt; i >= 1; i--)
	{
		int u = seq[i];
		for (j = adj[u]; j; j = nxt[j])
		{
			int v = go[j];
			f[u] = (ll)f[u] * f[v] % mod;
		}
	}
}

inline void solve()
{
	int i, j;
	for (i = 1; i <= cnt; i++)
	{
		int u = seq[i], pre = 1;
		for (j = adj[u]; j; j = nxt[j])
		{
			int v = go[j];
			add(tag[v], (ll)tag[u] * pre % mod);
			pre = (ll)pre * f[v] % mod;
		}
	}
}

int main()
{
	freopen("call.in", "r", stdin);
	freopen("call.out", "w", stdout);
	read(n);
	int i, j, k, x;
	for (i = 1; i <= n; i++) read(a[i]);
	read(m);
	for (i = 1; i <= m; i++)
	{
		read(typ[i]);
		f[i] = 1;
		if (typ[i] == 1) read(pos[i]), read(val[i]);
		else if (typ[i] == 2) read(val[i]), f[i] = val[i];
		else
		{
			read(k);
			for (j = 1; j <= k; j++)
			{
				read(x);
				link(i, x);
			}
		}
	}
	read(q);
	for (i = 1; i <= q; i++)
	{
		read(x);
		link(m + 1, x);
	}
	dfs(m + 1);
	for (i = 1; i <= m + 1; i++)
		for (j = adj[i]; j; j = nxt[j])
		{
			k = go[j];
			if (!vis[k] || !vis[i]) deg[k]--;
		}
	f[m + 1] = tag[m + 1] = 1;
	topsort();
	solve();
	for (i = 1; i <= n; i++) a[i] = (ll)a[i] * f[m + 1] % mod;
	for (i = 1; i <= m; i++)
		if (typ[i] == 1) add(a[pos[i]], (ll)val[i] * tag[i] % mod);
	for (i = 1; i <= n; i++)
		printf("%d ", a[i]);
	putchar('
');
	fclose(stdin);
	fclose(stdout);
	return 0;
}
原文地址:https://www.cnblogs.com/cyf32768/p/15240240.html