import astropy.io.fits as fits


from gdt.core.headers import Header,FileHeaders
from gdt.core.phaii import Phaii
from gdt.core.data_primitives import Gti, TimeEnergyBins


from pathlib import Path

import numpy as np


def header_class_fromfile(name, filename, class_name = None):

    hdr = fits.Header.fromtextfile(filename)
    keywords = [(card.keyword, card.value, card.comment) for card in hdr.cards]

    if class_name is None:
        class_name = name
    
    return type(class_name, (Header,), {'name':name, 'keywords':keywords})
    
def PhaiiFileHeaders(refdata, class_name = None):

    _header_templates = []
    
    hdr_dict={'PRIMARY':'primary.txt',
              'ARRAY_PHA':'array_data.txt',
              'EBOUNDS':'ebounds.txt',
              'STDGTI':'gti.txt'}
    for name,filename in hdr_dict.items():
        reffile = refdata + filename
        if not Path(reffile).exists():
            print(f"File {reffile} not found")
        else:
            reffile = Path(reffile)
        _header_templates += [header_class_fromfile(name, reffile)()]

    if class_name is None:
        class_name = '_'.join(hdr_dict.keys())

    cls = type(class_name, (FileHeaders,), {'_header_templates':_header_templates})()

    return cls

class MyPhaii(Phaii):
    """An example to read and write PHAII files for xxx instrument"""
    @property
    def detector(self):
        """(str): The detector name"""
        if self.headers['PRIMARY'].get('DETNAM') is not None:
            return self.headers['PRIMARY']['DETNAM']
        else:
            return self.headers['PRIMARY']['INSTRUME']
        
    @classmethod
    def open(cls, file_path, refdata, **kwargs):
        with super().open(file_path, **kwargs) as obj:

            # an example of how to set the headers
            hdrs = [hdu.header for hdu in obj.hdulist]
            if 'TRIGTIME' not in hdrs[1]:
                hdrs[1]['TRIGTIME'] = None
            PhaiiHeaders = PhaiiFileHeaders(refdata)
            headers = PhaiiHeaders.from_headers(hdrs)
            
            pha = obj.hdulist[1].data
            counts = pha['COUNTS']
            times  = pha['TIME']
            trigtime = hdrs[1]['TRIGTIME']
            
            if trigtime is not None:
                times -= trigtime
            
            tdel = hdrs[1]['TIMEDEL']
            tstart = times
            tstop = tstart[1:]

            tstop = np.append(tstop,tstop[-1]+tdel)

            exposure = np.full_like(times,tdel)
            ebounds = obj.hdulist[2].data
            emin = ebounds['E_MIN']
            emax = ebounds['E_MAX']
            # an example of how to set the data
            data = TimeEnergyBins(counts, tstart, tstop, exposure, emin,
                                  emax)
                  
            gti = obj.hdulist[3].data
            gti_start = gti['START']
            gti_end   = gti['STOP']
            
            if trigtime is not None:
                gti_start -= trigtime
                gti_end -= trigtime
            # an example of how to set the GTI
            gti = Gti.from_bounds(gti_start, gti_end)
            
        return cls.from_data(data, gti=gti, trigger_time=trigtime,
                             filename=obj.filename, headers=headers)
                             
    def set_timedel(self,timedel):
        self.timedel = timedel
        
    def _build_hdulist(self):
        # create FITS and primary header
        hdulist = fits.HDUList()
        primary_hdu = fits.PrimaryHDU(header=self.headers['PRIMARY'])
        for key, val in self.headers['PRIMARY'].items():
            if key == 'DETNAM':
                primary_hdu.header.set('DETNAM',self.detector,'Detector name',after='INSTRUME')
            else:
                primary_hdu.header[key] = val
        hdulist.append(primary_hdu)
       
        
        # the phaii extension
        
        phaii_hdu = self._phaii_table()
        hdulist.append(phaii_hdu) 

        # the ebounds extension
     
        ebounds_hdu = self._ebounds_table()
        hdulist.append(ebounds_hdu)       
        
        # the GTI extension
   
        gti_hdu = self._gti_table()
        hdulist.append(gti_hdu)
 
        # Make sure keywords are set
        headers = self._build_headers(self.trigtime,*self.time_range,self.num_chans)
        for hdu in headers:
            for key,val in hdu.items():
                if key == 'DETNAM':
                    pass
                else:
                    self.headers[hdu.name][key] = val
                    hdulist[hdu.name].header[key] = val
        
        return hdulist

    def _build_headers(self,trigtime,tstart,tstop, num_chans):
        headers = self.headers.copy()

        for hdu in headers:
            hdu['CREATOR']  = 'BurstCube'
            hdu['ORIGIN']   = 'GSFC'
            if hdu.get('DETCHANS') is not None:
                hdu['DETCHANS'] = num_chans


        return headers
    
    
    def _ebounds_table(self):
        chan_col = fits.Column(name='CHANNEL', format='1I', unit='chan',
                               array=np.arange(self.num_chans, dtype=int))
        emin_col = fits.Column(name='E_MIN', format='1E', unit='keV', 
                               array=self.ebounds.low_edges())
        emax_col = fits.Column(name='E_MAX', format='1E', unit='keV', 
                               array=self.ebounds.high_edges())
                                                 
        hdu = fits.BinTableHDU.from_columns([chan_col, emin_col, emax_col])

        hdu.header.update(self.headers['EBOUNDS'])
        hdu.header['EXTNAME'] = 'EBOUNDS'
        hdu.header.set('TLMIN',np.min(hdu.data['CHANNEL']),before='TTYPE2')
        hdu.header.set('TLMAX',np.max(hdu.data['CHANNEL']),before='TTYPE2')
        return hdu
    
    
    def _phaii_table(self):
        times = np.copy(self.data.tstart)
        if self.trigtime is not None:
            times += self.trigtime

        pha_format = str(self.data.counts.shape[1])+"I"
        time_col = fits.Column(name='TIME', format='1D',
                               unit='s', array=times)
        
                            
        phaii_col = fits.Column(name='COUNTS', format=pha_format, array=self.data.counts,
                                unit='counts')
        
        lc_col = fits.Column(name='SUMCOUNTS', format='1I', array=np.sum(self.data.counts, axis=1),
                             unit='counts')
        
        hdu = fits.BinTableHDU.from_columns([time_col, phaii_col, lc_col])
        hdu.header.update(self.headers['ARRAY_PHA'])
        hdu.header['EXTNAME'] = 'ARRAY_PHA'
        hdu.header['TIMEDEL'] = self.timedel
        return hdu

    def _gti_table(self):
        tstart = np.array(self.gti.low_edges())
        tstop = np.array(self.gti.high_edges())
        if self.trigtime is not None:
            tstart += self.trigtime
            tstop += self.trigtime

        start_col = fits.Column(name='START', format='1D', unit='s', 
                                array=tstart)
        stop_col = fits.Column(name='STOP', format='1D', unit='s', 
                               array=tstop)
        hdu = fits.BinTableHDU.from_columns([start_col, stop_col])
        
        hdu.header.update(self.headers['STDGTI'])
        hdu.header['EXTNAME'] = 'STDGTI'
        return hdu
        
