Home | 简体中文 | 繁体中文 | 杂文 | Github | 知乎专栏 | Facebook | Linkedin | Youtube | 打赏(Donations) | About
知乎专栏

10.10. tensorboard

		
import os

from PIL import Image
from torch.utils.data import Dataset
from torchvision  import transforms
from torch.utils.tensorboard import SummaryWriter

import common
class mydatasets(Dataset):
    def __init__(self,root_dir):
       super(mydatasets, self).__init__()
       self.images=[ os.path.join(root_dir,image_name) for image_name in os.listdir(root_dir)]

       self.transforms=transforms.Compose([
           transforms.Resize((60,160)),
           transforms.ToTensor(),
           transforms.Grayscale()

       ])
    def __getitem__(self, index):
        image_path = self.images[index]
        image = Image.open(image_path)
        image_name=image_path.split("/")[-1]
        data=self.transforms(image)
        img_lable=image_name.split("_")[0]
        img_lable=common.text2vec(img_lable)
        lable=img_lable.view(1,-1)[0]
        return data,lable
    def __len__(self):
        return self.images.__len__()



if __name__ == '__main__':

    # d=mydatasets("./data/train")
    d = mydatasets("./data/test")
    img,label=d[0]
    writer=SummaryWriter("logs")
    writer.add_image("img",img,1)
    print(img.shape)
    writer.close()