Introduction

What is this course? It’s my attempt to cut through some of the mystery and hype surrounding machine learning (ML). For some time now, machine learning has been a big buzzword, not only in our scientific research fields but across all of popular culture. But what exactly is machine learning? I’ll spend a little bit of time introducing you to the basic principles behind machine learning, “bust” some of the common myths about machine learning, and show you how you can use machine learning in your research. After that gentle introduction, we’ll walk through two different machine learning tasks using the R statistical programming language.

Download the worksheet for this lesson here.

What should you know coming into this course?

This course is designed for practicing researchers who have some experience with statistical analysis. It would be great if you already have a little bit of exposure to R. That’s because in the walkthroughs in this lesson, we will make use of some common R packages for manipulating data (dplyr) and plotting data (ggplot2), as well as write model formulas using the basic R model formula syntax.But if you don’t know any of that, don’t sweat it because you can just follow along by copying the code.

This course is also designed for people with little to no knowledge about machine learning, who are just interested in learning what it is. If you came here expecting to learn how to tweak the tuning parameters of a multilayer perceptron for image semantic segmentation, this is probably too basic for you.

What will you learn from this course?

Conceptual learning objectives

At the end of this course, you will understand …

  • What machine learning is and how it is similar to and different from statistics
  • Why the common myths about machine learning are not true
  • What the difference between prediction and inference is as a goal
  • What training and testing data are
  • What a loss function is
  • What supervised and unsupervised learning are
  • What the basic steps of a machine learning workflow are

Practical skills

At the end of this course, you will be able to …

  • Explore and pre-process data for machine learning models
  • Fit a random forest model in R for a classification task
  • Fit a lasso regression in R for a regression task
  • Cross-validate ML models
  • Use the R package caret to fit many kinds of ML models using the same code syntax

What is machine learning?

Machine learning myths

Before we get into too much technical detail, I want to start by discussing some common myths about machine learning, both good and bad.

Myth 1. Machine learning is way more powerful than statistics. Machine learning has been hyped, some might say overhyped, for some years now, as the hot new thing. In contrast, statistics has been around for a while, and its limitations are well known. In our culture with its short collective attention span, we tend to gravitate toward shiny new things and think “old = bad, new = good.” So you might think that machine learning and artificial intelligence are making classic inferential statistics obsolete. The counterargument to that is that machine learning models are essentially statistical models. They suffer from the same limitations as all statistical models: they are only as good as the data that you put into them. If there is bias or confounding influences in your data that you don’t account for, it will be reflected in poor model performance. If you try to extend or generalize a model too far away from the data you used to fit the model, you’re going to get poor model performance. That’s true for any kind of model.

original by sandserif comics
original by sandserif comics


Myth 2. Machine learning is too fancy/new/different for me to learn. Some machine learning models truly are complex and require a lot of theoretical background to fully understand. That said, there are a lot of software tools that make it easy to fit machine learning models … in fact some almost make it too easy. It’s a good idea to at least have a modest level of understanding of what kind of model you are working with. If you see it as a magic black box you will not be able to diagnose problems when things go wrong. And I hope today’s lesson convinces you that many machine learning models are really pretty simple and not as frighteningly complex as they are often hyped up to be.

Myth 3. My data aren’t big enough to do machine learning. If there’s one buzzword that’s even more overhyped than machine learning, it’s “Big Data.” From the way some people talk, you would think that you need terabytes of data that you can’t even load all at once onto your laptop’s RAM memory to do any meaningful research anymore. A lot of people in ARS do generate truly big data, especially if we’re dealing with things like UAV imagery and single-cell omics. But machine learning can be a useful tool even on relatively small datasets, which I hope we will demonstrate today with our example models.

Image credit Silicon Angle
Image credit Silicon Angle


Myth 4. I don’t need machine learning. OK, I may have spent a while hating on machine learning and saying it’s a bunch of hot air and hype. But don’t get the wrong idea. The techniques of machine learning are really useful across all kinds of scientific research. In fact I think that the focus on prediction has been a real wake-up call for statistics as a whole and is changing the way statistics is done even if people aren’t doing ML in the strict sense. Also, even if you never fit a machine learning model yourself, you will definitely hear about them in the scientific papers you read and review, and even in the news. It can’t hurt to become a little more literate about ML so that people can’t pull the wool over your eyes with their fancy jargon about variational autoencoders and stochastic gradient descent.

But what is machine learning, anyway?

Now that we’ve seen what machine learning is not, let’s think about what it is. A simple definition of machine learning is any job you give a machine (computer) to do, that it gets better at doing, the more data it gets. Well, that sounds a lot like statistics, doesn’t it? It can be argued that any time you use a computer to fit a statistical model, you are doing machine learning. Models in general, whether you call them statistical models or ML models, are machines that take data as input and spit out predictions as output. Ideally, the more data you have to pull a signal out of, the better your model will be at its job of making predictions.

Regression and classification
Regression and classification

What jobs do machine learning models do?

There are two main prediction “jobs,” or tasks, that machine learning models do: classification and regression. Classification means predicting which discrete category something belongs to, and regression means predicting a continuous outcome variable. Again, this is nothing new if you’ve fit statistical models before. In particular, regression should be familiar to you because even the simplest ANOVA is a type of regression model.

Supervised and unsupervised classification
Supervised and unsupervised classification

Supervised versus unsupervised classification

Classification might need a little more explaining. When it comes to classification, there are two types of learning we can do: supervised and unsupervised. Supervised learning means that you know which category each of your data points belong to, and you are trying to find patterns in the data that enable you to predict which of those previously defined categories new data points belong to. Unsupervised learning means you are trying to find natural groups or clusters in your data without knowing beforehand which group each data point belongs to, or even how many groups there are.

An example of supervised learning would be if you took a lot of clinical measurements on a group of people and followed them over time to see which individuals developed heart disease and which didn’t. You could develop a model to predict the probability of developing heart disease based on those measurements. Then you could use that model to predict heart disease in new individuals based on their measurements, and use that prediction to prescribe medicine or interventions.

An example of unsupervised learning is the model that Google News uses to categorize news articles into topics. The model gets text data from a bunch of articles from around the internet. It tries to cluster them into groups of stories related to the same topic or event, using a language model. This is unsupervised because the model does not have any information about what events any of the stories are written about; it just finds articles whose text is more closely related to each other than it is to articles not within that group. Any articles that end up within the same cluster should be about the same news event. But unsupervised learning models are not always that complicated and fancy. If you’ve ever done a PCA, you’ve used an unsupervised learning algorithm!

In general, supervised learning is great but requires more “ground-truthed” or labeled datasets, so it might not always be possible, especially at large scale. For example, if you want to classify huge numbers of remotely sensed images but don’t have the time or ability to manually label a huge training set of images, you might need to go the unsupervised route.

Image credit 2016CIAT/Neil Palmer
Image credit 2016CIAT/Neil Palmer


What kind of machine learning is going on at ARS?

