/*
 * Project: MoleCuilder
 * Description: creates and alters molecular systems
 * Copyright (C)  2014 Frederik Heber. All rights reserved.
 *
 *
 *   This file is part of MoleCuilder.
 *
 *    MoleCuilder is free software: you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation, either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    MoleCuilder is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with MoleCuilder.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 * SphericalPointDistribution.cpp
 *
 *  Created on: May 30, 2014
 *      Author: heber
 */

// include config.h
#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include "CodePatterns/MemDebug.hpp"

#include "SphericalPointDistribution.hpp"

#include "CodePatterns/Assert.hpp"
#include "CodePatterns/IteratorAdaptors.hpp"
#include "CodePatterns/Log.hpp"
#include "CodePatterns/toString.hpp"

#include <algorithm>
#include <boost/math/quaternion.hpp>
#include <cmath>
#include <functional>
#include <iterator>
#include <limits>
#include <list>
#include <vector>
#include <map>

#include "LinearAlgebra/Line.hpp"
#include "LinearAlgebra/Plane.hpp"
#include "LinearAlgebra/RealSpaceMatrix.hpp"
#include "LinearAlgebra/Vector.hpp"

// static entities
const double SphericalPointDistribution::SQRT_3(sqrt(3.0));
const double SphericalPointDistribution::warn_amplitude = 1e-2;

typedef std::vector<double> DistanceArray_t;

inline
DistanceArray_t calculatePairwiseDistances(
    const std::vector<Vector> &_points,
    const SphericalPointDistribution::IndexList_t &_indices
    )
{
  DistanceArray_t result;
  for (SphericalPointDistribution::IndexList_t::const_iterator firstiter = _indices.begin();
      firstiter != _indices.end(); ++firstiter) {
    for (SphericalPointDistribution::IndexList_t::const_iterator seconditer = firstiter;
        seconditer != _indices.end(); ++seconditer) {
      if (firstiter == seconditer)
        continue;
      const double distance = (_points[*firstiter] - _points[*seconditer]).NormSquared();
      result.push_back(distance);
    }
  }
  return result;
}

// class generator: taken from www.cplusplus.com example std::generate
struct c_unique {
  int current;
  c_unique() {current=0;}
  int operator()() {return current++;}
} UniqueNumber;

/** Returns squared L2 error of the given \a _Matching.
 *
 * We compare the pair-wise distances of each associated matching
 * and check whether these distances each match between \a _old and
 * \a _new.
 *
 * \param _old first set of returnpolygon (fewer or equal to \a _new)
 * \param _new second set of returnpolygon
 * \param _Matching matching between the two sets
 * \return pair with L1 and squared L2 error
 */
std::pair<double, double> SphericalPointDistribution::calculateErrorOfMatching(
    const std::vector<Vector> &_old,
    const std::vector<Vector> &_new,
    const IndexList_t &_Matching)
{
  std::pair<double, double> errors( std::make_pair( 0., 0. ) );

  if (_Matching.size() > 1) {
    LOG(3, "INFO: Matching is " << _Matching);

    // calculate all pair-wise distances
    IndexList_t keys(_Matching.size());
    std::generate (keys.begin(), keys.end(), UniqueNumber);
    const DistanceArray_t firstdistances = calculatePairwiseDistances(_old, keys);
    const DistanceArray_t seconddistances = calculatePairwiseDistances(_new, _Matching);

    ASSERT( firstdistances.size() == seconddistances.size(),
        "calculateL2ErrorOfMatching() - mismatch in pair-wise distance array sizes.");
    DistanceArray_t::const_iterator firstiter = firstdistances.begin();
    DistanceArray_t::const_iterator seconditer = seconddistances.begin();
    for (;(firstiter != firstdistances.end()) && (seconditer != seconddistances.end());
        ++firstiter, ++seconditer) {
      const double gap = *firstiter - *seconditer;
      // L1 error
      if (errors.first < gap)
        errors.first = gap;
      // L2 error
      errors.second += gap*gap;
    }
  } else
    ELOG(3, "calculateErrorOfMatching() - Given matching's size is less than 2.");
  LOG(3, "INFO: Resulting errors for matching (L1, L2): "
      << errors.first << "," << errors.second << ".");

  return errors;
}

