#include<bits/stdc++.h>
using namespace std;
const int N=100010;
typedef long long LL;
int p[N],sum[N],sum2[N];//sum表示联通块(并查集)内点的数量,sum2表示每个联通块内是否有炸弹(几个无所谓,只要知道他是否为0即可)
int find(int n)
{
if(n!=p[n]) p[n]=find(p[n]);
return p[n];
}
int main()
{
int n,m;
cin>>n>>m;
for(int i=1;i<=n;i++)
{
sum[i]=1;
p[i]=i;
}
while(m--)
{
int a,b;
cin>>a>>b;
if(find(a)==find(b)) continue;
sum[find(b)]+=sum[find(a)];
p[find(a)]=find(b);
}
int sum3=0;//sum3的意思是总的炸弹数量,若为0即所有连通块无炸弹
for(int i=1;i<=n;i++)
{
int x;
cin>>x;
sum3+=x;
if(x)
sum2[find(i)]+=x;
}
bool st1[N],st2[N];
LL res=0;
if(sum3==0)
{
for(int i=1;i<=n;i++)
{
if(!st2[find(i)])
res+=(LL)sum[find(i)]*(LL)sum[find(i)],st2[find(i)]=true;
}
printf("%lld",res);
}
else
{
int ans=0;
for(int i=1;i<=n;i++)
{
if(!st1[find(i)]&&sum2[find(i)]!=0)
res+=(LL)sum[find(i)]*(LL)sum[find(i)],ans++,st1[find(i)]=true;
}
if(ans==1)
printf("%lld",res);
else
printf("0");
}
}
全部评论
(1) 回帖