数据集类(Data Set)与数据加载器(Data Loader)

数据集类(Data Set)是指存储和表示数据的类或接口。它通常用于封装数据,以便能够在机器学习任务中使用。数据集可以是任何形式的数据,比如图像、文本、音频等。数据集的主要目的是提供对数据的标准访问方法,以便可以轻松地将其用于模型训练、验证和测试。

数据加载器(Data Loader)是一个提供批量加载数据的工具。它通过将数据集分割成小批量,并按照一定的顺序加载到内存中,以提高训练效率。数据加载器常用于训练过程中的数据预处理、批量化操作和数据并行处理等。

​ PyTorch中的torch.utils.data.Datasettorch.utils.data.DataLoader是数据加载和处理的核心组件。它们将数据读取与模型训练解耦,提供高效、灵活的数据迭代方式。下面从基础概念、自定义加载器参数、多进程机制等方面进行详细介绍。

1.数据集(Data Set)

1.1 自定义数据集定义实现

Data Set是一个抽象类,表示一个数据集。任何自定义数据集都必须继承它,自定义DataSet类,必须实现它构造函数和两个方法:

  • __init__: 在 实例化DataSet 对象运行一次。我们初始化包含图像的目录、注释文件和transform与 target_transform.

  • __len__:返回数据集的总样本数。len(dataset)会调用它。

  • __getitem__(self, idx):根据整数索引idx会返回一个样本(通常为特征和标签)。dataset[idx]会调用它。

其作用就是实现通过索引访问对应的数据以及标签

from torch.utils.data import Dataset class MyDataset(Dataset): def __init__(self, data, labels): self.data = data self.labels = labels def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx], self.labels[idx]

使用自定义数据集时,可以用将其与torch.utils.data.DataLoader结合使用,以便进行数据的批量加载和处理和训练。

1.2 两种自定义数据集风格

​ 在PyTorch中,自定义数据集有两个核心设计模式:映射式(Map-Style)可迭代式(Iterable-style)。它们的差异不仅是实现接口不同,更反映了“随机访问”与“流式读取”两种数据消费范式的根本区别。下面从设计理念、实现细节、多进程交互、适用场景等方面深入解析。

  • Map-style datasets(映射式):就是上述需要实现__getitem____len__的数据集,它通过索引映射到数据样本。适用于所有数据能一次性放入索引结构(如列表、文件路径列表)的场景。
  • Iterable-style datasets(可迭代式):当数据集太大无法一次性加载,或数据是流式读取时(如实时日志、数据库流),可以继承IterableDataset,实现__iter__方法返回一个迭代器。这种数据集不能使用len(),也无法使用随机采样(shuffle)的 loader,需使用Sampler的特定变体。

在后续笔记我们将详细介绍。

1.3 内置数据集

​ PyTorch提供了一些常用数据集类,主要在torchvision.datasetstorchtext.datasetstorchaudio.datasets中。例如:

  • torchvision.datasets.MNISTCIFAR10ImageFolder(从文件夹结构加载图片,子文件夹为类别)
  • torchtext.datasets.IMDB
  • torchaudio.datasets.LIBRISPEECH

这些内置类都继承自Dataset,使用时可自动下载数据,并提供标准化访问方式。

​ 现在我们来展示一个如何从TorchVision加载了Fahion-MINIST由60000个训练样本和10000个测试样本组成。每个样本包含一个28×28 灰度图像和一个来自10个类别之一的关联标签。下面使用以下参数加载FashionMINIST数据集:

  • root:是存储路径、测试数据的路径。
  • train:指定训练集或测试数据集。
  • download=True:如果root路径下没有数据,则从网上下载数据。
  • transformtarget_transform是指定特征和标签转换。
import torch from torch.utils.data import Dataset from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as plt training_data = datasets.FashionMNIST( root="./data", train=True, download=True, transform=ToTensor() ) test_data = datasets.FashionMNIST( root="./data", train=False, download=True, transform=ToTensor() )

我们可以用索引来访问数据集中的样本,用matplotlib可视化图形样本。

