Exemple d’un classifieur utilisant le processus SWA

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Exemple d'utilisation du module SWA (Stochastic Weight Averaging) pour un
classifier.

Created on Sun Jan 31 14:32:46 2021

@author: Cyrile Delestre
"""

import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import make_moons

from torch import float32
from torch.nn import Module, Softmax
from torch.nn.functional import cross_entropy
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, random_split

from dstk.pytorch.networks import MLP
from dstk.pytorch import (BaseClassifier, check_tensor,
                          StochasticWeightAveraging, cut_dataset)

class dataset(Dataset):
    def __init__(self, data, target):
        self.data = data
        self.target = target

    def __len__(self):
        return data.shape[0]

    def __getitem__(self, idx):
        if idx > self.__len__()-1:
            raise IndexError()
        return {'data': check_tensor(self.data[idx,:], float32),
                'target': self.target[idx]}

class MonClassifieur(Module, BaseClassifier):
    def __init__(self,
                 dim_in,
                 n_layers=3):
        super().__init__()
        self.dim_in = dim_in
        self.n_layers = n_layers
        self.build()

    def build(self):
        self.layer = MLP(
            dim_in=self.dim_in,
            dim_out=2,
            dim_first_lay=32,
            n_layers=self.n_layers,
            activation_last_layer=Softmax(dim=1),
            batchnorm=False,
            dropout_last_layer=False
        )

    def forward(self, data, **kargs):
        return self.layer(data)

#%%
if __name__=="__main__":
    # Génération d'un dataset de classification binaire test
    data, target = make_moons(
        n_samples=2000,
        shuffle=True,
        noise=0.5,
        random_state=None
    )

    # Création d'un dataset PyTorch
    data_ds = dataset(data, target)
    data_train, data_eval = random_split(data_ds, [1500, 500])

    # Chargement du modèle
    model_mlp = MonClassifieur(dim_in=2, n_layers=10)

    # Sépare le set d'entrainement en deux avec une proportion d'intersection 
    # de 50% ce qui garanti que la moitié des observations sont nouvelles 
    # dans la phase d'entrainement du SWA.
    data_train, data_swa = cut_dataset(
        model=model_mlp,
        data=data_train,
        ratio=0,
        stratified=True
    )
    train_dataset = DataLoader(data_train, batch_size=10, shuffle=True)
    train_swa = DataLoader(data_swa, batch_size=10, shuffle=True)
    eval_dataset = DataLoader(data_eval, batch_size=100)

    # Utilisation du module SWA (Stochastic Weight Averaging) permettant 
    # de rendre le modèle plus robuste en moyennant les poids du réseau et en 
    # lui appliquant une perturbation sur le learning rate via la technique de 
    # recuisson et une perturbation sur les observations d'apprentissage.
    swa_module = StochasticWeightAveraging(
        nb_epoch=5,
        lr=1e-5,
        nb_anneal=3,
        strategy='linear',
        inner_ratio=0.7,
        verbose=True
    )

    from dstk.pytorch import EarlyStoppingCallback, ProgressBarCallback
    model_mlp.fit(
        X=train_dataset,
        eval_dataset=eval_dataset,
        loss_fn=cross_entropy,
        optimizer=Adam,
        optimizer_kargs=dict(lr=1e-3),
        nb_epoch=20,
        target_field='target',
        callbacks=[EarlyStoppingCallback(verbose=True), ProgressBarCallback()]
    )

    # Visualisation du résultat du modèle
    xx, yy = np.meshgrid(np.arange(-2, 3, 0.01), np.arange(-1.5, 2, 0.01))
    c_torch = model_mlp.predict_proba(
        np.c_[xx.ravel(), yy.ravel()].astype(np.float32)
    )

    plt.figure()
    plt.contourf(xx, yy, c_torch[:,1].reshape(xx.shape), 50, cmap="RdBu_r")
    plt.scatter(data[target==0,0], data[target==0,1], c='blue', s=8)
    plt.scatter(data[target==1,0], data[target==1,1], c='red', s=8)
    plt.axis('square')
    plt.grid()

    # Sauvegarde du modèle
    model_mlp.save_weights('classif.pt');

    # Reload du modèle
    model = MonClassifieur.load_model('classif.pt')