

tidymodelsでロジスティック回帰 parsnip::logistic_reg



> library(tidyverse)
> library(tidymodels)
> data(spam, package = "kernlab")
> spam %>% skimr::skim()
Skim summary statistics
 n obs: 4601 
 n variables: 58 

─ Variable type:factor ──────────────────────────────────────────────────
 variable missing complete    n n_unique                  top_counts ordered
     type       0     4601 4601        2 non: 2788, spa: 1813, NA: 0   FALSE

─ Variable type:numeric ──────────────────────────────────────────────────
          variable missing complete    n     mean      sd p0   p25    p50     p75     p100     hist
           address       0     4601 4601   0.21     1.29   0  0     0       0        14.28 ▇▁▁▁▁▁▁▁
         addresses       0     4601 4601   0.049    0.26   0  0     0       0         4.41 ▇▁▁▁▁▁▁▁
               all       0     4601 4601   0.28     0.5    0  0     0       0.42      5.1  ▇▁▁▁▁▁▁▁
          business       0     4601 4601   0.14     0.44   0  0     0       0         7.14 ▇▁▁▁▁▁▁▁
        capitalAve       0     4601 4601   5.19    31.73   1  1.59  2.28    3.71   1102.5  ▇▁▁▁▁▁▁▁
       capitalLong       0     4601 4601  52.17   194.89   1  6    15      43      9989    ▇▁▁▁▁▁▁▁
      capitalTotal       0     4601 4601 283.29   606.35   1 35    95     266     15841    ▇▁▁▁▁▁▁▁
        charDollar       0     4601 4601   0.076    0.25   0  0     0       0.052     6    ▇▁▁▁▁▁▁▁
   charExclamation       0     4601 4601   0.27     0.82   0  0     0       0.32     32.48 ▇▁▁▁▁▁▁▁
          charHash       0     4601 4601   0.044    0.43   0  0     0       0        19.83 ▇▁▁▁▁▁▁▁
  charRoundbracket       0     4601 4601   0.14     0.27   0  0     0.065   0.19      9.75 ▇▁▁▁▁▁▁▁
     charSemicolon       0     4601 4601   0.039    0.24   0  0     0       0         4.38 ▇▁▁▁▁▁▁▁
 charSquarebracket       0     4601 4601   0.017    0.11   0  0     0       0         4.08 ▇▁▁▁▁▁▁▁
        conference       0     4601 4601   0.032    0.29   0  0     0       0        10    ▇▁▁▁▁▁▁▁
            credit       0     4601 4601   0.086    0.51   0  0     0       0        18.18 ▇▁▁▁▁▁▁▁
                cs       0     4601 4601   0.044    0.36   0  0     0       0         7.14 ▇▁▁▁▁▁▁▁
              data       0     4601 4601   0.097    0.56   0  0     0       0        18.18 ▇▁▁▁▁▁▁▁
            direct       0     4601 4601   0.065    0.35   0  0     0       0         4.76 ▇▁▁▁▁▁▁▁
               edu       0     4601 4601   0.18     0.91   0  0     0       0        22.05 ▇▁▁▁▁▁▁▁
             email       0     4601 4601   0.18     0.53   0  0     0       0         9.09 ▇▁▁▁▁▁▁▁
              font       0     4601 4601   0.12     1.03   0  0     0       0        17.1  ▇▁▁▁▁▁▁▁
              free       0     4601 4601   0.25     0.83   0  0     0       0.1      20    ▇▁▁▁▁▁▁▁
            george       0     4601 4601   0.77     3.37   0  0     0       0        33.33 ▇▁▁▁▁▁▁▁
                hp       0     4601 4601   0.55     1.67   0  0     0       0        20.83 ▇▁▁▁▁▁▁▁
               hpl       0     4601 4601   0.27     0.89   0  0     0       0        16.66 ▇▁▁▁▁▁▁▁
          internet       0     4601 4601   0.11     0.4    0  0     0       0        11.11 ▇▁▁▁▁▁▁▁
               lab       0     4601 4601   0.099    0.59   0  0     0       0        14.28 ▇▁▁▁▁▁▁▁
              labs       0     4601 4601   0.1      0.46   0  0     0       0         5.88 ▇▁▁▁▁▁▁▁
              mail       0     4601 4601   0.24     0.64   0  0     0       0.16     18.18 ▇▁▁▁▁▁▁▁
              make       0     4601 4601   0.1      0.31   0  0     0       0         4.54 ▇▁▁▁▁▁▁▁
           meeting       0     4601 4601   0.13     0.77   0  0     0       0        14.28 ▇▁▁▁▁▁▁▁
             money       0     4601 4601   0.094    0.44   0  0     0       0        12.5  ▇▁▁▁▁▁▁▁
            num000       0     4601 4601   0.1      0.35   0  0     0       0         5.45 ▇▁▁▁▁▁▁▁
           num1999       0     4601 4601   0.14     0.42   0  0     0       0         6.89 ▇▁▁▁▁▁▁▁
             num3d       0     4601 4601   0.065    1.4    0  0     0       0        42.81 ▇▁▁▁▁▁▁▁
            num415       0     4601 4601   0.048    0.33   0  0     0       0         4.76 ▇▁▁▁▁▁▁▁
            num650       0     4601 4601   0.12     0.54   0  0     0       0         9.09 ▇▁▁▁▁▁▁▁
             num85       0     4601 4601   0.11     0.53   0  0     0       0        20    ▇▁▁▁▁▁▁▁
            num857       0     4601 4601   0.047    0.33   0  0     0       0         4.76 ▇▁▁▁▁▁▁▁
             order       0     4601 4601   0.09     0.28   0  0     0       0         5.26 ▇▁▁▁▁▁▁▁
          original       0     4601 4601   0.046    0.22   0  0     0       0         3.57 ▇▁▁▁▁▁▁▁
               our       0     4601 4601   0.31     0.67   0  0     0       0.38     10    ▇▁▁▁▁▁▁▁
              over       0     4601 4601   0.096    0.27   0  0     0       0         5.88 ▇▁▁▁▁▁▁▁
             parts       0     4601 4601   0.013    0.22   0  0     0       0         8.33 ▇▁▁▁▁▁▁▁
            people       0     4601 4601   0.094    0.3    0  0     0       0         5.55 ▇▁▁▁▁▁▁▁
                pm       0     4601 4601   0.079    0.43   0  0     0       0        11.11 ▇▁▁▁▁▁▁▁
           project       0     4601 4601   0.079    0.62   0  0     0       0        20    ▇▁▁▁▁▁▁▁
                re       0     4601 4601   0.3      1.01   0  0     0       0.11     21.42 ▇▁▁▁▁▁▁▁
           receive       0     4601 4601   0.06     0.2    0  0     0       0         2.61 ▇▁▁▁▁▁▁▁
            remove       0     4601 4601   0.11     0.39   0  0     0       0         7.27 ▇▁▁▁▁▁▁▁
            report       0     4601 4601   0.059    0.34   0  0     0       0        10    ▇▁▁▁▁▁▁▁
             table       0     4601 4601   0.0054   0.076  0  0     0       0         2.17 ▇▁▁▁▁▁▁▁
        technology       0     4601 4601   0.097    0.4    0  0     0       0         7.69 ▇▁▁▁▁▁▁▁
            telnet       0     4601 4601   0.065    0.4    0  0     0       0        12.5  ▇▁▁▁▁▁▁▁
              will       0     4601 4601   0.54     0.86   0  0     0.1     0.8       9.67 ▇▁▁▁▁▁▁▁
               you       0     4601 4601   1.66     1.78   0  0     1.31    2.64     18.75 ▇▃▁▁▁▁▁▁
              your       0     4601 4601   0.81     1.2    0  0     0.22    1.27     11.11 ▇▂▁▁▁▁▁▁


