-
Notifications
You must be signed in to change notification settings - Fork 15
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Running the example below in a fresh R session, tabnet_pretrain() works with device = "mps", but tabnet_fit() hangs (no message) and I need to Terminate R to recover. Session info attached.
library(tabnet)
library(tidymodels)
library(modeldata)
set.seed(123)
data("lending_club", package = "modeldata")
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test <- testing(split)
tab_rec <-
train |>
recipe() |>
update_role(Class, new_role = "outcome") |>
update_role(-has_role(c("outcome", "id")), new_role = "predictor")
set.seed(1)
tab_pre <- tab_rec |>
tabnet_pretrain(train, device = "mps", checkpoint_epochs = 2)
tab_fit <- tab_rec |>
tabnet_fit(train, tabnet_model = tab_pre, from_epoch = 2, device = "cpu") # hangs with "mps"
test |> bind_cols(predict(tab_fit, test))
#> # A tibble: 2,465 × 24
#> funded_amnt term int_rate sub_grade addr_state verification_status
#> <int> <fct> <dbl> <fct> <fct> <fct>
#> 1 10000 term_36 11.5 B5 TX Source_Verified
#> 2 7000 term_36 13.0 C2 CA Source_Verified
#> 3 35000 term_36 11.5 B5 TN Source_Verified
#> 4 15000 term_36 10.8 B4 TX Not_Verified
#> 5 27200 term_60 10.8 B4 NC Not_Verified
#> 6 12000 term_36 14.5 C4 OR Source_Verified
#> 7 15025 term_36 13.7 C3 MA Source_Verified
#> 8 20000 term_36 5.32 A1 WI Not_Verified
#> 9 20000 term_36 12.0 C1 VA Verified
#> 10 10000 term_36 10.8 B4 NC Verified
#> # ℹ 2,455 more rows
#> # ℹ 18 more variables: annual_inc <dbl>, emp_length <fct>, delinq_2yrs <int>,
#> # inq_last_6mths <int>, revol_util <dbl>, acc_now_delinq <int>,
#> # open_il_6m <int>, open_il_12m <int>, open_il_24m <int>, total_bal_il <int>,
#> # all_util <int>, inq_fi <int>, inq_last_12m <int>, delinq_amnt <int>,
#> # num_il_tl <int>, total_il_high_credit_limit <int>, Class <fct>,
#> # .pred_class <fct>Created on 2024-01-12 with reprex v2.0.2
Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.3.2 (2023-10-31)
#> os macOS Sonoma 14.2.1
#> system aarch64, darwin20
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz Europe/London
#> date 2024-01-12
#> pandoc 3.1.1 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown)
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date (UTC) lib source
#> backports 1.4.1 2021-12-13 [2] CRAN (R 4.3.0)
#> bit 4.0.5 2022-11-15 [2] CRAN (R 4.3.0)
#> bit64 4.0.5 2020-08-30 [2] CRAN (R 4.3.0)
#> broom * 1.0.5 2023-06-09 [2] CRAN (R 4.3.0)
#> callr 3.7.3 2022-11-02 [2] CRAN (R 4.3.0)
#> class 7.3-22 2023-05-03 [2] CRAN (R 4.3.2)
#> cli 3.6.2 2023-12-11 [1] CRAN (R 4.3.1)
#> codetools 0.2-19 2023-02-01 [2] CRAN (R 4.3.2)
#> colorspace 2.1-0 2023-01-23 [1] CRAN (R 4.3.0)
#> coro 1.0.3 2022-07-19 [2] CRAN (R 4.3.0)
#> data.table 1.14.10 2023-12-08 [1] CRAN (R 4.3.1)
#> dials * 1.2.0 2023-04-03 [1] CRAN (R 4.3.0)
#> DiceDesign 1.10 2023-12-07 [1] CRAN (R 4.3.1)
#> digest 0.6.33 2023-07-07 [1] CRAN (R 4.3.0)
#> dplyr * 1.1.4 2023-11-17 [1] CRAN (R 4.3.1)
#> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.3.0)
#> evaluate 0.23 2023-11-01 [2] CRAN (R 4.3.1)
#> fansi 1.0.6 2023-12-08 [1] CRAN (R 4.3.1)
#> fastmap 1.1.1 2023-02-24 [2] CRAN (R 4.3.0)
#> foreach 1.5.2 2022-02-02 [1] CRAN (R 4.3.0)
#> fs 1.6.3 2023-07-20 [2] CRAN (R 4.3.0)
#> furrr 0.3.1 2022-08-15 [1] CRAN (R 4.3.0)
#> future 1.33.1 2023-12-22 [1] CRAN (R 4.3.1)
#> future.apply 1.11.1 2023-12-21 [1] CRAN (R 4.3.1)
#> generics 0.1.3 2022-07-05 [1] CRAN (R 4.3.0)
#> ggplot2 * 3.4.4 2023-10-12 [1] CRAN (R 4.3.1)
#> globals 0.16.2 2022-11-21 [1] CRAN (R 4.3.0)
#> glue 1.7.0 2024-01-09 [1] CRAN (R 4.3.1)
#> gower 1.0.1 2022-12-22 [1] CRAN (R 4.3.0)
#> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.3.0)
#> gtable 0.3.4 2023-08-21 [1] CRAN (R 4.3.0)
#> hardhat 1.3.0 2023-03-30 [1] CRAN (R 4.3.0)
#> htmltools 0.5.7 2023-11-03 [2] CRAN (R 4.3.1)
#> infer * 1.0.5 2023-09-06 [2] CRAN (R 4.3.0)
#> ipred 0.9-14 2023-03-09 [1] CRAN (R 4.3.0)
#> iterators 1.0.14 2022-02-05 [1] CRAN (R 4.3.0)
#> jsonlite 1.8.8 2023-12-04 [2] CRAN (R 4.3.1)
#> knitr 1.45 2023-10-30 [2] CRAN (R 4.3.1)
#> lattice 0.22-5 2023-10-24 [2] CRAN (R 4.3.1)
#> lava 1.7.3 2023-11-04 [1] CRAN (R 4.3.1)
#> lhs 1.1.6 2022-12-17 [1] CRAN (R 4.3.0)
#> lifecycle 1.0.4 2023-11-07 [1] CRAN (R 4.3.1)
#> listenv 0.9.0 2022-12-16 [1] CRAN (R 4.3.0)
#> lubridate 1.9.3 2023-09-27 [1] CRAN (R 4.3.1)
#> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.3.0)
#> MASS 7.3-60 2023-05-04 [2] CRAN (R 4.3.2)
#> Matrix 1.6-4 2023-11-30 [2] CRAN (R 4.3.1)
#> modeldata * 1.2.0 2023-08-09 [2] CRAN (R 4.3.0)
#> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.3.0)
#> nnet 7.3-19 2023-05-03 [2] CRAN (R 4.3.2)
#> parallelly 1.36.0 2023-05-26 [1] CRAN (R 4.3.0)
#> parsnip * 1.1.1 2023-08-17 [1] CRAN (R 4.3.0)
#> pillar 1.9.0 2023-03-22 [1] CRAN (R 4.3.0)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.3.0)
#> processx 3.8.3 2023-12-10 [2] CRAN (R 4.3.1)
#> prodlim 2023.08.28 2023-08-28 [1] CRAN (R 4.3.0)
#> ps 1.7.5 2023-04-18 [2] CRAN (R 4.3.0)
#> purrr * 1.0.2 2023-08-10 [1] CRAN (R 4.3.0)
#> R.cache 0.16.0 2022-07-21 [2] CRAN (R 4.3.0)
#> R.methodsS3 1.8.2 2022-06-13 [2] CRAN (R 4.3.0)
#> R.oo 1.25.0 2022-06-12 [2] CRAN (R 4.3.0)
#> R.utils 2.12.3 2023-11-18 [2] CRAN (R 4.3.1)
#> R6 2.5.1 2021-08-19 [1] CRAN (R 4.3.0)
#> Rcpp 1.0.12 2024-01-09 [1] CRAN (R 4.3.1)
#> recipes * 1.0.9 2023-12-13 [1] CRAN (R 4.3.1)
#> reprex 2.0.2 2022-08-17 [2] CRAN (R 4.3.0)
#> rlang 1.1.3 2024-01-10 [1] CRAN (R 4.3.1)
#> rmarkdown 2.25 2023-09-18 [2] CRAN (R 4.3.1)
#> rpart 4.1.23 2023-12-05 [2] CRAN (R 4.3.1)
#> rsample * 1.2.0 2023-08-23 [1] CRAN (R 4.3.0)
#> rstudioapi 0.15.0 2023-07-07 [2] CRAN (R 4.3.0)
#> safetensors 0.1.2 2023-09-12 [2] CRAN (R 4.3.0)
#> scales * 1.2.1 2022-08-20 [1] CRAN (R 4.3.2)
#> sessioninfo 1.2.2 2021-12-06 [2] CRAN (R 4.3.0)
#> styler 1.10.2 2023-08-29 [2] CRAN (R 4.3.0)
#> survival 3.5-7 2023-08-14 [2] CRAN (R 4.3.2)
#> tabnet * 0.5.0.9000 2024-01-11 [1] Github (mlverse/tabnet@962bafa)
#> tibble * 3.2.1 2023-03-20 [1] CRAN (R 4.3.0)
#> tidymodels * 1.1.1 2023-08-24 [2] CRAN (R 4.3.0)
#> tidyr * 1.3.0 2023-01-24 [1] CRAN (R 4.3.0)
#> tidyselect 1.2.0 2022-10-10 [1] CRAN (R 4.3.0)
#> timechange 0.2.0 2023-01-11 [1] CRAN (R 4.3.0)
#> timeDate 4032.109 2023-12-14 [1] CRAN (R 4.3.1)
#> torch 0.12.0 2024-01-05 [1] Github (mlverse/torch@23071c1)
#> tune * 1.1.2 2023-08-23 [1] CRAN (R 4.3.0)
#> utf8 1.2.4 2023-10-22 [1] CRAN (R 4.3.1)
#> vctrs 0.6.5 2023-12-01 [1] CRAN (R 4.3.1)
#> withr 2.5.2 2023-10-30 [1] CRAN (R 4.3.1)
#> workflows * 1.1.3 2023-02-22 [1] CRAN (R 4.3.0)
#> workflowsets * 1.0.1 2023-04-06 [2] CRAN (R 4.3.0)
#> xfun 0.41 2023-11-01 [2] CRAN (R 4.3.1)
#> yaml 2.3.8 2023-12-11 [2] CRAN (R 4.3.1)
#> yardstick * 1.2.0 2023-04-21 [1] CRAN (R 4.3.0)
#> zeallot 0.1.0 2018-01-28 [2] CRAN (R 4.3.0)
#>
#> [1] /Users/carlgoodwin/Library/R/arm64/4.3/library
#> [2] /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/library
#>
#> ──────────────────────────────────────────────────────────────────────────────Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working