Machine Learning Demystified

What is this course?

  • Cut through mystery and hype surrounding machine learning (ML)
  • Introduce you to the basic principles behind machine learning
  • “Bust” some of the common myths about machine learning
  • Walk through two different machine learning tasks in R

Minimal prerequisites

  • Some experience with statistical analysis
  • A little bit of exposure to R

Advanced prerequisites

  • Ability to write model formulas in R
  • Familiarity with tidyverse for R data manipulation and plotting

How to follow the course

  • Slides and text version of lessons are online
  • Fill in code in the worksheet (replace ... with code)
  • You can always copy and paste code from text version of lesson if you fall behind

Conceptual learning objectives

At the end of this course, you will understand …

  • What machine learning is and how it is similar to and different from statistics
  • Why the common myths about machine learning are not true
  • What the difference between prediction and inference is as a goal
  • What training and testing data are
  • What a loss function is
  • What supervised and unsupervised learning are
  • What the basic steps of a machine learning workflow are

Practical skills

At the end of this course, you will be able to …

  • Explore and pre-process data for machine learning models
  • Fit a random forest model in R for a classification task
  • Fit a lasso regression in R for a regression task
  • Cross-validate ML models
  • Use the R package caret to fit many kinds of ML models using the same code syntax

What is machine learning?

Image credit Shutterstock

Machine learning myths: Busted!

Myth Busted

  • Myth 1. Machine learning is way more powerful than statistics
  • Myth 2. Machine learning is too fancy/new/different for me to learn
  • Myth 3. My data aren’t big enough to do machine learning
  • Myth 4. I don’t need machine learning

Myth 1. Machine learning is way more powerful than statistics — BUSTED!

  • Machine learning is the (over)hyped hot new thing
  • Statistics has been around a long time and its limitations are well known
  • But ML models are essentially statistical models, and all models are only as good as the assumptions they make and the data you put into them

original by sandserif comics

Myth 2. Machine learning is too fancy/new/different for me to learn — BUSTED!

  • Yes some are very complex and require a lot of theoretical background to understand
  • But a lot of software tools exist to make it easy to fit ML models … almost too easy in fact
  • Having a modest level of understanding of the models you work with is a good thing

Myth 3. My data aren’t big enough to do machine learning — BUSTED!

  • Big Data: an even more overhyped buzzword than machine learning?
  • Yes, some ARS researchers have truly huge datasets that need ML techniques
  • But ML can be a useful tool even with relatively small “hand-collected” datasets

Image credit Silicon Angle

Myth 4. I don’t need machine learning — BUSTED!

  • ML is not just hot air and hype
  • ML techniques are useful across all scientific fields
  • The focus on predictive modeling is changing the way all statistics and all research are done
  • All scientists should have basic ML literacy

But what is machine learning, anyway?

  • Machine learning: any job you give a machine (computer) to do, that it gets better at doing, the more data it gets
  • ML models, including statistical models, are machines that take data as input and spit out predictions as output
  • Ideally, the more data you have, the better your model will be at its job of making predictions

What jobs do machine learning models do?

Regression and classification

  • Classification: predicting which discrete category something belongs to
  • Regression: predicting a continuous outcome variable

Supervised versus unsupervised classification

Supervised and unsupervised classification

  • Supervised: you know which category each of your data points belong to, and try to find patterns in the data that predict which category new data points belong to
  • Unsupervised: you are trying to find natural groups or clusters in your data without knowing beforehand which group each data point belongs to

Supervised versus unsupervised learning: examples

  • supervised: dataset of some patients with heart disease and some without, try to predict heart disease from clinical measurements on the patients
  • unsupervised: Google News categorizing news articles together that are about the same news event
  • Supervised learning is great but requires better quality data so large-scale projects may need unsupervised approach

What kind of machine learning is going on at ARS?

Image credit 2016CIAT/Neil Palmer
  • Genomic selection for plant and livestock breeding
  • Identifying weeds from drone imagery
  • Exploring the effect of diet on gut bacterial microbiome composition

What’s the big deal with prediction?

  • Inference: finding patterns in data and using them to understand processes going on in the world
  • Prediction: finding patterns in data and correctly reproducing them with a model
  • For prediction, how is more important than why

The importance of causal associations

Spurious correlation?

  • An ML model that incorporates true causal relationships between variables, and not just chance associations, is more robust.
  • Take shark attacks and ice cream sales as an example
  • The model might be predictive as long as the confounding variable affects both, but if that changes, it will fail
  • Better to “use your noggin” and try to model true causal relationships

