Oh fast.ai, how I love thee.

How easy it is to get started and train a model.

But sometimes, you are a pain in the neck.

The documentation isn’t really there and the source code is not readable.

I hear the forums are very helpful but it’s not quite a format I have time for.

Ultimately fasta.ai will only get you started, but I have immense respect and admiration for Howard, Sylvain, and the other people behind it.

Sometimes there is a true spark of genius in the way they do things, going above what many academics do and what other libraries provide (e.g. lr_find).

Let’s check out one of my favourite tricks: using a ResNet as an encoder in an autoencoder.

Sources:

  • https://colab.research.google.com/drive/1t9dn6qIdKc6rdF-A02KMdJ8UVGYPFh4v#scrollTo=e9dnMvm7q5q4
from pathlib import Path

import numpy as np
import imageio
import torch
import torchvision
from fastai.vision.all import (
    L,
    DataBlock,
    ImageBlock,
    CategoryBlock,
    aug_transforms,
    vision_learner,
    create_body,
    Resize,
    PixelShuffle_ICNR,
    ConvLayer,
    nn,
    Module,
    SigmoidRange,
    Tensor,
    xresnet18,
    Learner,
    MSELossFlat
)

def get_class(x):
    return x.stem.split("-")[-1]


def get_self(x):
    return x

# Make dummy dataset
dataset_dir = Path("datasets") / "dummy_resize"
dataset_dir.mkdir(exist_ok=True, parents=True)
(dataset_dir / "train").mkdir(exist_ok=True)
(dataset_dir / "valid").mkdir(exist_ok=True)

for i in range(10):
    for c in ["a", "b"]:
        img = np.random.randint(0, 255, (32, 32, 3)).astype("uint8")
        imageio.imwrite(dataset_dir / "train" / f"{i}-{c}.jpg", img)
        imageio.imwrite(dataset_dir / "valid" / f"{i}-{c}.jpg", img)
files = L(dataset_dir.glob("*/*.jpg"))

# Define model architecture
class UpsampleBlock(Module):
    def __init__(
        self,
        up_in_c: int,
        final_div: bool = True,
        blur: bool = False,
        leaky: float = None,
        **kwargs,
    ):
        self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c // 2, blur=blur, **kwargs)
        ni = up_in_c // 2
        nf = ni if final_div else ni // 2
        self.conv1 = ConvLayer(ni, nf, **kwargs)
        self.conv2 = ConvLayer(nf, nf, **kwargs)
        self.relu = nn.ReLU()

    def forward(self, up_in: Tensor) -> Tensor:
        up_out = self.shuf(up_in)
        cat_x = self.relu(up_out)
        return self.conv2(self.conv1(cat_x))

def decoder_resnet(y_range, n_out=3):
    return nn.Sequential(
        UpsampleBlock(512),
        UpsampleBlock(256),
        UpsampleBlock(128),
        UpsampleBlock(64),
        UpsampleBlock(32),
        nn.Conv2d(16, n_out, 1),
        SigmoidRange(*y_range),
    )

def autoencoder(encoder, y_range):
    return nn.Sequential(encoder, decoder_resnet(y_range))

# Make dataloader
btfms = aug_transforms()
block = DataBlock(
    blocks=[ImageBlock(), ImageBlock()],
    get_y=get_self,
    batch_tfms=btfms,
    item_tfms=Resize(32),
)
dls = block.dataloaders(files, path=dataset_dir, bs=12)
x, y = dls.one_batch()

# Build model and check
encoder = create_body(xresnet18(), n_in=3).cuda()
encoder(x).shape
y_range = (-3.0, 3.0)
ac_resnet = autoencoder(encoder, y_range).cuda()
assert x.shape == ac_resnet(x).shape
decoder = decoder_resnet(y_range).cuda()
assert x.shape == decoder(encoder(x)).shape

# Train
learn = Learner(dls, ac_resnet, loss_func=MSELossFlat())
blog comments powered by Disqus