竞赛讨论区 > 周赛143视频讲解对应代码
头像
_Bingbong
发布于 05-10 21:10 福建
+ 关注

周赛143视频讲解对应代码

牛客周赛 143 讲解

A.小红的区间构造

代码查看
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    i64 x;
    cin >> x;
    cout << "1 " << x;
    return 0;
}

import java.io.*;
import java.util.*;

public class Main {
    public static void main(String[] args) throws Exception {
        FastScanner fs = new FastScanner(System.in);
        long x = fs.nextLong();
        System.out.println("1 " + x);
    }
    static class FastScanner {
        private final InputStream in; private final byte[] buffer = new byte[1 << 16]; private int ptr = 0, len = 0;
        FastScanner(InputStream is) { in = is; }
        int read() throws IOException { if (ptr >= len) { len = in.read(buffer); ptr = 0; if (len <= 0) return -1; } return buffer[ptr++]; }
        long nextLong() throws IOException { int c; do { c = read(); } while (c <= 32 && c >= 0); long sign = 1; if (c == '-') { sign = -1; c = read(); } long v = 0; while (c > 32) { v = v * 10 + c - '0'; c = read(); } return v * sign; }
    }
}
x=int(input())
print(1, x)

B.小红的冷门副本

代码查看
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    int n, m, x;
    cin >> n >> m >> x;
    map<int, int> cnt;
    for (int i = 1; i <= n; i++){
        int v;
        cin >> v;
        cnt[v]++;
    }
    int ans = m ;
    for (auto v : cnt){
        ans -= (v.second > x);
    }
    cout << ans;
    return 0;
}

import java.io.*;
import java.util.*;

public class Main {
    public static void main(String[] args) throws Exception {
        FastScanner fs = new FastScanner(System.in);
        int n = fs.nextInt();
        long m = fs.nextLong();
        int x = fs.nextInt();
        HashMap<Long, Integer> cnt = new HashMap<>(n * 2);
        for (int i = 0; i < n; i++) {
            long a = fs.nextLong();
            cnt.put(a, cnt.getOrDefault(a, 0) + 1);
        }
        long ans = m ;
        for (int v : cnt.values()) if (v > x) ans--;
        System.out.println(ans);
    }
    static class FastScanner {
        private final InputStream in; private final byte[] buffer = new byte[1 << 16]; private int ptr = 0, len = 0;
        FastScanner(InputStream is) { in = is; }
        int read() throws IOException { if (ptr >= len) { len = in.read(buffer); ptr = 0; if (len <= 0) return -1; } return buffer[ptr++]; }
        long nextLong() throws IOException { int c; do { c = read(); } while (c <= 32 && c >= 0); long sign = 1; if (c == '-') { sign = -1; c = read(); } long v = 0; while (c > 32) { v = v * 10 + c - '0'; c = read(); } return v * sign; }
        int nextInt() throws IOException { return (int) nextLong(); }
    }
}
import sys
from collections import Counter

data = list(map(int, sys.stdin.buffer.read().split()))
n, m, x = data[:3]
a = data[3:]
cnt = Counter(a)
ans = m -  sum(1 for v in cnt.values() if v > x)
print(ans)

C.小红的因子幂和

