知乎专栏 |
import torch from torch.utils.data import Dataset, DataLoader x = torch.arange(15).reshape(5, 3) # print(x) y = torch.arange(5).reshape(5, 1) # print(y) class MyDataset(Dataset): def __init__(self, x, y): self.x = x self.y = y def __len__(self): return len(self.x) def __getitem__(self, item): x = self.x[item] y = self.y[item] return x, y dataset = MyDataset(x, y) print(len(dataset)) print(dataset.x) print(dataset.y) print(dataset[1]) loader = DataLoader(dataset, batch_size=2, shuffle=True) for i, data in enumerate(loader): print(i, data)
输出结果
5 tensor([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11], [12, 13, 14]]) tensor([[0], [1], [2], [3], [4]]) (tensor([3, 4, 5]), tensor([1])) 0 [tensor([[ 3, 4, 5], [12, 13, 14]]), tensor([[1], [4]])] 1 [tensor([[6, 7, 8], [0, 1, 2]]), tensor([[2], [0]])] 2 [tensor([[ 9, 10, 11]]), tensor([[3]])]