import astropy.io.fits as fits

import numpy as np

from pathlib import Path

def SpecPrimHeader(REFDATA):
    reffile = REFDATA + 'primary.txt'
    if not Path(reffile).exists():
        print(f"File {reffile} not found")
    else:
        reffile = Path(reffile)
    hdr = fits.Header.fromtextfile(reffile)
    return hdr

def Spec1Header(REFDATA):
    reffile = REFDATA + 'spec_type1.txt'
    if not Path(reffile).exists():
        print(f"File {reffile} not found")
    else:
        reffile = Path(reffile)
    hdr = fits.Header.fromtextfile(reffile)
    
    return hdr
    
def Spec2SRCHeader(REFDATA):
    reffile = REFDATA + 'spec_type2src.txt'
    if not Path(reffile).exists():
        print(f"File {reffile} not found")
    else:
        reffile = Path(reffile)
    hdr = fits.Header.fromtextfile(reffile)
    
    return hdr
    
def Spec2BGHeader(REFDATA):
    reffile = REFDATA + 'spec_type2bg.txt'
    if not Path(reffile).exists():
        print(f"File {reffile} not found")
    else:
        reffile = Path(reffile)
    hdr = fits.Header.fromtextfile(reffile)
    
    return hdr

def spec_pha2_table(counts,exposure,quality,tstart,tstop,numchans,trigtime,refdata,type='src',bkgfile=None):
    specnum = np.arange(len(counts))+1
    times = np.copy(tstart)
    endtimes = np.copy(tstop)
    times += trigtime
    endtimes += trigtime
    channels = np.arange(numchans)
    chans = np.full_like(counts,channels)

    specnum_col = fits.Column(name='SPEC_NUM', format='1I',
                             array=specnum)
    channels_col = fits.Column(name='CHANNELS', format=f"{numchans}J",
                             unit='chan', array=chans) 
    counts_col = fits.Column(name='COUNTS', format=f"{numchans}J",
                             unit='count', array=counts)
    exp_col  = fits.Column(name='EXPOSURE', format='1E',
                             unit='s', array=exposure)
    
    qual_col  = fits.Column(name='QUALITY', format='I',
                             array=quality)
                                 
    time_col = fits.Column(name='TIME', format='1D',
                           unit='s', array=times)
                           
    end_col = fits.Column(name='ENDTIME', format='1D',
                           unit='s', array=endtimes)
    if type=='src':
        # Create BACKFILE column based on bkgfile
        backfile = np.array([f'{bkgfile}{{{i+1}}}' for i in range(len(specnum))])
        back_col  = fits.Column(name='BACKFILE', format='40A',
                             array=backfile)
        hdu = fits.BinTableHDU.from_columns([specnum_col,channels_col,counts_col,exp_col,back_col,qual_col,time_col,end_col])
        hdu.header.update(Spec2SRCHeader(refdata))
    else:
        hdu = fits.BinTableHDU.from_columns([specnum_col,channels_col,counts_col,exp_col,qual_col,time_col,end_col])
        hdu.header.update(Spec2BGHeader(refdata))
    # Make sure these keywords are in the proper order
    hdu.header.set('1CDLT3',1,before='TTYPE4')
    hdu.header.set('1CDPX3',1,before='TTYPE4')
    hdu.header.set('1CRVL3',0,before='TTYPE4')
    hdu.header.set('1CUNI3','chan     ',before='TTYPE4')

    return hdu
    
def spec_table(counts,numchans,refdata):
    channels = np.arange(1,numchans+1)
    quality  = np.zeros_like(channels)

    channel_col = fits.Column(name='CHANNEL', format='I',
                             unit='chan', array=channels)
                             
    counts_col  = fits.Column(name='COUNTS', format='J',
                             unit='count', array=counts)
                             
    qual_col  = fits.Column(name='QUALITY', format='I',
                             array=quality)
    
    hdu = fits.BinTableHDU.from_columns([channel_col,counts_col,qual_col])
    hdu.header.update(Spec1Header(refdata))
    # Make sure these keywords are in the proper place
    hdu.header.set('TLMIN1',1,before='TUNIT1')
    hdu.header.set('TLMAX1',455,before='TUNIT1')
    return hdu