Heart attacks and news stories are great and all, but I am not a doctor or a software engineer in Silicon Valley. “How is all of this relevant to my work as an agricultural scientist?” you might be asking yourself that at this point. In fact, there are a lot of ARS scientists fitting machine learning models to all kinds of data. Here are just a few quick examples.

  • Genomic selection. Many crop and livestock breeders in ARS are using genomic selection models that are some flavor of machine learning model. These models often outperform conventional mixed models in predicting traits of interest from sequence data.

  • Identifying weeds from drone imagery. ARS scientists are using imagery collected by unmanned aerial vehicles to train models that can identify different species of weeds within production fields, allowing for more precise and targeted weed control.

  • Exploring the effect of diet on gut bacteria. ARS scientists are gathering data on the bacterial microbiome communities that live in the gut and using machine learning to explore whether modifying diets can improve microbiome health.

What’s the big deal with prediction?

Probably the biggest difference between the machine learning approach and the way statistics is traditionally done in the life sciences is that machine learning’s goal is prediction whereas classical statistics’ goal is inference. What does that mean? Both prediction and inference are trying to learn something from patterns in the data. Prediction is using that pattern to get an answer. With prediction, you don’t much care about the why as long as the model performs well. Inference is trying to use the patterns in the data to understand processes going on in nature. You really do care about the why. To me, that means inference is harder than prediction. To make a good inference, you not only have to have a model that makes reasonably good predictions, you also have to have a deeper understanding of how the different variables translate to actual physical causal relationships in the real world. Let’s leave all that difficult thinking to scientists and philosophers for now, and just pretend we’re engineers trying to build something that works!

Spurious correlation?
Spurious correlation?


However, there’s no reason why machine learning models couldn’t be used for inferential purposes as well. Whether the goal is inference or prediction, we want to get parameter estimates that are as unbiased as possible. Plus, if your machine learning model actually incorporates true causal relationships between variables, and not just chance associations, it will be a more robust model. For example, shark attacks and ice cream sales might be highly correlated, so you could build a model that predicts where to build the next ice cream stand based on where the most shark bites have been reported. But there is no true causal relationship between those variables, and they are just confounded by things like season of the year and number of visitors to the beach. Your model would start to fail if, let’s say, overfishing decimates shark populations in the area and the correlation is broken. Better to use your noggin and incorporate variables into your model that are not only correlated with your outcome of interest, but also have a true causal association with it.

What’s a loss function?

Any machine learning model we fit will have a loss function. Usually this represents some form of prediction error and we want to minimize it. You are already familiar with loss functions if you’ve ever done a linear regression. In a linear regression, you are trying to find the regression line that minimizes the sum of the squared residuals (distances from your data points to the fitted line). The lower that sum, the better the model (in this case the regression line) is doing at predicting the dependent variable. Machine learning models sometimes use mean squared error (MSE) as the loss function, which is the same as traditional linear regression. However, MSE isn’t the only loss function used. Mean absolute error (MAE) is another one. With MAE, you are trying to minimize the sum of the absolute values of the distances from the data points to the fitted line. Some people use it because it’s more tolerant of outliers … if you square the errors, outliers with large errors will be penalized very strongly but this isn’t the case if absolute values are used. Classification tasks use different loss functions. We won’t get into the details here but it’s something to be aware of as you get experience with ML models.

Regularization balances between underfitting and overfitting
Regularization balances between underfitting and overfitting


What is regularization?

You may have heard the terms overfitting and underfitting before in the context of models. Overfitting happens when your model fits too closely to your training data, and its predictions reproduce random noise in that specific dataset. Underfitting happens when your model doesn’t fit your training data closely enough, and its predictions miss true patterns. Both overfitting and underfitting are problems for a model because they result in poor prediction performance. In traditional statistics courses, the term regularization is not used that often, but you do regularization any time you do variable selection or model selection, or even whenever you include random effects in a mixed model. Regularization means making your model less complex and more general. A regularized model will fit your training data less closely, but perform better when making predictions on new data.

In the machine learning context, the loss function has a regularization parameter that determines the optimal level of regularization that will balance between overfitting and underfitting. We’ll discuss how you find that ideal balance later on.

How to fit a machine learning model

Now that you know a little about machine learning, let’s go through the basic “recipe,” or steps you would need to follow when you get some data and want to do some machine learning on it.

Training and test sets

The first thing you need to decide what data will be used to fit the model, the training set, and what data will be used to test whether the model is good, the test set. In an ideal world, the test data and training data would be as independent of each other as possible. That way, you can be sure that your model is making generalizable predictions and not just fitting random noise in your dataset. If the random noise in the training dataset is correlated with random noise in the test set, that’s called data leakage and it’s not a good thing.

Of course, there is a limit to just how independent the training and test set should be from each other. For instance a genomic selection model developed in Alabama to predict which varieties of a crop will have the highest yield might perform well in Mississippi, but no one would be surprised if it didn’t make good predictions in Michigan. You might see the word transferability used, which is the ability of a model to produce accurate predictions on data that were not used to fit the model. It’s important to think about what contexts you want your model to be applied in, and get validation data that will be as strict of a test as possible to make sure your model can do well across all those contexts.

Training and testing split
Training and testing split


If there’s no second independent test set available, the best you can do is set aside a portion of your data as the test set. Your model will never get to see that part of the data when you’re estimating the parameters. It will only be used to see if the model makes good predictions. Some people would ask, isn’t that a waste of perfectly good data? Well, yes, it isn’t ideal to have to fit your model to a reduced dataset, but it is essential to be able to validate your model. So it is usually a good idea to hold back a pretty hefty proportion of the data, like 10% to 20%, as a test set.

The test set should not necessarily be a completely random sample from your dataset. If you have a classification problem, it’s a good idea to do a stratified random sample to ensure you have adequate representation of all your categories in both the training and test set. Also, if your data has a multilevel structure, for example if it comes from a block design experiment, you might want to assign training and test at the block level. If a single block has data points that are part of the training set and the test set, this will cause data leakage and lead to overoptimistic estimates of model performance. This is because observations from the same block are correlated because they share the same unmeasured environmental influences. So your predictions for the test set would reflect that correlation. Once the model gets used in the real world in new contexts, you won’t have that correlation any more to prop up your predictions and it will perform poorly. Putting entire blocks in the test set will alleviate that problem.

Data visualization and pre-processing

Before you fit a model, you need to take a look at your data. Making plots is an important step in any kind of data analysis. If there are any data points that are clearly errors or otherwise invalid, you can remove them. This is not the same thing as removing “outliers” that have extreme values. When you are making predictions later on, you might not be able to prune away those extreme values. It’s better to make sure your model can handle extreme values, than to pretend they don’t exist (of course as long as those extreme values are real data and not faulty readings from an instrument, for example).

Another important step is removing predictor variables that don’t carry much information. By the way you can sound more like a machine learning engineer if you call your variables “features.” Features that have no or very little variance (for example a column with 999 zeros and 1 one) and features that are highly correlated (for example r > 0.95) with others in your dataset will not contribute anything to the prediction performance of your model. All they will do is slow down the computation. Removing those columns before fitting the model is a good idea. Of course, this doesn’t say anything about whether the variables you choose to keep or remove are causally associated with the response. As we have been saying, the goal of ML is typically prediction rather than inferring a process from a pattern.

