t-SNE and UMAP projections in R

Visualize t-SNE and UMAP in R with Plotly.


New to Plotly?

Plotly is a free and open-source graphing library for R. We recommend you read our Getting Started guide for the latest installation or upgrade instructions, then move on to our Plotly Fundamentals tutorials or dive straight in to some Basic Charts tutorials.

t-SNE and UMAP projections in R

This page presents various ways to visualize two popular dimensionality reduction techniques, namely the t-distributed stochastic neighbor embedding (t-SNE) and Uniform Manifold Approximation and Projection (UMAP). They are needed whenever you want to visualize data with more than two or three features (i.e. dimensions).

We first show how to visualize data with more than three features using the scatter plot matrix, then we apply dimensionality reduction techniques to get 2D/3D representation of our data, and visualize the results with scatter plots and 3D scatter plots.

Basic t-SNE projections

t-SNE is a popular dimensionality reduction algorithm that arises from probability theory. Simply put, it projects the high-dimensional data points (sometimes with hundreds of features) into 2D/3D by inducing the projected data to have a similar distribution as the original data points by minimizing something called the KL divergence.

Compared to a method like Principal Component Analysis (PCA), it takes significantly more time to converge, but present significantly better insights when visualized. For example, by projecting features of flowers, it will be able to distinctly group

Visualizing high-dimensional data with splom

First, let's try to visualize every feature of the Iris dataset, and color everything by the species. We will use the Scatter Plot Matrix (splom), which lets us plot each feature against everything else, which is convenient when your dataset has more than 3 dimensions.

library(plotly) 
library(stats) 
data(iris) 
X <- subset(iris, select = -c(Species)) 
axis = list(showline=FALSE, 
            zeroline=FALSE, 
            gridcolor='#ffff', 
            ticklen=4)
fig <- iris %>%  
  plot_ly()  %>%  
  add_trace(  
    type = 'splom',  
    dimensions = list( 
      list(label = 'sepal_width',values=~Sepal.Width),  
      list(label = 'sepal_length',values=~Sepal.Length),  
      list(label ='petal_width',values=~Petal.Width),  
      list(label = 'petal_length',values=~Petal.Length)),  
    color = ~Species, colors = c('#636EFA','#EF553B','#00CC96') 
  ) 
fig <- fig %>% 
  layout( 
    legend=list(title=list(text='species')), 
    hovermode='closest', 
    dragmode= 'select', 
    plot_bgcolor='rgba(240,240,240,0.95)', 
    xaxis=list(domain=NULL, showline=F, zeroline=F, gridcolor='#ffff', ticklen=4), 
    yaxis=list(domain=NULL, showline=F, zeroline=F, gridcolor='#ffff', ticklen=4), 
    xaxis2=axis, 
    xaxis3=axis, 
    xaxis4=axis, 
    yaxis2=axis, 
    yaxis3=axis, 
    yaxis4=axis 
  ) 
fig

Project data into 2D with t-SNE and px.scatter

Now, let's use the t-SNE algorithm to project the data shown above into two dimensions. Notice how each of the species is physically separate from each other.

library(tsne)
library(plotly)
data("iris")

features <- subset(iris, select = -c(Species)) 

set.seed(0)
tsne <- tsne(features, initial_dims = 2)
tsne <- data.frame(tsne)
pdb <- cbind(tsne,iris$Species)
options(warn = -1)
fig <-  plot_ly(data = pdb ,x =  ~X1, y = ~X2, type = 'scatter', mode = 'markers', split = ~iris$Species)

fig <- fig %>%
  layout(
    plot_bgcolor = "#e5ecf6"
  )

fig

Project data into 3D with t-SNE and px.scatter_3d

t-SNE can reduce your data to any number of dimensions you want! Here, we show you how to project it to 3D and visualize with a 3D scatter plot.

library(tsne)
library(plotly)
data("iris")

features <- subset(iris, select = -c(Species)) 

#set.seed(0)
tsne <- tsne(features, initial_dims = 3, k =3)
tsne <- data.frame(tsne)
pdb <- cbind(tsne,iris$Species)
options(warn = -1)
fig <-  plot_ly(data = pdb ,x =  ~X1, y = ~X2, z = ~X3, color = ~iris$Species, colors = c('#636EFA','#EF553B','#00CC96') ) %>% 
  add_markers(size = 8) %>%
  layout( 
    xaxis = list(
      zerolinecolor = "#ffff",
      zerolinewidth = 2,
      gridcolor='#ffff'), 
    yaxis = list(
      zerolinecolor = "#ffff",
      zerolinewidth = 2,
      gridcolor='#ffff'),
    scene =list(bgcolor = "#e5ecf6"))
