Utiliser un callback ReduceLROnPlateau
Pour détecter le surapprentissage (overfitting) pendant l’entraînement d’un modèle, vous pouvez créer un callback personnalisé qui surveille les métriques de performance sur les données d’entraînement et de validation. Si la performance sur les données de validation commence à se dégrader alors que la performance sur les données d’entraînement continue de s’améliorer, cela peut indiquer un surapprentissage.
Voici un exemple de callback personnalisé pour détecter le surapprentissage :
Étape 1 : Importer les bibliothèques nécessaires
Tout d’abord, assurez-vous d’importer les bibliothèques nécessaires.
import tensorflow as tffrom tensorflow.keras import layersfrom tensorflow.keras.callbacks import Callback
Étape 2 : Définir le callback personnalisé
Définissez une classe de callback personnalisée en héritant de tf.keras.callbacks.Callback
. Vous pouvez surcharger les méthodes on_epoch_end
pour surveiller les métriques de performance.
class OverfittingCallback(Callback): def __init__(self, patience=3, delta=0.01): super(OverfittingCallback, self).__init__() self.patience = patience self.delta = delta self.best_val_loss = None self.wait = 0
def on_epoch_end(self, epoch, logs=None): logs = logs or {} val_loss = logs.get('val_loss') train_loss = logs.get('loss')
if self.best_val_loss is None: self.best_val_loss = val_loss elif val_loss > self.best_val_loss + self.delta: self.wait += 1 if self.wait >= self.patience: print(f"Overfitting detected at epoch {epoch}. Stopping training.") self.model.stop_training = True else: self.wait = 0 self.best_val_loss = val_loss
print(f"End of epoch {epoch}, train_loss: {train_loss}, val_loss: {val_loss}, wait: {self.wait}")
Étape 3 : Définir le modèle
Définissez votre modèle. Pour cet exemple, nous utiliserons un modèle simple avec quelques couches de convolution et de dense.
class SimpleModel(tf.keras.Model): def __init__(self): super(SimpleModel, self).__init__() self.conv1 = layers.Conv2D(32, (3, 3), activation='relu') self.flatten = layers.Flatten() self.dense1 = layers.Dense(128, activation='relu') self.dense2 = layers.Dense(10)
def call(self, inputs): x = self.conv1(inputs) x = self.flatten(x) x = self.dense1(x) return self.dense2(x)
Étape 4 : Instancier et compiler le modèle
Instanciez et compilez votre modèle.
model = SimpleModel()
model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
Étape 5 : Préparer les données
Préparez vos données d’entraînement et de test. Pour cet exemple, nous utiliserons le jeu de données MNIST.
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()x_train, x_test = x_train / 255.0, x_test / 255.0
Étape 6 : Entraîner le modèle avec le callback personnalisé
Entraînez votre modèle en utilisant le callback personnalisé pour détecter le surapprentissage.
overfitting_callback = OverfittingCallback(patience=3, delta=0.01)
model.fit(x_train, y_train, epochs=50, batch_size=32, validation_split=0.2, callbacks=[overfitting_callback])
Code complet
Voici le code complet pour définir un modèle, préparer les données, définir le callback personnalisé pour détecter le surapprentissage, et entraîner le modèle avec ce callback.
import tensorflow as tffrom tensorflow.keras import layersfrom tensorflow.keras.callbacks import Callback
# Définir le callback personnalisé pour détecter le surapprentissageclass OverfittingCallback(Callback): def __init__(self, patience=3, delta=0.01): super(OverfittingCallback, self).__init__() self.patience = patience self.delta = delta self.best_val_loss = None self.wait = 0
def on_epoch_end(self, epoch, logs=None): logs = logs or {} val_loss = logs.get('val_loss') train_loss = logs.get('loss')
if self.best_val_loss is None: self.best_val_loss = val_loss elif val_loss > self.best_val_loss + self.delta: self.wait += 1 if self.wait >= self.patience: print(f"Overfitting detected at epoch {epoch}. Stopping training.") self.model.stop_training = True else: self.wait = 0 self.best_val_loss = val_loss
print(f"End of epoch {epoch}, train_loss: {train_loss}, val_loss: {val_loss}, wait: {self.wait}")
# Définir le modèleclass SimpleModel(tf.keras.Model): def __init__(self): super(SimpleModel, self).__init__() self.conv1 = layers.Conv2D(32, (3, 3), activation='relu') self.flatten = layers.Flatten() self.dense1 = layers.Dense(128, activation='relu') self.dense2 = layers.Dense(10)
def call(self, inputs): x = self.conv1(inputs) x = self.flatten(x) x = self.dense1(x) return self.dense2(x)
# Instancier le modèlemodel = SimpleModel()
# Compiler le modèlemodel.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
# Préparer les données(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()x_train, x_test = x_train / 255.0, x_test / 255.0
# Entraîner le modèle avec le callback personnaliséoverfitting_callback = OverfittingCallback(patience=3, delta=0.01)
model.fit(x_train, y_train, epochs=50, batch_size=32, validation_split=0.2, callbacks=[overfitting_callback])
Explication des paramètres du callback
patience
: Le nombre d’époques sans amélioration après lesquelles l’entraînement sera arrêté.delta
: La tolérance pour déterminer si la perte de validation a augmenté de manière significative.best_val_loss
: La meilleure perte de validation observée jusqu’à présent.wait
: Le nombre d’époques consécutives sans amélioration de la perte de validation.
En utilisant ce callback personnalisé, vous pouvez détecter le surapprentissage et arrêter l’entraînement prématurément pour éviter de surajuster le modèle aux données d’entraînement.