""" image_train.py – training script for Image models Author: Petr Kaška, Martin Hemza Date: 2025-28-1 """ from __future__ import annotations import argparse, math, random from pathlib import Path import numpy as np, torch, torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader, WeightedRandomSampler from tqdm import tqdm from image_augment import get_transforms from image_datasets import get_dataset from image_models import get_model # ---------- mix-up --------------------------------------------------------- def _mixup(x, y, alpha: float = 0.4): if alpha <= 0: return x, y, None, None, 1.0 lam = np.random.beta(alpha, alpha) idx = torch.randperm(x.size(0), device=x.device) return lam * x + (1 - lam) * x[idx], y, y[idx], idx, lam # ---------- trénovací/val epochy ------------------------------------------- def _train_epoch(model, loader, opt, device): model.train() tot, loss_sum = 0, 0.0 for x, y in tqdm(loader, desc="train", leave=False): x, y = x.to(device), y.to(device) x_mix, yA, yB, _, lam = _mixup(x, y) logits = model(x_mix, yA) loss = lam * F.cross_entropy(logits, yA) + (1 - lam) * F.cross_entropy(logits, yB) opt.zero_grad(set_to_none=True) loss.backward() opt.step() tot += y.size(0) loss_sum += loss.item() * y.size(0) return loss_sum / tot @torch.no_grad() def _validate(model, loader, device): model.eval() tot, loss_sum, correct = 0, 0.0, 0 for x, y in loader: x, y = x.to(device), y.to(device) logits = model(x) loss_sum += F.cross_entropy(logits, y, reduction="sum").item() correct += (logits.argmax(1) == y).sum().item() tot += y.size(0) return loss_sum / tot, correct / tot def _balanced_sampler(labels): cnt = np.bincount(labels) w = 1. / cnt[labels] return WeightedRandomSampler(w, num_samples=len(labels) * 40, replacement=True) # ---------- hlavní ---------------------------------------------------------- def main(): ap = argparse.ArgumentParser() ap.add_argument("--root", required=True) ap.add_argument("--model", default="face_resnet") ap.add_argument("--fold", default="fold1") ap.add_argument("--epochs", type=int, default=80) ap.add_argument("--out", default="checkpoints") args = ap.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" # -- data -------------------------------------------------------------- train_ds = get_dataset(args.root, transform=get_transforms("train"), fold=args.fold, split="train") val_ds = get_dataset(args.root, transform=get_transforms("val"), fold=args.fold, split="test") train_ld = DataLoader(train_ds, batch_size=64, sampler=_balanced_sampler([l for _,l in train_ds.samples]), num_workers=4, pin_memory=True) val_ld = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=4, pin_memory=True) # -- model + optim ----------------------------------------------------- model = get_model(args.model, num_classes=31).to(device) opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4) warm, total = 5, args.epochs sched = optim.lr_scheduler.LambdaLR( opt, lambda e: (e+1)/warm if e best + 1e-4: best, bad = val_acc, 0 torch.save(model.state_dict(), ckpt/"best.pth") else: bad += 1 if bad >= patience: print("Early-stop.") break print(f"Best val acc {best*100:.2f}%") if __name__ == "__main__": main()