Editorial for duong3982oj Contest 04 - Miku mày mò


Remember to use this editorial only when stuck, and not to copy-paste code from it. Please be respectful to the problem author and editorialist.
Submitting an official solution before solving the problem yourself is a bannable offence.

Author: clue_

Subtask 1

Với subtask này, ta có thể duyệt mọi xâu con, sau đó chia đôi và kiểm tra tính đối xứng.

Độ phức tạp: ~O~ (~n^3~).

Subtask 2

Với subtask này, thay vì kiểm tra xâu con bằng cách duyệt, ta có thể tiền xử lý: ~f[i][j] = 0/1~ nghĩa là có đối xứng hay không.

Độ phức tạp: ~O~ (~n^2~).

Subtask 3

Với subtask này, mọi xâu con chẵn đều đối xứng. Lúc này, ta có thể dễ dàng tính kết quả.

Độ phức tạp: ~O~ (~n~).

Subtask 4

Cải tiến từ subtask 2, ta có thể tìm kiếm nhị phân để tìm vị trí xa nhất thỏa mãn tính chất siêu đối xứng.

Độ phức tạp: ~O~ (~n \times log_ {n}~).

Code mẫu:

#include<bits/stdc++.h>
#define task "PALIN"
#define int long long
#define mod 1000000007
#define base 31
#define maxn 2000005
using namespace std;

int n;
int a[maxn];
char s[maxn];
int pw[maxn];
int hashL[maxn];
int hashR[maxn];

int get_hashL(int l, int r) {
    return ((hashL[r] - hashL[l-1] * pw[r-l+1]) % mod + mod) % mod;
}

int get_hashR(int l, int r) {
    return ((hashR[l] - hashR[r+1] * pw[r-l+1]) % mod + mod) % mod;
}

bool check1(int i, int x) {
    if(i + x - 1 > n || i - x + 1 < 1) return false;
    return (get_hashL(i - x + 1, i + x - 1) == get_hashR(i - x + 1, i + x - 1));
}

bool check2(int i, int j, int x) {
    if(i - x + 1 < 1 || j + x - 1 > n) return false;
    return (get_hashL(i - x + 1, j + x - 1) == get_hashR(i - x + 1, j + x - 1));
}

void initHash() {
    pw[0] = 1;
    for(int i = 1; i <= n; i++) {
        pw[i] = (1LL * pw[i-1] * base) % mod;
    }
    for(int i = 1; i <= n; i++) {
        hashL[i] = (hashL[i-1] * base + (s[i] - 'a')) % mod;
    }
    for(int i = n; i >= 1; i--) {
        hashR[i] = (hashR[i+1] * base + (s[i] - 'a')) % mod;
    }
}

// Cấu trúc cây Fenwick (BIT) để xử lý câu hỏi và cập nhật các đoạn con đối xứng
struct {
    int s[maxn];
    void init() {
        for(int i = 1; i <= n; i++) {
            s[i] = 0;
        }
    }
    void update(int x, int val) {
        for(; x > 0; x -= x & (-x)) {
            s[x] += val;
        }
    }
    int get(int x) {
        int sum = 0;
        for(; x <= n; x += x & (-x)) {
            sum += s[x];
        }
        return sum;
    }
} bit[2];

// Hàm tìm kiếm nhị phân để xác định các đoạn con đối xứng
int bs1(int lo, int hi, int i) {
    if(check1(i, hi)) return hi;
    while(hi - lo > 1) {
        int mid = (lo + hi) / 2;
        if(check1(i, mid)) lo = mid;
        else hi = mid;
    }
    return lo;
}

int bs2(int lo, int hi, int i) {
    if(check2(i, i + 1, hi)) return hi;
    while(hi - lo > 1) {
        int mid = (lo + hi) / 2;
        if(check2(i, i + 1, mid)) lo = mid;
        else hi = mid;
    }
    return lo;
}

vector<int> add[maxn], del[maxn];

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);

    cin >> n;
    for(int i = 1; i <= n; i++) {
        cin >> s[i];
    }

    initHash();

    // Xử lý cho các chuỗi đối xứng từ trái sang phải
    for(int i = 1; i <= n; i++) {
        int r = bs1(1, min(i, n - i + 1), i);
        a[i] = r;
        add[i].push_back(i);
        del[min(i + 2 * r - 1, n)].push_back(i);
    }

    int ds = 0;
    for(int i = 1; i <= n; i++) {
        int ope = 1 - (i % 2);
        ds += bit[ope].get(max(1LL, i - (2 * a[i] - 1)));
        for(int x : add[i]) {
            bit[x % 2].update(x, 1);
        }
        for(int x : del[i]) {
            bit[x % 2].update(x, -1);
        }
    }

    // Xử lý cho các chuỗi đối xứng từ phải sang trái
    for(int i = 1; i <= n; i++) {
        add[i].clear();
        del[i].clear();
    }
    bit[0].init();
    bit[1].init();

    for(int i = 1; i < n; i++) {
        if(s[i] != s[i + 1]) continue;
        int r = bs2(1, min(i, n - i), i);
        a[i] = r;
        add[i].push_back(i);
        del[min(i + 2 * r, n)].push_back(i);
    }

    for(int i = 1; i < n; i++) {
        if(s[i] != s[i + 1]) {
            for(int x : del[i]) bit[x % 2].update(x, -1);
            continue;
        }
        int ope = (i % 2);
        ds += bit[ope].get(max(1LL, i - a[i] * 2));
        for(int x : add[i]) bit[x % 2].update(x, 1);
        for(int x : del[i]) bit[x % 2].update(x, -1);
    }

    cout << ds << '\n';
}

Comments

Please read the guidelines before commenting.


There are no comments at the moment.