| 知乎专栏 |
import torch
from torch.utils.data import Dataset, DataLoader
x = torch.arange(15).reshape(5, 3)
# print(x)
y = torch.arange(5).reshape(5, 1)
# print(y)
class MyDataset(Dataset):
def __init__(self, x, y):
self.x = x
self.y = y
def __len__(self):
return len(self.x)
def __getitem__(self, item):
x = self.x[item]
y = self.y[item]
return x, y
dataset = MyDataset(x, y)
print(len(dataset))
print(dataset.x)
print(dataset.y)
print(dataset[1])
loader = DataLoader(dataset, batch_size=2, shuffle=True)
for i, data in enumerate(loader):
print(i, data)
输出结果
5
tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[12, 13, 14]])
tensor([[0],
[1],
[2],
[3],
[4]])
(tensor([3, 4, 5]), tensor([1]))
0 [tensor([[ 3, 4, 5],
[12, 13, 14]]), tensor([[1],
[4]])]
1 [tensor([[6, 7, 8],
[0, 1, 2]]), tensor([[2],
[0]])]
2 [tensor([[ 9, 10, 11]]), tensor([[3]])]