Skip to content

Issues with device = "mps" on Sonoma OS #144

@cgoo4

Description

@cgoo4

Running the example below in a fresh R session, tabnet_pretrain() works with device = "mps", but tabnet_fit() hangs (no message) and I need to Terminate R to recover. Session info attached.

library(tabnet)
library(tidymodels)
library(modeldata)

set.seed(123)
data("lending_club", package = "modeldata")
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test  <- testing(split)

tab_rec <-
  train |>
  recipe() |>
  update_role(Class, new_role = "outcome") |>
  update_role(-has_role(c("outcome", "id")), new_role = "predictor")

set.seed(1)

tab_pre <- tab_rec |> 
  tabnet_pretrain(train, device = "mps", checkpoint_epochs = 2)

tab_fit <- tab_rec |>
  tabnet_fit(train, tabnet_model = tab_pre, from_epoch = 2, device = "cpu") # hangs with "mps"

test |> bind_cols(predict(tab_fit, test))
#> # A tibble: 2,465 × 24
#>    funded_amnt term    int_rate sub_grade addr_state verification_status
#>          <int> <fct>      <dbl> <fct>     <fct>      <fct>              
#>  1       10000 term_36    11.5  B5        TX         Source_Verified    
#>  2        7000 term_36    13.0  C2        CA         Source_Verified    
#>  3       35000 term_36    11.5  B5        TN         Source_Verified    
#>  4       15000 term_36    10.8  B4        TX         Not_Verified       
#>  5       27200 term_60    10.8  B4        NC         Not_Verified       
#>  6       12000 term_36    14.5  C4        OR         Source_Verified    
#>  7       15025 term_36    13.7  C3        MA         Source_Verified    
#>  8       20000 term_36     5.32 A1        WI         Not_Verified       
#>  9       20000 term_36    12.0  C1        VA         Verified           
#> 10       10000 term_36    10.8  B4        NC         Verified           
#> # ℹ 2,455 more rows
#> # ℹ 18 more variables: annual_inc <dbl>, emp_length <fct>, delinq_2yrs <int>,
#> #   inq_last_6mths <int>, revol_util <dbl>, acc_now_delinq <int>,
#> #   open_il_6m <int>, open_il_12m <int>, open_il_24m <int>, total_bal_il <int>,
#> #   all_util <int>, inq_fi <int>, inq_last_12m <int>, delinq_amnt <int>,
#> #   num_il_tl <int>, total_il_high_credit_limit <int>, Class <fct>,
#> #   .pred_class <fct>

