Guide to Distillation

Open In Colab        DownloadDownload this Notebook

In this guide, you will learn how to distill a large model into a smaller model using Masterful. For a more conceptual discussion, see the concepts documents.

This guide closely follows the Keras Knowledge Distillation Guide, and its main goal is to show you how to replicate that work using Masterful. The Keras Knowledge Distillation guide can be found here.

Prerequisites

Please follow the Masterful installation instructions here in order to run this Quickstart.

Imports

Import tensorflow and masterful, and register the Masterful package.

[ ]:
import tensorflow as tf

import masterful
masterful = masterful.register()

You are going to use the MNIST dataset for this guide. You should limit yourself to very simple preprocessing, as required by the model you are distilling into.

[3]:
NUM_CLASSES = 10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Normalize data into the range (0,1)
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

# Masterful needs an explicit channels parameter, so for single channel
# data like MNIST we add the channels parameter explicitly.
x_train = tf.reshape(x_train, (-1, 28, 28, 1))
x_test = tf.reshape(x_test, (-1, 28, 28, 1))

# Masterful performs best with one-hot labels.
y_train = tf.keras.utils.to_categorical(y_train, NUM_CLASSES)
y_test = tf.keras.utils.to_categorical(y_test, NUM_CLASSES)

# Convert to Tensorflow Datasets for fast pipeline processing.
training_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))

This guide follows the same experimental setup as the Keras guide, so setup the teacher and student models respectively. These can also be called the source and target models. The teacher is a simple convolutional neural network, sized for the MNIST data.

[4]:
teacher_model = tf.keras.Sequential(
    [
        tf.keras.Input(shape=(28, 28, 1)),
        tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
        tf.keras.layers.LeakyReLU(alpha=0.2),
        tf.keras.layers.MaxPooling2D(
            pool_size=(2, 2), strides=(1, 1), padding="same"),
        tf.keras.layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(NUM_CLASSES),
    ],
    name="teacher",
)
teacher_model.summary()
Model: "teacher"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d (Conv2D)              (None, 14, 14, 256)       2560
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 14, 14, 256)       0
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 256)       0
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 7, 7, 512)         1180160
_________________________________________________________________
flatten (Flatten)            (None, 25088)             0
_________________________________________________________________
dense (Dense)                (None, 10)                250890
=================================================================
Total params: 1,433,610
Trainable params: 1,433,610
Non-trainable params: 0
_________________________________________________________________

The student model is an even simpler convolutional neural network, containing fewer parameters than the teacher network.

[5]:
student_model = tf.keras.Sequential(
    [
        tf.keras.Input(shape=(28, 28, 1)),
        tf.keras.layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
        tf.keras.layers.LeakyReLU(alpha=0.2),
        tf.keras.layers.MaxPooling2D(
            pool_size=(2, 2), strides=(1, 1), padding="same"),
        tf.keras.layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(NUM_CLASSES),
    ],
    name="student",
)
student_model.summary()
Model: "student"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d_2 (Conv2D)            (None, 14, 14, 16)        160
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 14, 14, 16)        0
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 16)        0
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 7, 7, 32)          4640
_________________________________________________________________
flatten_1 (Flatten)          (None, 1568)              0
_________________________________________________________________
dense_1 (Dense)              (None, 10)                15690
=================================================================
Total params: 20,490
Trainable params: 20,490
Non-trainable params: 0
_________________________________________________________________

Train the Teacher

Typically, you would use an already trained teacher model. In this guide, you need to explicitly train the teacher first before you can perform distillation. The teacher should achieve 97-98% accuracy in just five epochs.

[6]:
batch_size = 64
teacher_model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.CategoricalAccuracy()],
)
teacher_model.fit(training_dataset.batch(batch_size), epochs=5)
teacher_evaluation_metrics = teacher_model.evaluate(
    test_dataset.batch(batch_size), return_dict=True)
