# distutils: libraries = STD_LIBS
"""
This file contains coordinate mappings between physical coordinates and those
defined on unit elements, as well as doing the corresponding intracell
interpolation on finite element data.


"""


cimport cython
cimport numpy as np
from numpy cimport ndarray

import numpy as np

from libc.math cimport fabs

from yt.utilities.lib.autogenerated_element_samplers cimport (
    Q1Function2D,
    Q1Function3D,
    Q1Jacobian2D,
    Q1Jacobian3D,
    Q2Function2D,
    Q2Jacobian2D,
    T2Function2D,
    T2Jacobian2D,
    Tet2Function3D,
    Tet2Jacobian3D,
    W1Function3D,
    W1Jacobian3D,
)


cdef extern from "platform_dep.h":
    double fmax(double x, double y) noexcept nogil

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cdef double determinant_3x3(double* col0,
                            double* col1,
                            double* col2) noexcept nogil:
    return col0[0]*col1[1]*col2[2] - col0[0]*col1[2]*col2[1] - \
           col0[1]*col1[0]*col2[2] + col0[1]*col1[2]*col2[0] + \
           col0[2]*col1[0]*col2[1] - col0[2]*col1[1]*col2[0]


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cdef double maxnorm(double* f, int dim) noexcept nogil:
    cdef double err
    cdef int i
    err = fabs(f[0])
    for i in range(1, dim):
        err = fmax(err, fabs(f[i]))
    return err


cdef class ElementSampler:
    '''

    This is a base class for sampling the value of a finite element solution
    at an arbitrary point inside a mesh element. In general, this will be done
    by transforming the requested physical coordinate into a mapped coordinate
    system, sampling the solution in mapped coordinates, and returning the result.
    This is not to be used directly; use one of the subclasses instead.

    '''

    def __init__(self):
        self.inclusion_tol = 1.0e-8

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef void map_real_to_unit(self,
                               double* mapped_x,
                               double* vertices,
                               double* physical_x) noexcept nogil:
        pass

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef double sample_at_unit_point(self,
                                     double* coord,
                                     double* vals) noexcept nogil:
        pass

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int check_inside(self, double* mapped_coord) noexcept nogil:
        pass

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    def check_contains(self, np.float64_t[:,::1] vertices,
                             np.float64_t[:,::1] positions):
        cdef np.ndarray[np.float64_t, ndim=2] mapped_coords
        cdef np.ndarray[np.uint8_t, ndim=1] mask
        mapped_coords = self.map_reals_to_unit(vertices, positions)
        mask = np.zeros(mapped_coords.shape[0], dtype=np.uint8)
        cdef double[3] mapped_coord
        cdef int i, j
        for i in range(mapped_coords.shape[0]):
            for j in range(mapped_coords.shape[1]):
                mapped_coord[j] = mapped_coords[i, j]
            mask[i] = self.check_inside(mapped_coord)
        return mask


    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int check_mesh_lines(self, double* mapped_coord) noexcept nogil:
        pass

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef double sample_at_real_point(self,
                                     double* vertices,
                                     double* field_values,
                                     double* physical_x) noexcept nogil:
        cdef double val
        cdef double mapped_coord[4]

        self.map_real_to_unit(mapped_coord, vertices, physical_x)
        val = self.sample_at_unit_point(mapped_coord, field_values)
        return val

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    def map_reals_to_unit(self,
                          np.float64_t[:,::1] vertices,
                          np.float64_t[:,::1] positions):
        cdef double mapped_x[3]
        cdef int i, n
        # We have N vertices, which each have three components.
        cdef np.ndarray[np.float64_t, ndim=2] output_coords
        output_coords = np.zeros((positions.shape[0], positions.shape[1]), dtype="float64")
        # Now for each position, we map
        for n in range(positions.shape[0]):
            self.map_real_to_unit(mapped_x, &vertices[0,0], &positions[n, 0])
            for i in range(positions.shape[1]):
                output_coords[n, i] = mapped_x[i]
        return output_coords

    def sample_at_real_points(self,
                              np.float64_t[:,::1] vertices,
                              np.float64_t[::1] field_values,
                              np.float64_t[:,::1] positions):
        cdef np.ndarray[np.float64_t, ndim=1] output_values
        output_values = np.zeros(positions.shape[0], dtype="float64")
        for n in range(positions.shape[0]):
            output_values[n] = self.sample_at_real_point(
                &vertices[0,0], &field_values[0], &positions[n,0])
        return output_values

cdef class P1Sampler1D(ElementSampler):
    '''

    This implements sampling inside a linear, 1D element.

    '''

    def __init__(self):
        super(P1Sampler1D, self).__init__()
        self.num_mapped_coords = 1

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef void map_real_to_unit(self, double* mapped_x,
                               double* vertices, double* physical_x) noexcept nogil:
        mapped_x[0] = -1.0 + 2.0*(physical_x[0] - vertices[0]) / (vertices[1] - vertices[0])

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef double sample_at_unit_point(self,
                                     double* coord,
                                     double* vals) noexcept nogil:
        return vals[0] * (1 - coord[0]) / 2.0 + vals[1] * (1.0 + coord[0]) / 2.0

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int check_inside(self, double* mapped_coord) noexcept nogil:
        if (fabs(mapped_coord[0]) - 1.0 > self.inclusion_tol):
            return 0
        return 1


