# -*- coding: utf-8 -*-
"""This module defines a jit function that wraps the jit function for
the chosen form compiler."""

# Copyright (C) 2008-2016 Johan Hake
#
# This file is part of DOLFIN.
#
# DOLFIN is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# DOLFIN is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with DOLFIN. If not, see <http://www.gnu.org/licenses/>.
#
# Modified by Anders Logg, 2008-2009.
# Modified by Garth N. Wells, 2011.
# Modified by Martin Sandve Alnæs, 2014-2016.

from __future__ import print_function

__all__ = ["jit"]

from functools import wraps
import sys
import traceback
import six

import ufl
import ffc

# Import SWIG-generated extension module (DOLFIN C++)
import dolfin.cpp as cpp
from dolfin.cpp import MPI, info, parameters, warning, info
from dolfin.common.globalparameters import ffc_default_parameters


def mpi_jit_decorator(local_jit, *args, **kwargs):
    """A decorator for jit compilation

    Use this function as a decorator to any jit compiler function.  In
    a parallel run, this function will first call the jit compilation
    function on the first process. When this is done, and the module
    is in the cache, it will call the jit compiler on the remaining
    processes, which will then use the cached module.

    *Example*
        .. code-block:: python

            def jit_something(something):
                ....

    """
    @wraps(local_jit)
    def mpi_jit(*args, **kwargs):

        # Create MPI_COMM_WORLD wrapper
        mpi_comm = kwargs.get("mpi_comm")
        if mpi_comm is None:
            mpi_comm = cpp.mpi_comm_world()

        # Just call JIT compiler when running in serial
        if MPI.size(mpi_comm) == 1:
            return local_jit(*args, **kwargs)

        # Default status (0 == ok, 1 == fail)
        status = 0

        # Compile first on process 0
        root = MPI.rank(mpi_comm) == 0
        if root:
            try:
                output = local_jit(*args, **kwargs)
            except Exception as e:
                status = 1
                error_msg = str(e)

        # TODO: This would have lower overhead if using the dijitso.jit
        # features to inject a waiting callback instead of waiting out here.
        # That approach allows all processes to first look in the cache,
        # introducing a barrier only on cache miss.
        # There's also a sketch in dijitso of how to make only one
        # process per physical cache directory do the compilation.

        # Wait for the compiling process to finish and get status
        # TODO: Would be better to broadcast the status from root but this works.
        global_status = MPI.max(mpi_comm, status)

        if global_status == 0:
            # Success, call jit on all other processes
            # (this should just read the cache)
            if not root:
                output = local_jit(*args,**kwargs)
        else:
            # Fail simultaneously on all processes,
            # to allow catching the error without deadlock
            if not root:
                error_msg = "Compilation failed on root node."
            cpp.dolfin_error("jit.py",
                             "perform just-in-time compilation of form",
                             error_msg)
        return output

    # Return the decorated jit function
    return mpi_jit


@mpi_jit_decorator
def jit(ufl_object, form_compiler_parameters=None, mpi_comm=None):
    """Just-in-time compile any provided UFL Form or FiniteElement using FFC.

    Default parameters from FFC are overridden by parameters["form_compiler"]
    first and then the form_compiler_parameters argument to this function.
    """
    # Check that form is not empty (workaround for bug in UFL where
    # bilinear form 0*u*v*dx becomes the functional 0*dx)
    if isinstance(ufl_object, ufl.Form) and ufl_object.empty():
        cpp.dolfin_error("jit.py",
                         "perform just-in-time compilation of form",
                         "Form is empty. Cannot pass to JIT compiler")

    # Check DOLFIN build time UFC matches currently loaded UFC
    if cpp.__ufcsignature__ != ffc.ufc_signature():
        cpp.dolfin_error("jit.py",
                         "perform just-in-time compilation of form",
                         "DOLFIN was not compiled against matching"
                         "UFC from current FFC installation.")

    # Prepare form compiler parameters with overrides from dolfin and kwargs
    p = ffc_default_parameters()
    p.update(parameters["form_compiler"])
    p.update(form_compiler_parameters or {})

    # Execute!
    try:
        result = ffc.jit(ufl_object, parameters=p)
    except Exception as e:
        tb_text = ''.join(traceback.format_exception(*sys.exc_info()))
        cpp.dolfin_error("jit.py",
                         "perform just-in-time compilation of form",
                         "ffc.jit failed with message:\n%s" % (tb_text,))

    # Unpack result (ffc.jit returns different tuples based on input type)
    if isinstance(ufl_object, ufl.Form):
        compiled_form, module, prefix = result
        compiled_form = cpp.make_ufc_form(compiled_form)
        return compiled_form, module, prefix
    elif isinstance(ufl_object, ufl.FiniteElementBase):
        ufc_element, ufc_dofmap = result
        ufc_element = cpp.make_ufc_finite_element(ufc_element)
        ufc_dofmap = cpp.make_ufc_dofmap(ufc_dofmap)
        return ufc_element, ufc_dofmap
    elif isinstance(ufl_object, ufl.Mesh):
        ufc_coordinate_mapping = result
        ufc_coordinate_mapping = cpp.make_ufc_coordinate_mapping(ufc_coordinate_mapping)
        return ufc_coordinate_mapping
