#
# Requirement
#  The task asks for the final number of channels with the parameters chanbin
#  based on the allowed channels of the rebin and energy bounds file stored in CALDB
#  For this preset channels CALDB set-up is required.
#  However it is possible to select a chanbin different from those allowed, but
#  user has to enter a consistent file for the rebin and energy bounds both consistent
#  in format of the equivalent stored in the CALDB. This feature is useful to test
#  different binning or for different mission that has identical science data format
#  but different rebinning and energy boundaries scheme
#
#  The task use a query to CALDB to retrive the correct file and uses the parameter
#  calexpr that specify the caldb boundary keyword defined by CDB1001 and CDB2001.
#
#  The task need to have set the location of the REFDATA directory that contains
#  the headers of the extensions in  the output file
#
#   The steps of the tasks are the following
#   Step 1 Check input parameters, set-up logs are read header keywords
#   Step 2 Check the CALDB files and or the user calibration input files
#   Loop on single detectors
#   Step 3  Build the output file
#   Step 4  Build the lightcurve text file and plot
#   Step 5  Build the spectrum text file and plot
#   end loop
#   James Runge (ADNET) , Lorella Angelini (HEASARC)
#


import sys
import os
import argparse
from argparse import RawTextHelpFormatter
import datetime
import logging

import astropy.io.fits as fits
from astropy.time import Time

from gdt.core.tte import PhotonList
from gdt.core.data_primitives import Ebounds, Gti, EventList, TimeEnergyBins
from gdt.core.collection import DataCollection
from gdt.core.plot.lightcurve import Lightcurve
from gdt.core.plot.spectrum import Spectrum

from gdt.core.binning.binned import rebin_by_time
from gdt.core.binning.unbinned import bin_by_time

from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt

from heasoftpy.burstcube.lib.phaii import MyPhaii, PhaiiFileHeaders
from heasoftpy.burstcube.lib.caldbpy import caldbpy
from heasoftpy.burstcube.lib.io import get_filelist, file_extn, clob_check
from heasoftpy.burstcube.bcversion_task import bcversion

import warnings

# Suppress Matplotlib user warnings
warnings.filterwarnings("ignore", category=UserWarning)