代码查看
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
const int mod = 1e9 + 7;
map<int, int> cnt;
void fc(int x){
    vector<pair<int, int>> c;
    for (int i = 2; i <= sqrt(x); i++){
        int c = 0;
        while (x % i == 0){
            c++;
            x /= i;
        }
        if (c > 0)
            cnt[i] += c;
    }
    if (x > 1)
        cnt[x]++;
}
i64 qmi(i64 a, i64 b){
    a %= mod;
    i64 ans = 1;
    while (b){
        if (b & 1)
            ans = ans * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return ans;
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    int x, y;
    cin >> x >> y;
    fc(x), fc(y);
    vector<i64> a;
    a.push_back(1);
    for (auto [num, c] : cnt){
        i64 cur = 1;
        int tmp_size = a.size();
        for (int i = 1; i <= c; i++){
            cur = cur * num;
            for (int j = 0; j < tmp_size; j++){
                a.push_back(a[j] * cur);
            }
        }
    }
    i64 ans = 0;
    for (auto t : a)
        ans = (ans + qmi(t, t)) % mod;
    cout << ans;
    return 0;
}


import java.io.*;
import java.util.*;

public class Main {
    static final long MOD = 1000000007L;
    static long modPow(long a, long e) {
        a %= MOD;
        long r = 1;
        while (e > 0) {
            if ((e & 1) == 1) r = r * a % MOD;
            a = a * a % MOD;
            e >>= 1;
        }
        return r;
    }
    static void factor(long v, TreeMap<Long, Integer> fac) {
        for (long p = 2; p <= v / p; p++) {
            if (v % p != 0) continue;
            int cnt = 0;
            while (v % p == 0) {
                v /= p;
                cnt++;
            }
            fac.put(p, fac.getOrDefault(p, 0) + cnt);
        }
        if (v > 1) fac.put(v, fac.getOrDefault(v, 0) + 1);
    }
    public static void main(String[] args) throws Exception {
        FastScanner fs = new FastScanner(System.in);
        long x = fs.nextLong(), y = fs.nextLong();
        TreeMap<Long, Integer> fac = new TreeMap<>();
        factor(x, fac);
        factor(y, fac);
        ArrayList<Long> divisors = new ArrayList<>();
        divisors.add(1L);
        for (Map.Entry<Long, Integer> e : fac.entrySet()) {
            int old = divisors.size();
            long mul = 1;
            for (int i = 1; i <= e.getValue(); i++) {
                mul *= e.getKey();
                for (int j = 0; j < old; j++) divisors.add(divisors.get(j) * mul);
            }
        }
        long ans = 0;
        for (long d : divisors) {
            ans = (ans + modPow(d, d)) % MOD;
        }
        System.out.println(ans);
    }
    static class FastScanner {
        private final InputStream in; private final byte[] buffer = new byte[1 << 16]; private int ptr = 0, len = 0;
        FastScanner(InputStream is) { in = is; }
        int read() throws IOException { if (ptr >= len) { len = in.read(buffer); ptr = 0; if (len <= 0) return -1; } return buffer[ptr++]; }
        long nextLong() throws IOException { int c; do { c = read(); } while (c <= 32 && c >= 0); long v = 0; while (c > 32) { v = v * 10 + c - '0'; c = read(); } return v; }
    }
}
import sys
MOD = 10**9 + 7
x, y = map(int, sys.stdin.buffer.read().split())

def factor(v, fac):
    p = 2
    while p * p <= v:
        if v % p == 0:
            cnt = 0
            while v % p == 0:
                v //= p
                cnt += 1
            fac[p] = fac.get(p, 0) + cnt
        p += 1
    if v > 1:
        fac[v] = fac.get(v, 0) + 1

fac = {}
factor(x, fac)
factor(y, fac)
divisors = [1]
for prime, exp in fac.items():
    old = divisors[:]
    mul = 1
    for _ in range(exp):
        mul *= prime
        divisors += [d * mul for d in old]

ans = 0
for d in divisors:
    ans = (ans + pow(d, d, MOD)) % MOD
print(ans)

D. 小红的最佳区间

代码查看
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    int n, k;
    cin >> n >> k;
    map<int, int> d;
    for (int i = 1; i <= n; i++){
        int l, r;
        cin >> l >> r;
        d[l - k]++;
        d[r + 1]--;
    }

    int ans = 0, sum = 0;
    for (auto [a, b] : d){
        sum = sum + b;
        ans = max(ans, sum);
    }
    cout << ans;
    return 0;
}


import java.io.*;
import java.util.*;

public class Main {
    public static void main(String[] args) throws Exception {
        FastScanner fs = new FastScanner(System.in);
        int n = fs.nextInt();
        long k = fs.nextLong();
        TreeMap<Long, Integer> diff = new TreeMap<>();
        for (int i = 0; i < n; i++) {
            long l = fs.nextLong();
            long r = fs.nextLong();
            diff.put(l - k, diff.getOrDefault(l - k, 0) + 1);
            diff.put(r + 1, diff.getOrDefault(r + 1, 0) - 1);
        }
        int cur = 0, ans = 0;
        for (int delta : diff.values()) {
            cur += delta;
            ans = Math.max(ans, cur);
        }
        System.out.println(ans);
    }
    static class FastScanner {
        private final InputStream in; private final byte[] buffer = new byte[1 << 16]; private int ptr = 0, len = 0;
        FastScanner(InputStream is) { in = is; }
        int read() throws IOException { if (ptr >= len) { len = in.read(buffer); ptr = 0; if (len <= 0) return -1; } return buffer[ptr++]; }
        long nextLong() throws IOException { int c; do { c = read(); } while (c <= 32 && c >= 0); long sign = 1; if (c == '-') { sign = -1; c = read(); } long v = 0; while (c > 32) { v = v * 10 + c - '0'; c = read(); } return v * sign; }
        int nextInt() throws IOException { return (int) nextLong(); }
    }
}
import sys
from collections import defaultdict