print(f'Teacher evaluation metrics: {teacher_evaluation_metrics}')
Epoch 1/5
938/938 [==============================] - 29s 29ms/step - loss: 0.1576 - categorical_accuracy: 0.9527
Epoch 2/5
938/938 [==============================] - 27s 29ms/step - loss: 0.0812 - categorical_accuracy: 0.9755
Epoch 3/5
938/938 [==============================] - 27s 29ms/step - loss: 0.0677 - categorical_accuracy: 0.9795
Epoch 4/5
938/938 [==============================] - 27s 29ms/step - loss: 0.0606 - categorical_accuracy: 0.9818
Epoch 5/5
938/938 [==============================] - 27s 29ms/step - loss: 0.0572 - categorical_accuracy: 0.9827
157/157 [==============================] - 2s 10ms/step - loss: 0.1036 - categorical_accuracy: 0.9755
Teacher evaluation metrics: {'loss': 0.10363840311765671, 'categorical_accuracy': 0.9754999876022339}

Distill to the student

Now that you have a teacher model, you can distill that knowledge into the student model. The first step is to set up the model and data parameters that you will pass to Masterful. This lets Masterful know a little bit more about the model, data, and the task you are trying to perform.

[7]:
training_dataset_params = masterful.data.learn_data_params(
  dataset=training_dataset,
  task=masterful.enums.Task.CLASSIFICATION,
  image_range=masterful.enums.ImageRange.ZERO_ONE,
  num_classes=NUM_CLASSES,
  sparse_labels=False,
)

teacher_model_params = masterful.architecture.learn_architecture_params(
  model=teacher_model,
  task=masterful.enums.Task.CLASSIFICATION,
  input_range=masterful.enums.ImageRange.ZERO_ONE,
  num_classes=NUM_CLASSES,
  prediction_logits=True,
)

student_model_params = masterful.architecture.learn_architecture_params(
  model=student_model,
  task=masterful.enums.Task.CLASSIFICATION,
  input_range=masterful.enums.ImageRange.ZERO_ONE,
  num_classes=NUM_CLASSES,
  prediction_logits=True,
)

The final step is to call into Masterful to initiate the distillation process. By default, Masterul will learn an optimal training schedule for the distillation process based on the latest research for things like warmup, learning rate schedule, optimizer, and other optimization parameters. However, since this is a demonstration of the technique, in order to save time we can directly provide the optimization parameters to use during training.

[8]:
optimization_params = masterful.optimization.OptimizationParams(
  batch_size=64,
  epochs=20,
  metrics=[tf.keras.metrics.CategoricalAccuracy()],
  optimizer=tf.keras.optimizers.SGD(
    learning_rate=tf.keras.optimizers.schedules.PolynomialDecay(
      initial_learning_rate=1e-2,
      decay_steps=20 * len(x_train),
      end_learning_rate=1e-4
    ),
    momentum=0.9),
  loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True),
)

training_report = masterful.ssl.distill(
  teacher_model,
  teacher_model_params,
  student_model,
  student_model_params,
  optimization_params,
  training_dataset,
  training_dataset_params)
