一道有趣的题。
3s+O2
原题大概是 这道。
看到模 3 显然是出题人精心构造的。
于是我们打出一张组合数模 3 的表。
好多 0 啊~
于是我们考虑如何快速求出组合数不是 0 的位置。
先说一下我考场上想到的 log 做法:
我们考虑 lucas
,对于组合数(nm),其实就是(NiMi) 的乘积,其中 Mi,Ni 表示 m,n 在3进制下的第 i 位。
我们发现 (NiMi) 只要有一位是 0,最终结果就是0,所以我们只要枚举m 后枚举三进制下每一位是比 m 小的 n 就行了,这样每次找是 log 的,所以复杂度是 O(0 的个数 ×logn),hdu
上过了,可惜模拟赛的时候被卡掉了……
我们考虑 Lucas
是怎么递归下来的。对于每一层,我们只有 6 种方案使得 (nm)>0,我们直接反向递归上去,这样就能每次O(1) 搜出所有结果了。
关于 0 的个数为什么这么多(打表可得 ),我们发现每次递归下去的时候,我们都可以把序列根据mod3 分成 3 段:0,1,2,而每一层只有 6 个组合数大于0,这样我们有T(n)=6T(n/3)+O(n)。
下图截自出题人的solution
:
那就假装能过吧~~(反正比我的复杂度优)~~。
我被卡的代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
| #include <bits/stdc++.h>
using namespace std;
const int maxn = 100005;
#define LL long long #define inf 0x3f3f3f3f #define put putchar('\n') #define sqr(x) ((x)*(x)) #define re register #define ret return puts("-1"),0;
inline char gc() { static char buf[100000],*p1=buf,*p2=buf; return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++; } #define gc getchar inline int read() { char c=getchar(); int tot=1; while ((c<'0'|| c>'9')&&c!='-') c=getchar(); if (c=='-') { tot=-1; c=getchar(); } int sum=0; while (c>='0'&&c<='9') { sum=sum*10+c-'0'; c=getchar(); } return sum*tot; } inline void wr(int x) { if (x<0) { putchar('-'); wr(-x); return; } if(x>=10)wr(x/10); putchar(x%10+'0'); } inline void wri(int x) { wr(x); putchar(' '); }
int a[maxn], b[maxn]; int C[3333][3333];
inline int Lucas(int n, int m) { if (!n) return 1; if (n <= 3000 && m <= 3000) return C[m][n]; return Lucas(n / 3, m / 3) * C[m % 3][n % 3] % 3; }
int topp; int bit[20]; int jbit[20];
inline void pre(int x) { memset(bit, 0, sizeof(bit)); memset(jbit, 0, sizeof(jbit)); topp = 0; if (!x) { topp = 1; return; } while (x) { bit[++topp] = x % 3; x /= 3; } }
inline int nxt() { jbit[1]++; for (register int i = 1; i < topp; ++i) { if (jbit[i] <= bit[i]) break; jbit[i] = 0; jbit[i + 1]++; } if (jbit[topp] > bit[topp]) return -1; register int ans = 0; for (register int i = topp; i; --i) { ans *= 3; ans += jbit[i]; } return ans; }
inline void solve() { int cnt = 0; register int n = read(); for (register int i = 0; i < n; ++i) a[i] = read(); for (register int i = 0; i < n; ++i) b[i] = read(); for (register int i = 0, c, j; i < (n << 1) - 1; ++i) { c = 0; pre(i); for (j = 0; ~j && j < n; j = nxt()) { cnt++; if (i - j < n && a[j] && b[i - j]) (c += Lucas(j, i) * a[j] * b[i - j]) %= 3; } wri(c); } puts(""); }
int main() { freopen("cal.in", "r", stdin); freopen("cal.out", "w", stdout); for (register int i = 0; i <= 3000; ++i) { C[i][0] = 1; for (register int j = 1; j <= i; ++j) C[i][j] = (C[i - 1][j - 1] + C[i - 1][j]) % 3; } register int T = read(); while (T--) solve(); fclose(stdin); fclose(stdout); return 0; }
|
正解:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
| #include <bits/stdc++.h>
using namespace std;
inline char gc() { static const int L = 23333; static char sxd[L], *sss = sxd, *ttt = sxd; if (sss == ttt) { ttt = (sss = sxd) + fread(sxd, 1, L, stdin); if (sss == ttt) return EOF; } return *sss++; }
#define dd c = gc() template <class T> inline bool read(T& x) { x = 0; register char dd; register bool flag = false; for (; !isdigit(c); dd) { if(c == '-') flag = true; else if(c == EOF) return false; } for (; isdigit(c); dd) x = (x << 1) + (x << 3) + (c ^ 48); if (flag) x = -x; return true; } #undef dd
template <class T> inline void write(T x) { if(!x) { putchar('0'); return; } if (x < 0) putchar('-'), x = -x; int buf[20]; *buf = 0; while (x) { buf[++(*buf)] = x % 10; x /= 10; } while (*buf) putchar(buf[(*buf)--] | 48); }
template <class T> inline void writesp(T x) { write(x); putchar(' '); }
const int maxn = 100005;
int a[maxn], b[maxn], c[maxn << 1], n; const int x[] = {0, 0, 1, 0, 1, 2}; const int y[] = {0, 1, 1, 2, 2, 2}; const int z[] = {1, 1, 1, 1, 2, 1};
inline void dfs(const register int xx, const register int yy, const register int zz) { for (register int i = 0, xxx, yyy; i < 6; ++i) { xxx = xx + x[i], yyy = yy + y[i]; if (yyy - xxx < n && xxx < n) c[yyy] += zz * z[i] * a[xxx] * b[yyy - xxx]; if ((xxx || yyy) && (xxx * 3 < n) && (yyy * 3 < ((n << 1) - 1)) && ((yyy - xxx) * 3 < n)) dfs(xxx * 3, yyy * 3, zz * z[i]); } }
int main() { freopen("cal.in", "r", stdin); freopen("cal.out", "w", stdout); int T; read(T); while (T--) { memset(c, 0, sizeof(c)); read(n); for (register int i = 0; i < n; ++i) read(a[i]); for (register int i = 0; i < n; ++i) read(b[i]); dfs(0, 0, 1); for (register int i = 0; i < (n << 1) - 1; ++i) writesp(c[i] % 3); puts(""); } fclose(stdin); fclose(stdout); return 0; }
|