Linear Regression
Intro
The linear_regression()
function can be used to fit this kind of model inside a database. It uses dplyr
programming to abstract the steps needed produce a model, so that it can then be translated into SQL statements in the background.
Example setup
A lightweight SQLite database will be used for this article. Additionally, a sample data set is created.
# Open a database connection
con <- DBI::dbConnect(RSQLite::SQLite(), path = ":memory:")
RSQLite::initExtension(con)
library(dplyr)
# Copy data to the database
db_flights <- copy_to(con, nycflights13::flights, "flights")
# Create a simple sample
db_sample <- db_flights %>%
filter(!is.na(arr_time)) %>%
head(20000)
Model inside the database
The linear_regression()
function does not use a formula. It uses a table, and a named dependent variable. This means data preparation is needed prior to running the model. The best way to prepare the data for modeling will be using piped dplyr
operations.
db_sample %>%
select(arr_delay, dep_delay, distance) %>%
linear_regression(arr_delay)
## # A tibble: 1 x 3
## dep_delay distance Intercept
## <dbl> <dbl> <dbl>
## 1 1.00 -0.00337 -0.659
Categorical variables
Adding a categorical a variable to a model requires prior data transformation The add_dummy_variables()
appends a set of boolean variables, one for each discrete value. This function creates one-less discrete variable than the possible values. For example: if the categorical variable has three possible values, the function will append two variables. By default, add_dummy_variables()
removes the original variable.
The reason for this approach is to reduce the number of database operations. Without this step, then a fitting function would have to request all of the unique values every time a new model run, which creates unnecessary processing.
db_sample %>%
select(arr_delay, origin) %>%
add_dummy_variables(origin, values = c("EWR", "JFK", "LGA"))
## # Source: lazy query [?? x 3]
## # Database: sqlite 3.22.0 []
## arr_delay origin_JFK origin_LGA
## <dbl> <dbl> <dbl>
## 1 11. 0. 0.
## 2 20. 0. 1.
## 3 33. 1. 0.
## 4 -18. 1. 0.
## 5 -25. 0. 1.
## 6 12. 0. 0.
## 7 19. 0. 0.
## 8 -14. 0. 1.
## 9 -8. 1. 0.
## 10 8. 0. 1.
## # ... with more rows
In a real world scenario, the possible values are usually not known at the beginning of the analysis. So it is a good idea to load them into a vector variable so that it can be used any time that variable is added to a model. This can be easily done using the pull()
command from dplyr
:
origins <- db_flights %>%
group_by(origin) %>%
summarise() %>%
pull()
origins
## [1] "EWR" "JFK" "LGA"
The add_dummy_variables()
can be used as part of the piped code that terminates in the modeling function.
db_sample %>%
select(arr_delay, origin) %>%
add_dummy_variables(origin, values = origins) %>%
linear_regression(arr_delay)
## # A tibble: 1 x 3
## origin_JFK origin_LGA Intercept
## <dbl> <dbl> <dbl>
## 1 -10.6 -7.79 9.62
Multiple linear regression
One of two arguments is needed to be set when fitting a model with three or more independent variables. The both relate to the size of the data set used for the model. So either the sample_size
argument is passed, or auto_count
is set to TRUE
. When auto_count
is set to TRUE
, and no sample size is passed, then the function will do a table count as part of the model fitting. This is done in order to prevent unnecessary database operations, especially for cases when multiple models will be tested on top of the same sample data.
db_sample %>%
select(arr_delay, arr_time, dep_delay, dep_time) %>%
linear_regression(arr_delay, sample_size = 20000)
## # A tibble: 1 x 4
## arr_time dep_delay dep_time Intercept
## <dbl> <dbl> <dbl> <dbl>
## 1 -0.000208 1.01 -0.00155 -1.72
Interactions
Interactions have to be handled manually prior the modeling step.
db_sample %>%
mutate(distanceXarr_time = distance * arr_time) %>%
select(arr_delay, distanceXarr_time) %>%
linear_regression(arr_delay, sample_size = 20000)
## # A tibble: 1 x 2
## distanceXarr_time Intercept
## <dbl> <dbl>
## 1 -0.00000197 6.77
A more typical model would also include the two original variables:
db_sample %>%
mutate(distanceXarr_time = distance * arr_time) %>%
select(arr_delay, distance, arr_time, distanceXarr_time) %>%
linear_regression(arr_delay, sample_size = 20000)
## # A tibble: 1 x 4
## arr_time distance distanceXarr_time Intercept
## <dbl> <dbl> <dbl> <dbl>
## 1 0.00650 0.00269 -0.00000435 -2.11
Full example
Fitting a model with regular, categorical and interaction variables will look like this:
remote_model <- db_sample %>%
mutate(distanceXarr_time = distance * arr_time) %>%
select(arr_delay, dep_time, distanceXarr_time, origin) %>%
add_dummy_variables(origin, values = origins) %>%
linear_regression(y_var = arr_delay, sample_size = 20000)
remote_model
## # A tibble: 1 x 5
## dep_time distanceXarr_time Intercept origin_JFK origin_LGA
## <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 0.0132 -0.00000275 -3.92 -10.1 -8.05
Run predictions with tidypredict
The as_parsed_model()
function will convert the linear_regression()
model output to an output that the tidypredict
model can read.
parsed <- as_parsed_model(remote_model)
parsed
## # A tibble: 7 x 8
## labels estimate type vals field_1 field_2 field_3 field_4
## <chr> <dbl> <chr> <chr> <chr> <chr> <chr> <chr>
## 1 labels 0. vari~ <NA> dep_ti~ distan~ origin~ origin~
## 2 model NA vari~ lm <NA> <NA> <NA> <NA>
## 3 dep_time 1.32e-2 term <NA> {{:}} <NA> <NA> <NA>
## 4 distanceXarr_time -2.75e-6 term <NA> <NA> {{:}} <NA> <NA>
## 5 (Intercept) -3.92e+0 term <NA> <NA> <NA> <NA> <NA>
## 6 origin_JFK -1.01e+1 term <NA> <NA> <NA> {{:}} <NA>
## 7 origin_LGA -8.05e+0 term <NA> <NA> <NA> <NA> {{:}}
To preview what the prediction SQL statement will look like use tidypredict_sql()
library(tidypredict)
tidypredict_sql(parsed, dbplyr::simulate_dbi())
## <SQL> ("dep_time") * (0.0132251596085814) + ("distanceXarr_time") * (-2.75008443762809e-06) + -3.91880281681301 + ("origin_JFK") * (-10.0912262446948) + ("origin_LGA") * (-8.04838792899506)
Visualize results
Consider using dbplot_raster()
, from the dbplot
package, together with tidypredict
to get an idea of the model’s performance. The dbplot
package pushes the calculation of the plot back to the database, so it will make it easier to view results of a really large sample. The tidypredict_to_column()
function will calculate the prediction inside the database and return a new variable called fit
.
SQLite does not support min()
and max()
so in the example there is a collect()
step, please remove that step when working with a more sophisticated database back end.
library(dbplot)
db_sample %>%
mutate(distanceXarr_time = distance * arr_time) %>%
select(arr_delay, dep_time, distanceXarr_time, origin) %>%
add_dummy_variables(origin, values = origins) %>%
tidypredict_to_column(parsed) %>%
select(fit, arr_delay) %>%
collect() %>% # <----- This step is only needed if working with SQLite!
dbplot_raster(fit, arr_delay, resolution = 50)
## Warning: Removed 25 rows containing missing values (geom_raster).
Predicitions outside the sample
Running predictions can be done by simply taking the same piped data transformations, starting with a different tbl_sql()
variable, such as db_flights
, and terminating them into tidypredict_to_column()
db_flights %>%
mutate(distanceXarr_time = distance * arr_time) %>%
select(arr_delay, dep_time, distanceXarr_time, origin) %>%
add_dummy_variables(origin, values = origins) %>%
tidypredict_to_column(parsed)
## # Source: lazy query [?? x 6]
## # Database: sqlite 3.22.0 []
## arr_delay dep_time distanceXarr_time origin_JFK origin_LGA fit
## <dbl> <int> <dbl> <dbl> <dbl> <dbl>
## 1 11. 517 1162000. 0. 0. -0.277
## 2 20. 533 1203600. 0. 1. -8.23
## 3 33. 542 1005147. 1. 0. -9.61
## 4 -18. 544 1582304. 1. 0. -11.2
## 5 -25. 554 618744. 0. 1. -6.34
## 6 12. 554 532060. 0. 0. 1.94
## 7 19. 555 972345. 0. 0. 0.747
## 8 -14. 557 162361. 0. 1. -5.05
## 9 -8. 557 791072. 1. 0. -8.82
## 10 8. 558 551949. 0. 1. -6.11
## # ... with more rows
For database write-back strategies, also know at “operatioinalizing” or “productionizing”, please refer to this page in the tidypredict
website: https://tidypredict.netlify.com/articles/sql/