cdef class P1Sampler2D(ElementSampler):
    '''

    This implements sampling inside a linear, triangular mesh element.
    This mapping is linear and can be inverted easily. Note that this
    implementation uses triangular (or barycentric) coordinates.

    '''

    def __init__(self):
        super(P1Sampler2D, self).__init__()
        self.num_mapped_coords = 3


    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef void map_real_to_unit(self, double* mapped_x,
                               double* vertices, double* physical_x) noexcept nogil:

        cdef double[3] col0
        cdef double[3] col1
        cdef double[3] col2

        col0[0] = vertices[0]
        col0[1] = vertices[1]
        col0[2] = 1.0

        col1[0] = vertices[2]
        col1[1] = vertices[3]
        col1[2] = 1.0

        col2[0] = vertices[4]
        col2[1] = vertices[5]
        col2[2] = 1.0

        det = determinant_3x3(col0, col1, col2)

        mapped_x[0] = ((vertices[3] - vertices[5])*physical_x[0] + \
                       (vertices[4] - vertices[2])*physical_x[1] + \
                       (vertices[2]*vertices[5] - vertices[4]*vertices[3])) / det

        mapped_x[1] = ((vertices[5] - vertices[1])*physical_x[0] + \
                       (vertices[0] - vertices[4])*physical_x[1] + \
                       (vertices[4]*vertices[1] - vertices[0]*vertices[5])) / det

        mapped_x[2] = 1.0 - mapped_x[1] - mapped_x[0]


    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef double sample_at_unit_point(self,
                                     double* coord,
                                     double* vals) noexcept nogil:
        return vals[0]*coord[0] + vals[1]*coord[1] + vals[2]*coord[2]

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int check_inside(self, double* mapped_coord) noexcept nogil:
        # for triangles, we check whether all mapped_coords are
        # between 0 and 1, to within the inclusion tolerance
        cdef int i
        for i in range(3):
            if (mapped_coord[i] < -self.inclusion_tol or
                mapped_coord[i] - 1.0 > self.inclusion_tol):
                return 0
        return 1


cdef class P1Sampler3D(ElementSampler):
    '''

    This implements sampling inside a linear, tetrahedral mesh element.
    This mapping is linear and can be inverted easily.

    '''


    def __init__(self):
        super(P1Sampler3D, self).__init__()
        self.num_mapped_coords = 4

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef void map_real_to_unit(self, double* mapped_x,
                               double* vertices, double* physical_x) noexcept nogil:

        cdef int i
        cdef double d
        cdef double[3] bvec
        cdef double[3] col0
        cdef double[3] col1
        cdef double[3] col2

        # here, we express positions relative to the 4th element,
        # which is selected by vertices[9]
        for i in range(3):
            bvec[i] = physical_x[i]       - vertices[9 + i]
            col0[i] = vertices[0 + i]     - vertices[9 + i]
            col1[i] = vertices[3 + i]     - vertices[9 + i]
            col2[i] = vertices[6 + i]     - vertices[9 + i]

        d = determinant_3x3(col0, col1, col2)
        mapped_x[0] = determinant_3x3(bvec, col1, col2)/d
        mapped_x[1] = determinant_3x3(col0, bvec, col2)/d
        mapped_x[2] = determinant_3x3(col0, col1, bvec)/d
        mapped_x[3] = 1.0 - mapped_x[0] - mapped_x[1] - mapped_x[2]

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef double sample_at_unit_point(self,
                                     double* coord,
                                     double* vals) noexcept nogil:
        return vals[0]*coord[0] + vals[1]*coord[1] + \
            vals[2]*coord[2] + vals[3]*coord[3]

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int check_inside(self, double* mapped_coord) noexcept nogil:
        # for tetrahedra, we check whether all mapped coordinates
        # are within 0 and 1, to within the inclusion tolerance
        cdef int i
        for i in range(4):
            if (mapped_coord[i] < -self.inclusion_tol or
                mapped_coord[i] - 1.0 > self.inclusion_tol):
                return 0
        return 1

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int check_mesh_lines(self, double* mapped_coord) noexcept nogil:
        cdef double u, v, w
        cdef double thresh = 2.0e-2
        if mapped_coord[0] == 0:
            u = mapped_coord[1]
            v = mapped_coord[2]
            w = mapped_coord[3]
        elif mapped_coord[1] == 0:
            u = mapped_coord[2]
            v = mapped_coord[3]
            w = mapped_coord[0]
        elif mapped_coord[2] == 0:
            u = mapped_coord[1]
            v = mapped_coord[3]
            w = mapped_coord[0]
        else:
            u = mapped_coord[1]
            v = mapped_coord[2]
            w = mapped_coord[0]
        if ((u < thresh) or
            (v < thresh) or
            (w < thresh) or
            (fabs(u - 1) < thresh) or
            (fabs(v - 1) < thresh) or
            (fabs(w - 1) < thresh)):
            return 1
        return -1


