from pathlib import Path from typing import Any, Iterator import torch from torchaudio import load as audio_load from torch.utils.data import Dataset as TorchDataset, DataLoader as TorchDataLoader, default_collate import torchvision.transforms.functional as F from PIL import Image class DataTransform: """Decode, transform, and crop images to 80x80. Decode sounds. If called with a list of dicts, collates the results. """ def __call__(self, data): if isinstance(data, list): return default_collate([self.transform_one(x) for x in data]) else: return self.transform_one(data) def transform_one(self, data): img = Image.open(data["img_path"]).convert("RGB") img_tensor = F.pil_to_tensor(img) img_tensor = F.center_crop(img_tensor, [80, 80]) data["img"] = img_tensor snd_tensor, sample_rate = audio_load(data["snd_path"]) data["snd"] = snd_tensor data["snd_sr"] = sample_rate return data class DatasetLister: """Access data through either __getitem__, or an iterator. If using an iterator, will loop once, in order. """ def __init__(self, path: str): self.path = Path(path) self.labels = [] self.ids = [] self.sessions = [] self.img_paths = [] self.snd_paths = [] for label_path in self.path.iterdir(): for data_path in label_path.iterdir(): match data_path.suffix: case ".png": self.img_paths.append(str(data_path)) snd_path = data_path.with_suffix(".wav") assert snd_path.exists() self.snd_paths.append(str(snd_path)) parts = data_path.stem.split("_", 2) assert len(parts) == 3 id, session, _rest = parts self.ids.append(id) self.sessions.append(session) self.labels.append(int(label_path.name)) case ".wav": pass case _: raise TypeError(f"Unsupported data-type: {data_path}") assert len(self.labels) == len(self.img_paths) == len(self.snd_paths), ( len(self.labels), len(self.img_paths), len(self.snd_paths), ) def __getitem__(self, i: int) -> dict: data = { "label": self.labels[i], "id": self.ids[i], "session": self.sessions[i], "img_path": self.img_paths[i], "snd_path": self.snd_paths[i], } return data def __len__(self): return len(self.labels) def __iter__(self) -> Iterator[dict]: for i in range(len(self.labels)): yield {"img_path": self.img_paths[i], "snd_path": self.snd_paths[i]} class Dataset(TorchDataset): """Classic DataLoader v1-style dataset (map style). Applies DataTransform when retrieving items. """ def __init__(self, path: str): self.data = DatasetLister(path) self.tx = DataTransform() def __len__(self): return len(self.data) def __getitem__(self, i: int) -> dict: return self.tx(self.data[i]) def my_collate(batch, *, collate_fn_map=None): res = { "label": [], "img": [], "snd": [], "snd_sr": [], "id": [], "session": [], "img_path": [], "snd_path": [], } for sample in batch: for k, v in sample.items(): res[k].append(v) res["img"] = torch.stack(res["img"], axis=0) return res def setup_loader(*, data_dir, batch_size, shuffle): dataset = Dataset(data_dir) loader = TorchDataLoader( dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=my_collate, ) return loader