#include "colmap/sfm/incremental_mapper.h"

#include "colmap/controllers/incremental_pipeline.h"

#include "pycolmap/helpers.h"
#include "pycolmap/pybind11_extension.h"

#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>

using namespace colmap;
using namespace pybind11::literals;
namespace py = pybind11;

void BindIncrementalPipeline(py::module& m) {
  using Opts = IncrementalPipelineOptions;
  auto PyOpts = py::classh<Opts>(m, "IncrementalPipelineOptions");
  PyOpts.def(py::init<>())
      .def_readwrite(
          "min_num_matches",
          &Opts::min_num_matches,
          "The minimum number of matches for inlier matches to be considered.")
      .def_readwrite(
          "ignore_watermarks",
          &Opts::ignore_watermarks,
          "Whether to ignore the inlier matches of watermark image pairs.")
      .def_readwrite("multiple_models",
                     &Opts::multiple_models,
                     "Whether to reconstruct multiple sub-models.")
      .def_readwrite("max_num_models",
                     &Opts::max_num_models,
                     "The number of sub-models to reconstruct.")
      .def_readwrite(
          "max_model_overlap",
          &Opts::max_model_overlap,
          "The maximum number of overlapping images between sub-models. If the "
          "current sub-models shares more than this number of images with "
          "another model, then the reconstruction is stopped.")
      .def_readwrite("min_model_size",
                     &Opts::min_model_size,
                     "The minimum number of registered images of a sub-model, "
                     "otherwise the sub-model is discarded. Note that the "
                     "first sub-model is always kept independent of size.")
      .def_readwrite("init_image_id1",
                     &Opts::init_image_id1,
                     "The image identifier of the first image used to "
                     "initialize the reconstruction.")
      .def_readwrite(
          "init_image_id2",
          &Opts::init_image_id2,
          "The image identifier of the second image used to initialize the "
          "reconstruction. Determined automatically if left unspecified.")
      .def_readwrite("init_num_trials",
                     &Opts::init_num_trials,
                     "The number of trials to initialize the reconstruction.")
      .def_readwrite("extract_colors",
                     &Opts::extract_colors,
                     "Whether to extract colors for reconstructed points.")
      .def_readwrite("num_threads",
                     &Opts::num_threads,
                     "The number of threads to use during reconstruction.")
      .def_readwrite(
          "random_seed",
          &Opts::random_seed,
          "PRNG seed for all stochastic methods during reconstruction.")
      .def_readwrite("min_focal_length_ratio",
                     &Opts::min_focal_length_ratio,
                     "The threshold used to filter and ignore images with "
                     "degenerate intrinsics.")
      .def_readwrite("max_focal_length_ratio",
                     &Opts::max_focal_length_ratio,
                     "The threshold used to filter and ignore images with "
                     "degenerate intrinsics.")
      .def_readwrite("max_extra_param",
                     &Opts::max_extra_param,
                     "The threshold used to filter and ignore images with "
                     "degenerate intrinsics.")
      .def_readwrite(
          "ba_refine_focal_length",
          &Opts::ba_refine_focal_length,
          "Whether to refine the focal length during the reconstruction.")
      .def_readwrite(
          "ba_refine_principal_point",
          &Opts::ba_refine_principal_point,
          "Whether to refine the principal point during the reconstruction.")
      .def_readwrite(
          "ba_refine_extra_params",
          &Opts::ba_refine_extra_params,
          "Whether to refine extra parameters during the reconstruction.")
      .def_readwrite("ba_refine_sensor_from_rig",
                     &Opts::ba_refine_sensor_from_rig,
                     "Whether to refine rig poses during the reconstruction.")
      .def_readwrite(
          "ba_min_num_residuals_for_cpu_multi_threading",
          &Opts::ba_min_num_residuals_for_cpu_multi_threading,
          "The minimum number of residuals per bundle adjustment problem to "
          "enable multi-threading solving of the problems.")
      .def_readwrite(
          "ba_local_function_tolerance",
          &Opts::ba_local_function_tolerance,
          "Ceres solver function tolerance for local bundle adjustment.")
      .def_readwrite(
          "ba_local_max_num_iterations",
          &Opts::ba_local_max_num_iterations,
          "The maximum number of local bundle adjustment iterations.")
      .def_readwrite(
          "ba_global_frames_ratio",
          &Opts::ba_global_frames_ratio,
          "The growth rates after which to perform global bundle adjustment.")
      .def_readwrite(
          "ba_global_points_ratio",
          &Opts::ba_global_points_ratio,
          "The growth rates after which to perform global bundle adjustment.")
      .def_readwrite(
          "ba_global_frames_freq",
          &Opts::ba_global_frames_freq,
          "The growth rates after which to perform global bundle adjustment.")
      .def_readwrite(
          "ba_global_points_freq",
          &Opts::ba_global_points_freq,
          "The growth rates after which to perform global bundle adjustment.")
      .def_readwrite(
          "ba_global_function_tolerance",
          &Opts::ba_global_function_tolerance,
          "Ceres solver function tolerance for global bundle adjustment.")
      .def_readwrite(
          "ba_global_max_num_iterations",
          &Opts::ba_global_max_num_iterations,
          "The maximum number of global bundle adjustment iterations.")
      .def_readwrite(
          "ba_local_max_refinements",
          &Opts::ba_local_max_refinements,
          "The thresholds for iterative bundle adjustment refinements.")
      .def_readwrite(
          "ba_local_max_refinement_change",
          &Opts::ba_local_max_refinement_change,
          "The thresholds for iterative bundle adjustment refinements.")
      .def_readwrite(
          "ba_global_max_refinements",
          &Opts::ba_global_max_refinements,
          "The thresholds for iterative bundle adjustment refinements.")
      .def_readwrite(
          "ba_global_max_refinement_change",
          &Opts::ba_global_max_refinement_change,
          "The thresholds for iterative bundle adjustment refinements.")
      .def_readwrite("ba_use_gpu",
                     &IncrementalPipelineOptions::ba_use_gpu,
                     "Whether to use Ceres' CUDA sparse linear algebra "
                     "library, if available.")
      .def_readwrite("ba_gpu_index",
                     &IncrementalPipelineOptions::ba_gpu_index,
                     "Index of CUDA GPU to use for BA, if available.")
      .def_readwrite("use_prior_position",
                     &Opts::use_prior_position,
                     "Whether to use priors on the camera positions.")
      .def_readwrite("use_robust_loss_on_prior_position",
                     &Opts::use_robust_loss_on_prior_position,
                     "Whether to use a robust loss on prior camera positions.")
      .def_readwrite("prior_position_loss_scale",
                     &Opts::prior_position_loss_scale,
                     "Threshold on the residual for the robust position prior "
                     "loss (chi2 for 3DOF at 95% = 7.815).")
      .def_readwrite("snapshot_path",
                     &Opts::snapshot_path,
                     "Path to a folder in which reconstruction snapshots will "
                     "be saved during incremental reconstruction.")
      .def_readwrite("snapshot_frames_freq",
                     &Opts::snapshot_frames_freq,
                     "Frequency of registered images according to which "
                     "reconstruction snapshots will be saved.")
      .def_readwrite(
          "image_names",
          &Opts::image_names,
          "Optional list of image names to reconstruct. If no images are "
          "specified, all images will be reconstructed by default.")
      .def_readwrite("fix_existing_frames",
                     &Opts::fix_existing_frames,
                     "If reconstruction is provided as input, fix the existing "
                     "frame poses.")
      .def_readwrite(
          "constant_rigs",
          &Opts::constant_rigs,
          "List of rigs for which to fix the sensor_from_rig transformation, "
          "independent of ba_refine_sensor_from_rig.")
      .def_readwrite("constant_cameras",
                     &Opts::constant_cameras,
                     "List of cameras for which to fix the camera parameters "
                     "independent of refine_focal_length, "
                     "refine_principal_point, and refine_extra_params.")
      .def_readwrite(
          "max_runtime_seconds",
          &Opts::max_runtime_seconds,
          "Maximum runtime in seconds for the reconstruction process. If set "
          "to a non-positive value, the process will run until completion.")
      .def_readwrite(
          "mapper", &Opts::mapper, "Options of the IncrementalMapper.")
      .def_readwrite("triangulation",
                     &Opts::triangulation,
                     "Options of the IncrementalTriangulator.")
      .def("get_mapper", &Opts::Mapper)
      .def("get_triangulation", &Opts::Triangulation)
      .def("get_local_bundle_adjustment", &Opts::LocalBundleAdjustment)
      .def("get_global_bundle_adjustment", &Opts::GlobalBundleAdjustment)
      .def("is_initial_pair_provided", &Opts::IsInitialPairProvided)
      .def("check", &Opts::Check);
  MakeDataclass(PyOpts);

  using CallbackType = IncrementalPipeline::CallbackType;
  auto PyCallbackType =
      py::enum_<CallbackType>(m, "IncrementalMapperCallback")
          .value("INITIAL_IMAGE_PAIR_REG_CALLBACK",
                 CallbackType::INITIAL_IMAGE_PAIR_REG_CALLBACK)
          .value("NEXT_IMAGE_REG_CALLBACK",
                 CallbackType::NEXT_IMAGE_REG_CALLBACK)
          .value("LAST_IMAGE_REG_CALLBACK",
                 CallbackType::LAST_IMAGE_REG_CALLBACK);
  AddStringToEnumConstructor(PyCallbackType);

  using Status = IncrementalPipeline::Status;
  auto PyStatus = py::enum_<Status>(m, "IncrementalMapperStatus")
                      .value("NO_INITIAL_PAIR", Status::NO_INITIAL_PAIR)
                      .value("BAD_INITIAL_PAIR", Status::BAD_INITIAL_PAIR)
                      .value("SUCCESS", Status::SUCCESS)
                      .value("INTERRUPTED", Status::INTERRUPTED);
  AddStringToEnumConstructor(PyStatus);

  py::classh<IncrementalPipeline>(m, "IncrementalPipeline")
      .def(py::init<std::shared_ptr<const IncrementalPipelineOptions>,
                    const std::string&,
                    const std::string&,
                    std::shared_ptr<ReconstructionManager>>(),
           "options"_a,
           "image_path"_a,
           "database_path"_a,
           "reconstruction_manager"_a)
      .def_property_readonly("options", &IncrementalPipeline::Options)
      .def_property_readonly("image_path", &IncrementalPipeline::ImagePath)
      .def_property_readonly("database_path",
                             &IncrementalPipeline::DatabasePath)
      .def_property_readonly("reconstruction_manager",
                             &IncrementalPipeline::ReconstructionManager)
      .def_property_readonly("database_cache",
                             &IncrementalPipeline::DatabaseCache)
      .def("add_callback", &IncrementalPipeline::AddCallback, "id"_a, "func"_a)
      .def("callback", &IncrementalPipeline::Callback, "id"_a)
      .def("load_database", &IncrementalPipeline::LoadDatabase)
      .def("check_run_global_refinement",
           &IncrementalPipeline::CheckRunGlobalRefinement,
           "reconstruction"_a,
           "ba_prev_num_reg_images"_a,
           "ba_prev_num_points"_a)
      .def("reconstruct",
           &IncrementalPipeline::Reconstruct,
           "mapper"_a,
           "mapper_options"_a,
           "continue_reconstruction"_a)
      .def("reconstruct_sub_model",
           &IncrementalPipeline::ReconstructSubModel,
           "mapper"_a,
           "mapper_options"_a,
           "reconstruction"_a)
      .def("initialize_reconstruction",
           &IncrementalPipeline::InitializeReconstruction,
           "mapper"_a,
           "mapper_options"_a,
           "reconstruction"_a)
      .def("run", &IncrementalPipeline::Run);
}

