Home | 简体中文 | 繁体中文 | 杂文 | Github | 知乎专栏 | Facebook | Linkedin | Youtube | 打赏(Donations) | About
知乎专栏

10.8. torchvision

10.8.1. 安装 torchvision

		
pip install torchvision		
		
		

10.8.2. transforms 数据转换



Compose(): 用来管理所有的transforms操作。
ToTensor(): 把图片数据转换成张量并转化范围在[0,1]区间内。
Normalize(mean, std): 归一化。
Resize(size): 输入的PIL图像调整为指定的大小,参数可以为int或int元组。
CenterCrop(size): 将给定的PIL Image进行中心切割,得到指定size的tuple。
RandomCrop(size, padding=0):随机中心点切割。
RandomHorizontalFlip(size, interpolation=2):将给定的PIL Image随机切割,再resize。
RandomHorizontalFlip():随机水平翻转给定的PIL Image。
RandomVerticalFlip(): 随机垂直翻转给定的PIL Image。
ToPILImage(): 将Tensor或numpy.ndarray转换为PIL Image。
FiveCrop(size): 将给定的PIL图像裁剪成4个角落区域和中心区域。
Pad(padding, fill=0, padding_mode=‘constant’):对PIL边缘进行填充。
RandomAffine(degrees, translate=None, scale=None):保持中心不变的图片进行随机仿射变化。
RandomApply(transforms, p=0.5):随机选取变换。

		
from torchvision import transforms
# 数据预处理
transform = transforms.Compose([transforms.RandomCrop(32), transforms.Resize(256)])

# 对图像进行转换
img = Image.open('example.jpg')
img_transformed = transform(img)		
		
		

10.8.2.1. Compose 流程编排/组装

10.8.2.2. ToTensor 矩阵转换

将 PIL Image 或 numpy.ndarray 转为 tensor

np.uint8 类型转为 tensor
				
import numpy as np
from torchvision import transforms

data = np.array([
    [0, 5, 10, 20, 0],
    [255, 125, 180, 255, 196]
], dtype=np.uint8)

tensor = transforms.ToTensor()(data)
print(tensor)
"""
tensor([[[0.0000, 0.0196, 0.0392, 0.0784, 0.0000],
         [1.0000, 0.4902, 0.7059, 1.0000, 0.7686]]])
"""
				
				
				
非 np.uint8 类型转为 tensor
				
import numpy as np
from torchvision import transforms

data = np.array([
    [0, 5, 10, 20, 0],
    [255, 125, 180, 255, 196]
])      # data.dtype = int32

tensor = transforms.ToTensor()(data)
print(tensor)
"""
tensor([[[  0,   5,  10,  20,   0],
         [255, 125, 180, 255, 196]]], dtype=torch.int32)
"""				
				
				

10.8.2.3. Normalize

			
			
			
			

10.8.2.4. tensor转换为PIL Image

			
from torchvision.transforms 
PIL_img = transforms.ToPILImage()(tensor_img) 
			
			
			

10.8.2.5. 彩色图像转灰度图像

			
from torchvision import transforms
  
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1), # 彩色图像转灰度图像num_output_channels默认1
    transforms.ToTensor()
])
			
			
			

10.8.3. models 加载模型



import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
densenet = models.densenet_161()

10.8.3.1. 加载ResNet-50模型

			
from torchvision import models
# 加载预训练的ResNet-50模型
model = models.resnet50(pretrained=True)		
			
			

10.8.3.2. 加载Fashion-MNIST

			
from torchvision import datasets
dataset = datasets.FashionMNIST('data/', download=True, train=False, transform=None)			
			
			

10.8.4. datasets 数据加载

10.8.4.1. CIFAR10

			
# 数据预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)])

# 加载数据集
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)		
			
			

10.8.4.2. ImageFolder 数据导入

			
datasets.ImageFolder方法可以实现数据导入。

ImageFolder(root,transform=None,target_transform=None,loader=default_loader)			
			
			

10.8.4.3. 查看 Dataset

			
from torchvision import datasets

train_datasets = datasets.CIFAR10(root='data',train=False,download=True)
print(train_datasets)			
			
			
			
/Users/neo/PycharmProjects/netkiller/venv/bin/python /Users/neo/PycharmProjects/netkiller/test.py 
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz
100.0%
Extracting data/cifar-10-python.tar.gz to data
Dataset CIFAR10
    Number of datapoints: 10000
    Root location: data
    Split: Test			
			
			

10.8.4.4. 查看数据

			
from torchvision import datasets

train_datasets = datasets.CIFAR10(root='data',train=False,download=True)
print(train_datasets.data)			
			
			
查看矩阵维度
				
from torchvision import datasets

train_datasets = datasets.CIFAR10(root='data',train=False,download=True)
print(train_datasets.data.shape)
				
				
				
Files already downloaded and verified
(10000, 32, 32, 3)				
				
				

(图片数量, 高度, 宽度, 通道数)

最小值/最大值
				
from torchvision import datasets

train_datasets = datasets.CIFAR10(root='data',train=False,download=True)
print(train_datasets.data.min())
print(train_datasets.data.max())				
				
				
				
Files already downloaded and verified
0
255				
				
				
显示训练图片
				
from torchvision import datasets
from matplotlib import pyplot as plt

if __name__ == '__main__':
    train_datasets = datasets.CIFAR10(root='data', train=False, download=True)
    
    plt.figure(figsize=(4, 4))
    plt.imshow(train_datasets.data[0])
    plt.show()				
				
				

10.8.4.5. 查看标签

			
from torchvision import datasets

train_datasets = datasets.CIFAR10(root='data',train=False,download=True)
print(train_datasets.classes)
print(train_datasets.class_to_idx)			
			
			
			
Files already downloaded and verified
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
{'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}			
			
			

10.8.4.6. 查看数据集中的图像

单张图片

			
from torchvision import datasets
from matplotlib import pyplot as plt

train_datasets = datasets.CIFAR10(root='data',train=False,download=True)
plt.imshow(train_datasets.data[0])	
plt.show()
			
			

显示3x3九张图片

	 		
from torchvision import datasets
from matplotlib import pyplot as plt

test_datasets = datasets.CIFAR10(root='data',train=False,download=True)
fig,axes=plt.subplots(3,3,figsize=(4,4))

for i,ax in enumerate(axes.flat):
    ax.imshow(test_datasets.data[i])
    ax.axis("off")
    ax.set_title(test_datasets.classes[test_datasets.targets[i]])

plt.show()
			
			

10.8.5. 打开图片

		
    image = torchvision.io.read_image('test.png')
    image = image.resize(28, 28)