SphericalPointDistribution::Polygon_t SphericalPointDistribution::removeMatchingPoints(
    const VectorArray_t &_points,
    const IndexList_t &_matchingindices
    )
{
  SphericalPointDistribution::Polygon_t remainingreturnpolygon;
  IndexArray_t indices(_matchingindices.begin(), _matchingindices.end());
  std::sort(indices.begin(), indices.end());
  LOG(4, "DEBUG: sorted matching is " << indices);
  IndexArray_t remainingindices(_points.size(), -1);
  std::generate(remainingindices.begin(), remainingindices.end(), UniqueNumber);
  IndexArray_t::iterator remainiter = std::set_difference(
      remainingindices.begin(), remainingindices.end(),
      indices.begin(), indices.end(),
      remainingindices.begin());
  remainingindices.erase(remainiter, remainingindices.end());
  LOG(4, "DEBUG: remaining indices are " << remainingindices);
  for (IndexArray_t::const_iterator iter = remainingindices.begin();
      iter != remainingindices.end(); ++iter) {
    remainingreturnpolygon.push_back(_points[*iter]);
  }

  return remainingreturnpolygon;
}

/** Recursive function to go through all possible matchings.
 *
 * \param _MCS structure holding global information to the recursion
 * \param _matching current matching being build up
 * \param _indices contains still available indices
 * \param _matchingsize
 */
void SphericalPointDistribution::recurseMatchings(
    MatchingControlStructure &_MCS,
    IndexList_t &_matching,
    IndexList_t _indices,
    unsigned int _matchingsize)
{
  LOG(4, "DEBUG: Recursing with current matching " << _matching
      << ", remaining indices " << _indices
      << ", and sought size " << _matching.size()+_matchingsize);
  //!> threshold for L1 error below which matching is immediately acceptable
  const double L1THRESHOLD = 1e-2;
  if (!_MCS.foundflag) {
    LOG(4, "DEBUG: Current matching has size " << _matching.size() << ", places left " << _matchingsize);
    if (_matchingsize > 0) {
      // go through all indices
      for (IndexList_t::iterator iter = _indices.begin();
          (iter != _indices.end()) && (!_MCS.foundflag);) {
        // add index to matching
        _matching.push_back(*iter);
        LOG(5, "DEBUG: Adding " << *iter << " to matching.");
        // remove index but keep iterator to position (is the next to erase element)
        IndexList_t::iterator backupiter = _indices.erase(iter);
        // recurse with decreased _matchingsize
        recurseMatchings(_MCS, _matching, _indices, _matchingsize-1);
        // re-add chosen index and reset index to new position
        _indices.insert(backupiter, _matching.back());
        iter = backupiter;
        // remove index from _matching to make space for the next one
        _matching.pop_back();
      }
      // gone through all indices then exit recursion
      if (_matching.empty())
        _MCS.foundflag = true;
    } else {
      LOG(3, "INFO: Found matching " << _matching);
      // calculate errors
      std::pair<double, double> errors = calculateErrorOfMatching(
          _MCS.oldpoints, _MCS.newpoints, _matching);
      if (errors.first < L1THRESHOLD) {
        _MCS.bestmatching = _matching;
        _MCS.foundflag = true;
      } else if (_MCS.bestL2 > errors.second) {
        _MCS.bestmatching = _matching;
        _MCS.bestL2 = errors.second;
      }
    }
  }
}

/** Decides by an orthonormal third vector whether the sign of the rotation
 * angle should be negative or positive.
 *
 * \return -1 or 1
 */
