TPUs + Cassava Leaf Disease

By Jesse Mostipak in technical

January 4, 2021

Note: this post was originally published as a Kaggle notebook


Who this notebook is for

This notebook is for anyone interested in creating a baseline model using Tensor Processing Units (TPUs) and begin making submissions to the Cassava Leaf Disease Classification competition. If you’ve taken the Kaggle Intro to Deep Learning and//or the Kaggle Computer Vision courses you’ll find this notebook to be a good starting place to bridge what you’ve learned in our micro-courses and applying that knowledge to get started in a competition.

How to use this notebook

Feel free to use this notebook as a walkthrough on how to build a preliminary image classification model using TensorFlow and Tensor Processing Units (TPUs). You can copy and edit the notebook by clicking on the corresponding button in the top right, which will make your own personal copy of the notebook in your Kaggle account. From there any edits you make will be unique to your own copy of the notebook!

TPUs with TensorFlow

We’ll be using TensorFlow and Keras to build our computer vision model, and using TPUs to both train our model and make predictions. If you’d like to learn about more about TPUs be sure to check out our Learn With Me: Getting Started with Tensor Processing Units (TPUs) video.


This notebook was built using the following amazing resources created by Kagglers:

Tensor Processing Units (TPUs)

Tensor Processing Units (TPUs) are hardware accelerators that are specialized for deep learning tasks. All Kagglers have 30 hours of free TPU time each week, and can use up to 3 hours in a single session (although if you’d like to increase your TPU quota consider submitting an exemplary TPU notebook to our TPU Star program!)

You can read through the Kaggle documentation on TPUs here, and check out the TPU Star notebooks here.

Set up environment

import math, re, os
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from kaggle_datasets import KaggleDatasets
from tensorflow import keras
from functools import partial
from sklearn.model_selection import train_test_split
print("Tensorflow version " + tf.__version__)

Detect TPU

What we’re doing with our code here is making sure that we’ll be sending our data across a TPU. What you’re looking for is a printout of Number of replicas: 8, corresponding to the 8 cores of a TPU. If your printout instead says Number of replicas: 1 you likely do not have TPUs enabled in your notebook.

To enable TPUs navigate to the panel on the right and click on Accelerator. Choose TPU from the dropdown.

If you’d like more TPU troubleshooting and optimization guidelines check out our Learn With Me: Troubleshooting and Optimizing TPUs video.

    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

Set up variables

We’ll set up some of our variables for our notebook here.

If by chance you’re using a private dataset, you’ll also want to make sure that you have the Google Cloud Software Development Kit (SDK) attached to your notebook. You can find the Google Cloud SDK under the Add-ons dropdown menu at the top of your notebook. Documentation for the Google Cloud Software Development Kit (SDK) can be found here.

GCS_PATH = KaggleDatasets().get_gcs_path()
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
IMAGE_SIZE = [512, 512]
CLASSES = ['0', '1', '2', '3', '4']

Load the data

If you’ve primarily worked with notebooks in Learn, you’ve maybe noticed that data import and formatting is taken care of for you. But because we’re working with competition data we’ll have to handle this part of the pipeline ourselves.

The data we’re working with have been formatted into TFRecords, which are a format for storing a sequence of binary records. TFRecords work really well with TPUs, and allow us to send a small number of large files across the TPU for processing.

If you’d like to learn more about TFRecords and maybe even try creating them yourself, check out this TFRecords Basics notebook and corresponding video from Kaggle Data Scientist Ryan Holbrook.

Because our data consists of training and test images only, we’re going to split our training data into training and validation data using the train_test_split() function.

Decode the data

In the code chunk below we’ll set up a series of functions that allow us to convert our images into tensors so that we can utilize them in our model. We’ll also normalize our data. Our images are using a “Red, Blue, Green (RBG)” scale that has a range of [0, 255], and by normalizing it we’ll set each pixel’s value to a number in the range of [0, 1].

def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

If you think back to Intro to Machine Learning you might remember how we set up variables like X and y, representing our features, X, and prediction target, y. This code is accomplishing something similar, although instead of using the labels X and y, our features are represented by the term image and our prediction target by the term target.

You might also notice that this function accounts for unlabeled images. This is because our test image doesn’t have any labels.

def read_tfrecord(example, labeled):
    tfrecord_format = {
        "image":[], tf.string),
        "target":[], tf.int64)
    } if labeled else {
        "image":[], tf.string),
        "image_name":[], tf.string)
    example =, tfrecord_format)
    image = decode_image(example['image'])
    if labeled:
        label = tf.cast(example['target'], tf.int32)
        return image, label
    idnum = example['image_name']
    return image, idnum