cdef class NonlinearSolveSampler3D(ElementSampler):

    '''

    This is a base class for handling element samplers that require
    a nonlinear solve to invert the mapping between coordinate systems.
    To do this, we perform Newton-Raphson iteration using a specified
    system of equations with an analytic Jacobian matrix. This solver
    is hard-coded for 3D for reasons of efficiency. This is not to be
    used directly, use one of the subclasses instead.

    '''

    def __init__(self):
        super(NonlinearSolveSampler3D, self).__init__()
        self.tolerance = 1.0e-9
        self.max_iter = 10

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef void map_real_to_unit(self,
                               double* mapped_x,
                               double* vertices,
                               double* physical_x) noexcept nogil:
        '''

        A thorough description of Newton's method and modifications for global
        convergence can be found in Dennis's text "Numerical Methods for
        Unconstrained Optimization and Nonlinear Equations."

        x: solution vector; holds unit/mapped coordinates
        xk: temporary vector for holding solution of current iteration
        f: residual vector
        r, s, t: three columns of Jacobian matrix corresponding to unit/mapped
                 coordinates r, s, and t
        d: Jacobian determinant
        s_n: Newton step vector
        lam: fraction of Newton step by which to change x
        alpha: constant proportional to how much residual required to decrease.
               1e-4 is value of alpha recommended by Dennis
        err_c: Error of current iteration
        err_plus: Error of next iteration
        min_lam: minimum fraction of Newton step that the line search is allowed
                 to take. General experience suggests that lambda values smaller
                 than 1e-3 will not significantly reduce the residual, but we
                 set to 1e-6 just to be safe
        '''

        cdef int i
        cdef double d, lam
        cdef double[3] f
        cdef double[3] r
        cdef double[3] s
        cdef double[3] t
        cdef double[3] x, xk, s_n
        cdef int iterations = 0
        cdef double err_c, err_plus
        cdef double alpha = 1e-4
        cdef double min_lam = 1e-6

        # initial guess
        for i in range(3):
            x[i] = 0.0

        # initial error norm
        self.func(f, x, vertices, physical_x)
        err_c = maxnorm(f, 3)

        # begin Newton iteration
        while (err_c > self.tolerance and iterations < self.max_iter):
            self.jac(r, s, t, x, vertices, physical_x)
            d = determinant_3x3(r, s, t)

            s_n[0] = - (determinant_3x3(f, s, t)/d)
            s_n[1] = - (determinant_3x3(r, f, t)/d)
            s_n[2] = - (determinant_3x3(r, s, f)/d)
            xk[0] = x[0] + s_n[0]
            xk[1] = x[1] + s_n[1]
            xk[2] = x[2] + s_n[2]
            self.func(f, xk, vertices, physical_x)
            err_plus = maxnorm(f, 3)

            lam = 1
            while err_plus > err_c * (1. - alpha * lam) and lam > min_lam:
                lam = lam / 2
                xk[0] = x[0] + lam * s_n[0]
                xk[1] = x[1] + lam * s_n[1]
                xk[2] = x[2] + lam * s_n[2]
                self.func(f, xk, vertices, physical_x)
                err_plus = maxnorm(f, 3)

            x[0] = xk[0]
            x[1] = xk[1]
            x[2] = xk[2]
            err_c = err_plus
            iterations += 1

        if (err_c > self.tolerance):
            # we did not converge, set bogus value
            for i in range(3):
                mapped_x[i] = -99.0
        else:
            for i in range(3):
                mapped_x[i] = x[i]


cdef class Q1Sampler3D(NonlinearSolveSampler3D):

    '''

    This implements sampling inside a 3D, linear, hexahedral mesh element.

    '''

    def __init__(self):
        super(Q1Sampler3D, self).__init__()
        self.num_mapped_coords = 3
        self.dim = 3
        self.func = Q1Function3D
        self.jac = Q1Jacobian3D

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef double sample_at_unit_point(self, double* coord, double* vals) noexcept nogil:
        cdef double F, rm, rp, sm, sp, tm, tp

        rm = 1.0 - coord[0]
        rp = 1.0 + coord[0]
        sm = 1.0 - coord[1]
        sp = 1.0 + coord[1]
        tm = 1.0 - coord[2]
        tp = 1.0 + coord[2]

        F = vals[0]*rm*sm*tm + vals[1]*rp*sm*tm + vals[2]*rp*sp*tm + vals[3]*rm*sp*tm + \
            vals[4]*rm*sm*tp + vals[5]*rp*sm*tp + vals[6]*rp*sp*tp + vals[7]*rm*sp*tp
        return 0.125*F

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int check_inside(self, double* mapped_coord) noexcept nogil:
        # for hexes, the mapped coordinates all go from -1 to 1
        # if we are inside the element.
        if (fabs(mapped_coord[0]) - 1.0 > self.inclusion_tol or
            fabs(mapped_coord[1]) - 1.0 > self.inclusion_tol or
            fabs(mapped_coord[2]) - 1.0 > self.inclusion_tol):
            return 0
        return 1

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int check_mesh_lines(self, double* mapped_coord) noexcept nogil:
        if (fabs(fabs(mapped_coord[0]) - 1.0) < 1e-1 and
            fabs(fabs(mapped_coord[1]) - 1.0) < 1e-1):
            return 1
        elif (fabs(fabs(mapped_coord[0]) - 1.0) < 1e-1 and
              fabs(fabs(mapped_coord[2]) - 1.0) < 1e-1):
            return 1
        elif (fabs(fabs(mapped_coord[1]) - 1.0) < 1e-1 and
              fabs(fabs(mapped_coord[2]) - 1.0) < 1e-1):
            return 1
        else:
            return -1


cdef class S2Sampler3D(NonlinearSolveSampler3D):

    '''

    This implements sampling inside a 3D, 20-node hexahedral mesh element.

    '''

    def __init__(self):
        super(S2Sampler3D, self).__init__()
        self.num_mapped_coords = 3
        self.dim = 3
        self.func = S2Function3D
        self.jac = S2Jacobian3D

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef double sample_at_unit_point(self, double* coord, double* vals) noexcept nogil:
        cdef double F, r, s, t, rm, rp, sm, sp, tm, tp

        r = coord[0]
        rm = 1.0 - r
        rp = 1.0 + r

        s = coord[1]
        sm = 1.0 - s
        sp = 1.0 + s

        t = coord[2]
        tm = 1.0 - t
        tp = 1.0 + t

        F = rm*sm*tm*(-r - s - t - 2.0)*vals[0] \
          + rp*sm*tm*( r - s - t - 2.0)*vals[1] \
          + rp*sp*tm*( r + s - t - 2.0)*vals[2] \
          + rm*sp*tm*(-r + s - t - 2.0)*vals[3] \
          + rm*sm*tp*(-r - s + t - 2.0)*vals[4] \
          + rp*sm*tp*( r - s + t - 2.0)*vals[5] \
          + rp*sp*tp*( r + s + t - 2.0)*vals[6] \
          + rm*sp*tp*(-r + s + t - 2.0)*vals[7] \
          + 2.0*(1.0 - r*r)*sm*tm*vals[8]  \
          + 2.0*rp*(1.0 - s*s)*tm*vals[9]  \
          + 2.0*(1.0 - r*r)*sp*tm*vals[10] \
          + 2.0*rm*(1.0 - s*s)*tm*vals[11] \
          + 2.0*rm*sm*(1.0 - t*t)*vals[12] \
          + 2.0*rp*sm*(1.0 - t*t)*vals[13] \
          + 2.0*rp*sp*(1.0 - t*t)*vals[14] \
          + 2.0*rm*sp*(1.0 - t*t)*vals[15] \
          + 2.0*(1.0 - r*r)*sm*tp*vals[16] \
          + 2.0*rp*(1.0 - s*s)*tp*vals[17] \
          + 2.0*(1.0 - r*r)*sp*tp*vals[18] \
          + 2.0*rm*(1.0 - s*s)*tp*vals[19]
        return 0.125*F

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int check_inside(self, double* mapped_coord) noexcept nogil:
        if (fabs(mapped_coord[0]) - 1.0 > self.inclusion_tol or
            fabs(mapped_coord[1]) - 1.0 > self.inclusion_tol or
            fabs(mapped_coord[2]) - 1.0 > self.inclusion_tol):
            return 0
        return 1

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int check_mesh_lines(self, double* mapped_coord) noexcept nogil:
        if (fabs(fabs(mapped_coord[0]) - 1.0) < 1e-1 and
            fabs(fabs(mapped_coord[1]) - 1.0) < 1e-1):
            return 1
        elif (fabs(fabs(mapped_coord[0]) - 1.0) < 1e-1 and
              fabs(fabs(mapped_coord[2]) - 1.0) < 1e-1):
            return 1
        elif (fabs(fabs(mapped_coord[1]) - 1.0) < 1e-1 and
              fabs(fabs(mapped_coord[2]) - 1.0) < 1e-1):
            return 1
        else:
            return -1


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cdef inline void S2Function3D(double* fx,
                              double* x,
                              double* vertices,
                              double* phys_x) noexcept nogil:
        cdef int i
        cdef double r, s, t, rm, rp, sm, sp, tm, tp

        r = x[0]
        rm = 1.0 - r
        rp = 1.0 + r

        s = x[1]
        sm = 1.0 - s
        sp = 1.0 + s

        t = x[2]
        tm = 1.0 - t
        tp = 1.0 + t

        for i in range(3):
            fx[i] = rm*sm*tm*(-r - s - t - 2.0)*vertices[0 + i]  \
                  + rp*sm*tm*( r - s - t - 2.0)*vertices[3 + i]  \
                  + rp*sp*tm*( r + s - t - 2.0)*vertices[6 + i]  \
                  + rm*sp*tm*(-r + s - t - 2.0)*vertices[9 + i]  \
                  + rm*sm*tp*(-r - s + t - 2.0)*vertices[12 + i] \
                  + rp*sm*tp*( r - s + t - 2.0)*vertices[15 + i] \
                  + rp*sp*tp*( r + s + t - 2.0)*vertices[18 + i] \
                  + rm*sp*tp*(-r + s + t - 2.0)*vertices[21 + i] \
                  + 2.0*(1.0 - r*r)*sm*tm*vertices[24 + i] \
                  + 2.0*rp*(1.0 - s*s)*tm*vertices[27 + i] \
                  + 2.0*(1.0 - r*r)*sp*tm*vertices[30 + i] \
                  + 2.0*rm*(1.0 - s*s)*tm*vertices[33 + i] \
                  + 2.0*rm*sm*(1.0 - t*t)*vertices[36 + i] \
                  + 2.0*rp*sm*(1.0 - t*t)*vertices[39 + i] \
                  + 2.0*rp*sp*(1.0 - t*t)*vertices[42 + i] \
                  + 2.0*rm*sp*(1.0 - t*t)*vertices[45 + i] \
                  + 2.0*(1.0 - r*r)*sm*tp*vertices[48 + i] \
                  + 2.0*rp*(1.0 - s*s)*tp*vertices[51 + i] \
                  + 2.0*(1.0 - r*r)*sp*tp*vertices[54 + i] \
                  + 2.0*rm*(1.0 - s*s)*tp*vertices[57 + i] \
                  - 8.0*phys_x[i]


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cdef inline void S2Jacobian3D(double* rcol,
                              double* scol,
                              double* tcol,
                              double* x,
                              double* vertices,
                              double* phys_x) noexcept nogil:
        cdef int i
        cdef double r, s, t, rm, rp, sm, sp, tm, tp

        r = x[0]
        rm = 1.0 - r
        rp = 1.0 + r

        s = x[1]
        sm = 1.0 - s
        sp = 1.0 + s

        t = x[2]
        tm = 1.0 - t
        tp = 1.0 + t

        for i in range(3):
            rcol[i] = (sm*tm*(r + s + t + 2.0) - rm*sm*tm)*vertices[0  + i] \
                    + (sm*tm*(r - s - t - 2.0) + rp*sm*tm)*vertices[3  + i] \
                    + (sp*tm*(r + s - t - 2.0) + rp*sp*tm)*vertices[6  + i] \
                    + (sp*tm*(r - s + t + 2.0) - rm*sp*tm)*vertices[9  + i] \
                    + (sm*tp*(r + s - t + 2.0) - rm*sm*tp)*vertices[12 + i] \
                    + (sm*tp*(r - s + t - 2.0) + rp*sm*tp)*vertices[15 + i] \
                    + (sp*tp*(r + s + t - 2.0) + rp*sp*tp)*vertices[18 + i] \
                    + (sp*tp*(r - s - t + 2.0) - rm*sp*tp)*vertices[21 + i] \
                    - 4.0*r*sm*tm*vertices[24 + i] \
                    + 2.0*(1.0 - s*s)*tm*vertices[27 + i] \
                    - 4.0*r*sp*tm*vertices[30 + i] \
                    - 2.0*(1.0 - s*s)*tm*vertices[33 + i] \
                    - 2.0*sm*(1.0 - t*t)*vertices[36 + i] \
                    + 2.0*sm*(1.0 - t*t)*vertices[39 + i] \
                    + 2.0*sp*(1.0 - t*t)*vertices[42 + i] \
                    - 2.0*sp*(1.0 - t*t)*vertices[45 + i] \
                    - 4.0*r*sm*tp*vertices[48 + i] \
                    + 2.0*(1.0 - s*s)*tp*vertices[51 + i] \
                    - 4.0*r*sp*tp*vertices[54 + i] \
                    - 2.0*(1.0 - s*s)*tp*vertices[57 + i]
            scol[i] = ( rm*tm*(r + s + t + 2.0) - rm*sm*tm)*vertices[0  + i] \
                    + (-rp*tm*(r - s - t - 2.0) - rp*sm*tm)*vertices[3  + i] \
                    + ( rp*tm*(r + s - t - 2.0) + rp*sp*tm)*vertices[6  + i] \
                    + (-rm*tm*(r - s + t + 2.0) + rm*sp*tm)*vertices[9  + i] \
                    + ( rm*tp*(r + s - t + 2.0) - rm*sm*tp)*vertices[12 + i] \
                    + (-rp*tp*(r - s + t - 2.0) - rp*sm*tp)*vertices[15 + i] \
                    + ( rp*tp*(r + s + t - 2.0) + rp*sp*tp)*vertices[18 + i] \
                    + (-rm*tp*(r - s - t + 2.0) + rm*sp*tp)*vertices[21 + i] \
                    - 2.0*(1.0 - r*r)*tm*vertices[24 + i] \
                    - 4.0*rp*s*tm*vertices[27 + i] \
                    + 2.0*(1.0 - r*r)*tm*vertices[30 + i] \
                    - 4.0*rm*s*tm*vertices[33 + i] \
                    - 2.0*rm*(1.0 - t*t)*vertices[36 + i] \
                    - 2.0*rp*(1.0 - t*t)*vertices[39 + i] \
                    + 2.0*rp*(1.0 - t*t)*vertices[42 + i] \
                    + 2.0*rm*(1.0 - t*t)*vertices[45 + i] \
                    - 2.0*(1.0 - r*r)*tp*vertices[48 + i] \
                    - 4.0*rp*s*tp*vertices[51 + i] \
                    + 2.0*(1.0 - r*r)*tp*vertices[54 + i] \
                    - 4.0*rm*s*tp*vertices[57 + i]
            tcol[i] = ( rm*sm*(r + s + t + 2.0) - rm*sm*tm)*vertices[0  + i] \
                    + (-rp*sm*(r - s - t - 2.0) - rp*sm*tm)*vertices[3  + i] \
                    + (-rp*sp*(r + s - t - 2.0) - rp*sp*tm)*vertices[6  + i] \
                    + ( rm*sp*(r - s + t + 2.0) - rm*sp*tm)*vertices[9  + i] \
                    + (-rm*sm*(r + s - t + 2.0) + rm*sm*tp)*vertices[12 + i] \
                    + ( rp*sm*(r - s + t - 2.0) + rp*sm*tp)*vertices[15 + i] \
                    + ( rp*sp*(r + s + t - 2.0) + rp*sp*tp)*vertices[18 + i] \
                    + (-rm*sp*(r - s - t + 2.0) + rm*sp*tp)*vertices[21 + i] \
                    - 2.0*(1.0 - r*r)*sm*vertices[24 + i] \
                    - 2.0*rp*(1.0 - s*s)*vertices[27 + i] \
                    - 2.0*(1.0 - r*r)*sp*vertices[30 + i] \
                    - 2.0*rm*(1.0 - s*s)*vertices[33 + i] \
                    - 4.0*rm*sm*t*vertices[36 + i] \
                    - 4.0*rp*sm*t*vertices[39 + i] \
                    - 4.0*rp*sp*t*vertices[42 + i] \
                    - 4.0*rm*sp*t*vertices[45 + i] \
                    + 2.0*(1.0 - r*r)*sm*vertices[48 + i] \
                    + 2.0*rp*(1.0 - s*s)*vertices[51 + i] \
                    + 2.0*(1.0 - r*r)*sp*vertices[54 + i] \
                    + 2.0*rm*(1.0 - s*s)*vertices[57 + i]