def bcrebevt(infile, outprefix, timecol, rebincol, chanbin, timebin, ebounds='CALDB', rebin='CALDB',
             plotbin=0.128, refdata='REFDATA', calexpr='DATAMODE,CHAN', chatter=1, clobber=False, phaii_dict=None):

    # ----STEP 1 START -------#
    # Set logger parameters
    logger = logging.getLogger('bcrebevt')
    # Check that chatter is reasonable
    if chatter < 0 or chatter > 5:
        print(f'Chatter value {chatter} is not valid. Must be 0-5.')
        return 1
    # Create dictionary to set level according to chatter value
    chat_level = {0: 40, 1: 20, 5: 10}
    # Chatter values 1 through 4 are all the same for now
    if 1 <= chatter <= 4:
        chatter = 1
    # Set logger level
    logger.setLevel(chat_level[chatter])
    version = bcversion()
    logger.info(f'VERSION: {version}')
    # Check params
    # ----Simple checks first----
    # Make sure infile is not empty
    if infile.strip() == '':
        logger.error('Input file name(s) is empty.')
        return 1

    # Default outprefix if none given
    if outprefix.strip() == '':
        logger.warning('WARNING: Outprefix was blank. Defaulting to "temp"')
        outprefix = 'temp'

    # Check refdata
    if refdata.strip() == '':
        logger.error('Input refdata is empty.')
        return 1
    # If default
    if refdata == 'REFDATA':
        # Check environmental setting
        if os.getenv('LHEA_DATA'):
            refdata = os.environ['LHEA_DATA'] + '/'
        else:
            logger.error("LHEA_DATA is not defined in current environment.")
            return 1
    else:
        # Check that directory exists
        if Path(refdata).exists() and Path(refdata).is_dir():
            # Catch if path does not end with /
            if refdata[-1] != '/':
                refdata = refdata + '/'
        else:
            logger.error(f"Directory {refdata} does not exist.")
            return 1
    # End refdata check

    # Check chanbin
    try:
        chanbin = int(chanbin)
    except:
        logger.error(f"Channel binning {chanbin} is not an integer.")
        return 1
    # End chanbin check

    # Check timebin
    try:
        timebin = float(timebin)
    except:
        logger.error(f"{timebin} is not a proper float value.")
        return 1
    # End timebin check

    # Check plotbin
    try:
        plotbin = float(plotbin)
    except:
        logger.error(f"{plotbin} is not a proper float value.")
        return 1
    # End plotbin check

    # ----More involved checks----
    # Start checking input data files
    # Load input files
    file_list = get_filelist(infile, logger)
    # If file not found
    if not isinstance(file_list, np.ndarray):
        return 1
    # Initialize arrays to hold filenames(paths) and associated detector
    evt_lst = {}
    detectors = []
    # Boolean to check whether DETNAME or INSTRUME is used
    usedet = False
    # Check that files exist
    for file in file_list:
        # Pass to function to check for extension at end of filename
        file, extn = file_extn(file)
        if not Path(file).exists():
            logger.error(f"File {file} not found.")
            return 1
        else:
            # Load header from primary and retrieve keywords
            header = fits.open(file)[0].header
            mission = header['TELESCOP']
            instrument = header['INSTRUME']
            # If DETNAME exists, use that as detector name
            if header.get('DETNAM') is not None:
                det = header['DETNAM']
                usedet = True
                detstr = 'DETNAM'  # String to hold detector keyword
            else:
                det = instrument
                detstr = 'INSTRUME'  # String to hold detector keyword
            datamode = header['DATAMODE']
            # Populate arrays
            evt_lst[det] = {'file': Path(file), 'extn': extn}
            detectors.append(det)

    # Output list of files being used
    logger.info(f"Using files:{file_list}")
    # Sort detectors alphabetically
    detectors.sort()

    # Check for instances of duplicate detectors
    if has_duplicate(detectors):
        # Output error if a detector shows up more than once
        logger.error("Input files contain duplicate detector/instrument.")
        return 1
    # Stop ckecking science data

    # ----STEP 1 STOP -------#

    # ----STEP 2 START -------#
    # Start checking ebounds files
    # Set up collection to hold ebounds info from file
    ebounds_dict = DataCollection()
    # Strip quote marks if present
    ebounds = ebounds.strip('\'"')
    # If CALDB, check that CALDB environment is set
    if ebounds == 'CALDB':
        if not os.getenv('CALDB'):
            logger.error("CALDB environment is not set.")
            return 1
        else:
            CALDB = os.environ['CALDB']
    # If @file or comma-separated list
    else:
        # Strip any whitespace
        ebounds = ebounds.strip()
        # Retrieve list of files
        file_list = get_filelist(ebounds, logger)
        # If file not found
        if not isinstance(file_list, np.ndarray):
            return 1
        # Simple check to make sure number of ebounds matches input files
        if len(detectors) != file_list.size:
            logger.error(
                "Number of EBOUNDS files does not match number of input files.")
            return 1

    # Begin looping through number of detectors
    for i, det in enumerate(detectors):
        # If using CALDB, query using detector name
        if ebounds == 'CALDB':
            # Query CALDB for file location and extension
            # Note: Currently only viable for BurstCube

            # Get header for detector
            header = fits.open(evt_lst[det]['file'])[
                evt_lst[det]['extn']].header

            # Function to create expression based upon calexpr
            expr = create_query(calexpr, header, chanbin, logger)

            # If Header keyword was not found during expression creation
            # exit gracefully
            if not isinstance(expr, str):
                return 1

            # Search CALDB
            if usedet:
                filepath, extension = caldbpy(
                    mission, instrument, det, '-', 'energy ebounds', 'now', 'now', expr)
            else:
                filepath, extension = caldbpy(
                    mission, det, '-', '-', 'energy ebounds', 'now', 'now', expr)
            # Check if file was found
            if isinstance(filepath, np.ndarray):
                full_file = CALDB + "/" + filepath[0]
                # Check that file exists
                if not Path(full_file).exists():
                    logger.error(f"Could not find file {full_file}")
                    return 1
                else:
                    ebounds_file = fits.open(full_file)[extension[0]]
                    matching_data = ebounds_file.data
                    detector = det
            else:
                # If no file found
                logger.error(f"No matching EBOUNDS file found for {det}")
                logger.error(
                    f"No binning by {chanbin} channels in datamode {datamode}")
                return 1

        # If list of files, load each file then check header for detname
        else:
            file, extn = file_extn(file_list[i])
            if not Path(file).exists():
                logger.error(f"File {file} not found.")
                return 1
            else:
                ebounds_file = fits.open(file)[extn]
                header = ebounds_file.header
                # Check that channel numbers match expected
                if header['NAXIS2'] != chanbin:
                    logger.error(
                        f"{file} does not contain correct number of channel bins {chanbin}")
                    return 1
                # Check header for detector name
                if detstr in header and header[detstr] in detectors:
                    matching_data = ebounds_file.data
                    detector = header[detstr]
                else:
                    # Instrument mismatch
                    logger.error(
                        f"Detector in ebounds file ({header[detstr]}) does not match any input detectors:{detectors}")
                    return 1

        # Load E_MIN and E_MAX column from matching data
        emin = matching_data['E_MIN']
        emax = matching_data['E_MAX']
        # Create Ebounds data class based on loaded data and save into collection
        ebounds_dict.include(Ebounds.from_bounds(emin, emax), name=detector)

    # Check that all detectors are accounted for
    all_present = all(value in ebounds_dict.items for value in detectors)
    if not all_present:
        logger.error(f"Not all detectors({detectors}) found in EBOUNDS files.")
        return 1
    # Stop checking ebounds file

    # Start checking rebin file
    if datamode == 'EVENT':
        # Set up dict to hold rebin data
        rebin_dict = {}
        # Strip quote marks if present
        rebin = rebin.strip('\'"')
        # If CALDB, check that CALDB environment is set
        if rebin == 'CALDB':
            if not os.getenv('CALDB'):
                logger.error("CALDB environment is not set.")
                return 1
            else:
                CALDB = os.environ['CALDB']
        # If @file or comma-separated list
        else:
            # Strip any whitespace
            rebin = rebin.strip()
            # Retrieve list of files
            file_list = get_filelist(rebin, logger)
            # If file not found
            if not isinstance(file_list, np.ndarray):
                return 1
            # Simple check to make sure number of rebins matches input files
            if len(detectors) != file_list.size:
                logger.error(
                    "Number of REBIN files does not match number of input files.")
                return 1

        # Begin looping through number of detectors
        for i, det in enumerate(detectors):
            # Open event file and retrieve number of channels
            evt_header = fits.open(evt_lst[det]['file'])[
                evt_lst[det]['extn']].header
            colidx = get_colidx(evt_header, rebincol.upper())
            if colidx is None:
                logger.error(f"Column {rebincol.upper()} not found.")
                return 1
            # Calculate input number of channels from header keywords
            if evt_header.get('DETCHANS') is not None:
                inchan = evt_header['DETCHANS']
            elif evt_header.get(f'TLMAX{colidx}') is not None and evt_header.get(f'TLMIN{colidx}') is not None:
                inchan = evt_header[f'TLMAX{colidx}'] - \
                    evt_header[f'TLMIN{colidx}'] + 1
            else:
                logger.error(
                    f"Could not determine number of channels in file {evt_lst[det]['file']}")
                logger.error(
                    "Check that DETCHANS or TLMIN/TLMAX are present in header.")
                return 1
            # If using CALDB, query using detector name
            if rebin == 'CALDB':
                # Query CALDB for file location and extension

                expr = f"datamode={datamode}&chan={chanbin}"
                if usedet:
                    filepath, extension = caldbpy(
                        mission, instrument, det, '-', 'rebin', 'now', 'now', expr)
                else:
                    filepath, extension = caldbpy(
                        mission, det, '-', '-', 'rebin', 'now', 'now', expr)
                # Check if file was found
                if isinstance(filepath, np.ndarray):
                    full_file = CALDB + "/" + filepath[0]
                    # Check that file exists
                    if not Path(full_file).exists():
                        logger.error(f"Could not find file {full_file}")
                        return 1
                    else:
                        rebin_file = fits.open(full_file)[extension[0]]
                        rebin_data = rebin_file.data
                        detector = det
                else:
                    # If no file found
                    logger.error(
                        f"No matching REBIN file found for {det} and channel binning of {chanbin}")
                    return 1

            # If list of files, load each file then check header for detname
            else:
                file, extn = file_extn(file_list[i])
                if not Path(file).exists():
                    logger.error(f"File {file} not found.")
                    return 1
                else:
                    rebin_file = fits.open(file)[extn]
                    header = rebin_file.header
                    # Check header for detector name
                    if detstr in header and header[detstr] in detectors:
                        rebin_data = rebin_file.data
                        detector = header[detstr]
                    else:
                        if detstr in header:
                            # Instrument mismatch
                            logger.error(
                                f"Detector in REBIN file ({header[detstr]}) does not match any input detectors:{detectors}")
                        else:
                            # Header keyword not found
                            logger.error(
                                f"Detector keyword {detstr} not found in REBIN file.")
                        return 1

            # Retrieve column data
            chanmin = rebin_data['CHANMIN'].astype(int)
            chanmax = rebin_data['CHANMAX'].astype(int)
            rebinning = rebin_data['REBINNING'].astype(int)

            # Check that input channels from file is consistent with rebin file
            if inchan != chanmax.max() + 1:
                # Report error and quit
                logger.error(
                    "Mismatch of rebinning file channels and input files")
                return 1
            # Calculate channel edges
            channel_edges = get_edges(chanmin, chanmax, rebinning)
            # Check and make sure number of channels is consistent with chanbin
            if chanbin != channel_edges.size - 1:
                # Report error and quit
                logger.error(
                    f"Number of channels in rebin file ({channel_edges.size - 1}) does not match desired output of {chanbin}")
                return 1
            # Store channel edges in rebin dict
            rebin_dict[detector] = channel_edges
        # Check that all detectors are accounted for
        all_present = all(value in rebin_dict for value in detectors)
        if not all_present:
            logger.error(
                f"Not all detectors({detectors}) found in rebin files.")
            return 1
    # Stop checking rebin file
    # ----STEP 2 STOP -------#

    # Loop on each detector
    # ----STEP 3 START -------#
    # Initialize array names for the data (dictionaries)
    # output data : Time,  PHAarray , SumPHA. Note: but for SumPHA this is also the ATD input structure
    if phaii_dict is None:
        phaii_dict = DataCollection()
    # Stop initialize

    # Start set-up plotting
    # PLOTS : lightcurves and spectra summing all data for each detector
    # Plotting parameters
    # Calculate columns and rows based on number of files

    num_files = len(evt_lst)
    rows = int(np.ceil(np.sqrt(num_files)))
    cols = int(np.ceil(num_files/rows))

    # Set up figures for spec and lightcurves
    # Each figure has the columns and rows calculated above, dpi(resolution) set to 150
    fig_spec, ax_spec = plt.subplots(rows, cols, dpi=150)
    fig_lc, ax_lc = plt.subplots(rows, cols, dpi=150)
    # Initialize the arrays to hold the plots for each individual instrument
    lcplot = {}
    specplot = {}

    # Set up location of individual plots on figure
    # Location is set as starting in upper-left of grid, filling up row, and moving to next row
    det_to_ax_index = {}
    count = 0
    for i in range(rows):
        for j in range(cols):
            # If count greater than number of input files, remove plot space from figure
            if count >= len(evt_lst):
                fig_spec.delaxes(ax_spec[i, j])
                fig_lc.delaxes(ax_lc[i, j])
            else:
                # Assign instrument to specific location in figure
                det_to_ax_index[detectors[count]] = (i, j)
            count += 1
    # Stop set-up plotting

    # Calculate date for output filenames based on DATE-OBS
    fits_test = fits.open(evt_lst[detectors[0]]['file'])

    date_obs = fits_test[0].header['DATE-OBS']

    time = Time(date_obs, format='fits')
    # Extract the date components
    year = time.datetime.year % 100  # YY
    month = time.datetime.month       # MM
    day = time.datetime.day           # DD

    # Calculate the fraction of the day
    fraction_of_day = int(np.ceil((time.mjd - int(time.mjd))*1000))

    # Format the date and time components into YYMMDDXXX
    formatted_date_time = f'{year:02d}{month:02d}{day:02d}{fraction_of_day:03d}'
    # End creating output date for filenames

    # Start main loop for each detector and create spectrum and lightcurve for each
    # Loop for each detector
    for det, items in evt_lst.items():
        file = items['file']
        extn = items['extn']
        # Open file and load headers
        obj = fits.open(file)

        hdrs = [hdu.header for hdu in obj]

        # Load appropriate data
        # Routine if using EVENT file (i.e. unbinned data)
        if datamode == 'EVENT':

            fileprefix = f'{outprefix}_{det.lower()}_evt'

            events = obj[extn].data

            try:
                times = events[timecol.upper()]
            except:
                logger.error(f"Column {timecol.upper()} not found.")
                return 1

            trigtime = obj[extn].header.get('TRIGTIME')
            
            if trigtime is not None:
                times -= trigtime

            try:
                channels = events[rebincol.upper()]
            except:
                logger.error(f"Column {rebincol.upper()} not found.")
                return 1

            # Rebin channels based upon rebinning scheme
            rebinned_channels = np.digitize(channels, bins=rebin_dict[det]) - 1

            # New rebinned EventList with proper ebounds
            data = EventList(times=times,
                             channels=rebinned_channels,
                             ebounds=ebounds_dict.get_item(det))

            # the good time intervals
            gti = obj['STDGTI'].data

            gti_start = gti['START']
            gti_stop = gti['STOP']

            if trigtime is not None:
                gti_start -= trigtime
                gti_stop -= trigtime

            gti = Gti.from_bounds(gti_start, gti_stop)

            obj.close()

            tte = PhotonList.from_data(data,
                                       gti=gti,
                                       trigger_time=trigtime,
                                       filename=obj.filename)

            # Load header for output file
            headers = PhaiiFileHeaders(refdata)

            # Copy header info from input file
            for header in headers:
                for key, val in header.items():
                    if key in hdrs[extn].keys() and key[:5] != 'HDUCL':
                        header[key] = hdrs[extn][key]
                if usedet:
                    header.set(
                        'DETNAM', hdrs[extn]['DETNAM'], 'Detector name', after='INSTRUME')

            # Bin by time and save to phaii data class
            phaii1 = tte.to_phaii(
                bin_by_time, timebin, time_ref=tte.time_range[0], phaii_class=MyPhaii, headers=headers)
            phaii1.set_timedel(timebin)
        # End EVENT data routine

        # Start binned data routine
        else:

            fileprefix = f'{outprefix}_{det.lower()}_{datamode.lower()}'

            atd_data = obj[extn].data

            # File extension should have columns TIME and PHA (array)
            try:
                time_stop = atd_data[timecol.upper()]
            except:
                logger.error(f"Column {timecol.upper()} not found.")
                return 1

            timedel = obj[extn].header['TIMEDEL']
            time_start = np.insert(time_stop[:-1], 0, time_stop[0]-timedel)

            try:
                counts = atd_data[rebincol.upper()]
            except:
                logger.error(f"Column {rebincol.upper()} not found.")
                return 1

            trigtime = obj[extn].header.get('TRIGTIME')

            if trigtime is not None:
                time_start -= trigtime
                time_stop -= trigtime

            exposure = np.full_like(time_stop, timedel)
            emin = ebounds_dict.get_item(det).low_edges()
            emax = ebounds_dict.get_item(det).high_edges()

            data = TimeEnergyBins(counts, time_start,
                                  time_stop, exposure, emin, emax)

            # Load header for output file
            headers = PhaiiFileHeaders(refdata)

            # Copy header info from input file
            for key, val in headers['PRIMARY'].items():
                if key in hdrs[0].keys():
                    headers['PRIMARY'][key] = hdrs[0][key]

            for key, val in headers['ARRAY_PHA'].items():
                if key in obj[extn].header.keys():
                    headers['ARRAY_PHA'][key] = obj[extn].header[key]

            phaii1 = MyPhaii.from_data(
                data, headers=headers, trigger_time=trigtime)
            phaii1.set_timedel(timedel)

        # End binned data routine

        # Write phaii data to file
        phaii_filename = f'{fileprefix}_{formatted_date_time}_phaii.fits'
        logger.info(f"Writing output file {phaii_filename}")
        # Existence check
        status = clob_check(clobber, phaii_filename, logger)
        if status == 1:
            return 1
        else:
            phaii1.write(".", phaii_filename, overwrite=clobber)

        phaii_dict.include(phaii1, name=det)

        # ----STEP 3 STOP FOR SINGLE DETECTOR -------#

        # ----STEP 4 START FOR SINGLE DETECTOR -------#
        # Lightcurve creation
        logger.info(f"Creating lightcurve for instrument {det}")
        lc = phaii1.to_lightcurve()

        rebin_lc = lc.rebin(rebin_by_time, plotbin)
        rebin_lc = rebin_lc.slice(
                    phaii1.data.tstart[1], phaii1.data.tstop[-2])
        # Write to text file
        lc_filename = f"{fileprefix}_{formatted_date_time}_lc.txt"
        # Existence check
        status = clob_check(clobber, lc_filename, logger)
        if status:
            return 1
        else:
            lc_hdr = "Lo_edge, Hi_edge, Rate, Rate_uncertainty"
            lc_data = np.stack((rebin_lc.lo_edges, rebin_lc.hi_edges,
                               rebin_lc.rates, rebin_lc.rate_uncertainty), axis=1)
            np.savetxt(lc_filename, lc_data, fmt='%f',
                       header=lc_hdr, delimiter=',', newline='\n')

        # Create plot
        if len(evt_lst) > 1:
            lc_axes = ax_lc[det_to_ax_index[det]]
            spec_axes = ax_spec[det_to_ax_index[det]]
        else:
            lc_axes = ax_lc
            spec_axes = ax_spec

        lcplot[det] = Lightcurve(data=rebin_lc,
                                 ax=lc_axes)
        # Add detector label
        lcplot[det].ax.text(0.9, 0.9, det,
                            horizontalalignment='center',
                            verticalalignment='center',
                            transform=lcplot[det].ax.transAxes)

        # Put filename as title
        lcplot[det].ax.set_title(file.name, fontdict={
                                 'fontsize': 8, 'fontweight': 'medium'})
        if phaii1.trigtime is not None:
            lcplot[det].ax.set_xlabel("Time Since Trigger Time (s)")

        lc_axes.set_ylim(0.9*np.amin(rebin_lc.rates),
                         1.1*np.amax(rebin_lc.rates))
        # ----STEP 4 STOP FOR SINGLE DETECTOR -------#

        # ----STEP 5 START FOR SINGLE DETECTOR -------#
        # Spectrum Creation
        logger.info(f"Creating spectrum for instrument {det}")
        spectrum = phaii1.to_spectrum()

        # Write to text file
        spec_filename = f"{fileprefix}_{formatted_date_time}_spec.txt"
        # Existence check
        status = clob_check(clobber, spec_filename, logger)
        if status:
            return 1
        else:
            spec_hdr = "Lo_edge, Hi_edge, Rate, Rate_uncertainty"
            spec_data = np.stack((spectrum.lo_edges, spectrum.hi_edges,
                                 spectrum.rates, spectrum.rate_uncertainty), axis=1)
            np.savetxt(spec_filename, spec_data, fmt='%.3f',
                       header=spec_hdr, delimiter=',', newline='\n')

        # Create plot
        specplot[det] = Spectrum(data=spectrum,
                                 ax=spec_axes)
        # Add detector label
        specplot[det].ax.text(0.9, 0.9, det,
                              horizontalalignment='center',
                              verticalalignment='center',
                              transform=specplot[det].ax.transAxes)
        # Put filename as title
        specplot[det].ax.set_title(
            file.name, fontdict={'fontsize': 8, 'fontweight': 'medium'})

        spec_axes.set_xlim(0.9*np.amin(spectrum.lo_edges),
                           1.1*np.amax(spectrum.hi_edges))

        # ----STEP 5 STOP FOR SINGLE DETECTOR -------#
    # Stop loop through detectors

    # Make figures look better by ensuring no overlap
    fig_lc.tight_layout()
    fig_spec.tight_layout()

    # Save figures to file
    logger.info("Saving graphs")
    # Existence check
    lcgraph_filename = f'{outprefix}_{formatted_date_time}_lightcurves.png'
    status = clob_check(clobber, lcgraph_filename, logger)
    if status == 1:
        return 1
    else:
        fig_lc.savefig(lcgraph_filename)
    # Existence check
    specgraph_filename = f'{outprefix}_{formatted_date_time}_binned_spectra.png'
    status = clob_check(clobber, specgraph_filename, logger)
    if status == 1:
        return 1
    else:
        fig_spec.savefig(specgraph_filename)

    return 0

