# Requirements:
#   a) Array pha files from bcrebevt
#   b) output prefix for files
#   c) Good time interval file from bctimebst
#
# Optional (i.e. Have default values):
#   a) Three energy ranges in keV
#   b) Two timebin values in seconds
#   c) Order of polynomial for background fitting.
#
# Step 1: Check input parameters and files. Load data from files.
#
# Step 2: Setup prefix for output files and plotting parameters
#
# Step 3: Query CALDB for response file corresponding to ra,dec
#
# Step 4: Create FITS files containing only Primary for lightcurve data: one
#         file for each detector and one for the sum of all detectors.
#         Extensions will be appended to these files in the following steps.
#
# Begin loop over timebins
#   Begin loop over energy ranges
#       Begin loop over detectors
#
# Step 5: Cut data by energy range and rebin the data by timebin value.
#         Using background times from input good time interval, fit background
#         using background fitting routine with a polynomial which has an order
#         given by the parameter poly. The fitting is performed by channel
#         by channel and background values are interpolated for
#         burst time region based on background fit.
#
# Step 6: Extension for given energy range and timebin is written to lightcurve
#         FITS file.
#
# Step 7: Plot for this timebin and energy range is created. Plot shows
#         lightcurve, fitted background, and the burst region is highlighted.
#
#       End loop over detectors
#
# Step 8: Sum lightcurve data for all detectors for this energy range and timebin.
#         Write extension to total lightcurve fits file.
#
#   End loop over energy ranges
# End loop over timebins
#
# Step 9: Write lightcurve figure to file. This figure contains all individual
#         plots for each timebin and energy range.
#
# Begin loop over detectors
#
# Step 10: Using background times from input good time interval, fit background
#         using background fitting routine with a polynomial which has an order
#         given by the parameter poly. Background values are interpolated for
#         burst time region based on background fit. Data is cut to only include
#         values within the burst time.
#
# Step 11: PHA type 2 FITS files for source and background are written.
#
# Step 12: Data is integrated over burst time and exposure is calculated.
#          PHA type 1 FITS files for source and background are written.
#
# Step 13: A figure showing the source and background spectra is created and
#          written to file for this detector.
#
# End loop for detectors
#
# End task
#
#
# July 2024 - Added condition to check if energy channel is zero, skips energy
#             channel when fitting and interpolates the background spectrum
#             for that energy channel
#
#
#
import sys
import os
import argparse
from argparse import RawTextHelpFormatter
import copy
import datetime
import logging

import astropy.io.fits as fits
from astropy.time import Time
import astropy.units as u
from astropy.coordinates import SkyCoord


from gdt.core.coords import SpacecraftFrame, Quaternion
from gdt.core.data_primitives import TimeBins
from gdt.core.collection import DataCollection
from gdt.core.plot.lightcurve import Lightcurve
from gdt.core.plot.spectrum import Spectrum
from gdt.core.background.fitter import BackgroundFitter
from gdt.core.background.binned import Polynomial
from gdt.core.background.primitives import BackgroundRates

from gdt.core.binning.binned import rebin_by_time


from pathlib import Path

import numpy as np
import healpy as hp
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

from heasoftpy.burstcube.lib.phaii import MyPhaii
from heasoftpy.burstcube.lib import lc
from heasoftpy.burstcube.lib import spec
from heasoftpy.burstcube.lib.io import get_filelist, file_extn, clob_check
from heasoftpy.burstcube.lib.caldbpy import caldbpy
from heasoftpy.burstcube.bcversion_task import bcversion

import warnings

# Suppress Matplotlib user warnings
warnings.filterwarnings("ignore", category=UserWarning)


