В последнее время я публикую заметки, которые демонстрируют работу с пакетом tidymodels
. Я разбираю как простые, так и более сложные модели. Сегодняшняя заметка подойдет тем, кто только начинает свое знакомство с пакетом tidymodels
.
Знакомимся с данными
Мы будем работать с набором данных, в котором хранится информация о пингвинах, живущих на архипелаге Палмера. Наша задача — предсказать пол пингвинов. Для этого мы будем использовать модель классификации.
library(tidyverse)
library(palmerpenguins)
penguins
# A tibble: 344 × 8
species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex
<fct> <fct> <dbl> <dbl> <int> <int> <fct>
1 Adelie Torgersen 39.1 18.7 181 3750 male
2 Adelie Torgersen 39.5 17.4 186 3800 female
3 Adelie Torgersen 40.3 18 195 3250 female
4 Adelie Torgersen NA NA NA NA NA
5 Adelie Torgersen 36.7 19.3 193 3450 female
6 Adelie Torgersen 39.3 20.6 190 3650 male
7 Adelie Torgersen 38.9 17.8 181 3625 female
8 Adelie Torgersen 39.2 19.6 195 4675 male
9 Adelie Torgersen 34.1 18.1 193 3475 NA
10 Adelie Torgersen 42 20.2 190 4250 NA
# … with 334 more rows, and 1 more variable: year <int>
В нашем датасете есть такие переменные как:
species
- вид пингвинаisland
— остров, на котором обитает особьbill_length_mm
— длина клюва;bill_depth_mm
— глубина клюва;flipper_length_mm
— длина плавника;body_mass_g
— масса тела;sex
— полyear
— год наблюдения
Хочу обратить внимание, что если мы решим использовать модель классификации для предсказания вида пингвина, то обнаружим, что переменные о клюве и массе тела дают высокую точность прогноза. Поэтому мы сосредоточимся на предсказании пола особи, так как эта задача чуть интересней.
Посмотрим на переменную пола особи:
penguins %>%
count(sex)
# A tibble: 3 × 2
sex n
<fct> <int>
1 female 165
2 male 168
3 NA 11
В наших данных есть 11 пропущенных значений, что, в целом, не является критической проблемой для нас.
Давайте визуализируем наши данные:
penguins %>%
#исключим пропущенные значения
filter(!is.na(sex)) %>%
ggplot(aes(flipper_length_mm,
bill_length_mm,
color = sex,
size = body_mass_g)) +
geom_point(alpha = 0.5) +
facet_wrap(~species) +
theme_light()
Похоже, что самки более мелкие, чем особи мужского пола. Что, в целом, очевидно. Давайте приступим к нашему моделированию. Стоит отметить, что мы не будем использовать переменные island
и year
.
penguins_df <- penguins %>%
#Удалим из наших данных пропущенные значения
filter(!is.na(sex)) %>%
#Удалим из наших данных переменные island и year
select(-year, -island)
Построение модели
Начнем наше построение модели с подключения пакета tidymodels
и разделим данные на тестовую и обучающую выборки.
library(tidymodels)
set.seed(123)
# Мы разделим нашу выборку по полу, чтобы соблюсти пропорции самцов и самок
penguin_split <- initial_split(penguins_df, strata = sex)
penguin_train <- training(penguin_split)
penguin_test <- testing(penguin_split)
Так как у нас недостаточно наблюдений мы можем создать набор данных с повторной выборкой.
set.seed(123)
penguin_boot <- bootstraps(penguin_train)
penguin_boot
# Bootstrap sampling
# A tibble: 25 × 2
splits id
<list> <chr>
1 <split [249/93]> Bootstrap01
2 <split [249/91]> Bootstrap02
3 <split [249/90]> Bootstrap03
4 <split [249/91]> Bootstrap04
5 <split [249/85]> Bootstrap05
6 <split [249/87]> Bootstrap06
7 <split [249/94]> Bootstrap07
8 <split [249/88]> Bootstrap08
9 <split [249/95]> Bootstrap09
10 <split [249/89]> Bootstrap10
# … with 15 more rows
Давайте сравним две различные модели: модель логистической регрессии и модель случайного леса. Начнем с создания спецификаций моделей.
# спецификация для логистичсекой модели
glm_spec <- logistic_reg() %>%
set_engine("glm")
glm_spec
Logistic Regression Model Specification (classification)
Computational engine: glm
# спецификация для модели случайного леса
rf_spec <- rand_forest() %>%
set_mode("classification") %>%
set_engine("ranger")
rf_spec
Random Forest Model Specification (classification)
Computational engine: ranger
Сейчас это пустые спецификации модели, в которых нет данных. Нам нужно “собрать” модель и для этого мы будем использовать функцию workflow()
На первом шаге нам нужно указать формулу, где мы выберем переменные для предсказания пола особи.
penguin_wf <- workflow() %>%
add_formula(sex ~ .) # будем использовать все переменные для предсказания пола особи
penguin_wf
══ Workflow ══════════════════════════════════════════════════════════════════════════
Preprocessor: Formula
Model: None
── Preprocessor ──────────────────────────────────────────────────────────────────────
sex ~ .
На втором шаге, нам нужно добавить спецификацию модели. Начнем с логистической регрессии.
glm_rs <- penguin_wf %>%
add_model(glm_spec) %>%
fit_resamples(
resamples = penguin_boot,
control = control_resamples(save_pred = TRUE)
)
glm_rs
# Resampling results
# Bootstrap sampling
# A tibble: 25 × 5
splits id .metrics .notes .predictions
<list> <chr> <list> <list> <list>
1 <split [249/93]> Bootstrap01 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [93 × 6]>
2 <split [249/91]> Bootstrap02 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [91 × 6]>
3 <split [249/90]> Bootstrap03 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [90 × 6]>
4 <split [249/91]> Bootstrap04 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [91 × 6]>
5 <split [249/85]> Bootstrap05 <tibble [2 × 4]> <tibble [1 × 3]> <tibble [85 × 6]>
6 <split [249/87]> Bootstrap06 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [87 × 6]>
7 <split [249/94]> Bootstrap07 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [94 × 6]>
8 <split [249/88]> Bootstrap08 <tibble [2 × 4]> <tibble [1 × 3]> <tibble [88 × 6]>
9 <split [249/95]> Bootstrap09 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [95 × 6]>
10 <split [249/89]> Bootstrap10 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [89 × 6]>
# … with 15 more rows
Сделаем тоже самое для модели случайного леса:
rf_rs <- penguin_wf %>%
add_model(rf_spec) %>%
fit_resamples(
resamples = penguin_boot,
control = control_resamples(save_pred = TRUE)
)
rf_rs
# Resampling results
# Bootstrap sampling
# A tibble: 25 × 5
splits id .metrics .notes .predictions
<list> <chr> <list> <list> <list>
1 <split [249/93]> Bootstrap01 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [93 × 6]>
2 <split [249/91]> Bootstrap02 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [91 × 6]>
3 <split [249/90]> Bootstrap03 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [90 × 6]>
4 <split [249/91]> Bootstrap04 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [91 × 6]>
5 <split [249/85]> Bootstrap05 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [85 × 6]>
6 <split [249/87]> Bootstrap06 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [87 × 6]>
7 <split [249/94]> Bootstrap07 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [94 × 6]>
8 <split [249/88]> Bootstrap08 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [88 × 6]>
9 <split [249/95]> Bootstrap09 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [95 × 6]>
10 <split [249/89]> Bootstrap10 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [89 × 6]>
# … with 15 more rows
Оцениваем модель
Теперь давайте проверим, что у нас получилось.
Посмотрим метрики точности для модели случайного леса
collect_metrics(rf_rs)
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.914 25 0.00545 Preprocessor1_Model1
2 roc_auc binary 0.977 25 0.00202 Preprocessor1_Model1
Теперь посмотрим метрики точности для логистической регрессии
collect_metrics(glm_rs)
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.918 25 0.00639 Preprocessor1_Model1
2 roc_auc binary 0.979 25 0.00254 Preprocessor1_Model1
Мы видим, что модель glm_rs
сработала чуть лучше, чем модель rf_rs
. Учитывая, что эти модели дают примерно одинаковый результат, есть смысл выбрать ту, которая по своей сути проще. Поэтому мы остановимся на модели логистической регрессии.
Построим матрицу ошибок
glm_rs %>%
conf_mat_resampled()
# A tibble: 4 × 3
Prediction Truth Freq
<fct> <fct> <dbl>
1 female female 41.1
2 female male 3
3 male female 4.4
4 male male 42.3
Мы видим, что у нас нет проблем с предсказанием пола особи. Модель справляется достаточно хорошо.
Одним из наших показателей был roc_auc
поэтому было бы интересно посмотреть на ROC-кривые.
glm_rs %>%
collect_predictions() %>%
group_by(id) %>%
roc_curve(sex, .pred_female) %>%
ggplot(aes(1 - specificity, sensitivity, color = id)) +
geom_abline(lty = 2, color = "gray80", size = 1.5) +
geom_path(show.legend = FALSE, alpha = 0.6, size = 1.2) +
coord_equal() +
theme_light()
Мы видим, что ROC-кривая ступенчатая — это связано с тем, что у нас не большой набор данных.
Теперь мы можем вернуться к тестовому набору. Обратите внимание, что мы еще не использовали тестовую выборку. Ее мы можем использовать только для оценки производительности модели на новых данных.
penguin_final <- penguin_wf %>%
add_model(glm_spec) %>%
last_fit(penguin_split)
penguin_final
# Resampling results
# Manual resampling
# A tibble: 1 × 6
splits id .metrics .notes .predictions .workflow
<list> <chr> <list> <list> <list> <list>
1 <split [249/84]> train/test split <tibble [2 × 4]> <tibble> <tibble> <workflow>
Теперь мы модем посмотреть на оценки модели для тестового набора данных
collect_metrics(penguin_final)
# A tibble: 2 × 4
.metric .estimator .estimate .config
<chr> <chr> <dbl> <chr>
1 accuracy binary 0.857 Preprocessor1_Model1
2 roc_auc binary 0.938 Preprocessor1_Model1
Мы видим, что результат не сильно отличается от нашего примера выше, что является хорошим признаком.
Мы также можем посмотреть на предсказанные данные и построить матрицу ошибок.
collect_predictions(penguin_final)
# A tibble: 84 × 7
id .pred_female .pred_male .row .pred_class sex .config
<chr> <dbl> <dbl> <int> <fct> <fct> <chr>
1 train/test split 0.597 0.403 2 female female Preprocessor1_Mo…
2 train/test split 0.928 0.0724 3 female female Preprocessor1_Mo…
3 train/test split 0.647 0.353 4 female female Preprocessor1_Mo…
4 train/test split 0.219 0.781 18 male female Preprocessor1_Mo…
5 train/test split 0.0132 0.987 25 male male Preprocessor1_Mo…
6 train/test split 0.970 0.0298 28 female female Preprocessor1_Mo…
7 train/test split 0.0000232 1.00 31 male male Preprocessor1_Mo…
8 train/test split 0.872 0.128 34 female female Preprocessor1_Mo…
9 train/test split 0.998 0.00250 38 female female Preprocessor1_Mo…
10 train/test split 0.00000253 1.00 39 male male Preprocessor1_Mo…
# … with 74 more rows
Строим матрицу ошибок
collect_predictions(penguin_final) %>%
conf_mat(sex, .pred_class)
Truth
Prediction female male
female 37 7
male 5 35
Мы видим, что наша модель достаточно хорошо справляется.
Теперь мы можем посмотреть на подходящий рабочий процесс нашей модели.
penguin_final$.workflow[[1]]
══ Workflow [trained] ════════════════════════════════════════════════════════════════
Preprocessor: Formula
Model: logistic_reg()
── Preprocessor ──────────────────────────────────────────────────────────────────────
sex ~ .
── Model ─────────────────────────────────────────────────────────────────────────────
Call: stats::glm(formula = ..y ~ ., family = stats::binomial, data = data)
Coefficients:
(Intercept) speciesChinstrap speciesGentoo bill_length_mm
-1.042e+02 -8.892e+00 -1.138e+01 6.459e-01
bill_depth_mm flipper_length_mm body_mass_g
2.124e+00 5.654e-02 8.102e-03
Degrees of Freedom: 248 Total (i.e. Null); 242 Residual
Null Deviance: 345.2
Residual Deviance: 70.02 AIC: 84.02
Здесь мы видим какую модель мы использовали, а также мы можем увидеть рассчитанные коэффициенты нашей модели.
Используя функцию tidy()
, мы получим данные коэффициенты в аккуратном виде. Кроме того, мы можем применить аргумент exponentiate = TRUE
, преобразовав их в коэффициенты шансов.
penguin_final$.workflow[[1]] %>%
tidy(exponentiate = TRUE)
# A tibble: 7 × 5
term estimate std.error statistic p.value
<chr> <dbl> <dbl> <dbl> <dbl>
1 (Intercept) 5.75e-46 19.6 -5.31 0.000000110
2 speciesChinstrap 1.37e- 4 2.34 -3.79 0.000148
3 speciesGentoo 1.14e- 5 3.75 -3.03 0.00243
4 bill_length_mm 1.91e+ 0 0.180 3.60 0.000321
5 bill_depth_mm 8.36e+ 0 0.478 4.45 0.00000868
6 flipper_length_mm 1.06e+ 0 0.0611 0.926 0.355
7 body_mass_g 1.01e+ 0 0.00176 4.59 0.00000442
Мы видим, что глубина и длина клюва являются основными предикторами для классификации пола особи. Увеличение глубины клюва на 1 мм почти в 8 раза увеличивает шанс быть самцом.
Посмотрим на наш график, который мы строили в начале, но заменим переменную flipper_length_mm
на bill_depth_mm
penguins %>%
filter(!is.na(sex)) %>%
ggplot(aes(bill_depth_mm,
bill_length_mm,
color = sex,
size = body_mass_g)) +
geom_point(alpha = 0.5) +
facet_wrap(~species) +
theme_light()
Да, можно с уверенность сказать, что переменная bill_depth_mm
делит особей по полу более четко.