#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
#include <math.h>
using namespace std;
typedef long long LL;
const int N = 1e5 + 10, M = N * 2;
int n,m,seq;
int L[N], R[N];
int h[N], e[M], ne[M], w[M], idx;
int depth[N],mx;
LL s[N];
vector<int> f[N];
vector<int> val[N];
vector<vector<vector<LL>>> p;//RMQ
void init()
{
for(int i = 1; i <= mx; i++)
{
int len = f[i].size() - 1;
// cout << "len: " << len << endl;
vector<vector<LL>> vt(len + 1);
for(int j = 0; j <= len; j++)
{
vector<LL> tmp(20);
vt[j] = tmp;
}
for(int k = 0; k <= 17; k++)
{
for(int j = 0; j + (1 << k) - 1 <= len; j++)
{
if(!k)vt[j][k] = val[i][j];
else vt[j][k] = max(vt[j][k - 1], vt[j + (1 << k - 1) ][k - 1]);
}
}
p.push_back(vt);
}
}
void add(int a, int b, int c)
{
e[idx] = b; w[idx] = c; ne[idx] = h[a]; h[a] = idx++;
}
void dfs(int u, int fa)
{
L[u] = ++ seq;
for(int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if(j == fa) continue;
depth[j] = depth[u] + 1;
mx = max(mx, depth[j]);
s[j] = s[u] + w[i];
dfs(j, u);
f[depth[j]].push_back( L[j] ); //存储dfs序
val[depth[j]].push_back(s[j]); //对应每个dfs序的权值
}
R[u] = seq;
}
LL query(int t, int l, int r)
{
int len = r - l + 1;
int k = log(len) / log(2);
// cout << "l: "<< l << " r: " << r << " k: " << k << " " << r - (1 << k) << endl;
LL res = p[t - 1][l][k];
if(r - (1 << k) >= 0)res = max(res, p[t - 1][r - (1 << k)][k]);
return res;
}
int up(int t, int x)//找第一个大于等于x的下标位置
{
int len = f[t].size();
if(!len)return -1;
int l = 0, r = len - 1;
while(l < r)
{
int mid = (l + r) >> 1;
if(f[t][mid] >= x)r = mid;
else l = mid + 1;
}
if(f[t][l] < x) return -1;
return l;
}
int lower(int t, int x)//找第一个小于等于x的下标位置
{
int len = f[t].size();
if(!len)return -1;
int l = 0, r = len - 1;
while(l < r)
{
int mid = (l + r + 1) >> 1;
if(f[t][mid] > x)r = mid - 1;
else l = mid;
}
if(f[t][l] > x) return -1;
return l;
}
int main()
{
memset(h, -1, sizeof h);
scanf("%d", &n);
for(int i = 1; i < n; i++)
{
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c), add(b, a, c);
}
dfs(1, -1);
// for(int i = 1; i <= n; i++)cout << L[i] << " " ;
// cout << endl;
// for(int i = 1; i <= n; i++)cout << R[i] << " ";
// cout << endl;
// for(int i = 0; i < 3 ; i++)
// {
// for(int d:f[i])
// {
// cout << d << " ";
// }
// cout << endl;
// }
init();
scanf("%d", &m);
while(m--)
{
int u, k;
scanf("%d%d", &u, &k);
// 询问
int t = depth[u] + k;
if(t>mx)
{
printf("-1\n");
continue;
}
int l = up(t, L[u]), r = lower(t, R[u]);
if(l == -1 || r == -1)
{
printf("-1\n");
continue;
}
// cout << l << " " << r << endl;
printf("%d\n", query(t, l, r) - s[u]);
}
return 0;
}
错误原因:二分时所找到的区间l, r 有可能l > r 此时非法
错误数据:
7
1 2 1
2 3 1
2 5 1
1 3 1
1 6 1
6 7 1
1
3 1
Ac代码 :
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
#include <math.h>
using namespace std;
typedef long long LL;
const int N = 1e5 + 10, M = N * 2;
int n,m,seq;
int L[N], R[N];
int h[N], e[M], ne[M], w[M], idx;
int depth[N],mx;
LL s[N];
vector<int> f[N];
vector<LL> val[N];
vector<vector<vector<LL>>> p;
void init()
{
for(int i = 1; i <= mx; i++)
{
int len = f[i].size() - 1;
// cout << "len: " << len << endl;
vector<vector<LL>> vt(len + 1);
for(int j = 0; j <= len; j++)
{
vector<LL> tmp(20);
vt[j] = tmp;
}
for(int k = 0; k <= 17; k++)
{
for(int j = 0; j + (1 << k) - 1 <= len; j++)
{
if(!k)vt[j][k] = val[i][j];
else vt[j][k] = max(vt[j][k - 1], vt[j + (1 << k - 1)][k - 1]);
}
}
p.push_back(vt);
}
}
void add(int a, int b, int c)
{
e[idx] = b; w[idx] = c; ne[idx] = h[a]; h[a] = idx++;
}
void dfs(int u, int fa)
{
L[u] = ++ seq;
for(int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if(j == fa) continue;
depth[j] = depth[u] + 1;
mx = max(mx, depth[j]);
s[j] = s[u] + w[i];
dfs(j, u);
f[depth[j]].push_back( L[j] ); //存储dfs序
val[depth[j]].push_back(s[j]); //对应每个dfs序的权值
}
R[u] = seq;
}
LL query(int t, int l, int r)
{
int len = r - l + 1;
int k = log(len) / log(2);
return max(p[t - 1][l][k], p[t - 1][r - (1 << k) + 1][k]);
}
int up(int t, int x)
{
int len = f[t].size();
if(!len)return -1;
int l = 0, r = len - 1;
while(l < r)
{
int mid = (l + r) >> 1;
if(f[t][mid] >= x)r = mid;
else l = mid + 1;
}
if(f[t][l] < x) return -1;
return l;
}
int lower(int t, int x)
{
int len = f[t].size();
if(!len)return -1;
int l = 0, r = len - 1;
while(l < r)
{
int mid = (l + r + 1) >> 1;
if(f[t][mid] > x)r = mid - 1;
else l = mid;
}
if(f[t][l] > x) return -1;
return l;
}
int main()
{
memset(h, -1, sizeof h);
scanf("%d", &n);
for(int i = 1; i < n; i++)
{
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c), add(b, a, c);
}
dfs(1, -1);
init();
scanf("%d", &m);
while(m--)
{
int u, k;
scanf("%d%d", &u, &k);
// 询问
int t = depth[u] + k;
if(t > mx)
{
printf("-1\n");
continue;
}
int l = up(t, L[u]), r = lower(t, R[u]);
if(l > r || l == -1 || r == -1)
{
printf("-1\n");
continue;
}
printf("%lld\n", query(t, l, r) - s[u]);
}
return 0;
}

全部评论
(0) 回帖