HOME/Articles/

pil example Xception (snippet)

Article Outline

Python pil example 'Xception'

Modules used in program:

  • import PIL
  • import abc
  • import base64
  • import numpy as np
  • import tempfile
  • import json
  • import os
  • import tensorflow as tf

python Xception

Python pil example: Xception

import tensorflow as tf
from tensorflow import keras
#import keras
import os
import json
import tempfile
import numpy as np
import base64
import abc
from urllib.parse import urlparse, parse_qs
import PIL
from PIL import Image as pil_image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array

from rafiki.model import BaseModel, InvalidModelParamsException, test_model_class, \
                        IntegerKnob, FixedKnob, FloatKnob, CategoricalKnob, dataset_utils
from rafiki.constants import TaskType, ModelDependency


class TfXception(BaseModel):

    @staticmethod
    def get_knob_config():
        return {
            'epochs': FixedKnob(20),
            'learning_rate': FloatKnob(1e-4, 1e-3, is_exp=True),
            'batch_size': CategoricalKnob([8]),
        }

    def __init__(self, **knobs):
        super().__init__(**knobs)
        self._knobs = knobs
        self._graph = tf.Graph()
        #self._graph = tf.get_default_graph()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        self._sess = tf.Session(graph=self._graph, config=config)


    def crop_generator(self, batches, crop_length):
        '''
        Take as input a Keras ImageGen (Iterator) and generate random
        crops from the image batches generated by the original iterator
        '''
        def random_crop(img, random_crop_size):
            # Note: image_data_format is 'channel_last'
            assert img.shape[2] == 3
            height, width = img.shape[0], img.shape[1]
            dy, dx = random_crop_size
            x = np.random.randint(0, width - dx + 1)
            y = np.random.randint(0, height - dy + 1)
            return img[y:(y+dy), x:(x+dx), :]

        while True:
            batch_x, batch_y = next(batches)
            batch_crops = np.zeros((batch_x.shape[0], crop_length, crop_length, 3))
            #print(batch_x.shape[0])
            for i in range(batch_x.shape[0]):
                batch_crops[i] = random_crop(batch_x[i], (crop_length, crop_length))
            yield (batch_crops, batch_y)

    def random_scale_jitter(self, minimum, maximum):
        #need to scale by aspect ratio
        random_scale = np.random.randint(minimum, maximum)
        return (random_scale, random_scale)

    def train(self, dataset_uri):
        ep = self._knobs.get('epochs')
        bs = self._knobs.get('batch_size')
        num_classes = len(os.listdir(dataset_uri))
        total_number_of_images = 0
        for category in os.listdir(dataset_uri):
            total_number_of_images += len(os.listdir(dataset_uri + '/' + category))

        train_datagen = ImageDataGenerator(
            rescale=1. / 255,
            width_shift_range=0.2,
            height_shift_range=0.2,
            zoom_range=0.2,
            horizontal_flip=True)

        train_generator = train_datagen.flow_from_directory(
            dataset_uri,
            #target_size=(299, 299),
            target_size = self.random_scale_jitter(300, 400),
            batch_size=bs,
            class_mode='categorical')
        train_crops = self.crop_generator(train_generator, 299)

        with self._graph.as_default():
            #self._classIndices = train_generator.class_indices
            np.save('class_indices.npy', train_generator.class_indices)
            self._model = self._build_model(num_classes)
            with self._sess.as_default():
                #self._sess.run(tf.global_variables_initializer())
                #self._sess.run(tf.local_variables_initializer())
                self._model.fit_generator(
                    train_crops,
                    steps_per_epoch=total_number_of_images // bs,
                    epochs=ep
                    #validation_data=validation_generator,
                    #validation_steps=71 // 16
                )

    def evaluate(self, dataset_uri):

        test_datagen = ImageDataGenerator(rescale=1. / 255)
        validation_generator = test_datagen.flow_from_directory(
            dataset_uri,
            target_size=(299, 299),
            #target_size = self.random_scale_jitter(300, 400),
            batch_size=sum([len(i) for i in [files for root, dirs, files in os.walk(dataset_uri)][1:]]), 
            class_mode='categorical')
        #valid_crops = self.crop_generator(validation_generator, 299)

        images, classes = validation_generator.next()
        #images, classes = next(valid_crops)

        with self._graph.as_default():
            with self._sess.as_default():
                (loss, accuracy) = self._model.evaluate(images, classes)
        return accuracy


    def predict(self, queries):        
        images = [pil_image.fromarray(np.asarray(x, dtype=np.uint8)) for x in queries]
        images = np.asarray([np.asarray(x.resize((299,299)), dtype=np.float32)/255 for x in images])

        with self._graph.as_default():
            with self._sess.as_default():
                probs = self._model.predict(images)
                #class_indices = self._classIndices

        probs = probs[0]
        class_indices = np.load('class_indices.npy').item()
        class_indices = {v: k for k, v in class_indices.items()}
        top_indexes = np.argsort(probs)[::-1]#[:5]
        tops = {}
        for idx in top_indexes:
            tops[class_indices[idx]] = float(probs[idx])

        #return probs.tolist()
        return [list(tops.items())]

    def destroy(self):
        self._sess.close()

    def dump_parameters(self):
        params = {}

        # Save model parameters
        with tempfile.NamedTemporaryFile() as tmp:
            # Save whole model to temp h5 file
            with self._graph.as_default():
                with self._sess.as_default():
                    self._model.save(tmp.name)
                    #np.save('xception_class_indices.npy', self._classIndices)

            # Read from temp h5 file & encode it to base64 string
            with open(tmp.name, 'rb') as f:
                h5_model_bytes = f.read()

            params['h5_model_base64'] = base64.b64encode(h5_model_bytes).decode('utf-8')

        return params

    def load_parameters(self, params):
        # Load model parameters
        h5_model_base64 = params.get('h5_model_base64', None)
        if h5_model_base64 is None:
            raise InvalidModelParamsException()

        with tempfile.NamedTemporaryFile() as tmp:
            # Convert back to bytes & write to temp file
            h5_model_bytes = base64.b64decode(h5_model_base64.encode('utf-8'))
            with open(tmp.name, 'wb') as f:
                f.write(h5_model_bytes)

            # Load model from temp file
            with self._graph.as_default():
                with self._sess.as_default():
                    self._model = keras.models.load_model(tmp.name)
                    #self._classIndices = np.load('xception_class_indices.npy').item()

    def _build_model(self, num_classes):
        learning_rate = self._knobs.get('learning_rate')
        model = keras.applications.Xception(
            include_top=True,
            input_shape=(299, 299, 3),
            weights=None,
            classes=num_classes
        )

#         model.layers.pop()
#         predictions = keras.layers.Dense(num_classes, activation='softmax')(model.layers[-1].output)
#         model = keras.models.Model(inputs=[model.input], outputs=[predictions])

        model.compile(
            #optimizer=keras.optimizers.Adam(lr=learning_rate),
            optimizer=keras.optimizers.SGD(lr=learning_rate, momentum=0.9, nesterov=True),
            loss='categorical_crossentropy',
            metrics=['accuracy']
        )
        return model