def bcprod(infile, outprefix, burstgti, radec='NONE', attitude='NONE', erange1='50,100', erange2='100,300', erange3='300,1000',
           timebin1=0.064, timebin2=1.0, poly=3, refdata='REFDATA', chatter=1, clobber=False, phaii_dict=None):

    # ----STEP 1 START------#
    # Set logger parameters
    logger = logging.getLogger('bcprod')
    # Check that chatter is reasonable
    if chatter < 0 or chatter > 5:
        print(f'Chatter value {chatter} is not valid. Must be 0-5.')
        return 1
    # Create dictionary to set level according to chatter value
    chat_level = {0: 40, 1: 20, 5: 10}
    # Chatter values 1 through 4 are all the same for now
    if 1 <= chatter <= 4:
        chatter = 1
    # Set logger level
    logger.setLevel(chat_level[chatter])
    version = bcversion()
    logger.info(f'VERSION: {version}')  
    # Begin checking all input parameters
    logger.info("Verifying input parameters")

    # Check CALDB is set
    if not os.getenv('CALDB'):
        logger.error("CALDB environment is not set.")
        return 1
    else:
        CALDB = os.environ['CALDB']

    # Check refdata
    # If default
    if refdata == 'REFDATA':
        # Check environmental setting
        if os.getenv('LHEA_DATA'):
            refdata = os.environ['LHEA_DATA'] + '/'
        else:
            logger.error("LHEA_DATA is not defined in current environment.")
            return 1
    else:
        # Check that directory exists
        if Path(refdata).exists() and Path(refdata).is_dir():
            # Catch if path does not end with /
            if refdata[-1] != '/':
                refdata = refdata + '/'
        else:
            logger.error(f"Directory {refdata} does not exist.")
            return 1
    # End refdata check

    # If phaii_dict is not given, load files
    if phaii_dict is None:
        # Start checking if input files exist
        if infile.strip() == '':
            logger.error('Input infile name(s) is empty.')
            return 1
        else:
            file_list = get_filelist(infile, logger)
            # If file not found
            if not isinstance(file_list, np.ndarray):
                return 1

        # Initialize array to hold instrument names
        detectors = []
        # Initialize dictionary to hold PHAII data
        phaii_dict = DataCollection()
        # Check that files exist and get instrument name
        for file in file_list:
            if not Path(file).exists():
                logger.error(f"File {file} not found.")
                return 1
            else:
                phaii = MyPhaii.open(file, refdata)
                det = phaii.detector
                detectors.append(det)
                trigtime = phaii.trigtime
                phaii_dict.include(phaii, det)

        logger.info(f"Using files:{file_list}")
        # End checking input files
    # End if phaii_dict=None
    else:
        # phaii_dict is given, retrieve detectors
        detectors = phaii_dict.items
        trigtime = phaii_dict.get_item(detectors[0]).trigtime
    # Sort detectors alphabetically
    detectors.sort()
    datamode = phaii_dict.get_item(detectors[0]).headers[0]['DATAMODE']

    # Begin loading start and stop times from GTI
    # Check for file existence
    burstgti = burstgti.strip()
    burstgti_file, extn = file_extn(burstgti)
    # Set variable peak_found.
    peak_found = True
    if not Path(burstgti_file).exists():
        logger.error(f"GTI file {burstgti_file} not found.")
        return 1
    else:
        gti_hdu = fits.open(Path(burstgti_file))
        # Test that extension exists
        if len(gti_hdu) < (int(extn) + 1):
            logger.warn(
                f"WARNING: No extension {extn} found for file {burstgti_file}")
            logger.warn("WARNING: Will default to first extension.")
            extn = 1
        source_data = gti_hdu[extn].data
        # Check header keywords to see if peak was found
        bsttot = -99.0
        try:
            bsttot = gti_hdu[extn].header['BSTTOT']
        except:
            logger.error(
                f"Header keyword BSTTOT not found in {burstgti_file}.")
            return 1
        if bsttot < 0:
            logger.warn("WARNING: No peak was previously found.")
            logger.warn(
                "WARNING: Total time will be used and no background fitting will be done.")
            peak_found = False

        # Store source time range
        src_times = []

        for tstart, tstop in zip(source_data['START'], source_data['STOP']):
            src_times += [(tstart-trigtime, tstop-trigtime)]

        if trigtime is None:
            trigtime = 0.

        # If peak was found
        if peak_found:
            bkg_data = gti_hdu['BKG'].data
            bkg_times = []

            for tstart, tstop in zip(bkg_data['START'], bkg_data['STOP']):
                bkg_times += [(tstart-trigtime, tstop-trigtime)]

    # Stop loading start and stop times from GTI

    # Check that GTI are reasonable
    # Load time range for data
    time_range = phaii_dict.get_item(detectors[0]).time_range
    # Check source
    for times in src_times:
        if (times[1] < time_range[0]) or (times[0] > time_range[1]):
            logger.error("GTI for source not within data time range.")
            return 1
    # Check background
    if peak_found:
        for times in bkg_times:
            if (times[1] < time_range[0]) or (times[0] > time_range[1]):
                logger.error("GTI for background not within data time range.")
                return 1

    # End checking GTI

    # Start checking Eranges
    eranges = []
    erange1 = erange1.split(',')
    erange2 = erange2.split(',')
    erange3 = erange3.split(',')

    erange_test = [erange1, erange2, erange3]
    for erange in erange_test:
        if len(erange) == 2:
            try:
                emin = float(erange[0])
                emax = float(erange[1])
            except:
                logger.error(
                    f"{erange[0]},{erange[1]} does not contain proper float values.")
                return 1
            # if high end is lower than low end, throw error
            if emin > emax:
                logger.error(
                    f"Lower bound {emin} for energy range is larger than higher bound {emax}.")
                return 1
            else:
                # Add erange to eranges list
                eranges += [(emin, emax)]
        else:
            logger.error(
                f"{erange} is not a proper energy range (i.e. emin,emax)")
            return 1
    # Add range to cover total
    eranges += [(-999, 999)]

    # Stop checking Eranges

    # Check that ra,dec values are valid
    if radec.upper() == 'NONE':
        radec = None
    if radec is not None:
        # Check that attitude is not None
        if attitude.upper() == 'NONE':
            logger.error('Attitude file must be supplied if RA/DEC is given.')
            return 1
        else:
            # Begin checking attitude file and load attitude data
            # Check if extension given
            attitude = attitude.strip()
            file, extn = file_extn(attitude)
            # Check that file exists
            if not Path(file).exists():
                logger.error(f"File {file} not found.")
                return 1
    
            # Open file
            att_fits = fits.open(Path(file))[extn]
            # Load column info
            att_cols = att_fits.columns
            # Check that the columns TIME and QPARAM
            cols_needed = np.array(['TIME', 'QPARAM'])
    
            not_present = np.setdiff1d(cols_needed, att_cols.names)
    
            # If column missing, return error
            if not_present.size != 0:
                logger.error("Needed columns not found:", not_present)
                return 1
    
            # Begin loading column data
            att_data = att_fits.data
            att_time = att_data['TIME']-trigtime
            att_quat = att_data['QPARAM']
    
            # Check that attitude time covers src_time (i.e. time of burst)
            for times in src_times:
                if not (att_time[0] <= times[0] <= att_time[-1]) or not (att_time[0] <= times[1] <= att_time[-1]):
                    logger.error("Attitude file does not cover burst time.")
                    return 1
            # End checking attitude file
        
        radec_vals = radec.split(',')
        # If RA/DEC is not comma separated, it is a loc_map file
        if len(radec_vals) != 2:
            # Load locmap routine
            ra, dec = best_loc(radec_vals[0], logger)
            if ra == -99 and dec == -99:
                return 1
        else:
            # Check that ra/dec values are floats
            for deg in radec_vals:
                if not is_valid_float(deg):
                    logger.error(f"{deg} is not a valid float value.")
                    return 1
            # Check and make sure values fall within range
            ra = float(radec_vals[0])
            dec = float(radec_vals[1])
            if (ra < 0.0) or (ra > 360.0):
                logger.error(f"RA:{ra} is not a valid right ascension value.")
                return 1
            if (dec < -90.0) or (dec > 90.0):
                logger.error(f"DEC:{dec} is not a valid declination value.")
                return 1
    # End of checking ra,dec

    # Put time bins into array
    timebins = []
    # Check that timebins are floats
    for time in (timebin1, timebin2):
        if is_valid_float(time):
            timebins += [float(time)]
        else:
            logger.error(f"{time} is not a valid timebin value.")
            return 1
    # End checking timebins

    # Check poly
    try:
        poly = int(poly)
    except:
        logger.error(f"Polynomial order {poly} is not a valid integer.")
        return 1
    # End poly check

    # New parameter to create lightcurve/spectra plots for all detectors in
    # a single figure
    # Set to False for now
    combined = False

    # End of checking input parameters

    # ----STEP 1 STOP------#

    # ----STEP 2 START------#
    # Calculate date for output filenames based on DATE-OBS
    fits_test = fits.open(Path(file_list[0]))

    date_obs = fits_test[0].header['DATE-OBS']

    time = Time(date_obs, format='fits')
    # Extract the date components
    year = time.datetime.year % 100  # YY
    month = time.datetime.month       # MM
    day = time.datetime.day           # DD

    # Calculate the fraction of the day
    fraction_of_day = int(np.ceil((time.mjd - int(time.mjd))*1000))

    # Format the date and time components into YYMMDDXXX
    formatted_date_time = f'{year:02d}{month:02d}{day:02d}{fraction_of_day:03d}'
    # End creating output date for filenames

    # Create figure and axes for lc plots
    lcplots = {}
    fig_ax_det = {}
    for det in detectors:
        fig_lc, ax_lc = plt.subplots(4, 2, dpi=150, sharex='col', figsize=(
            20, 10), gridspec_kw={'hspace': 0.0})
        if trigtime is None or trigtime == 0.0:
            ax_lc[0, 0].set_title(f'Time Binning {timebins[0]} s')
            ax_lc[0, 1].set_title(f'Time Binning {timebins[1]} s')
        else:
            reference_time = "2021-01-01T00:00:00"
            reference = Time(reference_time, format='isot')
            trig_utc = reference + trigtime * u.s
            ax_lc[0, 0].set_title(
                f'Time Binning {timebins[0]} s\nTrigger Time in UTC: {trig_utc.isot[:19]}')
            ax_lc[0, 1].set_title(
                f'Time Binning {timebins[1]} s\nTrigger Time in UTC: {trig_utc.isot[:19]}')

        fig_ax_det[det] = {'fig': fig_lc, 'ax': ax_lc}

    # If param set to True create additional figures
    num_files = len(detectors)
    rows = int(np.ceil(np.sqrt(num_files)))
    cols = int(np.ceil(num_files/rows))

    # Set up figures for spec and lightcurves
    # Each figure has the columns and rows calculated above, dpi(resolution) set to 150
    fig_totspec, ax_totspec = plt.subplots(rows, cols, dpi=150)
    fig_totlc, ax_totlc = plt.subplots(rows, cols, dpi=150)
    # Initialize the arrays to hold the plots for each individual instrument
    tot_lcplots = {}

    # Set up location of individual plots on figure
    # Location is set as starting in upper-left of grid, filling up row, and moving to next row
    det_to_ax_index = {}
    count = 0
    for i in range(rows):
        for j in range(cols):
            # If count greater than number of input files, remove plot space from figure
            if count >= num_files:
                fig_totspec.delaxes(ax_totspec[i, j])
                fig_totlc.delaxes(ax_totlc[i, j])
            else:
                # Assign instrument to specific location in figure
                det_to_ax_index[detectors[count]] = (i, j)
            count += 1
    # Stop set-up plotting

    # ----STEP 2 STOP------#

    # ----STEP 3 START------#
    if radec is not None:
        # Create SkyCoord from ra dec
        coords = SkyCoord(ra,dec,unit='deg',frame='icrs')
        
        # Retrieve quaternion that is closest of midpoint of source time
        mid = (src_times[0][0] + src_times[0][1])/2.0
        idx = np.abs(att_time - mid).argmin()
        quat = Quaternion(att_quat[idx])
        print(f"Time: {att_time[idx]+trigtime}")
        print(f"Quaternion: {quat}")
        # Rotate to instrument frame and transform into theta and phi
        coords = coords.transform_to(SpacecraftFrame(quaternion=quat))
        coords = coords.represent_as('physicsspherical')
        
        # Search CALDB for response file corresponding to theta and phi
        theta_rad = coords.theta.rad
        phi_rad = coords.phi.rad
        
        print('Calculated Theta and Phi in Instrument frame')
        print(f'Theta:{coords.theta.degree} Phi:{coords.phi.degree}')
        
        # For now nside is defaulted to 16
        nside = 16
        near_pix = hp.ang2pix(nside, theta_rad, phi_rad)[0]
        print(f'Pixel:{near_pix}')
        near_theta, near_phi = hp.pix2ang(nside,near_pix)
        print('Midpoint Theta and Phi for Pixel')
        print(f'Theta:{np.degrees(near_theta)} Phi:{np.degrees(near_phi)}')
        # Create CALDB expression based on nearest pixel
        expr = f'PIXEL={near_pix:04d}'
    
        # Do loop for detectors
        for det in detectors:
            # Query CALDB
            filepath, extension = caldbpy(
                'BURSTCUBE', det, '-', '-', 'ebounds', 'now', 'now', expr)
        
            # Check if file was found
            if isinstance(filepath, np.ndarray):
                full_file = CALDB + "/" + filepath[0]
                # Check that file exists
                if not Path(full_file).exists():
                    logger.error(f"Could not find file {full_file}")
                    return 1
                else:
                    response_file = full_file
                    logger.info(f'Using RESPONSE file {response_file} for {det}')
        
            else:
                # If no file found
                logger.warning("No RESPONSE file found.")

    # ----STEP 3 STOP------#

    # ----STEP 4 START------#
    # ----Begin lightcurve creation-----#

    # Assign suffix based upon datamode
    if datamode == 'EVENT':
        mode_suff = '_tte'
    else:
        mode_suff = '_atd'
    # Create FITS lightcurve files for all detectors, just Primary for now
    for det in detectors:
        filename = "bc" + formatted_date_time + det.lower()+mode_suff+".lc"
        prim_hdu = fits.PrimaryHDU(header=lc.LCPrimHeader(refdata))
        prim_hdu.header = copyheader(
            phaii_dict.get_item(det).headers[0], prim_hdu.header)
        prim_hdu.header['DATE'] = Time(
            datetime.datetime.now(), scale='utc').isot[:19]
        hdulist = fits.HDUList([prim_hdu])
        # Existence check
        filename = f'{outprefix}{filename}'
        status = clob_check(clobber, filename, logger)
        if status:
            return 1
        else:
            hdulist.writeto(filename,
                            overwrite=clobber, checksum=True)
            modify = True
    # Create FITS file for sum of all detectors
    filename = "bc" + formatted_date_time + "csa" + mode_suff+".lc"
    prim_hdu = fits.PrimaryHDU(header=lc.LCPrimHeader(refdata))
    prim_hdu.header = copyheader(phaii_dict.get_item(
        detectors[0]).headers[0], prim_hdu.header)
    prim_hdu.header['INSTRUME'] = 'CSA'
    prim_hdu.header['DATE'] = Time(
        datetime.datetime.now(), scale='utc').isot[:19]
    hdulist = fits.HDUList([prim_hdu])
    # Existence check
    filename = f'{outprefix}{filename}'
    status = clob_check(clobber, filename, logger)
    if status:
        return 1
    else:
        hdulist.writeto(filename, overwrite=clobber, checksum=True)

    # ----STEP 4 STOP------#

    # ----STEP 5 START------#
    # Loop over each timebin, energy range, and detector
    # Loop for creating FITS files
    logger.info("Writing lightcurve files")
    for i, timebin in enumerate(timebins):

        for j, erange in enumerate(eranges):

            # Set up dictionaries to hold lc and background for totals
            lc_dict = DataCollection()
            bkg_dict = DataCollection()

            for det in detectors:

                phaii = phaii_dict.get_item(det)

                # Slice by appropriate energy range
                if erange[0] != -999:
                    phaii = phaii.slice_energy(erange)
                else:
                    erange = phaii.energy_range

                # Check that no energy bins have zero counts
                zeros, zero_indices, non_zero_indices, newerange = remove_zero_energybins(phaii)
                
                # Slice data by erange if needed
                if zeros:
                    phaii = phaii.slice_energy(newerange)

                # Rebin by appropriate timebin
                phaii = phaii.rebin_time(rebin_by_time, timebin)
                
                lightcurve = phaii.to_lightcurve()
                
                lc_dict.include(lightcurve, name=det)
                buff_bkg = background_buffer(bkg_times, phaii.data)

                # Fit the background to background times if peak was found
                if peak_found:
                    backfitter = BackgroundFitter.from_phaii(phaii,
                                                             Polynomial,
                                                             time_ranges=buff_bkg)

                    try:
                        backfitter.fit(order=poly)
                    except:
                        logger.error("Unable to fit background properly. Please change polynomial order")
                        return 1

                    back_rates = backfitter.interpolate_bins(
                        phaii.data.tstart, phaii.data.tstop)

                else:
                    # Create empty back_rates for no peak found
                    rates = np.zeros_like(phaii.data.rates)
                    rate_uncert = np.zeros_like(phaii.data.rate_uncertainty)
                    back_rates = BackgroundRates(rates, rate_uncert, phaii.data.tstart, phaii.data.tstop,
                                                 phaii.data.emin, phaii.data.emax, phaii.data.exposure)

                # Integrate by energy and save to dictionary
                if (back_rates.num_chans > 1):
                    back_rates = back_rates.integrate_energy()
                
                bkg_dict.include(back_rates, name=det)

                # ----STEP 5 STOP------#

                # ----STEP 6 START------#
                # Create lightcurve extension for FITS file
                lc_hdu = lc.lc_table(
                    lightcurve, back_rates.rates, back_rates.rate_uncertainty, phaii.trigtime, refdata)
                # Copy header info
                lc_hdu.header = copyheader(phaii.headers[1], lc_hdu.header)
                # Assign values to keywords
                if j == 3:
                    lightcurve_hdrupdate(lc_hdu, version, timebin, erange, True)
                else:
                    lightcurve_hdrupdate(lc_hdu, version, timebin, erange)

                # Open FITS file to append to
                if modify:
                    filename = "bc" + formatted_date_time + det.lower()+mode_suff+".lc"
                    existing_file = fits.open(outprefix + filename)
                    existing_file.append(lc_hdu)
                    existing_file.writeto(
                        outprefix + filename, overwrite=True, checksum=True)
                    existing_file.close()

                # ----STEP 6 STOP------#

                # ----STEP 7 START------#
                # Create subplot for timebin and erange
                # Set correct axes
                lc_ax = fig_ax_det[det]['ax'][j, i]
                lcplots[(det, timebin, erange)] = create_lcplot(lc_ax, det,
                                                                timebin, erange,
                                                                phaii, lightcurve,
                                                                back_rates, src_times[0],
                                                                trigtime, legend=True)

                if timebin == 1.0 and j == 3 and combined is True:
                    lc_axes = ax_totlc[det_to_ax_index[det]]
                    tot_lcplots[(det, timebin, erange)] = create_lcplot(lc_axes, det,
                                                                        timebin, erange,
                                                                        phaii, lightcurve,
                                                                        back_rates, src_times[0],
                                                                        trigtime)
                # ----STEP 7 STOP------#
                # End loop for detectors

            # ----STEP 8 START------#
            # Begin process for summing all detectors
            rates = np.zeros_like(lc_dict.get_item(detectors[0]).rates)
            rates_var = np.zeros_like(rates)
            bkg_rates = np.zeros_like(rates)
            bkg_var = np.zeros_like(rates)
            for det in detectors:
                rates += lc_dict.get_item(det).rates
                rates_var += lc_dict.get_item(det).rate_uncertainty ** 2
                bkg_rates += bkg_dict.get_item(det).rates
                bkg_var += bkg_dict.get_item(det).rate_uncertainty ** 2

            # averaged exposure, sampling times
            exposure = np.mean(
                [light.exposure for light in lc_dict.to_list()], axis=0)
            tstart = np.mean(
                [light.lo_edges for light in lc_dict.to_list()], axis=0)
            tstop = np.mean(
                [light.hi_edges for light in lc_dict.to_list()], axis=0)

            sum_lc = TimeBins(rates*exposure, tstart, tstop, exposure)
            # Create lightcurve extension for FITS file
            lc_hdu = lc.lc_table(sum_lc, bkg_rates, np.sqrt(
                bkg_var), phaii.trigtime, refdata)
            # Copy header info
            lc_hdu.header = copyheader(phaii.headers[1], lc_hdu.header)
            # Assign values to keywords
            lightcurve_hdrupdate(lc_hdu, version, timebin, erange)
            lc_hdu.header['INSTRUME'] = 'CSA'

            # Open FITS file to append to
            if modify:
                filename = "bc" + formatted_date_time + "csa" + mode_suff+".lc"
                existing_file = fits.open(outprefix + filename)
                existing_file.append(lc_hdu)
                existing_file.writeto(
                    outprefix + filename, overwrite=True, checksum=True)
                existing_file.close()
    # End loop for FITS files
    # ----STEP 8 STOP------#

    # ----STEP 9 START------#
    # Create lightcurve plots
    logger.info("Writing lightcurve plots")
    for det in detectors:
        filename = "bc" + formatted_date_time + det.lower()+mode_suff+"lc.png"
        fig_ax_det[det]['fig'].tight_layout()
        filename = f'{outprefix}{filename}'
        status = clob_check(clobber, filename, logger)
        if status:
            return 1
        else:
            fig_ax_det[det]['fig'].savefig(filename)
    if combined:
        filename = "bc" + formatted_date_time + mode_suff+"totlc.png"
        fig_totlc.tight_layout()
        filename = f'{outprefix}{filename}'
        status = clob_check(clobber, filename, logger)
        if status:
            return 1
        else:
            fig_totlc.savefig(filename)
    # End loop for creating plots
    # ----STEP 9 STOP------#
    # ---- End lightcurve creation-----#

    # ----STEP 10 START------#
    # --- Begin spectra creation ------#
    logger.info("Calculating spectra")
    for det in detectors:

        # Retrieve phaii data
        phaii = phaii_dict.get_item(det)
        
        # Check for bins with zero counts
        zeros, zero_indices, non_zero_indices, newerange = remove_zero_energybins(phaii)
        
        # Make backup of original phaii
        # Slice data by erange
        if zeros:
            orig_phaii = copy.deepcopy(phaii)
            phaii = phaii.slice_energy(newerange)

        buff_bkg = background_buffer(bkg_times, phaii.data)

        # Fit the background to background times if peak found
        if peak_found:
            backfitter = BackgroundFitter.from_phaii(phaii,
                                                     Polynomial,
                                                     time_ranges=buff_bkg)
            backfitter.fit(order=poly)

            back_rates = backfitter.interpolate_bins(
                phaii.data.tstart, phaii.data.tstop)
        else:
            # Create empty back_rates for no peak found
            rates = np.zeros_like(phaii.data.rates)
            rate_uncert = np.zeros_like(phaii.data.rate_uncertainty)
            back_rates = BackgroundRates(rates, rate_uncert, phaii.data.tstart, phaii.data.tstop,
                                         phaii.data.emin, phaii.data.emax, phaii.data.exposure)


        # Do time cuts to only include source region
        phaii = phaii.slice_time(src_times)
        # Currently assumes single time range for source
        back_rates = back_rates.slice_time(src_times[0][0], src_times[0][1])
        
        # If there were bins with zero counts, interpolate values in bins based
        # on background fitting
        if zeros:
            phaii = orig_phaii.slice_time(src_times)
            back_rates = backrates_interp(phaii, back_rates, zero_indices, non_zero_indices)

        # ----STEP 10 STOP------#

        # ----STEP 11 START------#
        # Create PHA type 2 files
        # Source FITS file creation
        # Set filenames
        source_file = "bc" + formatted_date_time + det.lower()+mode_suff+".phaii"
        bkg_file = "bc" + formatted_date_time + det.lower()+mode_suff+"_bg.phaii"
        # Create hdulist
        srclist = create_spec2_hdul(phaii, phaii.data.counts, phaii.data.exposure, phaii.data.quality,
                                    phaii.data.tstart, phaii.data.tstop, phaii.trigtime, refdata, 'src', bkg_file)
        srclist[1].header['HDUCLAS2'] = 'TOTAL'
        srclist[1].header['SOFTVER'] = version
        # Write source spec file
        logger.info(f"Writing source spectrum TYPE 2 file for {det}")
        # Existence check
        filename = f'{outprefix}{source_file}'
        status = clob_check(clobber, filename, logger)
        if status == 1:
            return 1
        else:
            srclist.writeto(filename,
                            overwrite=clobber, checksum=True)

        hdulist.close()

        # Background FITS file creation
        # Create hdulist
        bkglist = create_spec2_hdul(phaii, np.array(back_rates.counts, dtype=int), back_rates.exposure, back_rates.quality,
                                    back_rates.tstart, back_rates.tstop, phaii.trigtime, refdata, 'bkg')
        bkglist[1].header['HDUCLAS2'] = 'BKG'
        bkglist[1].header['SOFTVER'] = version
        # Write to file
        logger.info(f"Writing bkg spectrum TYPE 2 file for {det}")
        # Existence check
        filename = f'{outprefix}{bkg_file}'
        status = clob_check(clobber, filename, logger)
        if status == 1:
            return 1
        else:
            bkglist.writeto(filename,
                            overwrite=clobber, checksum=True)
        bkglist.close()

        # End PHA Type 2 Creation
        # ----STEP 11 STOP------#

        # ----STEP 12 START------#
        # Start PHA Type 1 Creation
        # Create spec by integrating over time
        phaii_rates = phaii.data.integrate_time()
        back_spec = back_rates.integrate_time()
        # Sum exposure
        exposure = np.sum(phaii.data.exposure)

        # Source FITS file creation
        # Set filename
        source_file = "bc" + formatted_date_time + det.lower()+mode_suff+".pha"
        # Create hdulist
        srclist = create_spec1_hdul(
            phaii, phaii_rates.counts, exposure, refdata)
        srclist[1].header['HDUCLAS2'] = 'TOTAL'
        srclist[1].header['SOFTVER'] = version
        # Write source spec file
        logger.info(f"Writing source spectrum TYPE 1 file for {det}")
        # Existence check
        filename = f'{outprefix}{source_file}'
        status = clob_check(clobber, filename, logger)
        if status == 1:
            return 1
        else:
            srclist.writeto(filename,
                            overwrite=clobber, checksum=True)

        hdulist.close()

        # Background FITS file creation
        # Set filename
        bkg_file = "bc" + formatted_date_time + det.lower()+mode_suff+"_bg.pha"
        # Create hdulist
        bkglist = create_spec1_hdul(phaii, back_spec.counts, exposure, refdata)
        bkglist[1].header['HDUCLAS2'] = 'BKG'
        bkglist[1].header['SOFTVER'] = version
        # Write to file
        logger.info(f"Writing bkg spectrum TYPE 1 file for {det}")
        # Existence check
        filename = f'{outprefix}{bkg_file}'
        status = clob_check(clobber, filename, logger)
        if status == 1:
            return 1
        else:
            bkglist.writeto(filename,
                            overwrite=clobber, checksum=True)
        bkglist.close()
        # End writing PHA Type 1

        # ----STEP 12 STOP------#

        # ----STEP 13 START------#
        # Create spectra plots
        specplot = Spectrum(data=phaii.to_spectrum(),
                            background=back_spec)
        # Add detector label
        specplot.ax.text(0.9, 0.9, det,
                         horizontalalignment='center',
                         verticalalignment='center',
                         transform=specplot.ax.transAxes)
        # Put filename as title
        specplot.ax.legend(["Source", "Fitted Background"], loc='lower left')
        # Add trigger time in upper left
        if trigtime is not None and trigtime != 0:
            plt.figtext(0.02, 0.98, f'Trigger Time in UTC: {trig_utc.isot[:19]}',
                        horizontalalignment='left', verticalalignment='top', fontsize=12)
        filename = "bc" + formatted_date_time + det.lower()+mode_suff+"pha.png"
        logger.info(f"Writing spectrum plot for {det}")
        filename = f'{outprefix}{filename}'
        status = clob_check(clobber, filename, logger)
        if status == 1:
            return 1
        else:
            plt.savefig(filename)
    # ----End of spectra section----#
    # ----STEP 13 START------#
    return 0


