{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Guide to Distillation\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)][1]        [![Download](images/download.png)][2][Download this Notebook][2]\n", "\n", "[1]:https://colab.research.google.com/github/masterfulai/masterful-docs/blob/main/notebooks/guide_distillation.ipynb\n", "[2]:https://docs.masterfulai.com/0.5/notebooks/guide_distillation.ipynb\n", "\n", "In this guide, you will learn how to distill a large model into a smaller model using Masterful. For a more conceptual discussion, see the concepts documents.\n", "\n", "This guide closely follows the Keras Knowledge Distillation Guide, and its main goal is to show you how to replicate that work using Masterful. The Keras Knowledge Distillation guide can be found [here](https://keras.io/examples/vision/knowledge_distillation/)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prerequisites\n", "\n", "Please follow the Masterful installation instructions [here](../markdown/tutorial_installation.md) in order to run this Quickstart." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports\n", "\n", "Import tensorflow and masterful, and register the Masterful package. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "\n", "import masterful\n", "masterful = masterful.register()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You are going to use the MNIST dataset for this guide. You should limit yourself to very simple preprocessing, as required by the model you are distilling into." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "NUM_CLASSES = 10\n", "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n", "\n", "# Normalize data into the range (0,1)\n", "x_train = x_train.astype(\"float32\") / 255.0\n", "x_test = x_test.astype(\"float32\") / 255.0\n", "\n", "# Masterful needs an explicit channels parameter, so for single channel\n", "# data like MNIST we add the channels parameter explicitly.\n", "x_train = tf.reshape(x_train, (-1, 28, 28, 1))\n", "x_test = tf.reshape(x_test, (-1, 28, 28, 1))\n", "\n", "# Masterful performs best with one-hot labels.\n", "y_train = tf.keras.utils.to_categorical(y_train, NUM_CLASSES)\n", "y_test = tf.keras.utils.to_categorical(y_test, NUM_CLASSES)\n", "\n", "# Convert to Tensorflow Datasets for fast pipeline processing.\n", "training_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n", "test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This guide follows the same experimental setup as the Keras guide, so setup the teacher and student models respectively. These can also be called the source and target models. The teacher is a simple convolutional neural network, sized for the MNIST data. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"teacher\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "conv2d (Conv2D) (None, 14, 14, 256) 2560 \n", "_________________________________________________________________\n", "leaky_re_lu (LeakyReLU) (None, 14, 14, 256) 0 \n", "_________________________________________________________________\n", "max_pooling2d (MaxPooling2D) (None, 14, 14, 256) 0 \n", "_________________________________________________________________\n", "conv2d_1 (Conv2D) (None, 7, 7, 512) 1180160 \n", "_________________________________________________________________\n", "flatten (Flatten) (None, 25088) 0 \n", "_________________________________________________________________\n", "dense (Dense) (None, 10) 250890 \n", "=================================================================\n", "Total params: 1,433,610\n", "Trainable params: 1,433,610\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ], "source": [ "teacher_model = tf.keras.Sequential(\n", " [\n", " tf.keras.Input(shape=(28, 28, 1)),\n", " tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2), padding=\"same\"),\n", " tf.keras.layers.LeakyReLU(alpha=0.2),\n", " tf.keras.layers.MaxPooling2D(\n", " pool_size=(2, 2), strides=(1, 1), padding=\"same\"),\n", " tf.keras.layers.Conv2D(512, (3, 3), strides=(2, 2), padding=\"same\"),\n", " tf.keras.layers.Flatten(),\n", " tf.keras.layers.Dense(NUM_CLASSES),\n", " ],\n", " name=\"teacher\",\n", ")\n", "teacher_model.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The student model is an even simpler convolutional neural network, containing fewer parameters than the teacher network." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"student\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "conv2d_2 (Conv2D) (None, 14, 14, 16) 160 \n", "_________________________________________________________________\n", "leaky_re_lu_1 (LeakyReLU) (None, 14, 14, 16) 0 \n", "_________________________________________________________________\n", "max_pooling2d_1 (MaxPooling2 (None, 14, 14, 16) 0 \n", "_________________________________________________________________\n", "conv2d_3 (Conv2D) (None, 7, 7, 32) 4640 \n", "_________________________________________________________________\n", "flatten_1 (Flatten) (None, 1568) 0 \n", "_________________________________________________________________\n", "dense_1 (Dense) (None, 10) 15690 \n", "=================================================================\n", "Total params: 20,490\n", "Trainable params: 20,490\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ], "source": [ "student_model = tf.keras.Sequential(\n", " [\n", " tf.keras.Input(shape=(28, 28, 1)),\n", " tf.keras.layers.Conv2D(16, (3, 3), strides=(2, 2), padding=\"same\"),\n", " tf.keras.layers.LeakyReLU(alpha=0.2),\n", " tf.keras.layers.MaxPooling2D(\n", " pool_size=(2, 2), strides=(1, 1), padding=\"same\"),\n", " tf.keras.layers.Conv2D(32, (3, 3), strides=(2, 2), padding=\"same\"),\n", " tf.keras.layers.Flatten(),\n", " tf.keras.layers.Dense(NUM_CLASSES),\n", " ],\n", " name=\"student\",\n", ")\n", "student_model.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the Teacher \n", "Typically, you would use an already trained teacher model. In this guide, you need to explicitly \n", "train the teacher first before you can perform distillation. The teacher should achieve 97-98% accuracy in just five epochs. " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n", "938/938 [==============================] - 29s 29ms/step - loss: 0.1576 - categorical_accuracy: 0.9527\n", "Epoch 2/5\n", "938/938 [==============================] - 27s 29ms/step - loss: 0.0812 - categorical_accuracy: 0.9755\n", "Epoch 3/5\n", "938/938 [==============================] - 27s 29ms/step - loss: 0.0677 - categorical_accuracy: 0.9795\n", "Epoch 4/5\n", "938/938 [==============================] - 27s 29ms/step - loss: 0.0606 - categorical_accuracy: 0.9818\n", "Epoch 5/5\n", "938/938 [==============================] - 27s 29ms/step - loss: 0.0572 - categorical_accuracy: 0.9827\n", "157/157 [==============================] - 2s 10ms/step - loss: 0.1036 - categorical_accuracy: 0.9755\n", "Teacher evaluation metrics: {'loss': 0.10363840311765671, 'categorical_accuracy': 0.9754999876022339}\n" ] } ], "source": [ "batch_size = 64\n", "teacher_model.compile(\n", " optimizer=tf.keras.optimizers.Adam(),\n", " loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),\n", " metrics=[tf.keras.metrics.CategoricalAccuracy()],\n", ")\n", "teacher_model.fit(training_dataset.batch(batch_size), epochs=5)\n", "teacher_evaluation_metrics = teacher_model.evaluate(\n", " test_dataset.batch(batch_size), return_dict=True)\n", "print(f'Teacher evaluation metrics: {teacher_evaluation_metrics}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Distill to the student\n", "Now that you have a teacher model, you can distill that knowledge into the student model. The first step is to set up the model and data parameters that you will pass to Masterful. This lets Masterful know a little bit more about the model, data, and the task you are trying to perform." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "training_dataset_params = masterful.data.learn_data_params(\n", " dataset=training_dataset,\n", " task=masterful.enums.Task.CLASSIFICATION,\n", " image_range=masterful.enums.ImageRange.ZERO_ONE,\n", " num_classes=NUM_CLASSES,\n", " sparse_labels=False,\n", ")\n", "\n", "teacher_model_params = masterful.architecture.learn_architecture_params(\n", " model=teacher_model,\n", " task=masterful.enums.Task.CLASSIFICATION,\n", " input_range=masterful.enums.ImageRange.ZERO_ONE,\n", " num_classes=NUM_CLASSES,\n", " prediction_logits=True,\n", ")\n", "\n", "student_model_params = masterful.architecture.learn_architecture_params(\n", " model=student_model,\n", " task=masterful.enums.Task.CLASSIFICATION,\n", " input_range=masterful.enums.ImageRange.ZERO_ONE,\n", " num_classes=NUM_CLASSES,\n", " prediction_logits=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The final step is to call into Masterful to initiate the distillation process. By default, Masterul will learn an optimal training schedule for the distillation process based on the latest research for things like warmup, learning rate schedule, optimizer, and other optimization parameters. However, since this is a demonstration of the technique, in order to save time we can directly provide the optimization parameters to use during training." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/20\n", "844/844 [==============================] - 45s 46ms/step - loss: 0.6626 - val_categorical_accuracy: 0.8882 - val_distillation_loss: 0.3587 - val_student_loss: 0.2855\n", "Epoch 2/20\n", "844/844 [==============================] - 43s 45ms/step - loss: 0.3548 - val_categorical_accuracy: 0.9507 - val_distillation_loss: 0.1526 - val_student_loss: 0.2989\n", "Epoch 3/20\n", "844/844 [==============================] - 44s 45ms/step - loss: 0.1878 - val_categorical_accuracy: 0.9707 - val_distillation_loss: 0.1037 - val_student_loss: 0.2367\n", "Epoch 4/20\n", "844/844 [==============================] - 43s 45ms/step - loss: 0.1328 - val_categorical_accuracy: 0.9765 - val_distillation_loss: 0.0828 - val_student_loss: 0.2153\n", "Epoch 5/20\n", "844/844 [==============================] - 45s 46ms/step - loss: 0.1109 - val_categorical_accuracy: 0.9770 - val_distillation_loss: 0.0692 - val_student_loss: 0.2035\n", "Epoch 6/20\n", "844/844 [==============================] - 44s 45ms/step - loss: 0.0996 - val_categorical_accuracy: 0.9782 - val_distillation_loss: 0.0587 - val_student_loss: 0.1756\n", "Epoch 7/20\n", "844/844 [==============================] - 45s 46ms/step - loss: 0.0908 - val_categorical_accuracy: 0.9782 - val_distillation_loss: 0.0581 - val_student_loss: 0.1897\n", "Epoch 8/20\n", "844/844 [==============================] - 44s 46ms/step - loss: 0.0844 - val_categorical_accuracy: 0.9795 - val_distillation_loss: 0.0575 - val_student_loss: 0.1936\n", "Epoch 9/20\n", "844/844 [==============================] - 45s 47ms/step - loss: 0.0804 - val_categorical_accuracy: 0.9790 - val_distillation_loss: 0.0532 - val_student_loss: 0.1788\n", "Epoch 10/20\n", "844/844 [==============================] - 45s 47ms/step - loss: 0.0761 - val_categorical_accuracy: 0.9802 - val_distillation_loss: 0.0537 - val_student_loss: 0.1793\n", "Epoch 11/20\n", "844/844 [==============================] - 44s 46ms/step - loss: 0.0727 - val_categorical_accuracy: 0.9793 - val_distillation_loss: 0.0480 - val_student_loss: 0.1696\n", "Epoch 12/20\n", "844/844 [==============================] - 44s 46ms/step - loss: 0.0712 - val_categorical_accuracy: 0.9795 - val_distillation_loss: 0.0489 - val_student_loss: 0.1709\n", "Epoch 13/20\n", "844/844 [==============================] - 45s 47ms/step - loss: 0.0660 - val_categorical_accuracy: 0.9813 - val_distillation_loss: 0.0436 - val_student_loss: 0.1619\n", "Epoch 14/20\n", "844/844 [==============================] - 44s 46ms/step - loss: 0.0647 - val_categorical_accuracy: 0.9817 - val_distillation_loss: 0.0445 - val_student_loss: 0.1726\n", "Epoch 15/20\n", "844/844 [==============================] - 45s 46ms/step - loss: 0.0628 - val_categorical_accuracy: 0.9820 - val_distillation_loss: 0.0396 - val_student_loss: 0.1555\n", "Epoch 16/20\n", "844/844 [==============================] - 45s 47ms/step - loss: 0.0602 - val_categorical_accuracy: 0.9810 - val_distillation_loss: 0.0366 - val_student_loss: 0.1409\n", "Epoch 17/20\n", "844/844 [==============================] - 45s 46ms/step - loss: 0.0584 - val_categorical_accuracy: 0.9810 - val_distillation_loss: 0.0406 - val_student_loss: 0.1701\n", "Epoch 18/20\n", "844/844 [==============================] - 44s 46ms/step - loss: 0.0576 - val_categorical_accuracy: 0.9803 - val_distillation_loss: 0.0424 - val_student_loss: 0.1651\n", "Epoch 19/20\n", "844/844 [==============================] - 45s 46ms/step - loss: 0.0552 - val_categorical_accuracy: 0.9803 - val_distillation_loss: 0.0359 - val_student_loss: 0.1400\n", "Epoch 20/20\n", "844/844 [==============================] - 48s 50ms/step - loss: 0.0542 - val_categorical_accuracy: 0.9802 - val_distillation_loss: 0.0329 - val_student_loss: 0.1365\n", "DistillerRestoreStudentWeightsToBest: Restoring model weights from epoch 20 with val_loss 0.03294413164258003.\n", "94/94 [==============================] - 5s 53ms/step - categorical_accuracy: 0.9802 - distillation_loss: 0.0214 - student_loss: 0.0915\n" ] } ], "source": [ "optimization_params = masterful.optimization.OptimizationParams(\n", " batch_size=64,\n", " epochs=20,\n", " metrics=[tf.keras.metrics.CategoricalAccuracy()],\n", " optimizer=tf.keras.optimizers.SGD(\n", " learning_rate=tf.keras.optimizers.schedules.PolynomialDecay(\n", " initial_learning_rate=1e-2, \n", " decay_steps=20 * len(x_train), \n", " end_learning_rate=1e-4\n", " ), \n", " momentum=0.9),\n", " loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True),\n", ")\n", "\n", "training_report = masterful.ssl.distill(\n", " teacher_model, \n", " teacher_model_params, \n", " student_model,\n", " student_model_params, \n", " optimization_params,\n", " training_dataset,\n", " training_dataset_params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Measure Results\n", "Let's see how well you did. The DistillationReport returned by Masterful contains information about the distillation process. Your student model is acheiving nearly the same accuracy as the teacher model using 20,000 weights instead of 1,400,000 million weights. You can also evaluate the student model directly on your holdout set. " ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "157/157 [==============================] - 1s 3ms/step - loss: 0.1102 - categorical_accuracy: 0.9746\n", "Teacher Evaluation metrics: {'loss': 0.10363840311765671, 'categorical_accuracy': 0.9754999876022339}\n", "Student Evaluation metrics: {'loss': 0.11019360274076462, 'categorical_accuracy': 0.9745625257492065}\n" ] } ], "source": [ "student_evaluation_metrics = student_model.evaluate(\n", " test_dataset.batch(batch_size), return_dict=True)\n", "print(f'Teacher Evaluation metrics: {teacher_evaluation_metrics}')\n", "print(f'Student Evaluation metrics: {student_evaluation_metrics}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "9716fef025f64a2b69cf36238ef46ef956cc7030912a8258187d9e0c43537004" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 2 }