From 0e15fbd4d07f397712cf7108e246c289f9c731ab Mon Sep 17 00:00:00 2001 From: Baoshuo Date: Thu, 1 Dec 2022 21:00:00 +0800 Subject: [PATCH] =?UTF-8?q?P3803=20=E3=80=90=E6=A8=A1=E6=9D=BF=E3=80=91?= =?UTF-8?q?=E5=A4=9A=E9=A1=B9=E5=BC=8F=E4=B9=98=E6=B3=95=EF=BC=88FFT?= =?UTF-8?q?=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit https://www.luogu.com.cn/record/96249847 --- Luogu/P3803/P3803.cpp | 127 +++++++++++++++++++++++++++++++----------- 1 file changed, 93 insertions(+), 34 deletions(-) diff --git a/Luogu/P3803/P3803.cpp b/Luogu/P3803/P3803.cpp index d0882a69..20c88a28 100644 --- a/Luogu/P3803/P3803.cpp +++ b/Luogu/P3803/P3803.cpp @@ -1,39 +1,108 @@ #include #include -#include -#include #include using std::cin; using std::cout; const char endl = '\n'; -const double PI = std::acos(-1); +const int mod = 998244353; -void FFT(std::vector>& a) { - if (a.size() == 1) return; +constexpr long long binpow(long long a, long long b) { + a %= mod; - int m = a.size() >> 1; - std::vector> a0, a1; + long long res = 1; - for (int i = 0; i < m; i++) { - a0.emplace_back(a[i << 1]); - a1.emplace_back(a[i << 1 | 1]); + while (b) { + if (b & 1) res = res * a % mod; + a = a * a % mod; + b >>= 1; } - FFT(a0), FFT(a1); - - std::complex - w0{std::cos(PI / m), std::sin(PI / m)}, - w1{1.0, 0.0}; - - for (int i = 0; i < m; i++) { - a[i] = a0[i] + w1 * a1[i]; - a[i + m] = a0[i] - w1 * a1[i]; - w1 *= w0; - } + return res; } +std::vector number_theoretic_transform(std::vector a) { + // assert(a.size() == (1 << std::__lg(a.size()))); + int k = std::__lg(a.size()); + + for (int i = 0; i < a.size(); i++) { + int t = 0; + + for (int j = 0; j < k; j++) { + if (i & (1 << j)) { + t |= 1 << (k - j - 1); + } + } + + if (i < t) std::swap(a[i], a[t]); + } + + for (int len = 2; len <= a.size(); len <<= 1) { + int m = len >> 1; + long long wn = binpow(3, (mod - 1) / len); + + for (int i = 0; i < a.size(); i += len) { + long long w = 1; + + for (int j = 0; j < m; j++) { + long long u = a[i + j], + v = a[i + j + m] * w % mod; + + a[i + j] = ((u + v) % mod + mod) % mod; + a[i + j + m] = ((u - v) % mod + mod) % mod; + w = w * wn % mod; + } + } + } + + return a; +} + +class Poly : public std::vector { + private: + public: + using std::vector::vector; + + Poly() = default; + + Poly(const std::vector &__v) + : std::vector(__v) {} + + Poly(std::vector &&__v) + : std::vector(std::move(__v)) {} + + Poly operator*(const Poly &b) const { + int n = size() - 1, + m = b.size() - 1, + k = 1 << (std::__lg(n + m) + 1), + inv = binpow(k, mod - 2); + + std::vector f(*this), g(b); + + f.resize(k); + f = number_theoretic_transform(f); + g.resize(k); + g = number_theoretic_transform(g); + + for (int i = 0; i < k; i++) { + f[i] = f[i] * g[i] % mod; + } + + f = number_theoretic_transform(f); + // assert(f.size() > 0) + std::reverse(f.begin() + 1, f.end()); + + std::vector res(n + m + 1); + + for (int i = 0; i <= n + m; i++) { + res[i] = f[i] * inv % mod; + } + + return Poly(res); + } +} poly; + int main() { std::ios::sync_with_stdio(false); cin.tie(nullptr); @@ -42,8 +111,7 @@ int main() { cin >> n >> m; - int k = 1 << (std::__lg(n + m) + 1); - std::vector> f(k), g(k); + Poly f(n + 1), g(m + 1); for (int i = 0; i <= n; i++) { cin >> f[i]; @@ -53,18 +121,9 @@ int main() { cin >> g[i]; } - FFT(f), FFT(g); + auto res = f * g; - for (int i = 0; i < k; i++) { - f[i] *= g[i]; - } - - FFT(f); - std::reverse(f.begin() + 1, f.end()); - - for (int i = 0; i <= n + m; i++) { - cout << static_cast(std::round(f[i].real() / k)) << ' '; - } + for (int x : res) cout << x << ' '; cout << endl;