We’ve got a pretty powerfull Lambda workstation with GPUs that has served us very well.

Let’s see how we can use them together with the Accelerator API from huggingface:

#!/usr/bin/env python

"""
A simple example of how to use the Accelerator API
to train a ResNet-50 model on a dummy dataset.

Accelerator enables training on a single GPU, multiple GPUs.

Run once `accelerate config` to set up your configuration file.

Run with `accelerate launch test_gpus_accelerate.py` to run on all GPUs.
"""

import fire
from tqdm import tqdm
import torch
import torchvision
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from accelerate import Accelerator


class DummyDataset(Dataset):
    def __len__(self) -> int:
        return 1_000_000

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        image = torch.rand(3, 224, 224)
        label = torch.randint(0, 1000, (1,))[0]
        return image, label


def test(
    model_name: str = "alexnet",
    batch_size: int = 2048,
    epochs: int = 3,
    num_workers: int = 8,
):
    accelerator = Accelerator()
    device = accelerator.device

    model = getattr(torchvision.models, model_name)(weights="DEFAULT").to(device)
    optimizer = torch.optim.Adam(model.parameters())

    dataset = DummyDataset()
    data = DataLoader(
        dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )

    model, optimizer, data = accelerator.prepare(model, optimizer, data)

    model.train()
    if accelerator.is_local_main_process:
        tqdm0 = tqdm(range(epochs), position=0, leave=True, desc="Epochs")
        tqdm1 = tqdm(data, position=1, leave=False, desc="Batches")
    else:
        tqdm0 = range(epochs)
        tqdm1 = data
    for epoch in tqdm0:
        for source, targets in tqdm1:
            source = source.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()

            output = model(source)
            loss = F.cross_entropy(output, targets)
            if accelerator.is_local_main_process:
                tqdm1.set_postfix({"images": source.shape[0], "loss": loss.item()})

            accelerator.backward(loss)

            optimizer.step()
        # Valid evaluation


if __name__ == "__main__":
    fire.Fire(test)
blog comments powered by Disqus