Unlabeled Data with Masterful (Part 2)

Author: sam
Date created: 2022/03/29
Last modified: 2022/03/29
Description: Part 2 of using unlabeled data with Masterful.

Open In Colab        DownloadDownload this Notebook

Introduction

In this guide, you will improve upon the results you achieved in Part 1 using semi-supervised learning inside of the Masterful AutoML platform. You will learn how to use a pretraining technique which can generate an optimal set of initial weights for your model from only unlabeled data. The pretraining algorithm you will use is a form of self-supervision which can learn meaningful representations of your dataset without any labels at all. In other words, at the end of this guide, you will have learned how to take a model that was only 15% accurate using 500 labeled training examples, and improve its accuracy to over 75% using only unlabeled training examples!

Prerequisites

Please follow the Masterful installation instructions here in order to run this Quickstart.

Imports

First, import the necessary libraries and register the Masterful package.

[1]:
import numpy as np
import os
import tempfile
import tensorflow as tf
import urllib.request

import masterful

masterful = masterful.register()
MASTERFUL: Your account has been successfully registered. Masterful v0.4.1.dev202204071649294505 is loaded.

Prepare the Data

For this guide, you will use only 1% of the CIFAR-10 data as your labeled dataset, in order to simulate a small of amount of labeled training data. You will also use the full CIFAR-10 training dataset without labels, as a pretraining step to learn an accurate representation of the this dataset in an unsupervised way (as part of self-supervision using Masterful).

Finally, you will then use the same 10x unlabeled data as in Part 1 to train a classification model from the representation you learned.

Below, you will create training, validation, and test sets that are duplicates of the Part 1 datasets. In addition, you will a much larger unlabeled dataset to use with learn_representation.

[2]:
NUM_CLASSES = 10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Normalize into the [0,1] range for numerical stability.
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

# Masterful does not recommend sparse labels so convert to categorical.
y_train = tf.keras.utils.to_categorical(y_train, NUM_CLASSES)
y_test = tf.keras.utils.to_categorical(y_test, NUM_CLASSES)

# Shuffle the data, and take 1% for the labeled data set,
# and 10x that amount for the unlabeled dataset.
training_percentage = 0.01
unlabeled_multiplier = 10
dataset_size = len(x_train)
indices = np.array(range(dataset_size))
generator = np.random.default_rng(seed=42)
generator.shuffle(indices)
cut = int(training_percentage * dataset_size)
part1_train_indices = indices[:cut]
part1_unlabeled_indices = indices[
    cut : cut + int(dataset_size * training_percentage * unlabeled_multiplier)
]

# Create the datasets from the splits. In order to
# compare against Part 1, you will explicitly create
# the same 1% training dataset and 10x unlabeled dataset.
# In addition, you will create a much larger unlabeled
# dataset you will use for representation learning.
part1_training_dataset = tf.data.Dataset.from_tensor_slices(
    (x_train[part1_train_indices], y_train[part1_train_indices])
)
part1_unlabeled_dataset = tf.data.Dataset.from_tensor_slices(
    (x_train[part1_unlabeled_indices],)
)
full_unlabeled_dataset = tf.data.Dataset.from_tensor_slices((x_train,))

# Split the test dataset into a test and validation dataset.
# The validation dataset is used for measuring training performance.
indices = np.array(range(len(x_test)))
generator.shuffle(indices)
test_indices = indices[:5000]
validation_indices = indices[5000:]
test_dataset = tf.data.Dataset.from_tensor_slices(
    (x_test[test_indices], y_test[test_indices])
)
validation_dataset = tf.data.Dataset.from_tensor_slices(
    (x_test[validation_indices], y_test[validation_indices])
)

Create the Model