def find_consecutive_nonzeros(arr):
    # Convert array to boolean where True represents zero
    is_zero = (arr == 0)
    
    # Find the differences to identify the start and end points
    diff = np.diff(is_zero.astype(int))
    
    # End points are where the diff is 1 (going from non-zero to zero)
    end_indices = np.where(diff == 1)[0] 
    
    # Start points are where the diff is -1 (going from zero to non-zero)
    start_indices = np.where(diff == -1)[0] +1
    
    # Adjust for cases where the sequence starts or ends with zeros
    if not is_zero[0]:
        start_indices = np.insert(start_indices, 0, 0)
    if not is_zero[-1]:
        end_indices = np.append(end_indices, len(arr) - 1)
    return start_indices, end_indices

def remove_zero_energybins(phaii):
    # Create spectrum to find bins with zero points
    spectrum = phaii.to_spectrum()
    
    # Find bins with zero counts
    zero_indices = np.where(spectrum.counts == 0)[0]

    if (len(zero_indices) == 0):
        return False, 0, 0, 0
    
    # Find bins that are not zero
    non_zero_indices = np.where(spectrum.counts != 0)[0]

    # Find start/stop indices where spectrum has counts
    start_indices, stop_indices = find_consecutive_nonzeros(spectrum.counts)
    
    # Set up new erange for slicing data to ignore these bins
    erange = []

    if (len(start_indices) > 1):
        for i in range(len(start_indices)):
            erange.append((phaii.data.emin[start_indices[i]], phaii.data.emax[stop_indices[i]]))
    else:
        erange.append((phaii.data.emin[start_indices[0]],phaii.data.emax[stop_indices[0]]))
    
    return True, zero_indices, non_zero_indices, erange


