# install.packages("fpp3")
library(tidymodels)
library(torch)
library(tidyverse)
library(magrittr)
library(skimr)
library(knitr)
library(fpp3)
library(modeltime)
library(timetk)
library(data.table)
ggplot2::theme_set(ggplot2::theme_bw())Back to the basic.
Library load
Dataset load
These are the file list in the competition. We will only use the train and test data for this notebook. Remember, make it simple.
dat <- read_csv("train.csv") %>%
janitor::clean_names() Warning in FUN(X[[i]], ...): unable to translate '<U+00C4>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00D6>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00DC>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00E4>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00F6>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00FC>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00DF>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00C6>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00E6>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00D8>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00F8>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00C5>' to native encoding
Warning in FUN(X[[i]], ...): unable to translate '<U+00E5>' to native encoding
<8d> <84><95>
dat <- dat %>%
filter(country == 'Norway' & store == 'KaggleRama' & product == "Kaggle Mug") %>% select(-c(row_id, country, store, product)) %>%
rename(value = num_sold)dat %>%
plot_time_series(.date_var = date,
.value = value)splits <- initial_time_split(dat, prop = 0.7)store_recipe <-
recipe(value ~ ., data = training(splits)) %>%
step_timeseries_signature(date) %>%
step_rm(matches("(iso)|(xts)|(hour)|(minute)|(second)|(am.pm)")) %>%
step_normalize(matches("(index.num)|(year)|(yday)")) %>%
step_dummy(all_nominal(), one_hot = TRUE) %>%
step_interact(~ matches("week2") * matches("wday.lbl")) %>%
step_fourier(date, period = c(7, 14, 30, 90, 365), K = 2)resamples_tscv_lag <- time_series_cv(
data = training(splits) %>% drop_na(),
cumulative = TRUE,
initial = "2 months",
assess = "20 weeks",
skip = "2 weeks",
slice_limit = 6
)Using date_var: date
resamples_tscv_lag %>%
tk_time_series_cv_plan() %>%
plot_time_series_cv_plan(date, value)model_spec_nnetar <- nnetar_reg(
seasonal_period = 7,
non_seasonal_ar = tune(id = "non_seasonal_ar"),
seasonal_ar = tune(),
hidden_units = tune(),
num_networks = 10,
penalty = tune(),
epochs = 50
) %>%
set_engine("nnetar")set.seed(123)
grid_spec_nnetar_1 <- grid_latin_hypercube(
parameters(model_spec_nnetar),
size = 15
)Warning: `parameters.model_spec()` was deprecated in tune 0.1.6.9003.
Please use `hardhat::extract_parameter_set_dials()` instead.
wflw_fit_nnetar <- workflow() %>%
add_recipe(store_recipe) %>%
add_model(model_spec_nnetar)library(doFuture)Loading required package: foreach
Attaching package: 'foreach'
The following objects are masked from 'package:purrr':
accumulate, when
Loading required package: future
registerDoFuture()
n_cores <- parallel::detectCores()
plan(
strategy = cluster,
workers = parallel::makeCluster(n_cores)
)
library(tictoc)
tic()
set.seed(123)
tune_results_nnetar_1 <- wflw_fit_nnetar %>%
tune_grid(
resamples = resamples_tscv_lag,
grid = grid_spec_nnetar_1,
metrics = default_forecast_accuracy_metric_set(),
control = control_grid(save_pred = TRUE)
)! Slice1: preprocessor 1/1: `terms_select()` was deprecated in recipes 0.1.17.
Ple...
! Slice1: preprocessor 1/1, model 1/15: unable to translate '<U+00C4>' to native e...
! Slice2: preprocessor 1/1: `terms_select()` was deprecated in recipes 0.1.17.
Ple...
! Slice2: preprocessor 1/1, model 1/15: unable to translate '<U+00C4>' to native e...
! Slice3: preprocessor 1/1: `terms_select()` was deprecated in recipes 0.1.17.
Ple...
! Slice3: preprocessor 1/1, model 1/15: unable to translate '<U+00C4>' to native e...
! Slice4: preprocessor 1/1: `terms_select()` was deprecated in recipes 0.1.17.
Ple...
! Slice4: preprocessor 1/1, model 1/15: unable to translate '<U+00C4>' to native e...
! Slice5: preprocessor 1/1: `terms_select()` was deprecated in recipes 0.1.17.
Ple...
! Slice5: preprocessor 1/1, model 1/15: unable to translate '<U+00C4>' to native e...
! Slice6: preprocessor 1/1: `terms_select()` was deprecated in recipes 0.1.17.
Ple...
! Slice6: preprocessor 1/1, model 1/15: unable to translate '<U+00C4>' to native e...
toc()36.27 sec elapsed
tune_results_nnetar_1$.notes[[1]]$note[1] "`terms_select()` was deprecated in recipes 0.1.17.\nPlease use `recipes_eval_select()` instead."
[2] "unable to translate '<U+00C4>' to native encoding, unable to translate '<U+00D6>' to native encoding, unable to translate '<U+00DC>' to native encoding, unable to translate '<U+00E4>' to native encoding, unable to translate '<U+00F6>' to native encoding, unable to translate '<U+00FC>' to native encoding, unable to translate '<U+00DF>' to native encoding, unable to translate '<U+00C6>' to native encoding, unable to translate '<U+00E6>' to native encoding, unable to translate '<U+00D8>' to native encoding, unable to translate '<U+00F8>' to native encoding, unable to translate '<U+00C5>' to native encoding, unable to translate '<U+00E5>' to native encoding"
set.seed(123)
wflw_fit_nnetar_tscv <- wflw_fit_nnetar %>%
finalize_workflow(
tune_results_nnetar_1 %>%
show_best(metric = "rmse", n = Inf) %>%
dplyr::slice(1)
) %>%
fit(training(splits))Warning: `terms_select()` was deprecated in recipes 0.1.17.
Please use `recipes_eval_select()` instead.
This warning is displayed once every 8 hours.
Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
wflw_fit_nnetar_tscv== Workflow [trained] ==========================================================
Preprocessor: Recipe
Model: nnetar_reg()
-- Preprocessor ----------------------------------------------------------------
6 Recipe Steps
* step_timeseries_signature()
* step_rm()
* step_normalize()
* step_dummy()
* step_interact()
* step_fourier()
-- Model -----------------------------------------------------------------------
Series: outcome
Model: NNAR(3,1,1)[7]
Call: forecast::nnetar(y = outcome, p = p, P = P, size = size, repeats = repeats,
xreg = xreg_matrix, decay = decay, maxit = maxit)
Average of 10 networks, each of which is
a 66-1-1 network with 69 weights
options were - linear output units decay=8.006317e-10
sigma^2 estimated as 4413
pred <- predict(wflw_fit_nnetar_tscv, testing(splits))
testing(splits) %>%
bind_cols(pred) %>%
ggplot() +
geom_line(aes(x = date, y = value), color = "blue") +
geom_line(aes(x = date, y = .pred), color = "red") 
Citation
BibTeX citation:
@online{don2022,
author = {Don, Don and Don, Don},
title = {Modeltime Tune},
date = {2022-07-10},
url = {https://dondonkim.netlify.app/posts/2022-07-10-modeltime-tune/tune.html},
langid = {en}
}
For attribution, please cite this work as:
Don, Don, and Don Don. 2022. “Modeltime Tune.” July 10,
2022. https://dondonkim.netlify.app/posts/2022-07-10-modeltime-tune/tune.html.