/*
  This file is part of CDO. CDO is a collection of Operators to manipulate and analyse Climate model Data.

  Author: Uwe Schulzweida

*/

/*
   This module contains the following operators:

      Trend      trend           Trend
*/

#include <cdi.h>

#include "process_int.h"
#include "cdo_vlist.h"
#include "cdo_options.h"
#include "datetime.h"
#include "pmlist.h"
#include "param_conversion.h"
#include "field_functions.h"
#include "arithmetic.h"

static void
trendGetParameter(bool &tstepIsEqual)
{
  auto pargc = cdo_operator_argc();
  if (pargc)
    {
      const auto &pargv = cdo_get_oper_argv();

      KVList kvlist;
      kvlist.name = cdo_module_name();
      if (kvlist.parse_arguments(pargv) != 0) cdo_abort("Parse error!");
      if (Options::cdoVerbose) kvlist.print();

      for (const auto &kv : kvlist)
        {
          const auto &key = kv.key;
          if (kv.nvalues > 1) cdo_abort("Too many values for parameter key >%s<!", key);
          if (kv.nvalues < 1) cdo_abort("Missing value for parameter key >%s<!", key);
          const auto &value = kv.values[0];

          // clang-format off
          if (key == "equal") tstepIsEqual = parameter_to_bool(value);
          else cdo_abort("Invalid parameter key >%s<!", key);
          // clang-format on
        }
    }
}

class Trend : public Process
{
public:
  using Process::Process;
  inline static CdoModule module = {
    .name = "Trend",
    .operators = { { "trend", TrendHelp } },
    .aliases = {},
    .mode = EXPOSED,     // Module mode: 0:intern 1:extern
    .number = CDI_REAL,  // Allowed number type
    .constraints = { 1, 2, OnlyFirst },
  };
  inline static RegisterEntry<Trend> registration = RegisterEntry<Trend>(module);

  static const int nwork = 5;
  FieldVector2D work[nwork];

  CdoStreamID streamID1;
  int taxisID1;

  CdoStreamID streamID2;
  int taxisID2;

  CdoStreamID streamID3;

  int maxrecs;

  bool tstepIsEqual = true;

  int calendar;

  VarList varList1;
  Field field1, field2;
  std::vector<RecordInfo> recList;

public:
  void
  init()
  {
    trendGetParameter(tstepIsEqual);

    streamID1 = cdo_open_read(0);

    auto vlistID1 = cdo_stream_inq_vlist(streamID1);
    auto vlistID2 = vlistDuplicate(vlistID1);

    vlist_unpack(vlistID2);

    vlistDefNtsteps(vlistID2, 1);

    taxisID1 = vlistInqTaxis(vlistID1);
    taxisID2 = taxisDuplicate(taxisID1);
    vlistDefTaxis(vlistID2, taxisID2);

    varList_init(varList1, vlistID1);

    auto nvars = vlistNvars(vlistID1);

    maxrecs = vlistNrecs(vlistID1);
    recList = std::vector<RecordInfo>(maxrecs);

    for (int varID = 0; varID < nvars; ++varID) vlistDefVarDatatype(vlistID2, varID, CDI_DATATYPE_FLT64);

    streamID2 = cdo_open_write(1);
    streamID3 = cdo_open_write(2);

    cdo_def_vlist(streamID2, vlistID2);
    cdo_def_vlist(streamID3, vlistID2);

    auto gridsizemax = vlistGridsizeMax(vlistID1);

    field1.resize(gridsizemax);
    field2.resize(gridsizemax);

    for (auto &w : work) fields_from_vlist(vlistID1, w, FIELD_VEC, 0);

    calendar = taxisInqCalendar(taxisID1);
  }