cdef class W1Sampler3D(NonlinearSolveSampler3D):

    '''

    This implements sampling inside a 3D, linear, wedge mesh element.

    '''

    def __init__(self):
        super(W1Sampler3D, self).__init__()
        self.num_mapped_coords = 3
        self.dim = 3
        self.func = W1Function3D
        self.jac = W1Jacobian3D

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef double sample_at_unit_point(self, double* coord, double* vals) noexcept nogil:
        cdef double F

        F = vals[0]*(1.0 - coord[0] - coord[1])*(1.0 - coord[2]) + \
            vals[1]*coord[0]*(1.0 - coord[2]) + \
            vals[2]*coord[1]*(1.0 - coord[2]) + \
            vals[3]*(1.0 - coord[0] - coord[1])*(1.0 + coord[2]) + \
            vals[4]*coord[0]*(1.0 + coord[2]) + \
            vals[5]*coord[1]*(1.0 + coord[2])

        return F / 2.0

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int check_inside(self, double* mapped_coord) noexcept nogil:
        # for wedges the bounds of the mapped coordinates are:
        #     0 <= mapped_coord[0] <= 1 - mapped_coord[1]
        #     0 <= mapped_coord[1]
        #    -1 <= mapped_coord[2] <= 1
        if (mapped_coord[0] < -self.inclusion_tol or
            mapped_coord[0] + mapped_coord[1] - 1.0 > self.inclusion_tol):
            return 0
        if (mapped_coord[1] < -self.inclusion_tol):
            return 0
        if (fabs(mapped_coord[2]) - 1.0 > self.inclusion_tol):
            return 0
        return 1

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int check_mesh_lines(self, double* mapped_coord) noexcept nogil:
        cdef double r, s
        cdef double thresh = 5.0e-2
        r = mapped_coord[0]
        s = mapped_coord[1]

        cdef int near_edge_r, near_edge_s, near_edge_t
        near_edge_r = (r < thresh) or (fabs(r + s - 1.0) < thresh)
        near_edge_s = (s < thresh)
        near_edge_t = fabs(fabs(mapped_coord[2]) - 1.0) < thresh

        # we use ray.instID to pass back whether the ray is near the
        # element boundary or not (used to annotate mesh lines)
        if (near_edge_r and near_edge_s):
            return 1
        elif (near_edge_r and near_edge_t):
            return 1
        elif (near_edge_s and near_edge_t):
            return 1
        else:
            return -1


