A Convolutional Neural Net for Ordinal Regression using CORN -- MNIST Dataset
In this tutorial, we implement a convolutional neural network for ordinal regression based on the CORN method. To learn more about CORN, please have a look at our preprint:
- Xintong Shi, Wenzhi Cao, and Sebastian Raschka (2021). Deep Neural Networks for Rank-Consistent Ordinal Regression Based On Conditional Probabilities. Arxiv preprint; https://arxiv.org/abs/2111.08851
Please note that MNIST is not an ordinal dataset. The reason why we use MNIST in this tutorial is that it is included in the PyTorch's torchvision
library and is thus easy to work with, since it doesn't require extra data downloading and preprocessing steps.
General settings and hyperparameters
- Here, we specify some general hyperparameter values and general settings
- Note that for small datatsets, it is not necessary and better not to use multiple workers as it can sometimes cause issues with too many open files in PyTorch. So, if you have problems with the data loader later, try setting
NUM_WORKERS = 0
instead.
BATCH_SIZE = 256
NUM_EPOCHS = 20
LEARNING_RATE = 0.005
NUM_WORKERS = 4
DATA_BASEPATH = "./"
Converting a regular classifier into a CORN ordinal regression model
Changing a classifier to a CORN model for ordinal regression is actually really simple and only requires a few changes:
1)
Consider the following output layer used by a neural network classifier:
output_layer = torch.nn.Linear(hidden_units[-1], num_classes)
In CORN we reduce the number of classes by 1:
output_layer = torch.nn.Linear(hidden_units[-1], num_classes-1)
2)
We swap the cross entropy loss from PyTorch,
torch.nn.functional.cross_entropy(logits, true_labels)
with the CORN loss (also provided via coral_pytorch
):
loss = corn_loss(logits, true_labels,
num_classes=num_classes)
Note that we pass num_classes
instead of num_classes-1
to the corn_loss
as it takes care of the rest internally.
3)
In a regular classifier, we usually obtain the predicted class labels as follows:
predicted_labels = torch.argmax(logits, dim=1)
In CORN, w replace this with the following code to convert the predicted probabilities into the predicted labels:
predicted_labels = corn_label_from_logits(logits)
Implementing a ConvNet
using PyTorch Lightning's LightningModule
- In this section, we set up the main model architecture using the
LightningModule
from PyTorch Lightning. - We start with defining our convolutional neural network
ConvNet
model in pure PyTorch, and then we use it in theLightningModule
to get all the extra benefits that PyTorch Lightning provides.
import torch
# Regular PyTorch Module
class ConvNet(torch.nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
# num_classes is used by the corn loss function
self.num_classes = num_classes
# Initialize CNN layers
all_layers = [
torch.nn.Conv2d(in_channels=in_channels, out_channels=3,
kernel_size=(3, 3), stride=(1, 1),
padding=1),
torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
torch.nn.Conv2d(in_channels=3, out_channels=6,
kernel_size=(3, 3), stride=(1, 1),
padding=1),
torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
torch.nn.Flatten()
]
# CORN output layer --------------------------------------
# Regular classifier would use num_classes instead of
# num_classes-1 below
output_layer = torch.nn.Linear(294, num_classes-1)
# ---------------------------------------------------------
all_layers.append(output_layer)
self.model = torch.nn.Sequential(*all_layers)
def forward(self, x):
x = self.model(x)
return x
- In our
LightningModule
we use loggers to track mean absolute errors for both the training and validation set during training; this allows us to select the best model based on validation set performance later. - Given a CNN classifier with cross-entropy loss, it is very easy to change this classifier into a ordinal regression model using CORN. In essence, it only requires three changes:
- Instead of using
num_classes
in the output layer, usenum_classes-1
as shown above - Change the loss from
loss = torch.nn.functional.cross_entropy(logits, y)
to
loss = corn_loss(logits, y, num_classes=self.num_classes)
- To obtain the class/rank labels from the logits, change
predicted_labels = torch.argmax(logits, dim=1)
to
predicted_labels = corn_label_from_logits(logits)
- Instead of using
from coral_pytorch.losses import corn_loss
from coral_pytorch.dataset import corn_label_from_logits
import pytorch_lightning as pl
import torchmetrics
# LightningModule that receives a PyTorch model as input
class LightningCNN(pl.LightningModule):
def __init__(self, model, learning_rate):
super().__init__()
self.learning_rate = learning_rate
# The inherited PyTorch module
self.model = model
# Save settings and hyperparameters to the log directory
# but skip the model parameters
self.save_hyperparameters(ignore=['model'])
# Set up attributes for computing the MAE
self.train_mae = torchmetrics.MeanAbsoluteError()
self.valid_mae = torchmetrics.MeanAbsoluteError()
self.test_mae = torchmetrics.MeanAbsoluteError()
# Defining the forward method is only necessary
# if you want to use a Trainer's .predict() method (optional)
def forward(self, x):
return self.model(x)
# A common forward step to compute the loss and labels
# this is used for training, validation, and testing below
def _shared_step(self, batch):
features, true_labels = batch
logits = self(features)
# Use CORN loss --------------------------------------
# A regular classifier uses:
# loss = torch.nn.functional.cross_entropy(logits, y)
loss = corn_loss(logits, true_labels,
num_classes=self.model.num_classes)
# ----------------------------------------------------
# CORN logits to labels ------------------------------
# A regular classifier uses:
# predicted_labels = torch.argmax(logits, dim=1)
predicted_labels = corn_label_from_logits(logits)
# ----------------------------------------------------
return loss, true_labels, predicted_labels
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("train_loss", loss)
self.train_mae(predicted_labels, true_labels)
self.log("train_mae", self.train_mae, on_epoch=True, on_step=False)
return loss # this is passed to the optimzer for training
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("valid_loss", loss)
self.valid_mae(predicted_labels, true_labels)
self.log("valid_mae", self.valid_mae,
on_epoch=True, on_step=False, prog_bar=True)
def test_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.test_mae(predicted_labels, true_labels)
self.log("test_mae", self.test_mae, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
Setting up the dataset
- In this section, we are going to set up our dataset.
- Please note that MNIST is not an ordinal dataset. The reason why we use MNIST in this tutorial is that it is included in the PyTorch's
torchvision
library and is thus easy to work with, since it doesn't require extra data downloading and preprocessing steps.
Inspecting the dataset
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
train_dataset = datasets.MNIST(root=DATA_BASEPATH,
train=True,
transform=transforms.ToTensor(),
download=True)
train_loader = DataLoader(dataset=train_dataset,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
drop_last=True,
shuffle=True)
test_dataset = datasets.MNIST(root=DATA_BASEPATH,
train=False,
transform=transforms.ToTensor())
test_loader = DataLoader(dataset=test_dataset,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
drop_last=False,
shuffle=False)
# Checking the dataset
all_train_labels = []
all_test_labels = []
for images, labels in train_loader:
all_train_labels.append(labels)
all_train_labels = torch.cat(all_train_labels)
for images, labels in test_loader:
all_test_labels.append(labels)
all_test_labels = torch.cat(all_test_labels)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/9912422 [00:00<?, ?it/s]
Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/28881 [00:00<?, ?it/s]
Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/1648877 [00:00<?, ?it/s]
Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/4542 [00:00<?, ?it/s]
Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
print('Training labels:', torch.unique(all_train_labels))
print('Training label distribution:', torch.bincount(all_train_labels))
print('\nTest labels:', torch.unique(all_test_labels))
print('Test label distribution:', torch.bincount(all_test_labels))
Training labels: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
Training label distribution: tensor([5911, 6730, 5949, 6125, 5832, 5410, 5911, 6254, 5841, 5941])
Test labels: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
Test label distribution: tensor([ 980, 1135, 1032, 1010, 982, 892, 958, 1028, 974, 1009])
- Above, we can see that the dataset consists of 8 features, and there are 998 examples in total.
- The labels are in range from 1 (weakest) to 5 (strongest), and we normalize them to start at zero (hence, the normalized labels are in the range 0 to 4).
- Notice also that the dataset is quite imbalanced.
Performance baseline
- Especially for imbalanced datasets, it's quite useful to compute a performance baseline.
- In classification contexts, a useful baseline is to compute the accuracy for a scenario where the model always predicts the majority class -- you want your model to be better than that!
- Note that if you are intersted in a single number that minimized the dataset mean squared error (MSE), that's the mean; similary, the median is a number that minimzes the mean absolute error (MAE).
- So, if we use the mean absolute error, , to evaluate the model, it is useful to compute the MAE pretending the predicted label is always the median:
all_test_labels = all_test_labels.float()
avg_prediction = torch.median(all_test_labels) # median minimizes MAE
baseline_mae = torch.mean(torch.abs(all_test_labels - avg_prediction))
print(f'Baseline MAE: {baseline_mae:.2f}')
Baseline MAE: 2.52
- In other words, a model that would always predict the dataset median would achieve a MAE of 2.52. A model that has an MAE of > 2.52 is certainly a bad model.
Setting up a DataModule
- There are three main ways we can prepare the dataset for Lightning. We can
- make the dataset part of the model;
- set up the data loaders as usual and feed them to the fit method of a Lightning Trainer -- the Trainer is introduced in the next subsection;
- create a LightningDataModule.
- Here, we are going to use approach 3, which is the most organized approach. The
LightningDataModule
consists of several self-explanatory methods as we can see below:
import os
from torch.utils.data.dataset import random_split
from torch.utils.data import DataLoader
class DataModule(pl.LightningDataModule):
def __init__(self, data_path='./'):
super().__init__()
self.data_path = data_path
def prepare_data(self):
datasets.MNIST(root=self.data_path,
download=True)
return
def setup(self, stage=None):
# Note transforms.ToTensor() scales input images
# to 0-1 range
train = datasets.MNIST(root=self.data_path,
train=True,
transform=transforms.ToTensor(),
download=False)
self.test = datasets.MNIST(root=self.data_path,
train=False,
transform=transforms.ToTensor(),
download=False)
self.train, self.valid = random_split(train, lengths=[55000, 5000])
def train_dataloader(self):
train_loader = DataLoader(dataset=self.train,
batch_size=BATCH_SIZE,
drop_last=True,
shuffle=True,
num_workers=NUM_WORKERS)
return train_loader
def val_dataloader(self):
valid_loader = DataLoader(dataset=self.valid,
batch_size=BATCH_SIZE,
drop_last=False,
shuffle=False,
num_workers=NUM_WORKERS)
return valid_loader
def test_dataloader(self):
test_loader = DataLoader(dataset=self.test,
batch_size=BATCH_SIZE,
drop_last=False,
shuffle=False,
num_workers=NUM_WORKERS)
return test_loader
- Note that the
prepare_data
method is usually used for steps that only need to be executed once, for example, downloading the dataset; thesetup
method defines the the dataset loading -- if you run your code in a distributed setting, this will be called on each node / GPU. - Next, lets initialize the
DataModule
; we use a random seed for reproducibility (so that the data set is shuffled the same way when we re-execute this code):
torch.manual_seed(1)
data_module = DataModule(data_path=DATA_BASEPATH)
Training the model using the PyTorch Lightning Trainer class
- Next, we initialize our CNN (
ConvNet
) model. - Also, we define a call back so that we can obtain the model with the best validation set performance after training.
- PyTorch Lightning offers many advanced logging services like Weights & Biases. Here, we will keep things simple and use the
CSVLogger
:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
pytorch_model = ConvNet(
in_channels=1,
num_classes=torch.unique(all_test_labels).shape[0])
lightning_model = LightningCNN(
pytorch_model, learning_rate=LEARNING_RATE)
callbacks = [ModelCheckpoint(
save_top_k=1, mode='min', monitor="valid_mae")] # save top 1 model
logger = CSVLogger(save_dir="logs/", name="cnn-corn-mnist")
- Now it's time to train our model:
import time
trainer = pl.Trainer(
max_epochs=NUM_EPOCHS,
callbacks=callbacks,
progress_bar_refresh_rate=50, # recommended for notebooks
accelerator="auto", # Uses GPUs or TPUs if available
devices="auto", # Uses all available GPUs/TPUs if applicable
logger=logger,
deterministic=True,
log_every_n_steps=10)
start_time = time.time()
trainer.fit(model=lightning_model, datamodule=data_module)
runtime = (time.time() - start_time)/60
print(f"Training took {runtime:.2f} min in total.")
/home/jovyan/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:96: LightningDeprecationWarning: Setting `Trainer(progress_bar_refresh_rate=50)` is deprecated in v1.5 and will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress bar pass `enable_progress_bar = False` to the Trainer.
rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: logs/cnn-corn-mnist
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
------------------------------------------------
0 | model | ConvNet | 2.9 K
1 | train_mae | MeanAbsoluteError | 0
2 | valid_mae | MeanAbsoluteError | 0
3 | test_mae | MeanAbsoluteError | 0
------------------------------------------------
2.9 K Trainable params
0 Non-trainable params
2.9 K Total params
0.011 Total estimated model params size (MB)
Sanity Checking: 0it [00:00, ?it/s]
Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Training took 1.38 min in total.
Evaluating the model
- After training, let's plot our training MAE and validation MAE using pandas, which, in turn, uses matplotlib for plotting (you may want to consider a more advanced logger that does that for you):
import pandas as pd
metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv")
aggreg_metrics = []
agg_col = "epoch"
for i, dfg in metrics.groupby(agg_col):
agg = dict(dfg.mean())
agg[agg_col] = i
aggreg_metrics.append(agg)
df_metrics = pd.DataFrame(aggreg_metrics)
df_metrics[["train_loss", "valid_loss"]].plot(
grid=True, legend=True, xlabel='Epoch', ylabel='Loss')
df_metrics[["train_mae", "valid_mae"]].plot(
grid=True, legend=True, xlabel='Epoch', ylabel='MAE')
<AxesSubplot:xlabel='Epoch', ylabel='MAE'>
- As we can see from the loss plot above, the model starts overfitting pretty quickly; however the validation set MAE keeps improving. Based on the MAE plot, we can see that the best model, based on the validation set MAE, may be around epoch 16.
- The
trainer
saved this model automatically for us, we which we can load from the checkpoint via theckpt_path='best'
argument; below we use thetrainer
instance to evaluate the best model on the test set:
trainer.test(model=lightning_model, datamodule=data_module, ckpt_path='best')
Restoring states from the checkpoint path at logs/cnn-corn-mnist/version_0/checkpoints/epoch=17-step=3852.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at logs/cnn-corn-mnist/version_0/checkpoints/epoch=17-step=3852.ckpt
Testing: 0it [00:00, ?it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test_mae │ 0.11959999799728394 │ └───────────────────────────┴───────────────────────────┘
[{'test_mae': 0.11959999799728394}]
- The MAE of our model is quite good, especially compared to the 2.52 MAE baseline earlier.
Predicting labels of new data
- You can use the
trainer.predict
method on a newDataLoader
orDataModule
to apply the model to new data. - Alternatively, you can also manually load the best model from a checkpoint as shown below:
path = trainer.checkpoint_callback.best_model_path
print(path)
logs/cnn-corn-mnist/version_0/checkpoints/epoch=17-step=3852.ckpt
lightning_model = LightningCNN.load_from_checkpoint(path, model=pytorch_model)
lightning_model.eval();
- Note that our
ConvNet
, which is passed toLightningCNN
requires input arguments. However, this is automatically being taken care of since we usedself.save_hyperparameters()
inLightningCNN
's__init__
method. - Now, below is an example applying the model manually. Here, pretend that the
test_dataloader
is a new data loader.
test_dataloader = data_module.test_dataloader()
all_predicted_labels = []
for batch in test_dataloader:
features, _ = batch
logits = lightning_model(features)
predicted_labels = corn_label_from_logits(logits)
all_predicted_labels.append(predicted_labels)
all_predicted_labels = torch.cat(all_predicted_labels)
all_predicted_labels[:5]
tensor([7, 2, 1, 0, 4])