You are going to use the same model architecture from Part 1 of this guide here. However, you are going to separate the model into two parts: and embedding module and a classification head. This just mean you will not add a classification head to the model (typically a pooling layer followed by a final dense layer). Why do this? In the first step, you will be learning an embedded representation of the unlabeled dataset using self-supervision, and for this task, you do not need the classification layer. In fact, since there are no labels, there is no way to make use of a classification layer…yet. So below, you will see us create the same model as Part 1, and then you will slice off the last two layers (GlobalAveragePooling2D and Dense) to create an “embedding model”.

[3]:
from tensorflow.keras.layers import (
    Input,
    Add,
    Conv2D,
    GlobalAveragePooling2D,
    MaxPooling2D,
    ReLU,
    ZeroPadding2D,
    BatchNormalization,
    Dense,
)


def identity_block(x, name, stage, unit, n_filters):
    shortcut = x

    x = BatchNormalization(name=name.format(stage, unit, "bn", 1))(x)
    x = ReLU(name=name.format(stage, unit, "relu", 1))(x)
    x = Conv2D(
        n_filters,
        (3, 3),
        strides=(1, 1),
        padding="same",
        kernel_initializer="he_uniform",
        name=name.format(stage, unit, "conv", 1),
    )(x)

    x = BatchNormalization(name=name.format(stage, unit, "bn", 2))(x)
    x = ReLU(name=name.format(stage, unit, "relu", 2))(x)
    x = Conv2D(
        n_filters,
        (3, 3),
        strides=(1, 1),
        padding="same",
        kernel_initializer="he_uniform",
        name=name.format(stage, unit, "conv", 2),
    )(x)

    x = Add(name=name.format(stage, unit, "add", 1))([shortcut, x])
    return x


def projection_block(x, name, stage, unit, strides, n_filters):
    x = BatchNormalization(name=name.format(stage, unit, "bn", 1))(x)
    x = ReLU(name=name.format(stage, unit, "relu", 1))(x)
    shortcut = Conv2D(
        n_filters,
        (1, 1),
        strides=strides,
        kernel_initializer="he_uniform",
        name=name.format(stage, unit, "sc", 1),
    )(x)

    x = Conv2D(
        n_filters,
        (3, 3),
        strides=strides,
        padding="same",
        kernel_initializer="he_uniform",
        name=name.format(stage, unit, "conv", 1),
    )(x)
    x = BatchNormalization(name=name.format(stage, unit, "bn", 2))(x)
    x = ReLU(name=name.format(stage, unit, "relu", 2))(x)
    x = Conv2D(
        n_filters,
        (3, 3),
        strides=(1, 1),
        padding="same",
        kernel_initializer="he_uniform",
        name=name.format(stage, unit, "conv", 2),
    )(x)

    x = Add(name=name.format(stage, unit, "add", 1))([x, shortcut])
    return x


def group(x, name, stage, strides, n_blocks, n_filters):
    x = projection_block(
        x, name=name, stage=stage, unit=1, strides=strides, n_filters=n_filters
    )
    for unit in range(n_blocks - 1):
        x = identity_block(
            x, name=name, stage=stage, unit=unit + 2, n_filters=n_filters
        )
    return x


