ローリングコンバットピッチなう!

AIとか仮想化とかペーパークラフトとか

ニューラルネットワークの無駄使い(九九をニューラルネットワークに学習させる)

[technology]ディープラーニングで九九を学習する(by chainer)

全く持って無駄無駄無駄ぁああああああ!!!!!って感じなのですが、ふと思いついてchainerでニューラルネットワークに九九を丸暗記させてみました(笑)

class DNN(Chain):
    def __init__(self):
        super(DNN, self).__init__(
            l1 = L.Linear(None,100),
            l2 = L.Linear(None,100),
            l3 = L.Linear(None,100)
        )
    def forward(self,x0,x1):
        h = F.relu(self.l1(F.concat((x0,x1),axis=1)))
        h = F.relu(self.l2(h))
        h = self.l3(h)
        return h

ネットワークはこんな感じです。隠れ層2層あるので一応ディープラーニング
forward()の入力にx0,x1という2つのベクトルを取っています。それぞれ0〜9を示すOne-Hotベクトルを想定します。
これを結合して、隠れ層2層を通して、出力段は100ノード。これは分類器風に0〜99を示すのが期待値。
九九なので0の行は要らないのですが、コーディング上0オリジンにしたほうがすっきりするのでこうしました。
出力も最大値が81なので、0〜81の82ノードあれば良いのですが、それではあまりにも予定調和すぎるので敢えて範囲を広げています。

レーニングとテスト用のコード全体は下記の様な感じです。
今回は「丸暗記」が目的なのでドロップアウトは使いません。

# -*- Coding: utf-8 -*-

# Numpy
import numpy as np
from numpy.random import *
# Chainer
import chainer
import chainer.links as L
import chainer.functions as F
from chainer import Chain,optimizers,Variable

# Parameters
n_epoch = 200

# Neural Network

class DNN(Chain):
    def __init__(self):
        super(DNN, self).__init__(
            l1 = L.Linear(None,100),
            l2 = L.Linear(None,100),
            l3 = L.Linear(None,100)
        )
    def forward(self,x0,x1):
        h = F.relu(self.l1(F.concat((x0,x1),axis=1)))
        h = F.relu(self.l2(h))
        h = self.l3(h)
        return h

x0_train = []
x1_train = []
t_train = []

for x0 in range(10):
    x0_vec = [0] * 10
    x0_vec[x0] = 1
    for x1 in range(10):
        x1_vec = [0] * 10
        x1_vec[x1] = 1
        x0_train.append(x0_vec)
        x1_train.append(x1_vec)
        t_train.append(x0 * x1)

x0_train = np.array(x0_train,dtype=np.float32)
x1_train = np.array(x1_train,dtype=np.float32)
t_train = np.array(t_train,dtype=np.int32)

# Create DNN class instance
model = DNN()

# Set optimizer
optimizer = optimizers.Adam()
optimizer.setup(model)

# Training
for epoch in range(n_epoch):
    perm = np.random.permutation(len(x0_train))
    x0v = Variable(x0_train[perm])
    x1v = Variable(x1_train[perm])
    t = Variable(t_train[perm])
    y = model.forward(x0v,x1v)
    model.cleargrads()
    loss = F.softmax_cross_entropy(y, t)
    loss.backward()
    optimizer.update()
    print("epoch: {}, mean loss: {}".format(epoch, loss.data))

# Execute test
ok_cnt = 0
for x0 in range(10):
    for x1 in range(10):
        x0_vec = [0] * 10
        x0_vec[x0] = 1
        x1_vec = [0] * 10
        x1_vec[x1] = 1
        x0_test = []
        x1_test = []
        x0_test.append(x0_vec)
        x1_test.append(x1_vec)
        x0_test = np.array(x0_test,dtype=np.float32)
        x1_test = np.array(x1_test,dtype=np.float32)
        x0v = Variable(x0_test)
        x1v = Variable(x1_test)
        y = model.forward(x0v,x1v)
        y = np.argmax(y.data[0])
        match = False
        if y == x0 * x1:
            ok_cnt += 1
            match = True
        print("{} * {} = Predicted {}, Expected {},Match {}".format(x0,x1,y,x0 * x1,match))

print("Ok {}/Total {}".format(ok_cnt,100))

学習の結果は100エポックくらいで正解率100%になりますが100エポックだとちょいちょい98,99%くらいが混ざるので、110エポックくらい回すと大丈夫でした。


うまく行った時はテストの結果は以下の様に表示されます。
ちなみに隠れ層は1層でも大丈夫ですが、その場合は最低200エポック回さないと正解率100%になりませんでした。
batch normalizationとか入れて早く収束するかは別途試したいと思います。

全く持って無意味なニューラルネットワークの使い方なので良い子は真似しないでください。

