Matrix Calculator with パーサジェネレーター

Matrix Calculatorとは、知る人ぞ知るアホ問題です。

サンプルを見ただけで嫌になるすごい問題なんですが、今回はパーサジェネレーターで楽々解決を目指します。

使用したのは

caper -- LALR(1) パーサジェネレータ

です。C++でのソースコードが出力に出てくるので、プロコンにうってつけですね。

まず、構文ファイルがこちら。

%token Digit<V> Var<V> Add Mul Sub Space LBranket RBranket LMatBranket RMatBranket SemiColon Comma Quote;
%namespace calc;

Expr<V> : [Identity] Term(0)
        | [MakeAdd] Expr(0) Add Term(1)
        | [MakeSub] Expr(0) Sub Term(1)
        ;


Term<V> : [Identity] Fact(0)
        | [MakeMul] Term(0) Mul Fact(1)
        ;

Fact<V>
        : [Identity] Prim(0)
        | [MakeNot] Sub Fact(0)
        ;

Prim<V> : [Identity] Inum(0)
        | [Identity] Var(0)
        | [Identity] Mat(0)
        | [Identity] IdxPrim(0)
        | [Identity] LBranket Expr(0) RBranket
        | [Identity] TransPrim(0)
        ;

IdxPrim<V>
        : [MakeIdx] Prim(0) LBranket Expr(1) Comma Expr(2) RBranket
        ;

TransPrim<V>
        : [MakeTrans] Prim(0) Quote
        ;

Mat<V>    : [Identity] LMatBranket RowSeq(0) RMatBranket
        ;

RowSeq<V>
        : [Identity] Row(0)
        | [RowMerge] RowSeq(0) SemiColon Row(1)
        ;

Row<V>    : [Identity] Expr(0)
        | [ColumMerge] Row(0) Space Expr(1)
        ;

Inum<V> : [Identity] Digit(0)
        | [InumMerge] Inum(0) Digit(1)
        ;

長いですが、問題文のBNFをほとんどそのまま書き下しただけです。でも20分ぐらいかかった。

次にソースがこちら

#include <iostream>
#include <algorithm>
#include <array>
#include <vector>
#include <cassert>

using namespace std;
using ll = long long;
using ull = unsigned long long;

template<uint MD>
struct ModInt {
    uint v;
    ModInt() : v(0) {}
    ModInt(ll v) : v(normS(v%MD+MD)) {}
    uint value() {return v;}
    static uint normS(const uint &x) {return (x<MD)?x:x-MD;};
    static ModInt make(const uint &x) {ModInt m; m.v = x; return m;}
    const ModInt operator+(const ModInt &r) const {return make(normS(v+r.v));}
    const ModInt operator-(const ModInt &r) const {return make(normS(v+normS(MD-r.v)));}
    const ModInt operator*(const ModInt &r) const {return make((ull)v*r.v%MD);}
    const ModInt operator-() const {return make(normS(MD-v));};
    ModInt& operator+=(const ModInt &r) {return *this=*this+r;}
    ModInt& operator-=(const ModInt &r) {return *this=*this-r;}
    ModInt& operator*=(const ModInt &r) {return *this=*this*r;}
    static ModInt inv(const ModInt &x) {
        return pow(ModInt(x), MD-2);
    }
};

template<class D>
struct Matrix {
    vector<vector<D>> d;
    int N, M;
    Matrix() {}
    Matrix(int N, int M) : N(N), M(M),
        d(vector<vector<D>>(N, vector<D>(M, D(0)))) {}

    vector<D>& operator[](int p) {return d[p];}
    const vector<D>& operator[](int p) const {return d[p];}

    const Matrix operator+(const Matrix &right) const {
        assert(right.N == N && right.M == M);
        Matrix res(N, M);
        for (int i = 0; i < N; i++) {
            for (int j = 0; j < M; j++) {
                res[i][j] = d[i][j]+right[i][j];
            }
        }
        return res;
    }
    const Matrix operator-() const {
        Matrix res(N, M);
        for (int i = 0; i < N; i++) {
            for (int j = 0; j < M; j++) {
                res[i][j] = -d[i][j];
            }
        }
        return res;
    }
    const Matrix operator*(const Matrix &right) const {
        if (N == 1 and M == 1) {
            return right * d[0][0];
        }
        if (right.N == 1 and right.M == 1) {
            return *this * right[0][0];
        }
        assert(M == right.N);
        Matrix res(N, right.M), r(right.M, right.N);
        for (int i = 0; i < right.M; i++) {
            for (int j = 0; j < right.N; j++) {
                r[i][j] = right[j][i]; 
            }
        }
        for (int i = 0; i < N; i++) {
            for (int j = 0; j < right.M; j++) {
                for (int k = 0; k < M; k++) {
                    res[i][j] += d[i][k]*r[j][k];
                }
            }
        }
        return res;
    }

