# (C) 2020 Frederik Heber
#
# This script randomly "grows" a molecule by using the following actions by
# pyMoleCuilder:
# - SaturateAtoms
# - ChangeElement
#
# This is to explore the space of chemical graphs.
#
# NOTE: For this to work, we need to point PYTHONPATH to a molecuilder installation
#       directory.
import os, sys, random
from itertools import chain, combinations

import pyMoleCuilder as mol


counter = 0
counter_all = 0




typical_distance = {
    "H": 0.6,
    "C": 1.,
    "N": 1.,
    "O": 1.,
    "P": 1.,
    "S": 1.
}
valencies = {
    "H": 1,
    "C": 4,
    "N": 3,
    "O": 2,
#    "P": 3,
#    "S": 2
}
atomicNumbers = {
    1: "H",
    6: "C",
    7: "N",
    8: "O",
    15: "P",
    16: "S",
}
elements = valencies.keys()

MAX_DEGREE = 3

THEORY = "MBPT2"
BASISNAME = "6-311G"

homology_container_filename = "homologies_"+THEORY+"_"+BASISNAME+".dat"
atomfragments_filename = "atomfragments_"+THEORY+"_"+BASISNAME+".dat"

#random.seed(426)
SEED = random.randint(0, 1e8)


def callCounter(function_name, *args, **kwargs):
    global counter
    global counter_all
    counter_all += 1
    # increase on actions, decrease on undos
    if function_name == "Undo":
        counter -= 1
    elif 'A' <= function_name[0] <= 'Z':
        # skip all lower case commands (which are pure python without side effects)
        counter += 1
    method_to_call = getattr(mol, function_name)
    return method_to_call(*args, **kwargs)


def add_seed_atom(element):
    # has side effects
    callCounter("SelectionAllMolecules")
    callCounter("AtomAdd", add_atom = element, domain_position="10,10,10")


def change_element(element):
    # has side effects
    callCounter("AtomChangeElement", change_element = element)
    callCounter("FragmentationClearFragmentationState")
    callCounter("wait")


def saturate_all_atoms():
    # has side effects
    callCounter("SelectionAllAtoms")
    callCounter("AtomSaturate", use_outer_shell = "1")
    callCounter("SelectionClearAllAtoms")
    callCounter("wait")


def select_atoms_randomly(number):
    # has side effects
    callCounter("SelectionClearAllAtoms")
    callCounter("SelectionAtomByRandom", select_atom_by_random = str(number))
    callCounter("wait")

def check_for_present_hydrogens():
    # has no side effects
    callCounter("SelectionPushAtoms")
    callCounter("SelectionClearAllAtoms")
    callCounter("SelectionAtomByElement", "H")
    callCounter("wait")
    hydrogen_ids = callCounter("getSelectedAtomIds")
    callCounter("SelectionPopAtoms")
    if len(hydrogen_ids) == 0:
        return False
    return True


def read_lastline(f):
    f.seek(-2,2)##
    while f.read(1) != b'\n':
        f.seek(-2, 1)
    return f.read()


def get_element_of_selected_atom():
    # has no side effects
    e = callCounter("getSelectedAtomElements")
    return atomicNumbers[int(e[0])]


pair_correlation_filename = "paircorrelation.dat"
pair_correlation_bin_filename = "paircorrelation_bins.dat"


def check_pair_correlation_distance(typical_distance, element):
    # has no side effects
    bin_end=typical_distance+0.1
    mol.AnalysisPairCorrelation(
        elements=element,
        bin_start="0.", bin_width="0.1", bin_end=str(bin_end),
        output_file=pair_correlation_filename,
        bin_output_file=pair_correlation_bin_filename,
        periodic="0"
    )
    callCounter("wait")
    bin = bin_end
    with open(pair_correlation_bin_filename, "r") as f:
        header = f.readline()
        assert( header.split()[3] == "Count")
        content = f.readlines()
        for line in content:
            fields = line.split()
            if int(fields[3]) != 0:
                bin = float(fields[0])
                print("First non-zero bin to "+element+" is at "+str(bin))
                break
    #[os.remove(f) for f in [filename, bin_filename]]
    callCounter("wait")
    return float(bin) > typical_distance


def extract_atom_name(text):
    return text.split(",")[0]


def parse_pair_correlation_distance(min_distance):
    returnlist = []
    with open(pair_correlation_filename, "r") as f:
        header = f.readline()
        assert( header.split()[0] == "BinStart")
        content = f.readlines()
        for line in content:
            fields = line.split()
            binstart = float(fields[0])
            if binstart < min_distance:
                returnlist.append(extract_atom_name(fields[3]))
    return returnlist


