RでFashion MNISTメモ データの可視化
今更ながら、RStudioのkeras
パッケージのチュートリアルを読んでいたら、Fashion MNISTデータセットをRからでも簡単に取って来れるようだったので、チュートリアルの実践とこのようなデータのggplot2
でのデータ可視化方法を含めてメモ。
パッケージの読み込みとデータ準備
まずは、パッケージを読み込みます。
library(tidyverse) library(keras) library(Rtsne)
KerasやTensorFlowの準備がまだの場合は、以下から始めます。
install_keras()
何も指定しない場合だと、CPUバージョンとなります。 GPUを用いる場合には、以下のようにします。
install_keras(tensorflow = "gpu")
エラーでKerasのセットアップに失敗する場合には、Anacondaの環境をアップデートしてみると、成功する可能性があります。
conda update -n base -c defaults conda
上記の準備が終わったら、いよいよデータセットを呼び出せます。一番最初はデータをダウンロードしてくるため、終わるまでに4, 5分程度かかります。
fashion_mnist <- dataset_fashion_mnist()
データ確認
fashion_mnist
はお馴染みのMNISTのデータと同じ訓練用とテスト用のイメージとラベルのデータがlist形式になっています。list形式からarray形式に直しておきます。
c(c(train_images, train_labels), c(test_images, test_labels)) %<-% fashion_mnist
この時、%<-%
が複数代入演算子として用いることができます。それぞれのデータの次元を確認します。
> dim(train_images) #> [1] 60000 28 28 > dim(train_labels) #> [1] 60000 > dim(test_images) #> [1] 10000 28 28 > dim(test_labels) #> [1] 10000
画像は28 x 28ピクセルのデータ、ラベルは0から9のintegerとなっています。
> head(train_labels, 30) #> [1] 9 0 0 3 0 2 7 2 5 5 0 9 5 5 7 9 1 0 6 4 3 1 4 8 4 3 0 2 4 4
これらの数字とファッションアイテムとの対応は、チュートリアルやdataset_fashion_mnist()
のヘルプから確認できます。
画像データの可視化
まずは、ggplot2
で1画像だけの可視化を行います。最初の画像データをdata.frameに変換し可視化します。
image1 <- train_images[1, , ] %>% as_data_frame() %>% rownames_to_column(var = "y") %>% gather(x, value, V1:V28, -y) %>% mutate(x = as.numeric(str_replace(x, "V", ""))) %>% mutate(y = as.numeric(y)) image1 %>% ggplot(aes(x = x, y = y, fill = value)) + geom_tile() + scale_fill_gradient(low = "white", high = "black", na.value = NA) + scale_y_reverse() + labs(x = "", y = "") + theme_minimal() + theme(panel.grid = element_blank(), legend.position = "none")
シューズ的なもの(Ankle boot)が可視化できました。array形式データを可視化しやすいdata.frameに変換するには一発ではいきませんが、通常の横持ちデータから縦持ちデータへの変換を行う感じにすればいけます。
t-SNEによる可視化
t-SNEでファッションアイテムの画像を2次元で可視化していきます。
train_images <- array_reshape(train_images, c(nrow(train_images), 784)) / 255
array形式からmatrix形式に変換した上で、2000個のファッションアイテムをランダムに取り出して、t-SNEを適用します。perplexityなどのパラメータはデフォルトのままです。
set.seed(1234) size <- 2000 index <- sample(1:nrow(train_images), size) tsne <- Rtsne(train_images[index, ]) tsne_df <- data.frame(tsne1 = tsne$Y[, 1], tsne2 = tsne$Y[, 2], labels = as.factor(train_labels[index])) tsne_df %>% ggplot(aes(x = tsne1, y = tsne2, colour = labels)) + geom_point(size = 3, alpha = 0.1) + geom_text(aes(label = labels)) + theme_bw()
1のTrouserや5, 7, 9のSandal, Sneaker, Ankle boot、8のBagは比較的分類しやすそうです。2のPulloverや4のCoatなどは分類しにくいアイテムであることがわかります。
次回は、Keras自体を使っていくつかモデルを作り、その分類精度を眺めてみます。