ML Regression in R

Visualize regression in Tidymodels with Plotly


New to Plotly?

Plotly is a free and open-source graphing library for R. We recommend you read our Getting Started guide for the latest installation or upgrade instructions, then move on to our Plotly Fundamentals tutorials or dive straight in to some Basic Charts tutorials.

This page shows how to use Plotly charts for displaying various types of regression models, starting from simple models like Linear Regression and progressively move towards models like Decision Tree and Polynomial Features. We highlight various capabilities of plotly, such as comparative analysis of the same model with different parameters, displaying Latex, and surface plots for 3D data.

We will use tidymodels to split and preprocess our data and train various regression models. Tidymodels is a popular Machine Learning (ML) library in R that is compatible with the "tidyverse" concepts, and offers various tools for creating and training ML algorithms, feature engineering, data cleaning, and evaluating and testing models. It is the next-gen version of the popular caret library for R.

Basic linear regression plots

In this section, we show you how to apply a simple regression model for predicting tips a server will receive based on various client attributes (such as sex, time of the week, and whether they are a smoker).

We will be using the Linear Regression, which is a simple model that fits an intercept (the mean tip received by a server), and adds a slope for each feature we use, such as the value of the total bill.

Linear Regression with R

library(reshape2) # to load tips data
library(tidyverse)
library(tidymodels) # for the fit() function
library(plotly)
data(tips)

y <- tips$tip
X <- tips$total_bill

lm_model <- linear_reg() %>% 
  set_engine('lm') %>% 
  set_mode('regression') %>%
  fit(tip ~ total_bill, data = tips) 

x_range <- seq(min(X), max(X), length.out = 100)
x_range <- matrix(x_range, nrow=100, ncol=1)
xdf <- data.frame(x_range)
colnames(xdf) <- c('total_bill')

ydf <- lm_model %>% predict(xdf) 

colnames(ydf) <- c('tip')
xy <- data.frame(xdf, ydf) 

fig <- plot_ly(tips, x = ~total_bill, y = ~tip, type = 'scatter', alpha = 0.65, mode = 'markers', name = 'Tips')
fig <- fig %>% add_trace(data = xy, x = ~total_bill, y = ~tip, name = 'Regression Fit', mode = 'lines', alpha = 1)
fig

Model generalization on unseen data

With add_trace(), you can easily color your plot based on a predefined data split. By coloring the training and the testing data points with different colors, you can easily see if the model generalizes well to the test data or not.

library(reshape2)
library(tidyverse)
library(tidymodels)
library(plotly)
data(tips)

y <- tips$tip
X <- tips$total_bill

set.seed(123)
tips_split <- initial_split(tips)
tips_training <- tips_split %>% 
  training()
tips_test <- tips_split %>% 
  testing()

lm_model <- linear_reg() %>% 
  set_engine('lm') %>% 
  set_mode('regression') %>%
  fit(tip ~ total_bill, data = tips_training) 

x_range <- seq(min(X), max(X), length.out = 100)
x_range <- matrix(x_range, nrow=100, ncol=1)
xdf <- data.frame(x_range)
colnames(xdf) <- c('total_bill')

ydf <- lm_model %>%
  predict(xdf) 

colnames(ydf) <- c('tip')
xy <- data.frame(xdf, ydf) 

fig <- plot_ly(data = tips_training, x = ~total_bill, y = ~tip, type = 'scatter', name = 'train', mode = 'markers', alpha = 0.65) %>% 
  add_trace(data = tips_test, x = ~total_bill, y = ~tip, type = 'scatter', name = 'test', mode = 'markers', alpha = 0.65 ) %>% 
  add_trace(data = xy, x = ~total_bill, y = ~tip, name = 'prediction', mode = 'lines', alpha = 1)
fig

Comparing different kNN models parameters

In addition to linear regression, it's possible to fit the same data using k-Nearest Neighbors. When you perform a prediction on a new sample, this model either takes the weighted or un-weighted average of the neighbors. In order to see the difference between those two averaging options, we train a kNN model with both of those parameters, and we plot them in the same way as the previous graph.

Notice how we can combine scatter points with lines using Plotly. You can learn more about multiple chart types.

library(reshape2)
library(tidyverse)
library(tidymodels)
library(plotly)
library(kknn)
data(tips)

y <- tips$tip
X <- tips$total_bill

# Model #1
knn_dist <- nearest_neighbor(neighbors = 10, weight_func = 'inv') %>% 
  set_engine('kknn') %>% 
  set_mode('regression') %>%
  fit(tip ~ total_bill, data = tips)

# Model #2
knn_uni <- nearest_neighbor(neighbors = 10, weight_func = 'rectangular') %>% 
  set_engine('kknn') %>% 
  set_mode('regression') %>%
  fit(tip ~ total_bill, data = tips) 

