HOME/Articles/

pil example unshredder (snippet)

Article Outline

Python pil example 'unshredder'

Functions in program:

  • def main():

Modules used in program:

  • import PIL.Image
  • import itertools
  • import colorsys
  • import argparse

python unshredder

Python pil example: unshredder

# instagram engineering challenge: unshredder 
# http://instagram-engineering.tumblr.com/post/12651721845/instagram-engineering-challenge-the-unshredder
# checked with Python 2.7.2 and PIL 1.1.7 on Ubuntu Oneiric

import argparse
import colorsys
import itertools
import PIL.Image


def main():
    methods = [n for n in dir(Closeness) if n.startswith("by_")]
    parser = argparse.ArgumentParser(
        description="Image Unshredder",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("-o", "--out-file", default="unshredded.png",
                        help="output image file name")
    parser.add_argument("-f", "--out-format", default="PNG",
                        help="output image format")
    parser.add_argument("-s", "--shred-width", type=int, default=32,
                        help="width of each shredded of input image")
    parser.add_argument("-g", "--guess-width", action="store_true", 
                        default=False,
                        help="automatically guess width of each shred: " + 
                        "SHRED_WIDTH will be ignored")
    parser.add_argument("-m", "--closeness-method", default="by_hsv_hist",
                        help="method for evaluate closeness: select one of " + 
                        str(methods))
    parser.add_argument("imagefile", help="input shredded image file name")

    args = parser.parse_args()
    evaluator = getattr(Closeness, args.closeness_method)
    in_image = PIL.Image.open(args.imagefile)
    if args.guess_width:
        shred_width = ShredSize(in_image, evaluator).guess()
        pass
    else:
        shred_width = args.shred_width
        pass
    assert shred_width > 1
    out_image = Unshredder(in_image, shred_width, evaluator).run()
    out_image.save(args.out_file, args.out_format)
    pass


class Part(object):
    "fragmented part of image"
    def __init__(self, crops, crop_width, lefts=None, rights=None):
        assert len(crops) > 0
        assert all(crop.size[0] == crop_width for crop in crops)
        assert all(crop.size[1] == crops[0].size[1] for crop in crops)
        self.crops = crops
        self.crop_width = crop_width
        # cache side lines
        self.lefts = lefts or [crops[0].getpixel((0, y)) 
                               for y in range(crops[0].size[1])]
        self.rights = rights or [crops[-1].getpixel((crop_width - 1, y)) 
                                 for y in range(crops[-1].size[1])]
        pass

    def join(self, right):
        assert self.crop_width == right.crop_width
        return Part(self.crops + right.crops, self.crop_width,
                    self.lefts, right.rights)

    def to_image(self):
        width = sum(crop.size[0] for crop in self.crops)
        height = self.crops[0].size[1]
        image = PIL.Image.new(self.crops[0].mode, (width, height))
        for i in range(len(self.crops)):
            image.paste(self.crops[i], (self.crop_width * i, 0))
            pass
        return image
    pass


class Unshredder(object):
    "recover shredded image"
    def __init__(self, in_image, shred_width, evaluator=None):
        self.in_image = in_image
        self.shred_width = shred_width
        self.shred_count = in_image.size[0] // shred_width
        self.height = in_image.size[1]
        self.evaluator = evaluator or Closeness.by_hsv_hist
        pass

    def run(self):
        parts = []
        for i in range(self.shred_count):
            # PIL box should be (l, t, r, b) as ranges: [l, r) and [t, b)
            box = (self.shred_width * i, 0, 
                   self.shred_width * (i + 1), self.height)
            parts.append(Part([self.in_image.crop(box)], self.shred_width))
            pass
        while len(parts) > 1:
            left, right = self.pick_nearest(parts)
            parts = [part for part in parts if part != left and part != right]
            parts.append(left.join(right))
            pass
        assert len(parts) == 1
        return parts[0].to_image()

    def pick_nearest(self, parts):
        assert len(parts) > 1
        queue = []
        for a, b in itertools.combinations(parts, 2):
            queue.append((self.evaluate(a, b), a, b))
            queue.append((self.evaluate(b, a), b, a))
            pass
        assert len(queue) >= 2
        closeness, left, right = min(queue, key=lambda tup: tup[0])
        return (left, right)

    def evaluate(self, left, right):
        return self.evaluator(left.rights, right.lefts, self.height)
    pass


class Closeness(object):
    "collection of closeness methods"
    @classmethod
    def by_rgb_ave(cls, lcolors, rcolors, size):
        # rgb diff average: bad result
        return sum(Math.vect_dist(lp, rp, 3) 
                   for lp, rp in zip(lcolors, rcolors)) 

    @classmethod
    def by_hsv_ave(cls, lcolors, rcolors, size):
        # hsv diff average: good result
        # (lighting value similarity may be important)
        return sum(Math.vect_dist(Math.hsv(lp), Math.hsv(rp), 3) 
                   for lp, rp in zip(lcolors, rcolors)) 

    @classmethod
    def by_rgb_med(cls, lcolors, rcolors, size):
        # rgb diff median: good result
        # (may ignore untypical points)
        r = [Math.vect_dist(lp, rp, 3) 
             for lp, rp in zip(lcolors, rcolors)]
        r.sort()
        return r[len(r) // 2]

    @classmethod
    def by_rgb_hist(cls, lcolors, rcolors, size):
        # histogram distance of rgb: bad result
        ls = [Math.bs2fs(c) for c in lcolors]
        rs = [Math.bs2fs(c) for c in rcolors]
        return Math.hist_dist(ls, rs, 64, 3)

    @classmethod
    def by_hsv_hist(cls, lcolors, rcolors, size):
        # histogram distance of hsv: good result(16, 64)
        ls = [Math.hsv(c) for c in lcolors]
        rs = [Math.hsv(c) for c in rcolors]
        return Math.hist_dist(ls, rs, 64, 3)
    pass


class ShredSize(object):
    "Guess shred size"
    def __init__(self, image, evaluator=None):
        self.image = image
        self.width = image.size[0]
        self.height = image.size[1]
        self.lines = [[self.image.getpixel((x, y)) 
                       for y in range(image.size[1])]
                      for x in range(image.size[0])]
        self.evaluator = evaluator or Closeness.by_hsv_hist
        pass

    def guess(self):
        minshred = max(self.width // 100, 2)
        dists = [(self.evaluate(rx - 1, rx), rx)
                 for rx in range(1, self.width)]
        dists.sort(key=lambda t: t[0], reverse=True)
        assert len(dists) > 0
        size = dists[0][1]
        for i in range(1, len(dists)):
            val = Math.gcd(size, dists[i][1])
            if val < minshred: 
                break
            size = val
            pass
        return size

    def evaluate(self, lx, rx):
         return self.evaluator(self.lines[lx], self.lines[rx], self.height)
    pass


class Math(object):
    "misc math functions"
    @classmethod
    def bs2fs(cls, bs):
        assert all(0 <= v and v < 256 for v in bs)
        return tuple(v / 255.0 for v in bs)

    @classmethod
    def hsv(cls, rgba):
        assert all(0 <= v and v < 256 for v in rgba)
        frgb = cls.bs2fs(rgba)
        return colorsys.rgb_to_hsv(frgb[0], frgb[1], frgb[2])

    @classmethod
    def log2(cls, n):
        assert n >= 0
        r = 0
        p2 = 1
        while p2 < n:
            r += 1
            p2 = p2 * 2
            pass
        return r

    @classmethod
    def gcd(cls, a, b):
        assert a >= 0 and b >= 0
        while a != b:
            if a > b: a = a - b
            if b > a: b = b - a
            pass
        return a

    @classmethod
    def hist(cls, vects, histsize, paramsize):
        m = histsize - 1
        hist = [[0] * histsize for k in range(paramsize)]
        for v in vects:
            for k in range(paramsize):
                hist[k][int(v[k] * m)] += 1
                pass
            pass
        return hist

    @classmethod
    def vect_dist(cls, va, vb, size):
        return sum((va[i] - vb[i]) ** 2 for i in range(size))

    @classmethod
    def vect_log_dist(cls, va, vb, size):
        return sum((cls.log2(va[i]) - cls.log2(vb[i])) ** 2 
                   for i in range(size))

    @classmethod
    def hist_dist(cls, va, vb, histsize, paramsize):
        ha = cls.hist(va, histsize, paramsize)
        hb = cls.hist(vb, histsize, paramsize)
        return sum(cls.vect_log_dist(ha[k], hb[k], histsize)
                   for k in range(paramsize))
    pass


if __name__ == "__main__":
    main()
    pass