HOME/Articles/

spline_demo

Article Outline

Example Python program spline_demo.py

Modules

  • import scipy.interpolate as sint
  • import scipy.interpolate as si
  • import numpy as np
  • from collections import namedtuple
  • from functools import reduce
  • import matplotlib

  • import matplotlib.pyplot as plt

Classes

  • class SplineFitter(object):
  • class SplineCurve(object):

Methods

  • def init(self, ax, pix_err=1):
  • def clear(self):
  • def connect_sf(self):
  • def disconnect_sf(self):
  • def click_event(self, event):
  • def remove_pt(self, loc):
  • def redraw(self):
  • def points(self):
  • def SplineCurve(self):
  • def _get_spline(cls, points, pix_err=2, need_sort=True, **kwargs):
  • def from_pts(cls, new_pts, **kwargs):
  • def init(self, tck):
  • def write_to_hdf(self, parent_group, name=None):
  • def circ(self):
  • def cntr(self):
  • def th_offset(self):
  • def tck0(self):
  • def tck1(self):
  • def tck2(self):
  • def q_phi_to_xy(self, q, phi, cross=None):

Code

Python example

import scipy.interpolate as sint
import scipy.interpolate as si
import numpy as np
from collections import namedtuple
from functools import reduce

# uncomment this to set the backend
# import matplotlib
# matplotlib.use('Qt4Agg')
import matplotlib.pyplot as plt


class SplineFitter(object):
    def __init__(self, ax, pix_err=1):
        self.canvas = ax.get_figure().canvas
        self.cid = None
        self.pt_lst = []
        self.pt_plot = ax.plot([], [], marker='o',
                               linestyle='none', zorder=5)[0]
        self.sp_plot = ax.plot([], [], lw=3, color='r')[0]
        self.pix_err = pix_err
        self.connect_sf()

    def clear(self):
        '''Clears the points'''
        self.pt_lst = []
        self.redraw()

    def connect_sf(self):
        if self.cid is None:
            self.cid = self.canvas.mpl_connect('button_press_event',
                                               self.click_event)

    def disconnect_sf(self):
        if self.cid is not None:
            self.canvas.mpl_disconnect(self.cid)
            self.cid = None

    def click_event(self, event):
        ''' Extracts locations from the user'''
        if event.key == 'shift':
            self.clear()
            return
        if event.xdata is None or event.ydata is None:
            return
        if event.button == 1:
            self.pt_lst.append((event.xdata, event.ydata))
        elif event.button == 3:
            self.remove_pt((event.xdata, event.ydata))
        self.ev = event
        self.redraw()

    def remove_pt(self, loc):
        if len(self.pt_lst) > 0:
            self.pt_lst.pop(np.argmin(list(map(lambda x:
                                          np.sqrt((x[0] - loc[0]) ** 2 +
                                                  (x[1] - loc[1]) ** 2),
                                               self.pt_lst))))

    def redraw(self):
        if len(self.pt_lst) > 5:
            SC = SplineCurve.from_pts(self.pt_lst, pix_err=self.pix_err)
            new_pts = SC.q_phi_to_xy(0, np.linspace(0, 2 * np.pi, 1000))
            center = SC.cntr
            self.sp_plot.set_xdata(new_pts[0])
            self.sp_plot.set_ydata(new_pts[1])
            self.pt_lst.sort(key=lambda x:
                             np.arctan2(x[1] - center[1], x[0] - center[0]))
        else:
            self.sp_plot.set_xdata([])
            self.sp_plot.set_ydata([])
        if len(self.pt_lst) > 0:
            x, y = zip(*self.pt_lst)
        else:
            x, y = [], []
        self.pt_plot.set_xdata(x)
        self.pt_plot.set_ydata(y)

        self.canvas.draw_idle()

    @property
    def points(self):
        '''Returns the clicked points in the format the rest of the
        code expects'''
        return np.vstack(self.pt_lst).T

    @property
    def SplineCurve(self):
        curve = SplineCurve.from_pts(self.pt_lst, pix_err=self.pix_err)
        return curve