cdef class NonlinearSolveSampler2D(ElementSampler):

    '''

    This is a base class for handling element samplers that require
    a nonlinear solve to invert the mapping between coordinate systems.
    To do this, we perform Newton-Raphson iteration using a specified
    system of equations with an analytic Jacobian matrix. This solver
    is hard-coded for 2D for reasons of efficiency. This is not to be
    used directly, use one of the subclasses instead.

    '''

    def __init__(self):
        super(NonlinearSolveSampler2D, self).__init__()
        self.tolerance = 1.0e-9
        self.max_iter = 10

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef void map_real_to_unit(self,
                               double* mapped_x,
                               double* vertices,
                               double* physical_x) noexcept nogil:
        '''

        A thorough description of Newton's method and modifications for global
        convergence can be found in Dennis's text "Numerical Methods for
        Unconstrained Optimization and Nonlinear Equations."

        x: solution vector; holds unit/mapped coordinates
        xk: temporary vector for holding solution of current iteration
        f: residual vector
        A: Jacobian matrix (derivative of residual vector wrt x)
        d: Jacobian determinant
        s_n: Newton step vector
        lam: fraction of Newton step by which to change x
        alpha: constant proportional to how much residual required to decrease.
               1e-4 is value of alpha recommended by Dennis
        err_c: Error of current iteration
        err_plus: Error of next iteration
        min_lam: minimum fraction of Newton step that the line search is allowed
                 to take. General experience suggests that lambda values smaller
                 than 1e-3 will not significantly reduce the residual, but we
                 set to 1e-6 just to be safe
        '''

        cdef int i
        cdef double d, lam
        cdef double[2] f
        cdef double[2] x, xk, s_n
        cdef double[4] A
        cdef int iterations = 0
        cdef double err_c, err_plus
        cdef double alpha = 1e-4
        cdef double min_lam = 1e-6

        # initial guess
        for i in range(2):
            x[i] = 0.0

        # initial error norm
        self.func(f, x, vertices, physical_x)
        err_c = maxnorm(f, 2)

        # begin Newton iteration
        while (err_c > self.tolerance and iterations < self.max_iter):
            self.jac(&A[0], &A[2], x, vertices, physical_x)
            d = (A[0]*A[3] - A[1]*A[2])

            s_n[0] = -( A[3]*f[0] - A[2]*f[1]) / d
            s_n[1] = -(-A[1]*f[0] + A[0]*f[1]) / d
            xk[0] = x[0] + s_n[0]
            xk[1] = x[1] + s_n[1]
            self.func(f, xk, vertices, physical_x)
            err_plus = maxnorm(f, 2)

            lam = 1
            while err_plus > err_c * (1. - alpha * lam) and lam > min_lam:
                lam = lam / 2
                xk[0] = x[0] + lam * s_n[0]
                xk[1] = x[1] + lam * s_n[1]
                self.func(f, xk, vertices, physical_x)
                err_plus = maxnorm(f, 2)

            x[0] = xk[0]
            x[1] = xk[1]
            err_c = err_plus
            iterations += 1

        if (err_c > self.tolerance):
            # we did not converge, set bogus value
            for i in range(2):
                mapped_x[i] = -99.0
        else:
            for i in range(2):
                mapped_x[i] = x[i]


cdef class Q1Sampler2D(NonlinearSolveSampler2D):

    '''

    This implements sampling inside a 2D, linear, quadrilateral mesh element.

    '''

    def __init__(self):
        super(Q1Sampler2D, self).__init__()
        self.num_mapped_coords = 2
        self.dim = 2
        self.func = Q1Function2D
        self.jac = Q1Jacobian2D

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef double sample_at_unit_point(self, double* coord, double* vals) noexcept nogil:
        cdef double F, rm, rp, sm, sp

        rm = 1.0 - coord[0]
        rp = 1.0 + coord[0]
        sm = 1.0 - coord[1]
        sp = 1.0 + coord[1]

        F = vals[0]*rm*sm + vals[1]*rp*sm + vals[2]*rp*sp + vals[3]*rm*sp
        return 0.25*F

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int check_inside(self, double* mapped_coord) noexcept nogil:
        # for quads, we check whether the mapped_coord is between
        # -1 and 1 in both directions.
        if (fabs(mapped_coord[0]) - 1.0 > self.inclusion_tol or
            fabs(mapped_coord[1]) - 1.0 > self.inclusion_tol):
            return 0
        return 1

