from numpy import *

from math import *
from mymath import *
import string
import random
import int_values


class Def_Complextype:
    def __init__(self, dtyp, coefs, tag, what, where, value):
        self.dtyp = dtyp           # sum/norm
        self.coefs = coefs
        self.tag = tag
        self.what = what
        self.where = where
        self.value = value
        self.labels = None
        self.a = [None, None]
        self.b = [None, None]

#######################################################################


class Xdatcarreader:
    def __init__(self):
        self.data = []
        self.lattmat = []

    def readFile(self, f, dynamic=1):
        if type(f) == type(""):
            f = open(f, 'r')
        else:
            return

        self.numconfig = 0
        local_dire = None
        coordtype = 'direct'
        calctype = 'nonselective'
        calctarget = 7
        atomtag = 0

        selective_Dummy = 0

        i = 0                                 # row counter
        lattmat_tmp = zeros((3, 3), float)

        # reads scaling factor and lattice parameters
        for line in f.readlines():
            i = i+1
            line = line.split()
            if i == 2:
                # factor for scaling of the lattice vectors
                self.scale = float(line[0])
            if i > 2 and i < 6:
                # lattice-vectors matrix
                lattmat_tmp[i-3] = [float(line[0]),
                                    float(line[1]), float(line[2])]
            if i == 6:
                if self.scale < 0:          # negative scale - required volume
                    self.volume = cross_product(lattmat_tmp[0], lattmat_tmp[1])
                    self.volume = abs(sum(self.volume*lattmat_tmp[2]))
                    self.scale = (abs(self.scale)/self.volume)**(1.0/3.0)
                lattmat_tmp = self.scale*lattmat_tmp
                self.lattmat.append(lattmat_tmp)
                self.volume = cross_product(lattmat_tmp[0], lattmat_tmp[1])
                self.volume = abs(sum(self.volume*lattmat_tmp[2]))
                self.tags_atoms = line
                lattmat_tmp = zeros((3, 3), float)
            if i == 7:
                katoms = []
                for j in range(len(line)):
                    katoms.append(int(line[j]))
                self.types = katoms
                self.ntypes = len(self.types)
                # self.atoms=katoms
                self.numofatoms = sum(katoms)
                lattmat_tmp = zeros((3, 3), float)
                coords = zeros((self.numofatoms, 3), float)

            if i == 8:
                if line[0][0] == 's' or line[0][0] == 'S':
                    calctype = 'selective'
                    selective_Dummy = 1
                elif (line[0][0]) == "C" or (line[0][0]) == "c":
                    coordtype = 'cart'

            if i == 9 and calctype == 'selective':
                if (line[0][0]) == "C" or (line[0][0]) == "c":
                    coordtype = 'cart'

            # print i,selective_Dummy
            # if i>8:
            if i > (7+selective_Dummy):
                indx = (i-(7+selective_Dummy)) % (self.numofatoms+1)
                if (indx == 0):
                    indx = self.numofatoms+1
                # print i,indx,line,self.numofatoms+8
                if indx > 1:
                    # print indx
                    # print line
                    coords[indx-2][0] = float(line[0])
                    coords[indx-2][1] = float(line[1])
                    coords[indx-2][2] = float(line[2])
                    if indx == self.numofatoms+1:
                        self.data.append(coords)
                        coords = zeros((self.numofatoms, 3), float)
                        self.numconfig += 1
                        # c static structural data - stop here
                        if dynamic == 0:
                            break

        for i in range(1, len(self.data)):
            for j in range(len(self.data[0])):
                for k in range(3):
                    if self.data[i][j][k]-self.data[i-1][j][k] > 0.5:
                        self.data[i][j][k] -= 1
                    if self.data[i][j][k]-self.data[i-1][j][k] < -0.5:
                        self.data[i][j][k] += 1

        self.lattinv = linalg.inv(self.lattmat[0])

        if coordtype == 'cart':
            for i in range(len(self.data)):
                coords_c = self.data[i]*self.scale
                coords_d = dot(coords_c, self.lattinv)
                self.data[i] = coords_d

        f.close()