x_range <- seq(min(X), max(X), length.out = 100)
x_range <- matrix(x_range, nrow=100, ncol=1)
xdf <- data.frame(x_range)
colnames(xdf) <- c('total_bill')

y_dist <- knn_dist %>%
  predict(xdf) 
y_uni <- knn_uni %>%
  predict(xdf) 

colnames(y_dist) <- c('dist')
colnames(y_uni) <- c('uni')
xy <- data.frame(xdf, y_dist, y_uni) 

fig <- plot_ly(tips, type = 'scatter', mode = 'markers', colors = c("#FF7F50", "#6495ED")) %>% 
  add_trace(data = tips, x = ~total_bill, y = ~tip, type = 'scatter', mode = 'markers', color = ~sex, alpha = 0.65) %>% 
  add_trace(data = xy, x = ~total_bill, y = ~dist, name = 'Weights: Distance', mode = 'lines', alpha = 1) %>%
  add_trace(data = xy, x = ~total_bill, y = ~uni, name = 'Weights: Uniform', mode = 'lines', alpha = 1) 
fig

3D regression surface with mesh3d and add_surface

Visualize the decision plane of your model whenever you have more than one variable in your input data. Here, we will use svm_rbf with kernlab engine in regression mode. For generating the 2D mesh on the surface, we use the pracma package.

library(reshape2)
library(tidyverse)
library(tidymodels)
library(plotly)
library(kernlab)
library(pracma) #For meshgrid()
data(iris)

mesh_size <- .02
margin <- 0
X <- iris %>% select(Sepal.Width, Sepal.Length)
y <- iris %>% select(Petal.Width)

model <- svm_rbf(cost = 1.0) %>% 
  set_engine("kernlab") %>% 
  set_mode("regression") %>% 
  fit(Petal.Width ~ Sepal.Width + Sepal.Length, data = iris)

x_min <- min(X$Sepal.Width) - margin
x_max <- max(X$Sepal.Width) - margin
y_min <- min(X$Sepal.Length) - margin
y_max <- max(X$Sepal.Length) - margin
xrange <- seq(x_min, x_max, mesh_size)
yrange <- seq(y_min, y_max, mesh_size)
xy <- meshgrid(x = xrange, y = yrange)
xx <- xy$X
yy <- xy$Y
dim_val <- dim(xx)
xx1 <- matrix(xx, length(xx), 1)
yy1 <- matrix(yy, length(yy), 1)
final <- cbind(xx1, yy1)
pred <- model %>%
  predict(final)

pred <- pred$.pred
pred <- matrix(pred, dim_val[1], dim_val[2])

fig <- plot_ly(iris, x = ~Sepal.Width, y = ~Sepal.Length, z = ~Petal.Width ) %>% 
  add_markers(size = 5) %>% 
  add_surface(x=xrange, y=yrange, z=pred, alpha = 0.65, type = 'mesh3d', name = 'pred_surface')
fig

Prediction Error Plots

When you are working with very high-dimensional data, it is inconvenient to plot every dimension with your output y. Instead, you can use methods such as prediction error plots, which let you visualize how well your model does compared to the ground truth.

Simple actual vs predicted plot

This example shows you the simplest way to compare the predicted output vs. the actual output. A good model will have most of the scatter dots near the diagonal black line.

library(tidyverse)
library(tidymodels)
library(plotly)
library(ggplot2)

data("iris")

X <- data.frame(Sepal.Width = c(iris$Sepal.Width), Sepal.Length = c(iris$Sepal.Length))
y <- iris$Petal.Width

lm_model <- linear_reg() %>% 
  set_engine('lm') %>% 
  set_mode('regression') %>%
  fit(Petal.Width ~ Sepal.Width + Sepal.Length, data = iris) 

y_pred <- lm_model %>%
  predict(X) 

db = cbind(iris, y_pred)

colnames(db)[4] <- "Ground_truth"
colnames(db)[6] <- "prediction"

x0 = min(y)
y0 = max(y)
x1 = max(y)
y1 = max(y)
p1 <- ggplot(db, aes(x= Ground_truth, y= prediction  )) +
  geom_point(aes(color = "Blue"), show.legend = FALSE) + geom_segment(aes(x = x0, y = x0, xend = y1, yend = y1 ),linetype = 2)


p1 <-  ggplotly(p1)
p1

Enhanced prediction error analysis using ggplotly

Add marginal histograms to quickly diagnoses any prediction bias your model might have.

library(plotly)
library(ggplot2)
library(tidyverse)
library(tidymodels)
data(iris)

X <- iris %>% select(Sepal.Width, Sepal.Length)
y <- iris %>% select(Petal.Width)

set.seed(0)
iris_split <- initial_split(iris, prop = 3/4)
iris_training <- iris_split %>% 
  training()