0 * 0 = Predicted 0, Expected 0,Match True
0 * 1 = Predicted 0, Expected 0,Match True
0 * 2 = Predicted 0, Expected 0,Match True
0 * 3 = Predicted 0, Expected 0,Match True
0 * 4 = Predicted 0, Expected 0,Match True
0 * 5 = Predicted 0, Expected 0,Match True
0 * 6 = Predicted 0, Expected 0,Match True
0 * 7 = Predicted 0, Expected 0,Match True
0 * 8 = Predicted 0, Expected 0,Match True
0 * 9 = Predicted 0, Expected 0,Match True
1 * 0 = Predicted 0, Expected 0,Match True
1 * 1 = Predicted 1, Expected 1,Match True
1 * 2 = Predicted 2, Expected 2,Match True
1 * 3 = Predicted 3, Expected 3,Match True
1 * 4 = Predicted 4, Expected 4,Match True
1 * 5 = Predicted 5, Expected 5,Match True
1 * 6 = Predicted 6, Expected 6,Match True
1 * 7 = Predicted 7, Expected 7,Match True
1 * 8 = Predicted 8, Expected 8,Match True
1 * 9 = Predicted 9, Expected 9,Match True
2 * 0 = Predicted 0, Expected 0,Match True
2 * 1 = Predicted 2, Expected 2,Match True
2 * 2 = Predicted 4, Expected 4,Match True
2 * 3 = Predicted 6, Expected 6,Match True
2 * 4 = Predicted 8, Expected 8,Match True
2 * 5 = Predicted 10, Expected 10,Match True
2 * 6 = Predicted 12, Expected 12,Match True
2 * 7 = Predicted 14, Expected 14,Match True
2 * 8 = Predicted 16, Expected 16,Match True
2 * 9 = Predicted 18, Expected 18,Match True
3 * 0 = Predicted 0, Expected 0,Match True
3 * 1 = Predicted 3, Expected 3,Match True
3 * 2 = Predicted 6, Expected 6,Match True
3 * 3 = Predicted 9, Expected 9,Match True
3 * 4 = Predicted 12, Expected 12,Match True
3 * 5 = Predicted 15, Expected 15,Match True
3 * 6 = Predicted 18, Expected 18,Match True
3 * 7 = Predicted 21, Expected 21,Match True
3 * 8 = Predicted 24, Expected 24,Match True
3 * 9 = Predicted 27, Expected 27,Match True
4 * 0 = Predicted 0, Expected 0,Match True
4 * 1 = Predicted 4, Expected 4,Match True
4 * 2 = Predicted 8, Expected 8,Match True
4 * 3 = Predicted 12, Expected 12,Match True
4 * 4 = Predicted 16, Expected 16,Match True
4 * 5 = Predicted 20, Expected 20,Match True
4 * 6 = Predicted 24, Expected 24,Match True
4 * 7 = Predicted 28, Expected 28,Match True
4 * 8 = Predicted 32, Expected 32,Match True
4 * 9 = Predicted 36, Expected 36,Match True
5 * 0 = Predicted 0, Expected 0,Match True
5 * 1 = Predicted 5, Expected 5,Match True
5 * 2 = Predicted 10, Expected 10,Match True
5 * 3 = Predicted 15, Expected 15,Match True
5 * 4 = Predicted 20, Expected 20,Match True
5 * 5 = Predicted 25, Expected 25,Match True
5 * 6 = Predicted 30, Expected 30,Match True
5 * 7 = Predicted 35, Expected 35,Match True
5 * 8 = Predicted 40, Expected 40,Match True
5 * 9 = Predicted 45, Expected 45,Match True
6 * 0 = Predicted 0, Expected 0,Match True
6 * 1 = Predicted 6, Expected 6,Match True
6 * 2 = Predicted 12, Expected 12,Match True
6 * 3 = Predicted 18, Expected 18,Match True
6 * 4 = Predicted 24, Expected 24,Match True
6 * 5 = Predicted 30, Expected 30,Match True
6 * 6 = Predicted 36, Expected 36,Match True
6 * 7 = Predicted 42, Expected 42,Match True
6 * 8 = Predicted 48, Expected 48,Match True
6 * 9 = Predicted 54, Expected 54,Match True
7 * 0 = Predicted 0, Expected 0,Match True
7 * 1 = Predicted 7, Expected 7,Match True
7 * 2 = Predicted 14, Expected 14,Match True
7 * 3 = Predicted 21, Expected 21,Match True
7 * 4 = Predicted 28, Expected 28,Match True
7 * 5 = Predicted 35, Expected 35,Match True
7 * 6 = Predicted 42, Expected 42,Match True
7 * 7 = Predicted 49, Expected 49,Match True
7 * 8 = Predicted 56, Expected 56,Match True
7 * 9 = Predicted 63, Expected 63,Match True
8 * 0 = Predicted 0, Expected 0,Match True
8 * 1 = Predicted 8, Expected 8,Match True
8 * 2 = Predicted 16, Expected 16,Match True
8 * 3 = Predicted 24, Expected 24,Match True
8 * 4 = Predicted 32, Expected 32,Match True
8 * 5 = Predicted 40, Expected 40,Match True
8 * 6 = Predicted 48, Expected 48,Match True
8 * 7 = Predicted 56, Expected 56,Match True
8 * 8 = Predicted 64, Expected 64,Match True
8 * 9 = Predicted 72, Expected 72,Match True
9 * 0 = Predicted 0, Expected 0,Match True
9 * 1 = Predicted 9, Expected 9,Match True
9 * 2 = Predicted 18, Expected 18,Match True
9 * 3 = Predicted 27, Expected 27,Match True
9 * 4 = Predicted 36, Expected 36,Match True
9 * 5 = Predicted 45, Expected 45,Match True
9 * 6 = Predicted 54, Expected 54,Match True
9 * 7 = Predicted 63, Expected 63,Match True
9 * 8 = Predicted 72, Expected 72,Match True
9 * 9 = Predicted 81, Expected 81,Match True
Ok 100/Total 100

追記: 2018/11/4
githubにコードを置きました。
github.com