########################################################################


def shortest_dist(cartesians, lattmat, atom1, atom2):
    """finds the shortest distance between two atoms
    """
    cart1 = cartesians[atom1]
    cart2 = cartesians[atom2]
    dists = []
    what = []

    for i in [-1, 0, 1]:
        for j in [-1, 0, 1]:
            for k in [-1, 0, 1]:
                trans = i*lattmat[0]+j*lattmat[1]+k*lattmat[2]
                point2 = cart2+trans
                dist = sum((cart1-point2)**2)**0.5
                dists.append(dist)
                what.append([i, j, k])

    minindex = 0
    for i in range(1, len(dists)):
        if dists[i] < dists[minindex]:
            minindex = i
    return [[0, 0, 0], [what[minindex][0], what[minindex][1], what[minindex][2]]]


def read_definition(f, crt, lattmat):
    """here we read allowed intervals 
    for reactant and product states
    """
    if type(f) == type(""):
        f = open(f, "r")
    inttags = []
    intwhat = []
    intwhere = []
    intcoefs = []
    complextype = []
    ctag = 0
    for line in f.readlines():
        if ctag == 0:
            line = string.split(line, chr(58))
            if ((line[0][0] == 'C') or (line[0][0] == 'c')):
                ctag = 1
                continue
            tagline = line[0]
            inttags.append(tagline)
            whatline = line[1].split()
            for i in range(len(whatline)):
                whatline[i] = int(whatline[i])
            intwhat.append(whatline)

            delement = array([0, 0, 0])

            #! minimal image convention
            if tagline == 'R':
                delement = array(shortest_dist(
                    crt, lattmat, whatline[0], whatline[1]))
            if tagline == 'M':
                delement1 = array(shortest_dist(
                    crt, lattmat, whatline[0], whatline[1]))
                delement2 = shortest_dist(
                    crt, lattmat, whatline[1], whatline[2])
                delement = [delement1[0], delement1[1],
                            delement1[1]+delement2[1]]
            if tagline == 'A':
                delement1 = shortest_dist(
                    crt, lattmat, whatline[1], whatline[0])
                delement2 = shortest_dist(
                    crt, lattmat, whatline[1], whatline[2])
                delement = [delement1[1], delement1[0], delement2[1]]
            if tagline == 'T':
                delement1 = shortest_dist(
                    crt, lattmat, whatline[1], whatline[0])
                delement2 = shortest_dist(
                    crt, lattmat, whatline[1], whatline[2])
                delement3 = shortest_dist(
                    crt, lattmat, whatline[2], whatline[3])
                delement = [delement1[1], delement1[0], delement2[1],
                            array(delement3[1])+array(delement2[1])]
            if tagline == 'RatioR':
                delement1 = array(shortest_dist(
                    crt, lattmat, whatline[0], whatline[1]))
                delement2 = array(shortest_dist(
                    crt, lattmat, whatline[2], whatline[3]))
                delement = [delement1[0], delement1[1],
                            delement2[0], delement2[1]]

            intwhere.append(delement)
        if ctag == 1:
            line = string.split(line, chr(58))
            comptype = string.split(line[1])
            complextype.append(comptype[0])
            coefs = string.split(line[0])
            for i in range(len(coefs)):
                coefs[i] = float(coefs[i])
            intcoefs.append(coefs)
    f.close()
    return inttags, intwhat, intwhere, intcoefs, complextype

################################################################################################


