modeltime tune

modeltime tuning <95> <86><8c><9c>

Author

Don Don

Published

July 10, 2022

Back to the basic.

Library load

# install.packages("fpp3")
library(tidymodels)
library(torch)
library(tidyverse)
library(magrittr)
library(skimr)
library(knitr)
library(fpp3)
library(modeltime)
library(timetk)
library(data.table)

ggplot2::theme_set(ggplot2::theme_bw())

Dataset load

These are the file list in the competition. We will only use the train and test data for this notebook. Remember, make it simple.

dat <- read_csv("train.csv") %>% 
  janitor::clean_names()  
Warning in FUN(X[[i]], ...): unable to translate '<U+00C4>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00D6>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00DC>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00E4>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00F6>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00FC>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00DF>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00C6>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00E6>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00D8>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00F8>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00C5>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00E5>' to native encoding

<8d> <84><95>

dat <- dat %>% 
    filter(country == 'Norway' & store == 'KaggleRama' & product == "Kaggle Mug") %>% select(-c(row_id, country, store, product)) %>% 
    rename(value = num_sold)
dat %>% 
    plot_time_series(.date_var = date, 
                     .value = value)
Jan 2015Jul 2015Jan 2016Jul 2016Jan 2017Jul 2017Jan 2018Jul 2018400600800100012001400160018002000
Time Series Plot
splits <- initial_time_split(dat, prop = 0.7)
store_recipe <-  
    recipe(value ~ ., data = training(splits)) %>% 
    step_timeseries_signature(date) %>% 
    step_rm(matches("(iso)|(xts)|(hour)|(minute)|(second)|(am.pm)")) %>%
    
    step_normalize(matches("(index.num)|(year)|(yday)")) %>%
    
    step_dummy(all_nominal(), one_hot = TRUE) %>%
    
    step_interact(~ matches("week2") * matches("wday.lbl")) %>%
    
    step_fourier(date, period = c(7, 14, 30, 90, 365), K = 2)
resamples_tscv_lag <- time_series_cv(
    data = training(splits) %>% drop_na(),
    cumulative  = TRUE,
    initial     = "2 months",
    assess      = "20 weeks",
    skip        = "2 weeks",
    slice_limit = 6
)
Using date_var: date
resamples_tscv_lag %>%
    tk_time_series_cv_plan() %>%
    plot_time_series_cv_plan(date, value)
50010001500400600800100012001400400600800100012001400400600800100012001400400600800100012001400Jan 2015Jul 2015Jan 2016Jul 2016Jan 2017Jul 201750010001500
LegendtrainingtestingTime Series Cross Validation PlanSlice1Slice2Slice3Slice4Slice5Slice6
model_spec_nnetar <- nnetar_reg(
    seasonal_period = 7,
    non_seasonal_ar = tune(id = "non_seasonal_ar"),
    seasonal_ar     = tune(),
    hidden_units    = tune(),
    num_networks    = 10,
    penalty         = tune(),
    epochs          = 50
) %>%
    set_engine("nnetar")
set.seed(123)
grid_spec_nnetar_1 <- grid_latin_hypercube(
    parameters(model_spec_nnetar),
    size = 15
)
Warning: `parameters.model_spec()` was deprecated in tune 0.1.6.9003.
Please use `hardhat::extract_parameter_set_dials()` instead.
wflw_fit_nnetar <- workflow() %>% 
  add_recipe(store_recipe) %>%
  add_model(model_spec_nnetar)
library(doFuture)
Loading required package: foreach

Attaching package: 'foreach'
The following objects are masked from 'package:purrr':

    accumulate, when
Loading required package: future
registerDoFuture()

n_cores <- parallel::detectCores()

plan(
    strategy = cluster,
    workers  = parallel::makeCluster(n_cores)
)


library(tictoc)

tic()
set.seed(123)
tune_results_nnetar_1 <- wflw_fit_nnetar %>%
    tune_grid(
        resamples = resamples_tscv_lag,
        grid      = grid_spec_nnetar_1,
        metrics   = default_forecast_accuracy_metric_set(),
        control   = control_grid(save_pred = TRUE)
    )