> spam_split <- spam %>% initial_split(prop = 0.8, strata = "type")
> spam_train_df <- spam_split %>% training()
> spam_test_df <- spam_split %>% testing()


> spam_recipe <- recipe(type ~., spam_train_df) %>%
+   step_center(all_predictors()) %>%
+   step_scale(all_predictors()) %>%
+   prep(spam_train_df)
> cv_tbl <- spam_recipe %>%
+   juice() %>%
+   vfold_cv(v = 10)
> cv_tbl
#  10-fold cross-validation 
# A tibble: 10 x 2
   splits             id    
   <list>             <chr> 
 1 <split [3.3K/369]> Fold01
 2 <split [3.3K/369]> Fold02
 3 <split [3.3K/368]> Fold03
 4 <split [3.3K/368]> Fold04
 5 <split [3.3K/368]> Fold05
 6 <split [3.3K/368]> Fold06
 7 <split [3.3K/368]> Fold07
 8 <split [3.3K/368]> Fold08
 9 <split [3.3K/368]> Fold09
10 <split [3.3K/368]> Fold10


> lr <- logistic_reg() %>%
+   set_engine("glm")
> cv_fit_tbl <- cv_tbl %>%
+   mutate(fitted = map(splits, ~ fit(lr, type ~ ., data = analysis(.)))) %>%
+   mutate(pred = map2(fitted, splits, ~ predict(.x, assessment(.y)) %>%
+                        bind_cols(assessment(.y) %>% select(type))))


> cv_fit_tbl %>%
+   unnest(pred) %>%
+   metrics(truth = type, estimate = .pred_class)
# A tibble: 2 x 3
  .metric  .estimator .estimate
  <chr>    <chr>          <dbl>
1 accuracy binary         0.928
2 kap      binary         0.848
> cv_fit_tbl %>%
+   unnest(pred) %>%
+   group_by(id) %>%
+   metrics(truth = type, estimate = .pred_class)
# A tibble: 20 x 4
   id     .metric  .estimator .estimate
   <chr>  <chr>    <chr>          <dbl>
 1 Fold01 accuracy binary         0.913
 2 Fold02 accuracy binary         0.927
 3 Fold03 accuracy binary         0.916
 4 Fold04 accuracy binary         0.935
 5 Fold05 accuracy binary         0.935
 6 Fold06 accuracy binary         0.932
 7 Fold07 accuracy binary         0.913
 8 Fold08 accuracy binary         0.935
 9 Fold09 accuracy binary         0.943
10 Fold10 accuracy binary         0.932
11 Fold01 kap      binary         0.820
12 Fold02 kap      binary         0.844
13 Fold03 kap      binary         0.819
14 Fold04 kap      binary         0.865
15 Fold05 kap      binary         0.864
16 Fold06 kap      binary         0.843
17 Fold07 kap      binary         0.819
18 Fold08 kap      binary         0.866
19 Fold09 kap      binary         0.878
20 Fold10 kap      binary         0.860


> lr_fitted <- lr %>%
+   fit(type ~ ., data = spam_recipe %>% juice())
> lr_fitted %>%
+   pluck("fit") %>%
+   summary()
stats::glm(formula = formula, family = stats::binomial, data = data)

Deviance Residuals: 
    Min       1Q   Median       3Q      Max  
-4.4160  -0.1822   0.0000   0.0842   4.7901  

                   Estimate Std. Error z value Pr(>|z|)    