fig

Projections with UMAP

Just like t-SNE, UMAP is a dimensionality reduction specifically designed for visualizing complex data in low dimensions (2D or 3D). As the number of data points increase, UMAP becomes more time efficient compared to TSNE.

In the example below, we see how easy it is to use UMAP in R.

library(plotly) 
library(umap) 
iris.data = iris[, grep("Sepal|Petal", colnames(iris))] 
iris.labels = iris[, "Species"] 
iris.umap = umap(iris.data, n_components = 2, random_state = 15) 
layout <- iris.umap[["layout"]] 
layout <- data.frame(layout) 
final <- cbind(layout, iris$Species) 

fig <- plot_ly(final, x = ~X1, y = ~X2, color = ~iris$Species, colors = c('#636EFA','#EF553B','#00CC96'), type = 'scatter', mode = 'markers')%>%  
  layout(
    plot_bgcolor = "#e5ecf6",
    legend=list(title=list(text='species')), 
    xaxis = list( 
      title = "0"),  
    yaxis = list( 
      title = "1")) 

iris.umap = umap(iris.data, n_components = 3, random_state = 15) 
layout <- iris.umap[["layout"]] 
layout <- data.frame(layout) 
final <- cbind(layout, iris$Species) 

fig2 <- plot_ly(final, x = ~X1, y = ~X2, z = ~X3, color = ~iris$Species, colors = c('#636EFA','#EF553B','#00CC96')) 
fig2 <- fig2 %>% add_markers() 
fig2 <- fig2 %>% layout(scene = list(xaxis = list(title = '0'), 
                                     yaxis = list(title = '1'), 
                                     zaxis = list(title = '2'))) 

fig 
fig2

Visualizing image datasets

In the following example, we show how to visualize large image datasets using UMAP.

Although there's over 1000 data points, and many more dimensions than the previous example, it is still extremely fast. This is because UMAP is optimized for speed, both from a theoretical perspective, and in the way it is implemented. Learn more in this comparison post.

library(rsvd) 
library(plotly) 
library(umap) 
data('digits') 
digits.data = digits[, grep("pixel", colnames(digits))] 
digits.labels = digits[, "label"] 
digits.umap = umap(digits.data, n_components = 2, k = 10) 
layout <- digits.umap[["layout"]] 
layout <- data.frame(layout) 
final <- cbind(layout, digits[,'label']) 
colnames(final) <- c('X1', 'X2', 'label') 

fig <- plot_ly(final, x = ~X1, y = ~X2, split = ~label,  type = 'scatter', mode = 'markers')%>%  
  layout(  
    plot_bgcolor = "#e5ecf6",
    legend=list(title=list(text='digit')), 
    xaxis = list( 
      title = "0"),  
    yaxis = list( 
      title = "1")) 
fig

Reference

Plotly figures: * https://plotly.com/r/line-and-scatter/

  • https://plotly.com/r/3d-scatter-plots/

  • https://plotly.com/r/splom/

Details about algorithms: * UMAP library: https://umap-learn.readthedocs.io/en/latest/

  • t-SNE User guide: https://cran.r-project.org/web/packages/tsne/tsne.pdf

  • t-SNE paper: https://www.jmlr.org/papers/volume9/vandermaaten08a/vandermaaten08a.pdf

  • MNIST: http://yann.lecun.com/exdb/mnist/

What About Dash?

Dash for R is an open-source framework for building analytical applications, with no Javascript required, and it is tightly integrated with the Plotly graphing library.

Learn about how to install Dash for R at https://dashr.plot.ly/installation.

Everywhere in this page that you see fig, you can display the same figure in a Dash for R application by passing it to the figure argument of the Graph component from the built-in dashCoreComponents package like this:

library(plotly)

fig <- plot_ly() 
# fig <- fig %>% add_trace( ... )
# fig <- fig %>% layout( ... ) 

library(dash)
library(dashCoreComponents)
library(dashHtmlComponents)

app <- Dash$new()
app$layout(
    htmlDiv(
        list(
            dccGraph(figure=fig) 
        )
     )
)

app$run_server(debug=TRUE, dev_tools_hot_reload=FALSE)