知乎专栏 |
Compose(): 用来管理所有的transforms操作。
ToTensor(): 把图片数据转换成张量并转化范围在[0,1]区间内。
Normalize(mean, std): 归一化。
Resize(size): 输入的PIL图像调整为指定的大小,参数可以为int或int元组。
CenterCrop(size): 将给定的PIL Image进行中心切割,得到指定size的tuple。
RandomCrop(size, padding=0):随机中心点切割。
RandomHorizontalFlip(size, interpolation=2):将给定的PIL Image随机切割,再resize。
RandomHorizontalFlip():随机水平翻转给定的PIL Image。
RandomVerticalFlip(): 随机垂直翻转给定的PIL Image。
ToPILImage(): 将Tensor或numpy.ndarray转换为PIL Image。
FiveCrop(size): 将给定的PIL图像裁剪成4个角落区域和中心区域。
Pad(padding, fill=0, padding_mode=‘constant’):对PIL边缘进行填充。
RandomAffine(degrees, translate=None, scale=None):保持中心不变的图片进行随机仿射变化。
RandomApply(transforms, p=0.5):随机选取变换。
from torchvision import transforms # 数据预处理 transform = transforms.Compose([transforms.RandomCrop(32), transforms.Resize(256)]) # 对图像进行转换 img = Image.open('example.jpg') img_transformed = transform(img)
将 PIL Image 或 numpy.ndarray 转为 tensor
import numpy as np from torchvision import transforms data = np.array([ [0, 5, 10, 20, 0], [255, 125, 180, 255, 196] ], dtype=np.uint8) tensor = transforms.ToTensor()(data) print(tensor) """ tensor([[[0.0000, 0.0196, 0.0392, 0.0784, 0.0000], [1.0000, 0.4902, 0.7059, 1.0000, 0.7686]]]) """
from torchvision.transforms PIL_img = transforms.ToPILImage()(tensor_img)
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
densenet = models.densenet_161()
from torchvision import models # 加载预训练的ResNet-50模型 model = models.resnet50(pretrained=True)
# 数据预处理 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)]) # 加载数据集 trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
datasets.ImageFolder方法可以实现数据导入。 ImageFolder(root,transform=None,target_transform=None,loader=default_loader)
from torchvision import datasets train_datasets = datasets.CIFAR10(root='data',train=False,download=True) print(train_datasets)
/Users/neo/PycharmProjects/netkiller/venv/bin/python /Users/neo/PycharmProjects/netkiller/test.py Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz 100.0% Extracting data/cifar-10-python.tar.gz to data Dataset CIFAR10 Number of datapoints: 10000 Root location: data Split: Test
from torchvision import datasets train_datasets = datasets.CIFAR10(root='data',train=False,download=True) print(train_datasets.data)
from torchvision import datasets train_datasets = datasets.CIFAR10(root='data',train=False,download=True) print(train_datasets.data.shape)
Files already downloaded and verified (10000, 32, 32, 3)
(图片数量, 高度, 宽度, 通道数)
from torchvision import datasets train_datasets = datasets.CIFAR10(root='data',train=False,download=True) print(train_datasets.data.min()) print(train_datasets.data.max())
Files already downloaded and verified 0 255
from torchvision import datasets train_datasets = datasets.CIFAR10(root='data',train=False,download=True) print(train_datasets.classes) print(train_datasets.class_to_idx)
Files already downloaded and verified ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
单张图片
from torchvision import datasets from matplotlib import pyplot as plt train_datasets = datasets.CIFAR10(root='data',train=False,download=True) plt.imshow(train_datasets.data[0]) plt.show()
显示3x3九张图片
from torchvision import datasets from matplotlib import pyplot as plt test_datasets = datasets.CIFAR10(root='data',train=False,download=True) fig,axes=plt.subplots(3,3,figsize=(4,4)) for i,ax in enumerate(axes.flat): ax.imshow(test_datasets.data[i]) ax.axis("off") ax.set_title(test_datasets.classes[test_datasets.targets[i]]) plt.show()