cdef class Q2Sampler2D(NonlinearSolveSampler2D):

    '''

    This implements sampling inside a 2D, quadratic, quadrilateral mesh element.

    '''

    def __init__(self):
        super(Q2Sampler2D, self).__init__()
        self.num_mapped_coords = 2
        self.dim = 2
        self.func = Q2Function2D
        self.jac = Q2Jacobian2D

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef double sample_at_unit_point(self, double* coord, double* vals) noexcept nogil:
        cdef double[9] phi
        cdef double rv = 0

        zet = coord[0]
        eta = coord[1]
        zetm = coord[0] - 1.
        zetp = coord[0] + 1.
        etam = coord[1] - 1.
        etap = coord[1] + 1.

        phi[0] = zet * zetm * eta * etam / 4.
        phi[1] = zet * zetp * eta * etam / 4.
        phi[2] = zet * zetp * eta * etap / 4.
        phi[3] = zet * zetm * eta * etap / 4.
        phi[4] = zetp * zetm * eta * etam / -2.
        phi[5] = zet * zetp * etap * etam / -2.
        phi[6] = zetp * zetm * eta * etap / -2.
        phi[7] = zet * zetm * etap * etam / -2.
        phi[8] = zetp * zetm * etap * etam

        for i in range(9):
            rv += vals[i] * phi[i]

        return rv

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int check_inside(self, double* mapped_coord) noexcept nogil:
        # for quads, we check whether the mapped_coord is between
        # -1 and 1 in both directions.
        if (fabs(mapped_coord[0]) - 1.0 > self.inclusion_tol or
            fabs(mapped_coord[1]) - 1.0 > self.inclusion_tol):
            return 0
        return 1


cdef class T2Sampler2D(NonlinearSolveSampler2D):

    '''

    This implements sampling inside a 2D, quadratic, triangular mesh
    element. Note that this implementation uses canonical coordinates.

    '''

    def __init__(self):
        super(T2Sampler2D, self).__init__()
        self.num_mapped_coords = 2
        self.dim = 2
        self.func = T2Function2D
        self.jac = T2Jacobian2D

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef double sample_at_unit_point(self, double* coord, double* vals) noexcept nogil:
        cdef double phi0, phi1, phi2, phi3, phi4, phi5, c0sq, c1sq, c0c1

        c0sq = coord[0] * coord[0]
        c1sq = coord[1] * coord[1]
        c0c1 = coord[0] * coord[1]

        phi0 = 1 - 3 * coord[0] + 2 * c0sq - 3 * coord[1] + \
               2 * c1sq + 4 * c0c1
        phi1 = -coord[0] + 2 * c0sq
        phi2 = -coord[1] + 2 * c1sq
        phi3 = 4 * coord[0] - 4 * c0sq - 4 * c0c1
        phi4 = 4 * c0c1
        phi5 = 4 * coord[1] - 4 * c1sq - 4 * c0c1

        return vals[0]*phi0 + vals[1]*phi1 + vals[2]*phi2 + vals[3]*phi3 + \
               vals[4]*phi4 + vals[5]*phi5

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int check_inside(self, double* mapped_coord) noexcept nogil:
        # for canonical tris, we check whether the mapped_coords are between
        # 0 and 1.
        if (mapped_coord[0] < -self.inclusion_tol or \
            mapped_coord[1] < -self.inclusion_tol or \
            mapped_coord[0] + mapped_coord[1] - 1.0 > self.inclusion_tol):
            return 0
        return 1

