Home | 简体中文 | 繁体中文 | 杂文 | Github | 知乎专栏 | Facebook | Linkedin | Youtube | 打赏(Donations) | About
知乎专栏

10.4. Dataset

			
import torch
from torch.utils.data import Dataset, DataLoader

x = torch.arange(15).reshape(5, 3)
# print(x)
y = torch.arange(5).reshape(5, 1)


# print(y)
class MyDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, item):
        x = self.x[item]
        y = self.y[item]
        return x, y


dataset = MyDataset(x, y)
print(len(dataset))
print(dataset.x)
print(dataset.y)
print(dataset[1])

loader = DataLoader(dataset, batch_size=2, shuffle=True)
for i, data in enumerate(loader):
    print(i, data)			
			
		

输出结果

		
5
tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11],
        [12, 13, 14]])
tensor([[0],
        [1],
        [2],
        [3],
        [4]])
(tensor([3, 4, 5]), tensor([1]))
0 [tensor([[ 3,  4,  5],
        [12, 13, 14]]), tensor([[1],
        [4]])]
1 [tensor([[6, 7, 8],
        [0, 1, 2]]), tensor([[2],
        [0]])]
2 [tensor([[ 9, 10, 11]]), tensor([[3]])]