from torchvision import datasets
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
if __name__ == '__main__':
train_datasets = datasets.CIFAR10(root='data', train=False, download=True)
train_dataloader = DataLoader(train_datasets, batch_size=10, shuffle=False)
fig, axes = plt.subplots(3, 3, figsize=(4, 4))
for i, ax in enumerate(axes.flat):
ax.imshow(train_dataloader.dataset.data[i])
ax.axis("off")
ax.set_title(train_dataloader.dataset.classes[train_dataloader.dataset.targets[i]])
plt.show()