Epoch 1/20
844/844 [==============================] - 45s 46ms/step - loss: 0.6626 - val_categorical_accuracy: 0.8882 - val_distillation_loss: 0.3587 - val_student_loss: 0.2855
Epoch 2/20
844/844 [==============================] - 43s 45ms/step - loss: 0.3548 - val_categorical_accuracy: 0.9507 - val_distillation_loss: 0.1526 - val_student_loss: 0.2989
Epoch 3/20
844/844 [==============================] - 44s 45ms/step - loss: 0.1878 - val_categorical_accuracy: 0.9707 - val_distillation_loss: 0.1037 - val_student_loss: 0.2367
Epoch 4/20
844/844 [==============================] - 43s 45ms/step - loss: 0.1328 - val_categorical_accuracy: 0.9765 - val_distillation_loss: 0.0828 - val_student_loss: 0.2153
Epoch 5/20
844/844 [==============================] - 45s 46ms/step - loss: 0.1109 - val_categorical_accuracy: 0.9770 - val_distillation_loss: 0.0692 - val_student_loss: 0.2035
Epoch 6/20
844/844 [==============================] - 44s 45ms/step - loss: 0.0996 - val_categorical_accuracy: 0.9782 - val_distillation_loss: 0.0587 - val_student_loss: 0.1756
Epoch 7/20
844/844 [==============================] - 45s 46ms/step - loss: 0.0908 - val_categorical_accuracy: 0.9782 - val_distillation_loss: 0.0581 - val_student_loss: 0.1897
Epoch 8/20
844/844 [==============================] - 44s 46ms/step - loss: 0.0844 - val_categorical_accuracy: 0.9795 - val_distillation_loss: 0.0575 - val_student_loss: 0.1936
Epoch 9/20
844/844 [==============================] - 45s 47ms/step - loss: 0.0804 - val_categorical_accuracy: 0.9790 - val_distillation_loss: 0.0532 - val_student_loss: 0.1788
Epoch 10/20
844/844 [==============================] - 45s 47ms/step - loss: 0.0761 - val_categorical_accuracy: 0.9802 - val_distillation_loss: 0.0537 - val_student_loss: 0.1793
Epoch 11/20
844/844 [==============================] - 44s 46ms/step - loss: 0.0727 - val_categorical_accuracy: 0.9793 - val_distillation_loss: 0.0480 - val_student_loss: 0.1696
Epoch 12/20
844/844 [==============================] - 44s 46ms/step - loss: 0.0712 - val_categorical_accuracy: 0.9795 - val_distillation_loss: 0.0489 - val_student_loss: 0.1709
Epoch 13/20
844/844 [==============================] - 45s 47ms/step - loss: 0.0660 - val_categorical_accuracy: 0.9813 - val_distillation_loss: 0.0436 - val_student_loss: 0.1619
Epoch 14/20
844/844 [==============================] - 44s 46ms/step - loss: 0.0647 - val_categorical_accuracy: 0.9817 - val_distillation_loss: 0.0445 - val_student_loss: 0.1726
Epoch 15/20
844/844 [==============================] - 45s 46ms/step - loss: 0.0628 - val_categorical_accuracy: 0.9820 - val_distillation_loss: 0.0396 - val_student_loss: 0.1555
Epoch 16/20
844/844 [==============================] - 45s 47ms/step - loss: 0.0602 - val_categorical_accuracy: 0.9810 - val_distillation_loss: 0.0366 - val_student_loss: 0.1409
Epoch 17/20
844/844 [==============================] - 45s 46ms/step - loss: 0.0584 - val_categorical_accuracy: 0.9810 - val_distillation_loss: 0.0406 - val_student_loss: 0.1701
Epoch 18/20
844/844 [==============================] - 44s 46ms/step - loss: 0.0576 - val_categorical_accuracy: 0.9803 - val_distillation_loss: 0.0424 - val_student_loss: 0.1651
Epoch 19/20
844/844 [==============================] - 45s 46ms/step - loss: 0.0552 - val_categorical_accuracy: 0.9803 - val_distillation_loss: 0.0359 - val_student_loss: 0.1400
Epoch 20/20
844/844 [==============================] - 48s 50ms/step - loss: 0.0542 - val_categorical_accuracy: 0.9802 - val_distillation_loss: 0.0329 - val_student_loss: 0.1365
DistillerRestoreStudentWeightsToBest: Restoring model weights from epoch 20 with val_loss 0.03294413164258003.
94/94 [==============================] - 5s 53ms/step - categorical_accuracy: 0.9802 - distillation_loss: 0.0214 - student_loss: 0.0915

Measure Results

Let’s see how well you did. The DistillationReport returned by Masterful contains information about the distillation process. Your student model is acheiving nearly the same accuracy as the teacher model using 20,000 weights instead of 1,400,000 million weights. You can also evaluate the student model directly on your holdout set.

[9]:
student_evaluation_metrics = student_model.evaluate(
    test_dataset.batch(batch_size), return_dict=True)
print(f'Teacher Evaluation metrics: {teacher_evaluation_metrics}')
print(f'Student Evaluation metrics: {student_evaluation_metrics}')
157/157 [==============================] - 1s 3ms/step - loss: 0.1102 - categorical_accuracy: 0.9746
Teacher Evaluation metrics: {'loss': 0.10363840311765671, 'categorical_accuracy': 0.9754999876022339}
Student Evaluation metrics: {'loss': 0.11019360274076462, 'categorical_accuracy': 0.9745625257492065}
[ ]: