#
#  Requires the definition of a number of paramaters used
#  for the fitting of the background, e.g. fprb, poly, niter,  
#  which default values are optimized for a typical flat background.
#  The paramter bnum set the number of block before and after the burst   
#  that the task attempts to use to define the background 
#  The background fitting
#
# Step 1 Check and load input files and parameters
# Step 2 Create combined lightcurve data using all detectors
# Step 3 Use Bayesian block analysis to find burst time range as well as T90, T50,
#        Peak, and background through an interative fitting.
# Step 4 Create lightcurve plot identifying where burst is located. Also
#        output the lightcurve data in text and FITS format.
# Step 5 Create FITS file containing Bayesian Block data
# Step 6 Create a multiextension GTI FITS file that contains the following times:
#        TOTAL(burst time),T90,T50,PEAK,and BKG
#
# The Bayesian block iterative fitting is from "BurstCube Analysis and Simulation Package"  
# The Background fitting is from the gdt library
#
#        James Runge and Lorella Angelini Feb 2024 
#

import sys
import os
import argparse
from argparse import RawTextHelpFormatter
import datetime
import logging

import astropy.io.fits as fits
from astropy.time import Time
from astropy.stats import bayesian_blocks


from gdt.core.data_primitives import Gti, TimeBins
from gdt.core.collection import DataCollection
from gdt.core.plot.lightcurve import Lightcurve

from gdt.core.binning.binned import combine_into_one

from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt

from heasoftpy.burstcube.lib.phaii import MyPhaii
from heasoftpy.burstcube.lib.bbint import Bayes
from heasoftpy.burstcube.lib import bsttime_files as bbfile
from heasoftpy.burstcube.lib import lc
from heasoftpy.burstcube.lib.io import get_filelist, clob_check
from heasoftpy.burstcube.lib.bayesian_lc import BayesianBlocksLightcurve
from heasoftpy.burstcube.bcversion_task import bcversion

import warnings

# Suppress Matplotlib user warnings
warnings.filterwarnings("ignore", category=UserWarning)



def bctimebst(infile, outprefix, timecol, cntscol, emin=45.3, emax=316,
              fprob=0.05, poly=3, niter=100, bnum=5, refdata='REFDATA',
              chatter=1, clobber=False, phaii_dict=None):

    # ----STEP 1 START -------#
    # Set logger parameters
    logger = logging.getLogger('bctimebst')
    # Check that chatter is reasonable
    if chatter < 0 or chatter > 5:
        print(f'Chatter value {chatter} is not valid. Must be 0-5.')
        return 1
    # Create dictionary to set level according to chatter value
    chat_level = {0: 40, 1: 20, 5: 10}
    # Chatter values 1 through 4 are all the same for now
    if 1 <= chatter <= 4:
        chatter = 1
    # Set logger level
    logger.setLevel(chat_level[chatter])
    version = bcversion()
    logger.info(f'VERSION: {version}')

    # Begin checking all input parameters
    logger.info("Verifying input parameters")
    # Check refdata
    if refdata.strip() == '':
        logger.error('Input refdata is empty.')
        return 1
    # If default
    if refdata == 'REFDATA':
        # Check environmental setting
        if os.getenv('LHEA_DATA'):
            refdata = os.environ['LHEA_DATA'] + '/'
        else:
            logger.error("LHEA_DATA is not defined in current environment.")
            return 1
    else:
        # Check that directory exists
        if Path(refdata).exists() and Path(refdata).is_dir():
            # Catch if path does not end with /
            if refdata[-1] != '/':
                refdata = refdata + '/'
        else:
            logger.error(f"Directory {refdata} does not exist.")
            return 1
    # End refdata check

    # If phaii_dict is not given, load files
    if phaii_dict is None:
        # Start checking if input files exist
        if infile.strip() == '':
            logger.error('Input infile name(s) is empty.')
            return 1
        else:
            file_list = get_filelist(infile, logger)
            # If file not found
            if not isinstance(file_list, np.ndarray):
                return 1

        # Initialize array to hold instrument names
        detectors = []
        # Initialize dictionary to hold PHAII data
        phaii_dict = DataCollection()
        # Check that files exist and get instrument name
        for file in file_list:
            if not Path(file).exists():
                logger.error(f"File {file} not found.")
                return 1
            else:
                phaii = MyPhaii.open(file, refdata)
                det = phaii.detector
                detectors.append(det)
                phaii_dict.include(phaii, det)

        logger.info(f"Using files:{file_list}")
        # End checking input files
    # End if phaii_dict=None
    else:
        # phaii_dict is given, retrieve detectors
        detectors = phaii_dict.items
    # Sort detectors alphabetically
    detectors.sort()
    # End checking input files

    # Start checking Erange
    # if high end is lower than low end, throw error
    try:
        erange = (float(emin), float(emax))
    except:
        logger.error(f"{emin},{emax} contains non-float values.")
        return 1
    if emin >= emax:
        logger.error(
            f"Lower bound for energy range {emin} is larger than or equal to higher bound {emax}.")
        return 1
    # End checking Erange

    # ----STEP 1 STOP -------#

    # ----STEP 2 START -------#
    # Set up array for T-ranges
    T_ranges = [0.9, 0.5]

    logger.info("Combining lightcurves")
    # Begin process of combining lightcurves for a given energy range
    counts = None

    # Loop over all detectors
    for det in detectors:

        phaii = phaii_dict.get_item(det)
        trigtime = phaii.trigtime

        # Workaround for bug in GBM data tools.
        # The first and last bin have an incorrect exposure
        phaii = phaii.slice_time(
            [(phaii.data.tstart[1], phaii.data.tstop[-2])])

        # Get lightcurve for a given energy range
        phaii = phaii.slice_energy(erange)
        phaii = phaii.rebin_energy(combine_into_one)

        lightcurve = phaii.to_lightcurve()

        if counts is None:
            counts = lightcurve.counts
        else:
            counts += lightcurve.counts

    # Create lightcurve based on combined counts
    lightcurve = TimeBins(counts, lightcurve.lo_edges,
                          lightcurve.hi_edges, lightcurve.exposure)

    # Strip leading zeros
    first_nonzero_idx = np.argmax(lightcurve.counts != 0)
    lightcurve = lightcurve.slice(lightcurve.lo_edges[first_nonzero_idx],lightcurve.hi_edges[-1])

    # Strip tailing zeros
    reversed_idx = np.argmax(lightcurve.counts[::-1] != 0)
    original_idx = len(lightcurve.counts) - 1 - reversed_idx
    lightcurve = lightcurve.slice(lightcurve.lo_edges[0],lightcurve.hi_edges[original_idx])

    # ----STEP 2 STOP -------#

    # ----STEP 3 START -------#
    logger.info("Doing Bayesian Block analysis")
    # Initialize bayesian block class
    burst = BayesianBlocksLightcurve(lightcurve)

    # Do bayesian block analysis and find times for burst
    burst.compute_bayesian_blocks(poly,fprob,bnum,max_iter=niter)
    
    durations = burst.signal_range(T_ranges)
    

    if durations is not None and isinstance(durations, (list, tuple, np.ndarray)):

        Total = Gti.from_bounds(
            (burst.signal_range().tstart), (burst.signal_range().tstop))
        Bkg = Gti.from_bounds(
            (lightcurve.lo_edges[0], Total[0].tstop), (Total[0].tstart, lightcurve.hi_edges[-1]))

        T90 = Gti.from_bounds((durations[0].tstart), (durations[0].tstop))
        T50 = Gti.from_bounds((durations[1].tstart), (durations[1].tstop))
                              
        BSTTOT = Total[0].duration
        BSTT90 = T90[0].duration
        BSTT50 = T50[0].duration
        #BSTPEAK = burst.get_peak()
        BSTPEAK = burst._peak
        
        logger.info("\tStart\t\tStop\t\tDuration")
        logger.info(
            f"Total:\t{Total[0].tstart:.3f}\t\t{Total[0].tstop:.3f}\t\t{BSTTOT:.3f}")
        logger.info(
            f"T90:\t{T90[0].tstart:.3f}\t\t{T90[0].tstop:.3f}\t\t{BSTT90:.3f}")
        logger.info(
            f"T50:\t{T50[0].tstart:.3f}\t\t{T50[0].tstop:.3f}\t\t{BSTT50:.3f}")
        peak_found = 1
    else:
        logger.warn("WARNING: Could not identify a burst.")
        logger.warn("WARNING: GTI set to default (total time).")
        Total = Gti.from_bounds(*lightcurve.range)
        BSTTOT = -99.00
        BSTT90 = -99.00
        BSTT50 = -99.00
        BSTPEAK = -99.00
        peak_found = 0
        # End of bayesian analysis
    # ----STEP 3 STOP -------#

    # ----STEP 4 START -------#
    logger.info("Creating lightcurve plot")
    fig, ax = plt.subplots()

    lightcurve = burst.lightcurve

    # Define background rate and bkg_err
    bkg = burst._lc_bkg_counts/lightcurve.exposure
    bkg = np.where(bkg < 0, 0, bkg)
    bkg_err = np.sqrt(bkg)

    plot = Lightcurve(data=lightcurve, ax=ax)
    plot.lightcurve.color = 'gray'

    plot_bayes = Lightcurve(data=burst._lc_bayes, ax=ax)
    plot_bayes.lightcurve.color = 'navy'

    # Plot background manually
    ax.plot(lightcurve.centroids, bkg,
            color='red', ls=':', label="Fitted background")

    if peak_found:
        # Vertical lines showing the start and stop of the identified signal
        '''
        ax.axvline(burst.get_signal_range().tstart, ls="--",
                   color='olive', label="Signal start/stop")
        ax.axvline(burst.get_signal_range().tstop, ls="--", color='olive')
        '''
        ax.axvline(burst.signal_range().tstart, ls="--",
                   color='olive', label="Signal start/stop")
        ax.axvline(burst.signal_range().tstop, ls="--", color='olive')
        
    if trigtime is None or trigtime == 0.0:
        ax.set_xlabel("Time (s)")
    else:
        ax.set_xlabel("Time - Trigger Time (s)")

    ax.set_ylim(bottom=np.amin(lightcurve.rates*0.9), top=np.amax(lightcurve.rates*1.2))

    # Plot legend
    if peak_found:
        ax.legend(["Lightcurve", "Bblocks",
                  "Fitted Background", "Signal start/stop"])
    else:
        ax.legend(["Lightcurve", "Bblocks", "Fitted Background"])

    fig.tight_layout()
    # Save figure
    # Existence check
    fig_filename = f'{outprefix}_signal_id.png'
    status = clob_check(clobber, fig_filename, logger)
    if status:
        return 1
    else:
        fig.savefig(fig_filename)

    # Write to lightcurve data text file
    lc_filename = f"{outprefix}_total_lc.txt"
    logger.info(f"Writing lightcurve text file {lc_filename}")
    # Existence check
    status = clob_check(clobber, lc_filename, logger)
    if status == 1:
        return 1
    else:
        lc_hdr = "Time, Rate, Bkg"
        lc_data = np.stack(
            (lightcurve.centroids, lightcurve.rates, bkg), axis=1)
        np.savetxt(lc_filename, lc_data, fmt='%f',
                   header=lc_hdr, delimiter=',', newline='\n')

    ## -----Write FITS files-----##

    logger.info(f"Writing {outprefix}_burst.lc")
    # Begin preparing lightcurve file
    prim_hdu = fits.PrimaryHDU(header=lc.LCPrimHeader(refdata))
    lc_hdu = lc.lc_table(lightcurve, bkg, bkg_err, trigtime, refdata)

    # Copy header info from PHAII file
    prim_hdu.header = copyheader(phaii.headers[0], prim_hdu.header)
    lc_hdu.header = copyheader(phaii.headers['ARRAY_PHA'], lc_hdu.header)
    prim_hdu.header['INSTRUME'] = 'CSA'
    lc_hdu.header['INSTRUME'] = 'CSA'
    # Assign values to keywords
    lc_hdu.header['E_MIN'] = erange[0]
    lc_hdu.header['E_MAX'] = erange[1]

    # Add keywords
    bstdur_keys(lc_hdu.header, 'E_UNIT', BSTTOT, BSTT90, BSTT50, BSTPEAK)

    # Put extensions into list
    hdul = fits.HDUList([prim_hdu, lc_hdu])

    # Update DATE
    updateDATE(hdul)

    # Write lightcurve to file
    # Existence check
    lcfits_filename = f'{outprefix}_burst.lc'
    status = clob_check(clobber, lcfits_filename, logger)
    if status == 1:
        return 1
    else:
        hdul.writeto(lcfits_filename, overwrite=clobber, checksum=True)

    # End lightcurve file
    # ----STEP 4 STOP -------#

    # ----STEP 5 START -------#
    # Begin preparing bb file
    logger.info(f"Writing {outprefix}_bblock.fits")
    prim_hdu = fits.PrimaryHDU(header=lc.LCPrimHeader(refdata))
    bb_hdu = bbfile.bb_table(
        burst._lc_bayes, burst._lc_bb_index, trigtime, refdata)
    # Copy header info from PHAII file
    prim_hdu.header = copyheader(phaii.headers[0], prim_hdu.header)
    bb_hdu.header = copyheader(phaii.headers['ARRAY_PHA'], bb_hdu.header)
    prim_hdu.header['INSTRUME'] = 'CSA'
    bb_hdu.header['INSTRUME'] = 'CSA'

    # Put extensions into list
    hdul = fits.HDUList([prim_hdu, bb_hdu])

    # Update DATE
    updateDATE(hdul)

    # Write bb data to file
    # Existence check
    bblock_fits = f'{outprefix}_bblock.fits'
    status = clob_check(clobber, bblock_fits, logger)
    if status == 1:
        return 1
    else:
        hdul.writeto(bblock_fits, overwrite=clobber, checksum=True)

    # End bb file
    # ----STEP 5 STOP -------#

    # ----STEP 6 START -------#
    ## Begin GTI file ##

    logger.info(f"Writing {outprefix}.gti")
    # Create extensions
    tot_hdu = bbfile.gti_table(
        Total.low_edges(), Total.high_edges(), trigtime, refdata)
    # If peak is found
    if peak_found == 1:
        # Calculate peak time interval
        #peak_index = np.searchsorted(lightcurve.hi_edges, burst.get_peak())
        peak_index = np.searchsorted(lightcurve.hi_edges, burst._peak)
        # Create extensions
        t90_hdu = bbfile.gti_table(
            T90.low_edges(), T90.high_edges(), trigtime, refdata)
        t50_hdu = bbfile.gti_table(
            T50.low_edges(), T50.high_edges(), trigtime, refdata)
        peak_hdu = bbfile.gti_table(np.array([lightcurve.lo_edges[peak_index]]),
                                    np.array(
                                        [lightcurve.hi_edges[peak_index]]),
                                    trigtime, refdata)
        bkg_hdu = bbfile.gti_table(
            Bkg.low_edges(), Bkg.high_edges(), trigtime, refdata)

    # Copy header info and set extname
    tot_hdu.header = copyheader(phaii.headers['ARRAY_PHA'], tot_hdu.header)
    bstdur_keys(tot_hdu.header, 'TSTOP', BSTTOT, BSTT90, BSTT50, BSTPEAK)
    tot_hdu.header['INSTRUME'] = 'CSA'
    tot_hdu.header['EXTNAME'] = 'TOTAL'
    if peak_found == 1:
        hdu_list = [t90_hdu, t50_hdu, peak_hdu, bkg_hdu]
        ext_names = ['T90', 'T50', 'PEAK', 'BKG']
        for ext_hdu, name in zip(hdu_list, ext_names):
            ext_hdu.header = copyheader(tot_hdu.header, ext_hdu.header)
            bstdur_keys(ext_hdu.header, 'TSTOP',
                        BSTTOT, BSTT90, BSTT50, BSTPEAK)
            ext_hdu.header['INSTRUME'] = 'CSA'
            ext_hdu.header['EXTNAME'] = name

    # Put extensions into list
    if peak_found == 1:
        hdul = fits.HDUList([prim_hdu, tot_hdu, t90_hdu,
                            t50_hdu, peak_hdu, bkg_hdu])
    else:
        hdul = fits.HDUList([prim_hdu, tot_hdu])

    # Update DATE
    updateDATE(hdul)

    # Write GTIs to file
    # Existence check
    gti_filename = f'{outprefix}.gti'
    status = clob_check(clobber, gti_filename, logger)
    if status == 1:
        return 1
    else:
        hdul.writeto(gti_filename, overwrite=clobber, checksum=True)

    ## End GTI file ##
    # ----STEP 6 STOP -------#


# Copy basic info from one header to another
def copyheader(fromheader, toheader):
    # Keywords to skip
    exclude = ['TFORM', 'TTYPE', 'TUNIT', 'HDUCL']
    # Iterate through keywords
    for key, val in fromheader.items():
        if any(keyword in key[:5] for keyword in exclude):
            continue
        if key in toheader.keys() and key != 'EXTNAME':
            toheader[key] = val

    return toheader

# Update DATE keyword in headers


def updateDATE(hdu_list):
    current_datetime = Time(datetime.datetime.now(), scale='utc')
    for hdu in hdu_list:
        header = hdu.header
        if 'DATE' in header:
            header['DATE'] = current_datetime.isot[:19]


# Set time duration keywords
def bstdur_keys(header, after_key, bsttot, bstt90, bstt50, bstpeak):
    ''' after_key: Keyword in header to place time duration keywords after'''

    header.set('BSTTOT', round(bsttot, 3),
               'Total duration in seconds', after=after_key)
    header.set('BSTT90', round(bstt90, 3),
               'Duration of T90 in seconds', after='BSTTOT')
    header.set('BSTT50', round(bstt50, 3),
               'Duration of T50 in seconds', after='BSTT90')
    header.set('BSTPEAK', round(bstpeak, 3),
               'Peak position since TRIGTIME in seconds', after='BSTT50')


