#!/usr/bin/python

from numpy import *

import cPickle
import sys
import math
import random
import tpsinput
import takeinp
import abstracts


def calcVel(xi, xj):
    v = xi-xj
    for i in range(len(v)):
        for j in range(3):
            while v[i][j] > 0.5:
                v[i][j] -= 1.0
            while v[i][j] <= -0.5:
                v[i][j] += 1.0
    return v


def calcMomentum(v, m):
    p = zeros(len(v), float)
    for i in range(len(v)):
        p[i] = v[i]*m[i/3]
    return p


def newShootingPath(nsw, xdat, lattmat, allmasses, shootfactor):
    """shooting operation
    """
    cros = random.randint(0, nsw-1)
    # cros=100
    x = array(xdat[cros])
    v = calcVel(x, array(xdat[cros-1]))
    v_new = newVelocity(v, lattmat, allmasses, shootfactor)
    return x, v_new, cros


def cmVelocity(vel, masses):
    """calculates velocity of the
    ce nter of masses
    """
    cmvel = zeros(3, float)
    for i in range(len(masses)):
        cmvel[0] += vel[3*i]*masses[i]
        cmvel[1] += vel[3*i+1]*masses[i]
        cmvel[2] += vel[3*i+2]*masses[i]
    cmvel = cmvel/sum(masses)
    return cmvel


def newVelocity(vel, lattmat, allmasses, shootfactor):
    """
    Generates modified velocity for 
    a shooting operation. 
    """
    vel = dot(vel, lattmat)
    factor = (1.0-shootfactor**2)**0.5
    velo = zeros(3*len(vel), float)
    momentum = zeros(3*len(vel), float)
    for i in range(len(vel)):
        for j in range(3):
            velo[3*i+j] = vel[i][j]
            momentum[3*i+j] = vel[i][j]*allmasses[i]
    norm = (sum(velo**2))**0.5
    meanT = (sum(velo*momentum))

    velo_ = zeros(3*len(vel), float)
    momentum_ = zeros(3*len(vel), float)
    for i in range(len(velo_)):
        velo_[i] = abstracts.boltzmannRandom(
            (meanT/(3*allmasses[i/3]*len(allmasses)))**0.5)
    cmvel = cmVelocity(velo_, allmasses)

    for i in range(len(velo_)/3):
        velo_[3*i] = velo_[3*i]-cmvel[0]
        velo_[3*i+1] = velo_[3*i+1]-cmvel[1]
        velo_[3*i+2] = velo_[3*i+2]-cmvel[2]
    cmvel = cmVelocity(velo_, allmasses)

    norm_ = (sum(velo_**2))**0.5
    momentum_ = calcMomentum(velo_, allmasses)
    newT = sum(velo_*momentum_)

    velo_ = velo_*(meanT/newT)**0.5
    cmvel = cmVelocity(velo_, allmasses)
    newT = sum(velo_*momentum_)

    velo_para = sum(velo_*velo)/norm**2*velo
    velo_perp = velo_-velo_para
    norm_perp = (sum(velo_perp**2))**0.5

    velo_para = factor*velo
    velo_perp = norm/norm_perp*(1.0-factor**2)**0.5*velo_perp

    velo_ = velo_perp+velo_para
    momentum_ = calcMomentum(velo_, allmasses)
    newT = sum(velo_*momentum_)
    print meanT, newT
    velo_ = velo_*(meanT/newT)**0.5
    momentum_ = calcMomentum(velo_, allmasses)
    newT = sum(velo_*momentum_)
    print meanT, newT
    print sum(velo_*velo)/sum(velo**2)**0.5/sum(velo_**2)**0.5, factor

    cmvel = cmVelocity(velo_, allmasses)
    print 'cmvel_:', cmvel

    vel_ = []
    for i in range(len(vel)):
        vel_.append([velo_[3*i], velo_[3*i+1], velo_[3*i+2]])
    vel_ = array(vel_)
    lattinv = linalg.inv(lattmat)
    vel_ = dot(vel_, lattinv)
    return array(vel_)


def newTimeshiftPath(inpt, xdat, nsw):
    """time shifting operation
    """
    cros = random.randint(1, len(xdat))
    x = array(xdat[cros])
    inpt.coords_c = dot(x, lattmat)
    coord = abstracts.Coordinates()
    coord.getDefinition("ICONST", inpt)
    coord.readInput("INPFILE")
    test = coord.detect(inpt)
    if test == 1:
        x = array(xdat[-1])
        v = calcVel(x, array(xdat[-2]))
        print 2
    elif test == 2:
        x = array(xdat[0])
        v = calcVel(array(xdat[0]), array(xdat[1]))
        print 3
    else:
        return newTimeshiftPath(inpt, xdat, nsw)
    return x, v, cros, test


tps = tpsinput.Tpsinput()
shootfactor = tps.SHOOTFACTOR
timefactor = tps.TIMEFACTOR

cfile = 'trajectory.dat'
ufile = open(cfile, 'r')
xdat = cPickle.load(ufile)
ufile.close()

nsw = len(xdat)
inpt = takeinp.TakeInput('POSCAR')
lattmat = inpt.lattmat
atoms = inpt.atoms

masses = abstracts.readMasses('MASSES')
allmasses = []
for i in range(len(atoms)):
    tmp = atoms[i]*[masses[i]]
    allmasses = allmasses+tmp
allmasses = array(allmasses)

rnd = random.random()
if rnd > timefactor:
    x_f, v_f, cros = newShootingPath(
        nsw, xdat, lattmat, allmasses, shootfactor)
    x_f = x_f % 1
    x_b = x_f-v_f
    x_b = x_b % 1
    v_b = -v_f
    abstracts.write_poscar('POSCAR.f', x_f, v_f, lattmat)
    abstracts.write_poscar('POSCAR.b', x_b, v_b, lattmat)
    abstracts.write_incar('INCAR.f', cros)
    abstracts.write_incar('INCAR.b', nsw-cros)
    print 1
else:
    x, v, cros, test = newTimeshiftPath(inpt, xdat, nsw)
    x = x % 1
    if test == 1:
        abstracts.write_poscar('POSCAR.f', x, v, lattmat)
        abstracts.write_incar('INCAR.f', cros)
        abstracts.write_tmpXdat(xdat, cros, test)
    elif test == 2:
        abstracts.write_poscar('POSCAR.b', x, v, lattmat)
        abstracts.write_incar('INCAR.b', nsw-cros)
        abstracts.write_tmpXdat(xdat, cros, test)
