SSL via Automatic Labeling

Open In Colab        DownloadDownload this Notebook

In this recipe, you’ll perform simple data manipulation and utilize Masterful’s stand-alone automatic labeling utility to improve your model’s accuracy with SSL techniques.

SSL, or Semi-Supervised Learning, means allowing your model to learn from both labeled and unlabeled data. Normally, SSL techniques require custom training loops and multiple losses. This recipe allows you to quickly implement one of the simplest forms of SSL: Automatic Labeling.

Automatic labeling draws from the research lineage of self training, teacher-student training, and feature consistency. It can deliver material improvements to your model’s accuracy. But it will not match the performance or reliablity of the full Masterful platform.

Note that the stand alone automatic labeling utility is just that - stand alone. It sits apart from the rest of the Masterful platform, and you will not see it used in conjection with the rest of the API.

Consider using this recipe if:

  • You want to quickly try an SSL technique.

  • You want to keep your own training loop.

  • You have a very well-tuned regularization policy.

For power users, you may want to skip this recipe and go straight to the full Masterful CLI Trainer or Masterful Python API if:

  • You want to maximize accuracy.

  • Your regularization policy is not optimally tuned.

  • Your model is detection or instance segmentation.

First, set up a standard supervised training pipeline.

This will not do any SSL yet. It should resemble most supervised training pipelines you’ve developed.

Implement functions to:

  • Get your dataset (get_labeled_datasets())

  • Create your model architecture (get_model())

  • Augment your data (augment_images())

  • Train your model (train_model()).

[1]:
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds


def get_labeled_datasets(train_percentage=1):
  """Simple function to get cifar10 as a `tf.data.Dataset`"""
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

  # Take the first training_percentage of the training data.
  train_cardinality = train_percentage * 50000 // 100
  x_train = x_train[0:train_cardinality]
  y_train = y_train[0:train_cardinality]

  # Normalize data into the range [0,1]
  x_train = x_train.astype("float32") / 255.0
  x_test = x_test.astype("float32") / 255.0

  # Convert labels to one-hot.
  y_train = tf.keras.utils.to_categorical(y_train, 10)
  y_test = tf.keras.utils.to_categorical(y_test, 10)

  # Split test into a val and test dataset.
  x_val = x_test[:5000]
  y_val = y_test[:5000]

  x_test = x_test[5000:]
  y_test = y_test[5000:]

  # Convert the data to tf.data.Dataset.
  train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
  val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
  test = tf.data.Dataset.from_tensor_slices((x_test, y_test))

  # Shuffle just the training dataset.
  train = train.shuffle(1000)

  # Batch the data. The batch size is a crucial hyperparameter to
  # take advantage of your GPU hardware. See the guide to the
  # optimization metalearner to find out how to learn an optimal batch size.
  train = train.batch(256)
  val = val.batch(256)
  test = test.batch(256)

  train = train.prefetch(tf.data.AUTOTUNE)

  return train, val, test

def get_model():
  """Returns a minimal convnet. """
  inp = tf.keras.Input((32, 32, 3))
  x = inp
  x = tf.keras.layers.Conv2D(16, 3, activation='relu')(x)
  x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.MaxPooling2D()(x)

  x = tf.keras.layers.Conv2D(32, 3, activation='relu')(x)
  x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.MaxPooling2D()(x)

  x = tf.keras.layers.Conv2D(64, 3, activation='relu')(x)
  x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.MaxPooling2D()(x)

  x = tf.keras.layers.GlobalAveragePooling2D()(x)
  x = tf.keras.layers.Flatten()(x)
  x = tf.keras.layers.Dense(10, activation='softmax')(x)
  return tf.keras.Model(inp, x)

def augment_image(image):
  """A simple augmentation pipeline."""
  image = tf.image.random_brightness(image, 0.1)
  image = tf.image.random_hue(image, 0.1)
  image = tf.image.resize(image, size=[32,32])
  image = tf.image.random_flip_left_right(image)

  return image