def backrates_interp(phaii, back_rates, zero_indices, non_zero_indices):
    
    # Create the new arrays initialized with zeros
    new_backrates = np.zeros_like(phaii.data.rates)
    new_backunc   = np.zeros_like(phaii.data.rates)
    
    # Replace non-zero values (already fitted)
    new_backrates[:, non_zero_indices] = back_rates.rates
    new_backunc[:, non_zero_indices] = back_rates.rate_uncertainty
    
    # Create spec to find number of counts in each energy bin
    back_spec = back_rates.integrate_time()
    
    # Use the non-zero indices and fitted values to create an interpolation function
    interp_function = interp1d(back_spec.centroids, back_spec.counts, axis=0, fill_value="extrapolate")
    
    # Interpolate the counts in energy bins
    new_spec_counts= interp_function(phaii.to_spectrum().centroids)
    
    # Get rates for fitted background
    back_rates = back_rates.integrate_energy()
    
    # Estimate rates in time/energybins based on shape of fitted background
    new_spec_rates = back_rates.counts/back_rates.counts.sum() * new_spec_counts[zero_indices] / phaii.data.exposure[0]
    new_spec_unc   = np.sqrt(new_spec_rates/phaii.data.exposure[0])

    # Replace zero values
    for i, zero_idx in enumerate(zero_indices):
        new_backrates[:, zero_idx] = new_spec_rates[:, i]
        new_backunc[:, zero_idx] =  new_spec_unc [:, i]

    
    # Create new BackgroundRates object based on interpolated energy bin counts
    new_back_rates = BackgroundRates(new_backrates, new_backunc, phaii.data.tstart, phaii.data.tstop,
                                 phaii.data.emin, phaii.data.emax, phaii.data.exposure)
    
    return new_back_rates


