这题……做法很多吧。据说莫队玄学多交几次就AC了。
题目要我们查询两个区间的不同的数的个数。我们先考虑一个区间内的做法。
这是一个经典的问题?(自行百度)
在线做法的话,只需要用可持久化线段树维护前个数中每个数最后一次出现的下标。
这样在扫一边的时候,如果某个数在前面出现过,就只要把前面的那个删掉,加上这个就好。
代码如下:
for (int i = 1; i <= n; i++) { int v = a[i]; if (~last[v]) { int t = update(last[v], root[i - 1], -1, 1, n); root[i] = update(i, t, 1, 1, n); } else root[i] = update(i, root[i - 1], 1, 1, n); last[v] = i; }
这样在查询的时候,我们只要在第个线段树中查询大于等于的数有多少个就行了。
现在考虑原来的问题。怎么处理两个区间?容易想到,把原来的序列在后面复制一次,那么原来的查询就等价于查询中不同的数的个数,那么这题就做完了。
(当然这个问题也是可以离线树状数组来做的)
完整代码(额,似乎比赛的时候判题机跑得比较快,如果过不了请自行开读入挂):
#include <bits/stdc++.h> using namespace std; const int N = 1 << 18; int cnt = 0; struct Node { int l, r, sum; } p[N << 5]; int update(int pos, int c, int v, int l, int r) { int nc = ++cnt; p[nc] = p[c]; p[nc].sum += v; if (l == r) return nc; int m = l + r >> 1; if (m >= pos) p[nc].l = update(pos, p[c].l, v, l, m); else p[nc].r = update(pos, p[c].r, v, m + 1, r); return nc; } int query(int pos, int c, int l, int r) { if (l == r) return p[c].sum; int m = l + r >> 1; if (m >= pos) return p[p[c].r].sum + query(pos, p[c].l, l, m); return query(pos, p[c].r, m + 1, r); } int a[N]; int root[N]; int last[N]; int main() { int n, q; while (~scanf("%d%d", &n, &q)) { cnt = 0; memset(last, -1, sizeof last); for (int i = 1; i <= n; i++) scanf("%d", a + i); for (int i = 1; i <= n; i++) a[n + i] = a[i]; int m = n; n <<= 1; for (int i = 1; i <= n; i++) { int v = a[i]; if (~last[v]) { int t = update(last[v], root[i - 1], -1, 1, n); root[i] = update(i, t, 1, 1, n); } else root[i] = update(i, root[i - 1], 1, 1, n); last[v] = i; } while (q--) { int x, y; scanf("%d %d", &x, &y); x += m; swap(x, y); printf("%d\n", query(x, root[y], 1, n)); } } return 0; }
不过非常可惜,判题机实在是太慢了。这个做法在比赛的时候也是卡过去的。
因此我们可以考虑离线操作,按查询的右端点递增的顺序来处理,这样求和就可以用树状数组来解决了。
代码如下:
#include <bits/stdc++.h> using namespace std; struct query { int l, r, id; }; int main() { int n, q; while (~scanf("%d%d", &n, &q)) { int m = n << 1 | 1; vector<int> a(m), bit(m); for (int i = 1; i <= n; i++) scanf("%d", &a[i]); for (int i = 1; i <= n; i++) a[n + i] = a[i]; vector<query> v(q); vector<int> last(n + 1, -1), ans(q); for (int i = 0; i < q; i++) scanf("%d%d", &v[i].r, &v[i].l), v[i].id = i, v[i].r += n; sort(v.begin(), v.end(), [](const query& a, const query& b) { return a.r < b.r; }); for (int i = 0, j = 1; i < q; i++) { for (; j <= v[i].r; j++) { if (~last[a[j]]) for (int k = last[a[j]]; k < m; k += k & -k) bit[k]--; for (int k = j; k < m; k += k & -k) bit[k]++; last[a[j]] = j; } int t = 0; for (int k = m; k; k -= k & -k) t += bit[k]; for (int k = v[i].l - 1; k; k -= k & -k) t -= bit[k]; ans[v[i].id] = t; } for(int i = 0; i < q; i++) printf("%d\n", ans[i]); } return 0; }
全部评论
(5) 回帖