"""
This class is largely created as a more versatile alternative to the Linear ND
interpolator, allowing the user to pass the interpolator an array of classes and then
perform interpolation between the results of a member function of that class. Currently
the name of the target interpolation function is passed in a string at initialisation.
TODO:
- Figure out what to do when out of bounds of interpolation range
- Implement for higher dimensions
- Create function to append points to the interpolator
"""
import numpy as np
import numpy.ma as ma
from scipy.ndimage import map_coordinates
from scipy.spatial import Delaunay
[docs]
class UnstructuredInterpolator:
    """
    This class performs linear interpolation between an unstructured set of data
    points. However the class is expanded such that it can interpolate between the
    values returned from class member functions. The primary use case of this being to
    interpolate between the predictions of a set of machine learning algorithms or
    regular grid interpolators.
    In the case that a numpy array is passed as the interpolation values this class will
    behave exactly the same as the scipy LinearNDInterpolator
    """
    def __init__(
        self,
        keys,
        values,
        function_name=None,
        remember_last=False,
        bounds=None,
        dtype=None,
    ):
        """
        Parameters
        ----------
        keys: ndarray
            Interpolation grid points
        values: ndarray
            Interpolation values
        function_name: str
            Name of class member function to call in the case we are interpolating
            between class predictions, for numpy arrays leave blank
        """
        self.keys = keys
        if dtype:
            self.values = np.array(values, dtype=dtype)
        else:
            self.values = np.array(values)
        self._n_dimensions = len(self.keys[0])
        # create an object with triangulation
        self._tri = Delaunay(self.keys)
        self._function_name = function_name
        # OK this code is horrid and will need fixing
        self._numpy_input = (
            isinstance(self.values[0], np.ndarray)
            or issubclass(type(self.values[0]), np.floating)
            or issubclass(type(self.values[0]), np.integer)
        )
        if self._numpy_input is False and function_name is None:
            self._function_name = "__call__"
        self._remember = remember_last
        if bounds is not None:
            bounds = np.array(bounds, dtype=dtype)
            self._bounds = bounds
            # Calculate the scaling factor to convert from bin number to real
            # coordinates for all axes
            scale = []
            table_shape = self.values[0].shape
            for i in range(bounds.shape[0]):
                scale_dimemsion = bounds[i][1] - bounds[i][0]
                scale_dimemsion = scale_dimemsion / float(table_shape[i] - 1)
                scale.append(scale_dimemsion)
            self.scale = np.array(scale, dtype=dtype)
        self._previous_v = None
        self._previous_m = None
        self._previous_shape = None
        self._previous_hull = None
        self._previous_points = None
        self.reset()
[docs]
    def reset(self):
        """
        Function used to reset some class values stored after previous event,
        also used as their initialisation in the init function
        """
        self._previous_v = None
        self._previous_m = None
        self._previous_shape = None
        self._previous_hull = None
        self._previous_points = None 
