/*
 * Project: MoleCuilder
 * Description: creates and alters molecular systems
 * Copyright (C)  2017 Frederik Heber. All rights reserved.
 * 
 *
 *   This file is part of MoleCuilder.
 *
 *    MoleCuilder is free software: you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation, either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    MoleCuilder is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with MoleCuilder.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 * BondVectorsUnitTest.cpp
 *
 *  Created on: Jun 29, 2017
 *      Author: heber
 */

// include config.h
#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include <cppunit/CompilerOutputter.h>
#include <cppunit/extensions/TestFactoryRegistry.h>
#include <cppunit/ui/text/TestRunner.h>

#include "CodePatterns/Assert.hpp"
#include "CodePatterns/Log.hpp"

#include <boost/assign.hpp>

#include "BondVectorsUnitTest.hpp"

#include "Atom/atom.hpp"
#include "Bond/bond.hpp"
#include "Dynamics/BondVectors.hpp"
#include "Element/periodentafel.hpp"
#include "World.hpp"
#include "WorldTime.hpp"

#ifdef HAVE_TESTRUNNER
#include "UnitTestMain.hpp"
#endif /*HAVE_TESTRUNNER*/

using namespace boost::assign;

/********************************************** Test classes **************************************/

// Registers the fixture into the 'registry'
CPPUNIT_TEST_SUITE_REGISTRATION( BondVectorsTest );


void BondVectorsTest::setUp()
{
  // failing asserts should be thrown
  ASSERT_DO(Assert::Throw);

  setVerbosity(4);

  // create an atom
  carbon = World::getInstance().getPeriode()->FindElement(6);
  CPPUNIT_ASSERT(carbon != NULL);

  _atom = World::getInstance().createAtom();
  _atom->setType(carbon);
  _atom->setPosition( zeroVec );
  atoms.push_back(_atom);
  _atom = World::getInstance().createAtom();
  _atom->setType(carbon);
  _atom->setPosition( Vector(1.6,0.,0.) );
  atoms.push_back(_atom);
  _atom = World::getInstance().createAtom();
  _atom->setType(carbon);
  _atom->setPosition( Vector(3.2,0.,0.) );
  atoms.push_back(_atom);
  _atom = World::getInstance().createAtom();
  _atom->setType(carbon);
  _atom->setPosition( Vector(1.6,1.6,0.) );
  atoms.push_back(_atom);
  _atom = World::getInstance().createAtom();
  _atom->setType(carbon);
  _atom->setPosition( Vector(2.8,2.8,0.) );
  atoms.push_back(_atom);
  _atom = World::getInstance().createAtom();
  _atom->setType(carbon);
  _atom->setPosition( Vector(1.6,-1.6,0.) );
  atoms.push_back(_atom);
  _atom = World::getInstance().createAtom();
  _atom->setType(carbon);
  _atom->setPosition( Vector(2.8,-1.2,0.) );
  atoms.push_back(_atom);

  bv = new BondVectors;
}

static void clearbondvector(
    std::vector<bond::ptr> &_bondvector)
{
  // remove bonds
  for (std::vector<bond::ptr>::iterator iter = _bondvector.begin();
      !_bondvector.empty(); iter = _bondvector.begin()) {
    (*iter)->leftatom->removeBond((*iter)->rightatom);
    _bondvector.erase(iter);
  }
}


void BondVectorsTest::tearDown()
{
  delete bv;

  atoms.clear();
  atomvector.clear();
  clearbondvector(bondvector);
  carbon = NULL;

  World::purgeInstance();
  WorldTime::purgeInstance();
}

/** Test whether current_mapped is kept up-to-date
 *
 */
void BondVectorsTest::current_mappedTest()
{
  {
    // gather atoms
    atomvector += atoms[center], atoms[left];
    // create bonds
    bondvector += atoms[center]->addBond(atoms[left]);
    // prepare bondvectors
    bv->setFromAtomRange< std::vector<atom *> >(atomvector.begin(), atomvector.end(), WorldTime::getTime());
    // get bond vectors
    const std::vector<Vector> Bondvectors =
        bv->getAtomsBondVectorsAtStep(*atoms[center], WorldTime::getTime());
    // check number of bond vectors
    CPPUNIT_ASSERT_EQUAL( Bondvectors.size(), (size_t)1 );
    // check norm of bond vector
    CPPUNIT_ASSERT( fabs(Bondvectors[0].Norm() - 1.) < MYEPSILON );

    // clear set of atoms and use a different one
    clearbondvector(bondvector);
    atomvector.clear();
  }

  {
    // gather atoms
    atomvector += atoms[center], atoms[left], atoms[right], atoms[top];
    // create bonds
    bondvector +=
        atoms[center]->addBond(atoms[left]),
        atoms[center]->addBond(atoms[right]),
        atoms[center]->addBond(atoms[top]);
    // prepare bondvectors
    bv->setFromAtomRange< std::vector<atom *> >(atomvector.begin(), atomvector.end(), WorldTime::getTime());
    // get bond vectors
    const std::vector<Vector> Bondvectors =
        bv->getAtomsBondVectorsAtStep(*atoms[center], WorldTime::getTime());
    // check number of bond vectors
    CPPUNIT_ASSERT_EQUAL( Bondvectors.size(), (size_t)3 );
    // check norm of bond vector
    for (size_t i=0;i<3;++i)
      CPPUNIT_ASSERT( fabs(Bondvectors[i].Norm() - 1.) < MYEPSILON );
  }
}

/** Test whether calculating weights works on single bond
 *
 */