What’s a loss function?

  • Loss function: A function that usually represents some form of prediction error; when you fit an ML model you want to minimize it
  • In a linear regression it’s the sum of squared residuals (distance from data points to the regression line)
  • ML models sometimes use mean squared error (MSE) as the loss function, which is the same as traditional linear regression.
  • Sometimes they use mean absolute error (MAE) is another one (mean of the absolute values of the distances from the data points to the fitted line), which is less sensitive to outliers
  • Classification tasks use different loss functions

What is regularization?

Regularization balances between underfitting and overfitting

  • Overfitting: Model fits too closely to data and its predictions reproduce random noise
  • Underfitting: Model doesn’t fit data closely enough and its predictions miss true patterns
  • Regularization: Anything that makes a model less complex and more general
    • Includes things like variable selection, model selection, random effects

Balance overfitting and underfitting

  • Choose ideal amount of regularization to balance between overfitting and underfitting
  • Fit the training data a little less closely to get better performance making predictions on new data
  • ML loss functions have regularization parameters that determines the optimal level of regularization

How to fit a machine learning model

Image credit Cold Spring Harbor Lab

Training and test sets

  • Training set: data used to fit the model
  • Test set: data used to evaluate the model
  • Ideally they are as independent from one another as possible
  • Data leakage: Occurs when the training and test sets are not independent; random noise in the training set is correlated with random noise in the test set
  • There is a limit to how different the training and test set should be
  • Transferability: The ability of a model to produce accurate predictions in new contexts
  • Get data to test your model that is independent from your training data but still within the domain you want to make predictions in

Training and testing split

Internal validation

  • If no independent test set is available, you have to set aside a portion of your data as the test set
  • Model never “sees” the test data when estimating parameters, only to test its predictions
  • Even though it “wastes” data, it is good to reserve 10% to 20% of your data to test the model
  • Select test sample to minimize data leakage (stratified random sample by categories and/or by blocks)

Data visualization and pre-processing

  • All data analysis and modeling should start with looking at the data!
  • Erroneous or invalid data points should be removed
  • Don’t necessarily remove “outliers” because you won’t be able to do that later when making predictions with the model

Removing useless features

  • Feature: a term used in the ML literature for predictor or x variable
  • Before fitting a model, remove features that don’t carry much variation
    • Features with no variance or very little variance (for example, a column with 998 zeros and 2 ones)
    • Features that are highly correlated with other features
  • Those features contribute nothing to predction performance and just slow down computation

Standardizing variables

  • Variables/features usually need to be transformed to a common scale
    • Often a z-transformation is used
  • Otherwise the model will give more weight to variables that happen to have larger magnitudes of their units
  • To avoid data leakage do the standardization separately on the training set and the test set

Tuning the model

  • We want to estimate parameters on our training set that will maximize prediction performance on the test set
  • Some parameters capture the effects of the different predictor variables on the outcome variable
  • Tuning parameters, or hyperparameters: special parameters that do things like set the balance between overfitting and underfitting

Image credit Rebecca Wilson

Cross-validation

  • Cross-validation: Repeatedly splitting a dataset into different training and test sets, and fitting the model on different training set each time until all data points have been included into the test set once.
  • Split up dataset into k equally-sized folds
  • Repeat the cross-validation process for a lot of other sets of tuning parameters
  • The set of tuning parameters that minimizes your loss function (has the best prediction performance) is the one you will use to fit the final model

Cross-validation

Cross-validation, continued

  • It is a good idea to repeat cross-validation multiple times because results may depend on how the folds are split
  • If you have a blocking structure in your data, make sure that the blocks aren’t split up into more than one fold, or you will get overinflated estimates of model performance

Fitting and validating the final model

  • Fit a final model to the entire training set using optimal tuning parameters selected by cross-validation
  • Now we can see how well the model does on data it’s never seen before!
  • Make predictions on the test set that was originally split off from the training set back at the start of this process, and evaluate performance
  • There is no universal threshold of whether performance is good enough; it is a practical question

Model tuning and validation workflow

Assessing variable importance

  • ML models can be used for inference as well
  • Different methods exist to quantify how important each feature is for predicting the outcome
  • Not a main focus of this intro lesson

Do it all over again

  • You can fit different ML models on the same training set and use the one that performs best on the test set
  • As long as you use regularization to prevent overfitting
  • More complicated is not always better!
  • The “best” model is one that gives acceptable prediction performance and one that you can understand what it’s doing

Why don’t people use this kind of workflow in conventional statistics?

  • Good question? They should!
  • Overfitting is more of a problem if your goal is prediction instead of understanding a natural process, but it’s always a problem
  • Lots of publications are based on mining data and making conclusions based on patterns that are really random noise
  • Cross-validation and external validation are excellent tests of a model’s performance and can help with both prediction and inference