def main():
    desc = """
    Program to find best burst time based upon combined lightcurve.
    """

    parser = argparse.ArgumentParser(
        description=desc, formatter_class=RawTextHelpFormatter)
    parser.add_argument('--infile', action="store", type=str,
                        help="List of input FITS files or @file containing list")
    parser.add_argument('--outprefix', type=str,
                        help='Prefix of output files')
    parser.add_argument('--timecol', type=str,
                        help='Time column of input files')
    parser.add_argument('--cntscol', type=str,
                        help='Counts column of input files')
    parser.add_argument('--emin', type=float,
                        help='Minimum energy (keV) to consider',
                        default=45.3)
    parser.add_argument('--emax', type=float,
                        help='Maximum energy (keV) to consider',
                        default=316)
    parser.add_argument('--fprob', type=float,
                        help='False alarm probability for Bayesian analysis',
                        default=0.05)
    parser.add_argument('--poly', type=int,
                        help='Order of polynomial for background fitting',
                        default=3)
    parser.add_argument('--niter', type=int,
                        help='Maximum number of iterations for fitting',
                        default=100)
    parser.add_argument('--bnum', type=int,
                        help='Number of buffer blocks to exclude for background fitting',
                        default=5)
    parser.add_argument('--refdata', type=str,
                        help='Reference data directory',
                        default='REFDATA')
    parser.add_argument('--chatter', type=int,
                        default=1, help='Level of verboseness.')
    parser.add_argument('--log', type=str,
                        default='no', help='Create log file?')
    parser.add_argument('--clobber', type=str,
                        default='no', help='Overwrite existing output file(s)?')

    args = parser.parse_args()
    args_dict = vars(args)
    # Loop through input parameters and prompt user if a required value is not provided
    for param, val in args_dict.items():
        if val is None:
            # Print help information for the parameter
            for action in parser._actions:
                if action.option_strings and action.option_strings[0] == f'--{param}':
                    # Prompt user for input
                    arg_type = action.type if hasattr(action, 'type') else str
                    # Make sure input is valid type
                    valid = False
                    while not valid:
                        new_value = input(f"{action.help}: ")
                        try:
                            new_value = arg_type(new_value)
                            setattr(args, param, new_value)
                            valid = True
                        except:
                            type_str = str(arg_type).split("'")[1]
                            print(
                                f"{param} must be a valid {type_str}. Try again.")

    infile = args.infile
    outprefix = args.outprefix
    timecol = args.timecol
    cntscol = args.cntscol
    emin = args.emin
    emax = args.emax
    fprob = args.fprob
    poly = args.poly
    niter = args.niter
    bnum = args.bnum
    refdata = args.refdata
    chatter = args.chatter
    log = args.log in ['yes', 'y', True]  # Simple check
    clob = args.clobber in ['yes', 'y', True]  # Simple check

    logger = logging.getLogger('bctimebst')
    # Set level based upon chatter level
    # Check that chatter is reasonable
    if chatter < 0 or chatter > 5:
        print(f'Chatter value {chatter} is not valid. Must be 0-5.')
        sys.exit()
    # Create dictionary to set level according to chatter value
    chat_level = {0: 40, 1: 20, 5: 10}
    # Chatter values 1 through 4 are all the same
    if 1 <= chatter <= 4:
        chatter = 1

    # If log file is wanted
    if log:
        # Create a file handler and set the logging level
        # Get current date and time
        now = datetime.datetime.now()
        date_time = now.strftime("%d%m%Y_%H%M%S")
        log_filename = f'bctimebst_{date_time}.log'
        file_handler = logging.FileHandler(log_filename)
        # Create a formatter and add it to the handlers
        formatter = logging.Formatter(
            '%(asctime)s - %(levelname)s - %(message)s')
        file_handler.setFormatter(formatter)
        file_handler.setLevel(chat_level[chatter])
        logger.addHandler(file_handler)

    bctimebst(infile, outprefix, timecol, cntscol, emin, emax,
              fprob, poly, niter, bnum, refdata, chatter, clob)


if __name__ == '__main__':
    main()