[docs]
    def __call__(self, points, eval_points=None):
        # Convert to a numpy array here in case we get a list
        points = np.array(points, dtype=np.float32)
        if eval_points is not None:
            eval_points = eval_points.astype(np.float32)
        if len(points.shape) == 1:
            points = np.array([points])
        # First find simplexes that contain interpolated points
        if self._remember and self._previous_v is not None:
            # We have a few different options here in the case that the points we get are similar
            # to the last set given
            # Our first check is if the interpolation points are exactly the same
            # in this case just use the previous set of vertices
            if np.all(points == self._previous_points):
                v = self._previous_v
                m = self._previous_m
            else:
                # If not we can check if our point set exists within the previous
                # simplax
                previous_keys = self.keys[self._previous_v.ravel()]
                self._previous_points = points
                if self._previous_hull is None:
                    hull = Delaunay(previous_keys)
                    self._previous_hull = hull
                else:
                    hull = self._previous_hull
                if np.all(eval_points is not None):
                    shape_check = eval_points.shape == self._previous_shape
                else:
                    shape_check = True
                # If it does then we can use this simplex
                if np.all(hull.find_simplex(points) >= 0) and shape_check:
                    v = self._previous_v
                    m = self._previous_m
                # If not we have to search through our point space
                else:
                    s = self._tri.find_simplex(points)
                    v = self._tri.simplices[s]
                    m = self._tri.transform[s]
                    self._previous_v = v
                    self._previous_m = m
                    self._previous_hull = None
                    if np.all(eval_points is not None):
                        self._previous_shape = eval_points.shape
        # If remember last is disabled we search our point space for every attempt
        else:
            s = self._tri.find_simplex(points)
            # get the vertices for each simplex
            v = self._tri.simplices[s]
            # get transform matrices for each simplex
            m = self._tri.transform[s]
            self._previous_v = v
            self._previous_m = m
            self._previous_points = points
            if np.all(eval_points is not None):
                self._previous_shape = eval_points.shape
        # Here comes some serious numpy magic, it could be done with a loop but would
        # be pretty inefficient I had to rip this from stack overflow - RDP
        # For each interpolated point, take the the transform matrix and multiply it by
        # the vector p-r, where r=m[:,n,:] is one of the simplex vertices to which
        # the matrix m is related to
        b = np.einsum(
            "ijk,ik->ij",
            m[:, : self._n_dimensions, : self._n_dimensions],
            points - m[:, self._n_dimensions, :],
        )
        # Use the above array to get the weights for the vertices; `b` contains an
        # n-dimensional vector with weights for all but the last vertices of the simplex
        # (note that for n-D grid, each simplex consists of n+1 vertices);
        # the remaining weight for the last vertex can be copmuted from
        # the condition that sum of weights must be equal to 1
        w = np.c_[b, 1 - b.sum(axis=1)]
        if self._numpy_input:
            if eval_points is None:
                selected_points = self.values[v]
            else:
                selected_points = self._numpy_interpolation(v, eval_points)
        else:
            selected_points = self._call_class_function(v, eval_points)
        # Multiply point values by weight
        p_values = np.einsum("ij...,ij...->i...", selected_points, w)
        return p_values 
    def _call_class_function(self, point_num, eval_points):
        """
        Function to loop over class function and return array of outputs
        Parameters
        ----------
        point_num: int
            Index of class position in values list
        eval_points: ndarray
            Inputs used to evaluate class member function
        Returns
        -------
        ndarray: output from member function
        """
        outputs = list()
        shape = point_num.shape
        three_dim = False
        if len(eval_points.shape) > 2:
            first_index = np.arange(point_num.shape[0])[..., np.newaxis] * np.ones_like(
                point_num
            )
            first_index = first_index.ravel()
            three_dim = True
        num = 0
        for pt in point_num.ravel():
            cls = self.values[pt]
            cls_function = getattr(cls, self._function_name)
            pt = eval_points
            if three_dim:
                pt = eval_points[first_index[num]]
            outputs.append(cls_function(pt))
            num += 1
        outputs = np.array(outputs)
        new_shape = (*shape, *outputs.shape[1:])
        outputs = outputs.reshape(new_shape)
        return outputs
    def _numpy_interpolation(self, point_num, eval_points):
        """
        Perform 2D interpolation of numpy array
        Parameters
        ----------
        point_num: int
        Index of class position in values list
        eval_points: ndarray
        Inputs used to evaluate class member function
        Returns
        -------
        ndarray: output from member function
        """
        # Check if our array is masked and remember its shape
        is_masked = ma.is_masked(eval_points)
        shape = point_num.shape
        # Get the list of templates that we want to interpolate between
        ev_shape = eval_points.shape
        vals = self.values[point_num.ravel()]
        # Scale the template x and y axes to convert into bin coordinates
        scaled_points = eval_points.T
        scaled_points[0] = (scaled_points[0] - self._bounds[0][0]) / self.scale[0]
        scaled_points[1] = (scaled_points[1] - self._bounds[1][0]) / self.scale[1]
        eval_points = scaled_points.T
        # This gets a bit ugly now but the general logic is...
        # for each point in the phase space repeat the x-y points by the number
        # of points which define the simplex
        eval_points = np.repeat(eval_points, shape[1], axis=0)
        it = np.arange(eval_points.shape[0])
        it = np.repeat(it, eval_points.shape[1], axis=0)
        eval_points = eval_points.reshape(
            eval_points.shape[0] * eval_points.shape[1], eval_points.shape[-1]
        )
        # Make a mask to be sure masked array values are not included
        scaled_points = eval_points.T
        if is_masked:
            mask = np.invert(ma.getmask(scaled_points[0]))
        else:
            mask = np.zeros_like(scaled_points[0], dtype=bool)
        it = ma.masked_array(it, mask)
        if not is_masked:
            mask = ~mask
        # Then stack up all the templates and do a 3d interpolation of points in all templates
        scaled_points = np.vstack((it, scaled_points))
        scaled_points = scaled_points.astype(self.values.dtype)
        output = np.zeros(scaled_points.T.shape[:-1])
        output = map_coordinates(vals, scaled_points, order=1, output=self.values.dtype)
        new_shape = (*shape, ev_shape[-2])
        output = output.reshape(new_shape)
        # Return a masked array of interpolated values
        return ma.masked_array(output, mask=mask, dtype=self.values.dtype)