マージテクの逆でよく出てくる"2個の木のうち小さいほうを探す"処理ってcoroutineと相性がいいよね

背景

次のような問題を考えます

  • 2個の木が与えられます。部分木の頂点数を$n, m$とした時に、$O(\min{(n, m)})$時間で小さいほうの部分木の頂点を列挙してください。

このような問題は「データ構造をマージする一般的なテクの逆」などと呼ばれるテクニックを使う問題で出てきます。具体例としては I - 盆栽 が一番有名だと思います。

冒頭の問題ですが、解法自体は対して難しくなく、「2つの木に並列にBFS/DFSして、どちらかが終わったら打ち切ればいい」というだけの話です。ですがいざ実装をしようとするとなかなか面倒です。しかもウニグラフ等で計算量が壊れがちだったりして厄介です。

実はこの実装はcoroutineと呼ばれる概念と相性が良いです。coroutineはC++だとC++20で入った機能 + 主な用途が並列処理やI/O bottleneckの処理等なので、おそらく競プロでの知名度は低いと思いますが、大体の新しめの言語には実装されている機能です。

実際に冒頭の問題を実装することを考えます。まず、$O(\max{(n, m)})$時間かけていいときの実装例を示します。ただ愚直にdfsをしているだけです。

using Tree = vector<vector<int>>;

void list_vertex(const Tree& tree, int u, int p, vector<int>& result) {
    result.push_back(u);
    for (int v : tree[u]) {
        if (v == p) continue;
        list_vertex(tree, v, p, result);
    }
}

vector<int> small_tree_vertex(const Tree& tree1, const Tree& tree2) {
    vector<int> result1, result2;
    list_vertex(tree1, 0, -1, result1);
    list_vertex(tree2, 0, -1, result2);

    if (result1.size() < result2.size()) {
        return result1;
    } else {
        return result2;
    }
}

これをcoroutineを使って実装すると次のようになります。

using Tree = vector<vector<int>>;

// https://github.com/lewissbaker/cppcoro/blob/master/include/cppcoro/recursive_generator.hpp
cppcoro::recursive_generator<int> list_vertex(const Tree& tree, int u, int p) {
    co_yield u;
    for (int v : tree[u]) {
        if (v == p) continue;
        co_yield list_vertex(tree, v, p);
    }
}

vector<int> small_tree_vertex(const Tree& tree1, const Tree& tree2) {
    vector<int> result1, result2;
    auto co1 = list_vertex(tree1, 0, -1);
    auto co2 = list_vertex(tree2, 0, -1);
    for (auto it1 = co1.begin(), it2 = co2.begin();; it1++, it2++) {
        if (it1 == co1.end()) return result1;
        if (it2 == co2.end()) return result2;
        result1.push_back(*it1);
        result2.push_back(*it2);
    }
}

少しlist_small_tree_vertexがごちゃごちゃしましたが、これで $O(\min{(n, m)})$ 時間で動作します。並列BFSを実装したことがあればなかなか驚きの実装量だと思います。また、C++23ならばstd::views::zipを使えばより簡潔な実装になるはずです。

coroutineというのは、ざっくり言うと「途中で中断と再開」が可能な関数です。実際に、新しいlist_vertex関数は、「頂点を見つけたら(= co_yield uにたどり着いたら)その頂点を返して関数を中断、そしてit++が呼ばれたらdfsをそこから再開」という挙動をします。なので、list_vertexの帰り値を普通のイテレーターのように扱い、どちらかのイテレーターが末尾に到達したらそこまでの結果を返すだけでよいです。

なお、C++だと再帰関数をcoroutineにするにはcppcoro::recursive_generatorのような追加実装が必要なようですが、MITライセンスで公開されているので適切にやれば自分で実装しなくても大丈夫です。

実際に盆栽を解いたコードはこちらです: Submission #49240549 - 東京大学プログラミングコンテスト2014 。冒頭(286行目まで)にこのrecursive_generatorが張り付けられているのでウォっとなりますが、それ以降だけ見ると結構簡潔ではないでしょうか。