# Retrieve column index number for a given column name


def get_colidx(header, column_name):
    # Find column index for given column name
    column_index = None
    for key in header.keys():
        if key.startswith('TTYPE') and header[key] == column_name:
            column_index = int(key[5:])
            break

    return column_index

# Simple function to check for duplicates in an array
# Returns boolean: True if duplicate present


def has_duplicate(arr):
    seen = set()
    for item in arr:
        if item in seen:
            return True
        seen.add(item)
    return False


# Return search CALDB search expression based on CAL_CBD keywords and values
def create_query(calexpr, header, chans, logger):
    expr = ''
    split_cal = calexpr.split(',')
    for key in split_cal:
        if key == 'CHAN':
            expr += f'chan={chans}&'
        else:
            val = header.get(key)
            if val is None:
                logger.error(f'Header keyword {key} not found.')
                return 1
            else:
                expr += f'{key.lower()}={val}&'

    # Remove final &
    expr = expr[:-1]
    return expr

# Function to return edges based upon rebinning file data


def get_edges(chanmin, chanmax, rebin):
    edges = np.array([], dtype='int')
    for min, max, binning in zip(chanmin, chanmax, rebin):
        if binning != 1:
            spacing = np.arange(min, max, binning)
            edges = np.append(edges, spacing)
        else:
            edges = np.append(edges, [min])

    edges = np.append(edges, chanmax[-1]+1)
    return edges