def get_mean_position_of_selected_atoms():
    # has no side effects
    mean_pos = [0., 0., 0.]
    callCounter("wait")
    positions = callCounter("getSelectedAtomPositions")
    if len(positions) > 0:
        for pos in positions:
            for i in range(3):
                mean_pos[i] += pos[i]
        for i in range(3):
            mean_pos[i] /= len(positions)
    return mean_pos


def get_all_hydrogen_bond_neighbors(current_id):
    # has no side effects
    callCounter("SelectionPushAtoms")
    callCounter("SelectionClearAllAtoms")
    callCounter("SelectionAtomById", str(current_id))
    callCounter("SelectionAtomBondNeighbors")
    for other_element in non_hydrogen_elements:
        callCounter("SelectionNotAtomByElement", other_element)
    # not needed as it is contained in elements
    # callCounter("SelectionNotAtomById", str(current_id))
    callCounter("wait")
    hydrogen_ids = callCounter("getSelectedAtomIds")
    print("Hydrogens are "+str([str(h) for h in hydrogen_ids]))
    callCounter("SelectionPopAtoms")
    return hydrogen_ids


def prepare_run(bond_table, filename = ""):
    # has side effects
    if len(filename) > 0:
        print("Using file "+filename)
        if filename.rfind('.') == -1:
            prefix = filename
        else:
            prefix = filename[0:filename.rfind('.')]
        callCounter("WorldInput", filename)
    print("Using prefix " + prefix)
    callCounter("ParserSetOutputFormats", "tremolo")
    callCounter("ParserSetTremoloAtomdata", "Id x=3 u=3 type neighbors=8")
    callCounter("ParserSetParserParameters", "mpqc", "theory="+THEORY+";basis="+BASISNAME+";")
    callCounter("CommandSetRandomNumbersEngine", set_random_number_engine="mt19937", random_number_engine_parameters='seed='+str(SEED))
    callCounter("CommandBondLengthTable", bond_table)
    callCounter("WorldChangeBox", "20,0,0,20,0,20")
    callCounter("wait")
    return prefix


def parse_containers():
    # has side effects
    if os.path.exists(homology_container_filename):
        callCounter("PotentialParseHomologies", homology_container_filename)
    if os.path.exists(atomfragments_filename):
        callCounter("PotentialParseAtomFragments", atomfragments_filename)
    callCounter("wait")


def store_containers():
    # has side effects
    callCounter("PotentialSaveHomologies", homology_container_filename)
    callCounter("PotentialSaveAtomFragments", atomfragments_filename)
    callCounter("wait")


class FolderScope(object):
    '''
    Scope to step into a folder and automatically step back on exit
    '''
    def __init__(self, folder_name):
        self.pwd = os.getcwd()
        self.folder_name = folder_name
    def __enter__(self):
        if not os.path.exists(self.folder_name):
            os.makedirs(self.folder_name)
        os.chdir(self.folder_name)
    def __exit__(self, type, value, traceback):
        os.chdir(self.pwd)


class UndoScope(object):
    '''
    Scope to automatically undo all molecuilder actions executed within
    its scope.
    '''
    def __init__(self):
        # print("ENTERing UndoScope at " + str(counter))
        callCounter("CommandUndoMark", "1")
    def __enter__(self):
        mol.wait()
    def __exit__(self, type, value, traceback):
        # print("EXITing UndoScope at " + str(counter))
        callCounter("Undo", till_mark="1")
        mol.wait()


def get_mean_position_of_hydrogens(current_position):
    # has no side effects
    mean_pos = get_mean_position_of_selected_atoms()
    # print("Mean positions " + str(mean_pos))

    # compare to position of non-hydrogen atom
    if sum([pow(current_position[i] - mean_pos[i], 2) for i in range(3)]) < 0.1:
        # too close to non-hydrogen, hence pick the first of the hydrogen positions
        print("Mean position " + str(mean_pos) + " is too close to non-hydrogen " + str(
           current_position) + ", using one of the hydrogens.")
        callCounter("wait")
        mean_pos = callCounter("getSelectedAtomPositions")[0]
    return mean_pos