def train_model(model, augmented_train, validation_data, epochs=100):
  """A simple training loop. """

  early_stopping = tf.keras.callbacks.EarlyStopping(patience=25,
                                                    verbose=2,
                                                    restore_best_weights=True)

  # The learning rate used by the optimizer (in this case Adam)
  # is a crucial hyperparameter to take advantage of your GPU hardware.
  # See the guide to the optimization metalearner to find out how to
  # learn an optimal learning rate.
  model.compile(
      optimizer=tfa.optimizers.LAMB(learning_rate=0.001),
      loss='categorical_crossentropy',
      metrics=['acc'],
  )

  model.fit(augmented_train,
            validation_data=validation_data,
            epochs=epochs,
            callbacks=early_stopping)

# Now use the functions you just defined to train a model start to finish.
train, val, test = get_labeled_datasets()
augmented_train = train.map(lambda image, label: (augment_image(image), label))
model = get_model()
train_model(model, augmented_train, val)
Epoch 1/100
2/2 [==============================] - 5s 1s/step - loss: 2.8249 - acc: 0.1377 - val_loss: 2.2997 - val_acc: 0.1374
Epoch 2/100
2/2 [==============================] - 1s 489ms/step - loss: 2.7732 - acc: 0.1205 - val_loss: 2.2985 - val_acc: 0.1352
Epoch 3/100

...

Epoch 41/100
2/2 [==============================] - 0s 344ms/step - loss: 1.6290 - acc: 0.4312 - val_loss: 2.2970 - val_acc: 0.1282
Epoch 42/100
2/2 [==============================] - 0s 347ms/step - loss: 1.6206 - acc: 0.4550 - val_loss: 2.2982 - val_acc: 0.1288
Epoch 43/100
2/2 [==============================] - 0s 350ms/step - loss: 1.6054 - acc: 0.4368 - val_loss: 2.2992 - val_acc: 0.1274
Restoring model weights from the end of the best epoch.
Epoch 00043: early stopping
[2]:
baseline_eval_metrics = model.evaluate(test)
20/20 [==============================] - 0s 8ms/step - loss: 2.2876 - acc: 0.1556

Now you’ll improve the accuracy of your model using automatic labeling.

First, set up your unlabeled data as a batched tf.data.Dataset. Typically, each element of a batched Dataset is a tuple of tensors: (images, labels). Since unlabeled data doesn’t have a label, just make each element of your batched dataset a tensor: images.

[3]:
# Normally, the unlabeled dataset comes from images that are not yet labeled.
# To simulate that with CIFAR10, you will use 5% of the training data, but
# remove the labels. Be sure to use the end of the training data, not the
# beginning, to ensure that the labeled and unlabeled sets are disjoint.
def get_unlabeled_data(train_percentage=5):
  """A simple function get unlabeled CIFAR10 data."""
  (x_train, _), (_, _) = tf.keras.datasets.cifar10.load_data()

  # Take the training_percentage of the training data.
  # Take it from the end of the numpy array, not the begignning, to prevent
  # overlap with the labeled data.
  train_cardinality = train_percentage * 50000 // 100
  x_train = x_train[-train_cardinality:]

  # Perform the same processing as the `get_labeled_data()` function.
  x_train = x_train.astype("float32") / 255.0
  train = tf.data.Dataset.from_tensor_slices(x_train)
  train = train.shuffle(1000)

  # Batch the data. The batch size is a crucial hyperparameter to
  # take advantage of your GPU hardware. See the guide to the
  # optimization metalearner to find out how to learn an optimal batch size.
  train = train.batch(256)

  return train

After activating with masterful.activate, call the masterful.ssl.analyze_data_then_save_to utility function. The function will ensure consistent batch sizes, a consistent right ratio of labeled to unlabeled data, and optionally allow for weighting of the labeled and unlabeled data.

The function will take some time to iterate through each example from both datasets, analyze them, and save the interim results to disk.

[4]:
import masterful

masterful = masterful.activate()

unlabeled = get_unlabeled_data()

masterful.ssl.analyze_data_then_save_to(model,
                                        train,
                                        unlabeled,
                                        path='/tmp/ssl')
Loaded Masterful version 0.4.1. This software is distributed free of
charge for personal projects and evaluation purposes.
See http://www.masterfulai.com/personal-and-evaluation-agreement for details.
Sign up in the next 44 days at https://www.masterfulai.com/get-it-now
to continue using Masterful.
3000it [00:02, 1030.33it/s]