Another thing that usually needs to be done is normalizing or standardizing your variables. If you have many predictors that are measured in different units, it is a good idea to standardize them somehow so that they are all on a common scale. That way, it is the relative changes in the variables that will be used to predict changes in the outcome. If you don’t, your model will give more weight to the variables that happen to be measured in larger numbers. Make sure you do this standardization or normalization separately on the training and test sets. Otherwise, you may have some leakage where information about the scale of the training data affects the test data.

Tuning the model

Now it’s time to actually fit some models. We want to estimate parameters on our training set that will maximize prediction performance on the test set. ML models have parameters just like the traditional statistical models you’re familiar with. These parameters capture the effects of the different predictor variables on the outcome variable. But they also have tuning parameters, or hyperparameters. These are special parameters that do things like set the level of regularization to balance between overfitting and underfitting, as we discussed above.

Image credit Rebecca Wilson
Image credit Rebecca Wilson


But how do we figure out that optimal balance that will give us the best performance on our testing set, without “peeking” at the test set? We usually use cross-validation. In cross-validation, we split up our training data once again, this time into equally-sized folds. Let’s say we are doing 5-fold cross validation. First, we fit a model with a particular set of tuning parameters to folds 2, 3, 4, and 5 combined, and use that model to predict the outcome variable for fold 1. Then fit the model with the same set of tuning parameters to folds 1, 3, 4, and 5 combined, and use that model to predict the outcome for fold 2. Repeat until all five folds have been predicted so that every data point has been predicted once. Calculate your preferred metric of prediction performance (such as MSE) on those combined predictions. Then, repeat the cross-validation process for a lot of other sets of tuning parameters. The set of tuning parameters that minimizes your loss function (has the best prediction performance) is the one you will use to fit the final model.

Cross-validation
Cross-validation


Because you can get better or worse performance by chance depending on how your folds are split up, it is usually a good idea to repeat the entire cross-validation process across your grid of tuning parameter values many times and select the tuning parameters with the best average performance over the many repeats. If that sounds like it might take a lot of time and computing power, it often does! Luckily, at ARS we have access to SCINet high performance computing, which can be a big help.

Note: As we mentioned above when discussing training-test set split, if you have a blocking structure in your data, you have to make sure that your cross-validation folds are split up at the block level and not at the level of individual data points within blocks. Otherwise you will get overinflated estimates of model performance.

Fitting and validating the final model

Now that you’ve dialed in your tuning parameters, the next step is to fit a final model to your entire training set using those optimal tuning parameters selected by cross-validation. Now you’re ready to hit the big time and see how well your model does on data it’s never seen before.

Use your model to make predictions on the test set that we split off from the training set back at the start of this process (making sure the variables in the test set are standardized or normalized in the same way as the training set). Then evaluate prediction performance. That’s it!

Now it’s important to keep in mind there is no significance threshold that you can mindlessly apply in all situations to determine whether your model is “good” or “bad.” Whether a predictive model is good or bad is a completely practical consideration. A classification model that classifies new observations correctly 90% of the time might be great for some applications, but even 99% accuracy might not be good enough for other more sensitive applications.

Model tuning and validation workflow
Model tuning and validation workflow


Assessing variable importance

Even though prediction is the main goal with ML models, we can actually use them to make inference as well. There are some things we can do post hoc to examine which of our predictor variables are most important for predicting the outcome — those variables would be the best candidates to examine further for potential causal relationships. We aren’t going to focus on these too much in this intro lesson, but we will look at one type of variable importance measure which ranks the variables in order of how much the model’s accuracy would decline if that variable were excluded.

Do it all over again

There is no reason not to fit several different ML models on your training set and compare which one performs best on the test set, as long as you use regularization to guard against overfitting. There are tons of options out there. Remember that more complicated is not always better. You may find that a plain-vanilla logistic regression does just as well as an extremely sophisticated neural network. And even if a more complicated model performs better, it might be better, especially when you are just starting out, to stick with simpler models that you understand better. That way, you will be better able to make sense of the results and fix things if something goes wrong.

Some more thoughts on machine learning

Why don’t people use this kind of workflow in conventional statistics?

After reading this, you might be thinking something along these lines: “This whole routine of training-test split and cross-validation sounds like a really good idea. Why don’t people do this when they fit conventional statistical models, too? Why haven’t I heard of this before in my stats classes?” Well, that’s a great question! Honestly, they should. Yes, overfitting is more of a problem if prediction is the goal, compared to the situation where we care more about understanding a natural process, but it’s still a problem in all cases. Plenty of published research consists of little more than mining data and making conclusions about natural phenomena based on patterns that are just illusions caused by random noise. We have that to thank for the reproducibility crisis in science, at least in part. Cross-validation and testing a model on an external validation dataset is a really strict test of whether a model performs well in the real world. There is no reason why we should not subject our statistical models to that kind of “tough love” as well. If the model can take that kind of punishment and still survive, we can be that much more confident that it is capturing a true and meaningful process. A big part of the reason that machine learning has become so successful in recent years is because of this focus on avoiding overfitting and making prediction performance king. This should be a wake-up call to anyone who works with any kind of model!

A note about software

Image credit Math is in the Air
Image credit Math is in the Air


We are going to use R for the demos in this lesson because all the other courses I have use R. There is a perception that Python is the best programming language for machine learning. It’s true that Python has some great ML libraries. But R, and yes SAS, do very well too. In the future I may provide Python code to compare alongside the R code in these demos.

A note of caution

Because machine learning models are now so widespread, lots of software platforms have been developed that claim to “streamline” or automate the model fitting process. Just throw in the data, press a few buttons, and the black box spits out a result! On one hand, it is nice to make the models more accessible to people that might not have a strong background in coding. On the other hand, there is some danger in making it too easy to get results from a model without having any clue what it is doing under the hood.

What we’re going to do today involves a little bit of “black boxing” because I don’t want you to get bogged down in the code. However, there is no real substitute for looking at the data and customizing the model to the dataset at hand. If you try to abstract too far from the data, you can run into problems. Either your model will not give satisfactory results and you won’t know why, or your model will seem to be perfect but there will be an underlying issue that you won’t be able to diagnose.

Image credit Sourav Bose
Image credit Sourav Bose


Here I am thinking specifically of the R package tidymodels (really a set of related packages analogous to tidyverse — it was made by the same group of people who made tidyverse). A lot of intro ML tutorials use this package so you may see it if you search for material on introductory machine learning in R. Although tidymodels has a lot of advantages, I think using it is not great for beginners because it makes it hard to see exactly what the model is doing. When you become more advanced with ML, you may be fitting large batches of models to lots of datasets. In that case a streamlined approach, such as you get with tidymodels, might be good. But because the R world is moving in that direction, I may create a tidymodels tutorial in the future. Check back later for that.

A final big picture note