labels_map = { 0: "T-Shirt", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle Boot", } figure = plt.figure(figsize=(8, 8)) cols, rows = 3, 3 for i in range(1, cols * rows + 1): sample_idx = torch.randint(len(training_data), size=(1,)).item() img, label = training_data[sample_idx] figure.add_subplot(rows, cols, i) plt.title(labels_map[label]) plt.axis("off") plt.imshow(img.squeeze(), cmap="gray") plt.show()

其运行结果如下:

2. 数据加载器(Data Loader)

数据加载器(Data Loader)DataSet封装为可迭代对象,负责批量加载、打乱数据、多进程并行加载等功能。其功能如下:

  • 批量加载数据:DataLoader可以从数据集中按照指定的批量大小加载数据。每个批次的数据可以作为一个张量或列表返回,便于进行后续的处理和训练。
  • 数据随机洗牌:通过设置shuffle=True,DataLoader可以在每个迭代周期中对数据进行随机洗牌,以减少模型对数据顺序的依赖性,提高训练效果。
  • 多线程数据加载:DataLoader支持使用多个线程来并行加载数据,加快数据加载的速度,提高训练效率。
  • 数据批次采样:除了按照批量大小加载数据外,DataLoader还支持自定义的数据批次采样方式。可以通过设置batch_sampler参数来指定自定义的批次采样器,例如按照指定的样本顺序或权重进行采样。

数据加载器的API形式核心参数

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, prefetch_factor=2, persistent_workers=False)
  • dataset:要加载的Dataset对象(映射式或可迭代式)。
  • batch_size:每个批次的样本数,默认为 1。
  • shuffle:是否在每个 epoch 开始时打乱数据顺序(仅对映射式有效)。打乱基于RandomSampler
  • sampler:自定义采样器,继承自torch.utils.data.Sampler。定义数据索引的抽取策略。如果指定,shuffle必须为False
  • batch_sampler:类似sampler,但每次返回一批索引,与batch_sizeshufflesampler互斥。
  • num_workers:用于数据加载的子进程数。0 表示在主进程中加载,通常设置大于 0 可以加速数据预处理,利用多核。
  • collate_fn:函数,定义如何将多个样本列表合并为一个批次。默认collate_fn会将所有样本沿第0维堆叠成张量,通常对于同型数据有效。如果样本结构不一致(如不同长度序列),需要自定义。
  • pin_memory:若为True,数据加载器在返回张量前将其复制到 CUDA 固定内存,加速数据传输到 GPU。仅适用于 CUDA。
  • drop_last:若为True,丢弃最后一个不完整批次(当总样本数不能被 batch_size 整除时)。在训练时如果要求严格固定批次大小(如 BatchNorm)应设为True
  • timeout:从 worker 进程获取一个 batch 的超时时间(秒)。如果超时会抛异常。
  • worker_init_fn:每个 worker 进程的初始化函数,参数为 worker id,可用于设置随机种子等。
  • generator:用于生成随机采样的伪随机数生成器,保证可复现性。
  • prefetch_factor:每个 worker 预先加载的 batch 数(默认 2),增加可以让 GPU 更少等待。
  • persistent_workers:若为True,在数据集被消费一次后不会关闭 worker 进程,可保持 worker 存活以加速后续 epoch。

数据调用案例Demo

import torch from torch.utils.data import Dataset, DataLoader # 自定义数据集类 class MyDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, index): return self.data[index] # 自定义数据加载器类 class MyDataLoader(DataLoader): def __init__(self, dataset, batch_size=1, shuffle=False, num_workers=0): super().__init__(dataset, batch_size, shuffle, num_workers=num_workers) def collate_fn(self, batch): # 自定义的数据预处理、合并等操作 # 这里只是简单地将样本转换为Tensor,并进行堆叠 return torch.stack(batch) # 自定义数据集类 data = [1, 2, 3, 4, 5] dataset = MyDataset(data) # 创建数据加载器实例 dataloader = MyDataLoader(dataset, batch_size=2, shuffle=True) # 遍历数据加载器 for batch in dataloader: # batch是一个包含多个样本的张量(或列表) # 这里可以对批次数据进行处理 print(batch)

3.实战案例

import torch from sklearn.datasets import load_iris from torch.utils.data import Dataset, DataLoader # 此函数用于加载鸢尾花数据集 def load_data(shuffle=True): x = torch.tensor(load_iris().data) y = torch.tensor(load_iris().target) # 数据归一化 x_min = torch.min(x, dim=0).values x_max = torch.max(x, dim=0).values x = (x - x_min) / (x_max - x_min) if shuffle: idx = torch.randperm(x.shape[0]) x = x[idx] y = y[idx] return x, y # 自定义鸢尾花数据类 class IrisDataset(Dataset): def __init__(self, mode='train', num_train=120, num_dev=15): super(IrisDataset, self).__init__() x, y = load_data(shuffle=True) if mode == 'train': self.x, self.y = x[:num_train], y[:num_train] elif mode == 'dev': self.x, self.y = x[num_train:num_train + num_dev], y[num_train:num_train + num_dev] else: self.x, self.y = x[num_train + num_dev:], y[num_train + num_dev:] def __getitem__(self, idx): return self.x[idx], self.y[idx] def __len__(self): return len(self.x) batch_size = 16 # 分别构建训练集、验证集和测试集 train_dataset = IrisDataset(mode='train') dev_dataset = IrisDataset(mode='dev') test_dataset = IrisDataset(mode='test') train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True) dev_loader = DataLoader(dev_dataset, batch_size=batch_size) test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

4.总 结

  • ataset定义数据源及其访问方式,映射式最常用,流式数据用IterableDataset
  • DataLoader封装采样、批处理、多进程加载和内存固定等功能,参数丰富。
  • 通过自定义samplercollate_fn可以灵活处理各种数据形式和不平衡问题。
  • 多进程加载是加速训练的关键,需注意内存复制和系统兼容性。