Utiliser un callback EarlyStopping
L’utilisation du callback EarlyStopping
est une technique courante pour éviter le surapprentissage (overfitting) en arrêtant l’entraînement du modèle lorsque la performance sur les données de validation ne s’améliore plus. Voici comment vous pouvez le faire en utilisant TensorFlow et Keras.
Étape 1 : Importer les bibliothèques nécessaires
Tout d’abord, assurez-vous d’importer les bibliothèques nécessaires.
import tensorflow as tffrom tensorflow.keras import layersfrom tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
Étape 2 : Définir le modèle
Définissez votre modèle. Pour cet exemple, nous utiliserons un modèle simple avec quelques couches de convolution et de dense.
class SimpleModel(tf.keras.Model): def __init__(self): super(SimpleModel, self).__init__() self.conv1 = layers.Conv2D(32, (3, 3), activation='relu') self.flatten = layers.Flatten() self.dense1 = layers.Dense(128, activation='relu') self.dense2 = layers.Dense(10)
def call(self, inputs): x = self.conv1(inputs) x = self.flatten(x) x = self.dense1(x) return self.dense2(x)
Étape 3 : Instancier et compiler le modèle
Instanciez et compilez votre modèle.
model = SimpleModel()
model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
Étape 4 : Préparer les données
Préparez vos données d’entraînement et de test. Pour cet exemple, nous utiliserons le jeu de données MNIST.
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()x_train, x_test = x_train / 255.0, x_test / 255.0
Étape 5 : Définir les callbacks EarlyStopping et ModelCheckpoint
Définissez un objet EarlyStopping
pour arrêter l’entraînement lorsque la performance sur les données de validation ne s’améliore plus, et un objet ModelCheckpoint
pour sauvegarder le meilleur modèle.
early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)checkpoint = ModelCheckpoint(filepath='best_model.h5', save_best_only=True, monitor='val_loss', mode='min', verbose=1)
Étape 6 : Entraîner le modèle avec les callbacks
Entraînez votre modèle en utilisant les callbacks EarlyStopping
et ModelCheckpoint
.
model.fit(x_train, y_train, epochs=50, batch_size=32, validation_split=0.2, callbacks=[early_stopping, checkpoint])
Code complet
Voici le code complet pour définir un modèle, préparer les données, définir les callbacks EarlyStopping
et ModelCheckpoint
, et entraîner le modèle avec ces callbacks.
import tensorflow as tffrom tensorflow.keras import layersfrom tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
# Définir le modèleclass SimpleModel(tf.keras.Model): def __init__(self): super(SimpleModel, self).__init__() self.conv1 = layers.Conv2D(32, (3, 3), activation='relu') self.flatten = layers.Flatten() self.dense1 = layers.Dense(128, activation='relu') self.dense2 = layers.Dense(10)
def call(self, inputs): x = self.conv1(inputs) x = self.flatten(x) x = self.dense1(x) return self.dense2(x)
# Instancier le modèlemodel = SimpleModel()
# Compiler le modèlemodel.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
# Préparer les données(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()x_train, x_test = x_train / 255.0, x_test / 255.0
# Définir les callbacks EarlyStopping et ModelCheckpointearly_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)checkpoint = ModelCheckpoint(filepath='best_model.h5', save_best_only=True, monitor='val_loss', mode='min', verbose=1)
# Entraîner le modèle avec les callbacksmodel.fit(x_train, y_train, epochs=50, batch_size=32, validation_split=0.2, callbacks=[early_stopping, checkpoint])
Explication des paramètres de EarlyStopping
monitor
: La métrique à surveiller pour déterminer l’arrêt prématuré (par exemple,val_loss
).patience
: Le nombre d’époques sans amélioration après lesquelles l’entraînement sera arrêté.restore_best_weights
: SiTrue
, les poids du modèle seront restaurés à ceux de l’époque où la meilleure valeur de la métrique surveillée a été observée.
Explication des paramètres de ModelCheckpoint
filepath
: Le chemin où le modèle sera sauvegardé.save_best_only
: SiTrue
, seul le meilleur modèle (en termes de la métrique surveillée) sera sauvegardé.monitor
: La métrique à surveiller pour déterminer le meilleur modèle.mode
: Le mode de surveillance (‘min’ pour minimiser la métrique, ‘max’ pour maximiser la métrique).verbose
: Le niveau de verbosité (0 pour silencieux, 1 pour afficher les messages).
En utilisant ces callbacks, vous pouvez éviter le surapprentissage et sauvegarder le meilleur modèle pendant l’entraînement.