void BindIncrementalMapperOptions(py::module& m) {
  using ImageSelection = IncrementalMapper::Options::ImageSelectionMethod;
  auto PyImageSelectionMethod =
      py::enum_<ImageSelection>(m, "ImageSelectionMethod")
          .value("MAX_VISIBLE_POINTS_NUM",
                 ImageSelection::MAX_VISIBLE_POINTS_NUM)
          .value("MAX_VISIBLE_POINTS_RATIO",
                 ImageSelection::MAX_VISIBLE_POINTS_RATIO)
          .value("MIN_UNCERTAINTY", ImageSelection::MIN_UNCERTAINTY);
  AddStringToEnumConstructor(PyImageSelectionMethod);

  using Opts = IncrementalMapper::Options;
  auto PyOpts = py::classh<Opts>(m, "IncrementalMapperOptions");
  PyOpts.def(py::init<>())
      .def_readwrite("init_min_num_inliers",
                     &Opts::init_min_num_inliers,
                     "Minimum number of inliers for initial image pair.")
      .def_readwrite("init_max_error",
                     &Opts::init_max_error,
                     "Maximum error in pixels for two-view geometry estimation "
                     "for initial image pair.")
      .def_readwrite("init_max_forward_motion",
                     &Opts::init_max_forward_motion,
                     "Maximum forward motion for initial image pair.")
      .def_readwrite("init_min_tri_angle",
                     &Opts::init_min_tri_angle,
                     "Minimum triangulation angle for initial image pair.")
      .def_readwrite(
          "init_max_reg_trials",
          &Opts::init_max_reg_trials,
          "Maximum number of trials to use an image for initialization.")
      .def_readwrite("abs_pose_max_error",
                     &Opts::abs_pose_max_error,
                     "Maximum reprojection error in absolute pose estimation.")
      .def_readwrite("abs_pose_min_num_inliers",
                     &Opts::abs_pose_min_num_inliers,
                     "Minimum number of inliers in absolute pose estimation.")
      .def_readwrite("abs_pose_min_inlier_ratio",
                     &Opts::abs_pose_min_inlier_ratio,
                     "Minimum inlier ratio in absolute pose estimation.")
      .def_readwrite(
          "abs_pose_refine_focal_length",
          &Opts::abs_pose_refine_focal_length,
          "Whether to estimate the focal length in absolute pose estimation.")
      .def_readwrite("abs_pose_refine_extra_params",
                     &Opts::abs_pose_refine_extra_params,
                     "Whether to estimate the extra parameters in absolute "
                     "pose estimation.")
      .def_readwrite("ba_local_num_images",
                     &Opts::ba_local_num_images,
                     "Number of images to optimize in local bundle adjustment.")
      .def_readwrite("ba_local_min_tri_angle",
                     &Opts::ba_local_min_tri_angle,
                     "Minimum triangulation for images to be chosen in local "
                     "bundle adjustment.")
      .def_readwrite(
          "ba_global_ignore_redundant_points3D",
          &Opts::ba_global_ignore_redundant_points3D,
          "Whether to ignore redundant 3D points in bundle adjustment when "
          "jointly optimizing all parameters. If this is enabled, then the "
          "bundle adjustment problem is first solved with a reduced set of 3D "
          "points and then the remaining 3D points are optimized in a second "
          "step with all other parameters fixed. Points excplicitly configured "
          "as constant or variable are not ignored. This is only activated "
          "when the reconstruction has reached sufficient size with at least "
          "10 registered frames.")
      .def_readwrite(
          "ba_global_prune_points_min_coverage_gain",
          &Opts::ba_global_ignore_redundant_points3D_min_coverage_gain,
          "The minimum coverage gain for any 3D point to be "
          "included in the optimization. A larger value means "
          "more 3D points are ignored.")
      .def_readwrite("min_focal_length_ratio",
                     &Opts::min_focal_length_ratio,
                     "The threshold used to filter and ignore images with "
                     "degenerate intrinsics.")
      .def_readwrite("max_focal_length_ratio",
                     &Opts::max_focal_length_ratio,
                     "The threshold used to filter and ignore images with "
                     "degenerate intrinsics.")
      .def_readwrite("max_extra_param",
                     &Opts::max_extra_param,
                     "The threshold used to filter and ignore images with "
                     "degenerate intrinsics.")
      .def_readwrite("filter_max_reproj_error",
                     &Opts::filter_max_reproj_error,
                     "Maximum reprojection error in pixels for observations.")
      .def_readwrite(
          "filter_min_tri_angle",
          &Opts::filter_min_tri_angle,
          "Minimum triangulation angle in degrees for stable 3D points.")
      .def_readwrite("max_reg_trials",
                     &Opts::max_reg_trials,
                     "Maximum number of trials to register an image.")
      .def_readwrite("fix_existing_frames",
                     &Opts::fix_existing_frames,
                     "If reconstruction is provided as input, fix the existing "
                     "frame poses.")
      .def_readwrite(
          "constant_rigs",
          &Opts::constant_rigs,
          "List of rigs for which to fix the sensor_from_rig transformation, "
          "independent of ba_refine_sensor_from_rig.")
      .def_readwrite("constant_cameras",
                     &Opts::constant_cameras,
                     "List of cameras for which to fix the camera parameters "
                     "independent of refine_focal_length, "
                     "refine_principal_point, and refine_extra_params.")
      .def_readwrite("num_threads", &Opts::num_threads, "Number of threads.")
      .def_readwrite(
          "random_seed",
          &Opts::random_seed,
          "PRNG seed for all stochastic methods during reconstruction.")
      .def_readwrite("image_selection_method",
                     &Opts::image_selection_method,
                     "Method to find and select next best image to register.")
      .def("check", &Opts::Check);
  MakeDataclass(PyOpts);
}

