今週も特にありません

進捗どうですか?

tidymodelsでロジスティック回帰 parsnip::logistic_reg

書き慣れたglmのみで分析を回してしまっていることが多いため、tidymodelsを用いることで、このようなtidyな感じに分析できるということを一通り確認したメモ。

ここでは、kernlabspamデータを用いて、スパムメールかそうではないかをロジスティック回帰parsnip::logistic_regにて分類する。(本気で分類精度を高めたいならば、他のアルゴリズムを使うべきでしょう)

> library(tidyverse)
> library(tidymodels)
>
> data(spam, package = "kernlab")
> spam %>% skimr::skim()
Skim summary statistics
 n obs: 4601 
 n variables: 58 

─ Variable type:factor ──────────────────────────────────────────────────
 variable missing complete    n n_unique                  top_counts ordered
     type       0     4601 4601        2 non: 2788, spa: 1813, NA: 0   FALSE

─ Variable type:numeric ──────────────────────────────────────────────────
          variable missing complete    n     mean      sd p0   p25    p50     p75     p100     hist
           address       0     4601 4601   0.21     1.29   0  0     0       0        14.28 ▇▁▁▁▁▁▁▁
         addresses       0     4601 4601   0.049    0.26   0  0     0       0         4.41 ▇▁▁▁▁▁▁▁
               all       0     4601 4601   0.28     0.5    0  0     0       0.42      5.1  ▇▁▁▁▁▁▁▁
          business       0     4601 4601   0.14     0.44   0  0     0       0         7.14 ▇▁▁▁▁▁▁▁
        capitalAve       0     4601 4601   5.19    31.73   1  1.59  2.28    3.71   1102.5  ▇▁▁▁▁▁▁▁
       capitalLong       0     4601 4601  52.17   194.89   1  6    15      43      9989    ▇▁▁▁▁▁▁▁
      capitalTotal       0     4601 4601 283.29   606.35   1 35    95     266     15841    ▇▁▁▁▁▁▁▁
        charDollar       0     4601 4601   0.076    0.25   0  0     0       0.052     6    ▇▁▁▁▁▁▁▁
   charExclamation       0     4601 4601   0.27     0.82   0  0     0       0.32     32.48 ▇▁▁▁▁▁▁▁
          charHash       0     4601 4601   0.044    0.43   0  0     0       0        19.83 ▇▁▁▁▁▁▁▁
  charRoundbracket       0     4601 4601   0.14     0.27   0  0     0.065   0.19      9.75 ▇▁▁▁▁▁▁▁
     charSemicolon       0     4601 4601   0.039    0.24   0  0     0       0         4.38 ▇▁▁▁▁▁▁▁
 charSquarebracket       0     4601 4601   0.017    0.11   0  0     0       0         4.08 ▇▁▁▁▁▁▁▁
        conference       0     4601 4601   0.032    0.29   0  0     0       0        10    ▇▁▁▁▁▁▁▁
            credit       0     4601 4601   0.086    0.51   0  0     0       0        18.18 ▇▁▁▁▁▁▁▁
                cs       0     4601 4601   0.044    0.36   0  0     0       0         7.14 ▇▁▁▁▁▁▁▁
              data       0     4601 4601   0.097    0.56   0  0     0       0        18.18 ▇▁▁▁▁▁▁▁
            direct       0     4601 4601   0.065    0.35   0  0     0       0         4.76 ▇▁▁▁▁▁▁▁
               edu       0     4601 4601   0.18     0.91   0  0     0       0        22.05 ▇▁▁▁▁▁▁▁
             email       0     4601 4601   0.18     0.53   0  0     0       0         9.09 ▇▁▁▁▁▁▁▁
              font       0     4601 4601   0.12     1.03   0  0     0       0        17.1  ▇▁▁▁▁▁▁▁
              free       0     4601 4601   0.25     0.83   0  0     0       0.1      20    ▇▁▁▁▁▁▁▁
            george       0     4601 4601   0.77     3.37   0  0     0       0        33.33 ▇▁▁▁▁▁▁▁
                hp       0     4601 4601   0.55     1.67   0  0     0       0        20.83 ▇▁▁▁▁▁▁▁
               hpl       0     4601 4601   0.27     0.89   0  0     0       0        16.66 ▇▁▁▁▁▁▁▁
          internet       0     4601 4601   0.11     0.4    0  0     0       0        11.11 ▇▁▁▁▁▁▁▁
               lab       0     4601 4601   0.099    0.59   0  0     0       0        14.28 ▇▁▁▁▁▁▁▁
              labs       0     4601 4601   0.1      0.46   0  0     0       0         5.88 ▇▁▁▁▁▁▁▁
              mail       0     4601 4601   0.24     0.64   0  0     0       0.16     18.18 ▇▁▁▁▁▁▁▁
              make       0     4601 4601   0.1      0.31   0  0     0       0         4.54 ▇▁▁▁▁▁▁▁
           meeting       0     4601 4601   0.13     0.77   0  0     0       0        14.28 ▇▁▁▁▁▁▁▁
             money       0     4601 4601   0.094    0.44   0  0     0       0        12.5  ▇▁▁▁▁▁▁▁
            num000       0     4601 4601   0.1      0.35   0  0     0       0         5.45 ▇▁▁▁▁▁▁▁
           num1999       0     4601 4601   0.14     0.42   0  0     0       0         6.89 ▇▁▁▁▁▁▁▁
             num3d       0     4601 4601   0.065    1.4    0  0     0       0        42.81 ▇▁▁▁▁▁▁▁
            num415       0     4601 4601   0.048    0.33   0  0     0       0         4.76 ▇▁▁▁▁▁▁▁
            num650       0     4601 4601   0.12     0.54   0  0     0       0         9.09 ▇▁▁▁▁▁▁▁
             num85       0     4601 4601   0.11     0.53   0  0     0       0        20    ▇▁▁▁▁▁▁▁
            num857       0     4601 4601   0.047    0.33   0  0     0       0         4.76 ▇▁▁▁▁▁▁▁
             order       0     4601 4601   0.09     0.28   0  0     0       0         5.26 ▇▁▁▁▁▁▁▁
          original       0     4601 4601   0.046    0.22   0  0     0       0         3.57 ▇▁▁▁▁▁▁▁
               our       0     4601 4601   0.31     0.67   0  0     0       0.38     10    ▇▁▁▁▁▁▁▁
              over       0     4601 4601   0.096    0.27   0  0     0       0         5.88 ▇▁▁▁▁▁▁▁
             parts       0     4601 4601   0.013    0.22   0  0     0       0         8.33 ▇▁▁▁▁▁▁▁
            people       0     4601 4601   0.094    0.3    0  0     0       0         5.55 ▇▁▁▁▁▁▁▁
                pm       0     4601 4601   0.079    0.43   0  0     0       0        11.11 ▇▁▁▁▁▁▁▁
           project       0     4601 4601   0.079    0.62   0  0     0       0        20    ▇▁▁▁▁▁▁▁
                re       0     4601 4601   0.3      1.01   0  0     0       0.11     21.42 ▇▁▁▁▁▁▁▁
           receive       0     4601 4601   0.06     0.2    0  0     0       0         2.61 ▇▁▁▁▁▁▁▁
            remove       0     4601 4601   0.11     0.39   0  0     0       0         7.27 ▇▁▁▁▁▁▁▁
            report       0     4601 4601   0.059    0.34   0  0     0       0        10    ▇▁▁▁▁▁▁▁
             table       0     4601 4601   0.0054   0.076  0  0     0       0         2.17 ▇▁▁▁▁▁▁▁
        technology       0     4601 4601   0.097    0.4    0  0     0       0         7.69 ▇▁▁▁▁▁▁▁
            telnet       0     4601 4601   0.065    0.4    0  0     0       0        12.5  ▇▁▁▁▁▁▁▁
              will       0     4601 4601   0.54     0.86   0  0     0.1     0.8       9.67 ▇▁▁▁▁▁▁▁
               you       0     4601 4601   1.66     1.78   0  0     1.31    2.64     18.75 ▇▃▁▁▁▁▁▁
              your       0     4601 4601   0.81     1.2    0  0     0.22    1.27     11.11 ▇▂▁▁▁▁▁▁