Created on 2024-01-12 with reprex v2.0.2

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.3.2 (2023-10-31)
#>  os       macOS Sonoma 14.2.1
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Europe/London
#>  date     2024-01-12
#>  pandoc   3.1.1 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version    date (UTC) lib source
#>  backports      1.4.1      2021-12-13 [2] CRAN (R 4.3.0)
#>  bit            4.0.5      2022-11-15 [2] CRAN (R 4.3.0)
#>  bit64          4.0.5      2020-08-30 [2] CRAN (R 4.3.0)
#>  broom        * 1.0.5      2023-06-09 [2] CRAN (R 4.3.0)
#>  callr          3.7.3      2022-11-02 [2] CRAN (R 4.3.0)
#>  class          7.3-22     2023-05-03 [2] CRAN (R 4.3.2)
#>  cli            3.6.2      2023-12-11 [1] CRAN (R 4.3.1)
#>  codetools      0.2-19     2023-02-01 [2] CRAN (R 4.3.2)
#>  colorspace     2.1-0      2023-01-23 [1] CRAN (R 4.3.0)
#>  coro           1.0.3      2022-07-19 [2] CRAN (R 4.3.0)
#>  data.table     1.14.10    2023-12-08 [1] CRAN (R 4.3.1)
#>  dials        * 1.2.0      2023-04-03 [1] CRAN (R 4.3.0)
#>  DiceDesign     1.10       2023-12-07 [1] CRAN (R 4.3.1)
#>  digest         0.6.33     2023-07-07 [1] CRAN (R 4.3.0)
#>  dplyr        * 1.1.4      2023-11-17 [1] CRAN (R 4.3.1)
#>  ellipsis       0.3.2      2021-04-29 [1] CRAN (R 4.3.0)
#>  evaluate       0.23       2023-11-01 [2] CRAN (R 4.3.1)
#>  fansi          1.0.6      2023-12-08 [1] CRAN (R 4.3.1)
#>  fastmap        1.1.1      2023-02-24 [2] CRAN (R 4.3.0)
#>  foreach        1.5.2      2022-02-02 [1] CRAN (R 4.3.0)
#>  fs             1.6.3      2023-07-20 [2] CRAN (R 4.3.0)
#>  furrr          0.3.1      2022-08-15 [1] CRAN (R 4.3.0)
#>  future         1.33.1     2023-12-22 [1] CRAN (R 4.3.1)
#>  future.apply   1.11.1     2023-12-21 [1] CRAN (R 4.3.1)
#>  generics       0.1.3      2022-07-05 [1] CRAN (R 4.3.0)
#>  ggplot2      * 3.4.4      2023-10-12 [1] CRAN (R 4.3.1)
#>  globals        0.16.2     2022-11-21 [1] CRAN (R 4.3.0)
#>  glue           1.7.0      2024-01-09 [1] CRAN (R 4.3.1)
#>  gower          1.0.1      2022-12-22 [1] CRAN (R 4.3.0)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.3.0)
#>  gtable         0.3.4      2023-08-21 [1] CRAN (R 4.3.0)
#>  hardhat        1.3.0      2023-03-30 [1] CRAN (R 4.3.0)
#>  htmltools      0.5.7      2023-11-03 [2] CRAN (R 4.3.1)
#>  infer        * 1.0.5      2023-09-06 [2] CRAN (R 4.3.0)
#>  ipred          0.9-14     2023-03-09 [1] CRAN (R 4.3.0)
#>  iterators      1.0.14     2022-02-05 [1] CRAN (R 4.3.0)
#>  jsonlite       1.8.8      2023-12-04 [2] CRAN (R 4.3.1)
#>  knitr          1.45       2023-10-30 [2] CRAN (R 4.3.1)
#>  lattice        0.22-5     2023-10-24 [2] CRAN (R 4.3.1)
#>  lava           1.7.3      2023-11-04 [1] CRAN (R 4.3.1)
#>  lhs            1.1.6      2022-12-17 [1] CRAN (R 4.3.0)
#>  lifecycle      1.0.4      2023-11-07 [1] CRAN (R 4.3.1)
#>  listenv        0.9.0      2022-12-16 [1] CRAN (R 4.3.0)
#>  lubridate      1.9.3      2023-09-27 [1] CRAN (R 4.3.1)
#>  magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.3.0)
#>  MASS           7.3-60     2023-05-04 [2] CRAN (R 4.3.2)
#>  Matrix         1.6-4      2023-11-30 [2] CRAN (R 4.3.1)
#>  modeldata    * 1.2.0      2023-08-09 [2] CRAN (R 4.3.0)
#>  munsell        0.5.0      2018-06-12 [1] CRAN (R 4.3.0)
#>  nnet           7.3-19     2023-05-03 [2] CRAN (R 4.3.2)
#>  parallelly     1.36.0     2023-05-26 [1] CRAN (R 4.3.0)
#>  parsnip      * 1.1.1      2023-08-17 [1] CRAN (R 4.3.0)
#>  pillar         1.9.0      2023-03-22 [1] CRAN (R 4.3.0)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.3.0)
#>  processx       3.8.3      2023-12-10 [2] CRAN (R 4.3.1)
#>  prodlim        2023.08.28 2023-08-28 [1] CRAN (R 4.3.0)
#>  ps             1.7.5      2023-04-18 [2] CRAN (R 4.3.0)
#>  purrr        * 1.0.2      2023-08-10 [1] CRAN (R 4.3.0)
#>  R.cache        0.16.0     2022-07-21 [2] CRAN (R 4.3.0)
#>  R.methodsS3    1.8.2      2022-06-13 [2] CRAN (R 4.3.0)
#>  R.oo           1.25.0     2022-06-12 [2] CRAN (R 4.3.0)
#>  R.utils        2.12.3     2023-11-18 [2] CRAN (R 4.3.1)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.3.0)
#>  Rcpp           1.0.12     2024-01-09 [1] CRAN (R 4.3.1)
#>  recipes      * 1.0.9      2023-12-13 [1] CRAN (R 4.3.1)
#>  reprex         2.0.2      2022-08-17 [2] CRAN (R 4.3.0)
#>  rlang          1.1.3      2024-01-10 [1] CRAN (R 4.3.1)
#>  rmarkdown      2.25       2023-09-18 [2] CRAN (R 4.3.1)
#>  rpart          4.1.23     2023-12-05 [2] CRAN (R 4.3.1)
#>  rsample      * 1.2.0      2023-08-23 [1] CRAN (R 4.3.0)
#>  rstudioapi     0.15.0     2023-07-07 [2] CRAN (R 4.3.0)
#>  safetensors    0.1.2      2023-09-12 [2] CRAN (R 4.3.0)
#>  scales       * 1.2.1      2022-08-20 [1] CRAN (R 4.3.2)
#>  sessioninfo    1.2.2      2021-12-06 [2] CRAN (R 4.3.0)
#>  styler         1.10.2     2023-08-29 [2] CRAN (R 4.3.0)
#>  survival       3.5-7      2023-08-14 [2] CRAN (R 4.3.2)
#>  tabnet       * 0.5.0.9000 2024-01-11 [1] Github (mlverse/tabnet@962bafa)
#>  tibble       * 3.2.1      2023-03-20 [1] CRAN (R 4.3.0)
#>  tidymodels   * 1.1.1      2023-08-24 [2] CRAN (R 4.3.0)
#>  tidyr        * 1.3.0      2023-01-24 [1] CRAN (R 4.3.0)
#>  tidyselect     1.2.0      2022-10-10 [1] CRAN (R 4.3.0)
#>  timechange     0.2.0      2023-01-11 [1] CRAN (R 4.3.0)
#>  timeDate       4032.109   2023-12-14 [1] CRAN (R 4.3.1)
#>  torch          0.12.0     2024-01-05 [1] Github (mlverse/torch@23071c1)
#>  tune         * 1.1.2      2023-08-23 [1] CRAN (R 4.3.0)
#>  utf8           1.2.4      2023-10-22 [1] CRAN (R 4.3.1)
#>  vctrs          0.6.5      2023-12-01 [1] CRAN (R 4.3.1)
#>  withr          2.5.2      2023-10-30 [1] CRAN (R 4.3.1)
#>  workflows    * 1.1.3      2023-02-22 [1] CRAN (R 4.3.0)
#>  workflowsets * 1.0.1      2023-04-06 [2] CRAN (R 4.3.0)
#>  xfun           0.41       2023-11-01 [2] CRAN (R 4.3.1)
#>  yaml           2.3.8      2023-12-11 [2] CRAN (R 4.3.1)
#>  yardstick    * 1.2.0      2023-04-21 [1] CRAN (R 4.3.0)
#>  zeallot        0.1.0      2018-01-28 [2] CRAN (R 4.3.0)
#> 
#>  [1] /Users/carlgoodwin/Library/R/arm64/4.3/library
#>  [2] /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions