Skip to content

A Multilayer Perceptron for Ordinal Regression using CORAL -- Cement Dataset

In this tutorial, we implement a multilayer perceptron for ordinal regression based on the CORAL method. To learn more about CORAL, please have a look at our paper:

General settings and hyperparameters

  • Here, we specify some general hyperparameter values and general settings
  • Note that for small datatsets, it is not necessary and better not to use multiple workers as it can sometimes cause issues with too many open files in PyTorch
BATCH_SIZE = 32
NUM_EPOCHS = 200
LEARNING_RATE = 0.01
NUM_WORKERS = 0

DATA_BASEPATH = "./data"

Converting a regular classifier into a CORAL ordinal regression model

Changing a classifier to a CORAL model for ordinal regression is actually really simple and only requires a few changes:

1) We replace the output layer

output_layer = torch.nn.Linear(hidden_units[-1], num_classes)

by a CORAL layer (available through coral_pytorch):

output_layer = CoralLayer(size_in=hidden_units[-1], num_classes=num_classes)`

2)

Convert the integer class labels into the extended binary label format using the levels_from_labelbatch provided via coral_pytorch:

levels = levels_from_labelbatch(class_labels, 
                                num_classes=num_classes)

3)

Swap the cross entropy loss from PyTorch,

torch.nn.functional.cross_entropy(logits, true_labels)

with the CORAL loss (also provided via coral_pytorch):

loss = coral_loss(logits, levels)

4)

In a regular classifier, we usually obtain the predicted class labels as follows:

predicted_labels = torch.argmax(logits, dim=1)

Replace this with the following code to convert the predicted probabilities into the predicted labels:

predicted_labels = proba_to_label(probas)

Implementing a MultiLayerPerceptron using PyTorch Lightning's LightningModule

  • In this section, we set up the main model architecture using the LightningModule from PyTorch Lightning.
  • We start with defining our MultiLayerPerceptron model in pure PyTorch, and then we use it in the LightningModule to get all the extra benefits that PyTorch Lightning provides.
  • Given a multilayer perceptron classifier with cross-entropy loss, it is very easy to change this classifier into a ordinal regression model using CORAL as explained in the previous section. In the code example below, we use "1) the CoralLayer".
import torch
from coral_pytorch.layers import CoralLayer


# Regular PyTorch Module
class MultiLayerPerceptron(torch.nn.Module):
    def __init__(self, input_size, hidden_units, num_classes):
        super().__init__()

        # num_classes is used by the CORAL loss function
        self.num_classes = num_classes

        # Initialize MLP layers
        all_layers = []
        for hidden_unit in hidden_units:
            layer = torch.nn.Linear(input_size, hidden_unit)
            all_layers.append(layer)
            all_layers.append(torch.nn.ReLU())
            input_size = hidden_unit

        # CORAL: output layer -------------------------------------------
        # Regular classifier would use the following output layer:
        # output_layer = torch.nn.Linear(hidden_units[-1], num_classes)

        # We replace it by the CORAL layer:
        output_layer = CoralLayer(size_in=hidden_units[-1],
                                  num_classes=num_classes)
        # ----------------------------------------------------------------

        all_layers.append(output_layer)
        self.model = torch.nn.Sequential(*all_layers)

    def forward(self, x):
        x = self.model(x)
        return x
  • In our LightningModule we use loggers to track mean absolute errors for both the training and validation set during training; this allows us to select the best model based on validation set performance later.
  • Note that we make changes 2) (levels_from_labelbatch), 3) (coral_loss), and 4) (proba_to_label) to implement a CORAL model instead of a regular classifier:
from coral_pytorch.losses import coral_loss
from coral_pytorch.dataset import levels_from_labelbatch
from coral_pytorch.dataset import proba_to_label

import pytorch_lightning as pl
import torchmetrics


# LightningModule that receives a PyTorch model as input
class LightningMLP(pl.LightningModule):
    def __init__(self, model, learning_rate):
        super().__init__()

        self.learning_rate = learning_rate
        # The inherited PyTorch module
        self.model = model

        # Save settings and hyperparameters to the log directory
        # but skip the model parameters
        self.save_hyperparameters(ignore=['model'])

        # Set up attributes for computing the MAE
        self.train_mae = torchmetrics.MeanAbsoluteError()
        self.valid_mae = torchmetrics.MeanAbsoluteError()
        self.test_mae = torchmetrics.MeanAbsoluteError()

    # Defining the forward method is only necessary 
    # if you want to use a Trainer's .predict() method (optional)
    def forward(self, x):
        return self.model(x)

    # A common forward step to compute the loss and labels
    # this is used for training, validation, and testing below
    def _shared_step(self, batch):
        features, true_labels = batch

        # Convert class labels for CORAL ------------------------
        levels = levels_from_labelbatch(
            true_labels, num_classes=self.model.num_classes)
        # -------------------------------------------------------

        logits = self(features)

        # CORAL Loss --------------------------------------------
        # A regular classifier uses:
        # loss = torch.nn.functional.cross_entropy(logits, true_labels)
        loss = coral_loss(logits, levels.type_as(logits))
        # -------------------------------------------------------

        # CORAL Prediction to label -----------------------------
        # A regular classifier uses:
        # predicted_labels = torch.argmax(logits, dim=1)
        probas = torch.sigmoid(logits)
        predicted_labels = proba_to_label(probas)
        # -------------------------------------------------------
        return loss, true_labels, predicted_labels

    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.log("train_loss", loss)
        self.train_mae(predicted_labels, true_labels)
        self.log("train_mae", self.train_mae, on_epoch=True, on_step=False)
        return loss  # this is passed to the optimzer for training

    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.log("valid_loss", loss)
        self.valid_mae(predicted_labels, true_labels)
        self.log("valid_mae", self.valid_mae,
                 on_epoch=True, on_step=False, prog_bar=True)

    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels = self._shared_step(batch)
        self.test_mae(predicted_labels, true_labels)
        self.log("test_mae", self.test_mae, on_epoch=True, on_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

Setting up the dataset

  • In this section, we are going to set up our dataset.
  • We start by downloading and taking a look at the Cement dataset:

Inspecting the dataset

import pandas as pd
import numpy as np


data_df = pd.read_csv("https://raw.githubusercontent.com/gagolews/"
                      "ordinal_regression_data/master/cement_strength.csv")
data_df["response"] = data_df["response"]-1  # labels should start at 0

data_labels = data_df["response"]
data_features = data_df.loc[:, [
    "V1", "V2", "V3", "V4", "V5", "V6", "V7", "V8"]]

data_df.head()
response V1 V2 V3 V4 V5 V6 V7 V8
0 4 540.0 0.0 0.0 162.0 2.5 1040.0 676.0 28
1 4 540.0 0.0 0.0 162.0 2.5 1055.0 676.0 28
2 2 332.5 142.5 0.0 228.0 0.0 932.0 594.0 270
3 2 332.5 142.5 0.0 228.0 0.0 932.0 594.0 365
4 2 198.6 132.4 0.0 192.0 0.0 978.4 825.5 360
print('Number of features:', data_features.shape[1])
print('Number of examples:', data_features.shape[0])
print('Labels:', np.unique(data_labels.values))
print('Label distribution:', np.bincount(data_labels))
Number of features: 8
Number of examples: 998
Labels: [0 1 2 3 4]
Label distribution: [196 310 244 152  96]
  • Above, we can see that the dataset consists of 8 features, and there are 998 examples in total.
  • The labels are in range from 1 (weakest) to 5 (strongest), and we normalize them to start at zero (hence, the normalized labels are in the range 0 to 4).
  • Notice also that the dataset is quite imbalanced.

Performance baseline

  • Especially for imbalanced datasets, it's quite useful to compute a performance baseline.
  • In classification contexts, a useful baseline is to compute the accuracy for a scenario where the model always predicts the majority class -- you want your model to be better than that!
  • Note that if you are intersted in a single number that minimized the dataset mean squared error (MSE), that's the mean; similary, the median is a number that minimzes the mean absolute error (MAE).
  • So, if we use the mean absolute error, , to evaluate the model, it is useful to compute the MAE pretending the predicted label is always the median:
avg_prediction = np.median(data_labels.values)  # median minimizes MAE
baseline_mae = np.mean(np.abs(data_labels.values - avg_prediction))
print(f'Baseline MAE: {baseline_mae:.2f}')
Baseline MAE: 1.03
  • In other words, a model that would always predict the dataset median would achieve a MAE of 1.03. A model that has an MAE of > 1 is certainly a bad model.

Creating a Dataset class

  • Next, let us set up a data loading mechanism for our model.
  • Note that the Cement dataset is a relatively small dataset that fits into memory quite comfortably so this may seem like overkill. However, the following steps are useful as a template since you can use those for arbitrarily-sized datatsets.
  • First, we define a PyTorch Dataset class that returns the features (inputs) and labels:
from torch.utils.data import Dataset


class MyDataset(Dataset):

    def __init__(self, feature_array, label_array, dtype=np.float32):
        self.features = feature_array.astype(dtype)
        self.labels = label_array

    def __getitem__(self, index):
        inputs = self.features[index]
        label = self.labels[index]
        return inputs, label

    def __len__(self):
        return self.features.shape[0]

Setting up a DataModule

  • There are three main ways we can prepare the dataset for Lightning. We can
  • make the dataset part of the model;
  • set up the data loaders as usual and feed them to the fit method of a Lightning Trainer -- the Trainer is introduced in the next subsection;
  • create a LightningDataModule.
  • Here, we are going to use approach 3, which is the most organized approach. The LightningDataModule consists of several self-explanatory methods as we can see below:
import os
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader


class DataModule(pl.LightningDataModule):
    def __init__(self, data_path='./'):
        super().__init__()
        self.data_path = data_path

    def prepare_data(self):
        data_df = pd.read_csv(
            'https://raw.githubusercontent.com/gagolews/'
            'ordinal_regression_data/master/cement_strength.csv')
        data_df.to_csv(
            os.path.join(self.data_path, 'cement_strength.csv'), index=None)
        return

    def setup(self, stage=None):
        data_df = pd.read_csv(
            os.path.join(self.data_path, 'cement_strength.csv'))
        data_df["response"] = data_df["response"]-1  # labels should start at 0
        self.data_labels = data_df["response"]
        self.data_features = data_df.loc[:, [
            "V1", "V2", "V3", "V4", "V5", "V6", "V7", "V8"]]

        # Split into
        # 70% train, 10% validation, 20% testing

        X_temp, X_test, y_temp, y_test = train_test_split(
            self.data_features.values,
            self.data_labels.values,
            test_size=0.2,
            random_state=1,
            stratify=self.data_labels.values)

        X_train, X_valid, y_train, y_valid = train_test_split(
            X_temp,
            y_temp,
            test_size=0.1,
            random_state=1,
            stratify=y_temp)

        # Standardize features
        sc = StandardScaler()
        X_train_std = sc.fit_transform(X_train)
        X_valid_std = sc.transform(X_valid)
        X_test_std = sc.transform(X_test)

        self.train = MyDataset(X_train_std, y_train)
        self.valid = MyDataset(X_valid_std, y_valid)
        self.test = MyDataset(X_test_std, y_test)

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=BATCH_SIZE,
                          num_workers=NUM_WORKERS,
                          drop_last=True)

    def val_dataloader(self):
        return DataLoader(self.valid, batch_size=BATCH_SIZE,
                          num_workers=NUM_WORKERS)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=BATCH_SIZE,
                          num_workers=NUM_WORKERS)
  • Note that the prepare_data method is usually used for steps that only need to be executed once, for example, downloading the dataset; the setup method defines the the dataset loading -- if you run your code in a distributed setting, this will be called on each node / GPU.
  • Next, lets initialize the DataModule; we use a random seed for reproducibility (so that the data set is shuffled the same way when we re-execute this code):
torch.manual_seed(1) 
data_module = DataModule(data_path=DATA_BASEPATH)

Training the model using the PyTorch Lightning Trainer class

  • Next, we initialize our multilayer perceptron model (here, a 2-layer MLP with 24 units in the first hidden layer, and 16 units in the second hidden layer).
  • We wrap the model in our LightningMLP so that we can use PyTorch Lightning's powerful Trainer API.
  • Also, we define a callback so that we can obtain the model with the best validation set performance after training.
  • Note PyTorch Lightning offers many advanced logging services like Weights & Biases. However, here, we will keep things simple and use the CSVLogger:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger


pytorch_model = MultiLayerPerceptron(
    input_size=data_features.shape[1],
    hidden_units=(24, 16),
    num_classes=np.bincount(data_labels).shape[0])

lightning_model = LightningMLP(
    model=pytorch_model,
    learning_rate=LEARNING_RATE)


callbacks = [ModelCheckpoint(
    save_top_k=1, mode="min", monitor="valid_mae")]  # save top 1 model 
logger = CSVLogger(save_dir="logs/", name="mlp-coral-cement")
  • Now it's time to train our model:
import time


trainer = pl.Trainer(
    max_epochs=NUM_EPOCHS,
    callbacks=callbacks,
    progress_bar_refresh_rate=50,  # recommended for notebooks
    accelerator="auto",  # Uses GPUs or TPUs if available
    devices="auto",  # Uses all available GPUs/TPUs if applicable
    logger=logger,
    deterministic=True,
    log_every_n_steps=10)

start_time = time.time()
trainer.fit(model=lightning_model, datamodule=data_module)

runtime = (time.time() - start_time)/60
print(f"Training took {runtime:.2f} min in total.")
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                 | Params
---------------------------------------------------
0 | model     | MultiLayerPerceptron | 636   
1 | train_mae | MeanAbsoluteError    | 0     
2 | valid_mae | MeanAbsoluteError    | 0     
3 | test_mae  | MeanAbsoluteError    | 0     
---------------------------------------------------
636       Trainable params
0         Non-trainable params
636       Total params
0.003     Total estimated model params size (MB)



Validation sanity check: 0it [00:00, ?it/s]


/home/jovyan/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:132: UserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/jovyan/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:132: UserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(



Training: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]


Training took 0.94 min in total.

Evaluating the model

  • After training, let's plot our training MAE and validation MAE using pandas, which, in turn, uses matplotlib for plotting (you may want to consider a more advanced logger that does that for you):
metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv")

aggreg_metrics = []
agg_col = "epoch"
for i, dfg in metrics.groupby(agg_col):
    agg = dict(dfg.mean())
    agg[agg_col] = i
    aggreg_metrics.append(agg)

df_metrics = pd.DataFrame(aggreg_metrics)
df_metrics[["train_loss", "valid_loss"]].plot(
    grid=True, legend=True, xlabel='Epoch', ylabel='Loss')
df_metrics[["train_mae", "valid_mae"]].plot(
    grid=True, legend=True, xlabel='Epoch', ylabel='MAE')
<AxesSubplot:xlabel='Epoch', ylabel='MAE'>

png

png

  • As we can see from the loss plot above, the model starts overfitting pretty quickly; however the validation set MAE keeps improving. Based on the MAE plot, we can see that the best model, based on the validation set MAE, may be around epoch 110.
  • The trainer saved this model automatically for us, we which we can load from the checkpoint via the ckpt_path='best' argument; below we use the trainer instance to evaluate the best model on the test set:
trainer.test(model=lightning_model, datamodule=data_module, ckpt_path='best')
Restoring states from the checkpoint path at logs/mlp-coral-cement/version_3/checkpoints/epoch=114-step=2529.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at logs/mlp-coral-cement/version_3/checkpoints/epoch=114-step=2529.ckpt
/home/jovyan/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:132: UserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(



Testing: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_mae': 0.25}
--------------------------------------------------------------------------------





[{'test_mae': 0.25}]
  • The MAE of our model is quite good, especially compared to the 1.03 MAE baseline earlier.

Predicting labels of new data

  • You can use the trainer.predict method on a new DataLoader or DataModule to apply the model to new data.
  • Alternatively, you can also manually load the best model from a checkpoint as shown below:
path = trainer.checkpoint_callback.best_model_path
print(path)
logs/mlp-coral-cement/version_3/checkpoints/epoch=114-step=2529.ckpt
lightning_model = LightningMLP.load_from_checkpoint(
    path, model=pytorch_model)
lightning_model.eval();
  • Note that our MultilayerPerceptron, which is passed to LightningMLP requires input arguments. However, this is automatically being taken care of since we used self.save_hyperparameters() in LightningMLP's __init__ method.
  • Now, below is an example applying the model manually. Here, pretend that the test_dataloader is a new data loader.
test_dataloader = data_module.test_dataloader()

all_predicted_labels = []
for batch in test_dataloader:
    features, _ = batch
    logits = lightning_model(features)
    probas = torch.sigmoid(logits)
    predicted_labels = proba_to_label(probas)
    all_predicted_labels.append(predicted_labels)

all_predicted_labels = torch.cat(all_predicted_labels)
all_predicted_labels[:5]
tensor([0, 4, 0, 3, 1])