竞赛讨论区 > 树上路径
头像
shyyhs
编辑于 2021-04-26 20:26
+ 关注

树上路径


#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod=1e9+7;
const int N=1e5+5;
const int iv2=(mod+1)/2;
ll w[N];
vector<int>g[N];
ll add(ll a,ll b)
{
	return (a+b)%mod; 
}

ll add(ll a,ll b,ll c)
{
	return add(add(a,b),c);
}

ll mul(ll a,ll b)
{
	return a*b%mod;
}

ll mul(ll a,ll b,ll c)
{
	return mul(mul(a,b),c);
}

int sz[N],son[N],dep[N],f[N];
void dfs(int u,int fa)
{
	dep[u]=dep[fa]+1;
	f[u]=fa;
	sz[u]=1;
	for(int v:g[u])
	{
		if(v==fa)	continue;
		dfs(v,u);
		sz[u]+=sz[v];
		if(sz[son[u]]<sz[v])
		{
			son[u]=v;
		}
	}
	
}

int idx[N],top[N],id;
ll val[N];
void DFS(int u,int tp)
{
	idx[u]=++id;
	top[u]=tp;
	val[id]=w[u];
	if(!son[u])	return;
	DFS(son[u],tp);
	for(int v:g[u])
	{
		if(!idx[v])
		{
			DFS(v,v);
		}
	}
}

struct SegTree{
	int l,r,len;
	ll lazy,sum,ans;
}Tr[N<<2];

void change(int u,ll k)
{
	Tr[u].lazy=add(Tr[u].lazy,k);
	Tr[u].ans=add(Tr[u].ans,mul(mul(Tr[u].len,iv2,Tr[u].len-1),(k*k%mod)),mul(Tr[u].sum,k,Tr[u].len-1));
	Tr[u].sum=add(Tr[u].sum,(Tr[u].len)*k%mod);
}

void pushup(int u)
{
	Tr[u].sum=add(Tr[u<<1].sum,Tr[u<<1|1].sum);
	Tr[u].ans=add(Tr[u<<1].ans,Tr[u<<1|1].ans,mul(Tr[u<<1].sum,Tr[u<<1|1].sum));
}

void pushdown(int u)
{
	if(Tr[u].lazy)
	{
		change(u<<1,Tr[u].lazy);
		change(u<<1|1,Tr[u].lazy);		
		Tr[u].lazy=0;
	}
}

void build(int u,int l,int r)
{
	Tr[u].l=l,Tr[u].r=r;
	Tr[u].len=(r-l+1);
	if(l==r)
	{
		Tr[u].sum=val[l];
		return;
	}
	int mid=(l+r)>>1;
	build(u<<1,l,mid);
	build(u<<1|1,mid+1,r);
	pushup(u);
}

void add(int u,int l,int r,ll k)
{
	//if(l>Tr[u].r||r<Tr[u].l)	return;
	if(Tr[u].l>=l&&Tr[u].r<=r)
	{
		change(u,k);
		return;
	}
	pushdown(u);
	int mid=(Tr[u].l+Tr[u].r)/2;
	if(l<=mid)	add(u<<1,l,r,k);
	if(r>mid)	add(u<<1|1,l,r,k);
	pushup(u);
}

SegTree merge(SegTree l,SegTree r)
{
	SegTree res;
	res.ans=add(l.ans,r.ans,mul(l.sum,r.sum));
	res.sum=add(l.sum,r.sum);
	return res;
}

SegTree query(int u,int l,int r)
{
	if(Tr[u].l>=l&&Tr[u].r<=r)
	{
		return Tr[u];
	}
	pushdown(u);
	int mid=(Tr[u].l+Tr[u].r)>>1;
	if(r<=mid)	return query(u<<1,l,r);
	else if(l>mid)	return query(u<<1|1,l,r);
	else return merge(query(u<<1,l,r),query(u<<1|1,l,r));
}

void Treeadd(int u,int v,ll k)
{
	while(top[u]!=top[v])
	{
		if(dep[top[u]]<dep[top[v]])	swap(u,v);
		add(1,idx[top[u]],idx[u],k);                   
		u=f[top[u]];
	}
	if(dep[v]>dep[u])	swap(u,v);
	add(1,idx[v],idx[u],k);
}

ll Treequery(int u,int v)
{
	SegTree res;
	res.sum=0,res.ans=0;
	while(top[u]!=top[v])
	{
		if(dep[top[u]]<dep[top[v]])	swap(u,v);
		res=merge(res,query(1,idx[top[u]],idx[u]));
		u=f[top[u]];
	}
	if(dep[v]>dep[u])	swap(u,v);
	res=merge(res,query(1,idx[v],idx[u]));
	return res.ans;
}

int main()
{
	int n,m;
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)
	{
		scanf("%lld",&w[i]);
	}
	for(int i=1;i<n;i++)
	{
		int u,v;
		scanf("%d%d",&u,&v);
		g[u].push_back(v);
		g[v].push_back(u);
	}
	dfs(1,1);
	DFS(1,1);
	build(1,1,n);
	while(m--)
	{
		int opt,v,u;ll k;
		scanf("%d",&opt);
		if(opt==1)
		{
			scanf("%d%lld",&u,&k);
			add(1,idx[u],idx[u]+sz[u]-1,k);
		}   
		else if(opt==2)
		{
			scanf("%d%d%lld",&u,&v,&k);
			Treeadd(u,v,k);
		}
		else
		{
			scanf("%d%d",&u,&v);
			printf("%lld\n",Treequery(u,v));
		}    
	}
	return 0;
}

这份代码交这个题,为什么有时候段错误,有时候ac...


全部评论

(2) 回帖
加载中...
话题 回帖

本文相关内容

等你来战

查看全部

热门推荐