HOME/AI/

自作でMNISTを分類しよう

Article Outline
TOC
Collection Outline

数字のクラス分類器をつくろう

ここでは、ニューラルネットワークとかの実装ではなく、実際に手で理論だてて数字のクラス分類器を作っていましょう。
fastaiのデータセットを使って、3と7の数字を分類してみよう。 クラス分類器を作るにあたって、以下のような手順で行っていこうと思う。
今回使うのは、平均絶対誤差(以下、L1ノルム)と平均平方二乗誤差(以下、L2ノルム)を使ってやってみよう。

  • それぞれの数字の全ての画像を重ね合わせて、平均を求める
  • 個々の画像と上記で求めた平均の画像のL1ノルムとL2ノルムを求める
  • 識別できているかを確認する

では、早速やってみようと思うが、L1ノルムとL2ノルムの違いを簡単に触れておこう。

平均絶対誤差(L1ノルム)

まず、求め方は(個々の画素値 - 平均の画素値)の絶対値の和の平均から算出します。
数式にすると、

L1ノルム = \frac{1}{n}\ \sum_{i=1}^{n} |a_i - f_i|

平均二乗誤差(L2ノルム)

求め方は、(個々の画素値 - 平均画素値)の2乗の平均から算出します。
数式にすると、

L2ノルム = \sqrt{\frac{1}{n}\ \sum_{i=1}^{n} (a_i - f_i)^2}

こうすることで、大きい誤差ではより大きいペナルティを与え、小さい誤差には寛容な値を出すことができます。 これで必要な知識がそろったので、実際にコーディングしていきましょう。

数字ごとの画像の平均を求めよう

まずは、fastaiをインポートします。

! [ -e /content ] && pip install -Uqq fastbook
import fastbook
fastbook.setup_book()

from fastbook import *
from fastai.vision.all import *

次に、MNISTのデータセットを取得します。

path = untar_data(URLs.MNIST_SAMPLE)

今回の使うのは3と7の数字だけなので、その画像を取得します。

threes = (path/'train'/'3').ls().sorted()
sevens = (path/'train'/'7').ls().sorted()

試しに、3の画像を出力してみましょう。

im3_path = threes[1]
im3 = Image.open(im3_path)
im3

また、画像の画素値を見てみましょう。

im3_t = tensor(im3)
df = pd.DataFrame(im3_t[4:15, 4:22])
df.style.set_properties(**{'font-style':'6pt'}).background_gradient('Greys')

次に、2次元の画素値をテンソル型にしましょう。

seven_tensors = [tensor(Image.open(o)) for o in sevens]
three_tensors = [tensor(Image.open(o)) for o in threes]

そしたら、服すのテンソルを積み重ねて1つのテンソルにして、ついでに正則化しましょう。

stacked_sevens = torch.stack(seven_tensors).float() / 255
stacked_threes = torch.stack(three_tensors).float() / 255

また、テンソルのサイズを確認してみましょう。

stacked_threes.shape

>>> torch.Size([6131, 28, 28])

このことから、6131枚のがぞうがあり、それぞれ28×28ピクセルで構成されていることが分かります。
では、この章最後に平均を求めましょう。といっても簡単です。
全ての画像の画素値の平均を求めたいのでmean関数に0次元目(3であれば6131)を表す0を引数として代入します。

mean3 = stacked_threes.mean(0)
mean7 = stacked_sevens.mean(0)

では、平均の画像を出力してみましょう。

show_image(mean3)

ぼやけているように見えると思います。これが、全ての画像を積み重ねたときの平均値です。
お疲れさまでした。

L1ノルムとL2ノルムを求めよう

ここは簡単なので、さくっと説明します。
といっても、上記ですでに公式を確認していただければ、簡単できると思います。

dist_3_abs = (a_3 - mean3).abs().mean()
dist_3_sqr = ((a_3 - mean3)**2).mean().sqrt()
dist_3_abs, dist_3_sqr

同様に7にも

dist_7_abs = (a_3 - mean7).abs().mean()
dist_7_sqr = ((a_3 - mean7)**2).mean().sqrt()
dist_7_abs, dist_7_sqr

ちなみに上のコードは下のコードと等価です。

F.l1_loss(a_3.float(), mean7)
F.mse_loss(a_3, mean7).sqrt()

以上です。
関数とかも用意していただいているので、本当に助かります(笑)。

精度を確かめてみよう

今までに書いたコードをすべての画像に対して実行できるように拡大してみましょう。
と、その前に検証セットの画像を一つのテンソルに積み重ねましょう。

valid_3_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'3').ls()])
valid_3_tens = valid_3_tens.float() / 255
valid_7_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'7').ls()])
valid_7_tens = valid_7_tens.float() / 255

次は、L1ノルムを返す簡単な関数を定義します。

def mnist_distance(a, b):
  return (a - b).abs().mean((-1, -2))

mnist_distance(a_3, mean3)

次に、ある画像が平均の3との距離が近い(L1の値が小さい)場合に3であるということにしたい。
そこで以下のような関数を用意します。

def is_3(x):
  return mnist_distance(x, mean3) < mnist_distance(x, mean7)

では、最後に精度を見ていきましょう。

accuracy_3s = is_3(valid_3_tens).float().mean()
accuracy_7s = (1 - is_3(valid_7_tens).float()).mean()
accuracy_3s, accuracy_7s

かなりの高精度な結果となりました。
ただし、今回は3と7に限定したのですべての数字に拡張するとどうなるのか、時間があれば取り組んでみたいと思います。
次回は確率勾配降下法とやらを用いてみたいと思います。やっと機械学習ぽっくなってきましたね(笑)。
お付き合いくださいありがとうございました。