iris_test <- iris_split %>% 
  testing()

train_index <- as.integer(rownames(iris_training))
test_index <- as.integer(rownames(iris_test))

iris[train_index,'split'] = 'train'
iris[test_index,'split'] = 'test'

lm_model <- linear_reg() %>% 
  set_engine('lm') %>% 
  set_mode('regression') %>%
  fit(Petal.Width ~ Sepal.Width + Sepal.Length, data = iris_training) 

prediction <- lm_model %>%
  predict(X) 
colnames(prediction) <- c('prediction')
iris = cbind(iris, prediction)

hist_top <- ggplot(iris,aes(x=Petal.Width)) + 
  geom_histogram(data=subset(iris,split == 'train'),fill = "red", alpha = 0.2, bins = 6) +
  geom_histogram(data=subset(iris,split == 'test'),fill = "blue", alpha = 0.2, bins = 6) +
  theme(axis.title.y=element_blank(),axis.text.y=element_blank(),axis.ticks.y=element_blank())
hist_top <- ggplotly(p = hist_top)

scatter <- ggplot(iris, aes(x = Petal.Width, y = prediction, color = split)) +
  geom_point() + 
  geom_smooth(formula=y ~ x, method=lm, se=FALSE) 
scatter <- ggplotly(p = scatter, type = 'scatter')

hist_right <- ggplot(iris,aes(x=prediction)) + 
  geom_histogram(data=subset(iris,split == 'train'),fill = "red", alpha = 0.2, bins = 13) +
  geom_histogram(data=subset(iris,split == 'test'),fill = "blue", alpha = 0.2, bins = 13) +
  theme(axis.title.x=element_blank(),axis.text.x=element_blank(),axis.ticks.x=element_blank())+
  coord_flip()
hist_right <- ggplotly(p = hist_right)

s <- subplot(
  hist_top, 
  plotly_empty(), 
  scatter,
  hist_right,
  nrows = 2, heights = c(0.2, 0.8), widths = c(0.8, 0.2), margin = 0,
  shareX = TRUE, shareY = TRUE, titleX = TRUE, titleY = TRUE
)
layout(s, showlegend = FALSE)

Residual plots

Just like prediction error plots, it's easy to visualize your prediction residuals in just a few lines of codes using ggplotly and tidymodels capabilities.

library(plotly)
library(ggplot2)
library(tidyverse)
library(tidymodels)

data(iris)

X <- iris %>% select(Sepal.Width, Sepal.Length)
y <- iris %>% select(Petal.Width)

set.seed(0)
iris_split <- initial_split(iris, prop = 3/4)
iris_training <- iris_split %>% 
  training()
iris_test <- iris_split %>% 
  testing()

train_index <- as.integer(rownames(iris_training))
test_index <- as.integer(rownames(iris_test))

iris[train_index,'split'] = 'train'
iris[test_index,'split'] = 'test'

lm_model <- linear_reg() %>% 
  set_engine('lm') %>% 
  set_mode('regression') %>%
  fit(Petal.Width ~ Sepal.Width + Sepal.Length, data = iris_training) 

prediction <- lm_model %>%
  predict(X) 
colnames(prediction) <- c('prediction')
iris = cbind(iris, prediction)
residual <- prediction - iris$Petal.Width
colnames(residual) <- c('residual')
iris = cbind(iris, residual)

scatter <- ggplot(iris, aes(x = prediction, y = residual, color = split)) + 
  geom_point() + 
  geom_smooth(formula=y ~ x, method=lm, se=FALSE) 

scatter <- ggplotly(p = scatter, type = 'scatter')

violin <- iris %>%
  plot_ly(x = ~split, y = ~residual, split = ~split, type = 'violin' )

s <- subplot(
  scatter,
  violin,
  nrows = 1, heights = c(1), widths = c(0.65, 0.35), margin = 0.01,
  shareX = TRUE, shareY = TRUE, titleX = TRUE, titleY = TRUE
)

layout(s, showlegend = FALSE)

What About Dash?

Dash for R is an open-source framework for building analytical applications, with no Javascript required, and it is tightly integrated with the Plotly graphing library.

Learn about how to install Dash for R at https://dashr.plot.ly/installation.

Everywhere in this page that you see fig, you can display the same figure in a Dash for R application by passing it to the figure argument of the Graph component from the built-in dashCoreComponents package like this:

library(plotly)

fig <- plot_ly() 
# fig <- fig %>% add_trace( ... )
# fig <- fig %>% layout( ... ) 

library(dash)
library(dashCoreComponents)
library(dashHtmlComponents)

app <- Dash$new()
app$layout(
    htmlDiv(
        list(
            dccGraph(figure=fig) 
        )
     )
)

app$run_server(debug=TRUE, dev_tools_hot_reload=FALSE)