Home | 简体中文 | 繁体中文 | 杂文 | Github | 知乎专栏 | Facebook | Linkedin | Youtube | 打赏(Donations) | About
知乎专栏

10.5. DataLoader

			
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True) 
#shuffle 将训练模型的数据集打乱。
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size)
			
		

10.5.1. 显示数据集中的图片

		
from torchvision import datasets
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt

if __name__ == '__main__':
    train_datasets = datasets.CIFAR10(root='data', train=False, download=True)
    train_dataloader = DataLoader(train_datasets, batch_size=10, shuffle=False)

    fig, axes = plt.subplots(3, 3, figsize=(4, 4))

    for i, ax in enumerate(axes.flat):
        ax.imshow(train_dataloader.dataset.data[i])
        ax.axis("off")
        ax.set_title(train_dataloader.dataset.classes[train_dataloader.dataset.targets[i]])
    plt.show()