はじめに、spamデータをtrainとtestに分割します。ここは、rsampleを使うことで簡単にできる。

> spam_split <- spam %>% initial_split(prop = 0.8, strata = "type")
> spam_train_df <- spam_split %>% training()
> spam_test_df <- spam_split %>% testing()

次に、recipesを用いて、標準化する前処理を準備して、クロスバリデーションさせる。

> spam_recipe <- recipe(type ~., spam_train_df) %>%
+   step_center(all_predictors()) %>%
+   step_scale(all_predictors()) %>%
+   prep(spam_train_df)
> 
> cv_tbl <- spam_recipe %>%
+   juice() %>%
+   vfold_cv(v = 10)
> cv_tbl
#  10-fold cross-validation 
# A tibble: 10 x 2
   splits             id    
   <list>             <chr> 
 1 <split [3.3K/369]> Fold01
 2 <split [3.3K/369]> Fold02
 3 <split [3.3K/368]> Fold03
 4 <split [3.3K/368]> Fold04
 5 <split [3.3K/368]> Fold05
 6 <split [3.3K/368]> Fold06
 7 <split [3.3K/368]> Fold07
 8 <split [3.3K/368]> Fold08
 9 <split [3.3K/368]> Fold09
10 <split [3.3K/368]> Fold10