cdef class Tet2Sampler3D(NonlinearSolveSampler3D):

    '''

    This implements sampling inside a 3D, quadratic, tetrahedral mesh
    element. Note that this implementation uses canonical coordinates.

    '''

    def __init__(self):
        super(Tet2Sampler3D, self).__init__()
        self.num_mapped_coords = 3
        self.dim = 3
        self.func = Tet2Function3D
        self.jac = Tet2Jacobian3D

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef double sample_at_unit_point(self, double* coord, double* vals) noexcept nogil:
        cdef double[10] phi
        cdef double coordsq[3]
        cdef int i
        cdef double return_value = 0

        for i in range(3):
            coordsq[i] = coord[i] * coord[i]

        phi[0] = 1 - 3 * coord[0] + 2 * coordsq[0] - 3 * coord[1] + \
                 2 * coordsq[1] - 3 * coord[2] + 2 * coordsq[2] + \
                 4 * coord[0] * coord[1] + 4 * coord[0] * coord[2] + \
                 4 * coord[1] * coord[2]
        phi[1] = -coord[0] + 2 * coordsq[0]
        phi[2] = -coord[1] + 2 * coordsq[1]
        phi[3] = -coord[2] + 2 * coordsq[2]
        phi[4] = 4 * coord[0] - 4 * coordsq[0] - 4 * coord[0] * coord[1] - \
                 4 * coord[0] * coord[2]
        phi[5] = 4 * coord[0] * coord[1]
        phi[6] = 4 * coord[1] - 4 * coordsq[1] - 4 * coord[0] * coord[1] - \
                 4 * coord[1] * coord[2]
        phi[7] = 4 * coord[2] - 4 * coordsq[2] - 4 * coord[2] * coord[0] - \
                 4 * coord[2] * coord[1]
        phi[8] = 4 * coord[0] * coord[2]
        phi[9] = 4 * coord[1] * coord[2]

        for i in range(10):
            return_value += phi[i] * vals[i]

        return return_value

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int check_inside(self, double* mapped_coord) noexcept nogil:
        # for canonical tets, we check whether the mapped_coords are between
        # 0 and 1.
        if (mapped_coord[0] < -self.inclusion_tol or \
            mapped_coord[1] < -self.inclusion_tol or \
            mapped_coord[2] < -self.inclusion_tol or \
            mapped_coord[0] + mapped_coord[1] + mapped_coord[2] - 1.0 > \
            self.inclusion_tol):
            return 0
        return 1

    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    cdef int check_mesh_lines(self, double* mapped_coord) noexcept nogil:
        cdef double u, v
        cdef double thresh = 2.0e-2
        if mapped_coord[0] == 0:
            u = mapped_coord[1]
            v = mapped_coord[2]
        elif mapped_coord[1] == 0:
            u = mapped_coord[2]
            v = mapped_coord[0]
        elif mapped_coord[2] == 0:
            u = mapped_coord[1]
            v = mapped_coord[0]
        else:
            u = mapped_coord[1]
            v = mapped_coord[2]
        if ((u < thresh) or
            (v < thresh) or
            (fabs(u - 1) < thresh) or
            (fabs(v - 1) < thresh)):
            return 1
        return -1

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def test_hex_sampler(np.ndarray[np.float64_t, ndim=2] vertices,
                     np.ndarray[np.float64_t, ndim=1] field_values,
                     np.ndarray[np.float64_t, ndim=1] physical_x):

    cdef double val

    cdef Q1Sampler3D sampler = Q1Sampler3D()

    val = sampler.sample_at_real_point(<double*> vertices.data,
                                       <double*> field_values.data,
                                       <double*> physical_x.data)
    return val


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def test_hex20_sampler(np.ndarray[np.float64_t, ndim=2] vertices,
                       np.ndarray[np.float64_t, ndim=1] field_values,
                       np.ndarray[np.float64_t, ndim=1] physical_x):

    cdef double val

    cdef S2Sampler3D sampler = S2Sampler3D()

    val = sampler.sample_at_real_point(<double*> vertices.data,
                                       <double*> field_values.data,
                                       <double*> physical_x.data)
    return val


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def test_tetra_sampler(np.ndarray[np.float64_t, ndim=2] vertices,
                       np.ndarray[np.float64_t, ndim=1] field_values,
                       np.ndarray[np.float64_t, ndim=1] physical_x):

    cdef double val

    sampler = P1Sampler3D()

    val = sampler.sample_at_real_point(<double*> vertices.data,
                                       <double*> field_values.data,
                                       <double*> physical_x.data)

    return val


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def test_wedge_sampler(np.ndarray[np.float64_t, ndim=2] vertices,
                       np.ndarray[np.float64_t, ndim=1] field_values,
                       np.ndarray[np.float64_t, ndim=1] physical_x):

    cdef double val

    cdef W1Sampler3D sampler = W1Sampler3D()

    val = sampler.sample_at_real_point(<double*> vertices.data,
                                       <double*> field_values.data,
                                       <double*> physical_x.data)
    return val


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def test_linear1D_sampler(np.ndarray[np.float64_t, ndim=2] vertices,
                          np.ndarray[np.float64_t, ndim=1] field_values,
                          np.ndarray[np.float64_t, ndim=1] physical_x):

    cdef double val

    sampler = P1Sampler1D()

    val = sampler.sample_at_real_point(<double*> vertices.data,
                                       <double*> field_values.data,
                                       <double*> physical_x.data)

    return val


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def test_tri_sampler(np.ndarray[np.float64_t, ndim=2] vertices,
                     np.ndarray[np.float64_t, ndim=1] field_values,
                     np.ndarray[np.float64_t, ndim=1] physical_x):

    cdef double val

    sampler = P1Sampler2D()

    val = sampler.sample_at_real_point(<double*> vertices.data,
                                       <double*> field_values.data,
                                       <double*> physical_x.data)

    return val


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def test_quad_sampler(np.ndarray[np.float64_t, ndim=2] vertices,
                      np.ndarray[np.float64_t, ndim=1] field_values,
                      np.ndarray[np.float64_t, ndim=1] physical_x):

    cdef double val

    sampler = Q1Sampler2D()

    val = sampler.sample_at_real_point(<double*> vertices.data,
                                       <double*> field_values.data,
                                       <double*> physical_x.data)

    return val

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def test_quad2_sampler(np.ndarray[np.float64_t, ndim=2] vertices,
                       np.ndarray[np.float64_t, ndim=1] field_values,
                       np.ndarray[np.float64_t, ndim=1] physical_x):

    cdef double val

    sampler = Q2Sampler2D()

    val = sampler.sample_at_real_point(<double*> vertices.data,
                                       <double*> field_values.data,
                                       <double*> physical_x.data)

    return val

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def test_tri2_sampler(np.ndarray[np.float64_t, ndim=2] vertices,
                      np.ndarray[np.float64_t, ndim=1] field_values,
                      np.ndarray[np.float64_t, ndim=1] physical_x):

    cdef double val

    sampler = T2Sampler2D()

    val = sampler.sample_at_real_point(<double*> vertices.data,
                                       <double*> field_values.data,
                                       <double*> physical_x.data)

    return val

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def test_tet2_sampler(np.ndarray[np.float64_t, ndim=2] vertices,
                      np.ndarray[np.float64_t, ndim=1] field_values,
                      np.ndarray[np.float64_t, ndim=1] physical_x):

    cdef double val

    sampler = Tet2Sampler3D()

    val = sampler.sample_at_real_point(<double*> vertices.data,
                                       <double*> field_values.data,
                                       <double*> physical_x.data)

    return val