Instead of ending on that negative note before we transition to actually looking at some ML code, I want to end on a positive and inspiring note that takes a big picture view. People working with data these days, whether you call what you’re doing statistics, data science, modeling, or just plain science, are spoiled for options. There are so many cool models and approaches out there to choose from! In my mind, the three biggies are mixed/multilevel models, Bayesian statistics, and of course machine learning. (There is a lot of overlap between those — you could have a multilevel Bayesian machine learning model — but I’m oversimplifying here to make a point.) They may seem very different superficially, but fundamentally they are all very similar:

  • Machine learning models use regularization penalties to avoid overfitting to your dataset and make good general predictions
  • Mixed models use random effects/shrinkage to avoid overfitting to your dataset and make good general predictions
  • Bayesian models use skeptical priors to avoid overfitting to your dataset and make good general predictions

So you see, they’re all the same (at least in a way)! As scientists we want to be able to abstract from our data and uncover general truths about the world. All these approaches have ways of helping us do that, so they are all great tools to use. But just think of them all as tools in your toolbox! The more different tools you know how to use, the better.

Demo 1: Date variety classification

Introducing the example data

This example dataset comes from Kaggle, a site that hosts lots of free example datasets that you can fit statistical and machine learning models to, and compete with others to build the best predictive models. There are a lot of food and agriculture-themed datasets on Kaggle.

Today we’re going to use the date fruit classification dataset provided by Murat Koklu. Each row of the dataset has 34 measurements that were taken from images of a single date fruit. In addition, there is a Class column indicating one of seven date varieties that the fruits belong to. In total there were 898 fruits measured. Our task here is to develop a model that accurately predicts the variety of a date from the measurements alone. The model could be used later on to classify dates of unknown variety.

Image credit Britannica
Image credit Britannica


What we are going to do:

  • Import the dataset and visualize it
  • Remove erroneous data points
  • Remove highly correlated predictor variables
  • Standardize the predictor variables
  • Split the dataset into training and testing sets
  • Fit a random forest model, using five-fold cross-validation, for each value of a tuning parameter
  • Identify the tuning parameter that maximizes accuracy on the hold-out folds
  • Fit the final model on the full training set
  • Validate the model by calculating accuracy on the testing set

Import and explore the data

The first thing we need to do is load any R packages we’ll need. The tidyverse packages are for manipulating and plotting data, randomForest is for fitting a random forest machine learning model, caret is for training the model with cross-validation, ggcorrplot is for plotting the variables to see whether any are highly correlated, and patchwork is for making a multi-panel plot.

library(tidyverse)
library(randomForest)
library(caret)
library(ggcorrplot)
library(patchwork)

Now import the data. Go ahead and convert the target or outcome variable, Class, to a factor variable.

date_fruit <- read_csv('https://usda-ree-ars.github.io/SEAStats/machine_learning_demystified/datasets/date_fruit.csv') %>% mutate(Class = factor(Class))

How many rows and columns of data do we have and how many of each of the classes are there? Out of the 898 fruit, there are seven classes and the dataset is reasonably well balanced: we have at least 65 fruit from each of the classes. This means we should have enough data to have adequate numbers from every class in both the training and test sets.

dim(date_fruit)
## [1] 898  35
table(date_fruit$Class)
## 
##  BERHI DEGLET  DOKOL  IRAQI ROTANA SAFAVI  SOGAY 
##     65     98    204     72    166    199     94

Let’s examine the data. Plot each of the response variables as boxplots grouped by the outcome variable, Class.

predictor_variables <- names(date_fruit)[!names(date_fruit) %in% 'Class']

date_long <- pivot_longer(date_fruit, cols = all_of(predictor_variables), names_to = 'variable')

ggplot(date_long, aes(x = Class, y = value)) +
  geom_boxplot(aes(fill = Class)) +
  facet_wrap(~ variable, scales = 'free_y') +
  theme_bw() +
  theme(legend.position = 'none', axis.text.x = element_text(angle = 45, hjust = 1))

We can see from the boxplots that some of the variables related to shape and size of the fruit are more informative for distinguishing between the varieties than others. It certainly looks like there is a pretty strong variety signal for quite a few of the variables, which is promising for our model. We can also see that most of the response variables are reasonably symmetrically distributed for most of the varieties. But there is one outlier in the Safavi variety for aspect ratio and a couple of the other response variables. The aspect ratio (length to width ratio) for most of the fruit is around 1 to 2. For one data point it is more than 500! That is not an extreme value that we need to account for in our model — that is clearly an erroneous value because a date that is 500 times longer than it is wide is impossible.

Let’s remove the extreme value. This reduces the number of rows in the dataset by one.

date_fruit <- filter(date_fruit, ASPECT_RATIO < 500)

nrow(date_fruit)
## [1] 897

Removing highly correlated features

Make a correlation heatmap to look at the correlation between each pair of variables. Here we are using Pearson correlation. It looks like some of the correlations are pretty close to ± 1.

date_corr <- cor(date_fruit[, predictor_variables])

ggcorrplot(date_corr, type = 'lower') 

There is a handy function in the caret package called findCorrelation() which will take a correlation matrix and identify variables correlated with at least one other variable above a certain threshold. We’ll use \(|r| > 0.95\) as our threshold. If two variables are correlated at that level, one of the two will provide no additional predictive power, so it can be removed. 16 variables are removed in this way.

correlated_variables <- findCorrelation(date_corr, cutoff = 0.95, names = TRUE)

date_fruit <- date_fruit %>% select(-all_of(correlated_variables))

Training-test split

We will use 90% of the data for training the model and set aside 10% of the data as the test set. The model will never access the test data while we are training it (estimating its parameters). The caret function createDataPartition() will split the data for us using stratified random sampling so that roughly the same proportions of each class are in the training set and the test set. Note that I used set.seed() to initialize the pseudo-random number generator at a specific state. If you set the same seed as me, you should get the same results as me.

set.seed(1)
trainIndex <- createDataPartition(date_fruit$Class, p = .9, list = FALSE)
date_train <- as.data.frame(date_fruit[trainIndex, ])
date_test <- as.data.frame(date_fruit[-trainIndex, ])

Tuning the model with cross-validation

We are going to fit a random forest model to classify date fruits into variety based on measurements taken from photographs of the fruit.

A digression: explanation of random forest

What is a random forest model, exactly? It is a type of model that randomly samples the data with replacement many different times and fits a decision tree to each random sample. But what is a decision tree? It’s a way of sequentially splitting up a dataset with a series of rules that ultimately lead to a class or value. If that doesn’t make sense here’s an example. We use a series of rules to classify five different animals based on their characteristics. Notice that it may take more decisions to identify one animal than another. Also notice that we could have used other traits to construct the tree or put them in a different order, and still correctly classified the animals.

Decision tree
Decision tree


Random forest is called “forest” because it grows a “forest” of decision trees and uses the majority vote of the many trees it creates to decide on the final classification rules. For each random sample of the data, some data points are selected and others are not. The data points that are selected are called “in the bag” and those that aren’t selected are “out of bag.” The random forest algorithm bases its final model on how well the decision trees do on predicting correct classifications for data points that are “out of bag” for each tree. Thus, the random forest algorithm already has some internal safeguards against overfitting because it is effectively making a training-test split with each tree it grows. But we are going to use cross-validation to make extra sure that we are getting just the right amount of regularization.