いよいよモデルを定義し、分割したそれぞれに対して学習と予測を行う。

> lr <- logistic_reg() %>%
+   set_engine("glm")
> 
> cv_fit_tbl <- cv_tbl %>%
+   mutate(fitted = map(splits, ~ fit(lr, type ~ ., data = analysis(.)))) %>%
+   mutate(pred = map2(fitted, splits, ~ predict(.x, assessment(.y)) %>%
+                        bind_cols(assessment(.y) %>% select(type))))

全体と分割したそれぞれに対する予測精度を確認しておく。

> cv_fit_tbl %>%
+   unnest(pred) %>%
+   metrics(truth = type, estimate = .pred_class)
# A tibble: 2 x 3
  .metric  .estimator .estimate
  <chr>    <chr>          <dbl>
1 accuracy binary         0.928
2 kap      binary         0.848
> 
> cv_fit_tbl %>%
+   unnest(pred) %>%
+   group_by(id) %>%
+   metrics(truth = type, estimate = .pred_class)
# A tibble: 20 x 4
   id     .metric  .estimator .estimate
   <chr>  <chr>    <chr>          <dbl>
 1 Fold01 accuracy binary         0.913
 2 Fold02 accuracy binary         0.927
 3 Fold03 accuracy binary         0.916
 4 Fold04 accuracy binary         0.935
 5 Fold05 accuracy binary         0.935
 6 Fold06 accuracy binary         0.932
 7 Fold07 accuracy binary         0.913
 8 Fold08 accuracy binary         0.935
 9 Fold09 accuracy binary         0.943
10 Fold10 accuracy binary         0.932
11 Fold01 kap      binary         0.820
12 Fold02 kap      binary         0.844
13 Fold03 kap      binary         0.819
14 Fold04 kap      binary         0.865
15 Fold05 kap      binary         0.864
16 Fold06 kap      binary         0.843
17 Fold07 kap      binary         0.819
18 Fold08 kap      binary         0.866
19 Fold09 kap      binary         0.878
20 Fold10 kap      binary         0.860

意外と分類精度が高い印象?分割された中で大きな精度の違いはなさそうということで、trainデータ全体を使って学習させることにする。

> lr_fitted <- lr %>%
+   fit(type ~ ., data = spam_recipe %>% juice())
>
> lr_fitted %>%
+   pluck("fit") %>%
+   summary()
Call:
stats::glm(formula = formula, family = stats::binomial, data = data)

Deviance Residuals: 
    Min       1Q   Median       3Q      Max  
-4.4160  -0.1822   0.0000   0.0842   4.7901  

Coefficients:
                   Estimate Std. Error z value Pr(>|z|)    
