竞赛讨论区 > 本题有长链剖分 + 线段树做法
头像
zhouzhendong
编辑于 2018-08-10 18:37
+ 关注

本题有长链剖分 + 线段树做法

做法

长链剖分 + 线段树

可过

本代码有细节疏漏,但是由于使我挂掉的数据比较难构造,所以可以通过。如果要弥补这个细节疏漏,需要加一些代码,我懒得调了。(滑稽)

时间复杂度 $O(n\log n)$

做法大致描述见:

附赠一组数据:
7
1 2
1 3
3 4
4 5
2 6
2 7
5
ans = 48

代码

#include <bits/stdc++.h>
using namespace std;
const int N=100005,mod=1e9+7;
int read(){
    int x=0;
    char ch=getchar();
    while (!isdigit(ch))
        ch=getchar();
    while (isdigit(ch))
        x=(x<<1)+(x<<3)+ch-48,ch=getchar();
    return x;
}
int Pow(int x,int y){
    int ans=1;
    for (;y;y>>=1,x=1LL*x*x%mod)
        if (y&1)
            ans=1LL*ans*x%mod;
    return ans;
}
struct Gragh{
    static const int M=N*2;
    int cnt,y[M],nxt[M],fst[N];
    void clear(){
        cnt=0;
        memset(fst,0,sizeof fst);
    }
    void add(int a,int b){
        y[++cnt]=b,nxt[cnt]=fst[a],fst[a]=cnt;
    }
}g;
int n;
int fa[N],sz[N],top[N],Maxd[N],depth[N],p[N],ap[N];
vector <int> son[N];
bool cmp(int x,int y){
    return Maxd[x]>Maxd[y];
}
void dfs(int x,int pre,int d){
    depth[x]=Maxd[x]=d,fa[x]=pre;
    son[x].clear();
    for (int i=g.fst[x];i;i=g.nxt[i])
        if (g.y[i]!=pre){
            int y=g.y[i];
            dfs(y,x,d+1);
            Maxd[x]=max(Maxd[x],Maxd[y]);
            son[x].push_back(y);
        }
    sort(son[x].begin(),son[x].end(),cmp);
    sz[x]=(int)son[x].size();
}
int Time=0;
void Get_Top(int x,int TOP){
    top[x]=TOP;
    ap[p[x]=++Time]=x;
    if (!sz[x])
        return;
    Get_Top(son[x][0],TOP);
    for (int i=1;i<sz[x];i++)
        Get_Top(son[x][i],son[x][i]);
}
struct Seg{
    int v,add;
}t[N<<2];
void build(int rt,int L,int R){
    t[rt].v=0,t[rt].add=1;
    if (L==R)
        return;
    int mid=(L+R)>>1,ls=rt<<1,rs=ls|1;
    build(ls,L,mid);
    build(rs,mid+1,R);
}
void Times(int rt,int d){
    t[rt].v=1LL*t[rt].v*d%mod;
    t[rt].add=1LL*t[rt].add*d%mod;
}
void pushdown(int rt){
    int ls=rt<<1,rs=ls|1,&v=t[rt].add;
    if (v==1)
        return;
    Times(ls,v);
    Times(rs,v);
    v=1;
}
void update(int rt,int L,int R,int xL,int xR,int opt,int d){
    if (L>xR||R<xL||xL>xR)
        return;
    if (xL<=L&&R<=xR){
        if (opt==0)
            Times(rt,d);
        else
            t[rt].v=(t[rt].v+d)%mod;
        return;
    }
    pushdown(rt);
    int mid=(L+R)>>1,ls=rt<<1,rs=ls|1;
    update(ls,L,mid,xL,xR,opt,d);
    update(rs,mid+1,R,xL,xR,opt,d);
    t[rt].v=(t[ls].v+t[rs].v)%mod;
}
int query(int rt,int L,int R,int xL,int xR){
    if (L>xR||R<xL||xL>xR)
        return 0;
    if (xL<=L&&R<=xR)
        return t[rt].v;
    pushdown(rt);
    int mid=(L+R)>>1,ls=rt<<1,rs=ls|1;
    return (query(ls,L,mid,xL,xR)+query(rs,mid+1,R,xL,xR))%mod;
}
void Prepare(){
    dfs(1,0,0);
    Get_Top(1,1);
}
int D,addv[N],sv[N];
void DFS(int x){
    update(1,1,n,p[x],p[x],1,1);
    if (sz[x]){
        DFS(son[x][0]);
        update(1,1,n,p[x]+1,p[x]+min(D,Maxd[x]-depth[x]),0,2);
        for (int i=1;i<sz[x];i++){
            int y=son[x][i],lim=Maxd[x]-depth[x];
            int vy=min(D,Maxd[y]-depth[y]+1);
            DFS(y);
            int lastv=1;
            for (int j=0;j<=vy;j++)
                addv[j]=0;
            for (int j=1;j<=vy;j++)
                sv[j]=query(1,1,n,p[x],p[x]+min(j,D-j));
            for (int j=1;j<=vy;j++){
                int v=query(1,1,n,p[y]+j-1,p[y]+j-1);
                addv[j]=(1LL*(sv[j]+1)*v+addv[j])%mod;
                int k=D-j;
                if (k>j){
                    int inv=Pow(lastv,mod-2);
                    update(1,1,n,p[x]+j+1,p[x]+min(k,lim),0,inv);
                    lastv=(lastv+v)%mod;
                    update(1,1,n,p[x]+j+1,p[x]+min(k,lim),0,lastv);
                }
            }
            for (int j=0;j<=vy;j++)
                update(1,1,n,p[***[x]+j,1,addv[j]);
            for (int j=vy+1;j<=Maxd[y]-depth[y]+1;j++){
                int v=query(1,1,n,p[y]+j-1,p[y]+j-1);
                update(1,1,n,p[***[x]+j,1,v);
            }
        }
    }
}
int solve(int DD){
    if (DD==0)
        return n;
    build(1,1,n);
    D=DD;
    DFS(1);
    return query(1,1,n,1,Maxd[1]+1);
}
int main(){
    n=read();
    g.clear();
    for (int i=1;i<n;i++){
        int a=read(),b=read();
        g.add(a,b);
        g.add(b,a);
    }
    Prepare();
    int DD=read();
    printf("%d",(solve(DD)-solve(DD-1)+mod)%mod);
    return 0;
}

全部评论

(0) 回帖
加载中...
话题 回帖

等你来战

查看全部

热门推荐