# Copy basic info from one header to another


def copyheader(fromheader, toheader):
    # Keywords to skip
    exclude = ['TFORM', 'TTYPE', 'TUNIT', 'HDUCL']
    # Iterate through keywords
    for key, val in fromheader.items():
        if any(keyword in key[:5] for keyword in exclude):
            continue
        if key in toheader.keys() and key != 'EXTNAME':
            toheader[key] = val

    return toheader

# Update DATE keyword in headers


def updateDATE(hdu_list):
    current_datetime = Time(datetime.datetime.now(), scale='utc')
    for hdu in hdu_list:
        header = hdu.header
        if 'DATE' in header:
            header['DATE'] = current_datetime.isot[:19]

# Update lightcurve headers


def lightcurve_hdrupdate(lc_hdu, version, timebin, erange, total=False):
    if timebin == 0.064:
        if total:
            lc_hdu.header['EXTNAME'] = f"LC_TOTAL_64ms"
        else:
            lc_hdu.header['EXTNAME'] = f"LC_{int(erange[0])}_{int(erange[1])}keV_64ms"
    else:
        if total:
            lc_hdu.header['EXTNAME'] = f"LC_TOTAL_1s"
        else:
            lc_hdu.header['EXTNAME'] = f"LC_{int(erange[0])}_{int(erange[1])}keV_1s"
    lc_hdu.header['TIMEDEL'] = timebin
    lc_hdu.header['E_MIN'] = erange[0]
    lc_hdu.header['E_MAX'] = erange[1]
    lc_hdu.header['SOFTVER'] = version
    lc_hdu.header['DATE'] = Time(
        datetime.datetime.now(), scale='utc').isot[:19]

