{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Guide To Ensembling\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)][1]        [![Download](images/download.png)][2][Download this Notebook][2]\n", "\n", "[1]:https://colab.research.google.com/github/masterfulai/masterful-docs/blob/main/notebooks/guide_ensembling.ipynb\n", "[2]:https://docs.masterfulai.com/0.5.2/notebooks/guide_ensembling.ipynb\n", "\n", "In this guide, you'll learn how to create an ensemble model. Ensembles are a common way to train a larger model successfully. In this guide, you'll run an experiment to determine if ensembling is useful for your problem of classifying CIFAR10. You will generate three scenarios to compare:\n", "\n", "1) As a first baseline, a model trained without Masterful. \n", "2) As a second baseline, the same model architecture trained by Masterful including Masterful's regularization techniques. This baseline allows you to to ablate the effects of regularization from ensembling. \n", "3) An ensemble model created by Masterful, including Masterful's regularization techniques. The null hypothesis you seek to reject is that the Masterful trained ensemble does not deliver better accuracy than the two baselines. \n", "\n", "For more information on the theory of ensembles, see the concepts doc." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prerequisites\n", "\n", "Please follow the Masterful installation instructions [here](../markdown/tutorial_installation.md) in order to run this Quickstart." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports\n", "\n", "This guide uses [Tensorflow](https://www.tensorflow.org) for training and [Tensorflow Datasets](https://www.tensorflow.org/datasets) for the training data." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import tensorflow as tf\n", "import tensorflow_datasets as tfds\n", "import tensorflow_addons as tfa\n", "import masterful\n", "masterful = masterful.activate()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Prepare your dataset. You'll use the MNIST dataset for this guide. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "NUM_CLASSES = 10\n", "\n", "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n", "\n", "# Normalize data into the range [0,1]\n", "x_train = x_train.astype(\"float32\") / 255.0\n", "x_test = x_test.astype(\"float32\") / 255.0\n", "\n", "# Reshape the data as single channel images.\n", "x_train = tf.reshape(x_train, (-1, 28, 28, 1))\n", "x_test = tf.reshape(x_test, (-1, 28, 28, 1))\n", "\n", "# Convert labels to one-hot. \n", "y_train = tf.keras.utils.to_categorical(y_train, NUM_CLASSES)\n", "y_test = tf.keras.utils.to_categorical(y_test, NUM_CLASSES)\n", "\n", "training_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n", "test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, build a simple model architecture. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "conv2d (Conv2D) (None, 14, 14, 5) 50 \n", "_________________________________________________________________\n", "re_lu (ReLU) (None, 14, 14, 5) 0 \n", "_________________________________________________________________\n", "flatten (Flatten) (None, 980) 0 \n", "_________________________________________________________________\n", "dense (Dense) (None, 10) 9810 \n", "=================================================================\n", "Total params: 9,860\n", "Trainable params: 9,860\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ], "source": [ "def get_model():\n", " model = tf.keras.Sequential(\n", " [\n", " tf.keras.Input(shape=(28, 28, 1)),\n", " tf.keras.layers.Conv2D(5, (3, 3), strides=(2, 2), padding=\"same\"),\n", " tf.keras.layers.ReLU(),\n", " tf.keras.layers.Flatten(),\n", " tf.keras.layers.Dense(NUM_CLASSES),\n", " ]\n", " )\n", " return model\n", "\n", "model = get_model()\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To implement scenario 1, train the model without using Masterful. This will train to 0.96-0.97 accuracy on the validation set. This is a sanity check to ensure that Masterful training improves the model accuracy. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/20\n", "938/938 [==============================] - 16s 4ms/step - loss: 0.8121 - categorical_accuracy: 0.7878 - val_loss: 0.2886 - val_categorical_accuracy: 0.9167\n", "Epoch 2/20\n", "938/938 [==============================] - 3s 3ms/step - loss: 0.2836 - categorical_accuracy: 0.9205 - val_loss: 0.2365 - val_categorical_accuracy: 0.9335\n", "Epoch 3/20\n", "938/938 [==============================] - 3s 3ms/step - loss: 0.2374 - categorical_accuracy: 0.9330 - val_loss: 0.2016 - val_categorical_accuracy: 0.9437\n", "Epoch 4/20\n", "938/938 [==============================] - 3s 3ms/step - loss: 0.2052 - categorical_accuracy: 0.9425 - val_loss: 0.1798 - val_categorical_accuracy: 0.9489\n", "Epoch 5/20\n", "938/938 [==============================] - 3s 3ms/step - loss: 0.1846 - categorical_accuracy: 0.9486 - val_loss: 0.1657 - val_categorical_accuracy: 0.9534\n", "Epoch 6/20\n", "938/938 [==============================] - 3s 3ms/step - loss: 0.1712 - categorical_accuracy: 0.9519 - val_loss: 0.1558 - val_categorical_accuracy: 0.9558\n", "Epoch 7/20\n", "938/938 [==============================] - 3s 3ms/step - loss: 0.1619 - categorical_accuracy: 0.9544 - val_loss: 0.1493 - val_categorical_accuracy: 0.9570\n", "Epoch 8/20\n", "938/938 [==============================] - 3s 3ms/step - loss: 0.1551 - categorical_accuracy: 0.9563 - val_loss: 0.1443 - val_categorical_accuracy: 0.9576\n", "Epoch 9/20\n", "938/938 [==============================] - 3s 3ms/step - loss: 0.1499 - categorical_accuracy: 0.9573 - val_loss: 0.1408 - val_categorical_accuracy: 0.9584\n", "Epoch 10/20\n", "938/938 [==============================] - 3s 3ms/step - loss: 0.1457 - categorical_accuracy: 0.9583 - val_loss: 0.1383 - val_categorical_accuracy: 0.9588\n", "Epoch 11/20\n", "938/938 [==============================] - 3s 3ms/step - loss: 0.1422 - categorical_accuracy: 0.9591 - val_loss: 0.1365 - val_categorical_accuracy: 0.9591\n", "Epoch 12/20\n", "938/938 [==============================] - 3s 3ms/step - loss: 0.1392 - categorical_accuracy: 0.9599 - val_loss: 0.1349 - val_categorical_accuracy: 0.9595\n", "Epoch 13/20\n", "938/938 [==============================] - 3s 3ms/step - loss: 0.1366 - categorical_accuracy: 0.9604 - val_loss: 0.1337 - val_categorical_accuracy: 0.9591\n", "Epoch 14/20\n", "938/938 [==============================] - 3s 3ms/step - loss: 0.1343 - categorical_accuracy: 0.9609 - val_loss: 0.1328 - val_categorical_accuracy: 0.9593\n", "Epoch 15/20\n", "938/938 [==============================] - 3s 3ms/step - loss: 0.1321 - categorical_accuracy: 0.9614 - val_loss: 0.1320 - val_categorical_accuracy: 0.9594\n", "Epoch 16/20\n", "938/938 [==============================] - 3s 3ms/step - loss: 0.1301 - categorical_accuracy: 0.9622 - val_loss: 0.1314 - val_categorical_accuracy: 0.9597\n", "Epoch 17/20\n", "938/938 [==============================] - 3s 3ms/step - loss: 0.1282 - categorical_accuracy: 0.9628 - val_loss: 0.1308 - val_categorical_accuracy: 0.9597\n", "Epoch 18/20\n", "938/938 [==============================] - 3s 3ms/step - loss: 0.1263 - categorical_accuracy: 0.9633 - val_loss: 0.1301 - val_categorical_accuracy: 0.9598\n", "Epoch 19/20\n", "938/938 [==============================] - 3s 3ms/step - loss: 0.1241 - categorical_accuracy: 0.9638 - val_loss: 0.1288 - val_categorical_accuracy: 0.9605\n", "Epoch 20/20\n", "938/938 [==============================] - 3s 3ms/step - loss: 0.1215 - categorical_accuracy: 0.9646 - val_loss: 0.1268 - val_categorical_accuracy: 0.9610\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch_size = 64\n", "model.compile(\n", " optimizer=tf.keras.optimizers.Adam(),\n", " loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),\n", " metrics=[tf.keras.metrics.CategoricalAccuracy()],\n", ")\n", "model.fit(training_dataset.batch(batch_size), validation_data=test_dataset.batch(batch_size), epochs=20)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "79/79 [==============================] - 0s 3ms/step - loss: 0.1268 - categorical_accuracy: 0.9610\n" ] } ], "source": [ "baseline_eval_metrics = model.evaluate(test_dataset.batch(128))\n", "baseline_val_loss, baseline_val_accuracy = baseline_eval_metrics" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train with Masterful\n", "\n", "You can implement scenario 2 and 3 at the same time. The Masterful ensembling API will train multiple individual models and report on their accuracy before ensembling and reporting on the ensemble's accuracy. \n", "\n", "The Masterful AutoML platform learns how to train your model by focusing on five core organizational principles in deep learning: architecture, data, optimization, regularization, and semi-supervision.\n", "\n", "**Architecture** is the structure of weights, biases, and activations that define a model. In this example, the architecture is defined by the model you created above. Importantly, we will use the [ensemble_multiplier](../api/api_architecture.rst#masterful.architecture.ArchitectureParams) parameter of the [ArchitectureParams](../api/api_architecture.rst#masterful.architecture.ArchitectureParams) structure to specifiy how many models we will ensemble together.\n", "\n", "**Data** is the input used to train the model. In this example, you are using a labeled training dataset of flowers. More advanced usages of the Masterful AutoML platform can take into account unlabeled and synthetic data as well, using a variety of different techniques.\n", "\n", "**Optimization** means finding the best weights for a model and training data. Optimization is different from regularization because optimization does not consider generalization to unseen data. The central challenge of optimization is speed - find the best weights faster.\n", "\n", "**Regularization** means helping a model generalize to data it has not yet seen. Another way of saying this is that regularization is about fighting overfitting.\n", "\n", "**Semi-Supervision** is the process by which a model can be trained using both labeled and unlabeled data.\n", "\n", "The first step when using Masterful is to learn the optimal set of parameters for each of the five buckets above. You start by learning the architecture and data parameters of the model and training dataset. In the code below, you are telling Masterful that your model is performing a classification task (`masterful.enums.Task.CLASSIFICATION`) with 10 labels (`num_classes=NUM_CLASSES`), and that the input range of the image features going into your model are in the range [0,1] (`input_range=masterful.enums.ImageRange.ZERO_ONE`). Also, the model outputs logits rather than a softmax classification (`prediction_logits=True`). Note the `ensemble_multiplier` parameter, set to 5, which tells Masterful to train and ensemble 5 different models and average the results of the joint prediction.\n", "\n", "Furthermore, in the training dataset, you are providing dense labels (`sparse_labels=False`) rather than sparse labels.\n", "\n", "For more details on architecture and data parameters, see the API specifications for [ArchitectureParams](../api/api_architecture.rst#masterful.architecture.ArchitectureParams) and [DataParams](../api/api_data.rst#masterful.data.DataParams). " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "model_params = masterful.architecture.learn_architecture_params(\n", " model=model,\n", " task=masterful.enums.Task.CLASSIFICATION,\n", " input_range=masterful.enums.ImageRange.ZERO_ONE,\n", " num_classes=NUM_CLASSES,\n", " prediction_logits=True,\n", " ensemble_multiplier=5,\n", ")\n", "training_dataset_params = masterful.data.learn_data_params(\n", " dataset=training_dataset,\n", " task=masterful.enums.Task.CLASSIFICATION,\n", " image_range=masterful.enums.ImageRange.ZERO_ONE,\n", " num_classes=NUM_CLASSES,\n", " sparse_labels=False,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next you learn the optimization parameters that will be used to train the model. Below, you use Masterful to learn the standard set of optimization parameters to train your model for a classification task.\n", "\n", "For more details on the optmization parameters, please see the [OptimizationParams](../api/api_optimization.rst#masterful.optimization.OptimizationParams) API specification." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "optimization_params = masterful.optimization.learn_optimization_params(\n", " model,\n", " model_params,\n", " training_dataset,\n", " training_dataset_params,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The regularization parameters used can have a dramatic impact on the final performance of your trained model. Learning these parameters can be a time-consuming and domain specific challenge. Masterful can speed up this process by learning these parameters for you. In general, this can be an expensive operation. A rough order of magnitude for learning these parameters is 2x the time it takes to train your model. However, this is still dramatically faster than manually finding these parameters yourself. In the example below, you will use the [learn_regularization_params](../api/api_regularization.rst#masterful.regularization.learn_regularization_params) API to learn these parameters directly from your dataset and model.\n", "\n", "For more details on the regularization parameters, please see the [RegularizationParams](../api/api_regularization.rst#masterful.regularization.RegularizationParams) API specification." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "regularization_params = masterful.regularization.learn_regularization_params(\n", " model,\n", " model_params,\n", " optimization_params,\n", " training_dataset,\n", " training_dataset_params,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The final step before training is to learn the optimal set of semi-supervision parameters. For this Quickstart, we are not using any unlabeled or synthetic data as part of training, so most forms of semi-supervision will be disabled by default.\n", "\n", "For more details on the semi-supervision parameters, please see the [SemiSupervisedParams](../api/api_ssl.rst#masterful.ssl.SemiSupervisedParams) API specification." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ssl_params = masterful.ssl.learn_ssl_params(\n", " training_dataset,\n", " training_dataset_params,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, you are ready to train your model using the Masterful AutoML platform. In the next cell, you will see the call to [masterful.training.train](../api/api_training.rst#masterful.training.train), which is the entry point to the meta-learning engine of the Masterful AutoML platform. Notice there is no need to batch your data (Masterful will find the optimal batch size for you). No need to shuffle your data (Masterful handles this for you). You don't even need to pass in a validation dataset (Masterful finds one for you). You hand Masterful a model and a dataset, and Masterful will figure the rest out for you. Ensembling with Masterful looks exactly the same as non-ensemble training with Masterful. Masterful handles everything for you inside the API." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fitting model phase 2 of 2 DONE.\n", "Masterful backend.fit() DONE.\n" ] } ], "source": [ "training_report = masterful.training.train(\n", " model, \n", " model_params, \n", " optimization_params,\n", " regularization_params,\n", " ssl_params,\n", " training_dataset,\n", " training_dataset_params,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that the ensemble is trained, access it using the report returned and evaluate on the `test_dataset`. " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Evaluating ensembled model...\n", "157/157 [==============================] - 1s 4ms/step - loss: 0.0560 - categorical_accuracy: 0.9824\n" ] } ], "source": [ "print('Evaluating ensembled model...')\n", "ensembled_model = training_report.model\n", "ensembled_results = ensembled_model.evaluate(test_dataset.batch(64), return_dict=True)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'loss': 0.0560285821557045,\n", " 'categorical_accuracy': 0.9824000000953674,\n", " 'child_models': {0: {'loss': 0.07878458499908447,\n", " 'categorical_accuracy': 0.9750999808311462},\n", " 1: {'loss': 0.06273771077394485, 'categorical_accuracy': 0.98089998960495},\n", " 2: {'loss': 0.069137804210186, 'categorical_accuracy': 0.9787999987602234},\n", " 3: {'loss': 0.061786066740751266, 'categorical_accuracy': 0.980400025844574},\n", " 4: {'loss': 0.06251046806573868,\n", " 'categorical_accuracy': 0.9807000160217285}}}" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "training_report.validation_results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The report also holds metrics on the child models. Access them, average, and plot the results. You've successfully demonstrated that an ensemble outperforms both individual child models, as well as a model trained without any of Masterful's techniques. " ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "157/157 [==============================] - 1s 3ms/step - loss: 0.0943 - categorical_accuracy: 0.9708\n", "157/157 [==============================] - 1s 3ms/step - loss: 0.0753 - categorical_accuracy: 0.9775\n", "157/157 [==============================] - 1s 3ms/step - loss: 0.0833 - categorical_accuracy: 0.9754\n", "157/157 [==============================] - 1s 3ms/step - loss: 0.0745 - categorical_accuracy: 0.9758\n", "157/157 [==============================] - 1s 3ms/step - loss: 0.0750 - categorical_accuracy: 0.9781\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd8AAAEWCAYAAAAn550kAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAgbklEQVR4nO3debglVX3u8e8LzdgggwgyNLQCziCjSuQiCsEBBXJFiRFFUQxJrmLihEMUo0YSvFdMnCAOoOCAqISACgkKCAhCMw+iCC2jQsssjvC7f9Q69OZwpj6nT50evp/nOU/vXVVr1aq1a9e7a+3q2qkqJElSf1aY6QZIkrS8MXwlSeqZ4StJUs8MX0mSemb4SpLUM8NXkqSeGb7SEiDJmUneONPtWJIk+XCSBUl+uZjqqyRbLI66Jri+uW2ds9rz7yY5oK/1L2uSHJPkw2PM7/X1nSrDV4tFC4+7kqwy021Z0iU5LMlx01j//CS7L4Z6XpfknMXRphHqHvNAmWQO8DbgaVX1+OloQ9+q6sVVdexMt0NLBsNXU5ZkLvC/gAL26nnds/pcn3qzGfDrqrp9UQu6T2hpYPhqcXgtcD5wDPCIYbUkc5J8K8kdSX6d5JMD8w5Kck2S+5JcnWS7Nv0RZ0WDw01Jdk1yc5J3teHILyZZJ8kpbR13tcebDJRfN8kXk9za5p/Upl+Z5GUDy63Uhjm3GWkjW3uvS3JnkpOTbDQwr5IcnORnbR2fSpIR6ngR8B5gvyT3J7lsYPZmSc5t/XF6kvUGyj0nyXlJ7k5yWZJdR2njl4FNgf9q9b9zvPLtDPf6tt4bkrw6yVOBzwI7tXruHmV9jyo7MO/A9vreleS0JJu16We3RS5rde83rM7dgf8GNmrzj2nT90pyVduGM1sbh8rMb/vE5cBvxgjgl7T2LkhyRJIVWvnNk3y/7aMLkhyfZO2B+t+V5Ja2ndcm2a1NXyHJoUl+3sqekGTdUfrq4a8WWr+dk+RjrX9uSPLigWXXSvL5JLe19X44yYoj1LlRkt8OrjPJtm0bVkqyRZKzktzTpn19lH4Zbx85M8mHRto/k6ya5Li2/XcnuTDJBuNtR+uDc5N8vJW7Psmftek3Jbk9jx6mXy/Jf7c2nDW0T42wLau0vr0xya+SfDbJaqNt+4yoKv/8m9IfcB3wt8D2wB+BDdr0FYHLgI8Ds4FVgZ3bvFcAtwA7AgG2ADZr8wrYYqD+Y4APt8e7An8C/gVYBVgNeCzwcmB1YE3gG8BJA+VPBb4OrAOsBDyvTX8n8PWB5fYGrhhlG18ALAC2a+v9d+DsgfkFnAKsTRd+dwAvGqWuw4Djhk07E/g58KS2TWcCh7d5GwO/Bl5C94H5z9vzx41S/3xg94Hno5Zvr8u9wJPbshsCT2+PXwecM8brPlbZfdp+8VRgFvA+4Lxh/bXFGHXvCtw88PxJwG9a21dqr911wMoD23wpMAdYbZQ6C/gBsG57jX4KvLHN26LVvUrrl7OBI9u8JwM3ARu153OBzdvjt9J98NyklT0K+OrAcgXMGniN3zjQt38EDqJ7n/wNcCuQNv+kVtdsYH3gx8Bfj7Jd3wcOGnh+BPDZ9virwHvb6/7w+2+EOsbcxxh7//xr4L/o3n8r0h0HHjPedrQ++BPw+lbuw8CNwKdaX+4B3AesMXAcuA/Ypc3/BAP7JwP7FHAkcHJ7rdds7fvoTB8rH9HnM90A/5buP2DndhBZrz3/CfD37fFOdCE0a4RypwGHjFLneOH7B2DVMdq0DXBXe7wh8BCwzgjLbdTezEMHihOBd45S5+eBfx14vkbb7rkDbd55YP4JwKGj1HUYI4fv+wae/y3wvfb4XcCXR+i/A0apfz6PDN9Ry7eD4t10H15WG7bM6xg/fEcr+13gDQPPVwAeYJQPWCPUvSuPDN9/BE4YVt8twK4D23zgOPtqMfCBqPXxGaMsuw9wSXu8BXA7sDuw0rDlrgF2G3i+YdsvZjF++F43UG71tuzjgQ2A3w/2KfAq4AejtPWNwPfb49B9UNilPf8ScDSwyTh9M+Y+Ns7+eSBwHrD1sPJjbkfrg58NzNuq9cEGA9N+DWzTHh8DfG3Ye/BBYM7gPtX64De0D0ht3k7ADWP1Qd9/Djtrqg4ATq+qBe35V1g49DwH+EVV/WmEcnPoPklPxh1V9buhJ0lWT3JUkl8kuZfurGXtNrw1B7izqu4aXklV3QqcC7y8DTG+GDh+lHVuBPxioOz9dAeGjQeWGbwq9wG6g8OiGK38ZsAr2tDc3emGgHemO9BPxKjlq+o3wH7AwcBtSU5N8pSJVDpO2c2ATwys7066g+LGI1Y2vuH9/xBdyAzWd9ME6hlc5hetXpKsn+RrbWj0XuA4YL22ruvoznAPA25vyw195bAZ8O2B7byGLhA2mEBbHn69q+qB9nCNVudKdH06VO9RdGeOIzmR7uuBjejOCgv4YZv3Trp+/3G6IfsDR6ljIvvYaPvnl+mC+mvpvtr51yQrTXA7fjXw+LetL4ZPG3wfPfz6tffgnbTXcMDj6D7MzBtY7/fa9CWGFyZo0tp3KK8EVszC/w6yCl3wPZPujbJpklkjBPBNwOajVP0A3ZtnyOOBmwee17Dl30Y3NPjsqvpluu9sL2HhWcC6SdauqrtHWNexdGcOs4AfVdUto7TpVrqDCQBJZtMNd4+2/FiGt388N9GdlRw0yfrHLF9VpwGntdfzw8B/sPACurFXNHrZm4CPVNVoH2YW1a10Z0YAJAndB6vB/p9Iv84BrmqPN231Any0ld+6qn6dZB/g4esTquorwFeSPIYuQP4FeA3ddh5YVecOX1G6CxEn4ya6M8b1Rvng+ghVdXeS0+nei0+lG/auNu+XdEPbJNkZ+J8kZ7cPFMPXuSj72OD6/wh8EPhg2+bvANe2fye8HRM0Z+hBkjXohpVvHbbMArrQfvoY7+cZ55mvpmIfuk/5T6Mb6t2G7s3/Q7qLsH4M3AYcnmR2uzDjua3s54C3J9k+nS0GLp64FPirJCumu0DpeeO0Y026N9vd6S48+cDQjKq6jW4I9NPpLsxaKckuA2VPovse9xC6IbrRfAV4fZJt0v13qn8GLqiq+eO0bSS/AuamXewzAccBL0vywtYnq6a78GyTUZb/FfDEiZRPskG6C5lm0x0o76d7TYfq2STJyiOtZJyynwXeneTpbdm1krxijDaO5wRgzyS7tbOqt7V1nrcIdQC8o+0Hc+he86ELkNZs7b87ycbAOwa288lJXtBe99/R7WuD2/mRLLyY7HFJ9l7ENj1C22dPB/5vkseku6hr8yRjvQ++Qveee3l7PNT2VwzsJ3fRfcB48NHFF3kfe1iS5yfZqo003Us37P7gJLdjPC9JsnPbJz9E9x58xIhHGxX5D+DjSdZvbdw4yQunsN7FzvDVVBwAfLGqbqyqXw790Z0xvJruzPNldN/D3Eh39rofQFV9A/gI3YHiProQHLpi85BW7u5Wz0njtONIuotAFtBd/PK9YfNfQ3dA+Andd3dvHZpRVb8Fvgk8AfjWaCuoqjPovnf8Jt0His2BvxynXaP5Rvv310kuHm/hdnDZm+4q6TvozlLewejv348C72tDbm8fp/wKdEF2K90Q3vPovs+D7kKeq4BfJlnAo41atqq+TXd2+LU2jHsl3bD+kMOAY1sbXzmBPrgW2J/uQrcFdPvHy6rqD+OVHeY/gXl0H/BOpfsuH7ozt+2Ae9r0wX1hFeDwtt5f0g2bvqfN+wTdhT2nJ7mPbv979iK2aSSvBVYGrqYLzRMZ+2uGk4EtgV9V1eAV9DsCFyS5vy1zSFXdMLzwJPaxQY9v7buXbtj9LLown8x2jOcrdB+u76S7sOvVoyz3LroL8s5v+9//0I2OLTGGrqyTlltJ3g88qar2n+m2SFo++J2vlmttmPoNdGfHktQLh5213EpyEN3w2ner6uzxlpekxcVhZ0mSeuaZryRJPfM73+XYeuutV3Pnzp3pZkjSUmXevHkLqmpKN+0wfJdjc+fO5aKLLprpZkjSUiXJL8ZfamwOO0uS1DPDV5Kknhm+kiT1zPCVJKlnhq8kST0zfCVJ6pnhK0lSzwxfSZJ65k02lmNX3HIPcw89daabIWkZMP/wPWe6CUsVz3wlSeqZ4StJUs8MX0mSemb4SpLUM8NXkqSeGb6SJPXM8JUkqWeGryRJPTN8JUnqmeErSVLPDF9Jknpm+EqS1DPDV5Kknhm+kiT1zPCVJKlnhq8kST0zfCVJ6pnhK0lSzwxfSZJ6ZvhKktQzw1eSpJ4ZvpIk9czwlSSpZ4avJEk9M3yXcEmOSbLvCNN3TXLKTLRJkjQ1hq8kST0zfMeRZP8kP05yaZKjkqyY5P4kH0lyWZLzk2zQln1Fkivb9LPbtBWTHJHkwiSXJ/nrNn3XJGclOSHJT5McnuTVbV1XJNl8oBm7J/lhW+6lI7RxdpIvtHVckmTvXjpHkjQphu8YkjwV2A94blVtAzwIvBqYDZxfVc8EzgYOakXeD7ywTd+rTXsDcE9V7QjsCByU5Alt3jOBQ4CtgNcAT6qqZwGfA9480JS5wPOAPYHPJll1WFPfC3y/reP5wBFJZo+yTW9KclGSix584J5F7RJJ0mIwa6YbsITbDdgeuDAJwGrA7cAfgKHvW+cBf94enwsck+QE4Ftt2h7A1gPf264FbNnquLCqbgNI8nPg9LbMFXQhOuSEqnoI+FmS64GnDGvnHsBeSd7enq8KbApcM3yDqupo4GiAVTbcsibWDZKkxcnwHVuAY6vq3Y+YmLy9qoaC60FaP1bVwUmeTXeGemmSbVodb66q04bVsSvw+4FJDw08f4hHvjbDQ3L48wAvr6prJ7xlkqQZ47Dz2M4A9k2yPkCSdZNsNtrCSTavqguq6v3AAmAOcBrwN0lWass8abQh4TG8IskK7XvgJwLDQ/Y04M1pp+dJtl3E+iVJPfLMdwxVdXWS9wGnJ1kB+CPwd2MUOSLJlnRnomcAlwGX031ne3ELxzuAfRaxKdcCZwEbAAdX1e9azg75EHAkcHlbx3zgURdmSZKWDFk4eqrlzSobblkbHnDkTDdD0jJg/uF7znQTepNkXlXtMJU6HHaWJKlnhq8kST0zfCVJ6pnhK0lSzwxfSZJ6ZvhKktQzw1eSpJ4ZvpIk9czwlSSpZ4avJEk9M3wlSeqZ4StJUs8MX0mSemb4SpLUM8NXkqSeGb6SJPXM8JUkqWeGryRJPTN8JUnqmeErSVLPDF9Jknpm+EqS1DPDV5Kkns2a6QZo5my18VpcdPieM90MSVrueOYrSVLPDF9Jknpm+EqS1DPDV5Kknhm+kiT1zPCVJKlnhq8kST0zfCVJ6pnhK0lSzwxfSZJ6ZvhKktQzw1eSpJ4ZvpIk9cxfNVqOXXHLPcw99NSZboakJdR8f/Vs2njmK0lSzwxfSZJ6ZvhKktQzw1eSpJ4ZvpIk9czwlSSpZ4avJEk9M3wlSeqZ4StJUs8MX0mSemb4SpLUM8NXkqSeTSh8k7wiyZrt8fuSfCvJdtPbNEmSlk0TPfP9x6q6L8nOwAuBY4HPTF+zJEladk00fB9s/+4JfKaq/hNYeXqaJEnSsm2i4XtLkqOAVwLfSbLKIpSVJEkDJhqgrwROA15UVXcD6wLvmK5GSZK0LJs11swk6w48PXNg2u+Bi6avWZIkLbvGDF9gHlBARphXwBMXe4skSVrGjRm+VfWEvhoiSdLyYqL/zzdJ9k/yj+35pkmeNb1NkyRp2TTRC64+DewE/FV7fh/wqWlpkSRJy7jxvvMd8uyq2i7JJQBVdVcS/5+vJEmTMNEz3z8mWZHuIiuSPA54aNpaJUnSMmyi4ftvwLeB9ZN8BDgH+Odpa5UkScuwCYVvVR0PvBP4KHAbsE9VfWNRV5bkmCT7jjB9oyQntse7JjlllPLzk6y3qOudjImsq8/2SJKWHYtyk43bga8OzquqOxdHI6rqVuBRoSxJ0rJovDPfeXR3spoH3AH8FPhZezxvvMqTvDbJ5UkuS/LlNnmXJOcluX7oLDjJ3CRXjlD+sUlOT3JJu7f0SDf7GFp2bpKfJPlckiuTHJ9k9yTnJvnZ0H+NSrJukpNau85PsvV462r/zerHSS5NclT7/ntw3bOTnNq288ok+43RzvlJ/jnJj5JclGS7JKcl+XmSgweWe0eSC1s7Pzgw/aQk85JcleRNA9PvT/KR1obzk2wwyvrf1NZ70YMP3DNaMyVJ02jM8K2qJ1TVE+nu6/yyqlqvqh4LvBT41lhlkzwdeC/wgqp6JnBIm7UhsHOr4/Bx2vcB4Jyq2hY4Gdh0nOW3AD4BbA08he6/Ru0MvB14T1vmg8AlVbV1m/alsdaV5KnAfsBzq2obul94evWw9b4IuLWqnllVzwC+N047b6qqnYAfAsfQnfU/B/ints49gC2BZwHbANsn2aWVPbCqtgd2AN6S5LFt+mzg/NbXZwMHjbTiqjq6qnaoqh1WXH2tcZopSZoOE73gaseq+s7Qk6r6LvC8ccq8ADixqha0MkND1CdV1UNVdTUw4tnZgF2A41r5U4G7xln+hqq6oqoeAq4CzqiqAq4A5rZldga+3Or8PvDYJGuNsa7dgO2BC5Nc2p4Pv63mFcDuSf4lyf+qqvFOKU8eKHdBVd1XVXcAv0uyNrBH+7sEuJjug8SWrcxbklwGnA/MGZj+B2Dou/J5A9srSVrCTPT/+S5I8j66cCpgf+DX45RJW3a43w9bZjwj1TGawbofGnj+EAu3dbT7VI+2rgDHVtW7R21g1U+TbA+8BPhoktOr6p8m0M7BNg62M8BHq+qoRzQk2RXYHdipqh5Iciawapv9x/ZBA7qz84m+tpKknk30zPdVwOPo/rvRScD6bdpYzgBeOTQsOuzirYk6mzbEm+TFwDqTqGOsOncFFlTVvWOs6wxg3yTrt3nrJtlssMIkGwEPVNVxwMeA7abYxtOAA5Os0erfuK1/LeCuFrxPoRuqliQtZSZ0dtSGjA9J8hjgoaq6fwJlrmr/J/isJA/SDaEuqg8CX01yMXAWcOMk6hjuMOCLSS4HHgAOGGtdVXV1O+s/PckKwB+BvwN+MVDnVsARSR5q8/9mKg2sqtPbd80/SgJwP91ow/eAg1vbr6UbepYkLWWycKRyjIWSreguTBo6e10AHFBVj7pCWUuPVTbcsjY84MiZboakJdT8w/ec6SYskZLMq6odplLHRIedjwL+oao2q6rNgLcBR09lxZIkLa8melHO7Kr6wdCTqjozyexpatOY2nfIZ4wwa7eqGu8isN4k+TYw/PeQ31VVp81EeyRJS46Jhu/16X7Ld+hGGfsDN0xPk8bWAnabmVj3oqiqv5jpNkiSlkwTHXY+kO5q52/S3VxjPeB109QmSZKWaRMN383pbuiwArAS3Y0mzp6uRkmStCyb6LDz8XS3aLwSf8dXkqQpmWj43lFV/zWtLZEkaTkx0fD9QJLP0V1l/PDtEKtqzB9XkCRJjzbR8H093c39V2LhsHMxzi8bSZKkR5to+D6zqraa1pZIkrScmOjVzucnedq0tkSSpOXERM98dwYOSHID3Xe+Aar9IL0kSVoEEw3fF01rKyRJWo5M9CcFfzH+UpIkaSIm+p2vJElaTAxfSZJ6ZvhKktQzw1eSpJ4ZvpIk9czwlSSpZ4avJEk9m+hNNrQM2mrjtbjo8D1nuhmStNzxzFeSpJ4ZvpIk9czwlSSpZ4avJEk9M3wlSeqZ4StJUs8MX0mSemb4SpLUM8NXkqSeGb6SJPXM8JUkqWeGryRJPTN8JUnqmb9qtBy74pZ7mHvoqTPdDGmpMN9fANNi5JmvJEk9M3wlSeqZ4StJUs8MX0mSemb4SpLUM8NXkqSeGb6SJPXM8JUkqWeGryRJPTN8JUnqmeErSVLPDF9Jknpm+EqS1DPDV5Kknhm+kiT1zPCVJKlnhq8kST0zfCVJ6pnhK0lSzwxfSZJ6ZvhKktQzw1eSpJ4ZvpIk9czwlSSpZ4avJEk9M3ynIMncJFdOU927JjmlPd4ryaHTsR5JUv9mzXQDNL6qOhk4eabbIUlaPDzznbpZSY5NcnmSE5OsnuT9SS5McmWSo5MEIMlbklzdlv1amzY7yRfa8pck2Xv4CpK8Lskn2+NjkvxbkvOSXJ9k34Hl3tHquTzJB/vqAEnSojF8p+7JwNFVtTVwL/C3wCeraseqegawGvDStuyhwLZt2YPbtPcC36+qHYHnA0ckmT3OOjcEdm71Hg6QZA9gS+BZwDbA9kl2GV4wyZuSXJTkogcfuGey2yxJmgLDd+puqqpz2+Pj6ELx+UkuSHIF8ALg6W3+5cDxSfYH/tSm7QEcmuRS4ExgVWDTcdZ5UlU9VFVXAxsM1LMHcAlwMfAUujB+hKo6uqp2qKodVlx9rUXeWEnS1Pmd79TVCM8/DexQVTclOYwuUAH2BHYB9gL+McnTgQAvr6prBytJsgGj+/3gogP/frSqjprUVkiSeuOZ79RtmmSn9vhVwDnt8YIkawD7AiRZAZhTVT8A3gmsDawBnAa8eeB74W0n2Y7TgAPbOkmycZL1J1mXJGkaeeY7ddcAByQ5CvgZ8BlgHeAKYD5wYVtuReC4JGvRnaV+vKruTvIh4Ejg8hbA81n4HfGEVdXpSZ4K/Kjl+P3A/sDtk94ySdK0SNXwUVMtL1bZcMva8IAjZ7oZ0lJh/uF7znQTtIRIMq+qdphKHQ47S5LUM8NXkqSeGb6SJPXM8JUkqWeGryRJPTN8JUnqmeErSVLPDF9Jknpm+EqS1DPDV5Kknhm+kiT1zPCVJKlnhq8kST0zfCVJ6pnhK0lSzwxfSZJ6ZvhKktQzw1eSpJ4ZvpIk9czwlSSpZ4avJEk9M3wlSerZrJlugGbOVhuvxUWH7znTzZCk5Y5nvpIk9czwlSSpZ4avJEk9M3wlSeqZ4StJUs8MX0mSemb4SpLUM8NXkqSeGb6SJPUsVTXTbdAMSXIfcO1Mt2MJsR6wYKYbsYSwLxayLxayLxZ6clWtOZUKvL3k8u3aqtphphuxJEhykX3RsS8Wsi8Wsi8WSnLRVOtw2FmSpJ4ZvpIk9czwXb4dPdMNWILYFwvZFwvZFwvZFwtNuS+84EqSpJ555itJUs8MX0mSemb4LoOSvCjJtUmuS3LoCPPXSfLtJJcn+XGSZ0y07NJmsn2RZE6SHyS5JslVSQ7pv/WL11T2izZ/xSSXJDmlv1ZPjym+R9ZOcmKSn7T9Y6d+W794TbEv/r69P65M8tUkq/bb+sUnyReS3J7kylHmJ8m/tX66PMl2A/MW/bhZVf4tQ3/AisDPgScCKwOXAU8btswRwAfa46cAZ0y07NL0N8W+2BDYrj1eE/jp8toXA/P/AfgKcMpMb89M9gVwLPDG9nhlYO2Z3qaZ6AtgY+AGYLX2/ATgdTO9TVPoi12A7YArR5n/EuC7QIDnABdMtA9H+vPMd9nzLOC6qrq+qv4AfA3Ye9gyTwPOAKiqnwBzk2wwwbJLk0n3RVXdVlUXt+n3AdfQHWyWVlPZL0iyCbAn8Ln+mjxtJt0XSR5Dd5D+fJv3h6q6u7eWL35T2i/obtS0WpJZwOrArf00e/GrqrOBO8dYZG/gS9U5H1g7yYZM8rhp+C57NgZuGnh+M48OjcuA/w2Q5FnAZsAmEyy7NJlKXzwsyVxgW+CC6WpoD6baF0cC7wQemtZW9mMqffFE4A7gi20I/nNJZk9/k6fNpPuiqm4BPgbcCNwG3FNVp097i2fOaH01qeOm4bvsyQjThv9/ssOBdZJcCrwZuAT40wTLLk2m0hddBckawDeBt1bVvdPUzj5Mui+SvBS4varmTW8TezOV/WIW3dDkZ6pqW+A3wNJ8bcRU9ot16M7wngBsBMxOsv80tnWmjdZXkzpuem/nZc/NwJyB55swbCiohcjrobuIgO57mxvoho3GLLuUmUpfkGQluuA9vqq+1UeDp9FU+uIvgb2SvARYFXhMkuOqamk90E71PXJzVQ2NgpzI0h2+U+mLFwI3VNUdbd63gD8Djpv+Zs+I0fpq5VGmj8kz32XPhcCWSZ6QZGW6A+fJgwu0qzVXbk/fCJzd3mDjll3KTLov2kHm88A1VfX/em319Jh0X1TVu6tqk6qa28p9fykOXphaX/wSuCnJk9u83YCr+2r4NJjK8eJG4DlJVm/vl93oro1YVp0MvLZd9fwcumH225jkcdMz32VMVf0pyf8BTqO7Cu8LVXVVkoPb/M8CTwW+lORBugPHG8YqOxPbsThMpS+A5wKvAa5ow20A76mq7/S5DYvLFPtimbIY+uLNwPHtQHs97axwaTTF48UFSU4ELqYbkr+EpfgWlEm+CuwKrJfkZuADwErwcD98h+6K5+uAB2iv+2SPm95eUpKknjnsLElSzwxfSZJ6ZvhKktQzw1eSpJ4ZvpIk9czwlTQlSf4iSSV5yky3RVpaGL6SpupVwDl0NxeYFklWnK66pZlg+EqatHbv6+fS3XjhL9u0FZN8LMkV7XdP39ym75jkvCSXpftd2DWTvC7JJwfqOyXJru3x/Un+KckFwE5J3p/kwnS/HXt0u6sSSbZI8j+t3ouTbJ7ky0n2Hqj3+CR79dUv0ngMX0lTsQ/wvar6KXBnuh8YfxPdzfa3raqtWXg3qK8Dh1TVM4Hdgd+OU/dsut9WfXZVnQN8sqp2rKpnAKsBL23LHQ98qtX7Z3S/sPM5Ft6PeK02fam8O5mWTYavpKl4Fd3vl9L+fRVdsH62qv4EUFV3Ak8GbquqC9u0e4fmj+FBuh+2GPL8JBckuQJ4AfD0JGsCG1fVt1u9v6uqB6rqLGCLJOu3Nn1zAuuTeuO9nSVNSpLH0oXgM5IU3X1tC5jHo39SLSNMg+6ewIMnAasOPP5dVT3Y1rUq8Glgh6q6KclhbdmRfs5tyJeBV9MNhx84wc2SeuGZr6TJ2hf4UlVtVlVzq2oO3U/NXQwcnGQWQJJ1gZ8AGyXZsU1bs82fD2yTZIUkc4BnjbKuoVBe0L5n3hce/rm7m5Ps0+pdJcnqbdljgLe25ZbaHwjRssnwlTRZrwK+PWzaN+l+WP1G4PIklwF/VVV/APYD/r1N+2+6QD2XLrCvAD5GF9yPUlV3A//RljuJ7mfchrwGeEuSy4HzgMe3Mr+i+4m7L05xO6XFzl81krRMamfAVwDbVdU9M90eaZBnvpKWOUl2pxvq/neDV0siz3wlSeqZZ76SJPXM8JUkqWeGryRJPTN8JUnqmeErSVLP/j9pyDxpdAmECgAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "models = ['baseline', 'child_models_mean', 'ensemble']\n", "accuracies = [baseline_val_accuracy]\n", "accuracies_children = []\n", "\n", "for ensemble_child_model in ensembled_model.layers[1:6]:\n", " ensemble_model_child_results = ensemble_child_model.evaluate(test_dataset.batch(batch_size), return_dict=True)\n", " accuracies_children.append(ensemble_model_child_results['categorical_accuracy'])\n", "\n", "mean_accuracy_children = np.mean(np.array(accuracies_children))\n", "accuracies.append(mean_accuracy_children)\n", "\n", "accuracies.append(ensembled_results['categorical_accuracy'])\n", "\n", "plt.barh(models,accuracies)\n", "plt.title('Accuracy on the test set for baseline vs ensemble')\n", "plt.ylabel('models')\n", "plt.xlabel('Accuracy')\n", "plt.xlim([0.90, 1.0])\n", "plt.show()" ] } ], "metadata": { "interpreter": { "hash": "e11de040a44de2599d5826916dec5532a989d7fc6a7daf05571191351ea2bbfc" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 2 }