def main():
    desc = """
    Program to rebin data based upon an input time bin value.
    """

    parser = argparse.ArgumentParser(
        description=desc, formatter_class=RawTextHelpFormatter)
    parser.add_argument('--infile', action="store", type=str,
                        help='Input EVENT or ATD science file or @file')
    parser.add_argument('--outprefix', type=str,
                        help='Prefix for output files')
    parser.add_argument('--timecol', type=str,
                        help='Time column of input file')
    parser.add_argument('--rebincol', type=str,
                        help='Column to be rebinned (e.g. COUNTS, PHA, etc...)')
    parser.add_argument('--chanbin', type=int,
                        help='Number of output channels [EVENT 16 or 64, ATD 16]')
    parser.add_argument('--ebounds', type=str,
                        help='FITS file containing EBOUNDS or CALDB',
                        default='CALDB')
    parser.add_argument('--rebin', type=str,
                        help='FITS file containg rebinning or CALDB',
                        default='CALDB')
    parser.add_argument('--timebin', type=float,
                        help='Time bin width in seconds',
                        default=0.032)
    parser.add_argument('--plotbin', type=float,
                        help='Bin width in seconds used for plotting',
                        default=0.128)
    parser.add_argument('--refdata', type=str,
                        help='Reference data directory',
                        default='REFDATA')
    parser.add_argument('--calexpr', type=str,
                        help='CALDB EBOUNDS query keywords',
                        default='DATAMODE,CHAN')
    parser.add_argument('--chatter', type=int,
                        default=1, help='Level of verboseness.')
    parser.add_argument('--log', type=str,
                        default='no', help='Create log file?')
    parser.add_argument('--clobber', type=str,
                        default='no', help='Overwrite existing output file(s)?')

    args = parser.parse_args()
    args_dict = vars(args)
    # Loop through input parameters and prompt user if a required value is not provided
    for param, val in args_dict.items():
        if val is None:
            # Print help information for the parameter
            for action in parser._actions:
                if action.option_strings and action.option_strings[0] == f'--{param}':
                    # Prompt user for input
                    arg_type = action.type if hasattr(action, 'type') else str
                    # Make sure input is valid type
                    valid = False
                    while not valid:
                        new_value = input(f"{action.help}: ")
                        try:
                            new_value = arg_type(new_value)
                            setattr(args, param, new_value)
                            valid = True
                        except:
                            type_str = str(arg_type).split("'")[1]
                            print(
                                f"{param} must be a valid {type_str}. Try again.")

    # Assign values to be passed to main routine
    infile = args.infile
    outprefix = args.outprefix
    chanbin = args.chanbin
    timecol = args.timecol
    rebincol = args.rebincol
    ebounds = args.ebounds
    rebin = args.rebin
    timebin = args.timebin
    plotbin = args.plotbin
    refdata = args.refdata
    calexpr = args.calexpr
    chatter = args.chatter
    log = args.log in ['yes', 'y', True]  # Simple check
    clob = args.clobber in ['yes', 'y', True]  # Simple check

    # Set up logger
    logger = logging.getLogger('bcrebevt')
    # Set level based upon chatter level
    # Check that chatter is reasonable
    if chatter < 0 or chatter > 5:
        print(f'Chatter value {chatter} is not valid. Must be 0-5.')
        sys.exit()
    # Create dictionary to set level according to chatter value
    chat_level = {0: 40, 1: 20, 5: 10}
    # Chatter values 1 through 4 are all the same
    if 1 <= chatter <= 4:
        chatter = 1

    # If log file is wanted
    if log:
        # Create a file handler and set the logging level
        # Get current date and time
        now = datetime.datetime.now()
        date_time = now.strftime("%d%m%Y_%H%M%S")
        log_filename = f'bcrebevt_{date_time}.log'
        file_handler = logging.FileHandler(log_filename)
        # Create a formatter and add it to the handlers
        formatter = logging.Formatter(
            '%(asctime)s - %(levelname)s - %(message)s')
        file_handler.setFormatter(formatter)
        file_handler.setLevel(chat_level[chatter])
        logger.addHandler(file_handler)

    # Pass parameters to main routine
    bcrebevt(infile, outprefix, timecol, rebincol, chanbin, timebin,
             ebounds, rebin, plotbin, refdata, calexpr, chatter, clob)


if __name__ == '__main__':
    main()
