做法
长链剖分 + 线段树
可过
本代码有细节疏漏,但是由于使我挂掉的数据比较难构造,所以可以通过。如果要弥补这个细节疏漏,需要加一些代码,我懒得调了。(滑稽)
时间复杂度 $O(n\log n)$
做法大致描述见:
附赠一组数据:
7
1 2
1 3
3 4
4 5
2 6
2 7
5
ans = 48
1 2
1 3
3 4
4 5
2 6
2 7
5
ans = 48
代码
#include <bits/stdc++.h> using namespace std; const int N=100005,mod=1e9+7; int read(){ int x=0; char ch=getchar(); while (!isdigit(ch)) ch=getchar(); while (isdigit(ch)) x=(x<<1)+(x<<3)+ch-48,ch=getchar(); return x; } int Pow(int x,int y){ int ans=1; for (;y;y>>=1,x=1LL*x*x%mod) if (y&1) ans=1LL*ans*x%mod; return ans; } struct Gragh{ static const int M=N*2; int cnt,y[M],nxt[M],fst[N]; void clear(){ cnt=0; memset(fst,0,sizeof fst); } void add(int a,int b){ y[++cnt]=b,nxt[cnt]=fst[a],fst[a]=cnt; } }g; int n; int fa[N],sz[N],top[N],Maxd[N],depth[N],p[N],ap[N]; vector <int> son[N]; bool cmp(int x,int y){ return Maxd[x]>Maxd[y]; } void dfs(int x,int pre,int d){ depth[x]=Maxd[x]=d,fa[x]=pre; son[x].clear(); for (int i=g.fst[x];i;i=g.nxt[i]) if (g.y[i]!=pre){ int y=g.y[i]; dfs(y,x,d+1); Maxd[x]=max(Maxd[x],Maxd[y]); son[x].push_back(y); } sort(son[x].begin(),son[x].end(),cmp); sz[x]=(int)son[x].size(); } int Time=0; void Get_Top(int x,int TOP){ top[x]=TOP; ap[p[x]=++Time]=x; if (!sz[x]) return; Get_Top(son[x][0],TOP); for (int i=1;i<sz[x];i++) Get_Top(son[x][i],son[x][i]); } struct Seg{ int v,add; }t[N<<2]; void build(int rt,int L,int R){ t[rt].v=0,t[rt].add=1; if (L==R) return; int mid=(L+R)>>1,ls=rt<<1,rs=ls|1; build(ls,L,mid); build(rs,mid+1,R); } void Times(int rt,int d){ t[rt].v=1LL*t[rt].v*d%mod; t[rt].add=1LL*t[rt].add*d%mod; } void pushdown(int rt){ int ls=rt<<1,rs=ls|1,&v=t[rt].add; if (v==1) return; Times(ls,v); Times(rs,v); v=1; } void update(int rt,int L,int R,int xL,int xR,int opt,int d){ if (L>xR||R<xL||xL>xR) return; if (xL<=L&&R<=xR){ if (opt==0) Times(rt,d); else t[rt].v=(t[rt].v+d)%mod; return; } pushdown(rt); int mid=(L+R)>>1,ls=rt<<1,rs=ls|1; update(ls,L,mid,xL,xR,opt,d); update(rs,mid+1,R,xL,xR,opt,d); t[rt].v=(t[ls].v+t[rs].v)%mod; } int query(int rt,int L,int R,int xL,int xR){ if (L>xR||R<xL||xL>xR) return 0; if (xL<=L&&R<=xR) return t[rt].v; pushdown(rt); int mid=(L+R)>>1,ls=rt<<1,rs=ls|1; return (query(ls,L,mid,xL,xR)+query(rs,mid+1,R,xL,xR))%mod; } void Prepare(){ dfs(1,0,0); Get_Top(1,1); } int D,addv[N],sv[N]; void DFS(int x){ update(1,1,n,p[x],p[x],1,1); if (sz[x]){ DFS(son[x][0]); update(1,1,n,p[x]+1,p[x]+min(D,Maxd[x]-depth[x]),0,2); for (int i=1;i<sz[x];i++){ int y=son[x][i],lim=Maxd[x]-depth[x]; int vy=min(D,Maxd[y]-depth[y]+1); DFS(y); int lastv=1; for (int j=0;j<=vy;j++) addv[j]=0; for (int j=1;j<=vy;j++) sv[j]=query(1,1,n,p[x],p[x]+min(j,D-j)); for (int j=1;j<=vy;j++){ int v=query(1,1,n,p[y]+j-1,p[y]+j-1); addv[j]=(1LL*(sv[j]+1)*v+addv[j])%mod; int k=D-j; if (k>j){ int inv=Pow(lastv,mod-2); update(1,1,n,p[x]+j+1,p[x]+min(k,lim),0,inv); lastv=(lastv+v)%mod; update(1,1,n,p[x]+j+1,p[x]+min(k,lim),0,lastv); } } for (int j=0;j<=vy;j++) update(1,1,n,p[***[x]+j,1,addv[j]); for (int j=vy+1;j<=Maxd[y]-depth[y]+1;j++){ int v=query(1,1,n,p[y]+j-1,p[y]+j-1); update(1,1,n,p[***[x]+j,1,v); } } } } int solve(int DD){ if (DD==0) return n; build(1,1,n); D=DD; DFS(1); return query(1,1,n,1,Maxd[1]+1); } int main(){ n=read(); g.clear(); for (int i=1;i<n;i++){ int a=read(),b=read(); g.add(a,b); g.add(b,a); } Prepare(); int DD=read(); printf("%d",(solve(DD)-solve(DD-1)+mod)%mod); return 0; }
全部评论
(0) 回帖