Exemple d’un callback utilisateur

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Exemple de création d'un callback tensorboard pour visualiser l'évolution de 
la fonction coût et la performance de prédiction. Cette exemple s'appuit sur 
l'exemple du classifier.

Created on Sat Feb  6 17:41:38 2021

@author: Cyrile Delestre
"""
from dataclasses import dataclass

import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import make_moons
from sklearn.metrics import accuracy_score

from torch import float32
from torch.nn import Module, Softmax
from torch.nn.functional import cross_entropy
from torch.optim import Adam
from torch.utils.data import Dataset, random_split
from torch.utils.tensorboard import SummaryWriter

from dstk.pytorch.networks import MLP
from dstk.pytorch import (BaseClassifier, check_tensor, CallbackInterface)

class dataset(Dataset):
    def __init__(self, data, target):
        self.data = data
        self.target = target

    def __len__(self):
        return data.shape[0]

    def __getitem__(self, idx):
        if idx > self.__len__()-1:
            raise IndexError()
        return {'data': check_tensor(self.data[idx,:], float32),
                'target': self.target[idx]}

class MonClassifieur(Module, BaseClassifier):
    def __init__(self,
                 dim_in,
                 n_layers=2):
        super().__init__()
        self.dim_in = dim_in
        self.n_layers = n_layers
        self.build()

    def build(self):
        self.layer = MLP(
            dim_in=self.dim_in,
            dim_out=2,
            dim_first_lay=32,
            n_layers=self.n_layers,
            activation_last_layer=Softmax(dim=1),
            batchnorm=False,
            dropout_last_layer=False
        )
        self.optimizer = Adam(self.parameters(), lr=1e-3)

    def forward(self, data, **kargs):
        return self.layer(data)

@dataclass
class TensorBoeadCallback(CallbackInterface):
    r"""
    Exemple de callback simple qui permet d'envoyer dans TensorBoard des 
    informations en sortie de la phase d'évaluation.
    """
    writer = SummaryWriter("./tensorboard")

    def end_eval(self, model, data, state, control, res_loss):
        n_iter = state.n_epoch*state.epoch_max_steps + state.n_iter
        self.writer.add_scalar('Loss/eval', res_loss, n_iter)
        y_pred = []
        y_true = []
        for ii in data:
            y_pred.append(model.predict(ii))
            y_true.append(ii['target'].numpy())
        acc = accuracy_score(
            y_true=np.concatenate(y_pred, axis=0).reshape(-1),
            y_pred=np.concatenate(y_true, axis=0).reshape(-1)
        )
        self.writer.add_scalar('Acc/eval', acc, n_iter)

#%%
if __name__=="__main__":
    # Génération d'un dataset de classification binaire test
    data, target = make_moons(
        n_samples=2000,
        shuffle=True,
        noise=0.4,
        random_state=None
    )

    # Création d'un dataset PyTorch
    data_ds = dataset(data, target)
    data_train, data_eval = random_split(data_ds, [1500, 500])

    # Chargement du modèle
    model_mlp = MonClassifieur(dim_in=2, n_layers=10)

    # Via le callback TensorBoeadCallback il est possible de suivre 
    # l'évolution de l'apprentissage dans TensorBoard.
    model_mlp.fit(
        X=data_train,
        eval_set=data_eval,
        loss_fn=cross_entropy,
        nb_epoch=30,
        batch_size=10,
        field_target='target',
        callbacks=[TensorBoeadCallback()]
    )

    # Visualisation du résultat du modèle
    xx, yy = np.meshgrid(np.arange(-2, 3, 0.01), np.arange(-1.5, 2, 0.01))
    c_torch = model_mlp.predict_proba(
        np.c_[xx.ravel(), yy.ravel()].astype(np.float32)
    )

    plt.figure()
    plt.contourf(xx, yy, c_torch[:,1].reshape(xx.shape), 50, cmap="RdBu_r")
    plt.scatter(data[target==0,0], data[target==0,1], c='blue', s=8)
    plt.scatter(data[target==1,0], data[target==1,1], c='red', s=8)
    plt.axis('square')
    plt.grid()

    # Sauvegarde du modèle
    model_mlp.save_weights('classif.pt');

    # Reload du modèle
    model = MonClassifieur.load_model('classif.pt')
Callback TensorBoard