    const Matrix operator*(const D &x) const {
        Matrix res(N, M);
        for (int i = 0; i < N; i++) {
            for (int j = 0; j < M; j++) {
                res[i][j] = d[i][j]*x;
            }
        }
        return res;
    }
};

using V = Matrix<ModInt<(1<<15)>>;

#include "parser.hpp"

class unexpected_char : public std::exception {};

struct SemanticAction {
    void syntax_error() {}
    void stack_overflow() {}
    void downcast(V& x, V y) { x = y; }
    void upcast(V& x, V y) { x = y; }

    V Identity(V n) { return n; }
    V MakeNot(V n) {
        return -n;
    }
    V MakeMul(V x, V y) {
        return x * y;
    }
    V MakeAdd(V x, V y) {
        return x + y;
    }
    V MakeSub(V x, V y) {
        return x + (-y);
    }
    V MakeTrans(V x) {
        V res(x.M, x.N);
        for (int i = 0; i < x.N; i++) {
            for (int j = 0; j < x.M; j++) {
                res.d[j][i] = x.d[i][j];
            }
        }
        return res;
    }
    V MakeIdx(V d, V x, V y) {
        V res(x.M, y.M);
        for (int i = 0; i < x.M; i++) {
            for (int j = 0; j < y.M; j++) {
                res[i][j] = d[x[0][i].v-1][y[0][j].v-1];
            }
        }
        return res;
    }
    V RowMerge(V x, V y) {
        assert(x.M == y.M);
        V res(x.N+y.N, x.M);
        for (int i = 0; i < x.N; i++) {
            for (int j = 0; j < x.M; j++) {
                res.d[i][j] = x[i][j];
            }
        }
        for (int i = 0; i < y.N; i++) {
            for (int j = 0; j < y.M; j++) {
                res.d[x.N+i][j] = y[i][j];
            }
        }
        return res;
    }
    V ColumMerge(V x, V y) {
        assert(x.N == y.N);
        V res(x.N, x.M+y.M);
        for (int i = 0; i < x.N; i++) {
            for (int j = 0; j < x.M; j++) {
                res.d[i][j] = x[i][j];
            }
        }
        for (int i = 0; i < y.N; i++) {
            for (int j = 0; j < y.M; j++) {
                res.d[i][x.M+j] = y[i][j];
            }
        }
        return res;
    }
    V InumMerge(V x, V y) {
        V res(1, 1);
        res[0][0] = x[0][0] * 10 + y[0][0];
        return res;
    }
};


V var[26];

class scanner {
public:
    using char_type = char;

private:
    string s;
    int pos;
public:
    scanner(string s) : s(s), pos(0) {}

    calc::Token get(V& v) {
        char c = getc();
        switch (c) {
        case '+': return calc::token_Add;
        case '*': return calc::token_Mul;
        case '-': return calc::token_Sub;
        case ' ': return calc::token_Space;
        case '(': return calc::token_LBranket;
        case ')': return calc::token_RBranket;
        case '[': return calc::token_LMatBranket;
        case ']': return calc::token_RMatBranket;
        case ';': return calc::token_SemiColon;
        case ',': return calc::token_Comma;
        case '\'': return calc::token_Quote;
        case '.': return calc::token_eof;
        }
        if (isdigit(c)) {
            v = V(1, 1);
            v[0][0] = c - '0';
            return calc::token_Digit;
        }
        if ('A' <= c and c <= 'Z') {
            v = var[c - 'A'];
            return calc::token_Var;
        }
        cerr << c << endl;
        throw unexpected_char();
    }

private:
    char_type getc() {
        if (pos == s.size()) return EOF;
        return s[pos++];
    }
};



int main() {
    while (true) {
        int n;
        cin >> n;
        cin.ignore();
        if (n == 0) break;
        V last;
        for (int i = 0; i < n; i++) {
            string s;
            getline(cin, s);
            int VC = s[0] - 'A';
            s = s.substr(2);
            scanner scn(s);
            SemanticAction sa;
            calc::Parser<V, SemanticAction> parser(sa);
            calc::Token token;
            while (true) {
                V v;
                token = scn.get(v);
                if (parser.post(token, v)) { break; }
            }
            V v;
            if (parser.accept(v)) {
                var[VC] = v;
                last = v;
            }
            for (int i = 0; i < last.N; i++) {
                for (int j = 0; j < last.M; j++) {
                    printf("%d", last.d[i][j].v);
                    if (j != last.M-1) {
                        printf(" ");
                    } else {
                        printf("\n");
                    }
                }
            }
        }
        printf("-----\n");
    }
    return 0;
}

自動mod取り構造体やら行列やらを使いまくっているので読むのが大変。

ACしたの?

ACしてません。出力がインデントを消しても40kbyteぐらいあって、AOJに投げられません。残念。

実装時間も63分かかったようです。

ですが、構文解析部分がバグらないというのは非常に強く、デバッグ時間は10分程度です。(サンプルが通っただけなので実際に正しいかはわかりません) かなり強力なんじゃないかと思いました。(ICPCではもちろん使えないんだけど…)