这题……做法很多吧。据说莫队玄学多交几次就AC了。
题目要我们查询两个区间的不同的数的个数。我们先考虑一个区间内的做法。
这是一个经典的问题?(自行百度)
在线做法的话,只需要用可持久化线段树维护前个数中每个数最后一次出现的下标。
这样在扫一边的时候,如果某个数在前面出现过,就只要把前面的那个删掉,加上这个就好。
代码如下:
for (int i = 1; i <= n; i++)
{
int v = a[i];
if (~last[v])
{
int t = update(last[v], root[i - 1], -1, 1, n);
root[i] = update(i, t, 1, 1, n);
}
else
root[i] = update(i, root[i - 1], 1, 1, n);
last[v] = i;
} 这样在查询的时候,我们只要在第个线段树中查询大于等于
的数有多少个就行了。
现在考虑原来的问题。怎么处理两个区间?容易想到,把原来的序列在后面复制一次,那么原来的查询就等价于查询中不同的数的个数,那么这题就做完了。
(当然这个问题也是可以离线树状数组来做的)
完整代码(额,似乎比赛的时候判题机跑得比较快,如果过不了请自行开读入挂):
#include <bits/stdc++.h>
using namespace std;
const int N = 1 << 18;
int cnt = 0;
struct Node
{
int l, r, sum;
} p[N << 5];
int update(int pos, int c, int v, int l, int r)
{
int nc = ++cnt;
p[nc] = p[c];
p[nc].sum += v;
if (l == r) return nc;
int m = l + r >> 1;
if (m >= pos)
p[nc].l = update(pos, p[c].l, v, l, m);
else
p[nc].r = update(pos, p[c].r, v, m + 1, r);
return nc;
}
int query(int pos, int c, int l, int r)
{
if (l == r) return p[c].sum;
int m = l + r >> 1;
if (m >= pos)
return p[p[c].r].sum + query(pos, p[c].l, l, m);
return query(pos, p[c].r, m + 1, r);
}
int a[N];
int root[N];
int last[N];
int main()
{
int n, q;
while (~scanf("%d%d", &n, &q))
{
cnt = 0;
memset(last, -1, sizeof last);
for (int i = 1; i <= n; i++) scanf("%d", a + i);
for (int i = 1; i <= n; i++) a[n + i] = a[i];
int m = n;
n <<= 1;
for (int i = 1; i <= n; i++)
{
int v = a[i];
if (~last[v])
{
int t = update(last[v], root[i - 1], -1, 1, n);
root[i] = update(i, t, 1, 1, n);
}
else
root[i] = update(i, root[i - 1], 1, 1, n);
last[v] = i;
}
while (q--)
{
int x, y;
scanf("%d %d", &x, &y);
x += m;
swap(x, y);
printf("%d\n", query(x, root[y], 1, n));
}
}
return 0;
} 不过非常可惜,判题机实在是太慢了。这个做法在比赛的时候也是卡过去的。
因此我们可以考虑离线操作,按查询的右端点递增的顺序来处理,这样求和就可以用树状数组来解决了。
代码如下:
#include <bits/stdc++.h>
using namespace std;
struct query
{
int l, r, id;
};
int main()
{
int n, q;
while (~scanf("%d%d", &n, &q))
{
int m = n << 1 | 1;
vector<int> a(m), bit(m);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
for (int i = 1; i <= n; i++) a[n + i] = a[i];
vector<query> v(q);
vector<int> last(n + 1, -1), ans(q);
for (int i = 0; i < q; i++) scanf("%d%d", &v[i].r, &v[i].l), v[i].id = i, v[i].r += n;
sort(v.begin(), v.end(), [](const query& a, const query& b) { return a.r < b.r; });
for (int i = 0, j = 1; i < q; i++)
{
for (; j <= v[i].r; j++)
{
if (~last[a[j]])
for (int k = last[a[j]]; k < m; k += k & -k) bit[k]--;
for (int k = j; k < m; k += k & -k) bit[k]++;
last[a[j]] = j;
}
int t = 0;
for (int k = m; k; k -= k & -k) t += bit[k];
for (int k = v[i].l - 1; k; k -= k & -k) t -= bit[k];
ans[v[i].id] = t;
}
for(int i = 0; i < q; i++) printf("%d\n", ans[i]);
}
return 0;
}
全部评论
(5) 回帖