def replace_hydrogens_by_atom_rescale_and_saturate(pick_element, mean_pos, bond_degree, current_id):
    # has side effects
    # remove hydrogens and add new atom
    callCounter("SelectionAllMolecules")
    callCounter("AtomRemove")
    mol.AtomAdd(add_atom=pick_element,
                domain_position='{},{},{}'.format(mean_pos[0], mean_pos[1], mean_pos[2]))

    # set bond degree and scale bond distance to typical distance
    callCounter("SelectionAtomByOrder", '-1')
    callCounter("SelectionAtomById", str(current_id))
    callCounter("BondAdd")
    callCounter("BondSetDegree", str(bond_degree))
    #callCounter("GraphUpdateMolecules")
    callCounter("MoleculeStretchBond", stretch_bond="0.")

    # bondify and saturate
    callCounter("SelectionNotAtomById", str(current_id))
    callCounter("AtomBondify", max_hydrogens="1")
    callCounter("AtomSaturate", use_outer_shell = "1")
    callCounter("SelectionClearAllAtoms")
    callCounter("wait")


def calculate_energy(loop_prefix):
    # has side effects
    callCounter("SelectionPushAtoms")
    callCounter("SelectionAllAtoms")
    callCounter("AtomRandomPerturbation", random_perturbation="0.01")
    callCounter("FragmentationClearFragmentationState")
    callCounter("FragmentationFragmentation",
        fragment_molecule=loop_prefix,
        DoSaturate="1",
        ExcludeHydrogen="1",
        order="3",
        output_types="",
        parse_state_files="0"
    )
    callCounter("FragmentationFragmentationAutomation",
        server_address="eumaios",
        server_port="20024",
        DoLongrange="0"
    )
    callCounter("FragmentationAnalyseFragmentationResults")
    unique_filename = ensure_unique_filename(loop_prefix, 'FragmentResults.dat')
    callCounter("FragmentationSaveFragmentResults", save_fragment_results=unique_filename)
    callCounter("SelectionPopAtoms")
    callCounter("wait")


def get_unique_filename(num_atoms):
    # has no side effects
    callCounter("SelectionPushAtoms")
    callCounter("SelectionClearAllAtoms")
    callCounter("SelectionAllAtoms")
    callCounter("SelectionNotAtomByElement", "H")
    callCounter("wait")
    atom_ids_check = callCounter("getSelectedAtomIds")
    assert (len(atom_ids_check) == num_atoms)
    graph6_string = callCounter("getGraph6String")
    # hard code canonical form for N=3
    if graph6_string in ["Bo", "Bg"]:
        graph6_string = "BW"
    elementlist_string = callCounter("getElementListAsString")
    unique_filename = graph6_string+"_-_"+elementlist_string.replace(" ", "_")
    print("Picking "+unique_filename+" as filename.")
    callCounter("SelectionPopAtoms")
    return unique_filename


def ensure_unique_filename(prefix, suffix):
    unique_filename = prefix+suffix
    if os.path.exists(unique_filename):
        for i in range(1000):
            if not os.path.exists(prefix + "-" + str(i) + suffix):
                unique_filename = prefix + "-" + str(i) + suffix
                break
    return unique_filename


def extend_graph_by_element(loop_prefix, pick_element, current_id, current_position, hydrogen_ids):
    # has side effects
    callCounter("SelectionClearAllAtoms")
    callCounter("SelectionAtomById", " ".join([str(item) for item in hydrogen_ids]))
    callCounter("wait")

    # get mean of their positions
    bond_degree = len(hydrogen_ids)
    assert( 1 <= bond_degree <= 3)
    mean_pos = get_mean_position_of_hydrogens(current_position)

    # add atom, add bond, rescale, saturate, bondify
    replace_hydrogens_by_atom_rescale_and_saturate(pick_element, mean_pos, bond_degree, current_id)


def powerset(iterable):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))


def get_all_non_hydrogen_atoms():
    callCounter("SelectionPushAtoms")
    callCounter("SelectionClearAllAtoms")
    callCounter("SelectionAllAtoms")
    callCounter("SelectionNotAtomByElement", "H")
    callCounter("wait")
    atom_ids = callCounter("getSelectedAtomIds")
    callCounter("SelectionPopAtoms")
    return atom_ids


def get_all_hydrogen_atoms():
    callCounter("SelectionPushAtoms")
    callCounter("SelectionClearAllAtoms")
    callCounter("SelectionAtomByElement", "H")
    callCounter("wait")
    hydrogen_ids = callCounter("getSelectedAtomIds")
    callCounter("SelectionPopAtoms")
    return hydrogen_ids


