#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Exemple d'utilisation du module SWA (Stochastic Weight Averaging) pour un
classifier.
Created on Sun Jan 31 14:32:46 2021
@author: Cyrile Delestre
"""
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from torch import float32
from torch.nn import Module, Softmax
from torch.nn.functional import cross_entropy
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, random_split
from dstk.pytorch.networks import MLP
from dstk.pytorch import (BaseClassifier, check_tensor,
StochasticWeightAveraging, cut_dataset)
class dataset(Dataset):
def __init__(self, data, target):
self.data = data
self.target = target
def __len__(self):
return data.shape[0]
def __getitem__(self, idx):
if idx > self.__len__()-1:
raise IndexError()
return {'data': check_tensor(self.data[idx,:], float32),
'target': self.target[idx]}
class MonClassifieur(Module, BaseClassifier):
def __init__(self,
dim_in,
n_layers=3):
super().__init__()
self.dim_in = dim_in
self.n_layers = n_layers
self.build()
def build(self):
self.layer = MLP(
dim_in=self.dim_in,
dim_out=2,
dim_first_lay=32,
n_layers=self.n_layers,
activation_last_layer=Softmax(dim=1),
batchnorm=False,
dropout_last_layer=False
)
def forward(self, data, **kargs):
return self.layer(data)
#%%
if __name__=="__main__":
# Génération d'un dataset de classification binaire test
data, target = make_moons(
n_samples=2000,
shuffle=True,
noise=0.5,
random_state=None
)
# Création d'un dataset PyTorch
data_ds = dataset(data, target)
data_train, data_eval = random_split(data_ds, [1500, 500])
# Chargement du modèle
model_mlp = MonClassifieur(dim_in=2, n_layers=10)
# Sépare le set d'entrainement en deux avec une proportion d'intersection
# de 50% ce qui garanti que la moitié des observations sont nouvelles
# dans la phase d'entrainement du SWA.
data_train, data_swa = cut_dataset(
model=model_mlp,
data=data_train,
ratio=0,
stratified=True
)
train_dataset = DataLoader(data_train, batch_size=10, shuffle=True)
train_swa = DataLoader(data_swa, batch_size=10, shuffle=True)
eval_dataset = DataLoader(data_eval, batch_size=100)
# Utilisation du module SWA (Stochastic Weight Averaging) permettant
# de rendre le modèle plus robuste en moyennant les poids du réseau et en
# lui appliquant une perturbation sur le learning rate via la technique de
# recuisson et une perturbation sur les observations d'apprentissage.
swa_module = StochasticWeightAveraging(
nb_epoch=5,
lr=1e-5,
nb_anneal=3,
strategy='linear',
inner_ratio=0.7,
verbose=True
)
from dstk.pytorch import EarlyStoppingCallback, ProgressBarCallback
model_mlp.fit(
X=train_dataset,
eval_dataset=eval_dataset,
loss_fn=cross_entropy,
optimizer=Adam,
optimizer_kargs=dict(lr=1e-3),
nb_epoch=20,
target_field='target',
callbacks=[EarlyStoppingCallback(verbose=True), ProgressBarCallback()]
)
# Visualisation du résultat du modèle
xx, yy = np.meshgrid(np.arange(-2, 3, 0.01), np.arange(-1.5, 2, 0.01))
c_torch = model_mlp.predict_proba(
np.c_[xx.ravel(), yy.ravel()].astype(np.float32)
)
plt.figure()
plt.contourf(xx, yy, c_torch[:,1].reshape(xx.shape), 50, cmap="RdBu_r")
plt.scatter(data[target==0,0], data[target==0,1], c='blue', s=8)
plt.scatter(data[target==1,0], data[target==1,1], c='red', s=8)
plt.axis('square')
plt.grid()
# Sauvegarde du modèle
model_mlp.save_weights('classif.pt');
# Reload du modèle
model = MonClassifieur.load_model('classif.pt')