# Check that parameter can be treated as a float


def is_valid_float(value):
    if isinstance(value, str):
        try:
            float_value = float(value)
            return True
        except ValueError:
            return False
    elif isinstance(value, float):
        return True
    else:
        return False


# Get theta and phi from localization map header
def best_loc(locmap, logger):
    # locmap - Filename of localization map with header keywords
    locmap = locmap.strip()
    locmap_file, extn = file_extn(locmap)

    if not Path(locmap_file).exists():
        logger.error(f"Localization map file {locmap_file} not found.")
        return -99, -99

    loc_hdu = fits.open(Path(locmap_file))
    # Test that extension exists
    if len(loc_hdu) < (int(extn) + 1):
        logger.warn(
            f"WARNING: No extension {extn} found for file {locmap_file}")
        logger.warn("WARNING: Will default to first extension.")
        extn = 1

    # Retrieve header
    loc_header = loc_hdu[extn].header
    ra = loc_header.get('RA_OBJ')
    dec = loc_header.get('DEC_OBJ')
    if ra is None or dec is None:
        logger.error(
            f'RA_OBJ and DEC_OBJ not found in header of {locmap_file}.')
        return -99, -99
    else:
        return ra, dec


# Create individual lightcurve plots
def create_lcplot(axes, detector, timebin, erange, phaii, lightcurve, backrates,
                  srctime, trigtime=None, legend=False):
    lcplots = {}
    lcplots[(detector, timebin, erange)] = Lightcurve(
        data=lightcurve, background=backrates, ax=axes)

    # Highlight source region (optional?)
    source_lc = phaii.to_lightcurve(time_range=srctime)
    lcplots[(detector, timebin, erange)].add_selection(source_lc)
    lcplots[(detector, timebin, erange)].selections[0].color = 'green'
    # Add detector label
    lcplots[(detector, timebin, erange)].ax.text(0.9, 0.9, detector,
                                                 horizontalalignment='center',
                                                 verticalalignment='center',
                                                 transform=lcplots[(detector, timebin, erange)].ax.transAxes)
    # Add legend
    if legend:
        artists = lcplots[(detector, timebin, erange)].ax.get_children()
        lcplots[(detector, timebin, erange)].ax.legend([artists[0], artists[2], artists[4]], [
            f"{erange[0]}-{erange[1]:.1f}keV", "Fitted Background", "Source Region"], loc='upper left')
    if trigtime is None or trigtime == 0.0:
        lcplots[(detector, timebin, erange)].ax.set_xlabel("Time (s)")
    else:
        lcplots[(detector, timebin, erange)].ax.set_xlabel(
            "Time - Trigger Time (s)")
    axes.set_ylim(0.8*np.amin(lightcurve.rates),
                  1.2*np.amax(lightcurve.rates))

    return lcplots