class Coordinates:
    def __init__(self):
        self.data = []

    def getDefinition(self, f, crt, lattmat):
        # lattmat=inpt.lattmat
        # crt=inpt.coords_c
        try:
            consttags, constwhat, constwhere, constcoefs, complextype =\
                read_definition(f, crt, lattmat)
        except IOError:
            numconst = 0
            print('check the definition file (ICOORD)')
        else:
            numconst = len(constwhat)
        if numconst > 0:
            if constcoefs == []:
                for i in range(numconst):
                    self.data.append(Def_Complextype('simple', [1], consttags[i], constwhat[i],
                                                     constwhere[i], 0.0))
            else:
                for i in range(len(constcoefs)):
                    if len(constcoefs[i]) != len(consttags):
                        print(
                            'incorect definition of complex constrained coordinate!!!')
                    else:
                        if len(constcoefs[i]) == 1:
                            self.data.append(Def_Complextype('simple', constcoefs[i], consttags[i], constwhat[i],
                                                             constwhere[i], 0.0))
                        else:
                            self.data.append(Def_Complextype(complextype[i], constcoefs[i], consttags, constwhat,
                                                             constwhere, 0.0))
                        # print coords[-1].what,coords[-1].where,coords[-1].tag,coords[-1].status,'ahoj'
        return self

    def readInput(self, f):
        """here we read definitions
        for order parameters
        """
        coords = self.data
        if type(f) == type(""):
            f = open(f, 'r')
        i = 0
        j = 0
        for line in f.readlines():
            line = string.split(line)
            tmp = [0.0, 0.0]
            if len(line) > 2:
                if (line[0][0] == 'a' or line[0][0] == 'A'):
                    if (i >= len(coords)):
                        print('Warning: State A bounds specified for a nonexisting coordinate (%d):\n%s' % (
                            i, line))
                    else:
                        tmp[0] = float(line[1])
                        tmp[1] = float(line[2])
                        tmp.sort()
                        coords[i].a = tmp
                        i += 1
                elif (line[0][0] == 'b' or line[0][0] == 'B'):
                    if (j >= len(coords)):
                        print('Warning: State B bounds specified for a nonexisting coordinate (%d):\n%s' % (
                            j, line))
                    else:
                        tmp[0] = float(line[1])
                        tmp[1] = float(line[2])
                        tmp.sort()
                        coords[j].b = tmp
                        j += 1
        f.close()

    def detect2(self, carts, lattmat):
        # inpt=Structure('POSCAR')
        # lattmat=inpt.lattmat
        # carts=inpt.coords_c
        coords = self.data

        result = 0
        if len(coords) > 0:
            deal = int_values.Int_values(carts, coords, lattmat)
            for i in range(len(coords)):
                print(coords[i].value)
        else:
            print('Error: Undefined coordinates')
            return None
        return result


####################################################
def write_poscar(filename, x, v, lattmat):
    """Writes catesian coords.
    """
    v = dot(v, lattmat)  # vasp reads velocities in cartesians!!!
    newline = []
    f = open('POSCAR', 'r')
    i = 0
    while i < 6:
        line = f.readline()
        if i == 0:
            comment = line
        if i == 5:
            atoms = line
        i = i+1
    f.close()
    f = open(filename, 'w')
    f.write(comment)
    f.write(str(1.000)+'\n')
    for i in range(3):
        f.write(string.join(map(str, lattmat[i]))+'\n')
    f.write(atoms)
    f.write('direct \n')
    for i in range(len(x)):
        f.write(string.join(map(str, x[i]))+'\n')
    f.write("\n")
    for i in range(len(v)):
        f.write(string.join(map(str, v[i]))+'\n')
    f.write("\n1\n1.0\n0.0 0.0 0.0 0.0\n")
    for i in range(len(v)):
        f.write(string.join(map(str, x[i]))+'\n')
    f.write("0.0 0.0 0.0\n"*(len(x)*2))
    f.close()