data = list(map(int, sys.stdin.buffer.read().split()))
n, k = data[:2]
diff = defaultdict(int)
idx = 2
for _ in range(n):
    l, r = data[idx], data[idx + 1]
    idx += 2
    diff[l - k] += 1
    diff[r + 1] -= 1
cur = ans = 0
for pos in sorted(diff):
    cur += diff[pos]
    ans = max(ans, cur)
print(ans)

E.小红的好矩阵

代码查看
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    int n;
    cin >> n;
    vector<string> s(3);
    cin >> s[1] >> s[2];
    if (n % 3 != 0){
        cout << "-1";
        return 0;
    }
    auto cul1 = [&](int l, int r) -> pair<int, int> {
        int temp1 = 0, temp2 = 0;
        for (int i = l; i <= r; i++){
            temp1 += (s[1][i] != '1');
            temp1 += (s[2][i] != '0');
            temp2 += (s[1][i] != '0');
            temp2 += (s[2][i] != '1');
        }
        return {temp1, temp2};
    };
    int ans1 = 0, ans2 = 0, cnt = 0;
    for (int i = 0; i < n; i += 3) {
        auto t = cul1(i, i + 2);
        cnt++;
        if (cnt & 1){
            ans1 += t.first;
            ans2 += t.second;
        }
        else{
            ans1 += t.second;
            ans2 += t.first;
        }
    }
    auto cul2 = [&](int l, int r) -> pair<int, int> {
        int temp1 = 0, temp2 = 0;
        temp1 += (s[1][l] != '1');
        temp1 += (s[2][l] != '1');
        temp1 += (s[1][r] != '0');
        temp1 += (s[2][r] != '0');
        temp1 += min((s[1][l + 1] != '1') + (s[2][l + 1] != '0'), (s[1][l + 1] != '0') + (s[2][l + 1] != '1'));
        temp2 += (s[1][l] != '0');
        temp2 += (s[2][l] != '0');
        temp2 += (s[1][r] != '1');
        temp2 += (s[2][r] != '1');
        temp2 += min((s[1][l + 1] != '0') + (s[2][l + 1] != '1'), (s[1][l + 1] != '1') + (s[2][l + 1] != '0'));
        return {temp1, temp2};
    };
    int ans3 = 0, ans4 = 0;
    for (int i = 0; i < n; i += 3){
        auto t = cul2(i, i + 2);
        ans3 += t.first;
        ans4 += t.second;
    }
    cout << min({ans1, ans2, ans3, ans4});
    return 0;
}
import java.io.*;
import java.util.*;