That was a very superficial description of random forest that doesn’t really give the full technical details of how it works. See the further reading section at the bottom of this lesson for some great resources to learn more.

Training the random forest model with caret

The caret package makes it fairly easy to train our machine learning model. First, we specify our model training method using trainControl(). The argument method = 'cv' indicates cross-validation, with number = 5 folds. Now, the train() function is where the real magic happens. It has many arguments of which we only supply some in this case.

  • form is the model formula. Class ~ . is shorthand for “predict the Class variable using all other variables in the dataset.
  • data provides the data frame where the model fitting function will look for the variables in the model formula.
  • method = 'rf' means to fit a random forest model using the R package randomForest in the back end. Here is a list of all the different models that can be specified with method.
  • preProcess = c('center', 'scale') is a handy way to standardize all your predictor variables with a z-transformation. center subtracts the mean from each column and scale divides each column by the standard deviation.
  • tuneGrid is a data frame where each row has a combination of tuning parameters that we are going to cross-validate. Here, we are only going to tune one parameter, mtry. In a random forest model, that is the number of variables randomly selected each time a new split point is determined. A higher number means a closer fit to the data at the potential risk of overfitting.
  • metric is the metric of model performance (i.e. the loss function) that we are going to use to compare how well each combination of the tuning parameters does. Here we say 'Accuracy', the overall proportion of correct classifications. Thus the goal of the cross-validation is to find the mtry value that maximizes accuracy.
  • trControl is the argument where we pass the model training specification we created using trainControl().
  • ntree is an argument to the randomForest() function that is fit behind the scenes. It sets the number of trees generated for each random forest model. The higher the better, but too high of a number will take a long time to run.
  • importance = TRUE means that variable importance values will be calculated for us to look at later.
cv_spec <- trainControl(method = 'cv', number = 5)

date_rf_train <- train(
  form = Class ~ .,
  data = date_train,
  method = 'rf',
  preProcess = c('center', 'scale'),
  tuneGrid = data.frame(mtry = 1:10),
  metric = 'Accuracy',
  trControl = cv_spec,
  ntree = 500,
  importance = TRUE
)

The code above should run in less than a minute. train() does a lot in one function call. It pre-processes the data, fits five cross-validation folds for each of the ten values of mtry we provide, selects the optimal value of mtry based on the accuracy estimated by predicting the classes for the fold that is held out each time, and fits the model with the optimal value of mtry on the full training set!

Let’s look at the results. plot() shows the accuracy for each value of mtry. We see that our model performance was not very sensitive to the choice of mtry in this case. The worst performance was a little under 88.6% accurate and the best a little over 89.2%. That is close enough that you might get a different result if you did a different random split into cross-validation folds.

plot(date_rf_train)

We can also look at summary information about the model by typing the name of the model into the console.

date_rf_train
## Random Forest 
## 
## 811 samples
##  18 predictor
##   7 classes: 'BERHI', 'DEGLET', 'DOKOL', 'IRAQI', 'ROTANA', 'SAFAVI', 'SOGAY' 
## 
## Pre-processing: centered (18), scaled (18) 
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 649, 649, 649, 649, 648 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    1    0.8865712  0.8628787
##    2    0.8902446  0.8676050
##    3    0.8927214  0.8708087
##    4    0.8927365  0.8708251
##    5    0.8878058  0.8649237
##    6    0.8853367  0.8618949
##    7    0.8890252  0.8664363
##    8    0.8877982  0.8649681
##    9    0.8865637  0.8634920
##   10    0.8877982  0.8649621
## 
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 4.

Below we produce a confusion matrix using the aptly named function confusionMatrix(). A confusion matrix has the model predicted classes down the rows and the true “reference” classes across the columns. Percentage values are in the cells. A model with 100% accuracy would have only nonzero values along the diagonal, where predicted class = true class, and zero values everywhere else. Here, most of the values off the diagonal are close to zero because the cross-validated accuracy on the training set is almost 90%. It is informative to see what the model got wrong. For instance, if the true class was Berhi (first column), the prediction was mostly correct but if it was wrong, the model mostly called it Iraqi. If you look back at the raw data you can see why this might be — Berhi and Iraqi are quite similar on a number of metrics. For example, they both have the largest area among the seven varieties.

confusionMatrix(date_rf_train)
## Cross-Validated (5 fold) Confusion Matrix 
## 
## (entries are percentual average cell counts across resamples)
##  
##           Reference
## Prediction BERHI DEGLET DOKOL IRAQI ROTANA SAFAVI SOGAY
##     BERHI    5.4    0.0   0.0   0.6    0.0    0.0   0.0
##     DEGLET   0.0    7.5   1.4   0.0    0.1    0.0   1.1
##     DOKOL    0.0    1.4  21.1   0.0    0.0    0.0   0.0
##     IRAQI    1.2    0.0   0.0   7.3    0.0    0.0   0.0
##     ROTANA   0.2    0.5   0.0   0.0   17.6    0.0   0.6
##     SAFAVI   0.0    0.1   0.0   0.1    0.0   21.7   0.1
##     SOGAY    0.4    1.5   0.2   0.0    0.7    0.4   8.6
##                             
##  Accuracy (average) : 0.8927

Validating the model

Now that we have built a model and tuned its parameters using cross-validation, here comes the moment of truth. Let’s test the model on a “new” dataset that it has never seen before. How will it do? Unfortunately we do not have a truly independent dataset of date images to test the model on, so the best we can do is give it the 10% of the original dataset that we held back and didn’t use for the cross-validation tuning.

We use the predict() function which has a method for caret random forest models, and pass date_test to the newdata argument. Whatever we pass to newdata has to have a column for all the predictor variables that are in the model. Then we use the function confusionMatrix() to make a confusion matrix for the test data. As it turns out, we actually have very slightly better prediction accuracy (~89.5%) in the test set than the training set! This is a great sign for the model’s potential applicability to new datasets. But again, we have to keep in mind that this is not a truly independent test set, so we have to be cautious and not get overoptimistic.

Also notice that we get a variety of other prediction metrics on a class-by-class basis in the output of confusionMatrix(). I won’t go into detail about them here.

date_predict_test <- predict(date_rf_train, newdata = date_test)