inline
double determineSignOfRotation(
    const Vector &_oldPosition,
    const Vector &_newPosition,
    const Vector &_RotationAxis
    )
{
  Vector dreiBein(_oldPosition);
  dreiBein.VectorProduct(_RotationAxis);
  dreiBein.Normalize();
  const double sign =
      (dreiBein.ScalarProduct(_newPosition) < 0.) ? -1. : +1.;
  LOG(6, "DEBUG: oldCenter on plane is " << _oldPosition
      << ", newCenter in plane is " << _newPosition
      << ", and dreiBein is " << dreiBein);
  return sign;
}

/** Finds combinatorially the best matching between points in \a _polygon
 * and \a _newpolygon.
 *
 * We find the matching with the smallest L2 error, where we break when we stumble
 * upon a matching with zero error.
 *
 * \sa recurseMatchings() for going through all matchings
 *
 * \param _polygon here, we have indices 0,1,2,...
 * \param _newpolygon and here we need to find the correct indices
 * \return list of indices: first in \a _polygon goes to first index for \a _newpolygon
 */
SphericalPointDistribution::IndexList_t SphericalPointDistribution::findBestMatching(
    const SphericalPointDistribution::Polygon_t &_polygon,
    const SphericalPointDistribution::Polygon_t &_newpolygon
    )
{
  MatchingControlStructure MCS;
  MCS.foundflag = false;
  MCS.bestL2 = std::numeric_limits<double>::max();
  MCS.oldpoints.insert(MCS.oldpoints.begin(), _polygon.begin(),_polygon.end() );
  MCS.newpoints.insert(MCS.newpoints.begin(), _newpolygon.begin(),_newpolygon.end() );

  // search for bestmatching combinatorially
  {
    // translate polygon into vector to enable index addressing
    IndexList_t indices(_newpolygon.size());
    std::generate(indices.begin(), indices.end(), UniqueNumber);
    IndexList_t matching;

    // walk through all matchings
    const unsigned int matchingsize = _polygon.size();
    ASSERT( matchingsize <= indices.size(),
        "SphericalPointDistribution::matchSphericalPointDistributions() - not enough new points to choose for matching to old ones.");
    recurseMatchings(MCS, matching, indices, matchingsize);
  }
  return MCS.bestmatching;
}

inline
Vector calculateCenter(
    const SphericalPointDistribution::VectorArray_t &_positions,
    const SphericalPointDistribution::IndexList_t &_indices)
{
  Vector Center;
  Center.Zero();
  for (SphericalPointDistribution::IndexList_t::const_iterator iter = _indices.begin();
      iter != _indices.end(); ++iter)
    Center += _positions[*iter];
  if (!_indices.empty())
    Center *= 1./(double)_indices.size();

  return Center;
}

inline
void calculateOldAndNewCenters(
    Vector &_oldCenter,
    Vector &_newCenter,
    const SphericalPointDistribution::VectorArray_t &_referencepositions,
    const SphericalPointDistribution::VectorArray_t &_currentpositions,
    const SphericalPointDistribution::IndexList_t &_bestmatching)
{
  const size_t NumberIds = std::min(_bestmatching.size(), (size_t)3);
  SphericalPointDistribution::IndexList_t continuousIds(NumberIds, -1);
  std::generate(continuousIds.begin(), continuousIds.end(), UniqueNumber);
  _oldCenter = calculateCenter(_referencepositions, continuousIds);
  // C++11 defines a copy_n function ...
  SphericalPointDistribution::IndexList_t::const_iterator enditer = _bestmatching.begin();
  std::advance(enditer, NumberIds);
  SphericalPointDistribution::IndexList_t firstbestmatchingIds(NumberIds, -1);
  std::copy(_bestmatching.begin(), enditer, firstbestmatchingIds.begin());
  _newCenter = calculateCenter( _currentpositions, firstbestmatchingIds);
}

