We’ve got a pretty powerfull Lambda workstation with GPUs that has served us very well.

Let’s see how we can use them together with the Accelerator API from huggingface:

#!/usr/bin/env python

A simple example of how to use the Accelerator API
to train a ResNet-50 model on a dummy dataset.

Accelerator enables training on a single GPU, multiple GPUs.

Run once `accelerate config` to set up your configuration file.

Run with `accelerate launch test_gpus_accelerate.py` to run on all GPUs.

import fire
from tqdm import tqdm
import torch
import torchvision
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from accelerate import Accelerator

class DummyDataset(Dataset):
    def __len__(self) -> int:
        return 1_000_000

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        image = torch.rand(3, 224, 224)
        label = torch.randint(0, 1000, (1,))[0]
        return image, label

def test(
    model_name: str = "alexnet",
    batch_size: int = 2048,
    epochs: int = 3,
    num_workers: int = 8,
    accelerator = Accelerator()
    device = accelerator.device

    model = getattr(torchvision.models, model_name)(weights="DEFAULT").to(device)
    optimizer = torch.optim.Adam(model.parameters())

    dataset = DummyDataset()
    data = DataLoader(
        dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers

    model, optimizer, data = accelerator.prepare(model, optimizer, data)

    if accelerator.is_local_main_process:
        tqdm0 = tqdm(range(epochs), position=0, leave=True, desc="Epochs")
        tqdm1 = tqdm(data, position=1, leave=False, desc="Batches")
        tqdm0 = range(epochs)
        tqdm1 = data
    for epoch in tqdm0:
        for source, targets in tqdm1:
            source = source.to(device)
            targets = targets.to(device)


            output = model(source)
            loss = F.cross_entropy(output, targets)
            if accelerator.is_local_main_process:
                tqdm1.set_postfix({"images": source.shape[0], "loss": loss.item()})


        # Valid evaluation

if __name__ == "__main__":
blog comments powered by Disqus