def get_position_of_id(current_id):
    callCounter("SelectionPushAtoms")
    callCounter("SelectionClearAllAtoms")
    callCounter("SelectionAtomById", str(current_id))
    callCounter("wait")
    current_position = callCounter("getSelectedAtomPositions")[0]
    callCounter("SelectionPopAtoms")
    return current_position


def recurse(num_atoms, left_atoms):
    if left_atoms <= 0:
        return

    loop_prefix = prefix + '-' + str(num_atoms-left_atoms)

    # check if we can continue due to present hydrogens
    if not check_for_present_hydrogens():
        print("Stopping because there are no more hydrogens in the fully saturated system.")
        return

    atom_ids = get_all_non_hydrogen_atoms()
    assert( left_atoms == num_atoms-len(atom_ids) )

    # go through every element
    for pick_element in non_hydrogen_elements:
        print("Current element is " + pick_element)

        # go through all atoms
        for current_id in atom_ids:
            print("Current id is " + str(current_id))
            current_position = get_position_of_id(current_id)

            # get all bonded hydrogen ids of this element
            hydrogen_ids = get_all_hydrogen_bond_neighbors(current_id)
            print("Non-hydrogen atom has " + str(len(hydrogen_ids)) + " hydrogens")

            # pick every combination of hydrogen_ids till a maximum bond degree
            for pick_hydrogen_ids in powerset(hydrogen_ids):
                max_degree = min(MAX_DEGREE, valencies[pick_element])
                if 0 < len(pick_hydrogen_ids) <= max_degree:
                    print("Current set in powerset is " + str(pick_hydrogen_ids))
                    with UndoScope():
                        extend_graph_by_element(loop_prefix, pick_element, current_id, current_position,
                                                pick_hydrogen_ids)
                        atom_ids_check = get_all_non_hydrogen_atoms()
                        hydrogen_ids_check = get_all_hydrogen_atoms()
                        print("Non-hydrogens " +str(atom_ids_check)+", hydrogens "+str(hydrogen_ids_check))
                        assert( all([i not in hydrogen_ids_check for i in pick_hydrogen_ids]))
                        #assert( all([i not in atom_ids_check for i in pick_hydrogen_ids]))

                        # save current state
                        if left_atoms == 1:
                            unique_filename = ensure_unique_filename(loop_prefix + "_-_" + get_unique_filename(num_atoms), ".data")
                            callCounter("WorldOutputAs", unique_filename)
                            callCounter("wait")

                            # calculate energy
                            # calculate_energy(loop_prefix)
                        else:
                            recurse(num_atoms, left_atoms-1)


# main function
if __name__ == '__main__':
    prefix = "ChemicalGraph"
    if len(sys.argv) <= 3 :
        print("Usage: "+sys.argv[0]+" <number of final non-hydrogen atoms> <bond_table> [prefix]")
        sys.exit(1)

    num_atoms = int(sys.argv[1])
    bond_table = sys.argv[2]
    filename = sys.argv[3] if len(sys.argv) > 3 else ""
    prefix = prepare_run(bond_table = bond_table, filename = filename)

    callCounter("CommandVerbose", "1")

    # parsing homologies works as deserialization, i.e. it does not append
    # hence, we parse the current state here and later update it
    # parse_containers()

    callCounter("SelectionAllAtoms")
    callCounter("AtomRemove")
    callCounter("SelectionAllMolecules")
    callCounter("MoleculeRemove")
    # this may fail. Hence, wait here
    callCounter("wait")

    # create the test folder and step into
    with FolderScope(prefix):

        # loop init
        non_hydrogen_elements = list(filter(lambda x: x != "H", valencies.keys()))
        for initial_element in non_hydrogen_elements:
            print("Using "+initial_element+" as initial element")
            add_seed_atom(initial_element)
            saturate_all_atoms()

            # save current state and calculate energy if N=1
            if num_atoms==1:
                unique_filename = ensure_unique_filename(prefix+'-1_-_' + get_unique_filename(num_atoms), ".data")
                callCounter("WorldOutputAs",  unique_filename)
                callCounter("wait")

                # calculate energy
                # calculate_energy(prefix+'-1')

            # recursion
            recurse(num_atoms, num_atoms-1)

            # tabula rasa for next seed atom
            callCounter("SelectionAllAtoms")
            callCounter("AtomRemove")

    # exit
    #callCounter("WorldOutput")
    callCounter("wait")
    #store_containers()

    print("finished, "+str(counter_all)+" actions called.")