A Simple SSL Recipe¶
In this recipe, you’ll perform simple data manipulation and utilize Masterful utilities to improve your model’s accuracy with SSL techniques.
SSL, or Semi-Supervised Learning, means allowing your model to learn from both labeled and unlabeled data. Normally, SSL techniques require custom training loops and multiple losses. This recipe allows you to quickly implement SSL and get potentially significant improvements without custom training loops or multiple losses.
Consider using this recipe if:
You want to quickly try an SSL technique.
You want to keep your own training loop.
You have a very well tuned regularization policy.
Your model is classification (binary, single-label, or multilabel) or semantic segmentation.
For power users, you may want to skip this recipe and go straight to the full Masterful Platform if:
You want to maximize accuracy.
Your regularization policy is not optimally tuned.
Your model is detection or instance segmentation.
You are using Label Smoothing Regularization.
First, Set up a standard supervised training pipeline.¶
This will not do any SSL yet. It should resemble most supervised training pipelines you’ve developed.
Implement functions to:
Get your dataset (
get_labeled_datasets()
)Create your model architecture (
get_model()
)Augment your data (
augment_images()
)Train your model (
train_model()
).
[1]:
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
def get_labeled_datasets(train_percentage=1):
"""Simple function to get cifar10 as a `tf.data.Dataset`"""
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# Take the first training_percentage of the training data.
train_cardinality = train_percentage * 50000 // 100
x_train = x_train[0:train_cardinality]
y_train = y_train[0:train_cardinality]
# Normalize data into the range [0,1]
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
# Convert labels to one-hot.
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)
# Split test into a val and test dataset.
x_val = x_test[:5000]
y_val = y_test[:5000]
x_test = x_test[5000:]
y_test = y_test[5000:]
# Convert the data to tf.data.Dataset.
train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
# Shuffle just the training dataset.
train = train.shuffle(1000)
# Batch the data. The batch size is a crucial hyperparameter to
# take advantage of your GPU hardware. See the guide to the
# optimization metalearner to find out how to learn an optimal batch size.
train = train.batch(256)
val = val.batch(256)
test = test.batch(256)
train = train.prefetch(tf.data.AUTOTUNE)
return train, val, test
def get_model():
"""Returns a minimal convnet. """
inp = tf.keras.Input((32, 32, 3))
x = inp
x = tf.keras.layers.Conv2D(16, 3, activation='relu')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.MaxPooling2D()(x)
x = tf.keras.layers.Conv2D(32, 3, activation='relu')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.MaxPooling2D()(x)
x = tf.keras.layers.Conv2D(64, 3, activation='relu')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.MaxPooling2D()(x)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(10, activation='softmax')(x)
return tf.keras.Model(inp, x)
def augment_image(image):
"""A simple augmentation pipeline."""
image = tf.image.random_brightness(image, 0.1)
image = tf.image.random_hue(image, 0.1)
image = tf.image.resize(image, size=[32,32])
image = tf.image.random_flip_left_right(image)
return image
def train_model(model, augmented_train, validation_data, epochs=100):
"""A simple training loop. """
early_stopping = tf.keras.callbacks.EarlyStopping(patience=25,
verbose=2,
restore_best_weights=True)
# The learning rate used by the optimizer (in this case Adam)
# is a crucial hyperparameter to take advantage of your GPU hardware.
# See the guide to the optimization metalearner to find out how to
# learn an optimal learning rate.
model.compile(
optimizer=tfa.optimizers.LAMB(learning_rate=0.001),
loss='categorical_crossentropy',
metrics=['acc'],
)
model.fit(augmented_train,
validation_data=validation_data,
epochs=epochs,
callbacks=early_stopping)
# Now use the functions you just defined to train a model start to finish.
train, val, test = get_labeled_datasets()
augmented_train = train.map(lambda image, label: (augment_image(image), label))
model = get_model()
train_model(model, augmented_train, val)
Epoch 1/100
2/2 [==============================] - 3s 904ms/step - loss: 3.2753 - acc: 0.0780 - val_loss: 2.3118 - val_acc: 0.1010
Epoch 2/100
2/2 [==============================] - 1s 762ms/step - loss: 3.1935 - acc: 0.0860 - val_loss: 2.3098 - val_acc: 0.1016
Epoch 3/100
...
Epoch 51/100
2/2 [==============================] - 1s 765ms/step - loss: 1.6426 - acc: 0.4180 - val_loss: 2.3094 - val_acc: 0.1128
Epoch 52/100
2/2 [==============================] - 1s 770ms/step - loss: 1.6359 - acc: 0.4020 - val_loss: 2.3108 - val_acc: 0.1136
Epoch 53/100
2/2 [==============================] - ETA: 0s - loss: 1.6156 - acc: 0.4300Restoring model weights from the end of the best epoch: 28.
2/2 [==============================] - 1s 770ms/step - loss: 1.6156 - acc: 0.4300 - val_loss: 2.3122 - val_acc: 0.1142
Epoch 53: early stopping
[2]:
baseline_eval_metrics = model.evaluate(test)
20/20 [==============================] - 1s 32ms/step - loss: 2.2884 - acc: 0.1200
Now you’ll improve the accuracy of your model using SSL techniques.¶
First, set up your unlabeled data as a batched tf.data.Dataset
. Typically, each element of a batched Dataset is a tuple of tensors: (images, labels)
. Since unlabeled data doesn’t have a label, just make each element of your batched dataset a tensor: images
.
[3]:
# Normally, the unlabeled dataset comes from images that are not yet labeled.
# To simulate that with CIFAR10, you will use 5% of the training data, but
# remove the labels. Be sure to use the end of the training data, not the
# beginning, to ensure that the labeled and unlabeled sets are disjoint.
def get_unlabeled_data(train_percentage=5):
"""A simple function get unlabeled CIFAR10 data."""
(x_train, _), (_, _) = tf.keras.datasets.cifar10.load_data()
# Take the training_percentage of the training data.
# Take it from the end of the numpy array, not the begignning, to prevent
# overlap with the labeled data.
train_cardinality = train_percentage * 50000 // 100
x_train = x_train[-train_cardinality:]
# Perform the same processing as the `get_labeled_data()` function.
x_train = x_train.astype("float32") / 255.0
train = tf.data.Dataset.from_tensor_slices(x_train)
train = train.shuffle(1000)
# Batch the data. The batch size is a crucial hyperparameter to
# take advantage of your GPU hardware. See the guide to the
# optimization metalearner to find out how to learn an optimal batch size.
train = train.batch(256)
return train
Now call the Masterful SSL utility, which analyzes your data and stores the analysis to disk. The utility will ensure consistent batch sizes, a consistent right ratio of labeled to unlabeled data, take care of complexities of running and training a model, and optionally allow for weighting of the labeled and unlabeled data.
The function all will take some time to iterate through each example from both datasets, analyze them, and save the interim results to disk.
[4]:
import masterful
masterful = masterful.register()
unlabeled = get_unlabeled_data()
masterful.ssl.analyze_data_then_save_to(model,
train,
unlabeled,
path='/tmp/ssl',
)
Loaded Masterful version 0.4.0. This software is distributed free of
charge for personal projects and evaluation purposes.
See http://www.masterfulai.com/personal-and-evaluation-agreement for details.
Sign up in the next 39 days at https://www.masterfulai.com/get-it-now
to continue using Masterful.
3000it [00:02, 1413.45it/s]
Now load the record from disk into a tf.data.Dataset
, apply your augmentation function, and train. The record includes both labeled and unlabeled data, so each epoch will take longer to run.
[5]:
new_model = get_model()
ssl_training_data = masterful.ssl.load_from(path='/tmp/ssl').batch(256)
augmented_ssl_training_data = ssl_training_data.map(
lambda image, label: (augment_image(image), label))
train_model(new_model, augmented_ssl_training_data, val)
Epoch 1/100
12/12 [==============================] - 4s 233ms/step - loss: 3.1790 - acc: 0.1177 - val_loss: 2.2992 - val_acc: 0.1396
Epoch 2/100
12/12 [==============================] - 2s 169ms/step - loss: 2.9010 - acc: 0.1137 - val_loss: 2.3001 - val_acc: 0.0976
Epoch 3/100
12/12 [==============================] - 2s 168ms/step - loss: 2.7171 - acc: 0.1373 - val_loss: 2.3025 - val_acc: 0.0990
...
Epoch 98/100
12/12 [==============================] - 2s 175ms/step - loss: 2.1727 - acc: 0.3537 - val_loss: 2.0497 - val_acc: 0.3080
Epoch 99/100
12/12 [==============================] - 2s 178ms/step - loss: 2.1730 - acc: 0.3437 - val_loss: 2.0553 - val_acc: 0.3084
Epoch 100/100
12/12 [==============================] - 2s 177ms/step - loss: 2.1683 - acc: 0.3507 - val_loss: 2.0504 - val_acc: 0.3050
Evaluate your newly training model against the old one. You should see an improvement in accuracy now that you are applying SSL techniques to learn from unlabeled data.
[6]:
ssl_eval_metrics = new_model.evaluate(test)
print(f'run , test loss, test accuracy')
print(f'baseline, {baseline_eval_metrics[0]:.4}, {baseline_eval_metrics[1]:.4}')
print(f'ssl , {ssl_eval_metrics[0]:.4}, {ssl_eval_metrics[1]:.5}')
20/20 [==============================] - 1s 35ms/step - loss: 2.0740 - acc: 0.2952
run , test loss, test accuracy
baseline, 2.288, 0.12
ssl , 2.074, 0.2952
Advanced Tuning¶
To improve the results, two hyperparameters to tune are the intensity of augmentations, and the weighting of the unlabeled data.
The intensity of augmentations is generally empirically discovered by a search algorithm, such as guessing and checking or grid search. If your augmentations are suboptimally tuned, consider using the full Masterful API to manage SSL end to end.
The weighting of unlabeled data is also generally empirically discovered by a search algorithm. As a rule of thumb, a 1:5 ratio of labeled to unlabeled data often works well. If you have more unlabeled data than that, you’ll want to downweight each unlabeled example (and vice versa).
Examples are below.
[ ]:
# You can quickly increase your augmentation intensity by augmenting twice.
new_model = get_model()
ssl_training_data = masterful.ssl.load_from(path='/tmp/ssl')
ssl_training_data = ssl_training_data.map(lambda image, label: (augment(augment(image)), label))
train_model(new_model, augment_image(augment_image(ssl_training_data)), val)
# If your unlabeled training data is 4x or less the cardinality of your labeled training data,
# you can increase the weight of the unlabeled training data.
new_model = get_model()
ssl_training_data = masterful.ssl.load_from(path='/tmp/ssl', unlabeled_weight=2.0)
ssl_training_data = ssl_training_data.map(lambda image, label, weight: (augment(image)), label, weight))
train_model(new_model, augment_image(ssl_training_data), val)
# Alternatively, if your unlabeled training data is 6x or more
# the cardinality of your labeled training data, you can decrease the
# weight of the unlabeled training data.
new_model = get_model()
ssl_training_data = ssl_training_data.map(lambda image, label, weight: (augment(image)), label, weight))
train_model(new_model, augment_image(ssl_training_data), val)