def resnet18(input_shape, num_classes):
    inputs = Input(input_shape)
    x = ZeroPadding2D(padding=(3, 3))(inputs)

    x = Conv2D(
        64, (3, 3), strides=(1, 1), padding="valid", kernel_initializer="he_uniform"
    )(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = ZeroPadding2D(padding=(1, 1))(x)
    x = MaxPooling2D((3, 3), strides=(2, 2))(x)

    x = group(
        x, strides=(1, 1), name="stage{}_unit{}_{}{}", stage=1, n_blocks=2, n_filters=64
    )
    x = group(
        x,
        strides=(2, 2),
        name="stage{}_unit{}_{}{}",
        stage=2,
        n_blocks=2,
        n_filters=128,
    )
    x = group(
        x,
        strides=(2, 2),
        name="stage{}_unit{}_{}{}",
        stage=3,
        n_blocks=2,
        n_filters=256,
    )
    x = group(
        x,
        strides=(2, 2),
        name="stage{}_unit{}_{}{}",
        stage=4,
        n_blocks=2,
        n_filters=512,
    )

    x = BatchNormalization()(x)
    x = ReLU()(x)

    x = GlobalAveragePooling2D()(x)
    x = Dense(num_classes, kernel_initializer="he_normal")(x)
    return tf.keras.Model(inputs=inputs, outputs=x)


INPUT_SHAPE = (32, 32, 3)
NUM_CLASSES = 10

classification_model = resnet18(INPUT_SHAPE, NUM_CLASSES)

# Slice off the final two layers to create an embedding model.
embedding_output = classification_model.layers[-3].output
embedding_model = tf.keras.Model(
    inputs=classification_model.input, outputs=embedding_output
)
embedding_model.summary()
Extension horovod.torch has not been built: /home/ubuntu/anaconda3/envs/tensorflow2_latest_p37/lib/python3.7/site-packages/horovod/torch/mpi_lib/_mpi_lib.cpython-37m-x86_64-linux-gnu.so not found
If this is not expected, reinstall Horovod with HOROVOD_WITH_PYTORCH=1 to debug the build error.
Warning! MPI libs are missing, but python applications are still avaiable.
[2022-04-07 01:28:56.004 ip-172-31-46-120:13987 INFO utils.py:27] RULE_JOB_STOP_SIGNAL_FILENAME: None
[2022-04-07 01:28:56.032 ip-172-31-46-120:13987 INFO profiler_config_parser.py:111] Unable to find config at /opt/ml/input/config/profilerconfig.json. Profiler is disabled.
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            [(None, 32, 32, 3)]  0
__________________________________________________________________________________________________
zero_padding2d (ZeroPadding2D)  (None, 38, 38, 3)    0           input_1[0][0]
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 36, 36, 64)   1792        zero_padding2d[0][0]
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 36, 36, 64)   256         conv2d[0][0]
__________________________________________________________________________________________________
re_lu (ReLU)                    (None, 36, 36, 64)   0           batch_normalization[0][0]
__________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D (None, 38, 38, 64)   0           re_lu[0][0]
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 18, 18, 64)   0           zero_padding2d_1[0][0]
__________________________________________________________________________________________________
stage1_unit1_bn1 (BatchNormaliz (None, 18, 18, 64)   256         max_pooling2d[0][0]
__________________________________________________________________________________________________
stage1_unit1_relu1 (ReLU)       (None, 18, 18, 64)   0           stage1_unit1_bn1[0][0]
__________________________________________________________________________________________________
stage1_unit1_conv1 (Conv2D)     (None, 18, 18, 64)   36928       stage1_unit1_relu1[0][0]
__________________________________________________________________________________________________
stage1_unit1_bn2 (BatchNormaliz (None, 18, 18, 64)   256         stage1_unit1_conv1[0][0]
__________________________________________________________________________________________________
stage1_unit1_relu2 (ReLU)       (None, 18, 18, 64)   0           stage1_unit1_bn2[0][0]
__________________________________________________________________________________________________
stage1_unit1_conv2 (Conv2D)     (None, 18, 18, 64)   36928       stage1_unit1_relu2[0][0]
__________________________________________________________________________________________________
stage1_unit1_sc1 (Conv2D)       (None, 18, 18, 64)   4160        stage1_unit1_relu1[0][0]
__________________________________________________________________________________________________
stage1_unit1_add1 (Add)         (None, 18, 18, 64)   0           stage1_unit1_conv2[0][0]
                                                                 stage1_unit1_sc1[0][0]
__________________________________________________________________________________________________
stage1_unit2_bn1 (BatchNormaliz (None, 18, 18, 64)   256         stage1_unit1_add1[0][0]
__________________________________________________________________________________________________
stage1_unit2_relu1 (ReLU)       (None, 18, 18, 64)   0           stage1_unit2_bn1[0][0]
__________________________________________________________________________________________________
stage1_unit2_conv1 (Conv2D)     (None, 18, 18, 64)   36928       stage1_unit2_relu1[0][0]
__________________________________________________________________________________________________
stage1_unit2_bn2 (BatchNormaliz (None, 18, 18, 64)   256         stage1_unit2_conv1[0][0]
__________________________________________________________________________________________________
stage1_unit2_relu2 (ReLU)       (None, 18, 18, 64)   0           stage1_unit2_bn2[0][0]
__________________________________________________________________________________________________
stage1_unit2_conv2 (Conv2D)     (None, 18, 18, 64)   36928       stage1_unit2_relu2[0][0]
__________________________________________________________________________________________________
stage1_unit2_add1 (Add)         (None, 18, 18, 64)   0           stage1_unit1_add1[0][0]
                                                                 stage1_unit2_conv2[0][0]
__________________________________________________________________________________________________
stage2_unit1_bn1 (BatchNormaliz (None, 18, 18, 64)   256         stage1_unit2_add1[0][0]
__________________________________________________________________________________________________
stage2_unit1_relu1 (ReLU)       (None, 18, 18, 64)   0           stage2_unit1_bn1[0][0]
__________________________________________________________________________________________________
stage2_unit1_conv1 (Conv2D)     (None, 9, 9, 128)    73856       stage2_unit1_relu1[0][0]
__________________________________________________________________________________________________
stage2_unit1_bn2 (BatchNormaliz (None, 9, 9, 128)    512         stage2_unit1_conv1[0][0]
__________________________________________________________________________________________________
stage2_unit1_relu2 (ReLU)       (None, 9, 9, 128)    0           stage2_unit1_bn2[0][0]
__________________________________________________________________________________________________
stage2_unit1_conv2 (Conv2D)     (None, 9, 9, 128)    147584      stage2_unit1_relu2[0][0]
__________________________________________________________________________________________________
stage2_unit1_sc1 (Conv2D)       (None, 9, 9, 128)    8320        stage2_unit1_relu1[0][0]
__________________________________________________________________________________________________
stage2_unit1_add1 (Add)         (None, 9, 9, 128)    0           stage2_unit1_conv2[0][0]
                                                                 stage2_unit1_sc1[0][0]
__________________________________________________________________________________________________
stage2_unit2_bn1 (BatchNormaliz (None, 9, 9, 128)    512         stage2_unit1_add1[0][0]
__________________________________________________________________________________________________
stage2_unit2_relu1 (ReLU)       (None, 9, 9, 128)    0           stage2_unit2_bn1[0][0]
__________________________________________________________________________________________________
stage2_unit2_conv1 (Conv2D)     (None, 9, 9, 128)    147584      stage2_unit2_relu1[0][0]
__________________________________________________________________________________________________
stage2_unit2_bn2 (BatchNormaliz (None, 9, 9, 128)    512         stage2_unit2_conv1[0][0]
__________________________________________________________________________________________________
stage2_unit2_relu2 (ReLU)       (None, 9, 9, 128)    0           stage2_unit2_bn2[0][0]
__________________________________________________________________________________________________
stage2_unit2_conv2 (Conv2D)     (None, 9, 9, 128)    147584      stage2_unit2_relu2[0][0]
__________________________________________________________________________________________________
stage2_unit2_add1 (Add)         (None, 9, 9, 128)    0           stage2_unit1_add1[0][0]
                                                                 stage2_unit2_conv2[0][0]
__________________________________________________________________________________________________
stage3_unit1_bn1 (BatchNormaliz (None, 9, 9, 128)    512         stage2_unit2_add1[0][0]
__________________________________________________________________________________________________
stage3_unit1_relu1 (ReLU)       (None, 9, 9, 128)    0           stage3_unit1_bn1[0][0]
__________________________________________________________________________________________________
stage3_unit1_conv1 (Conv2D)     (None, 5, 5, 256)    295168      stage3_unit1_relu1[0][0]
__________________________________________________________________________________________________
stage3_unit1_bn2 (BatchNormaliz (None, 5, 5, 256)    1024        stage3_unit1_conv1[0][0]
__________________________________________________________________________________________________
stage3_unit1_relu2 (ReLU)       (None, 5, 5, 256)    0           stage3_unit1_bn2[0][0]
__________________________________________________________________________________________________
stage3_unit1_conv2 (Conv2D)     (None, 5, 5, 256)    590080      stage3_unit1_relu2[0][0]
__________________________________________________________________________________________________
stage3_unit1_sc1 (Conv2D)       (None, 5, 5, 256)    33024       stage3_unit1_relu1[0][0]
__________________________________________________________________________________________________
stage3_unit1_add1 (Add)         (None, 5, 5, 256)    0           stage3_unit1_conv2[0][0]
                                                                 stage3_unit1_sc1[0][0]
__________________________________________________________________________________________________
stage3_unit2_bn1 (BatchNormaliz (None, 5, 5, 256)    1024        stage3_unit1_add1[0][0]
__________________________________________________________________________________________________
stage3_unit2_relu1 (ReLU)       (None, 5, 5, 256)    0           stage3_unit2_bn1[0][0]
__________________________________________________________________________________________________
stage3_unit2_conv1 (Conv2D)     (None, 5, 5, 256)    590080      stage3_unit2_relu1[0][0]
__________________________________________________________________________________________________
stage3_unit2_bn2 (BatchNormaliz (None, 5, 5, 256)    1024        stage3_unit2_conv1[0][0]
__________________________________________________________________________________________________
stage3_unit2_relu2 (ReLU)       (None, 5, 5, 256)    0           stage3_unit2_bn2[0][0]
__________________________________________________________________________________________________
stage3_unit2_conv2 (Conv2D)     (None, 5, 5, 256)    590080      stage3_unit2_relu2[0][0]
__________________________________________________________________________________________________
stage3_unit2_add1 (Add)         (None, 5, 5, 256)    0           stage3_unit1_add1[0][0]
                                                                 stage3_unit2_conv2[0][0]
__________________________________________________________________________________________________
stage4_unit1_bn1 (BatchNormaliz (None, 5, 5, 256)    1024        stage3_unit2_add1[0][0]
__________________________________________________________________________________________________
stage4_unit1_relu1 (ReLU)       (None, 5, 5, 256)    0           stage4_unit1_bn1[0][0]
__________________________________________________________________________________________________
stage4_unit1_conv1 (Conv2D)     (None, 3, 3, 512)    1180160     stage4_unit1_relu1[0][0]
__________________________________________________________________________________________________
stage4_unit1_bn2 (BatchNormaliz (None, 3, 3, 512)    2048        stage4_unit1_conv1[0][0]
__________________________________________________________________________________________________
stage4_unit1_relu2 (ReLU)       (None, 3, 3, 512)    0           stage4_unit1_bn2[0][0]
__________________________________________________________________________________________________
stage4_unit1_conv2 (Conv2D)     (None, 3, 3, 512)    2359808     stage4_unit1_relu2[0][0]
__________________________________________________________________________________________________
stage4_unit1_sc1 (Conv2D)       (None, 3, 3, 512)    131584      stage4_unit1_relu1[0][0]
__________________________________________________________________________________________________
stage4_unit1_add1 (Add)         (None, 3, 3, 512)    0           stage4_unit1_conv2[0][0]
                                                                 stage4_unit1_sc1[0][0]
__________________________________________________________________________________________________
stage4_unit2_bn1 (BatchNormaliz (None, 3, 3, 512)    2048        stage4_unit1_add1[0][0]
__________________________________________________________________________________________________
stage4_unit2_relu1 (ReLU)       (None, 3, 3, 512)    0           stage4_unit2_bn1[0][0]
__________________________________________________________________________________________________
stage4_unit2_conv1 (Conv2D)     (None, 3, 3, 512)    2359808     stage4_unit2_relu1[0][0]
__________________________________________________________________________________________________
stage4_unit2_bn2 (BatchNormaliz (None, 3, 3, 512)    2048        stage4_unit2_conv1[0][0]
__________________________________________________________________________________________________
stage4_unit2_relu2 (ReLU)       (None, 3, 3, 512)    0           stage4_unit2_bn2[0][0]
__________________________________________________________________________________________________
stage4_unit2_conv2 (Conv2D)     (None, 3, 3, 512)    2359808     stage4_unit2_relu2[0][0]
__________________________________________________________________________________________________
stage4_unit2_add1 (Add)         (None, 3, 3, 512)    0           stage4_unit1_add1[0][0]
                                                                 stage4_unit2_conv2[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 3, 3, 512)    2048        stage4_unit2_add1[0][0]
__________________________________________________________________________________________________
re_lu_1 (ReLU)                  (None, 3, 3, 512)    0           batch_normalization_1[0][0]
==================================================================================================
Total params: 11,184,064
Trainable params: 11,176,128
Non-trainable params: 7,936
__________________________________________________________________________________________________

Learn Unlabeled Representation

In this section, you will use a self-supervised technique called Barlow Twins to learn a representation of your unlabeled data. In essence, this will create a better set of initialized weights for you model, that will allow it train better.

You will use the Masterful API learn_representation to learn from the unlabeled dataset and create an optimal set of initial weights for your model.

This is the same set of steps as provided in the Pretraining Guide.

[4]:
full_unlabeled_dataset_params = masterful.data.learn_data_params(
    task=masterful.enums.Task.CLASSIFICATION,
    dataset=full_unlabeled_dataset,
    image_range=masterful.enums.ImageRange.ZERO_ONE,
    num_classes=NUM_CLASSES,
    sparse_labels=False,
)
model_params = masterful.architecture.learn_architecture_params(
    model=embedding_model,
    task=masterful.enums.Task.CLASSIFICATION,
    input_range=masterful.enums.ImageRange.ZERO_ONE,
    num_classes=NUM_CLASSES,
    prediction_logits=True,
    backbone_only=True,
)

# For demonstration purposes, you are only training for
# two epochs here. In general though, you will want to train
# for a lot longer. For example, Barlow Twins will typically
# require 800 epochs of training, with 5 warmup epochs
# for stability.
optimization_params = masterful.optimization.OptimizationParams(
    batch_size=512,
    epochs=2,
    warmup_epochs=1,
)

# Explicitly use Barlow Twins as the representation learner.
ssl_params = masterful.ssl.SemiSupervisedParams(algorithms=["barlow_twins"])

_ = masterful.ssl.learn_representation(
    model=embedding_model,
    model_params=model_params,
    optimization_params=optimization_params,
    ssl_params=ssl_params,
    unlabeled_datasets=[(full_unlabeled_dataset, full_unlabeled_dataset_params)],
)
Epoch 1/2
98/98 [==============================] - 113s 992ms/step - loss: 2041.6415
Epoch 2/2
98/98 [==============================] - 103s 995ms/step - loss: 1973.8130

Build the Classification Model

In the previous step, you created an embedding model that learned a representation from the full unlabeled dataset. But what should you do with this? The next step is take this representation and train a classification model created from it on your labeled training dataset, in this instance 1% of the CIFAR-10 dataset.

The first step is to create a classification model from your embedding model. This involves attaching a dense head to the embedding model. In the below, we also load a set of pretrained weights that we created for this guide, by running learn_representation for a full 800 epochs. In particular, you can reproduce these results using the same setup as above with the following optimization parameters:

optimization_params = OptimizationParams(batch_size=512,
                                    epochs=800,
                                    warmup_epochs=5)
[5]:

# Load the pretrained weights into the model backbone.
# These weights were trained on above model using
# masterful.ssl.learn_representation with the above
# documented optimization parameters.
def load_pretrained_weights(model):
    with tempfile.TemporaryDirectory() as tempdir:
        # Download the weights from AWS
        urllib.request.urlretrieve(
            "https://masterful-public.s3.us-west-1.amazonaws.com/933013963/static-data/resnet18_cifar10/trained_model_weights.index",
            os.path.join(tempdir, "trained_model_weights.index"),
        )
        urllib.request.urlretrieve(
            "https://masterful-public.s3.us-west-1.amazonaws.com/933013963/static-data/resnet18_cifar10/trained_model_weights.data-00000-of-00001",
            os.path.join(tempdir, "trained_model_weights.data-00000-of-00001"),
        )
        model.load_weights(os.path.join(tempdir, "trained_model_weights"))


tf.keras.backend.clear_session()
load_pretrained_weights(embedding_model)

# Attach the classifier head to the embedding model.
input = tf.keras.layers.Input(shape=INPUT_SHAPE)
x = embedding_model(input)
x = GlobalAveragePooling2D()(x)
x = Dense(NUM_CLASSES, kernel_initializer="he_normal")(x)
classification_model = tf.keras.Model(inputs=input, outputs=x)

Setup Training Parameters

Next, you need to setup all of the parameters for training with Masterful. These parameters are different than the ones you used to learn the representation, because you will be training on different datasets - namely the small labeled dataset and the larger (but not full) unlabeled dataset from Part 1.

[6]:
model_params = masterful.architecture.learn_architecture_params(
    model=classification_model,
    task=masterful.enums.Task.CLASSIFICATION,
    input_range=masterful.enums.ImageRange.ZERO_ONE,
    num_classes=NUM_CLASSES,
    prediction_logits=True,
)
training_dataset_params = masterful.data.learn_data_params(
    dataset=part1_training_dataset,
    task=masterful.enums.Task.CLASSIFICATION,
    image_range=masterful.enums.ImageRange.ZERO_ONE,
    num_classes=NUM_CLASSES,
    sparse_labels=False,
)
validation_dataset_params = masterful.data.learn_data_params(
    dataset=validation_dataset,
    task=masterful.enums.Task.CLASSIFICATION,
    image_range=masterful.enums.ImageRange.ZERO_ONE,
    num_classes=NUM_CLASSES,
    sparse_labels=False,
)
unlabeled_dataset_params = masterful.data.learn_data_params(
    dataset=part1_unlabeled_dataset,
    task=masterful.enums.Task.CLASSIFICATION,
    image_range=masterful.enums.ImageRange.ZERO_ONE,
    num_classes=NUM_CLASSES,
    sparse_labels=None,
)

Next you learn the optimization parameters that will be used to train the classification model. These are different than the optimization parameters you used above to learn the representation.

For more details on the optmization parameters, please see the OptimizationParams API specification.

[7]:
optimization_params = masterful.optimization.learn_optimization_params(
    classification_model,
    model_params,
    part1_training_dataset,
    training_dataset_params,
)
MASTERFUL: Learning optimal batch size.
MASTERFUL: Learning optimal initial learning rate for batch size 32.

The regularization parameters used can have a dramatic impact on the final performance of your trained model. Learning these parameters can be a time-consuming and domain specific challenge. Masterful can speed up this process by learning these parameters for you. In general, this can be an expensive operation. A rough order of magnitude for learning these parameters is 2x the time it takes to train your model. However, this is still dramatically faster than manually finding these parameters yourself. In the example below, you will use one of the many sets of pre-learned regularization parameters that are shipped in the Masterful API. In most instances, you should learn these parameters directly using the learn_regularization_params API.

For more details on the regularization parameters, please see the RegularizationParams API specification.

[8]:
# This is a set of parameters learned on CIFAR10 for
# for  ResNet18 models.
regularization_params = masterful.regularization.parameters.CIFAR10_RESNET18

The final step before training is to learn the optimal set of semi-supervision parameters. In this example, Masterful will apply Noisy Student Training to improve your model during training with the provided unlabeled data.

For more details on the semi-supervision parameters, please see the SemiSupervisedParams API specification.

[9]:
ssl_params = masterful.ssl.learn_ssl_params(
    part1_training_dataset,
    training_dataset_params,
    unlabeled_datasets=[(part1_unlabeled_dataset, unlabeled_dataset_params)],
)

Training with Unlabeled Data

Now, you are ready to train your model using the Masterful AutoML platform. In the next cell, you will see the call to masterful.training.train, which is the entry point to the meta-learning engine of the Masterful AutoML platform. Notice there is no need to batch your data (Masterful will find the optimal batch size for you). No need to shuffle your data (Masterful handles this for you). You don’t even need to pass in a validation dataset (Masterful finds one for you). You hand Masterful a model and a dataset, and Masterful will figure the rest out for you.

[10]:
training_report = masterful.training.train(
    classification_model,
    model_params,
    optimization_params,
    regularization_params,
    ssl_params,
    part1_training_dataset,
    training_dataset_params,
    validation_dataset,
    validation_dataset_params,
    unlabeled_datasets=[(part1_unlabeled_dataset, unlabeled_dataset_params)],
)
MASTERFUL: Training model with semi-supervised learning enabled.
MASTERFUL: Performing basic dataset analysis.
MASTERFUL: Training model with:
MASTERFUL:      500 labeled examples.
MASTERFUL:      5000 validation examples.
MASTERFUL:      0 synthetic examples.
MASTERFUL:      5000 unlabeled examples.
MASTERFUL: Training model with learned parameters quesadilla-cyclic-offer in two phases.
MASTERFUL: The first phase is supervised training with the learned parameters.
MASTERFUL: The second phase is semi-supervised training to boost performance.
MASTERFUL: Warming up model for supervised training.
MASTERFUL:      Warming up batch norm statistics (this could take a few minutes).
MASTERFUL:      Warming up training for 500 steps.
100%|██████████| 500/500 [02:01<00:00,  4.12steps/s]
MASTERFUL:      Validating batch norm statistics after warmup for stability (this could take a few minutes).
MASTERFUL: Starting Phase 1: Supervised training until the validation loss stabilizes...
Supervised Training: 100%|██████████| 7612/7612 [02:22<00:00, 53.33steps/s]
MASTERFUL: Starting Phase 2: Semi-supervised training until the validation loss stabilizes...
MASTERFUL: Warming up model for semi-supervised training.
MASTERFUL:      Warming up batch norm statistics (this could take a few minutes).
MASTERFUL:      Warming up training for 500 steps.
100%|██████████| 500/500 [01:12<00:00,  6.92steps/s]
MASTERFUL:      Validating batch norm statistics after warmup for stability (this could take a few minutes).
Semi-Supervised Training: 100%|██████████| 21385/21385 [26:15<00:00, 13.57steps/s]
MASTERFUL: Semi-Supervised training complete.
MASTERFUL: Training complete in 32.63983148733775 minutes.

The model you passed into masterful.training.train is now trained and updated in place, so you are able to evaluate it just like any other trained Keras model.

[11]:
masterful_metrics = classification_model.evaluate(
    test_dataset.batch(optimization_params.batch_size), return_dict=True
)
print(f"Masterful model accuracy: {masterful_metrics['categorical_accuracy']}")
157/157 [==============================] - 1s 5ms/step - loss: 0.7355 - categorical_accuracy: 0.7712
Masterful model accuracy: 0.7712000012397766

If you recall, in Part 1 you ended up with a model that was 25-35% accurate. Now, using self-supervision on a much larger unlabeled dataset, and self-training on the original dataset, you increased that accuracy to over 75%! And all of this occurred with no additional labels in your dataset.