confusionMatrix(date_predict_test, date_test$Class)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction BERHI DEGLET DOKOL IRAQI ROTANA SAFAVI SOGAY
##     BERHI      4      0     0     0      0      0     0
##     DEGLET     0      8     1     0      0      0     3
##     DOKOL      0      1    19     0      0      0     0
##     IRAQI      2      0     0     6      0      0     0
##     ROTANA     0      0     0     1     16      0     1
##     SAFAVI     0      0     0     0      0     19     0
##     SOGAY      0      0     0     0      0      0     5
## 
## Overall Statistics
##                                          
##                Accuracy : 0.8953         
##                  95% CI : (0.8106, 0.951)
##     No Information Rate : 0.2326         
##     P-Value [Acc > NIR] : < 2.2e-16      
##                                          
##                   Kappa : 0.8734         
##                                          
##  Mcnemar's Test P-Value : NA             
## 
## Statistics by Class:
## 
##                      Class: BERHI Class: DEGLET Class: DOKOL Class: IRAQI
## Sensitivity               0.66667       0.88889       0.9500      0.85714
## Specificity               1.00000       0.94805       0.9848      0.97468
## Pos Pred Value            1.00000       0.66667       0.9500      0.75000
## Neg Pred Value            0.97561       0.98649       0.9848      0.98718
## Prevalence                0.06977       0.10465       0.2326      0.08140
## Detection Rate            0.04651       0.09302       0.2209      0.06977
## Detection Prevalence      0.04651       0.13953       0.2326      0.09302
## Balanced Accuracy         0.83333       0.91847       0.9674      0.91591
##                      Class: ROTANA Class: SAFAVI Class: SOGAY
## Sensitivity                 1.0000        1.0000      0.55556
## Specificity                 0.9714        1.0000      1.00000
## Pos Pred Value              0.8889        1.0000      1.00000
## Neg Pred Value              1.0000        1.0000      0.95062
## Prevalence                  0.1860        0.2209      0.10465
## Detection Rate              0.1860        0.2209      0.05814
## Detection Prevalence        0.2093        0.2209      0.05814
## Balanced Accuracy           0.9857        1.0000      0.77778

Variable importance

Even though prediction is what we care about most, we also might be interested in how much each of the variables in our dataset contributes to the prediction of date variety. The varImp() function from the caret package will calculate variable importance scores for many different types of models. There are several different importance measures for random forest models. Here, we are looking at the mean decrease in accuracy associated with removing that variable from the model, measured by percent increase in mean squared error. The higher the value, the more the prediction accuracy decreases when the variable is removed and thus the more important the variable.

The argument type = 1 indicates to give overall scores. If you omit it, you can see separate variable importance scores for each of the varieties. The argument scale = FALSE gives us the raw importance values. We can see that AREA is the most important variable. On average, MSE increases by about 50% when it is excluded from the model. A table like this would be informative to include in a publication or presentation along with the prediction performance results.

varImp(date_rf_train, type = 1, scale = FALSE)
## rf variable importance
## 
##               Overall
## AREA            49.69
## EntropyRB       34.70
## COMPACTNESS     31.58
## ROUNDNESS       30.41
## SkewRB          27.52
## SHAPEFACTOR_2   23.29
## SkewRG          22.68
## SOLIDITY        22.39
## MeanRB          22.22
## SkewRR          20.12
## KurtosisRG      19.07
## SHAPEFACTOR_4   18.97
## KurtosisRB      18.77
## StdDevRG        16.59
## StdDevRB        16.52
## KurtosisRR      14.38
## EXTENT          12.63
## StdDevRR        11.80

Longhand version

If you are interested, I’ve included some code to do a “longhand” version of what we did above with caret functions, so that you can see a little bit more explicitly how it all works. This includes the training-test split, the centering and scaling of the data, the five-fold cross-validation for each of the ten different values of mtry, and the evaluation of the model using the test set. Click “Show” to see the code.

library(randomForest)

# Scale the predictor variables.
X <- scale(date_fruit[, predictor_variables])
Y <- factor(date_fruit$Class)

# Let's hold 10% of the data back as a test set. We will do cross-validation and then test on this final boss
test_idx <- sample(1:nrow(X), size = round(nrow(X) * .1))

X_train <- X[-test_idx, ]
X_test <- X[test_idx, ]

Y_train <- Y[-test_idx]
Y_test <- Y[test_idx]

# Next, let's try to fit a couple of models to classify the different types of date.

# What tuning parameters will we use?
mtrys <- 1:10

# Split it up into folds
fold_id <- sample(rep_len(1:5, length.out = nrow(X_train)))

# Loop through the mtry values. For each mtry value, loop through the folds and fit k-fold CV
date_cv <- data.frame(mtry = 1:10, error_rate = NA)

for (mtry in 1:10) {
  Y_pred <- rep(NA, length(Y_train))
  for (fold in 1:5) {
    rf_fit <- randomForest(x = X_train[fold_id != fold, ], 
                           y = Y_train[fold_id != fold],
                           ntree = 1000, mtry = mtry)
    Y_pred[fold_id == fold] <- predict(rf_fit, newdata = X_train[fold_id == fold, ])
  }
  date_cv$error_rate[date_cv$mtry == mtry] <- mean(Y_pred != as.numeric(Y_train))
}

# We see that mtry=4 minimizes the error rate.
with(date_cv, plot(mtry, error_rate, type = 'l'))

# Fit this to the full training set
date_finalmodel <- randomForest(x = X_train, y = Y_train, ntree = 1000, mtry = 4)

# Predict on the test set
Y_train_pred <- predict(date_finalmodel)
Y_test_pred <- predict(date_finalmodel, newdata = X_test)

# Confusion matrices for training and test data
confusionMatrix(Y_train_pred, Y_train)
confusionMatrix(Y_test_pred, Y_test)

Fitting other machine learning models

As we discussed above, if prediction is the goal, it is okay to try other types of classification model on the same dataset. For example, you can fit another type of model called support vector machine (SVM). I won’t get into the details here but I just want to illustrate that the train() function can be used to fit another type of model with basically the same arguments. The only differences are the method and tuneGrid arguments. Here we are trying out different values of the C parameter. This is the same kind of tuning parameter that many machine learning models have, that determines the balance between overfitting and underfitting. Small values of C result in a model that does not fit the data quite as closely, guarding against overfitting but running the risk of underfitting. Larger values of C do the opposite.

date_svm_train <- train(
  form = Class ~ .,
  data = date_train,
  method = 'svmLinear',
  preProcess = c('center', 'scale'),
  tuneGrid = data.frame(C = seq(0.02, 1, by = 0.02)),
  metric = 'Accuracy',
  trControl = cv_spec
)

plot(date_svm_train)

date_svm_predict_test <- predict(date_svm_train, newdata = date_test)