We’ll use the following function to load our dataset. One of the advantages of a TPU is that we can run multiple files across the TPU at once, and this accounts for the speed advantages of using a TPU. To capitalize on that, we want to make sure that we’re using data as soon as it streams in, rather than creating a data streaming bottleneck.

def load_dataset(filenames, labeled=True, ordered=False):
    ignore_order =
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed
    dataset =, num_parallel_reads=AUTOTUNE) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset =, labeled=labeled), num_parallel_calls=AUTOTUNE)
    return dataset

A note on using train_test_split()

While I used train_test_split() to create both a training and validation dataset, consider exploring cross validation instead.

TRAINING_FILENAMES, VALID_FILENAMES = train_test_split( + '/train_tfrecords/ld_train*.tfrec'),
    test_size=0.35, random_state=5

TEST_FILENAMES = + '/test_tfrecords/ld_test*.tfrec')

Adding in augmentations

You learned about augmentations in the Computer Vision: Data Augmentation lesson on Kaggle Learn, and here I’ve applied an augmentation available to us through TensorFlow. You can read more about these augmentations (as well as all of the other augmentations available to you!) in the TensorFlow tf.image documentation.

If you’re interested in learning how to create and use custom augmentations, check out these Rotation Augmentation GPU/TPU and CutMix and MixUp on GPU/TPU from Kaggle Grandmaster Chris Deotte.

def data_augment(image, label):
    # Thanks to the dataset.prefetch(AUTO) statement in the following function this happens essentially for free on TPU. 
    # Data pipeline code is executed on the "CPU" part of the TPU while the TPU itself is computing gradients.
    image = tf.image.random_flip_left_right(image)
    return image, label

Define data loading methods

The following functions will be used to load our training, validation, and test datasets, as well as print out the number of images in each dataset.

def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES, labeled=True)  
    dataset =, num_parallel_calls=AUTOTUNE)  
    dataset = dataset.repeat()
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset
def get_validation_dataset(ordered=False):
    dataset = load_dataset(VALID_FILENAMES, labeled=True, ordered=ordered) 
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset
def get_test_dataset(ordered=False):
    dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

