1. Overview

This guide will focus on the development of a main.py script. At a high level, its structure will look like this:

1. Prerequisite

To successfully follow this guide, you need to:

  • Import at least one Dataset into Picsellia.
  • Create a Project and attach a Dataset to it.
  • Create an experiment with the attached Dataset.
  • Have a Model and ModelVersion. If you don't have any, follow this tutorial to create them.

2. Starting point

An example of folder with your scripts

An example of folder with your scripts

Let's consider a folder containing three scripts: model.py, utils.py, and train.py.

For the purpose of this guide, we will focus on the train.py file and we will put all the code / functions inside. However, to enforce clean code conventions, it is recommended to create a dedicated file for Picsellia utilities.

Let's say we want to integrate Picsellia into an EfficientNet fine-tuning script, similar to this one from Keras, your train.py file should look like this:

from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers
import tensorflow as tf
import tensorflow_datasets as tfds

IMG_SIZE = 224

    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
    print("Device:", tpu.master())
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError:
    print("Not connected to a TPU runtime. Using CPU/GPU strategy")
    strategy = tf.distribute.MirroredStrategy()

batch_size = 64

dataset_name = "stanford_dogs"
(ds_train, ds_test), ds_info = tfds.load(
    dataset_name, split=["train", "test"], with_info=True, as_supervised=True

NUM_CLASSES = ds_info.features["label"].num_classes

ds_train = ds_train.map(lambda image, label: (tf.image.resize(image, size), label))
ds_test = ds_test.map(lambda image, label: (tf.image.resize(image, size), label))

def format_label(label):
    string_label = label_info.int2str(label)
    return string_label.split("-")[1]

label_info = ds_info.features["label"]

img_augmentation = Sequential(
        layers.RandomTranslation(height_factor=0.1, width_factor=0.1),

# One-hot / categorical encoding
def input_preprocess(image, label):
    label = tf.one_hot(label, NUM_CLASSES)
    return image, label

ds_train = ds_train.map(
    input_preprocess, num_parallel_calls=tf.data.AUTOTUNE
ds_train = ds_train.batch(batch_size=batch_size, drop_remainder=True)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

ds_test = ds_test.map(input_preprocess)
ds_test = ds_test.batch(batch_size=batch_size, drop_remainder=True)

def build_model(num_classes):
    inputs = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
    x = img_augmentation(inputs)
    model = EfficientNetB0(include_top=False, input_tensor=x, weights="imagenet")

    # Freeze the pretrained weights
    model.trainable = False

    # Rebuild top
    x = layers.GlobalAveragePooling2D(name="avg_pool")(model.output)
    x = layers.BatchNormalization()(x)

    top_dropout_rate = 0.2
    x = layers.Dropout(top_dropout_rate, name="top_dropout")(x)
    outputs = layers.Dense(NUM_CLASSES, activation="softmax", name="pred")(x)

    # Compile
    model = tf.keras.Model(inputs, outputs, name="EfficientNet")
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-2)
        optimizer=optimizer, loss="categorical_crossentropy", metrics=["accuracy"]
    return model

with strategy.scope():
    model = build_model(num_classes=NUM_CLASSES)

epochs = 25  
hist = model.fit(ds_train, epochs=epochs, validation_data=ds_test, verbose=2)

def unfreeze_model(model):
    # We unfreeze the top 20 layers while leaving BatchNorm layers frozen
    for layer in model.layers[-20:]:
        if not isinstance(layer, layers.BatchNormalization):
            layer.trainable = True

    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
        optimizer=optimizer, loss="categorical_crossentropy", metrics=["accuracy"]


epochs = 10  
hist = model.fit(ds_train, epochs=epochs, validation_data=ds_test, verbose=2)

Now, let's learn how to transform it into a fully configurable and connected training script. πŸš€