void BindIncrementalMapperImpl(py::module& m) {
  BindIncrementalMapperOptions(m);

  // bind local bundle adjustment report
  using LocalBAReport = IncrementalMapper::LocalBundleAdjustmentReport;
  auto PyLocalBAReport =
      py::classh<LocalBAReport>(m, "LocalBundleAdjustmentReport");
  PyLocalBAReport.def(py::init<>())
      .def_readwrite("num_merged_observations",
                     &LocalBAReport::num_merged_observations)
      .def_readwrite("num_completed_observations",
                     &LocalBAReport::num_completed_observations)
      .def_readwrite("num_filtered_observations",
                     &LocalBAReport::num_filtered_observations)
      .def_readwrite("num_adjusted_observations",
                     &LocalBAReport::num_adjusted_observations);
  MakeDataclass(PyLocalBAReport);

  // bind incremental mapper
  // TODO: migrate comments. improve formatting
  py::classh<IncrementalMapper>(m, "IncrementalMapper")
      .def(py::init<std::shared_ptr<const DatabaseCache>>(), "database_cache"_a)
      .def("begin_reconstruction",
           &IncrementalMapper::BeginReconstruction,
           "reconstruction"_a)
      .def("end_reconstruction",
           &IncrementalMapper::EndReconstruction,
           "discard"_a)
      .def(
          "find_initial_image_pair",
          [](IncrementalMapper& self,
             const IncrementalMapper::Options& options,
             int image_id1,
             int image_id2)
              -> py::typing::Optional<
                  py::typing::Tuple<py::typing::Tuple<image_t, image_t>,
                                    Rigid3d>> {
            // Explicitly handle the conversion
            // from -1 (int) to kInvalidImageId (uint32_t).
            image_t image_id1_cast = image_id1;
            image_t image_id2_cast = image_id2;
            Rigid3d cam2_from_cam1;
            const bool success = self.FindInitialImagePair(
                options, image_id1_cast, image_id2_cast, cam2_from_cam1);
            if (success) {
              const auto pair = std::make_pair(image_id1_cast, image_id2_cast);
              return py::cast(std::make_pair(pair, cam2_from_cam1));
            } else {
              return py::none();
            }
          },
          "options"_a,
          "image_id1"_a,
          "image_id2"_a)
      .def(
          "estimate_initial_two_view_geometry",
          [](IncrementalMapper& self,
             const IncrementalMapper::Options& options,
             const image_t image_id1,
             const image_t image_id2) -> py::typing::Optional<Rigid3d> {
            Rigid3d cam2_from_cam1;
            const bool success = self.EstimateInitialTwoViewGeometry(
                options, image_id1, image_id2, cam2_from_cam1);
            if (success)
              return py::cast(cam2_from_cam1);
            else
              return py::none();
          },
          "options"_a,
          "image_id1"_a,
          "image_id2"_a)
      .def("register_initial_image_pair",
           &IncrementalMapper::RegisterInitialImagePair,
           "options"_a,
           "two_view_geometry"_a,
           "image_id1"_a,
           "image_id2"_a)
      .def("find_next_images", &IncrementalMapper::FindNextImages, "options"_a)
      .def("register_next_image",
           &IncrementalMapper::RegisterNextImage,
           "options"_a,
           "image_id"_a)
      .def("triangulate_image",
           &IncrementalMapper::TriangulateImage,
           "tri_options"_a,
           "image_id"_a)
      .def("retriangulate", &IncrementalMapper::Retriangulate, "tri_options"_a)
      .def("complete_tracks",
           &IncrementalMapper::CompleteTracks,
           "tri_options"_a)
      .def("merge_tracks", &IncrementalMapper::MergeTracks, "tri_options"_a)
      .def("complete_and_merge_tracks",
           &IncrementalMapper::CompleteAndMergeTracks,
           "tri_options"_a)
      .def("adjust_local_bundle",
           &IncrementalMapper::AdjustLocalBundle,
           "options"_a,
           "ba_options"_a,
           "tri_options"_a,
           "image_id"_a,
           "point3D_ids"_a)
      .def("iterative_local_refinement",
           &IncrementalMapper::IterativeLocalRefinement,
           "max_num_refinements"_a,
           "max_refinement_change"_a,
           "options"_a,
           "ba_options"_a,
           "tri_options"_a,
           "image_id"_a)
      .def("find_local_bundle",
           &IncrementalMapper::FindLocalBundle,
           "options"_a,
           "image_id"_a)
      .def("adjust_global_bundle",
           &IncrementalMapper::AdjustGlobalBundle,
           "options"_a,
           "ba_options"_a)
      .def("iterative_global_refinement",
           &IncrementalMapper::IterativeGlobalRefinement,
           "max_num_refinements"_a,
           "max_refinement_change"_a,
           "options"_a,
           "ba_options"_a,
           "tri_options"_a,
           "normalize_reconstruction"_a = true)
      .def("filter_frames", &IncrementalMapper::FilterFrames, "options"_a)
      .def("filter_points", &IncrementalMapper::FilterPoints, "options"_a)
      .def_property_readonly("reconstruction",
                             &IncrementalMapper::Reconstruction)
      .def_property_readonly("observation_manager",
                             &IncrementalMapper::ObservationManager)
      .def_property_readonly("triangulator", &IncrementalMapper::Triangulator)
      .def_property_readonly("filtered_frames",
                             &IncrementalMapper::FilteredFrames)
      .def_property_readonly("existing_frame_ids",
                             &IncrementalMapper::ExistingFrameIds)
      .def("reset_initialization_stats",
           &IncrementalMapper::ResetInitializationStats)
      .def_property_readonly("num_reg_frames_per_rig",
                             &IncrementalMapper::NumRegFramesPerRig)
      .def_property_readonly("num_reg_images_per_camera",
                             &IncrementalMapper::NumRegImagesPerCamera)
      .def("num_total_reg_images", &IncrementalMapper::NumTotalRegImages)
      .def("num_shared_reg_images", &IncrementalMapper::NumSharedRegImages)
      .def("get_modified_points3D", &IncrementalMapper::GetModifiedPoints3D)
      .def("clear_modified_points3D",
           &IncrementalMapper::ClearModifiedPoints3D);
}

void BindIncrementalMapper(py::module& m) {
  BindIncrementalMapperImpl(m);
  BindIncrementalPipeline(m);
}