Now call masterful.ssl.load_from to load the record from disk into a tf.data.Dataset, apply your augmentation function, and train. The record includes both (a) the labeled data and (b) the unlabeled data with automatic labels. So each epoch will take longer to run.

[5]:
ssl_training_data = masterful.ssl.load_from(path='/tmp/ssl').batch(256)

augmented_ssl_training_data = ssl_training_data.map(
    lambda image, label: (augment_image(image), label))

new_model = get_model()
train_model(new_model, augmented_ssl_training_data, val)
Epoch 1/100
12/12 [==============================] - 3s 132ms/step - loss: 3.0270 - acc: 0.0470 - val_loss: 2.3087 - val_acc: 0.0836
Epoch 2/100
12/12 [==============================] - 1s 110ms/step - loss: 2.7617 - acc: 0.0616 - val_loss: 2.3102 - val_acc: 0.0974
Epoch 3/100

...

12/12 [==============================] - 1s 100ms/step - loss: 2.2064 - acc: 0.3250 - val_loss: 2.0833 - val_acc: 0.2986
Epoch 98/100
12/12 [==============================] - 1s 100ms/step - loss: 2.1699 - acc: 0.3193 - val_loss: 2.0858 - val_acc: 0.2804
Epoch 99/100
12/12 [==============================] - 1s 101ms/step - loss: 2.1651 - acc: 0.3252 - val_loss: 2.0820 - val_acc: 0.2768
Epoch 100/100
12/12 [==============================] - 1s 100ms/step - loss: 2.1664 - acc: 0.3269 - val_loss: 2.0744 - val_acc: 0.2846

Evaluate your newly training model against the old one. You should see an improvement in accuracy now that you are applying SSL techniques to learn from unlabeled data.

[10]:
def show_eval_metrics(baseline_metrics, ssl_metrics):
  print('Run     | Test Loss | Test Accuracy')
  print('-----------------------------------')
  print(f'baseline| {baseline_metrics[0]:.4}     |{baseline_metrics[1]:.4}')
  print(f'ssl     | {ssl_metrics[0]:.4}     |{ssl_metrics[1]:.5}')

ssl_eval_metrics = new_model.evaluate(test)

show_eval_metrics(baseline_eval_metrics, ssl_eval_metrics)
20/20 [==============================] - 0s 8ms/step - loss: 2.0755 - acc: 0.2944
Run     | Test Loss | Test Accuracy
-----------------------------------
baseline| 2.288     |0.1556
ssl     | 2.076     |0.2944

Advanced Tuning

To improve the results, two hyperparameters to tune are the intensity of augmentations, and the weighting of the unlabeled data.

The intensity of augmentations is generally empirically discovered by a search algorithm, such as guessing and checking or grid search. If your augmentations are suboptimally tuned, consider using the full Masterful API to manage SSL end to end.

The weighting of unlabeled data is also generally empirically discovered by a search algorithm. As a rule of thumb, a 1:5 ratio of labeled to unlabeled data often works well. If you have more unlabeled data than that, you’ll want to downweight each unlabeled example (and vice versa).

Examples are below.

[11]:
# You can quickly increase your augmentation intensity by augmenting twice.
ssl_training_data = masterful.ssl.load_from(path='/tmp/ssl')
augmented_ssl_training_data = ssl_training_data.map(lambda image, label: (augment_image(image), label))
augmented_ssl_training_data = augmented_ssl_training_data.batch(256)

new_model = get_model()
train_model(new_model, augmented_ssl_training_data, val)

ssl_eval_metrics = new_model.evaluate(test)
show_eval_metrics(baseline_eval_metrics, ssl_eval_metrics)

# If your unlabeled training data is 4x or less the cardinality of your labeled training data,
# you can increase the weight of the unlabeled training data.
ssl_training_data = masterful.ssl.load_from(path='/tmp/ssl', unlabeled_weight=2.0)
augmented_ssl_training_data = ssl_training_data.map(lambda image, label, weight: (augment_image(image), label, weight))
augmented_ssl_training_data = augmented_ssl_training_data.batch(256)

new_model = get_model()
train_model(new_model, augmented_ssl_training_data, val)

ssl_eval_metrics = new_model.evaluate(test)
show_eval_metrics(baseline_eval_metrics, ssl_eval_metrics)

