tidymodelsでロジスティック回帰 parsnip::logistic_reg
書き慣れたglmのみで分析を回してしまっていることが多いため、tidymodels
を用いることで、このようなtidyな感じに分析できるということを一通り確認したメモ。
ここでは、kernlab
のspam
データを用いて、スパムメールかそうではないかをロジスティック回帰parsnip::logistic_reg
にて分類する。(本気で分類精度を高めたいならば、他のアルゴリズムを使うべきでしょう)
> library(tidyverse) > library(tidymodels) > > data(spam, package = "kernlab") > spam %>% skimr::skim() Skim summary statistics n obs: 4601 n variables: 58 ─ Variable type:factor ────────────────────────────────────────────────── variable missing complete n n_unique top_counts ordered type 0 4601 4601 2 non: 2788, spa: 1813, NA: 0 FALSE ─ Variable type:numeric ────────────────────────────────────────────────── variable missing complete n mean sd p0 p25 p50 p75 p100 hist address 0 4601 4601 0.21 1.29 0 0 0 0 14.28 ▇▁▁▁▁▁▁▁ addresses 0 4601 4601 0.049 0.26 0 0 0 0 4.41 ▇▁▁▁▁▁▁▁ all 0 4601 4601 0.28 0.5 0 0 0 0.42 5.1 ▇▁▁▁▁▁▁▁ business 0 4601 4601 0.14 0.44 0 0 0 0 7.14 ▇▁▁▁▁▁▁▁ capitalAve 0 4601 4601 5.19 31.73 1 1.59 2.28 3.71 1102.5 ▇▁▁▁▁▁▁▁ capitalLong 0 4601 4601 52.17 194.89 1 6 15 43 9989 ▇▁▁▁▁▁▁▁ capitalTotal 0 4601 4601 283.29 606.35 1 35 95 266 15841 ▇▁▁▁▁▁▁▁ charDollar 0 4601 4601 0.076 0.25 0 0 0 0.052 6 ▇▁▁▁▁▁▁▁ charExclamation 0 4601 4601 0.27 0.82 0 0 0 0.32 32.48 ▇▁▁▁▁▁▁▁ charHash 0 4601 4601 0.044 0.43 0 0 0 0 19.83 ▇▁▁▁▁▁▁▁ charRoundbracket 0 4601 4601 0.14 0.27 0 0 0.065 0.19 9.75 ▇▁▁▁▁▁▁▁ charSemicolon 0 4601 4601 0.039 0.24 0 0 0 0 4.38 ▇▁▁▁▁▁▁▁ charSquarebracket 0 4601 4601 0.017 0.11 0 0 0 0 4.08 ▇▁▁▁▁▁▁▁ conference 0 4601 4601 0.032 0.29 0 0 0 0 10 ▇▁▁▁▁▁▁▁ credit 0 4601 4601 0.086 0.51 0 0 0 0 18.18 ▇▁▁▁▁▁▁▁ cs 0 4601 4601 0.044 0.36 0 0 0 0 7.14 ▇▁▁▁▁▁▁▁ data 0 4601 4601 0.097 0.56 0 0 0 0 18.18 ▇▁▁▁▁▁▁▁ direct 0 4601 4601 0.065 0.35 0 0 0 0 4.76 ▇▁▁▁▁▁▁▁ edu 0 4601 4601 0.18 0.91 0 0 0 0 22.05 ▇▁▁▁▁▁▁▁ email 0 4601 4601 0.18 0.53 0 0 0 0 9.09 ▇▁▁▁▁▁▁▁ font 0 4601 4601 0.12 1.03 0 0 0 0 17.1 ▇▁▁▁▁▁▁▁ free 0 4601 4601 0.25 0.83 0 0 0 0.1 20 ▇▁▁▁▁▁▁▁ george 0 4601 4601 0.77 3.37 0 0 0 0 33.33 ▇▁▁▁▁▁▁▁ hp 0 4601 4601 0.55 1.67 0 0 0 0 20.83 ▇▁▁▁▁▁▁▁ hpl 0 4601 4601 0.27 0.89 0 0 0 0 16.66 ▇▁▁▁▁▁▁▁ internet 0 4601 4601 0.11 0.4 0 0 0 0 11.11 ▇▁▁▁▁▁▁▁ lab 0 4601 4601 0.099 0.59 0 0 0 0 14.28 ▇▁▁▁▁▁▁▁ labs 0 4601 4601 0.1 0.46 0 0 0 0 5.88 ▇▁▁▁▁▁▁▁ mail 0 4601 4601 0.24 0.64 0 0 0 0.16 18.18 ▇▁▁▁▁▁▁▁ make 0 4601 4601 0.1 0.31 0 0 0 0 4.54 ▇▁▁▁▁▁▁▁ meeting 0 4601 4601 0.13 0.77 0 0 0 0 14.28 ▇▁▁▁▁▁▁▁ money 0 4601 4601 0.094 0.44 0 0 0 0 12.5 ▇▁▁▁▁▁▁▁ num000 0 4601 4601 0.1 0.35 0 0 0 0 5.45 ▇▁▁▁▁▁▁▁ num1999 0 4601 4601 0.14 0.42 0 0 0 0 6.89 ▇▁▁▁▁▁▁▁ num3d 0 4601 4601 0.065 1.4 0 0 0 0 42.81 ▇▁▁▁▁▁▁▁ num415 0 4601 4601 0.048 0.33 0 0 0 0 4.76 ▇▁▁▁▁▁▁▁ num650 0 4601 4601 0.12 0.54 0 0 0 0 9.09 ▇▁▁▁▁▁▁▁ num85 0 4601 4601 0.11 0.53 0 0 0 0 20 ▇▁▁▁▁▁▁▁ num857 0 4601 4601 0.047 0.33 0 0 0 0 4.76 ▇▁▁▁▁▁▁▁ order 0 4601 4601 0.09 0.28 0 0 0 0 5.26 ▇▁▁▁▁▁▁▁ original 0 4601 4601 0.046 0.22 0 0 0 0 3.57 ▇▁▁▁▁▁▁▁ our 0 4601 4601 0.31 0.67 0 0 0 0.38 10 ▇▁▁▁▁▁▁▁ over 0 4601 4601 0.096 0.27 0 0 0 0 5.88 ▇▁▁▁▁▁▁▁ parts 0 4601 4601 0.013 0.22 0 0 0 0 8.33 ▇▁▁▁▁▁▁▁ people 0 4601 4601 0.094 0.3 0 0 0 0 5.55 ▇▁▁▁▁▁▁▁ pm 0 4601 4601 0.079 0.43 0 0 0 0 11.11 ▇▁▁▁▁▁▁▁ project 0 4601 4601 0.079 0.62 0 0 0 0 20 ▇▁▁▁▁▁▁▁ re 0 4601 4601 0.3 1.01 0 0 0 0.11 21.42 ▇▁▁▁▁▁▁▁ receive 0 4601 4601 0.06 0.2 0 0 0 0 2.61 ▇▁▁▁▁▁▁▁ remove 0 4601 4601 0.11 0.39 0 0 0 0 7.27 ▇▁▁▁▁▁▁▁ report 0 4601 4601 0.059 0.34 0 0 0 0 10 ▇▁▁▁▁▁▁▁ table 0 4601 4601 0.0054 0.076 0 0 0 0 2.17 ▇▁▁▁▁▁▁▁ technology 0 4601 4601 0.097 0.4 0 0 0 0 7.69 ▇▁▁▁▁▁▁▁ telnet 0 4601 4601 0.065 0.4 0 0 0 0 12.5 ▇▁▁▁▁▁▁▁ will 0 4601 4601 0.54 0.86 0 0 0.1 0.8 9.67 ▇▁▁▁▁▁▁▁ you 0 4601 4601 1.66 1.78 0 0 1.31 2.64 18.75 ▇▃▁▁▁▁▁▁ your 0 4601 4601 0.81 1.2 0 0 0.22 1.27 11.11 ▇▂▁▁▁▁▁▁
はじめに、spam
データをtrainとtestに分割します。ここは、rsample
を使うことで簡単にできる。
> spam_split <- spam %>% initial_split(prop = 0.8, strata = "type") > spam_train_df <- spam_split %>% training() > spam_test_df <- spam_split %>% testing()
次に、recipes
を用いて、標準化する前処理を準備して、クロスバリデーションさせる。
> spam_recipe <- recipe(type ~., spam_train_df) %>% + step_center(all_predictors()) %>% + step_scale(all_predictors()) %>% + prep(spam_train_df) > > cv_tbl <- spam_recipe %>% + juice() %>% + vfold_cv(v = 10) > cv_tbl # 10-fold cross-validation # A tibble: 10 x 2 splits id <list> <chr> 1 <split [3.3K/369]> Fold01 2 <split [3.3K/369]> Fold02 3 <split [3.3K/368]> Fold03 4 <split [3.3K/368]> Fold04 5 <split [3.3K/368]> Fold05 6 <split [3.3K/368]> Fold06 7 <split [3.3K/368]> Fold07 8 <split [3.3K/368]> Fold08 9 <split [3.3K/368]> Fold09 10 <split [3.3K/368]> Fold10
いよいよモデルを定義し、分割したそれぞれに対して学習と予測を行う。
> lr <- logistic_reg() %>% + set_engine("glm") > > cv_fit_tbl <- cv_tbl %>% + mutate(fitted = map(splits, ~ fit(lr, type ~ ., data = analysis(.)))) %>% + mutate(pred = map2(fitted, splits, ~ predict(.x, assessment(.y)) %>% + bind_cols(assessment(.y) %>% select(type))))
全体と分割したそれぞれに対する予測精度を確認しておく。
> cv_fit_tbl %>% + unnest(pred) %>% + metrics(truth = type, estimate = .pred_class) # A tibble: 2 x 3 .metric .estimator .estimate <chr> <chr> <dbl> 1 accuracy binary 0.928 2 kap binary 0.848 > > cv_fit_tbl %>% + unnest(pred) %>% + group_by(id) %>% + metrics(truth = type, estimate = .pred_class) # A tibble: 20 x 4 id .metric .estimator .estimate <chr> <chr> <chr> <dbl> 1 Fold01 accuracy binary 0.913 2 Fold02 accuracy binary 0.927 3 Fold03 accuracy binary 0.916 4 Fold04 accuracy binary 0.935 5 Fold05 accuracy binary 0.935 6 Fold06 accuracy binary 0.932 7 Fold07 accuracy binary 0.913 8 Fold08 accuracy binary 0.935 9 Fold09 accuracy binary 0.943 10 Fold10 accuracy binary 0.932 11 Fold01 kap binary 0.820 12 Fold02 kap binary 0.844 13 Fold03 kap binary 0.819 14 Fold04 kap binary 0.865 15 Fold05 kap binary 0.864 16 Fold06 kap binary 0.843 17 Fold07 kap binary 0.819 18 Fold08 kap binary 0.866 19 Fold09 kap binary 0.878 20 Fold10 kap binary 0.860
意外と分類精度が高い印象?分割された中で大きな精度の違いはなさそうということで、trainデータ全体を使って学習させることにする。
> lr_fitted <- lr %>% + fit(type ~ ., data = spam_recipe %>% juice()) > > lr_fitted %>% + pluck("fit") %>% + summary() Call: stats::glm(formula = formula, family = stats::binomial, data = data) Deviance Residuals: Min 1Q Median 3Q Max -4.4160 -0.1822 0.0000 0.0842 4.7901 Coefficients: Estimate Std. Error z value Pr(>|z|) (Intercept) -11.03057 2.21921 -4.971 6.68e-07 *** make -0.11567 0.08275 -1.398 0.162183 address -0.18874 0.09824 -1.921 0.054699 . all -0.01276 0.06590 -0.194 0.846415 num3d 4.13411 2.99730 1.379 0.167809 our 0.32354 0.07183 4.504 6.66e-06 *** over 0.17603 0.07044 2.499 0.012457 * remove 1.11893 0.17777 6.294 3.09e-10 *** internet 0.33036 0.10318 3.202 0.001366 ** order 0.34472 0.10244 3.365 0.000765 *** mail 0.05606 0.04921 1.139 0.254569 receive -0.08768 0.07037 -1.246 0.212778 will -0.18379 0.07907 -2.325 0.020098 * people -0.01070 0.07714 -0.139 0.889698 report 0.08011 0.06044 1.325 0.185006 addresses 0.44321 0.27905 1.588 0.112219 free 1.01890 0.14696 6.933 4.11e-12 *** business 0.78720 0.15591 5.049 4.44e-07 *** email 0.01583 0.07185 0.220 0.825669 you 0.16197 0.06965 2.326 0.020036 * credit 0.46118 0.32447 1.421 0.155223 your 0.33921 0.07446 4.555 5.23e-06 *** font 0.25124 0.17733 1.417 0.156556 num000 0.77821 0.18296 4.253 2.11e-05 *** money 0.28989 0.11954 2.425 0.015308 * hp -4.26292 0.70479 -6.048 1.46e-09 *** hpl -0.69120 0.40816 -1.693 0.090369 . george -32.84208 7.40414 -4.436 9.18e-06 *** num650 0.25363 0.12315 2.059 0.039454 * lab -2.38990 1.44426 -1.655 0.097974 . labs -0.08034 0.15910 -0.505 0.613579 telnet 0.35114 0.30083 1.167 0.243113 num857 1.11316 0.86138 1.292 0.196257 data -0.64293 0.24542 -2.620 0.008802 ** num415 0.17660 0.47874 0.369 0.712205 num85 -1.61614 0.75269 -2.147 0.031782 * technology 0.52329 0.15197 3.443 0.000574 *** num1999 0.05341 0.07911 0.675 0.499578 parts -0.17965 0.11995 -1.498 0.134203 pm -0.44567 0.19173 -2.324 0.020099 * direct -0.09013 0.15100 -0.597 0.550583 cs -17.52864 12.18618 -1.438 0.150320 meeting -1.68504 0.50663 -3.326 0.000881 *** original -0.20614 0.15032 -1.371 0.170268 project -0.90776 0.33530 -2.707 0.006784 ** re -0.80397 0.16877 -4.764 1.90e-06 *** edu -1.02751 0.22487 -4.569 4.89e-06 *** table -0.28628 0.22561 -1.269 0.204474 conference -1.42051 0.58870 -2.413 0.015824 * charSemicolon -0.37106 0.13565 -2.736 0.006228 ** charRoundbracket -0.15627 0.11291 -1.384 0.166342 charSquarebracket -0.12351 0.13754 -0.898 0.369219 charExclamation 0.24374 0.06468 3.768 0.000164 *** charDollar 1.25848 0.19345 6.505 7.74e-11 *** charHash 0.84666 0.55791 1.518 0.129125 capitalAve 0.08332 0.63932 0.130 0.896308 capitalLong 1.67504 0.59140 2.832 0.004621 ** capitalTotal 0.64674 0.16980 3.809 0.000140 *** --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 (Dispersion parameter for binomial family taken to be 1) Null deviance: 4937.8 on 3681 degrees of freedom Residual deviance: 1365.7 on 3624 degrees of freedom AIC: 1481.7 Number of Fisher Scoring iterations: 13
この結果をもとにtestデータに対して予測を行う。
> lr_pred <- lr_fitted %>% + predict(spam_recipe %>% bake(spam_test_df) %>% select(-type)) %>% + mutate(truth = spam_test_df %>% select(type) %>% pull())
最後に予測精度を確認する。yardstick
を用いれば、混同行列の結果を簡単に可視化できる。
> lr_pred %>% + metrics(truth = truth, estimate = .pred_class) # A tibble: 2 x 3 .metric .estimator .estimate <chr> <chr> <dbl> 1 accuracy binary 0.917 2 kap binary 0.826 > > lr_pred %>% + conf_mat(truth = truth, estimate = .pred_class) %>% + autoplot(type = "heatmap")
以上のようにして、tidyな世界で一連の分析を行うことができる。tidymodels
に含まれるパッケージには便利な関数がたくさんあるが、書き方に慣れないと恩恵が得られないので、実務投入しながら、色々と試します...