P7077 函数调用

https://www.luogu.com.cn/problem/P7077
https://loj.ac/p/3381

可以转换一下变更值的思路,由于询问只有,一次,可以理解为维护一个全局乘法标记,让所有数都先成上这个标记,这解决了所有乘法操作
对于加法操作,设原来要加的值为 (add),那么由于其他乘法操作会对它产生影响,从而实际应该乘一个系数,也就是加的值为 (add imes val)

接下来关键的问题是如何求出 (val)
首先注意到,函数的调用关系是一个 DAG,那么可以给每个节点记录一个 (mul_u) 表示如果调用这个节点,会对全局乘法标记产生多少影响
也就是对于第一类,(mul=1);对于第二类,(mul=v);对于第三类,(mul) 为所以要调用的函数的 (mul) 之积
这可以通过一遍 dfs 求出

然后考虑依次调用的 (q) 个函数,后面的函数的乘法,肯定会对前面的函数的加法的系数(也就是 (val))产生影响,要加到它的系数上
所以倒着处理,维护一个乘积每次让 (Q_i)(val) 加上这个乘积,然后这个乘积再乘上 (Q_i)(mul)
一个例子,比如:

+3 *2 +2 *3

这里省略了加法操作的对应元素,不过那并不重要,第一个 +3 会由于后面的 *2 *3 而被加六遍,也就是它的 (val=6);同理,那个 +2(val=3)

但是这样处理以后有一些 (val) 是在第三类函数上的,需要将他们下传至叶子节点
如何下传,其实思路是相同的,因为一开始函数的依次调用可以看作是一个零号的三类函数
那么就反着第三类函数的执行顺序,同样维护乘积,对于他的子节点的 (val) 做更改
这可以通过一遍拓扑排序解决
注意由于邻接表加边后顺序自然边反,所以直接正常遍历就行

m才是函数个数,我因为这个WA65了一年

#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<map>
#include<iomanip>
#include<cstring>
#define reg register
#define EN puts("")
inline int read(){
	register int x=0;register int y=1;
	register char c=std::getchar();
	while(c<'0'||c>'9'){if(c=='-') y=0;c=getchar();}
	while(c>='0'&&c<='9'){x=x*10+(c^48);c=getchar();}
	return y?x:-x;
}
#define N 200006
#define M 2200006
struct Graph{
	int fir[N],nex[M],to[M],tot;
	inline void add(int u,int v){
		to[++tot]=v;
		nex[tot]=fir[u];fir[u]=tot;
	}
}G;
#define mod 998244353
int n,m;
int Q[N];
long long a[N];
struct Node{
	long long add,mul,val;//val 表示 add 的系数
	int pos;
}b[N];
int vis[N];
void getmul(reg int u){
	vis[u]=1;
	for(reg int v,i=G.fir[u];i;i=G.nex[i]){
		v=G.to[i];
		if(!vis[v]) getmul(v);
		b[u].mul=b[u].mul*b[v].mul%mod;
	}
}
int que[N],left,right;
int in[N];
inline void topo(){
	right=-1;left=0;
	for(reg int i=1;i<=m;i++)if(!in[i]) que[++right]=i;
	reg int u,v,i;
	while(left<=right){
		u=que[left++];
		long long prod=1;
		for(i=G.fir[u];i;i=G.nex[i]){
			v=G.to[i];
			b[v].val=(b[v].val+b[u].val*prod%mod)%mod;
			prod=prod*b[v].mul%mod;
			if(!--in[v]) que[++right]=v;
		}
	}
}
int main(){
//		std::freopen("call.in","r",stdin);
//		std::freopen("call.out","w",stdout);
	n=read();
	for(reg int i=1;i<=n;i++) a[i]=read();
	m=read();
	for(reg int op,i=1;i<=m;i++){
		op=read();
		b[i].mul=1;
		if(op==1) b[i].pos=read(),b[i].add=read();
		else if(op==2) b[i].mul=read();
		else{
			op=read();int x;
			while(op--){
				x=read();
				G.add(i,x);in[x]++;
			}
		}
	}
	for(reg int i=1;i<=m;i++)if(!in[i]) getmul(i);
	int q=read();
	for(reg int i=1;i<=q;i++) Q[i]=read();
	long long prod=1;
	for(reg int i=q;i;i--){
		b[Q[i]].val=(b[Q[i]].val+prod)%mod;
		prod=prod*b[Q[i]].mul%mod;
	}
//		puts("
----------------");
//		printf("prod=%lld
",prod);
//		for(reg int i=1;i<=n;i++) printf("i=%d pos=%d add=%lld val=%lld mul=%lld
",i,b[i].pos,b[i].add,b[i].val,b[i].mul);
//		puts("------------------
");
	topo();
	for(reg int i=1;i<=n;i++) a[i]=a[i]*prod%mod;
	for(reg int i=1;i<=m;i++)if(b[i].pos){//操作 1
		a[b[i].pos]=(a[b[i].pos]+b[i].add*b[i].val%mod)%mod;
	}
	for(reg int i=1;i<=n;i++) printf("%lld ",a[i]);
//		puts("
----------------");
//		for(reg int i=1;i<=n;i++) printf("i=%d pos=%d add=%lld val=%lld mul=%lld
",i,b[i].pos,b[i].add,b[i].val,b[i].mul);
//		puts("------------------
");
	return 0;
}
原文地址:https://www.cnblogs.com/suxxsfe/p/14058265.html