HOME/Articles/

pil example scaleai5 (snippet)

Article Outline

Python pil example 'scaleai5'

Functions in program:

  • def get_async_dataloaders(gcs_data_root: str, batch_size: int = 8, transform: Callable = None,
  • def generate_stream(items: List[str]) -> Generator[str, None, None]:

Modules used in program:

  • import aiohttp
  • import asyncio
  • import random

python scaleai5

Python pil example: scaleai5

import random
import asyncio

import aiohttp
from janus import Queue
from gcloud.aio.storage import Storage


def generate_stream(items: List[str]) -> Generator[str, None, None]:
    while True:
        # Python's randint has inclusive upper bound
        index = random.randint(0, len(items) - 1)
        yield items[index]


class AsyncImageDataset(torch.utils.data.IterableDataset):
    def __init__(self, data_root: str, items: List[str], transform: Callable = None, concurrency: int = 64):
        self.data_root = data_root
        self.items = items
        self.transform = transform
        self.worker_initialized = False
        self.loop_thread = None
        self.q = None
        self.creds = os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
        self.concurrency = concurrency
        self.stream = generate_stream(self.items)

    async def run(self, loop, session):
        for item in self.stream:
            try:
                image_gs = urlparse(os.path.join(self.data_root, "images", item + ".jpg"))
                label_gs = urlparse(os.path.join(self.data_root, "labels", item + ".png"))
                aio_storage = Storage(service_file=self.creds, session=session)
                blobs = await asyncio.gather(
                    aio_storage.download(image_gs.netloc, image_gs.path[1:]),
                    aio_storage.download(label_gs.netloc, label_gs.path[1:]),
                    loop=loop
                )
                image = Image.open(io.BytesIO(blobs[0]))
                label = Image.open(io.BytesIO(blobs[1])).convert("RGB")
                await self.q.async_q.put((image, label))
            except aiohttp.ClientError as e:
                logging.debug(e)
            except TimeoutError:
                pass
            except Exception as e:
                logging.exception(e)

    def init_worker(self):
        loop = asyncio.new_event_loop()
        session = aiohttp.ClientSession(loop=loop, connector=aiohttp.TCPConnector(limit=0, loop=loop),
                                        raise_for_status=True)
        self.q = Queue(self.concurrency, loop=loop)

        # Spin up workers
        for _ in range(self.concurrency):
            loop.create_task(self.run(loop, session))

        def loop_in_thread(loop):
            asyncio.set_event_loop(loop)
            loop.run_forever()

        self.loop_thread = Thread(target=loop_in_thread, args=(loop,), daemon=True)
        self.loop_thread.start()
        self.worker_initialized = True

    def __iter__(self):
        while True:
            if not self.worker_initialized:
                self.init_worker()

            image, label = self.q.sync_q.get()
            if self.transform is not None:
                image, label = self.transform((image, label))

            yield image, label


def get_async_dataloaders(gcs_data_root: str, batch_size: int = 8, transform: Callable = None,
                          test_ratio: float = 0.1, num_workers: int = 8) -> Dict[str, torch.utils.data.DataLoader]:
    # Async Streaming
    streamed_items = load_items_gcs(os.path.join(gcs_data_root, "images"))
    train_indices, test_indices = consistent_train_test_split(streamed_items, test_ratio)
    train_items = [streamed_items[i.item()] for i in train_indices]
    train_dataset = AsyncImageDataset(gcs_data_root, train_items, transform=transform, concurrency=128)
    test_items = [streamed_items[i.item()] for i in test_indices]
    test_dataset = AsyncImageDataset(gcs_data_root, test_items, transform=transform, concurrency=128)

    return {
        "train": torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, worker_init_fn=worker_init_fn,
                                num_workers=num_workers),
        "test": torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, worker_init_fn=worker_init_fn,
                               num_workers=num_workers)
    }