Fenwick Treeの定数倍高速化
この記事はrogy advent calenderの3日目の記事です。
日本時間では4日になっているような気がしますが、地球上にはまだ3日の部分が残っているのでセーフです。
(本当は丁寧に書くつもりだったのですが、時間がないので、プロコンをしている人向けの記事になりました。悲しいね)
概要
Fenwick Treeというデータ構造があります、これを高速化しようと頑張りました
Fenwick Treeって?
調べればたくさん説明は出てくると思います。
数列について、
という操作が、どちらもO(logN)で出来るデータ構造です。
例えば普通に配列でやると、O(1) / O(N)、累積和を用いると O(N) / O(1) になります。
高速化
template<class T>
struct Fenwick {
int N;
vector<T> seg;
Fenwick(int N) : N(N) {
seg.resize(N+1);
fill_n(begin(seg), N+1, 0);
}
void add(int i, T x) {
i++;
while (i <= N) {
seg[i] += x;
i += i & -i;
}
}
T sum(int i) {
T s{0};
while (i > 0) {
s += seg[i];
i -= i & -i;
}
return s;
}
T sum(int a, int b) {
return sum(b) - sum(a);
}
};
まず、これがFenwick Treeのコードです。使いやすいように0-indexedになっています。
上記のコードを見てもわかるように、Fenwick Treeは非常に美しいデータ構造です。
そして定数倍も非常に速いです。これをさらに高速化することなど出来るのでしょうか?
まず、Nが大きく(100,000とか)なった時に、どこがボトルネックになるかを考えます。
メモリアクセスな気がします。log2(N) / 2 回バラバラなところにアクセスするので、これをどうにか減らせないでしょうか?
まず、Fenwick Treeではなく、普通の累積和を求めるSegment Treeで考えます。
考えると、2分木ではなく多分木にすればメモリアクセスの回数が減りそうな気がします。
例えば4分木にすれば、アクセスする場所は半分になります。代わりに、それぞれのノードにサイズ4のFenwick Treeを持たせる必要があります。
このFenwick Treeを愚直に配列/累積和で実装すると、そこが重くなります。
そして元々のFenwick Treeの綺麗な構造が消えるので、要するに遅くなります。
ポインタを使い、愚直に多分木を書いてみたのですが、もうハチャメチャ遅いです。
変なことをしようとすると、綺麗な構造が壊れてなんか逆に遅くなる、ということがわかりました。
予想していた通り、Fenwick Treeの高速化は難しいです。
なので、SIMD命令を使います。CPUはIntel 6700で、Intel AVX2を使用しています。
SIMD命令は、連続した要素に対して、一括で命令を行うことができます。
最近のはすごくて、ポインタの配列について、一括でその指す場所の値をgetできます(gather命令)
これにより、sum関数の高速化を図ります。
更に、8分木にして、ノードごとの累積和の計算を高速化します。
これにより、get関数の高速化を図ります。
template<int X>
struct FenwickSimd {
int *seg_base;
int *(seg[X]);
m256 segC;
FenwickSimd() {}
FenwickSimd(int N) {
assert(N < 1<<(3*X));
int S = 0;
int segCbuf[8] = {};
for (int i = 0; i < X; i++) {
N = (N+7)/8;
segCbuf[i] = S;
S += 8*(N+1);
}
seg_base = (int *)aligned_alloc(32, sizeof(int)*S);
for (int i = 0; i < X; i++) {
seg[i] = seg_base+segCbuf[i];
}
segC = _mm256_load_si256((m256 *)segCbuf);
}
void add(int p, int x) {
for (int i = 0; i < X; i++) {
int dp = p&7, up = p&~7;
const m256 pd = _mm256_set1_epi32(dp);
const m256 base = _mm256_set_epi32(7,6,5,4,3,2,1,0);
m256 xx = _mm256_set1_epi32(x);
xx = _mm256_and_si256(xx, _mm256_cmpgt_epi32(base, pd));
m256 tar = _mm256_load_si256((m256 *)(seg[i]+up));
tar = _mm256_add_epi32(tar, xx);
_mm256_store_si256((m256 *)(seg[i]+up), tar);
p >>= 3;
}
}
int sum(int p) {
int s{0};
const m256 off = _mm256_set_epi32(21,18,15,12,9,6,3,0);
m256 adr = _mm256_set1_epi32(p);
adr = _mm256_srlv_epi32(adr, off);
adr = _mm256_add_epi32(adr, segC);
m256 buf = _mm256_setzero_si256();
const m256 mask = _mm256_set_epi32(
7<X?-1:0,6<X?-1:0,5<X?-1:0,4<X?-1:0,
3<X?-1:0,2<X?-1:0,1<X?-1:0,0<X?-1:0);
buf = _mm256_mask_i32gather_epi32(buf, seg_base, adr, mask, 4);
buf = _mm256_hadd_epi32(buf, buf);
buf = _mm256_hadd_epi32(buf, buf);
s += _mm256_extract_epi32(buf, 0);
s += _mm256_extract_epi32(buf, 4);
return s;
}
int sum(int a, int b) {
return sum(b) - sum(a);
}
};
こちらが完成したコードになります。
適当にベンチマークした結果ですが、N=200,000だとget/sum共に2.5倍ほど速くなりました。すごい
でも2D Bitにすると遅かったり、なんかいろいろ難しいです。悩ましいですね
プログラムは、8分木を作って、getはgatherで各段から回収し、sumは各段の累積和を更新しています。
SIMD命令、適当に関数をウリャオイするだけで動いてすごい