public class Main {
    static int mismatch(char got, char need) {
        return got == need ? 0 : 1;
    }
    static int[] cul1(String a, String b, int l, int r) {
        int t1 = 0, t2 = 0;
        for (int i = l; i <= r; i++) {
            t1 += mismatch(a.charAt(i), '1');
            t1 += mismatch(b.charAt(i), '0');
            t2 += mismatch(a.charAt(i), '0');
            t2 += mismatch(b.charAt(i), '1');
        }
        return new int[]{t1, t2};
    }
    static int[] cul2(String a, String b, int l, int r) {
        int t1 = 0, t2 = 0;
        t1 += mismatch(a.charAt(l), '1');
        t1 += mismatch(b.charAt(l), '1');
        t1 += mismatch(a.charAt(r), '0');
        t1 += mismatch(b.charAt(r), '0');
        t1 += Math.min(mismatch(a.charAt(l + 1), '1') + mismatch(b.charAt(l + 1), '0'),
                       mismatch(a.charAt(l + 1), '0') + mismatch(b.charAt(l + 1), '1'));
        t2 += mismatch(a.charAt(l), '0');
        t2 += mismatch(b.charAt(l), '0');
        t2 += mismatch(a.charAt(r), '1');
        t2 += mismatch(b.charAt(r), '1');
        t2 += Math.min(mismatch(a.charAt(l + 1), '0') + mismatch(b.charAt(l + 1), '1'),
                       mismatch(a.charAt(l + 1), '1') + mismatch(b.charAt(l + 1), '0'));
        return new int[]{t1, t2};
    }
    public static void main(String[] args) throws Exception {
        FastScanner fs = new FastScanner(System.in);
        int n = fs.nextInt();
        String a = fs.next();
        String b = fs.next();
        if (n % 3 != 0) {
            System.out.println(-1);
            return;
        }
        int ans1 = 0, ans2 = 0, cnt = 0;
        for (int i = 0; i < n; i += 3) {
            int[] t = cul1(a, b, i, i + 2);
            cnt++;
            if ((cnt & 1) == 1) {
                ans1 += t[0];
                ans2 += t[1];
            } else {
                ans1 += t[1];
                ans2 += t[0];
            }
        }
        int ans3 = 0, ans4 = 0;
        for (int i = 0; i < n; i += 3) {
            int[] t = cul2(a, b, i, i + 2);
            ans3 += t[0];
            ans4 += t[1];
        }
        System.out.println(Math.min(Math.min(ans1, ans2), Math.min(ans3, ans4)));
    }
    static class FastScanner {
        private final InputStream in; private final byte[] buffer = new byte[1 << 16]; private int ptr = 0, len = 0;
        FastScanner(InputStream is) { in = is; }
        int read() throws IOException { if (ptr >= len) { len = in.read(buffer); ptr = 0; if (len <= 0) return -1; } return buffer[ptr++]; }
        String next() throws IOException { int c; do { c = read(); } while (c <= 32 && c >= 0); StringBuilder sb = new StringBuilder(); while (c > 32) { sb.append((char)c); c = read(); } return sb.toString(); }
        int nextInt() throws IOException { return Integer.parseInt(next()); }
    }
}
import sys

data = sys.stdin.read().split()
n = int(data[0])
a, b = data[1], data[2]
if n % 3 != 0:
    print(-1)
    raise SystemExit

def mis(ch, need):
    return 0 if ch == need else 1

def cul1(l, r):
    t1 = t2 = 0
    for i in range(l, r + 1):
        t1 += mis(a[i], '1') + mis(b[i], '0')
        t2 += mis(a[i], '0') + mis(b[i], '1')
    return t1, t2

def cul2(l, r):
    t1 = mis(a[l], '1') + mis(b[l], '1') + mis(a[r], '0') + mis(b[r], '0')
    t1 += min(mis(a[l + 1], '1') + mis(b[l + 1], '0'),
              mis(a[l + 1], '0') + mis(b[l + 1], '1'))
    t2 = mis(a[l], '0') + mis(b[l], '0') + mis(a[r], '1') + mis(b[r], '1')
    t2 += min(mis(a[l + 1], '0') + mis(b[l + 1], '1'),
              mis(a[l + 1], '1') + mis(b[l + 1], '0'))
    return t1, t2

ans1 = ans2 = cnt = 0
for i in range(0, n, 3):
    x, y = cul1(i, i + 2)
    cnt += 1
    if cnt & 1:
        ans1 += x
        ans2 += y
    else:
        ans1 += y
        ans2 += x

ans3 = ans4 = 0
for i in range(0, n, 3):
    x, y = cul2(i, i + 2)
    ans3 += x
    ans4 += y

print(min(ans1, ans2, ans3, ans4))

F.小红的网格路径 II