SphericalPointDistribution::Rotation_t SphericalPointDistribution::findPlaneAligningRotation(
    const VectorArray_t &_referencepositions,
    const VectorArray_t &_currentpositions,
    const IndexList_t &_bestmatching
    )
{
#ifndef NDEBUG
  bool dontcheck = false;
#endif
  // initialize to no rotation
  Rotation_t Rotation;
  Rotation.first.Zero();
  Rotation.first[0] = 1.;
  Rotation.second = 0.;

  // calculate center of triangle/line/point consisting of first points of matching
  Vector oldCenter;
  Vector newCenter;
  calculateOldAndNewCenters(
      oldCenter, newCenter,
      _referencepositions, _currentpositions, _bestmatching);

  if ((!oldCenter.IsZero()) && (!newCenter.IsZero())) {
    LOG(4, "DEBUG: oldCenter is " << oldCenter << ", newCenter is " << newCenter);
    oldCenter.Normalize();
    newCenter.Normalize();
    if (!oldCenter.IsEqualTo(newCenter)) {
      // calculate rotation axis and angle
      Rotation.first = oldCenter;
      Rotation.first.VectorProduct(newCenter);
      Rotation.second = oldCenter.Angle(newCenter); // /(M_PI/2.);
    } else {
      // no rotation required anymore
    }
  } else {
    LOG(4, "DEBUG: oldCenter is " << oldCenter << ", newCenter is " << newCenter);
    if ((oldCenter.IsZero()) && (newCenter.IsZero())) {
      // either oldCenter or newCenter (or both) is directly at origin
      if (_bestmatching.size() == 2) {
        // line case
        Vector oldPosition = _currentpositions[*_bestmatching.begin()];
        Vector newPosition = _referencepositions[0];
        // check whether we need to rotate at all
        if (!oldPosition.IsEqualTo(newPosition)) {
          Rotation.first = oldPosition;
          Rotation.first.VectorProduct(newPosition);
          // orientation will fix the sign here eventually
          Rotation.second = oldPosition.Angle(newPosition);
        } else {
          // no rotation required anymore
        }
      } else {
        // triangle case
        // both triangles/planes have same center, hence get axis by
        // VectorProduct of Normals
        Plane newplane(_referencepositions[0], _referencepositions[1], _referencepositions[2]);
        VectorArray_t vectors;
        for (IndexList_t::const_iterator iter = _bestmatching.begin();
            iter != _bestmatching.end(); ++iter)
          vectors.push_back(_currentpositions[*iter]);
        Plane oldplane(vectors[0], vectors[1], vectors[2]);
        Vector oldPosition = oldplane.getNormal();
        Vector newPosition = newplane.getNormal();
        // check whether we need to rotate at all
        if (!oldPosition.IsEqualTo(newPosition)) {
          Rotation.first = oldPosition;
          Rotation.first.VectorProduct(newPosition);
          Rotation.first.Normalize();

          // construct reference vector to determine direction of rotation
          const double sign = determineSignOfRotation(oldPosition, newPosition, Rotation.first);
          Rotation.second = sign * oldPosition.Angle(newPosition);
          LOG(5, "DEBUG: Rotating plane normals by " << Rotation.second
              << " around axis " << Rotation.first);
        } else {
          // else do nothing
        }
      }
    } else {
      // TODO: we can't do anything here, but this case needs to be dealt with when
      // we have no ideal geometries anymore
      if ((oldCenter-newCenter).Norm() > warn_amplitude)
        ELOG(2, "oldCenter is " << oldCenter << ", yet newCenter is " << newCenter);
#ifndef NDEBUG
      // else they are considered close enough
      dontcheck = true;
#endif
    }
  }

#ifndef NDEBUG
  // check: rotation brings newCenter onto oldCenter position
  if (!dontcheck) {
    Line Axis(zeroVec, Rotation.first);
    Vector test = Axis.rotateVector(newCenter, Rotation.second);
    LOG(4, "CHECK: rotated newCenter is " << test
        << ", oldCenter is " << oldCenter);
    ASSERT( (test - oldCenter).NormSquared() < std::numeric_limits<double>::epsilon()*1e4,
        "matchSphericalPointDistributions() - rotation does not work as expected by "
        +toString((test - oldCenter).NormSquared())+".");
  }
#endif

  return Rotation;
}

