Matrix Calculatorとは、知る人ぞ知るアホ問題です。
サンプルを見ただけで嫌になるすごい問題なんですが、今回はパーサジェネレーターで楽々解決を目指します。
使用したのは
caper -- LALR(1) パーサジェネレータ
です。C++でのソースコードが出力に出てくるので、プロコンにうってつけですね。
まず、構文ファイルがこちら。
%token Digit<V> Var<V> Add Mul Sub Space LBranket RBranket LMatBranket RMatBranket SemiColon Comma Quote;
%namespace calc;
Expr<V> : [Identity] Term(0)
| [MakeAdd] Expr(0) Add Term(1)
| [MakeSub] Expr(0) Sub Term(1)
;
Term<V> : [Identity] Fact(0)
| [MakeMul] Term(0) Mul Fact(1)
;
Fact<V>
: [Identity] Prim(0)
| [MakeNot] Sub Fact(0)
;
Prim<V> : [Identity] Inum(0)
| [Identity] Var(0)
| [Identity] Mat(0)
| [Identity] IdxPrim(0)
| [Identity] LBranket Expr(0) RBranket
| [Identity] TransPrim(0)
;
IdxPrim<V>
: [MakeIdx] Prim(0) LBranket Expr(1) Comma Expr(2) RBranket
;
TransPrim<V>
: [MakeTrans] Prim(0) Quote
;
Mat<V> : [Identity] LMatBranket RowSeq(0) RMatBranket
;
RowSeq<V>
: [Identity] Row(0)
| [RowMerge] RowSeq(0) SemiColon Row(1)
;
Row<V> : [Identity] Expr(0)
| [ColumMerge] Row(0) Space Expr(1)
;
Inum<V> : [Identity] Digit(0)
| [InumMerge] Inum(0) Digit(1)
;
長いですが、問題文のBNFをほとんどそのまま書き下しただけです。でも20分ぐらいかかった。
次にソースがこちら
#include <iostream>
#include <algorithm>
#include <array>
#include <vector>
#include <cassert>
using namespace std;
using ll = long long;
using ull = unsigned long long;
template<uint MD>
struct ModInt {
uint v;
ModInt() : v(0) {}
ModInt(ll v) : v(normS(v%MD+MD)) {}
uint value() {return v;}
static uint normS(const uint &x) {return (x<MD)?x:x-MD;};
static ModInt make(const uint &x) {ModInt m; m.v = x; return m;}
const ModInt operator+(const ModInt &r) const {return make(normS(v+r.v));}
const ModInt operator-(const ModInt &r) const {return make(normS(v+normS(MD-r.v)));}
const ModInt operator*(const ModInt &r) const {return make((ull)v*r.v%MD);}
const ModInt operator-() const {return make(normS(MD-v));};
ModInt& operator+=(const ModInt &r) {return *this=*this+r;}
ModInt& operator-=(const ModInt &r) {return *this=*this-r;}
ModInt& operator*=(const ModInt &r) {return *this=*this*r;}
static ModInt inv(const ModInt &x) {
return pow(ModInt(x), MD-2);
}
};
template<class D>
struct Matrix {
vector<vector<D>> d;
int N, M;
Matrix() {}
Matrix(int N, int M) : N(N), M(M),
d(vector<vector<D>>(N, vector<D>(M, D(0)))) {}
vector<D>& operator[](int p) {return d[p];}
const vector<D>& operator[](int p) const {return d[p];}
const Matrix operator+(const Matrix &right) const {
assert(right.N == N && right.M == M);
Matrix res(N, M);
for (int i = 0; i < N; i++) {
for (int j = 0; j < M; j++) {
res[i][j] = d[i][j]+right[i][j];
}
}
return res;
}
const Matrix operator-() const {
Matrix res(N, M);
for (int i = 0; i < N; i++) {
for (int j = 0; j < M; j++) {
res[i][j] = -d[i][j];
}
}
return res;
}
const Matrix operator*(const Matrix &right) const {
if (N == 1 and M == 1) {
return right * d[0][0];
}
if (right.N == 1 and right.M == 1) {
return *this * right[0][0];
}
assert(M == right.N);
Matrix res(N, right.M), r(right.M, right.N);
for (int i = 0; i < right.M; i++) {
for (int j = 0; j < right.N; j++) {
r[i][j] = right[j][i];
}
}
for (int i = 0; i < N; i++) {
for (int j = 0; j < right.M; j++) {
for (int k = 0; k < M; k++) {
res[i][j] += d[i][k]*r[j][k];
}
}
}
return res;
}
const Matrix operator*(const D &x) const {
Matrix res(N, M);
for (int i = 0; i < N; i++) {
for (int j = 0; j < M; j++) {
res[i][j] = d[i][j]*x;
}
}
return res;
}
};
using V = Matrix<ModInt<(1<<15)>>;
#include "parser.hpp"
class unexpected_char : public std::exception {};
struct SemanticAction {
void syntax_error() {}
void stack_overflow() {}
void downcast(V& x, V y) { x = y; }
void upcast(V& x, V y) { x = y; }
V Identity(V n) { return n; }
V MakeNot(V n) {
return -n;
}
V MakeMul(V x, V y) {
return x * y;
}
V MakeAdd(V x, V y) {
return x + y;
}
V MakeSub(V x, V y) {
return x + (-y);
}
V MakeTrans(V x) {
V res(x.M, x.N);
for (int i = 0; i < x.N; i++) {
for (int j = 0; j < x.M; j++) {
res.d[j][i] = x.d[i][j];
}
}
return res;
}
V MakeIdx(V d, V x, V y) {
V res(x.M, y.M);
for (int i = 0; i < x.M; i++) {
for (int j = 0; j < y.M; j++) {
res[i][j] = d[x[0][i].v-1][y[0][j].v-1];
}
}
return res;
}
V RowMerge(V x, V y) {
assert(x.M == y.M);
V res(x.N+y.N, x.M);
for (int i = 0; i < x.N; i++) {
for (int j = 0; j < x.M; j++) {
res.d[i][j] = x[i][j];
}
}
for (int i = 0; i < y.N; i++) {
for (int j = 0; j < y.M; j++) {
res.d[x.N+i][j] = y[i][j];
}
}
return res;
}
V ColumMerge(V x, V y) {
assert(x.N == y.N);
V res(x.N, x.M+y.M);
for (int i = 0; i < x.N; i++) {
for (int j = 0; j < x.M; j++) {
res.d[i][j] = x[i][j];
}
}
for (int i = 0; i < y.N; i++) {
for (int j = 0; j < y.M; j++) {
res.d[i][x.M+j] = y[i][j];
}
}
return res;
}
V InumMerge(V x, V y) {
V res(1, 1);
res[0][0] = x[0][0] * 10 + y[0][0];
return res;
}
};
V var[26];
class scanner {
public:
using char_type = char;
private:
string s;
int pos;
public:
scanner(string s) : s(s), pos(0) {}
calc::Token get(V& v) {
char c = getc();
switch (c) {
case '+': return calc::token_Add;
case '*': return calc::token_Mul;
case '-': return calc::token_Sub;
case ' ': return calc::token_Space;
case '(': return calc::token_LBranket;
case ')': return calc::token_RBranket;
case '[': return calc::token_LMatBranket;
case ']': return calc::token_RMatBranket;
case ';': return calc::token_SemiColon;
case ',': return calc::token_Comma;
case '\'': return calc::token_Quote;
case '.': return calc::token_eof;
}
if (isdigit(c)) {
v = V(1, 1);
v[0][0] = c - '0';
return calc::token_Digit;
}
if ('A' <= c and c <= 'Z') {
v = var[c - 'A'];
return calc::token_Var;
}
cerr << c << endl;
throw unexpected_char();
}
private:
char_type getc() {
if (pos == s.size()) return EOF;
return s[pos++];
}
};
int main() {
while (true) {
int n;
cin >> n;
cin.ignore();
if (n == 0) break;
V last;
for (int i = 0; i < n; i++) {
string s;
getline(cin, s);
int VC = s[0] - 'A';
s = s.substr(2);
scanner scn(s);
SemanticAction sa;
calc::Parser<V, SemanticAction> parser(sa);
calc::Token token;
while (true) {
V v;
token = scn.get(v);
if (parser.post(token, v)) { break; }
}
V v;
if (parser.accept(v)) {
var[VC] = v;
last = v;
}
for (int i = 0; i < last.N; i++) {
for (int j = 0; j < last.M; j++) {
printf("%d", last.d[i][j].v);
if (j != last.M-1) {
printf(" ");
} else {
printf("\n");
}
}
}
}
printf("-----\n");
}
return 0;
}
自動mod取り構造体やら行列やらを使いまくっているので読むのが大変。
ACしたの?
ACしてません。出力がインデントを消しても40kbyteぐらいあって、AOJに投げられません。残念。
実装時間も63分かかったようです。
ですが、構文解析部分がバグらないというのは非常に強く、デバッグ時間は10分程度です。(サンプルが通っただけなので実際に正しいかはわかりません)
かなり強力なんじゃないかと思いました。(ICPCではもちろん使えないんだけど…)