import h5py
import numpy as np

PI = np.pi
EVTOJ = 1.60217733E-19
PLANCK_SI = 6.626070040e-34
THZTOEV = 1e12 * PLANCK_SI / EVTOJ / 2 / PI

def abs2(x):
    """
    Get absolute squared value of a complex number
    """
    return x.real**2+x.imag**2

def get_cluster(eigk,eps=1e-6):
    """
    From a list of eigenvalues create a corresponding integer list with the
    degeneracies. For example: [ 1.2, 2.5, 2.5, 6.5] will produce [1 2 2 3]
    which is returned as a numpy array
    """
    eigi=0
    cluster = []
    #create clusters
    nbands = len(eigk)
    for ib in range(nbands-1):
        cluster.append(eigi)
        if abs(eigk[ib]-eigk[ib+1])>eps:
            eigi=eigi+1
    cluster.append(eigi)
    return np.array(cluster)

def get_clusters(eig):
    nspin,nkpoints,nbands = eig.shape
    clusters = np.zeros([nspin,nkpoints,nbands])
    for ispin in range(nspin):
        for ik in range(nkpoints):
            clusters[ispin,ik,:] = get_cluster(eig[ispin,ik,:])
    return clusters

class VaspElphG():
    def __init__(self,mels,vkpt_k,vkpt_kp,fbz2ibz,eig_k,eig_kp,pheigs):
        self.mels = mels
        self.vkpt_k = vkpt_k
        self.vkpt_kp = vkpt_kp
        self.fbz2ibz = fbz2ibz
        self.eig_k = eig_k
        self.eig_kp = eig_kp
        self.pheigs = pheigs
        self.nspin, self.nkpoints_kp, self.nkpoints_k, self.nmodes, self.nbands_k, self.nbands_kp = self.mels.shape
        # new attributes for k-path lengths
        self.path_len_k = None
        self.path_len_kp = None
        # compute path lengths immediately
        self.compute_k_path_lengths()

    @classmethod
    def from_file(cls,filename='vaspelph.h5'):
        elph = h5py.File(filename,'r')
        mels_re = elph['matrix_elements/elph'][:]
        mels =    mels_re[:,:,:,:,:,:,0]+\
               1j*mels_re[:,:,:,:,:,:,1]
        vkpt_k = elph['/kpoints/vkpt_k'][:]
        vkpt_kp = elph['/kpoints/vkpt_kp'][:]
        fbz2ibz = elph['kpoints/indx_fbz2ibz'][:]-1
        eig_k = elph['matrix_elements/eigenvalues_k'][:]
        eig_kp = elph['matrix_elements/eigenvalues_kp'][:]
        pheigs = elph['matrix_elements/phonon_eigenvalues'][:]*THZTOEV
        return cls(mels,vkpt_k,vkpt_kp,fbz2ibz,eig_k,eig_kp,pheigs)

    def compute_k_path_lengths(self, eps=1e-8):
        """
        Compute cumulative k-path lengths for vkpt_k and vkpt_kp.
        Rules:
          - First non-zero distance encountered defines the segment step size.
          - A distance equal (within eps) to the current segment step size increases the cumulative length.
          - A zero distance OR a distance differing from the current segment step size does not increase the cumulative length (segment boundary or marker).
          - After a differing distance (boundary), a new segment is started: that differing distance becomes the new step size but is not added.
        """
        def _compute(kpts):
            n = len(kpts)
            out = np.zeros(n, dtype=float)
            cum = 0.0
            step_size = None
            for i in range(1, n):
                d = np.linalg.norm(kpts[i] - kpts[i-1])
                if step_size is None:
                    step_size = d
                    cum += d
                elif abs(d - step_size) < eps:
                    cum += d
                else:
                    step_size = None
                out[i] = cum
            return out

        self.path_len_k = _compute(self.vkpt_k)
        self.path_len_kp = _compute(self.vkpt_kp)
        return self.path_len_k, self.path_len_kp

    def get_averaged_g(self):
        """
        The electron-phonon matrix elements are gauge-dependent.
        This is problematic when we want to compare the electron-phonon matrix
        elements of two different runs because for degenerate states one might get different
        linear combinations of orbitals.
        The best we can do for testing is to average the square of the matrix elements
        and then take the square root again and compare these arrays.
        """
        clusters_k = get_clusters(self.eig_k)
        clusters_kp = get_clusters(self.eig_kp)
        averaged_g = np.zeros([self.nspin,self.nkpoints_kp,self.nkpoints_k,self.nmodes,self.nbands_k,self.nbands_k])

        #average out over degenerate eigenvalues
        for ispin in range(self.nspin):
            for ik in range(self.nkpoints_k):
                ik_cluster = clusters_k[ispin,ik]
                for ikp in range(self.nkpoints_kp):
                    ikp_cluster = clusters_kp[ispin,ikp]
                    # group degenerate eigenvalues in clusters
                    pheig = self.pheigs[ikp,ik,:]
                    iq_cluster = get_cluster(pheig)

                    m2 = abs2(self.mels[ispin,ikp,ik,:,:,:])

                    #average over all q clusters
                    for iq in np.unique(iq_cluster):
                        idx = np.where(iq_cluster==iq)[0]
                        if (len(idx)==1): continue
                        average = np.average(m2[idx,:,:],axis=0)
                        for i in idx:
                            m2[i,:,:] = average

                    #average over all ikp clusters
                    for ikpc in np.unique(ikp_cluster):
                        idx = np.where(ikp_cluster==ikpc)[0]
                        if (len(idx)==1): continue
                        average = np.average(m2[:,idx,:],axis=1)
                        for i in idx:
                            m2[:,i,:] = average

                    #average over all ik clusters
                    for ikc in np.unique(ik_cluster):
                        idx = np.where(ik_cluster==ikc)[0]
                        if (len(idx)==1): continue
                        average = np.average(m2[:,:,idx],axis=2)
                        for i in idx:
                            m2[:,:,i] = average

                    averaged_g[ispin,ikp,ik,:,:,:]=np.sqrt(m2[:,:,:])
        return averaged_g