SphericalPointDistribution::Rotation_t SphericalPointDistribution::findPointAligningRotation(
    const VectorArray_t &remainingold,
    const VectorArray_t &remainingnew,
    const IndexList_t &_bestmatching)
{
  // initialize rotation to zero
  Rotation_t Rotation;
  Rotation.first.Zero();
  Rotation.first[0] = 1.;
  Rotation.second = 0.;

  // recalculate center
  Vector oldCenter;
  Vector newCenter;
  calculateOldAndNewCenters(
      oldCenter, newCenter,
      remainingold, remainingnew, _bestmatching);

  Vector oldPosition = remainingnew[*_bestmatching.begin()];
  Vector newPosition = remainingold[0];
  LOG(6, "DEBUG: oldPosition is " << oldPosition << " and newPosition is " << newPosition);
  if (!oldPosition.IsEqualTo(newPosition)) {
    if ((!oldCenter.IsZero()) && (!newCenter.IsZero())) {
      oldCenter.Normalize();  // note weighted sum of normalized weight is not normalized
      Rotation.first = oldCenter;
      LOG(6, "DEBUG: Picking normalized oldCenter as Rotation.first " << oldCenter);
      oldPosition.ProjectOntoPlane(Rotation.first);
      newPosition.ProjectOntoPlane(Rotation.first);
      LOG(6, "DEBUG: Positions after projection are " << oldPosition << " and " << newPosition);
    } else {
      if (_bestmatching.size() == 2) {
        // line situation
        try {
          Plane oldplane(oldPosition, oldCenter, newPosition);
          Rotation.first = oldplane.getNormal();
          LOG(6, "DEBUG: Plane is " << oldplane << " and normal is " << Rotation.first);
        } catch (LinearDependenceException &e) {
          LOG(6, "DEBUG: Vectors defining plane are linearly dependent.");
          // oldPosition and newPosition are on a line, just flip when not equal
          if (!oldPosition.IsEqualTo(newPosition)) {
            Rotation.first.Zero();
            Rotation.first.GetOneNormalVector(oldPosition);
            LOG(6, "DEBUG: For flipping we use Rotation.first " << Rotation.first);
            assert( Rotation.first.ScalarProduct(oldPosition) < std::numeric_limits<double>::epsilon()*1e4);
  //              Rotation.second = M_PI;
          } else {
            LOG(6, "DEBUG: oldPosition and newPosition are equivalent.");
          }
        }
      } else {
        // triangle situation
        Plane oldplane(remainingold[0], remainingold[1], remainingold[2]);
        Rotation.first = oldplane.getNormal();
        LOG(6, "DEBUG: oldPlane is " << oldplane << " and normal is " << Rotation.first);
        oldPosition.ProjectOntoPlane(Rotation.first);
        LOG(6, "DEBUG: Positions after projection are " << oldPosition << " and " << newPosition);
      }
    }
    // construct reference vector to determine direction of rotation
    const double sign = determineSignOfRotation(oldPosition, newPosition, Rotation.first);
    Rotation.second = sign * oldPosition.Angle(newPosition);
  } else {
    LOG(6, "DEBUG: oldPosition and newPosition are equivalent, hence no orientating rotation.");
  }

  return Rotation;
}