confusionMatrix(date_svm_predict_test, date_test$Class)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction BERHI DEGLET DOKOL IRAQI ROTANA SAFAVI SOGAY
##     BERHI      5      0     0     0      0      0     0
##     DEGLET     0      8     0     0      0      0     2
##     DOKOL      0      1    20     0      0      0     0
##     IRAQI      1      0     0     6      0      0     0
##     ROTANA     0      0     0     1     15      0     0
##     SAFAVI     0      0     0     0      0     18     0
##     SOGAY      0      0     0     0      1      1     7
## 
## Overall Statistics
##                                           
##                Accuracy : 0.9186          
##                  95% CI : (0.8395, 0.9666)
##     No Information Rate : 0.2326          
##     P-Value [Acc > NIR] : < 2.2e-16       
##                                           
##                   Kappa : 0.9018          
##                                           
##  Mcnemar's Test P-Value : NA              
## 
## Statistics by Class:
## 
##                      Class: BERHI Class: DEGLET Class: DOKOL Class: IRAQI
## Sensitivity               0.83333       0.88889       1.0000      0.85714
## Specificity               1.00000       0.97403       0.9848      0.98734
## Pos Pred Value            1.00000       0.80000       0.9524      0.85714
## Neg Pred Value            0.98765       0.98684       1.0000      0.98734
## Prevalence                0.06977       0.10465       0.2326      0.08140
## Detection Rate            0.05814       0.09302       0.2326      0.06977
## Detection Prevalence      0.05814       0.11628       0.2442      0.08140
## Balanced Accuracy         0.91667       0.93146       0.9924      0.92224
##                      Class: ROTANA Class: SAFAVI Class: SOGAY
## Sensitivity                 0.9375        0.9474       0.7778
## Specificity                 0.9857        1.0000       0.9740
## Pos Pred Value              0.9375        1.0000       0.7778
## Neg Pred Value              0.9857        0.9853       0.9740
## Prevalence                  0.1860        0.2209       0.1047
## Detection Rate              0.1744        0.2093       0.0814
## Detection Prevalence        0.1860        0.2093       0.1047
## Balanced Accuracy           0.9616        0.9737       0.8759

Overall, the performance of this model on the test set is slightly better than the random forest model. Feel free to explore some other options!

Demo 2: Sugarcane yield prediction

Introducing the example data

The next example we’ll walk through together illustrates a regression task rather than classification. Here, the variable we want to predict, sugarcane yield, is a continuous variable rather than categorical like fruit variety in the previous example. In particular, we are interested in whether we can use photographs taken of the crop in the field to predict sugar yield. This is real data (with random noise added) from an ARS researcher in Florida who flew drones over the crop and took imagery with a hyperspectral camera. There are many columns of data representing reflectance at different wavelengths of light. There are also columns of data with different vegetation indices that are calculated by transforming and combining the reflectance values in different ways; these vegetation indices are thought to be better representations of crop productivity than the raw wavelength data. We will let the regression model decide which of these potential predictors is best at predicting sugar yield.

Image credit Hannah Penn
Image credit Hannah Penn


The outcome variable is an important component of sugarcane yield, tons of sugar per hectare (TSH). There are eight columns of reflectance values for different wavelengths, beginning with R. There are ten columns of vegetation index values calculated from the reflectances. There is also a numeric column, Rep, with values 1-6 indicating which experimental block or replicate that row comes from. I’ve removed some extraneous variables from the dataset to simplify the example: the original dataset had multiple crop years, as sugarcane is a multi-year crop, as well as information on the variety and fertilization treatment for each of the rows.

We are going to use a form of linear regression called lasso to predict sugar yield from the reflectance and vegetation index values. A lasso regression is very similar to a multiple linear regression except that a regularization parameter \(\lambda\) (lambda) is included in the regression equation. This parameter effectively causes the regression coefficients to be shrunk toward zero, or even become zero and drop out of the model entirely. Thus if the coefficients are smaller in magnitude and fewer in number than standard linear regression without the penalty, we will end up with a simpler model that does not fit the training data quite as closely, but is less overfit and makes better predictions on the test data. The larger the \(\lambda\) the stricter the regularization and the more coefficients drop to zero. If \(\lambda = 0\), there is no shrinkage and lasso becomes equivalent to an ordinary linear regression.

sugarcane <- read_csv('https://usda-ree-ars.github.io/SEAStats/machine_learning_demystified/datasets/sugarcaneyield.csv')

Explore the data

Here plot the outcome variable versus each of the predictors. A smoothing trend (locally weighted regression) is plotted over each scatterplot to help visualize any relationship there may be. Some of the predictors do seem to have a weak relationship with the outcome.

sugarcane_predictor_variables <- setdiff(names(sugarcane), c('TSH', 'Rep'))

scatter_plots <- map(sugarcane_predictor_variables, ~ ggplot(sugarcane, aes(x = !!sym(.), y = TSH)) + 
                       geom_point() + 
                       geom_smooth(method = 'gam') +
                       theme_bw())

wrap_plots(scatter_plots, ncol = 4)

Let’s also look at the correlation coefficients between the predictors in the same way as we did for the date fruit variables.

sugarcane_corr <- cor(sugarcane[, sugarcane_predictor_variables])

ggcorrplot(sugarcane_corr, type = 'lower') 

Many of the predictors are highly correlated with one another. This is especially unsurprising for the vegetation indices. After all, the different indices are just different ways of multiplying and adding the same underlying data. Here, instead of pre-processing and removing correlated variables ahead of model fitting, we will leave them in and see how the lasso regularization deals with the colinear variables.

Set up pre-defined cross-validation folds

Here, instead of letting the data points be assigned at random to the cross-validation folds, we are going to manually specify the folds. Because the experiment we got this dataset from had a randomized block design, we will put all data points from the same Rep in the same CV fold. There will be six folds, one for each block. We are not going to use a validation dataset, because here we are interested more in inference. LASSO is basically optimized linear regression, so we are fitting a statistical model taking a page out of the ML playbook.

Here we use the map() function applied to each rep number to create a list with vectors of row indices for each fold. Then pass that list of vectors to the trainControl() function to specify how the cross-validation will be done.

cv_folds <- map(1:6, function(x) which(sugarcane$Rep == x))
cv_spec_sugarcane <- trainControl(method = 'cv', number = 6, indexOut = cv_folds)

Train the model with cross-validation

Once again we are using train() from the caret package to train the model, so this should look familiar to you. An alternative to the form argument to specify the model formula is to use the arguments x (a data frame of predictors) and y (a vector of outcomes). The glmnet package, which implements lasso regression among other techniques, is used as the model fitting method. As before, we will tell train() to center and scale the predictors. Here tuneGrid includes a range of lambda values increasing by factors of ten as well as setting alpha = 1. (If alpha were set to 0 or allowed to vary between 0 and 1, we would have ridge or elastic net regression, a related technique which we won’t cover in this lesson.) The metric we are going to use to decide which model is the best fit is the root mean squared error: metric = 'RMSE'. This is the square root of the average of the prediction errors, so lower RMSE is better. Lastly, we pass cv_spec_sugarcane, the cross-validation folds we just predefined, to the trControl argument.

set.seed(3)

tsh_lasso_fit <- train(
  x = sugarcane %>% select(all_of(sugarcane_predictor_variables)) %>% as.data.frame,
  y = sugarcane$TSH,
  method = 'glmnet',
  preProcess = c('center', 'scale'),
  tuneGrid = expand.grid(
    alpha = 1,
    lambda = c(0.00001, 0.0001, 0.001, 0.01, 0.1)
  ),
  metric = 'RMSE',
  trControl = cv_spec_sugarcane
)

Evaluating the model

We can see what value of \(\lambda\) minimized root mean squared error in the hold-out folds by plotting the model object.

plot(tsh_lasso_fit)