class SplineCurve(object):
    '''
    A class that wraps the scipy.interpolation objects
    '''
    @classmethod
    def _get_spline(cls, points, pix_err=2, need_sort=True, **kwargs):
        '''
        Returns a closed spline for the points handed in.
        Input is assumed to be a (2xN) array

        =====
        input
        =====

        :param points: the points to fit the spline to
        :type points: a 2xN ndarray or a list of len =2 tuples

        :param pix_err: the error is finding the spline in pixels
        :param need_sort: if the points need to be sorted
            or should be processed as-is

        =====
        output
        =====
        tck
           The return data from the spline fitting
        '''
        if type(points) is np.ndarray:
            # make into a list
            pt_lst = zip(*points)
            # get center
            center = np.mean(points, axis=1).reshape(2, 1)
        else:
            # make a copy of the list
            pt_lst = list(points)
            # compute center
            tmp_fun = lambda x, y: (x[0] + y[0], x[1] + y[1])
            center = np.array(reduce(tmp_fun, pt_lst)).reshape(2, 1)
            center /= len(pt_lst)
        if len(pt_lst) < 5:
            raise TooFewPointsException("not enough points")

        if need_sort:
            # sort the list by angle around center
            pt_lst.sort(key=lambda x: np.arctan2(x[1] - center[1],
                                                 x[0] - center[0]))
        # add first point to end because it is periodic (makes the
        # interpolation code happy)
        pt_lst.append(pt_lst[0])
        # make array for handing in to spline fitting
        pt_array = np.vstack(pt_lst).T
        # do spline fitting

        tck, u = si.splprep(pt_array, s=len(pt_lst) * (pix_err ** 2), per=True)
        return tck

    @classmethod
    def from_pts(cls, new_pts, **kwargs):
        tck = cls._get_spline(new_pts, **kwargs)
        this = cls(tck)
        this.raw_pts = new_pts
        return this

    def __init__(self, tck):
        '''Use `from_pts` class method to construct instance
        '''
        self.tck = tck
        self._cntr = None
        self._circ = None
        self._th_offset = None

    def write_to_hdf(self, parent_group, name=None):
        '''
        Writes out the essential data (spline of central curve) to hdf file.
        '''
        if name is not None:
            curve_group = parent_group.create_group(name)
        else:
            curve_group = parent_group
        curve_group.attrs['tck0'] = self.tck[0]
        curve_group.attrs['tck1'] = np.vstack(self.tck[1])
        curve_group.attrs['tck2'] = self.tck[2]

    @property
    def circ(self):
        '''returns a rough estimate of the circumference'''
        if self._circ is None:
            new_pts = si.splev(np.linspace(0, 1, 1000), self.tck, ext=2)
            self._circ = np.sum(np.sqrt(np.sum(np.diff(new_pts, axis=1) ** 2,
                                               axis=0)))
        return self._circ

    @property
    def cntr(self):
        '''returns a rough estimate of the circumference'''
        if self._cntr is None:
            new_pts = si.splev(np.linspace(0, 1, 1000), self.tck, ext=2)
            self._cntr = np.mean(new_pts, 1)
        return self._cntr

    @property
    def th_offset(self):
        """
        The angle from the y-axis for (x, y) at `phi=0`
        """
        if self._th_offset is None:
            x, y = self.q_phi_to_xy(0, 0) - self.cntr.reshape(2, 1)
            self._th_offset = np.arctan2(y, x)
        return self._th_offset

    @property
    def tck0(self):
        return self.tck[0]

    @property
    def tck1(self):
        return self.tck[1]

    @property
    def tck2(self):
        return self.tck[2]

    def q_phi_to_xy(self, q, phi, cross=None):
        '''Converts q, phi pairs -> x, y pairs.  All other code that
        does this should move to using this so that there is minimal
        breakage when we change over to using additive q instead of
        multiplicative'''
        # make sure data is arrays
        q = np.asarray(q)
        # convert real units -> interpolation units
        phi = np.mod(np.asarray(phi), 2 * np.pi) / (2 * np.pi)
        # get the shapes
        q_shape, phi_shape = [_.shape if (_.shape != () and
                                          len(_) > 1) else None for
                              _ in (q, phi)]

        # flatten everything
        q = q.ravel()
        phi = phi.ravel()
        # sanity checks on shapes
        if cross is False:
            if phi_shape != q_shape:
                raise ValueError("q and phi must have same" +
                                 " dimensions to broadcast")
        if cross is None:
            if ((phi_shape is not None) and (q_shape is not None)
                  and (phi_shape == q_shape)):
                cross = False
            elif q_shape is None:
                cross = False
                q = q[0]
            else:
                cross = True

        x, y = si.splev(phi, self.tck, ext=2)
        dx, dy = si.splev(phi, self.tck, der=1, ext=2)
        norm = np.sqrt(dx ** 2 + dy ** 2)
        nx, ny = dy / norm, -dx / norm

        # if cross, then
        if cross:
            data_out = zip(
                *map(lambda q_: ((x + q_ * nx).reshape(phi_shape),
                                 (y + q_ * ny).reshape(phi_shape)),
                q)
            )
        else:

            data_out = np.vstack([(x + q * nx).reshape(phi_shape),
                                  (y + q * ny).reshape(phi_shape)])

        return data_out

fig, ax = plt.subplots()
sp = SplineFitter(ax, .001)
plt.ion()
plt.show()