SphericalPointDistribution::Polygon_t
SphericalPointDistribution::matchSphericalPointDistributions(
    const SphericalPointDistribution::Polygon_t &_polygon,
    const SphericalPointDistribution::Polygon_t &_newpolygon
    )
{
  SphericalPointDistribution::Polygon_t remainingreturnpolygon;
  VectorArray_t remainingold(_polygon.begin(), _polygon.end());
  VectorArray_t remainingnew(_newpolygon.begin(), _newpolygon.end());
  LOG(2, "INFO: Matching old polygon " << _polygon
      << " with new polygon " << _newpolygon);

  if (_polygon.size() == _newpolygon.size()) {
    // same number of points desired as are present? Do nothing
    LOG(2, "INFO: There are no vacant points to return.");
    return remainingreturnpolygon;
  }

  if (_polygon.size() > 0) {
    IndexList_t bestmatching = findBestMatching(_polygon, _newpolygon);
    LOG(2, "INFO: Best matching is " << bestmatching);

    // determine rotation angles to align the two point distributions with
    // respect to bestmatching:
    // we use the center between the three first matching points
    /// the first rotation brings these two centers to coincide
    VectorArray_t rotated_newpolygon = remainingnew;
    {
      Rotation_t Rotation = findPlaneAligningRotation(
          remainingold,
          remainingnew,
          bestmatching);
      LOG(5, "DEBUG: Rotating coordinate system by " << Rotation.second
          << " around axis " << Rotation.first);
      Line Axis(zeroVec, Rotation.first);

      // apply rotation angle to bring newCenter to oldCenter
      for (VectorArray_t::iterator iter = rotated_newpolygon.begin();
          iter != rotated_newpolygon.end(); ++iter) {
        Vector &current = *iter;
        LOG(6, "DEBUG: Original point is " << current);
        current =  Axis.rotateVector(current, Rotation.second);
        LOG(6, "DEBUG: Rotated point is " << current);
      }

#ifndef NDEBUG
      // check: rotated "newCenter" should now equal oldCenter
      {
        Vector oldCenter;
        Vector rotatednewCenter;
        calculateOldAndNewCenters(
            oldCenter, rotatednewCenter,
            remainingold, rotated_newpolygon, bestmatching);
        // NOTE: Center must not necessarily lie on the sphere with norm 1, hence, we
        // have to normalize it just as before, as oldCenter and newCenter lengths may differ.
        if ((!oldCenter.IsZero()) && (!rotatednewCenter.IsZero())) {
          oldCenter.Normalize();
          rotatednewCenter.Normalize();
          LOG(4, "CHECK: rotatednewCenter is " << rotatednewCenter
              << ", oldCenter is " << oldCenter);
          ASSERT( (rotatednewCenter - oldCenter).NormSquared() < std::numeric_limits<double>::epsilon()*1e4,
              "matchSphericalPointDistributions() - rotation does not work as expected by "
              +toString((rotatednewCenter - oldCenter).NormSquared())+".");
        }
      }
#endif
    }
    /// the second (orientation) rotation aligns the planes such that the
    /// points themselves coincide
    if (bestmatching.size() > 1) {
      Rotation_t Rotation = findPointAligningRotation(
          remainingold,
          rotated_newpolygon,
          bestmatching);

      // construct RotationAxis and two points on its plane, defining the angle
      Rotation.first.Normalize();
      const Line RotationAxis(zeroVec, Rotation.first);

      LOG(5, "DEBUG: Rotating around self is " << Rotation.second
          << " around axis " << RotationAxis);

#ifndef NDEBUG
      // check: first bestmatching in rotated_newpolygon and remainingnew
      // should now equal
      {
        const IndexList_t::const_iterator iter = bestmatching.begin();
        Vector rotatednew = RotationAxis.rotateVector(
            rotated_newpolygon[*iter],
            Rotation.second);
        LOG(4, "CHECK: rotated first new bestmatching is " << rotatednew
            << " while old was " << remainingold[0]);
        ASSERT( (rotatednew - remainingold[0]).Norm() < warn_amplitude,
            "matchSphericalPointDistributions() - orientation rotation ends up off by more than "
            +toString(warn_amplitude)+".");
      }
#endif

      for (VectorArray_t::iterator iter = rotated_newpolygon.begin();
          iter != rotated_newpolygon.end(); ++iter) {
        Vector &current = *iter;
        LOG(6, "DEBUG: Original point is " << current);
        current = RotationAxis.rotateVector(current, Rotation.second);
        LOG(6, "DEBUG: Rotated point is " << current);
      }
    }

    // remove all points in matching and return remaining ones
    SphericalPointDistribution::Polygon_t remainingpoints =
        removeMatchingPoints(rotated_newpolygon, bestmatching);
    LOG(2, "INFO: Remaining points are " << remainingpoints);
    return remainingpoints;
  } else
    return _newpolygon;
}
