RでFashion MNISTメモ シンプルなCNN
RでFashion MNISTの続き。前回の全結合モデルでは、正解率9割届かずという結果でした。今回は畳み込みニューラルネットを用いることで、それらの正解率が向上するか確認します。
データ準備
前回同様、まずは畳み込みニューラルネットで読み込む形にデータを変換します。
library(keras) library(tidyverse) fashion_mnist <- dataset_fashion_mnist() c(c(train_images, train_labels), c(test_images, test_labels)) %<-% fashion_mnist train_images = array_reshape(train_images, c(nrow(train_images), 28, 28, 1)) / 255 test_images = array_reshape(test_images, c(nrow(test_images), 28, 28, 1)) / 255 train_labels = to_categorical(train_labels, 10) test_labels = to_categorical(test_labels, 10)
今回は、(28, 28, 1)というサイズが入力になります。
モデル定義
畳み込み層のフィルタ数を変更した、2つのモデルを定義します。
cnn_model1 <- keras_model_sequential() %>% layer_conv_2d(filters = 32, kernel_size = c(3, 3), activation = "relu", input_shape = c(28, 28, 1)) %>% layer_max_pooling_2d(pool_size = c(2, 2)) %>% layer_conv_2d(filters = 64, kernel_size = c(3, 3), activation = "relu") %>% layer_max_pooling_2d(pool_size = c(2, 2)) %>% layer_conv_2d(filters = 64, kernel_size = c(3, 3), activation = "relu") %>% layer_max_pooling_2d(pool_size = c(2, 2)) %>% layer_flatten() %>% layer_dropout(rate = 0.5) %>% layer_dense(units = 64, activation = "relu") %>% layer_dense(units = 10, activation = "softmax") cnn_model2 <- keras_model_sequential() %>% layer_conv_2d(filters = 64, kernel_size = c(3, 3), activation = "relu", input_shape = c(28, 28, 1)) %>% layer_max_pooling_2d(pool_size = c(2, 2)) %>% layer_conv_2d(filters = 128, kernel_size = c(3, 3), activation = "relu") %>% layer_max_pooling_2d(pool_size = c(2, 2)) %>% layer_conv_2d(filters = 128, kernel_size = c(3, 3), activation = "relu") %>% layer_max_pooling_2d(pool_size = c(2, 2)) %>% layer_flatten() %>% layer_dropout(rate = 0.5) %>% layer_dense(units = 128, activation = "relu") %>% layer_dense(units = 10, activation = "softmax")
定義したモデルの中身を確認します。
> cnn_model1 #> Model #> _______________________________________________________________ #> Layer (type) Output Shape Param # #> =============================================================== #> conv2d_1 (Conv2D) (None, 26, 26, 32) 320 #> _______________________________________________________________ #> max_pooling2d_1 (MaxPooling (None, 13, 13, 32) 0 #> _______________________________________________________________ #> conv2d_2 (Conv2D) (None, 11, 11, 64) 18496 #> _______________________________________________________________ #> max_pooling2d_2 (MaxPooling (None, 5, 5, 64) 0 #> _______________________________________________________________ #> conv2d_3 (Conv2D) (None, 3, 3, 64) 36928 #> _______________________________________________________________ #> flatten_1 (Flatten) (None, 576) 0 #> _______________________________________________________________ #> dropout_1 (Dropout) (None, 576) 0 #> _______________________________________________________________ #> dense_1 (Dense) (None, 64) 36928 #> _______________________________________________________________ #> dense_2 (Dense) (None, 10) 650 #> =============================================================== #> Total params: 93,322 #> Trainable params: 93,322 #> Non-trainable params: 0 #> _______________________________________________________________ > cnn_model2 #> Model #> _______________________________________________________________ #> Layer (type) Output Shape Param # #> =============================================================== #> conv2d_4 (Conv2D) (None, 26, 26, 64) 640 #> _______________________________________________________________ #> max_pooling2d_3 (MaxPooling (None, 13, 13, 64) 0 #> _______________________________________________________________ #> conv2d_5 (Conv2D) (None, 11, 11, 128) 73856 #> _______________________________________________________________ #> max_pooling2d_4 (MaxPooling (None, 5, 5, 128) 0 #> _______________________________________________________________ #> conv2d_6 (Conv2D) (None, 3, 3, 128) 147584 #> _______________________________________________________________ #> flatten_2 (Flatten) (None, 1152) 0 #> _______________________________________________________________ #> dropout_2 (Dropout) (None, 1152) 0 #> _______________________________________________________________ #> dense_3 (Dense) (None, 128) 147584 #> _______________________________________________________________ #> dense_4 (Dense) (None, 10) 1290 #> =============================================================== #> Total params: 370,954 #> Trainable params: 370,954 #> Non-trainable params: 0 #> _______________________________________________________________
学習・検証
前回作成したfit_model
関数に上記で定義したモデルを流します。
CPUではこれらの実行には少々時間がかかるため、注意が必要です。
result_cnn_model1 <- fit_model(cnn_model1, epochs = 10) result_cnn_model2 <- fit_model(cnn_model2, epochs = 10)
実行が終わったら、それぞれの学習の推移をプロットして確認します。
plot(result_cnn_model1$history)
plot(result_cnn_model2$history)
評価
テストセットに対する損失と精度を確認します。
> result_cnn_model1$evaluate$loss #> [1] 0.2751199 > result_cnn_model2$evaluate$loss #> [1] 0.2569514 > result_cnn_model1$evaluate$acc #> [1] 0.904 > result_cnn_model2$evaluate$acc #> [1] 0.9166
どちらのモデルでも9割を超える正答率になりました。 それぞれのファッションアイテムごとの正解率についても計算します。
> predict_labels <- result_cnn_model1$model %>% + predict_classes(test_images) > table(fashion_mnist$test$y, predict_labels) #> predict_labels #> 0 1 2 3 4 5 6 7 8 9 #> 0 930 1 14 13 4 3 29 0 6 0 #> 1 2 981 0 12 3 0 0 0 2 0 #> 2 23 1 852 12 54 0 56 0 2 0 #> 3 26 5 9 915 20 0 21 0 4 0 #> 4 2 1 43 31 868 0 53 0 2 0 #> 5 0 0 0 0 0 966 0 25 0 9 #> 6 195 2 62 24 83 0 616 0 18 0 #> 7 0 0 0 0 0 4 0 984 0 12 #> 8 4 0 1 3 2 4 2 6 978 0 #> 9 0 0 0 0 0 5 0 45 0 950 > diag(table(fashion_mnist$test$y, predict_labels)) / table(fashion_mnist$test$y) #> #> 0 1 2 3 4 5 6 7 8 9 #> 0.930 0.981 0.852 0.915 0.868 0.966 0.616 0.984 0.978 0.950 > > predict_labels <- result_cnn_model2$model %>% + predict_classes(test_images) > table(fashion_mnist$test$y, predict_labels) #> predict_labels #> 0 1 2 3 4 5 6 7 8 9 #> 0 910 1 23 14 4 1 42 0 5 0 #> 1 0 983 1 14 0 0 0 0 2 0 #> 2 16 2 880 8 43 0 50 0 1 0 #> 3 14 8 10 924 20 0 23 0 1 0 #> 4 0 0 44 23 881 0 49 0 3 0 #> 5 0 0 0 0 0 978 0 11 0 11 #> 6 150 2 61 27 62 0 689 0 9 0 #> 7 0 0 0 0 0 4 0 975 0 21 #> 8 4 1 2 4 3 1 1 2 982 0 #> 9 0 0 1 0 0 4 0 31 0 964 > diag(table(fashion_mnist$test$y, predict_labels)) / table(fashion_mnist$test$y) #> #> 0 1 2 3 4 5 6 7 8 9 #> 0.910 0.983 0.880 0.924 0.881 0.978 0.689 0.975 0.982 0.964
全結合モデルでは、2のPulloverや4のCoatは7割5分程度の正解率であったものが、8割5分を越えるまでに大きく上昇していることが確認できました。