print('Dataset: {} training images, {} validation images, {} (unlabeled) test images'.format(

Brief Exploratory Data Analysis (EDA)

Brief exploratory data analysis (EDA) First we’ll print out the shapes and labels for a sample of each of our three datasets:

print("Training data shapes:")
for image, label in get_training_dataset().take(3):
    print(image.numpy().shape, label.numpy().shape)
print("Training data label examples:", label.numpy())
print("Validation data shapes:")
for image, label in get_validation_dataset().take(3):
    print(image.numpy().shape, label.numpy().shape)
print("Validation data label examples:", label.numpy())
print("Test data shapes:")
for image, idnum in get_test_dataset().take(3):
    print(image.numpy().shape, idnum.numpy().shape)
print("Test data IDs:", idnum.numpy().astype('U')) # U=unicode string

The following code chunk sets up a series of functions that will print out a grid of images. The grid of images will contain images and their corresponding labels.

# numpy and matplotlib defaults
np.set_printoptions(threshold=15, linewidth=80)

def batch_to_numpy_images_and_labels(data):
    images, labels = data
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    if numpy_labels.dtype == object: # binary string in this case, these are image ID strings
        numpy_labels = [None for _ in enumerate(numpy_images)]
    # If no labels, only image IDs, return None for labels (this is the case for test data)
    return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
    if correct_label is None:
        return CLASSES[label], True
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(CLASSES[label], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
                                CLASSES[correct_label] if not correct else ''), correct

def display_one_plant(image, title, subplot, red=False, titlesize=16):
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)

def display_batch_of_images(databatch, predictions=None):
    """This will work with:
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
    # data
    images, labels = batch_to_numpy_images_and_labels(databatch)
    if labels is None:
        labels = [None for _ in enumerate(images)]
    # auto-squaring: this will drop data that does not fit into square or square-ish rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows
    # size and spacing
    FIGSIZE = 13.0
    SPACING = 0.1
    if rows < cols:
    # display
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        title = '' if label is None else CLASSES[label]
        correct = True
        if predictions is not None:
            title, correct = title_from_label_and_target(predictions[i], label)
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        subplot = display_one_plant(image, title, subplot, not correct, titlesize=dynamic_titlesize)
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
# load our training dataset for EDA
training_dataset = get_training_dataset()
training_dataset = training_dataset.unbatch().batch(20)
train_batch = iter(training_dataset)
# run this cell again for another randomized set of training images

You can also modify the above code to look at your validation and test data, like this:

# load our validation dataset for EDA
validation_dataset = get_validation_dataset()
validation_dataset = validation_dataset.unbatch().batch(20)
valid_batch = iter(validation_dataset)
# run this cell again for another randomized set of training images
# load our test dataset for EDA
testing_dataset = get_test_dataset()
testing_dataset = testing_dataset.unbatch().batch(20)
test_batch = iter(testing_dataset)
# we only have one test image

Building the model

Learning rate schedule

We learned about learning rates in the Intro to Deep Learning: Stochastic Gradient Descent lesson, and here I’ve created a learning rate schedule mostly using the defaults in the Keras Exponential Decay Learning Rate Scheduler documentation (I did change the initial_learning_rate. You can adjust the learning rate scheduler below, and read more about the other types of schedulers available to you in the Keras learning rate schedules API.

lr_scheduler = keras.optimizers.schedules.ExponentialDecay(

Building our model

In order to ensure that our model is trained on the TPU, we build it using with strategy.scope().

This model was built using transfer learning, meaning that we have a pre-trained model (ResNet50) as our base model and then the customizable model built using tf.keras.Sequential. If you’re new to transfer learning I recommend setting base_model.trainable to False, but do encourage you to change which base model you’re using (more options are available in the tf.keras.applications Module documentation) as well iterate on the custom model.

Note that we’re using sparse_categorical_crossentropy as our loss function, because we did not one-hot encode our labels.

with strategy.scope():       
    img_adjust_layer = tf.keras.layers.Lambda(tf.keras.applications.resnet50.preprocess_input, input_shape=[*IMAGE_SIZE, 3])
    base_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False)
    base_model.trainable = False
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(8, activation='relu'),
        tf.keras.layers.Dense(len(CLASSES), activation='softmax')  
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr_scheduler, epsilon=0.001),

Train the model

As our model is training you’ll see a printout for each epoch, and can also monitor TPU usage by clicking on the TPU metrics in the toolbar at the top right of your notebook.

# load data
train_dataset = get_training_dataset()
valid_dataset = get_validation_dataset()

history =, 

With model.summary() we’ll see a printout of each of our layers, their corresponding shape, as well as the associated number of parameters. Notice that at the bottom of the printout we’ll see information on the total parameters, trainable parameters, and non-trainable parameters. Because we’re using a pre-trained model, we expect there to be a large number of non-trainable parameters (because the weights have already been assigned in the pre-trained model).


Evaluating our model

The first chunk of code is provided to show you where the variables in the second chunk of code came from. As you can see, there’s a lot of room for improvement in this model, but because we’re using TPUs and have a relatively short training time, we’re able to iterate on our model fairly rapidly.

# print out variables available to us
# create learning curves to evaluate model performance
history_frame = pd.DataFrame(history.history)
history_frame.loc[:, ['loss', 'val_loss']].plot()
history_frame.loc[:, ['sparse_categorical_accuracy', 'val_sparse_categorical_accuracy']].plot();

Making predictions

Now that we’ve trained our model we can use it to make predictions!

# this code will convert our test image data to a float32 
def to_float32(image, label):
    return tf.cast(image, tf.float32), label
test_ds = get_test_dataset(ordered=True) 
test_ds =

print('Computing predictions...')
test_images_ds = testing_dataset
test_images_ds = image, idnum: image)
probabilities = model.predict(test_images_ds)
predictions = np.argmax(probabilities, axis=-1)

Creating a submission file

Now that we’ve trained a model and made predictions we’re ready to submit to the competition! You can run the following code below to get your submission file.

print('Generating submission.csv file...')
test_ids_ds = image, idnum: idnum).unbatch()
test_ids = next(iter(test_ids_ds.batch(NUM_TEST_IMAGES))).numpy().astype('U') # all in one batch
np.savetxt('submission.csv', np.rec.fromarrays([test_ids, predictions]), fmt=['%s', '%d'], delimiter=',', header='id,label', comments='')
!head submission.csv

Be aware that because this is a code competition with a hidden test set, internet and TPUs cannot be enabled on your submission notebook. Therefore TPUs will only be available for training models. For a walk-through on how to train on TPUs and run inference/submit on GPUs, see our TPU Docs.

Posted on:
January 4, 2021
12 minute read, 2427 words
kaggle python TPU computer vision
See Also:
No more tears: the easy way to install Python on your machine
Truss + XGBoost for Rapid Model Deployment
Dive into {dplyr}