! Slice1: preprocessor 1/1: `terms_select()` was deprecated in recipes 0.1.17.
Ple...
! Slice1: preprocessor 1/1, model 1/15: unable to translate '<U+00C4>' to native e...
! Slice2: preprocessor 1/1: `terms_select()` was deprecated in recipes 0.1.17.
Ple...
! Slice2: preprocessor 1/1, model 1/15: unable to translate '<U+00C4>' to native e...
! Slice3: preprocessor 1/1: `terms_select()` was deprecated in recipes 0.1.17.
Ple...
! Slice3: preprocessor 1/1, model 1/15: unable to translate '<U+00C4>' to native e...
! Slice4: preprocessor 1/1: `terms_select()` was deprecated in recipes 0.1.17.
Ple...
! Slice4: preprocessor 1/1, model 1/15: unable to translate '<U+00C4>' to native e...
! Slice5: preprocessor 1/1: `terms_select()` was deprecated in recipes 0.1.17.
Ple...
! Slice5: preprocessor 1/1, model 1/15: unable to translate '<U+00C4>' to native e...
! Slice6: preprocessor 1/1: `terms_select()` was deprecated in recipes 0.1.17.
Ple...
! Slice6: preprocessor 1/1, model 1/15: unable to translate '<U+00C4>' to native e...
toc()
36.27 sec elapsed
tune_results_nnetar_1$.notes[[1]]$note
[1] "`terms_select()` was deprecated in recipes 0.1.17.\nPlease use `recipes_eval_select()` instead
[2] "unable to translate '<U+00C4>' to native encoding, unable to translate '<U+00D6>' to native encoding, unable to translate '<U+00DC>' to native encoding, unable to translate '<U+00E4>' to native encoding, unable to translate '<U+00F6>' to native encoding, unable to translate '<U+00FC>' to native encoding, unable to translate '<U+00DF>' to native encoding, unable to translate '<U+00C6>' to native encoding, unable to translate '<U+00E6>' to native encoding, unable to translate '<U+00D8>' to native encoding, unable to translate '<U+00F8>' to native encoding, unable to translate '<U+00C5>' to native encoding, unable to translate '<U+00E5>' to native encoding"
set.seed(123)
wflw_fit_nnetar_tscv <- wflw_fit_nnetar %>%
    finalize_workflow(
        tune_results_nnetar_1 %>% 
            show_best(metric = "rmse", n = Inf) %>%
            dplyr::slice(1)
    ) %>%
    fit(training(splits))
Warning: `terms_select()` was deprecated in recipes 0.1.17.
Please use `recipes_eval_select()` instead.
This warning is displayed once every 8 hours.
Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
wflw_fit_nnetar_tscv
== Workflow [trained] ==========================================================
Preprocessor: Recipe
Model: nnetar_reg()

-- Preprocessor ----------------------------------------------------------------
6 Recipe Steps

* step_timeseries_signature()
* step_rm()
* step_normalize()
* step_dummy()
* step_interact()
* step_fourier()

-- Model -----------------------------------------------------------------------
Series: outcome 
Model:  NNAR(3,1,1)[7] 
Call:   forecast::nnetar(y = outcome, p = p, P = P, size = size, repeats = repeats, 
    xreg = xreg_matrix, decay = decay, maxit = maxit)

Average of 10 networks, each of which is
a 66-1-1 network with 69 weights
options were - linear output units  decay=8.006317e-10

sigma^2 estimated as 4413
pred <- predict(wflw_fit_nnetar_tscv, testing(splits))


testing(splits) %>% 
  bind_cols(pred) %>% 
  ggplot() + 
  geom_line(aes(x = date, y = value), color = "blue") + 
  geom_line(aes(x = date, y = .pred), color = "red") 

Citation

BibTeX citation:
@online{don2022,
  author = {Don, Don and Don, Don},
  title = {Modeltime Tune},
  date = {2022-07-10},
  url = {https://dondonkim.netlify.app/posts/2022-07-10-modeltime-tune/tune.html},
  langid = {en}
}
For attribution, please cite this work as:
Don, Don, and Don Don. 2022. “Modeltime Tune.” July 10, 2022. https://dondonkim.netlify.app/posts/2022-07-10-modeltime-tune/tune.html.