# Create PHA 2 hdul


def create_spec2_hdul(phaii, counts, exposure, quality, tstart, tstop, trigtime, refdata, type, bkgfile=None):
    # Create primary
    prim_hdu = fits.PrimaryHDU(header=spec.SpecPrimHeader(refdata))
    # Copy header info to primary header
    prim_hdu.header = copyheader(phaii.headers[0], prim_hdu.header)
    # Create spectrum extension
    spec_hdu = spec.spec_pha2_table(
        counts, exposure, quality, tstart, tstop, phaii.num_chans, trigtime, refdata, type, bkgfile)
    # Copy header info
    spec_hdu.header = copyheader(phaii.headers[1], spec_hdu.header)
    # Create HDUlist
    hdulist = fits.HDUList([prim_hdu, spec_hdu])
    # Update DATE keyword
    updateDATE(hdulist)

    return hdulist

# Create PHA 1 hdul


def create_spec1_hdul(phaii, counts, exposure, refdata):
    # Create primary
    prim_hdu = fits.PrimaryHDU(header=spec.SpecPrimHeader(refdata))
    # Copy header info to primary header
    prim_hdu.header = copyheader(phaii.headers[0], prim_hdu.header)
    # Create spectrum extension
    spec_hdu = spec.spec_table(counts, phaii.num_chans, refdata)
    # Copy header info
    spec_hdu.header = copyheader(phaii.headers[1], spec_hdu.header)
    spec_hdu.header['EXPOSURE'] = exposure
    # Create HDUlist
    hdulist = fits.HDUList([prim_hdu, spec_hdu])
    # Update DATE keyword
    updateDATE(hdulist)

    return hdulist