(Intercept)       -11.03057    2.21921  -4.971 6.68e-07 ***
make               -0.11567    0.08275  -1.398 0.162183    
address            -0.18874    0.09824  -1.921 0.054699 .  
all                -0.01276    0.06590  -0.194 0.846415    
num3d               4.13411    2.99730   1.379 0.167809    
our                 0.32354    0.07183   4.504 6.66e-06 ***
over                0.17603    0.07044   2.499 0.012457 *  
remove              1.11893    0.17777   6.294 3.09e-10 ***
internet            0.33036    0.10318   3.202 0.001366 ** 
order               0.34472    0.10244   3.365 0.000765 ***
mail                0.05606    0.04921   1.139 0.254569    
receive            -0.08768    0.07037  -1.246 0.212778    
will               -0.18379    0.07907  -2.325 0.020098 *  
people             -0.01070    0.07714  -0.139 0.889698    
report              0.08011    0.06044   1.325 0.185006    
addresses           0.44321    0.27905   1.588 0.112219    
free                1.01890    0.14696   6.933 4.11e-12 ***
business            0.78720    0.15591   5.049 4.44e-07 ***
email               0.01583    0.07185   0.220 0.825669    
you                 0.16197    0.06965   2.326 0.020036 *  
credit              0.46118    0.32447   1.421 0.155223    
your                0.33921    0.07446   4.555 5.23e-06 ***
font                0.25124    0.17733   1.417 0.156556    
num000              0.77821    0.18296   4.253 2.11e-05 ***
money               0.28989    0.11954   2.425 0.015308 *  
hp                 -4.26292    0.70479  -6.048 1.46e-09 ***
hpl                -0.69120    0.40816  -1.693 0.090369 .  
george            -32.84208    7.40414  -4.436 9.18e-06 ***
num650              0.25363    0.12315   2.059 0.039454 *  
lab                -2.38990    1.44426  -1.655 0.097974 .  
labs               -0.08034    0.15910  -0.505 0.613579    
telnet              0.35114    0.30083   1.167 0.243113    
num857              1.11316    0.86138   1.292 0.196257    
data               -0.64293    0.24542  -2.620 0.008802 ** 
num415              0.17660    0.47874   0.369 0.712205    
num85              -1.61614    0.75269  -2.147 0.031782 *  
technology          0.52329    0.15197   3.443 0.000574 ***
num1999             0.05341    0.07911   0.675 0.499578    
parts              -0.17965    0.11995  -1.498 0.134203    
pm                 -0.44567    0.19173  -2.324 0.020099 *  
direct             -0.09013    0.15100  -0.597 0.550583    
cs                -17.52864   12.18618  -1.438 0.150320    
meeting            -1.68504    0.50663  -3.326 0.000881 ***
original           -0.20614    0.15032  -1.371 0.170268    
project            -0.90776    0.33530  -2.707 0.006784 ** 
re                 -0.80397    0.16877  -4.764 1.90e-06 ***
edu                -1.02751    0.22487  -4.569 4.89e-06 ***
table              -0.28628    0.22561  -1.269 0.204474    
conference         -1.42051    0.58870  -2.413 0.015824 *  
charSemicolon      -0.37106    0.13565  -2.736 0.006228 ** 
charRoundbracket   -0.15627    0.11291  -1.384 0.166342    
charSquarebracket  -0.12351    0.13754  -0.898 0.369219    
charExclamation     0.24374    0.06468   3.768 0.000164 ***
charDollar          1.25848    0.19345   6.505 7.74e-11 ***
charHash            0.84666    0.55791   1.518 0.129125    
capitalAve          0.08332    0.63932   0.130 0.896308    
capitalLong         1.67504    0.59140   2.832 0.004621 ** 
capitalTotal        0.64674    0.16980   3.809 0.000140 ***
---
Signif. codes:  0***0.001**0.01*0.05 ‘.’ 0.1 ‘ ’ 1

(Dispersion parameter for binomial family taken to be 1)

    Null deviance: 4937.8  on 3681  degrees of freedom
Residual deviance: 1365.7  on 3624  degrees of freedom
AIC: 1481.7

Number of Fisher Scoring iterations: 13

この結果をもとにtestデータに対して予測を行う。

> lr_pred <- lr_fitted %>%
+   predict(spam_recipe %>% bake(spam_test_df) %>% select(-type)) %>%
+   mutate(truth = spam_test_df %>% select(type) %>% pull())

最後に予測精度を確認する。yardstickを用いれば、混同行列の結果を簡単に可視化できる。

> lr_pred %>%
+   metrics(truth = truth, estimate = .pred_class)
# A tibble: 2 x 3
  .metric  .estimator .estimate
  <chr>    <chr>          <dbl>
1 accuracy binary         0.917
2 kap      binary         0.826
> 
> lr_pred %>%
+   conf_mat(truth = truth, estimate = .pred_class) %>%
+   autoplot(type = "heatmap")

f:id:masaqol:20191031010828p:plain

以上のようにして、tidyな世界で一連の分析を行うことができる。tidymodelsに含まれるパッケージには便利な関数がたくさんあるが、書き方に慣れないと恩恵が得られないので、実務投入しながら、色々と試します...

github.com