HOME/Articles/

pil example imagezipdataset (snippet)

Article Outline

Python pil example 'imagezipdataset'

Functions in program:

  • def test(cache=True, shuffle=True, num_workers=2):

Modules used in program:

  • import os
  • import numpy as np
  • import torch.multiprocessing as mp
  • import mmap
  • import zipfile
  • import tarfile
  • import torch

python imagezipdataset

Python pil example: imagezipdataset

import torch
from torch.utils.data import Dataset, DataLoader
import tarfile
import zipfile
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
import mmap
import torch.multiprocessing as mp
import numpy as np
import os

# mp.set_start_method('forkserver', force=True) # On linux, you may want to use the fork method (this the default with pytorch 0.4)

class ImageZipDataset(Dataset):
    def __init__(self, path, extension=".jpg", cache=False):

        zip = zipfile.ZipFile(open(path, mode="rb"))

        self._files = [m for m in zip.namelist() if m.endswith(extension)]
        self._path = path
        self._cache = cache
        self._archive = None
        self._pil_to_tensor = transforms.ToTensor()
        self._memmap_handle = None
        if mp.get_start_method() == 'fork' and self._cache is True:
            self._memmap_handle = self.get_memmap_handle()

    def get_memmap_handle(self):
        if self._memmap_handle is not None:
            return self._memmap_handle
        else:
            handle = open(self._path, "rb")
            if os.name == 'nt':
                memmap_handle = mmap.mmap(handle.fileno(), 0, self._path, access=mmap.ACCESS_READ)
            else:
                memmap_handle = mmap.mmap(handle.fileno(), 0, access=mmap.ACCESS_READ)
            return memmap_handle

    def get_zip_handle(self):
        if self._archive is None:
            if self._cache is True:
                file_handle = self.get_memmap_handle()
            else:
                file_handle = open(self._path, "rb")
            self._archive = zipfile.ZipFile(file_handle, mode="r")
        return self._archive

    def __getitem__(self, index):
        #return torch.Tensor(list(self.get_zip_handle().read(self._files[index]))) # Test without jpeg decoding
        img = Image.open(self.get_zip_handle().open(self._files[index]), mode="r") # use of jpeg-turbo is recommended
        return self._pil_to_tensor(img)

    def __len__(self):
        return len(self._files)

def test(cache=True, shuffle=True, num_workers=2):

    test = ZipDataset("myfiles.zip", cache=cache)
    if mp.get_start_method() != 'fork' and cache == True :
        handle = test.get_memmap_handle() # force having at least one handle of the memmaped file in spawn mode (not sure if done right?)
    dl = DataLoader(test, batch_size=1, shuffle=shuffle, num_workers=num_workers)

    res = 1
    for i in range(3):
        for x in tqdm(dl):
            res += x.sum()


if __name__ == "__main__":
    test()