  void
  run()
  {
    CheckTimeIncr checkTimeIncr;
    JulianDate julianDate0;
    double deltat1 = 0.0;
    CdiDateTime vDateTime{};

    int tsID = 0;

    while (true)
      {
        auto nrecs = cdo_stream_inq_timestep(streamID1, tsID);
        if (nrecs == 0) break;

        vDateTime = taxisInqVdatetime(taxisID1);

        if (tstepIsEqual) check_time_increment(tsID, calendar, vDateTime, checkTimeIncr);
        auto zj = tstepIsEqual ? (double) tsID : delta_time_step_0(tsID, calendar, vDateTime, julianDate0, deltat1);

        for (int recID = 0; recID < nrecs; ++recID)
          {
            int varID, levelID;
            cdo_inq_record(streamID1, &varID, &levelID);

            recList[recID].set(varID, levelID);

            cdo_read_record(streamID1, field1.vec_d.data(), &field1.numMissVals);

            auto gridsize = varList1[varID].gridsize;
            auto missval = varList1[varID].missval;

            auto &sumj = work[0][varID][levelID].vec_d;
            auto &sumjj = work[1][varID][levelID].vec_d;
            auto &sumjx = work[2][varID][levelID].vec_d;
            auto &sumx = work[3][varID][levelID].vec_d;
            auto &zn = work[4][varID][levelID].vec_d;

            auto trend_sum = [&](auto i, auto value, auto is_EQ) {
              if (!is_EQ(value, missval))
                {
                  sumj[i] += zj;
                  sumjj[i] += zj * zj;
                  sumjx[i] += zj * value;
                  sumx[i] += value;
                  zn[i]++;
                }
            };

            if (std::isnan(missval))
              for (size_t i = 0; i < gridsize; ++i) trend_sum(i, field1.vec_d[i], dbl_is_equal);
            else
              for (size_t i = 0; i < gridsize; ++i) trend_sum(i, field1.vec_d[i], is_equal);
          }

        tsID++;
      }

    taxisDefVdatetime(taxisID2, vDateTime);
    cdo_def_timestep(streamID2, 0);
    cdo_def_timestep(streamID3, 0);

    for (int recID = 0; recID < maxrecs; ++recID)
      {
        auto [varID, levelID] = recList[recID].get();

        auto gridsize = varList1[varID].gridsize;
        auto missval = varList1[varID].missval;
        auto missval1 = missval;
        auto missval2 = missval;
        field1.size = gridsize;
        field1.missval = missval;
        field2.size = gridsize;
        field2.missval = missval;

        const auto &sumj = work[0][varID][levelID].vec_d;
        const auto &sumjj = work[1][varID][levelID].vec_d;
        const auto &sumjx = work[2][varID][levelID].vec_d;
        const auto &sumx = work[3][varID][levelID].vec_d;
        const auto &zn = work[4][varID][levelID].vec_d;

        auto trend_kernel = [&](auto i, auto is_EQ) {
          auto temp1 = SUBM(sumjx[i], DIVM(MULM(sumj[i], sumx[i]), zn[i]));
          auto temp2 = SUBM(sumjj[i], DIVM(MULM(sumj[i], sumj[i]), zn[i]));

          field2.vec_d[i] = DIVM(temp1, temp2);
          field1.vec_d[i] = SUBM(DIVM(sumx[i], zn[i]), MULM(DIVM(sumj[i], zn[i]), field2.vec_d[i]));
        };

        if (std::isnan(missval))
          for (size_t i = 0; i < gridsize; ++i) trend_kernel(i, dbl_is_equal);
        else
          for (size_t i = 0; i < gridsize; ++i) trend_kernel(i, is_equal);

        cdo_def_record(streamID2, varID, levelID);
        cdo_write_record(streamID2, field1.vec_d.data(), field_num_miss(field1));

        cdo_def_record(streamID3, varID, levelID);
        cdo_write_record(streamID3, field2.vec_d.data(), field_num_miss(field2));
      }
  }

  void
  close()
  {
    cdo_stream_close(streamID3);
    cdo_stream_close(streamID2);
    cdo_stream_close(streamID1);
  }
};