The plot shows that the smallest values of \(\lambda\), with the least amount of shrinkage, are slightly overfit. The best prediction performance (lowest RMSE) is at \(\lambda = 0.01\). Above that, the model becomes underfit and the prediction performance gets worse and worse. Of course, it’s important to note the scale of the y-axis. The errors only vary between ~5.4 and ~5.7, where the range of the data is between about 10 and 45 tons of sugar per hectare. This is a sign that we have a relatively weak signal in our data because even our optimal model can only reduce the average error a tiny bit below what we get with a very high \(\lambda\) value that shrinks almost all the coefficients to zero.

We can see what the optimal \(\lambda\) was by printing the bestTune element from the fitted model object.

tsh_lasso_fit$bestTune
##   alpha lambda
## 4     1   0.01

If you type the name of the model object into the console it will give you prediction performance for the values of lambda we tested.

tsh_lasso_fit
## glmnet 
## 
## 144 samples
##  18 predictor
## 
## Pre-processing: centered (18), scaled (18) 
## Resampling: Cross-Validated (6 fold) 
## Summary of sample sizes: 120, 120, 120, 120, 120, 120, ... 
## Resampling results across tuning parameters:
## 
##   lambda  RMSE      Rsquared   MAE     
##   1e-05   5.460200  0.3051896  4.480154
##   1e-04   5.460200  0.3051896  4.480154
##   1e-03   5.457436  0.3052400  4.477154
##   1e-02   5.441517  0.3037141  4.449482
##   1e-01   5.727710  0.2408323  4.681241
## 
## Tuning parameter 'alpha' was held constant at a value of 1
## RMSE was used to select the optimal model using the smallest value.
## The final values used for the model were alpha = 1 and lambda = 0.01.

You can see that the RMSE values all hover a little below 5.5 while the R-squared values never make it much above 0.3. The model performance is pretty insensitive to \(\lambda\) until \(\lambda = 0.1\) where the RMSE jumps up and R-squared drops. I thought this was a nice example dataset to give you, because many machine learning tutorials use simulated datasets or cherry-picked datasets where you can effortlessly build a model with near-perfect predictive performance. This is more illustrative of a real-world situation where there just isn’t much signal in the data. Maybe other modeling approaches could slightly decrease the prediction error — but there are real limits imposed by the fact that the natural world is a mess. R-squared of 0.3 is interesting for explaining a good portion of what is going on, but probably inadequate if reliable yield predictions are what you want.

Coefficients

What are the coefficients associated with the optimal model, anyway? The object tsh_lasso_fit contains an element called finalModel which is the final model fit to the entire training set. But coef(tsh_lasso_fit$finalModel) does not return just the coefficients for the optimal lambda, it returns a large matrix of coefficients. To extract the column we want, we first need to find the index of the lambda value in the final model that is closest to the optimal lambda, then extract that column index from the large matrix of coefficients.

I’ve also included some code to illustrate how the lasso regression shrinks the parameter estimates by fitting a linear regression model without any cross-validation or any shrinkage to the same data, and displaying those values in a column next to the lasso coefficients.

lambda_use <- min(tsh_lasso_fit$finalModel$lambda[tsh_lasso_fit$finalModel$lambda >= tsh_lasso_fit$bestTune$lambda])
position <- which(tsh_lasso_fit$finalModel$lambda == lambda_use)
best_coefs <- coef(tsh_lasso_fit$finalModel)[, position]

lm_coefs <- lm(TSH ~ ., data = sugarcane %>% select(all_of(c('TSH', sugarcane_predictor_variables))))$coefficients

data.frame(lasso_coefficient = round(best_coefs, 3),
           unshrunk_coefficient = round(lm_coefs, 3))
##                     lasso_coefficient unshrunk_coefficient
## (Intercept)                    24.042               -2.174
## R480                            8.223               13.668
## R550                           -0.392               -1.050
## R680                           -4.278               -4.889
## R710                           -1.510               -1.488
## R800                           -3.462               -0.548
## R980                            9.264                1.525
## R1200                          -4.766               -1.084
## R1660                          -3.339               -1.066
## NDVIRE800_710                  -5.306             -107.589
## NDVI800_680                    -0.830              -62.799
## CI800_710                      -2.566               -5.251
## MTCI800_710_710_680             1.541                3.365
## SAVI800_680                     0.247               46.189
## CCCI710_680                    -1.441              -65.323
## SR800_680                       2.364                0.987
## SR800_710                       0.000               -1.679
## GNDVI800_550                    6.708              201.650
## NGRDI550_680                    0.323               21.259

In this case, no coefficients have been shrunk all the way to zero but some have been shrunk pretty dramatically, though a few are actually higher than in the unoptimized linear model. Thus in this case the lasso regression didn’t fully eliminate any variables from consideration, though you could say that the lasso coefficients with higher magnitude (R480, R980, and GNDVI800_550) might be more important for predicting yield.

However, I would argue that the prediction performance of this model is not good enough to take it to the big time and deploy it in the field to predict yield from other drone imagery datasets. I would tend to advise against spending a lot of time with a kitchen-sink approach, trying out every model under the sun. No model is magic. If there’s a strong relationship, even a simple model should be able to pull it out and make decent predictions. But if there’s no relationship, the most sophisticated model in the world will not save you.

Further reading

This lesson is only a starting point. Here are some books and tutorials that might be helpful resources if you want to learn more about machine learning. The internet is chock-full of machine learning tutorials, which vary widely in quality. These are a few examples that I know are pretty good, focused on R.

  • Introduction to Statistical Learning: This is really the one book you should read if you want an accessible and easy-to-read introduction to machine learning. It was written by some very accomplished statisticians that pioneered some of the foundational techniques of machine learning. It is rare to find true experts that are also that good at explaining the topic of their expertise in such an understandable way. Better yet, it has versions with both R and Python examples and is available as a free download.
  • Applied Predictive Modeling: This book was co-written by Max Kuhn, the developer of the caret package that we used extensively in this lesson. It covers a wide range of models.
  • Series of introductory tidymodels articles: Though I didn’t use the tidymodels framework in this lesson, it is the direction a lot of machine learning practitioners that use R are going. This series of tutorial articles will introduce you to their recommended “pipeline” for preprocessing, evaluating, and tuning a predictive model.

Exercises

Exercise 1

Difficulty: easy

Set a new seed for the random number generator by replacing the line set.seed(1) with set.seed(2). Then train the date variety classification random forest model again and display the results for the training and testing sets. Do the results differ? Why or why not?

Exercise 2

Difficulty: challenging

Train another ML model to classify the date varieties. How do the results compare to the random forest and SVM?

Hint: Try one of the classification models from the list of available caret models. You may need to search online for documentation to find what tuning parameters you need to specify, because the default search grid may not always give you the best values.

Exercise 3

Difficulty: easy

Train the sugarcane yield regression model again using alpha = 0, which is ridge regression, a different form of regularization than the lasso. Do the results differ? Why or why not?

Exercise 4

Difficulty: challenging

Train another ML model to predict sugarcane yield. How do the results compare to lasso and ridge?

Hint: As in exercise 2, try one of the regression models from the list of available caret models and do some research in the documentation to find the best range for the tuning parameters.

Click here for answers