代码查看
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
const int mod = 1e9 + 7;
i64 qmi(i64 a, i64 b){
    a %= mod;
    i64 ans = 1;
    while (b){
        if (b & 1)
            ans = ans * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return ans;
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    int n, m, k;
    cin >> n >> m >> k;
    map<int, vector<pair<int, int>>> v; // 存放每列的可通过区间
    vector<int> a; // 存放存在不可通过的点的列编号
    for (int i = 1; i <= k; i++){
        int x, y;
        cin >> x >> y;
        a.push_back(y);
        if(x == 1){
            if (n == 1){
                cout << "0";
                return 0;
            }
            else{
                v[y].push_back({2, n});
            }
        }else if(x == n){
            if (n == 1){
                cout << "0";
                return 0;
            }
            else{
                v[y].push_back({1, n - 1});
            }
        }else{
            v[y].push_back({1, x - 1});
            v[y].push_back({x + 1, n});
        }
    }
    map<pair<int, int>, i64> dp;
    if (v[1].size() == 0){
        v[1].push_back({1, n});
        dp[{1, n}] = 1;
    }else{
        dp[{v[1][0].first, v[1][0].second}] = 1;
        dp[{v[1][1].first, v[1][1].second}] = 0;
    }
    if (v[m].size() == 0)
        v[m].push_back({1, n});
    a.push_back(1);
    a.push_back(m);
    sort(a.begin(), a.end());
    a.erase(unique(a.begin(), a.end()), a.end());
    auto get = [&](auto x, auto y) -> i64 {
        if (x.first > y.first)
            swap(x, y);
        if (x.second < y.first)
            return 0;
        if (y.second < x.second)
            return y.second - y.first + 1;
        return x.second - y.first + 1;
    };
    for (int i = 0; i < a.size() - 1; i++){
        map<pair<int, int>, i64> ndp;
        for (auto r1 : v[a[i]]){
            for (auto r2 : v[a[i + 1]]){
                if (a[i] + 1 == a[i + 1]){
                    i64 len = get(r1, r2);
                    ndp[r2] = (ndp[r2] + dp[r1] * len % mod) % mod;
                }else{
                    ndp[r2] += dp[r1] * (r1.second - r1.first + 1) %mod * (r2.second - r2.first + 1) % mod * qmi(n, a[i + 1] - a[i] - 2)%mod;
                    ndp[r2] %= mod;
                }
            }
        }
        dp.swap(ndp);
    }
    for (auto ran : dp){
        if (ran.first.second == n){
            cout << ran.second;
            return 0;
        }
    }
    return 0;
}

import java.io.*;
import java.util.*;

public class Main {
    static final long MOD = 1000000007L;
    static class Range implements Comparable<Range> {
        long l, r;
        Range(long l, long r) { this.l = l; this.r = r; }
        public int compareTo(Range o) {
            int c = Long.compare(l, o.l);
            return c != 0 ? c : Long.compare(r, o.r);
        }
        public boolean equals(Object obj) {
            if (!(obj instanceof Range)) return false;
            Range o = (Range)obj;
            return l == o.l && r == o.r;
        }
        public int hashCode() { return Objects.hash(l, r); }
    }
    static long modPow(long a, long b) {
        a %= MOD;
        long ans = 1;
        while (b > 0) {
            if ((b & 1) == 1) ans = ans * a % MOD;
            a = a * a % MOD;
            b >>= 1;
        }
        return ans;
    }
    static long overlap(Range x, Range y) {
        if (x.l > y.l) {
            Range t = x; x = y; y = t;
        }
        if (x.r < y.l) return 0;
        if (y.r < x.r) return y.r - y.l + 1;
        return x.r - y.l + 1;
    }
    static long len(Range r) {
        return r.r - r.l + 1;
    }
    public static void main(String[] args) throws Exception {
        FastScanner fs = new FastScanner(System.in);
        long n = fs.nextLong();
        long m = fs.nextLong();
        int k = fs.nextInt();
        TreeMap<Long, ArrayList<Range>> ranges = new TreeMap<>();
        ArrayList<Long> cols = new ArrayList<>();
        for (int i = 0; i < k; i++) {
            long x = fs.nextLong();
            long y = fs.nextLong();
            cols.add(y);
            ArrayList<Range> list = ranges.computeIfAbsent(y, key -> new ArrayList<>());
            if (x == 1) {
                if (n == 1) {
                    System.out.println(0);
                    return;
                }
                list.add(new Range(2, n));
            } else if (x == n) {
                if (n == 1) {
                    System.out.println(0);
                    return;
                }
                list.add(new Range(1, n - 1));
            } else {
                list.add(new Range(1, x - 1));
                list.add(new Range(x + 1, n));
            }
        }
        TreeMap<Range, Long> dp = new TreeMap<>();
        ArrayList<Range> first = ranges.computeIfAbsent(1L, key -> new ArrayList<>());
        if (first.isEmpty()) {
            Range all = new Range(1, n);
            first.add(all);
            dp.put(all, 1L);
        } else {
            dp.put(first.get(0), 1L);
            if (first.size() > 1) dp.put(first.get(1), 0L);
        }
        ranges.computeIfAbsent(m, key -> new ArrayList<>());
        if (ranges.get(m).isEmpty()) ranges.get(m).add(new Range(1, n));
        cols.add(1L);
        cols.add(m);
        Collections.sort(cols);
        ArrayList<Long> uniq = new ArrayList<>();
        for (long c : cols) {
            if (uniq.isEmpty() || uniq.get(uniq.size() - 1) != c) uniq.add(c);
        }
        for (int i = 0; i + 1 < uniq.size(); i++) {
            long c1 = uniq.get(i), c2 = uniq.get(i + 1);
            TreeMap<Range, Long> ndp = new TreeMap<>();
            for (Range r1 : ranges.get(c1)) {
                long ways = dp.getOrDefault(r1, 0L);
                for (Range r2 : ranges.get(c2)) {
                    long add;
                    if (c1 + 1 == c2) {
                        add = ways * overlap(r1, r2) % MOD;
                    } else {
                        add = ways * (len(r1) % MOD) % MOD * (len(r2) % MOD) % MOD * modPow(n, c2 - c1 - 2) % MOD;
                    }
                    ndp.put(r2, (ndp.getOrDefault(r2, 0L) + add) % MOD);
                }
            }
            dp = ndp;
        }
        for (Map.Entry<Range, Long> e : dp.entrySet()) {
            if (e.getKey().r == n) {
                System.out.println(e.getValue() % MOD);
                return;
            }
        }
    }
    static class FastScanner {
        private final InputStream in; private final byte[] buffer = new byte[1 << 16]; private int ptr = 0, len = 0;
        FastScanner(InputStream is) { in = is; }
        int read() throws IOException { if (ptr >= len) { len = in.read(buffer); ptr = 0; if (len <= 0) return -1; } return buffer[ptr++]; }
        long nextLong() throws IOException { int c; do { c = read(); } while (c <= 32 && c >= 0); long v = 0; while (c > 32) { v = v * 10 + c - '0'; c = read(); } return v; }
        int nextInt() throws IOException { return (int) nextLong(); }
    }
}
import sys
from collections import defaultdict

MOD = 10**9 + 7
data = list(map(int, sys.stdin.buffer.read().split()))
n, m, k = data[:3]
ranges = defaultdict(list)
cols = []
idx = 3
for _ in range(k):
    x, y = data[idx], data[idx + 1]
    idx += 2
    cols.append(y)
    if x == 1:
        if n == 1:
            print(0)
            raise SystemExit
        ranges[y].append((2, n))
    elif x == n:
        if n == 1:
            print(0)
            raise SystemExit
        ranges[y].append((1, n - 1))
    else:
        ranges[y].append((1, x - 1))
        ranges[y].append((x + 1, n))

dp = {}
if not ranges[1]:
    ranges[1].append((1, n))
    dp[(1, n)] = 1
else:
    dp[ranges[1][0]] = 1
    if len(ranges[1]) > 1:
        dp[ranges[1][1]] = 0
if not ranges[m]:
    ranges[m].append((1, n))

cols = sorted(set(cols + [1, m]))

def overlap(x, y):
    if x[0] > y[0]:
        x, y = y, x
    if x[1] < y[0]:
        return 0
    if y[1] < x[1]:
        return y[1] - y[0] + 1
    return x[1] - y[0] + 1

def length(r):
    return r[1] - r[0] + 1

for i in range(len(cols) - 1):
    c1, c2 = cols[i], cols[i + 1]
    ndp = {}
    for r1 in ranges[c1]:
        ways = dp.get(r1, 0)
        for r2 in ranges[c2]:
            if c1 + 1 == c2:
                add = ways * overlap(r1, r2) % MOD
            else:
                add = ways * length(r1) % MOD * length(r2) % MOD * pow(n, c2 - c1 - 2, MOD) % MOD
            ndp[r2] = (ndp.get(r2, 0) + add) % MOD
    dp = ndp

for r in sorted(dp):
    if r[1] == n:
        print(dp[r] % MOD)
        break

全部评论

(0) 回帖
加载中...
话题 回帖

等你来战

查看全部

热门推荐