HOME/Articles/

pil example data loader (snippet)

Article Outline

Python pil example 'data loader'

Functions in program:

  • def get_loader(root, json, vocab, transform, batch_size, shuffle,num_workers):
  • def collate_fn(data):

Modules used in program:

  • import jieba
  • import numpy as np
  • import pickle
  • import os
  • import torch.utils.data as data
  • import torchvision.transforms as transforms
  • import torch

python data loader

Python pil example: data loader

import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import os
import pickle
import numpy as np
from PIL import Image
from build_vocab import Vocabulary
import jieba

class myDataset(data.Dataset):
    """Custom Dataset"""
    def __init__(self, root, json, vocab, transform=None):
        '''
        :param root: image dir
        :param json: annotation file path
        :param vocab: vocab wrapper
        :param transform: image transform
        '''
        self.root = root
        self.json = json
        self.vocab = vocab
        self.transform = transform

    def __getitem__(self, index):
        """return one data pair
        (image, caption_jieba_list)
        iamge: PIL image metadata
        captions_jieba_list: [[1,2,3,,,], [4,5,6,,,],,,]
        """
        vocab = self.vocab
        img_name = self.json[index]['image_id']
        img_path = os.path.join(self.root, img_name)
        image = Image.open(img_path).convert("RGB")

        captions_list = self.json[index]['caption']

        if self.transform is not None:
            image = self.transform(image)

        captions_jieba = [[] for i in range(len(captions_list))]

        for i, cap in enumerate(captions_list):
            captions_jieba[i].append(vocab('<start>'))
            captions_jieba[i].extend([vocab(c) for c in jieba.cut(cap)])
            captions_jieba[i].append(vocab('<end>'))
        return image, captions_jieba

    def __len__(self):
        return len(self.json)

def collate_fn(data):
    """
    Create mini-batch tensors from a list of tuple (image, caption)
    返回的是图片乘以5倍之后的结果

    Args:
        data: 从getitem返回的data类型
            - image: torch tensor of shape
            - caption: [[1,2,3,,,], [4,5,6,,,],,,]
    Returns:
        images: torch tensor of shape (batch_size, 3, 256, 256)
        targets: torch tensors of shape (batch_size, padded_length)
        lengths: List; valid length for each padded caption
    """
    images, captions = zip(*data)
    images_five_times = []
    for image in images:
        for i in range(5):
            images_five_times.append(image.clone())

    # Merge into 4D
    images_five_times = torch.stack(images_five_times, 0)

    # Merge captions into 2D
    lengths = [len(cap) for i in captions for cap in i]

    targets = torch.zeros(len(lengths), max(lengths)).long()
    for cap_l in enumerate(captions):
        for i, cap in enumerate(cap_l):
            end = lengths[i]
            targets[i, :end] = cap[:end]
    return images, targets, lengths



def get_loader(root, json, vocab, transform, batch_size, shuffle,num_workers):
    """Returns torch.utils.data.DataLoader for custom dataset"""
    data = myDataset(root=root,
                     json=json,
                     vocab=vocab,
                     transform=transform)
    # 返回 (images, caption, lengths) 对于每一次iteration
    data_loader = torch.utils.data.DataLoader(dataset=data,
                                              batch_size=batch_size,
                                              shuffle=shuffle,
                                              num_workers=num_workers,
                                              collate_fn=collate_fn)
    return data_loader