pytorch c’est quoi ?¶
- Le module dstk.pytorch est un module facilitant l’utilisation de PyTorch tout en garantissant une compatilibité avec Scikit-Learn tout en préservant la flexibilité de développement d’un module PyTorch from scratch (choix du device, implémentation de ses propres couches, etc.). Le module est constitué de plusieurs parties :
- La partie principale est construite autour des classes wrappées à Scikit-Learn :
BaseClassifieretBaseRegressorqui sont des modules d’apprentissage batch ;BaseClassifierOnlineetBaseRegressorOnlinequi sont des modules d’apprentissage online.
D’une méthode de random search
RandomizedSearchOnlinepermettant de chercher les hyper-paramètres du modèle par approche de Monte-Carlo pour les modèles online ;D’un réseau générique MLP
MLPpermettant de modéliser un réseau feed forward à l’aide de peu d’hyper-paramètres ;D’un module permettant d’appliquer une procedure de Stochastic Weight Averaging (SWA) permettant de robustifier la phase d’inférence d’un modèle à l’aide d’une procédure de moyennage des poids durant l’apprentissage via la classe
StochasticWeightAveraging;D’un module de
supervisionpermettant une utilisation simplifiée de TensorBoard ;etc.
- La phase d’apprentissage via la méthode
fit()permet une grande fléxibilité des actions possibles durant cette phase importante grâce au principe des callbacks offrant une interface flexibleCallbackInterface. Il est facile de créer son propre callback (voir l’exemple de création d’un callback utilisateur pour utiliser TensorBoard). Il existe déjà tout une liste de callback déjà disponible : PrintCallback;EarlyStoppingCallback;LRSchedulerCallback;ProgressBarCallback.
Squelette d’un module PyTorch¶
Voici le squelette d’un module PyTorch pour être compatible avec les classes BaseClassifier, BaseRegressor, BaseClassifierOnline et BaseRegressorOnline :
from torch.nn import Module
from dstk.pytorch import BaseClassifier
class MonClassifieur(Module, BaseClassifier):
def __init__(self, karg1, karg2, ...):
super().__init__()
self.karg1 = karg1
self.karg2 = karg2
################################
# Mettre en attribut tous les #
# arguments d'initialisation #
################################
self.build()
def build(self):
###########################################
# Instanciation des layers du réseau #
# avec les paramètres de l'initialisation #
###########################################
self.optimizer = Optimizer(self.parameters(), ...)
def forward(self, karg_in_1, karg_in_2, ..., **kargs):
######################
# Exécuter le réseau #
######################
return resultat
model = MonClassifieur()
- __init__
La création de réseau profond se fait par l’héritage de Module (torch.nn.Module) et pour cet exemple nous avons choisi d’utiliser
BaseClassifier. Il faut savoir que Scikit-Learn impose une certaine rigueur dans la création et la gestion des classes de machine learning, de transformation, etc. Par exemple, il faut que dans l’initialisation tous les arguments en entrée de l’instanciation de la classe deviennent des attributs possédant le même nom. Ceci est dû au fait que Scikit-Learn (pour des raisons de performance) ne réinstancie jamais une classe pour une copie d’un objet. Scikit-Learn utilise set_params et get_params issus de la classe sklearn.base.BaseEstimator. C’est une notion très importante, car pour un réseau de neurones, en cas de changement de paramètre, il nous faut réinstancier les layers qui composent le réseau. C’est ce que fait la méthode build. Donc dans l’initialisation et dans l’ordre, il faut :initier le réseau super().__init__ pour le module torch.nn.Module,
créer les attributs correspondants aux arguments de la classe MonClassifier,
appeler la méthode self.build() permettant d’instancier le réseau.
- build
Méthode de construction du réseau. Elle peut utiliser les attributs déclarés en entrée de l’instance pour paramétrer le réseau. Cette méthode doit porter également l’optimiseur via l’attribut optimizer. Le nom de l’attribut est important car s’il est mal renseigné l’entrainement via la méthode
fit()génèrera une erreur. Le choix de mettre l’optimiseur ici est motivé par 2 raisons :dans le cadre d’une recherche d’hyper-paramètres les options de l’optimiseur (comme le lerning rate, par exemple) vont être très souvent déterminantes ;
et dans un contexte online il est important que le module possède en interne l’état de l’optimiseur simplifant son utilisation.
- forward
Méthode imposée par PyTorch. Il est important de mettre **kargs à la fin du prototypage de la méthode car il permet de faire passer des arguments inutiles pour l’application du réseau et indispensables pour le processus de la méthode
fit()(comme la target, par exemple).
Multi-GPU et multi-CPU¶
Depuis la version 3.8.0 de DSTK il est possible d’entrainer les modélisations en multi-GPU/CPU sur des noeuds de machines via l’API Distributed Data Parallel de PyTorch. La procédure est simplifié au maximum afin que ce soit le plus transparant possibilité pour l’utilisateur. En reprennant l’exemple précédent, voici un exemple type de la mise en parallèle d’un apprentissage :
from dstk.pytorch import auto_init_distributed
model, train_dataloader, eval_dataloader = auto_init_distributed(
model=model,
train_dataset=train_dataset,
kwargs_train_dataloader=dict(
shuffle=True,
batch_size=128,
num_workers=4,
pin_memory=True
),
eval_dataset=eval_dataset,
kwargs_eval_dataloader=dict(
shuffle=False,
batch_size=256,
num_workers=4,
pin_memory=True
)
)
Et utiliser la méthode fit() de manière transparente. Cependant la procédure de lancement du scipt d’apprentissage ce fait de manière différente. Il faut que le code soit lancé via TorchElastic, il faudra donc lancer le script via un terminal :
torchrun (--args of torchrun) train_script.py (--args of train_script)
Exemple d'utilisation de la classe BaseClassifier
Exemple d'utilisation de la classe BaseRegressorOnline
Exemple d'utilisation du module Stochastic Weight Averaging (SWA)
Exemple d'une implémentation d'un callback utilisateur pour TensorBoard
Exemple d'utilisation de la classe EmbeddingProjector
Warning
PyTorch ne fait pas partie des dépendances du package DSTK. En effet, PyTorch est une librairie relativement lourde et inutile si on ne fait pas de réseaux de neurones complexes, donc pour ne pas générer un projet trop conséquent il est laissé à la charge du data scientist d’installer PyTorch ou non, laissant le reste du package DSTK utilisable. Si un chargement dans le projet du module dstk.pytorch est réalisé alors que PyTorch n’est pas installé, une erreur ModuleNotFoundError est générée invitant l’utilisateur à installer PyTorch. Une fois la procédure d’installation réalisée le module dstk.pytorch fonctionnera convenablement.