A note about software

Image credit Math is in the Air

  • This lesson uses R
  • There is a perception that Python is best for ML
  • But R, Python, and SAS all have good libraries for ML

A note of caution

Image credit Sourav Bose
  • Lots of software platforms claim to make ML easy: data goes in, result comes out!
  • It is good to make models accessible without requiring much coding knowledge
  • But it can be dangerous to get results without any understanding of how you got them
  • No real substitute for looking at the data and customizing the model for your particular situation
  • R package tidymodels is used in many intro ML tutorials, but not this one

A final big picture note

  • Machine learning models use regularization penalties to avoid overfitting to your dataset and make good general predictions
  • Mixed models use random effects/shrinkage to avoid overfitting to your dataset and make good general predictions
  • Bayesian models use skeptical priors to avoid overfitting to your dataset and make good general predictions
  • They’re all good tools to use; the more tools you know how to use, the better

Demo 1: Date variety classification

Image credit Britannica

Introducing the example data

  • From Kaggle, a site that hosts lots of free example datasets for predictive modeling competitions, including many ag datasets
  • date fruit classification dataset provided by Murat Koklu
  • Photos were taken of 898 total dates of seven different varieties
  • Image analysis software used to take 34 measurements from each photo (one row of data = one fruit)
  • Class column indicates variety

Demo 1 workflow

  • Import the dataset and visualize it
  • Remove erroneous data points
  • Remove highly correlated predictor variables
  • Standardize the predictor variables
  • Split the dataset into training and testing sets
  • Fit a random forest model, using five-fold cross-validation, for each value of a tuning parameter
  • Identify the tuning parameter that maximizes accuracy on the hold-out folds
  • Fit the final model on the full training set
  • Validate the model by calculating accuracy on the testing set

Load packages and import data

  • Load needed R packages
library(tidyverse)
library(randomForest)
library(caret)
library(ggcorrplot)
library(patchwork)
  • Import the data from CSV and convert Class to a factor variable
date_fruit <- read_csv('https://usda-ree-ars.github.io/SEAStats/machine_learning_demystified/datasets/date_fruit.csv') %>% 
  mutate(Class = factor(Class))

Inspect the data

  • Data is reasonably well balanced; enough to represent all varieties in training and test sets
dim(date_fruit)
table(date_fruit$Class)

Plot the data

  • Plot each of the response variables as boxplots grouped by Class
  • Patterns by variety? Shape of data distribution? Weird data points?
predictor_variables <- names(date_fruit)[!names(date_fruit) %in% 'Class']

date_long <- pivot_longer(date_fruit, cols = all_of(predictor_variables), names_to = 'variable')

ggplot(date_long, aes(x = Class, y = value)) +
  geom_boxplot(aes(fill = Class)) +
  facet_wrap(~ variable, scales = 'free_y') +
  theme_bw() +
  theme(legend.position = 'none', axis.text.x = element_text(angle = 45, hjust = 1))

Remove erroneous values

  • Aspect ratio (length to width ratio) should not be >500
  • Simplest in this case to remove the entire row containing the erroneous value
date_fruit <- filter(date_fruit, ASPECT_RATIO < 500)

nrow(date_fruit)

Inspect correlated features

  • Heatmap to look at the Pearson correlation between each pair of variables
  • Some of the correlations are pretty close to ± 1
date_corr <- cor(date_fruit[, predictor_variables])

ggcorrplot(date_corr, type = 'lower') 

Remove highly correlated features

  • findCorrelation() in caret package identifies any variables correlated with at least one other variable above a certain threshold
  • We’re using \(|r| > 0.95\)
correlated_variables <- findCorrelation(date_corr, cutoff = 0.95, names = TRUE)

date_fruit <- date_fruit %>% select(-all_of(correlated_variables))

Training-test split

  • Use 90% of the data for training the model, set aside 10% as the test set
  • caret function createDataPartition() uses stratified random sampling to balance training and test sets
set.seed(1)
trainIndex <- createDataPartition(date_fruit$Class, p = .9, list = FALSE)
date_train <- as.data.frame(date_fruit[trainIndex, ])
date_test <- as.data.frame(date_fruit[-trainIndex, ])

Random forest

  • We’re going to fit a random forest model to classify the dates into variety — what’s that?!?
  • Randomly samples the data with replacement many different times and fits a decision tree to each random sample
    • But what’s a decision tree?!?

Decision tree

  • Decision tree: A way of sequentially splitting up a dataset with a series of rules that ultimately lead to a class or value