# Alternatively, if your unlabeled training data is 6x or more
# the cardinality of your labeled training data, you can decrease the
# weight of the unlabeled training data.
ssl_training_data = masterful.ssl.load_from(path='/tmp/ssl', unlabeled_weight=0.5)
augmented_ssl_training_data = ssl_training_data.map(lambda image, label, weight: (augment_image(image), label, weight))
augmented_ssl_training_data = augmented_ssl_training_data.batch(256)

new_model = get_model()
train_model(new_model, augmented_ssl_training_data, val)

ssl_eval_metrics = new_model.evaluate(test)
show_eval_metrics(baseline_eval_metrics, ssl_eval_metrics)
Epoch 1/100
12/12 [==============================] - 3s 121ms/step - loss: 3.3689 - acc: 0.2178 - val_loss: 2.3050 - val_acc: 0.1028
Epoch 2/100
12/12 [==============================] - 1s 105ms/step - loss: 3.0205 - acc: 0.1672 - val_loss: 2.3071 - val_acc: 0.1024
Epoch 3/100

...

Epoch 98/100
12/12 [==============================] - 1s 102ms/step - loss: 2.1763 - acc: 0.3229 - val_loss: 2.0621 - val_acc: 0.2908
Epoch 99/100
12/12 [==============================] - 1s 105ms/step - loss: 2.1764 - acc: 0.3292 - val_loss: 2.0739 - val_acc: 0.2986
Epoch 100/100
12/12 [==============================] - 1s 102ms/step - loss: 2.1805 - acc: 0.3320 - val_loss: 2.0633 - val_acc: 0.2906
20/20 [==============================] - 0s 8ms/step - loss: 2.0709 - acc: 0.2942

Run     | Test Loss | Test Accuracy
-----------------------------------
baseline| 2.288     |0.1556
ssl     | 2.071     |0.2942

Epoch 1/100
12/12 [==============================] - 3s 119ms/step - loss: 5.5107 - acc: 0.0489 - val_loss: 2.3053 - val_acc: 0.1000
Epoch 2/100
12/12 [==============================] - 1s 102ms/step - loss: 5.0825 - acc: 0.0871 - val_loss: 2.3054 - val_acc: 0.1090
Epoch 3/100

...

Epoch 98/100
12/12 [==============================] - 1s 101ms/step - loss: 4.1386 - acc: 0.3869 - val_loss: 2.1388 - val_acc: 0.2810
Epoch 99/100
12/12 [==============================] - 1s 103ms/step - loss: 4.1368 - acc: 0.3770 - val_loss: 2.1407 - val_acc: 0.2842
Epoch 100/100
12/12 [==============================] - 1s 102ms/step - loss: 4.1367 - acc: 0.3685 - val_loss: 2.1330 - val_acc: 0.2834
20/20 [==============================] - 0s 8ms/step - loss: 2.1373 - acc: 0.2842

Run     | Test Loss | Test Accuracy
-----------------------------------
baseline| 2.288     |0.1556
ssl     | 2.137     |0.2842

Epoch 1/100
12/12 [==============================] - 3s 117ms/step - loss: 1.7747 - acc: 0.0883 - val_loss: 2.3029 - val_acc: 0.1126
Epoch 2/100
12/12 [==============================] - 1s 103ms/step - loss: 1.5988 - acc: 0.0992 - val_loss: 2.3033 - val_acc: 0.0868
Epoch 3/100

...

Epoch 84/100
12/12 [==============================] - 1s 103ms/step - loss: 1.1847 - acc: 0.3145 - val_loss: 1.9927 - val_acc: 0.3098
Epoch 85/100
12/12 [==============================] - 1s 102ms/step - loss: 1.1864 - acc: 0.3123 - val_loss: 2.0123 - val_acc: 0.3042
Epoch 86/100
12/12 [==============================] - 1s 102ms/step - loss: 1.1803 - acc: 0.2990 - val_loss: 2.0239 - val_acc: 0.2890
Restoring model weights from the end of the best epoch.
Epoch 00086: early stopping
20/20 [==============================] - 0s 8ms/step - loss: 1.9777 - acc: 0.3186

Run     | Test Loss | Test Accuracy
-----------------------------------
baseline| 2.288     |0.1556
ssl     | 1.978     |0.3186