void BondVectorsTest::weights_singlebondTest()
{
  // gather atoms
  atomvector += atoms[center], atoms[left];
  // create bonds
  bondvector += atoms[center]->addBond(atoms[left]);
  // prepare bondvectors
  bv->setFromAtomRange< std::vector<atom *> >(atomvector.begin(), atomvector.end(), WorldTime::getTime());
  // calculate weights
  BondVectors::weights_t weights = bv->getWeightsForAtomAtStep(*atoms[center], WorldTime::getTime());
  LOG(2, "DEBUG: Single bond weights are " << weights);
  // check number of weights
  CPPUNIT_ASSERT_EQUAL( weights.size(), (size_t)1 );
  // check sum of weights
  const double weight_sum = std::accumulate(weights.begin(), weights.end(), 0.);
  CPPUNIT_ASSERT( fabs(weight_sum - 1.) < MYEPSILON );
  // check weight
  CPPUNIT_ASSERT( fabs(weight_sum - 1.) < 1e-10 );
}

/** Test whether calculating weights works on linear chain config
 *
 */
void BondVectorsTest::weights_linearchainTest()
{
  // gather atoms
  atomvector += atoms[center], atoms[left], atoms[right];
  // create bonds
  bondvector += atoms[center]->addBond(atoms[left]), atoms[center]->addBond(atoms[right]);
  // prepare bondvectors
  bv->setFromAtomRange< std::vector<atom *> >(atomvector.begin(), atomvector.end(), WorldTime::getTime());
  // calculate weights
  BondVectors::weights_t weights = bv->getWeightsForAtomAtStep(*atoms[center], WorldTime::getTime());
  LOG(2, "DEBUG: Linear chain weights are " << weights);
  // check number of weights
  CPPUNIT_ASSERT_EQUAL( weights.size(), (size_t)2 );
  // check sum of weights
  const double weight_sum = std::accumulate(weights.begin(), weights.end(), 0.);
  CPPUNIT_ASSERT( fabs(weight_sum - 1.) < 1e-10 );
}

/** Test whether calculating weights works on right angle config
 *
 */
void BondVectorsTest::weights_rightangleTest()
{
  // gather atoms
  atomvector += atoms[center], atoms[left], atoms[top];
  // create bonds
  bondvector +=
      atoms[center]->addBond(atoms[left]),
      atoms[center]->addBond(atoms[top]);
  // prepare bondvectors
  bv->setFromAtomRange< std::vector<atom *> >(atomvector.begin(), atomvector.end(), WorldTime::getTime());
  // calculate weights
  BondVectors::weights_t weights = bv->getWeightsForAtomAtStep(*atoms[center], WorldTime::getTime());
  LOG(2, "DEBUG: Right angle weights are " << weights);
  // check number of weights
  CPPUNIT_ASSERT_EQUAL( weights.size(), (size_t)2 );
  // check sum of weights: two independent vectors == 1+1
  const double weight_sum = std::accumulate(weights.begin(), weights.end(), 0.);
  CPPUNIT_ASSERT( fabs(weight_sum - 2.) < 1e-10 );
}

/** Test whether calculating weights works on triangle config
 *
 */
void BondVectorsTest::weights_triangleTest()
{
  // gather atoms
  atomvector += atoms[center], atoms[left], atoms[right], atoms[top];
  // create bonds
  bondvector +=
      atoms[center]->addBond(atoms[left]),
      atoms[center]->addBond(atoms[right]),
      atoms[center]->addBond(atoms[top]);
  // prepare bondvectors
  bv->setFromAtomRange< std::vector<atom *> >(atomvector.begin(), atomvector.end(), WorldTime::getTime());
  // calculate weights
  BondVectors::weights_t weights = bv->getWeightsForAtomAtStep(*atoms[center], WorldTime::getTime());
  LOG(2, "DEBUG: Triangle weights are " << weights);
  // check number of weights
  CPPUNIT_ASSERT_EQUAL( weights.size(), (size_t)3 );
  // check sum of weights: one linear independent, two dependent vectors = 1 + 2*0.5
  const double weight_sum = std::accumulate(weights.begin(), weights.end(), 0.);
  CPPUNIT_ASSERT( fabs(weight_sum - 2.) < 1e-10 );
}

/** Test whether calculating weights works on complex config
 *
 */
void BondVectorsTest::weights_complexTest()
{
  // gather atoms
  atomvector += atoms[center], atoms[left], atoms[right], atoms[top], atoms[topright], atoms [bottomright];
  // create bonds
  bondvector +=
      atoms[center]->addBond(atoms[left]),
      atoms[center]->addBond(atoms[right]),
      atoms[center]->addBond(atoms[top]),
      atoms[center]->addBond(atoms[topright]),
      atoms[center]->addBond(atoms[bottomright]);
  // prepare bondvectors
  bv->setFromAtomRange< std::vector<atom *> >(atomvector.begin(), atomvector.end(), WorldTime::getTime());
  // calculate weights
  BondVectors::weights_t weights = bv->getWeightsForAtomAtStep(*atoms[center], WorldTime::getTime());
  LOG(2, "DEBUG: Complex weights are " << weights);
  // check number of weights
  CPPUNIT_ASSERT_EQUAL( weights.size(), (size_t)5 );
  // check sum of weights
  const double weight_sum = std::accumulate(weights.begin(), weights.end(), 0.);
  CPPUNIT_ASSERT( fabs(weights[0] - .372244) < 1e-6 );
  CPPUNIT_ASSERT( fabs(weights[1] - .529694) < 1e-6 );
  CPPUNIT_ASSERT( fabs(weights[2] - .2) < 1e-6 );
  CPPUNIT_ASSERT( fabs(weights[3] - .248464) < 1e-6 );
  CPPUNIT_ASSERT( fabs(weights[4] - .248464) < 1e-6 );
}
