ローリングコンバットピッチなう!

AIとか仮想化とかペーパークラフトとか

chainer用のMNISTのデータの中身を確認する

機械学習とかDeep Learningとかに興味は持っているものの中々取り組めていません。
だいぶ前にscikit-learnでSVMとかは動かしてみたのですが、最近chainerを動かしてみたいと思い、まずは定番のMNISTから...というところなんですが、動かすだけなら最近はそこいら中にサンプルコードが転がっていてコピペで動かせるわけで、下記のAI研究所さんのブログとか見ながら準備をしていました。
機械学習用ライブラリ「Chainer」を使ったディープラーニング
https://ai-kenkyujo.com/2017/09/18/chainer-deeplearning/

サンプルコード動かす動機として、

  1. chainerでの基本的なニューラルネットワークの組み方、使い方を理解する
  2. GPU無しのCPU環境(しかもシングルコア)でどの程度の処理時間がかかるか把握する
  3. chainerで画像(とりあえずグレースケールで良い)を扱う場合のデータ構造を理解する、その例題としてMNISTを見る
というのがあって、前者については動かす前にサンプルコードを自分の頭で読み解いています。
chainerの場合、MNISTのデータが

train, test = chainer.datasets.get_mnist()

というコードで自動的にDLされ、ローカルにキャッシュされるみたいなのですが、このtrain,testに含まれるデータの形式をちゃんと理解するためにtrain内の任意の位置のデータについて画像とラベルを表示する簡単なScriptを書きました。

とりあえず理解した事として

  • trainには60,000点のデータが入っている
  • iをデータのindexとすると、train[i][0]に手書き数字の画像情報が1次元のベクタとして入っている。train[i][1]にはintでその手書き数字を示すラベルが入っている。(0〜9)
  • 画像情報は0〜255のGray Scaleデータを255で割り算したFloat値で、予めchainerで使える様に正規化済み(scikit-learnでDLすると画像は縦横2次元の整数行列で0〜255のGray Scale値が入っていたと思います。)
  • なので、画像をmatplotlibやopencvなんかで画面表示したい場合はtrain[i][0]の各要素を255倍した上で整数化し、更に28x28の行列に構成しなおす必要があります。(MNISTのデータは縦横28ピクセルのGray Scale画像)
判っている人にはなんでも無い話だと思いますが、色々な解説記事読んでも意外とすっきり整理されて説明されていないんですよね。

というわけで、自分で下記python script書いて色々データを見ていましたが、かなり崩れた字が入っていてちょっとびっくり。人間の目で見ても「むむ?」ってなるデータが結構入っていますね。

python3 + numpy + matplotlib + chainerがインストールされている環境(自分はanaconda3で作ったpython3.6環境にchainerをインストールしました)で下記のscriptを実行し、0〜59999までのindex番号を打ち込むとその番号に対応した画像とラベルが表示されます。画像ウィンドウを消すと次のindex番号を入力できます。
終了する時はindex番号として-1を入力してください。

# -*- Coding: utf-8 -*-

# Numpy
import numpy as np
# Chainer
import chainer
# Matplotlib
import matplotlib.pyplot as plt

train, test = chainer.datasets.get_mnist()
train_max = len(train)

plt.gray()

while(True):
    try:
        index = int(input('enter data no.(0 - %d) (-1: exit): ' % (train_max - 1)))
        if index >= train_max:
            print('Maximum data no. is %d' % (train_max - 1))
            print('Try again!')
        elif index < 0:
            print('Exit')
            break
        else:
            data = train[index][0]
            num  = train[index][1]

            img = (data * 255).astype('uint8').reshape(28,28)
            print('number = %d' % num)
            plt.imshow(img)
            plt.show()
    except Exception as e:
        print('Exception: ',e.args)

ちなみにAI研究所さんのブログに載っているMNISTのコードをそのまま動かした結果は以下の通りでした。
使ったCPUはIntel(R) Celeron(R) CPU 550 @ 2.00GHz。

(py36) toy@toy-VGN-NR52 ~/chainer_test/mnist $ time python mnist_test.py
/home/toy/anaconda3/envs/py36/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
epoch: 0, mean loss: 0.8545098960399627
epoch: 1, mean loss: 0.2892803058028221
epoch: 2, mean loss: 0.2260947120686372
epoch: 3, mean loss: 0.1872865878045559
epoch: 4, mean loss: 0.15878612051407495
epoch: 5, mean loss: 0.1371945430835088
epoch: 6, mean loss: 0.12088873945176601
epoch: 7, mean loss: 0.10777940427263578
epoch: 8, mean loss: 0.09627761915326119
epoch: 9, mean loss: 0.08701442939539751
9674
accuracy: 0.9674

real	1m11.938s
user	0m57.416s
sys	0m1.500s

で、不正解だったテストデータってどんなものか確認するために、AI研究所さんのサンプルコードのテスト部分に下記の様に手を入れて、不正解データを表示しながら確認出来るようにしてみました。最大20個の不正解をmatplotlib使ってテスト画像表示しながらテストを進めます。(画像表示したところで止まるので、画像ウィンドウをクローズするとテストが次に進みます)

cnt = 0
# --ここから--
import matplotlib.pyplot as plt
miscnt = 0
maxmisdisp = 20
plt.gray()
# --ここまで追加--
for i in range(10000):
    x = Variable(np.array([x_test[i]], dtype=np.float32))
    t = t_test[i]
    y = model.forward(x)
    y = np.argmax(y.data[0])
    if t == y:
        cnt += 1
# --ここから--
    else:
        miscnt += 1
        if miscnt <= maxmisdisp:
            mis_data = test[i][0]
            img = (mis_data * 255).astype('uint8').reshape(28,28)
            print('test number = %d, predict result = %d' % (t,y))
            plt.imshow(img)
            plt.show()
# --ここまで追加--

ほとんどの不正解テストデータは、人間が見えればまあ正解するだろうなというものも多いのですが、以下のようなものも入っています。
最初のやつは正解は5なんですが、ニューラルネットワークは3と解答。2つ目の正解は8らしいのですが、ニューラルネットワークは2と解答。このあたりは字汚すぎ!!って気が。こんなのを読まされるAIさんも大変です(笑)