Decision tree

Random forest, explained

  • Random forest grows a “forest” of decision trees and uses majority vote of the trees to make the final classification rules
  • For each random sample of the data, some data points are selected (in the bag) and others are not (out of the bag)
  • Algorithm bases its final model on trees’ prediction performance on data points that are “out of bag” for each tree
  • This provides some regularization, but we are also going to use cross-validation

Training the random forest model with caret

  • caret package used to train ML models
  • train() is the main workhorse function
  • First specify the model training method using trainControl()
cv_spec <- trainControl(method = 'cv', number = 5)

Using train()

date_rf_train <- train(
  form = Class ~ .,
  data = date_train,
  method = 'rf',
  preProcess = c('center', 'scale'),
  tuneGrid = data.frame(mtry = 1:10),
  metric = 'Accuracy',
  trControl = cv_spec,
  ntree = 500,
  importance = TRUE
)

Arguments to train()

  • form: the model formula
    • Class ~ . means to use all other variables to predict Class
  • data: data frame containing variables
  • method = 'rf': fit a random forest model using the R package randomForest in the back end.
  • preProcess = c('center', 'scale'): z-transform all variables

Arguments to train(), continued

  • tuneGrid: data frame with all combination of tuning parameters
    • Here, we are only going to tune one parameter, mtry, the number of variables randomly selected each time a new split point is determined.
    • Higher mtry means a closer fit to the data at the risk of overfitting
  • metric = 'Accuracy': metric of model performance (i.e. the loss function) is 'Accuracy', the overall proportion of correct classifications
  • trControl: pass the model training specification we created using trainControl()
  • ntree is an argument to the randomForest() function, the number of trees generated for each random forest model
    • The higher the better, but takes longer to compute
  • importance = TRUE: variable importance values will be calculated for us to look at later

Training the model

  • When you run train():
    • Pre-processes data
    • Fits five cross-validation folds for each of the ten values of mtry we provide
    • Selects the optimal value of mtry that maximizes prediction accuracy of the hold-out folds
    • Fits the model with the optimal value of mtry on the full training set!

Examining the results

  • plot() shows the accuracy for each value of mtry. What do you think?
plot(date_rf_train)
  • Typing the name of the model into the console provides summary info
date_rf_train

Confusion matrix

  • confusion matrix: A matrix showing the true and false predictions of a classification model
  • confusionMatrix() returns one with model’s predicted classes down the rows and true classes across the columns
  • 100% accurate model would have zeros everywhere but the diagonal
confusionMatrix(date_rf_train)

Validating the model

  • The moment of truth: testing the model on a new dataset it’s never seen before
  • Here we use the 10% of the original dataset that we split off before tuning the model
  • predict() function has a method for caret models
  • Data passed to newdata argument must have a column for all predictor variables that are in the model
date_predict_test <- predict(date_rf_train, newdata = date_test)

confusionMatrix(date_predict_test, date_test$Class)

Variable importance

  • varImp() calculates variable importance scores for many different types of models
  • Mean decrease in accuracy (percent increase in MSE) when you remove a variable from the model
  • type = 1 indicates to give overall scores, not separately by class
  • scale = FALSE gives us the raw importance values, otherwise they are scaled to a maximum of 100
varImp(date_rf_train, type = 1, scale = FALSE)

Fitting other machine learning models

  • train() can be used to fit other ML models with very similar arguments
  • Here we fit a type of model called support vector machine (SVM)
  • Only differences are method and tuneGrid arguments
  • C is the name of the parameter that determines balance between overfitting and underfitting
    • Small C values do not fit the data quite as closely, protecting against overfitting

Using train() to cross-validate SVM model

date_svm_train <- train(
  form = Class ~ .,
  data = date_train,
  method = 'svmLinear',
  preProcess = c('center', 'scale'),
  tuneGrid = data.frame(C = seq(0.02, 1, by = 0.02)),
  metric = 'Accuracy',
  trControl = cv_spec
)

Evaluating SVM model

plot(date_svm_train)

date_svm_predict_test <- predict(date_svm_train, newdata = date_test)

confusionMatrix(date_svm_predict_test, date_test$Class)

Demo 2: Sugarcane yield prediction

Image credit Hannah Penn

Introducing the example data

  • Predicting sugarcane yield (tons of sugar per hectare or TSH); regression not classification this time
  • Can we use imagery from a hyperspectral camera taken in the field to predict yield?
  • Real ARS data (with noise added)
  • Drones flown over crop and reflectance at many different wavelengths of light were recorded
  • Different vegetation indices calculated from combinations of the reflectance values
  • Rep column indicates which experimental block each row of data comes from
  • Information on variety and treatment excluded for this demo