def write_incar(filename, nsw):
    """Writes catesian coords.
    """
    f = open('INCAR', 'r')
    g = open(filename, 'w')
    for line in f.readlines():
        tags = string.split(line, '=')
        if len(tags) > 1:
            tags[0] = string.strip(tags[0])
            tags[1] = string.strip(tags[1])
            tags[1] = string.split(tags[1])
            tags[0] = string.upper(tags[0])
            if tags[0] == 'NSW':
                tags[1] = str(nsw)
                g.write(string.join([tags[0], '=', tags[1]])+'\n')
            elif tags[0] == 'TEBEG':
                continue
            else:
                g.write(line)
    g.close()
    f.close()


def write_tmpXdat(xdat, cros, test):
    f = open('XDATCAR.tmp', 'w')
    f.write('\n\n\n\n\n\n')
    if test == 1:
        for i in range(1, len(xdat)-cros+1):
            indx = (len(xdat)-1)-i
            for j in range(len(xdat[0])):
                f.write(string.join(map(str, xdat[indx][j]))+'\n')
            f.write('\n')
    elif test == 2:
        for i in range(1, cros+1):
            indx = i
            for j in range(len(xdat[0])):
                f.write(string.join(map(str, xdat[indx][j]))+'\n')
            f.write('\n')
    else:
        f.close()
        return
    f.close()


def readMasses(f):
    f = open(f, 'r')
    masses = []
    for line in f.readlines():
        line = string.split(line)
        if len(line) > 0:
            masses.append(float(line[0]))
    return masses


def boltzmannRandom(width):
    x = random.random()
    y = random.random()
    z = width*cos(2*pi*x)*(-2*log(y))**0.5
    return z

###################################################


def calcVel(xi, xj):
    v = xi-xj
    for i in range(len(v)):
        for j in range(3):
            while v[i][j] > 0.5:
                v[i][j] -= 1.0
            while v[i][j] <= -0.5:
                v[i][j] += 1.0
    return v


def calcMomentum(v, m):
    p = zeros(len(v), float)
    for i in range(len(v)):
        p[i] = v[i]*m[i/3]
    return p


def newShootingPath(nsw, xdat, lattmat, allmasses, shootfactor):
    """shooting operation
    """
    cros = random.randint(0, nsw-1)
    # cros=100
    x = array(xdat[cros])
    v = calcVel(x, array(xdat[cros-1]))
    v_new = newVelocity(v, lattmat, allmasses, shootfactor)
    return x, v_new, cros


def cmVelocity(vel, masses):
    """calculates velocity of the
    ce nter of masses
    """
    cmvel = zeros(3, float)
    for i in range(len(masses)):
        cmvel[0] += vel[3*i]*masses[i]
        cmvel[1] += vel[3*i+1]*masses[i]
        cmvel[2] += vel[3*i+2]*masses[i]
    cmvel = cmvel/sum(masses)
    return cmvel


def brandNewVelocity(meanT, lattmat, allmasses):
    kb = 1.3807e-23
    velo_ = zeros(3*len(allmasses), float)
    momentum_ = zeros(3*len(allmasses), float)
    for i in range(len(velo_)):
        velo_[i] = boltzmannRandom((meanT*kb/allmasses[i/3])**0.5)
    cmvel = cmVelocity(velo_, allmasses)

    for i in range(len(velo_)/3):
        velo_[3*i] = velo_[3*i]-cmvel[0]
        velo_[3*i+1] = velo_[3*i+1]-cmvel[1]
        velo_[3*i+2] = velo_[3*i+2]-cmvel[2]
    cmvel = cmVelocity(velo_, allmasses)

    momentum_ = calcMomentum(velo_, allmasses)
    newT = sum(velo_*momentum_)/(3*(len(allmasses)-1)*kb)
    print(newT)
    velo_ = velo_*(meanT/newT)**0.5
    momentum_ = calcMomentum(velo_, allmasses)
    newT = sum(velo_*momentum_)/(3*(len(allmasses)-1)*kb)
    print(newT)
    velo_ *= 1e-5

    # norm_=(sum(velo_**2))**0.5
    # momentum_=calcMomentum(velo_,allmasses)
    # newT=sum(velo_*momentum_)

    # velo_=velo_*(meanT/newT)**0.5
    # cmvel=cmVelocity(velo_,allmasses)
    # newT=sum(velo_*momentum_)

    vel_ = []
    for i in range(len(allmasses)):
        vel_.append([velo_[3*i], velo_[3*i+1], velo_[3*i+2]])
    vel_ = array(vel_)
    lattinv = linalg.inv(lattmat)
    vel_ = dot(vel_, lattinv)
    return array(vel_)