# Function to calculate background times taking into account rebinning
def background_buffer(bkg_times, data):

    buff_bkg=[]
    # Loop through time pairs in bkg_times
    for pair in bkg_times:
        start = data.closest_time_edge(pair[0], 'high')
        stop = data.closest_time_edge(pair[1], 'low')
        if (start > stop):
            start = data.closest_time_edge(pair[0], 'low')
            stop = data.closest_time_edge(pair[1], 'high')
        buff_bkg += [(start,stop)]

    return buff_bkg

# Function to return index of smallest positive value in list
def find_smallest_positive_index(list):

    arr = np.array(list)
    positive_arr = np.where(arr > 0, arr, np.inf)
    min_index = np.argmin(positive_arr)
    if np.isinf(positive_arr[min_index]):
        return -1
    return min_index

def main():
    desc = """
    Program to produce BurstCube data products
    """

    parser = argparse.ArgumentParser(
        description=desc, formatter_class=RawTextHelpFormatter)
    parser.add_argument('--infile', action="store", type=str,
                        help="List of input FITS files or @file containing list")
    parser.add_argument('--outprefix', type=str,
                        help='Prefix for output files')
    parser.add_argument('--burstgti', type=str,
                        help='GTI file containing time ranges for source and background')
    parser.add_argument('--radec', type=str,
                        help="'RA,DEC' or the name of the localization map from bcfindloc",
                        default='NONE')
    parser.add_argument('--attitude', type=str,
                        help='FITS file containing attitude information',
                        default='NONE')
    parser.add_argument('--erange1', type=str,
                        help='First energy range of the format "min,max"',
                        default='50,100')
    parser.add_argument('--erange2', type=str,
                        help='Second energy range of the format "min,max"',
                        default='100,300')
    parser.add_argument('--erange3', type=str,
                        help='Third energy range of the format "min,max"',
                        default='300,1000')
    parser.add_argument('--timebin1', type=float,
                        help='First time binning value in seconds',
                        default=0.064)
    parser.add_argument('--timebin2', type=float,
                        help='Second time binning value in seconds',
                        default=1.0)
    parser.add_argument('--poly', type=int,
                        help='Polynomial order for background fitting',
                        default=3)
    parser.add_argument('--refdata', type=str,
                        help='Reference data directory',
                        default='REFDATA')
    parser.add_argument('--chatter', type=int,
                        default=1, help='Level of verboseness.')
    parser.add_argument('--log', type=str,
                        default='no', help='Create log file?')
    parser.add_argument('--clobber', type=str,
                        default='no', help='Overwrite existing output file(s)?')

    args = parser.parse_args()
    args_dict = vars(args)
    # Loop through input parameters and prompt user if a required value is not provided
    for param, val in args_dict.items():
        if val is None:
            # Print help information for the parameter
            for action in parser._actions:
                if action.option_strings and action.option_strings[0] == f'--{param}':
                    # Prompt user for input
                    arg_type = action.type if hasattr(action, 'type') else str
                    # Make sure input is valid type
                    valid = False
                    while not valid:
                        new_value = input(f"{action.help}: ")
                        try:
                            new_value = arg_type(new_value)
                            setattr(args, param, new_value)
                            valid = True
                        except:
                            type_str = str(arg_type).split("'")[1]
                            print(
                                f"{param} must be a valid {type_str}. Try again.")

    infile = args.infile
    outprefix = args.outprefix + '_'
    burstgti = args.burstgti
    radec = args.radec
    attitude = args.attitude
    erange1 = args.erange1
    erange2 = args.erange2
    erange3 = args.erange3
    timebin1 = args.timebin1
    timebin2 = args.timebin2
    poly = int(args.poly)
    refdata = args.refdata
    chatter = int(args.chatter)
    log = args.log in ['yes', 'y', True]  # Simple check
    clob = args.clobber in ['yes', 'y', True]  # Simple check

    # logging.basicConfig(format='%(message)s')
    logger = logging.getLogger('bcprod')
    # Set level based upon chatter level
    # Check that chatter is reasonable
    if chatter < 0 or chatter > 5:
        print(f'Chatter value {chatter} is not valid. Must be 0-5.')
        sys.exit()
    # Create dictionary to set level according to chatter value
    chat_level = {0: 40, 1: 20, 5: 10}
    # Chatter values 1 through 4 are all the same
    if 1 <= chatter <= 4:
        chatter = 1

    # If log file is wanted
    if log:
        # Create a file handler and set the logging level
        # Get current date and time
        now = datetime.datetime.now()
        date_time = now.strftime("%d%m%Y_%H%M%S")
        log_filename = f'bcprod_{date_time}.log'
        file_handler = logging.FileHandler(log_filename)
        # Create a formatter and add it to the handlers
        formatter = logging.Formatter(
            '%(asctime)s - %(levelname)s - %(message)s')
        file_handler.setFormatter(formatter)
        file_handler.setLevel(chat_level[chatter])
        logger.addHandler(file_handler)

    bcprod(infile, outprefix, burstgti, radec, attitude, erange1, erange2,
           erange3, timebin1, timebin2, poly, refdata, chatter, clob)


if __name__ == '__main__':
    main()