(Intercept)       -11.03057    2.21921  -4.971 6.68e-07 ***
make               -0.11567    0.08275  -1.398 0.162183    
address            -0.18874    0.09824  -1.921 0.054699 .  
all                -0.01276    0.06590  -0.194 0.846415    
num3d               4.13411    2.99730   1.379 0.167809    
our                 0.32354    0.07183   4.504 6.66e-06 ***
over                0.17603    0.07044   2.499 0.012457 *  
remove              1.11893    0.17777   6.294 3.09e-10 ***
internet            0.33036    0.10318   3.202 0.001366 ** 
order               0.34472    0.10244   3.365 0.000765 ***
mail                0.05606    0.04921   1.139 0.254569    
receive            -0.08768    0.07037  -1.246 0.212778    
will               -0.18379    0.07907  -2.325 0.020098 *  
people             -0.01070    0.07714  -0.139 0.889698    
report              0.08011    0.06044   1.325 0.185006    
addresses           0.44321    0.27905   1.588 0.112219    
free                1.01890    0.14696   6.933 4.11e-12 ***
business            0.78720    0.15591   5.049 4.44e-07 ***
email               0.01583    0.07185   0.220 0.825669    
you                 0.16197    0.06965   2.326 0.020036 *  
credit              0.46118    0.32447   1.421 0.155223    
your                0.33921    0.07446   4.555 5.23e-06 ***
font                0.25124    0.17733   1.417 0.156556    
num000              0.77821    0.18296   4.253 2.11e-05 ***
money               0.28989    0.11954   2.425 0.015308 *  
hp                 -4.26292    0.70479  -6.048 1.46e-09 ***
hpl                -0.69120    0.40816  -1.693 0.090369 .  
george            -32.84208    7.40414  -4.436 9.18e-06 ***
num650              0.25363    0.12315   2.059 0.039454 *  
lab                -2.38990    1.44426  -1.655 0.097974 .  
labs               -0.08034    0.15910  -0.505 0.613579    
telnet              0.35114    0.30083   1.167 0.243113    
num857              1.11316    0.86138   1.292 0.196257    
data               -0.64293    0.24542  -2.620 0.008802 ** 
num415              0.17660    0.47874   0.369 0.712205    
num85              -1.61614    0.75269  -2.147 0.031782 *  
technology          0.52329    0.15197   3.443 0.000574 ***
num1999             0.05341    0.07911   0.675 0.499578    
parts              -0.17965    0.11995  -1.498 0.134203    
pm                 -0.44567    0.19173  -2.324 0.020099 *  
direct             -0.09013    0.15100  -0.597 0.550583    
cs                -17.52864   12.18618  -1.438 0.150320    
meeting            -1.68504    0.50663  -3.326 0.000881 ***
original           -0.20614    0.15032  -1.371 0.170268    
project            -0.90776    0.33530  -2.707 0.006784 ** 
re                 -0.80397    0.16877  -4.764 1.90e-06 ***
edu                -1.02751    0.22487  -4.569 4.89e-06 ***
table              -0.28628    0.22561  -1.269 0.204474    
conference         -1.42051    0.58870  -2.413 0.015824 *  
charSemicolon      -0.37106    0.13565  -2.736 0.006228 ** 
charRoundbracket   -0.15627    0.11291  -1.384 0.166342    
charSquarebracket  -0.12351    0.13754  -0.898 0.369219    
charExclamation     0.24374    0.06468   3.768 0.000164 ***
charDollar          1.25848    0.19345   6.505 7.74e-11 ***
charHash            0.84666    0.55791   1.518 0.129125    
capitalAve          0.08332    0.63932   0.130 0.896308    
capitalLong         1.67504    0.59140   2.832 0.004621 ** 
capitalTotal        0.64674    0.16980   3.809 0.000140 ***
Signif. codes:  0***0.001**0.01*0.05 ‘.’ 0.1 ‘ ’ 1

(Dispersion parameter for binomial family taken to be 1)

    Null deviance: 4937.8  on 3681  degrees of freedom
Residual deviance: 1365.7  on 3624  degrees of freedom
AIC: 1481.7

Number of Fisher Scoring iterations: 13


> lr_pred <- lr_fitted %>%
+   predict(spam_recipe %>% bake(spam_test_df) %>% select(-type)) %>%
+   mutate(truth = spam_test_df %>% select(type) %>% pull())


> lr_pred %>%
+   metrics(truth = truth, estimate = .pred_class)
# A tibble: 2 x 3
  .metric  .estimator .estimate
  <chr>    <chr>          <dbl>
1 accuracy binary         0.917
2 kap      binary         0.826
> lr_pred %>%
+   conf_mat(truth = truth, estimate = .pred_class) %>%
+   autoplot(type = "heatmap")