Lasso regression

  • We’re using a form of linear regression called lasso
  • Similar to a multiple linear regression with a regularization parameter \(\lambda\) (lambda) included in the regression equation.
  • This parameter causes the regression coefficients to be shrunk toward zero, or even drop out of the model entirely
  • The larger the \(\lambda\) the stricter the regularization and the more coefficients drop to zero
  • This results in a simpler model that doesn’t fit training data quite as closely as standard linear regression, but is less overfit and makes better predictions on test data

Import and explore the data

  • Plot a smoothing trend of the outcome variable versus each predictor
sugarcane <- read_csv('https://usda-ree-ars.github.io/SEAStats/machine_learning_demystified/datasets/sugarcaneyield.csv')

sugarcane_predictor_variables <- setdiff(names(sugarcane), c('TSH', 'Rep'))

scatter_plots <- map(sugarcane_predictor_variables, ~ ggplot(sugarcane, aes(x = !!sym(.), y = TSH)) + 
                       geom_point() + 
                       geom_smooth(method = 'gam') +
                       theme_bw())

wrap_plots(scatter_plots, ncol = 4)

Inspect correlated features

  • Many are highly correlated, especially vegetation indices
  • Instead of removing them, we will see how lasso deals with them
sugarcane_corr <- cor(sugarcane[, sugarcane_predictor_variables])

ggcorrplot(sugarcane_corr, type = 'lower') 

Set up pre-defined cross-validation folds

  • No initial training-test split; use the whole dataset for model fitting
  • Manually specify the experimental blocks as cross-validation folds, instead of assigning data points at random to folds
  • map() function creates list of vectors of the row numbers that belong to each block
cv_folds <- map(1:6, function(x) which(sugarcane$Rep == x))
cv_spec_sugarcane <- trainControl(method = 'cv', number = 6, indexOut = cv_folds)

Train the model with cross-validation

  • train() used as before
  • Instead of a model formula with form argument, supply x (data frame) and y (vector)
  • glmnet is used as method
  • Center and scale the predictors
  • tuneGrid includes a range of lambda values increasing by factors of ten and sets alpha = 1
  • metric = 'RMSE', or root mean squared error; lower is better
  • Pass pre-defined folds to the trControl argument

Using train() to cross-validate lasso model

set.seed(3)

tsh_lasso_fit <- train(
  x = sugarcane %>% select(all_of(sugarcane_predictor_variables)) %>% as.data.frame,
  y = sugarcane$TSH,
  method = 'glmnet',
  preProcess = c('center', 'scale'),
  tuneGrid = expand.grid(
    alpha = 1,
    lambda = c(0.00001, 0.0001, 0.001, 0.01, 0.1)
  ),
  metric = 'RMSE',
  trControl = cv_spec_sugarcane
)

Evaluating the model

  • Plotting the model object shows what value of \(\lambda\) minimized RMSE in cross-validation
  • Intermediate \(\lambda\) is best but there is very little variation in RMSE between best and worst
plot(tsh_lasso_fit)
  • bestTune element of fitted model object shows optimal tuning parameters
tsh_lasso_fit$bestTune
  • Typing model name into console gives prediction performance for all tested \(\lambda\)
tsh_lasso_fit

Coefficients

  • tsh_lasso_fit contains an element called finalModel, the final model fit to the entire training set
  • coef(tsh_lasso_fit$finalModel) returns a large matrix of coefficients
  • This code finds the index of the lambda value in the final model that is closest to the optimal lambda, then extracts that column index from the matrix
  • Compare the shrunken lasso coefficients with the simple linear regression coefficients. What do you observe?
lambda_use <- min(tsh_lasso_fit$finalModel$lambda[tsh_lasso_fit$finalModel$lambda >= tsh_lasso_fit$bestTune$lambda])
position <- which(tsh_lasso_fit$finalModel$lambda == lambda_use)
best_coefs <- coef(tsh_lasso_fit$finalModel)[, position]

lm_coefs <- lm(TSH ~ ., data = sugarcane %>% select(all_of(c('TSH', sugarcane_predictor_variables))))$coefficients

data.frame(lasso_coefficient = round(best_coefs, 3),
           unshrunk_coefficient = round(lm_coefs, 3))

Coefficients: interpretation

  • No coefficients have been shrunk all the way to zero but some have been shrunk pretty dramatically,
  • Look at the ones with highest absolute value
  • Overall the prediction performance of this model is not good enough to deploy in the field
  • Other models might improve performance slightly but there is no silver bullet

Further reading