def newVelocity(vel, lattmat, allmasses, shootfactor):
    """
    Generates modified velocity for
    a shooting operation. 
    """
    vel = dot(vel, lattmat)
    factor = (1.0-shootfactor**2)**0.5
    velo = zeros(3*len(vel), float)
    momentum = zeros(3*len(vel), float)
    for i in range(len(vel)):
        for j in range(3):
            velo[3*i+j] = vel[i][j]
            momentum[3*i+j] = vel[i][j]*allmasses[i]
    norm = (sum(velo**2))**0.5
    meanT = (sum(velo*momentum))

    velo_ = zeros(3*len(vel), float)
    momentum_ = zeros(3*len(vel), float)
    for i in range(len(velo_)):
        velo_[i] = boltzmannRandom(
            (meanT/(3*allmasses[i/3]*len(allmasses)))**0.5)
    cmvel = cmVelocity(velo_, allmasses)

    for i in range(len(velo_)/3):
        velo_[3*i] = velo_[3*i]-cmvel[0]
        velo_[3*i+1] = velo_[3*i+1]-cmvel[1]
        velo_[3*i+2] = velo_[3*i+2]-cmvel[2]
    cmvel = cmVelocity(velo_, allmasses)

    norm_ = (sum(velo_**2))**0.5
    momentum_ = calcMomentum(velo_, allmasses)
    newT = sum(velo_*momentum_)

    velo_ = velo_*(meanT/newT)**0.5
    cmvel = cmVelocity(velo_, allmasses)
    newT = sum(velo_*momentum_)

    velo_para = sum(velo_*velo)/norm**2*velo
    velo_perp = velo_-velo_para
    norm_perp = (sum(velo_perp**2))**0.5

    velo_para = factor*velo
    velo_perp = norm/norm_perp*(1.0-factor**2)**0.5*velo_perp

    velo_ = velo_perp+velo_para
    momentum_ = calcMomentum(velo_, allmasses)
    newT = sum(velo_*momentum_)
    # print meanT,newT
    velo_ = velo_*(meanT/newT)**0.5
    momentum_ = calcMomentum(velo_, allmasses)
    newT = sum(velo_*momentum_)
    # print meanT,newT
    # print sum(velo_*velo)/sum(velo**2)**0.5/sum(velo_**2)**0.5,factor

    cmvel = cmVelocity(velo_, allmasses)
    # print 'cmvel_:',cmvel

    vel_ = []
    for i in range(len(vel)):
        vel_.append([velo_[3*i], velo_[3*i+1], velo_[3*i+2]])
    vel_ = array(vel_)
    lattinv = linalg.inv(lattmat)
    vel_ = dot(vel_, lattinv)
    return array(vel_)


def newTimeshiftPath(inpt, xdat, nsw):
    """time shifting operation
    """
    cros = random.randint(1, len(xdat))
    x = array(xdat[cros])
    inpt.coords_c = dot(x, lattmat)
    coord = Coordinates()
    coord.getDefinition("ICONST", inpt)
    coord.readInput("INPFILE")
    test = coord.detect(inpt)
    if test == 1:
        x = array(xdat[-1])
        v = calcVel(x, array(xdat[-2]))
        print(2)
    elif test == 2:
        x = array(xdat[0])
        v = calcVel(array(xdat[0]), array(xdat[1]))
        print(3)
    else:
        return newTimeshiftPath(inpt, xdat, nsw)
    return x, v, cros, test
