#
#
#  Requirement
#   The task requires as input files 
#       a) the map file output of bcoccult
#       b) the calibration file stored in CALDB  which contains the count simulated for different
#       spectral models for each of the  detector and 3072 all sky pixesl which corresponding to a
#       HEAPIX representation with nside = 16.
#       c) the array science file output of the bcevtreb
#       d) the gti file as the output of bctimebst
#       e) the attitude file appropriate to the observation  
#   and the energy range over which the position is calculted.
#
#   Using  the orbit file and gti it creates an output file with a map in the primary
#   header in AIT projection and a first extension with HEAPIX representation. The values in the map
#   represent the part of the sky occulted by the earth (pixel set to 0) or not
#   occulted pixel set =1. The size of map is internally calculated from
#   the parameter nside that is 12 x (nside)**2.
#   The task uses the simulated value for nside into CALDB , if the size of the occulation is different
#   the task make th eoccultaion map match the simulated number of pixels 
#    The task need to have set the location of the REFDATA directory that contains
#    the headers of the output file  
#
# Step 1: Check input parameters and load data from input files.
#
# Step 2: Divede the data source and background counts using the gti. Using the background interval
#         fit the data to find the background counts that apply to the burst interval (fitted interva)
#
# Step 3: Select from the attitude all quaternions that are within the burst time interval 
#
# Begin loop over simulation files and using all detectors
# Step 4: Load individual simulation data file from CALDB 
#
# Step 5: Using the CALDB simulation file and attitude.  From a HEAPIX map with nside equal to the
#         that of the CALDB file, take the associated theta and phi for a given HEAPIX pixel
#         and use the quaternian in the attitude to calculate a new theta and  phi (ntheta nphi).
#         Read from the CALDB the rates corresponding to the ntheta and nphi and the near 4 pixels to
#         calculate and weighted average. This average is placed in the pixel number corresponding
#         to the unrotated  HEAPIX map. The procedure is done for each quaternian included in the burst  
#         interval. The maps for each quaternian are multiplied by the duration of the quanternian and
#         all quaternian maps are summed together so that each pixel contains counts.
#         This proceduce is done for each detectors.  
#         time theta phi index ntheta nphi  value1 value2 value3 value4 valueaver detector       
#
# Step 6: Using the total source counts, background counts for the burst interval for each detectora
#         within a specific energy range  and the expected counts from CALDB adjusted for the attitude
#        (Step 5) , calculate
#         the log-likelihood at each HEALPix pixel position considering all detectors. For example in
#         the number of sky-pixel is 3072 the likelihood is calculated for each pixels using the source and
#         background counts from the 4 detectors in a given pixel. The array calculated contains the
#         likehood probability in each pixel considering all detectors 
# 
# Step 7: Summed the expected counts in the maps calculated in Step 5 and pick the detector with the second
#         highest  expected counts.
#
# Step 8: Apply the occultation map to the log-likelihood map by muiltpling pixel by pixel. The
#         occultaion map is resized to the likehood map if larger.
#         The higher value within this map is the "best position location" of the event for that
#         spectral simulated model.
#
# End loop over simulation files
#
# Step 9: Choose log-likelihood map out of Step 8 with highest likelihood and write a FITS
#         file with two extensions:
#         Primary: Image in Aitoff-galactic projection 
#         HEALPIX: HEALPix data map 
#
#         Dump all into one file x N extensions where each correponds to a model
#         
#
# Step 10:Create figure displaying highest log-likelihood map showing 1-,2-, and
#         3-sigma contour levels. Write figure to file.
#


import sys
import os
import argparse
from argparse import RawTextHelpFormatter
import logging

import astropy.io.fits as fits
from astropy.time import Time
import astropy.units as u
import datetime


from astropy.coordinates.representation import PhysicsSphericalRepresentation
from astropy.coordinates import SkyCoord
from astropy_healpix import HEALPix
from astropy.utils.exceptions import AstropyWarning

from gdt.core.coords import SpacecraftFrame, Quaternion
from gdt.core.background.fitter import BackgroundFitter
from gdt.core.background.binned import Polynomial
from gdt.core.collection import DataCollection
from gdt.core.plot.sky import EquatorialPlot, get_lonlat
from gdt.core.healpix import HealPixLocalization
from gdt.core.plot.plot import SkyLine, SkyPolygon
from gdt.core.binning.binned import combine_into_one


from pathlib import Path

import numpy as np
from scipy.special import gammaln, logsumexp

import matplotlib.pyplot as plt
import healpy as hp

from heasoftpy.burstcube.lib.phaii import MyPhaii
from heasoftpy.burstcube.lib.fast_norm_fit import FastNormFit
from heasoftpy.burstcube.lib.io import get_filelist, file_extn, clob_check
from heasoftpy.burstcube.lib.caldbpy import caldbpy
from heasoftpy.burstcube.lib.pixmaps import pix_img, pix_data, HealpixPrimHeader
from heasoftpy.burstcube.bcversion_task import bcversion

import warnings

# Suppress Matplotlib user warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.simplefilter('ignore', category=AstropyWarning)



def bcfindloc(infile, outprefix, burstgti, attitude, occultation, emin=50.0, emax=300.0,
              simloc='CALDB', extension='all', poly=3, refdata='REFDATA',
              chatter=1, clobber=False, phaii_dict=None):

    # ----STEP 1 START------#
    # ----Outline----
    # check parameters and load files
    # Set logger parameters
    logger = logging.getLogger('bcfindloc')
    # Check that chatter is reasonable
    if chatter < 0 or chatter > 5:
        print(f'Chatter value {chatter} is not valid. Must be 0-5.')
        return 1
    # Create dictionary to set level according to chatter value
    chat_level = {0: 40, 1: 20, 5: 10}
    # Chatter values 1 through 4 are all the same for now
    if 1 <= chatter <= 4:
        chatter = 1
    # Set logger level
    logger.setLevel(chat_level[chatter])
    version = bcversion()
    logger.info(f'VERSION: {version}')
    # Begin checking all input parameters
    logger.info("Verifying input parameters")

    # Check params
    # ----Simple checks first----
    # Check for empty strings
    # Default outprefix if none given
    if outprefix.strip() == '':
        logger.warning('WARNING: Outprefix was blank. Defaulting to "temp"')
        outprefix = 'temp'

    filenames = [infile, burstgti, attitude, occultation]
    varnames = ['infile', 'burstgti', 'attitude', 'occultation']
    for file, var in zip(filenames, varnames):
        if file.strip() == '':
            logger.error(f'Input {var} filename is empty string')
            return 1

    # If default
    if refdata == 'REFDATA':
        # Check environmental setting
        if os.getenv('LHEA_DATA'):
            refdata = os.environ['LHEA_DATA'] + '/'
        else:
            logger.error("LHEA_DATA is not defined in current environment.")
            return 1
    else:
        # Check that directory exists
        if Path(refdata).exists() and Path(refdata).is_dir():
            # Catch if path does not end with /
            if refdata[-1] != '/':
                refdata = refdata + '/'
        else:
            logger.error(f"Directory {refdata} does not exist.")
            return 1
    # End refdata check

    # Start checking Erange
    # if high end is lower than low end, throw error
    try:
        erange = (float(emin), float(emax))
    except:
        logger.error(f"{emin},{emax} contains non-float values.")
        return 1
    if emin >= emax:
        logger.error(
            f"Lower bound for energy range {emin} is larger than or equal to higher bound {emax}.")
        return 1
    # End checking Erange

    # ---End Simple Checks----

    # ---More involved checks / Loading data---

    # If phaii_dict is not given, load files
    if phaii_dict is None:
        # Start checking if input files exist
        if infile.strip() == '':
            logger.error('Input infile name(s) is empty.')
            return 1
        else:
            file_list = get_filelist(infile, logger)
            # If file not found
            if not isinstance(file_list, np.ndarray):
                return 1

        # Initialize array to hold instrument names
        detectors = []
        # Initialize dictionary to hold PHAII data
        phaii_dict = DataCollection()
        # Check that files exist and get instrument name
        for file in file_list:
            if not Path(file).exists():
                logger.error(f"File {file} not found.")
                return 1
            else:
                phaii = MyPhaii.open(file, refdata)
                det = phaii.detector
                detectors.append(det)
                trigtime = phaii.trigtime
                phaii_dict.include(phaii, det)
                telescope = phaii.headers[0]['TELESCOP']
        logger.info(f"Using files:{file_list}")
        # End checking input files
    # End if phaii_dict=None
    else:
        # phaii_dict is given, retrieve detectors
        detectors = phaii_dict.items
        trigtime = phaii_dict.get_item(detectors[0]).trigtime
        telescope = phaii_dict.get_item(detectors[0]).headers[0]['TELESCOP']
    # Sort detectors alphabetically
    detectors.sort()

    # Begin loading start and stop times from GTI
    # Check for file existence
    burstgti = burstgti.strip()
    burstgti_file, extn = file_extn(burstgti)
    # Set variable peak_found.
    peak_found = True
    if not Path(burstgti_file).exists():
        logger.error(f"GTI file {burstgti_file} not found.")
        return 1
    else:
        gti_hdu = fits.open(Path(burstgti_file))
        # Test that extension exists
        if len(gti_hdu) < (int(extn) + 1):
            logger.warn(
                f"WARNING: No extension {extn} found for file {burstgti_file}")
            logger.warn("WARNING: Will default to first extension.")
            extn = 1
        source_data = gti_hdu[extn].data
        # Check header keywords to see if peak was found
        bsttot = -99.0
        try:
            bsttot = gti_hdu[extn].header['BSTTOT']
        except:
            logger.error(
                f"Header keyword BSTTOT not found in {burstgti_file}.")
            return 1
        if bsttot < 0:
            logger.warn("WARNING: No peak was previously found.")
            logger.warn(
                "WARNING: Total time will be used and no background fitting will be done.")
            peak_found = False

        # Store source time range
        src_times = []

        if trigtime is None:
            trigtime = 0.

        for tstart, tstop in zip(source_data['START'], source_data['STOP']):
            src_times += [(tstart-trigtime, tstop-trigtime)]

        # If peak was found
        if peak_found:
            bkg_data = gti_hdu['BKG'].data
            bkg_times = []

            for tstart, tstop in zip(bkg_data['START'], bkg_data['STOP']):
                bkg_times += [(tstart-trigtime, tstop-trigtime)]

    # Stop loading start and stop times from GTI

    # Check that GTI are reasonable
    # Load time range for data
    time_range = phaii_dict.get_item(detectors[0]).time_range
    # Check source
    for times in src_times:
        if (times[1] < time_range[0]) or (times[0] > time_range[1]):
            logger.error("GTI for source not within data time range.")
            return 1
    # Check background
    if peak_found:
        for times in bkg_times:
            if (times[1] < time_range[0]) or (times[0] > time_range[1]):
                logger.error("GTI for background not within data time range.")
                return 1

    # End checking GTI

    # Begin checking attitude file and load attitude data
    # Check if extension given
    attitude = attitude.strip()
    file, extn = file_extn(attitude)
    # Check that file exists
    if not Path(file).exists():
        logger.error(f"File {file} not found.")
        return 1

    # Open file
    att_fits = fits.open(Path(file))[extn]
    # Load column info
    att_cols = att_fits.columns
    # Check that the columns TIME and QPARAM
    cols_needed = np.array(['TIME', 'QPARAM'])

    not_present = np.setdiff1d(cols_needed, att_cols.names)

    # If column missing, return error
    if not_present.size != 0:
        logger.error("Needed columns not found:", not_present)
        return 1

    # Begin loading column data
    att_data = att_fits.data
    att_time = att_data['TIME']-trigtime
    att_quat = att_data['QPARAM']

    # Check that attitude time covers src_time (i.e. time of burst)
    for times in src_times:
        if not (att_time[0] <= times[0] <= att_time[-1]) or not (att_time[0] <= times[1] <= att_time[-1]):
            logger.error("Attitude file does not cover burst time.")
            return 1
    # End checking attitude file

    # Begin checking/loading occultation file
    occultation = occultation.strip()
    file, extn = file_extn(occultation)
    # Check that file exists
    if not Path(file).exists():
        logger.error(f"File {file} not found.")
        return 1

    # Open file
    occ_fits = fits.open(Path(file))[extn]

    # Open header
    occ_header = occ_fits.header
    order = occ_header['ORDERING']

    # Get healpixmap data
    occ_map = occ_fits.data['MAP_VALUE']
    # Convert to ring if needed
    if 'nest' in order.lower():
        occ_map = hp.reorder(occ_map, n2r=True)
    # End occultation loading

    # Begin checking simloc file
    if simloc == 'CALDB':
        if not os.getenv('CALDB'):
            logger.error("CALDB environment is not set.")
            return 1
        else:
            CALDB = os.environ['CALDB']

        # Query CALDB
        filepath, _ = caldbpy(
            telescope, 'CSA', '-', '-', 'simulation', 'now', 'now', '-')
        # Check if file was found
        if isinstance(filepath, np.ndarray):
            full_file = CALDB + "/" + filepath[0]
            # Check that file exists
            if not Path(full_file).exists():
                logger.error(f"Could not find file {full_file}")
                return 1
            else:
                simloc_file = full_file

        else:
            # If no file found
            logger.error("No simulation file found in CALDB")
            return 1
    else:
        # Check that file exists
        simloc = simloc.strip()
        file, extn = file_extn(simloc)
        # Check that file exists
        if not Path(file).exists():
            logger.error(f"File {file} not found.")
            return 1
        simloc_file = Path(file)

    # Get number of extensions
    hdulist = fits.open(simloc_file)
    num_ext = len(hdulist)
    # Close file
    hdulist.close()

    # End simloc check

    # Extension use check
    if extension == 'all':
        sim_extension = np.arange(1, num_ext)
    else:
        exten_str = np.array(extension.split(','))
        sim_extension = np.zeros_like(exten_str)
        # Check that extensions are reasonable
        for i, ext in enumerate(exten_str):
            try:
                ext_num = int(ext)
            except:
                logger.error(f'Extension {ext} is not an integer')
                return 1
            if ext_num > num_ext - 1:
                logger.error(f'Extension {ext_num} does not exist')
                return 1
            else:
                sim_extension[i] = ext_num
    # End extension check

    # End checking input files
    # ----STEP 1 STOP------#

    # ----STEP 2 START------#
    # For each detector, get source and background counts
    src_cnts = []
    bkg_cnts = []
    # Loop over all detectors
    # Get counts for source and background within signal range for each detector
    for det in detectors:

        phaii = phaii_dict.get_item(det)

        # Workaround for bug in GBM data tools.
        # The first and last bin have an incorrect exposure
        phaii = phaii.slice_time(
            [(phaii.data.tstart[1], phaii.data.tstop[-1])])

        # Slice by appropriate energy range
        phaii = phaii.slice_energy(erange)
        phaii = phaii.rebin_energy(combine_into_one)

        lightcurve = phaii.to_lightcurve(energy_range=erange)
        lc_signal = lightcurve.slice(*src_times[0])

        # Get total number of source counts in signal range
        src_cnts.append(np.sum(lc_signal.counts))

        # Fit the background to background times if peak was found
        if peak_found:
            backfitter = BackgroundFitter.from_phaii(phaii,
                                                     Polynomial,
                                                     time_ranges=bkg_times)

            backfitter.fit(order=poly)
            back_rates = backfitter.interpolate_bins(lc_signal.lo_edges,
                                                     lc_signal.hi_edges)

            bkg_cnts.append(int(back_rates.counts.sum()))

        else:
            bkg_cnts.append(0)

    # End getting source and background counts for each detector

    # ----STEP 2 STOP------#
   
    # ----STEP 3 START------#
    # Choose attitude quaternions based on burst time
    # Get indices within burst time
    start_ind = np.searchsorted(att_time, src_times[0][0])
    stop_ind = np.searchsorted(att_time, src_times[0][1])

    # Select only values with those indices
    att_times = att_time[start_ind-1:stop_ind+1]
    quat = Quaternion(att_quat[start_ind-1:stop_ind+1])
    # Set start and end time appropiately
    att_times[0] = src_times[0][0]
    att_times[-1] = src_times[0][1]

    # Create durations array
    durations = att_times[1:] - att_times[:-1]
    # End setting up quaternions

    # ----STEP 3 STOP------#

    # ----STEP 4 START------#
    # Create likelihood maps for each simulated rates map
    loc_maps = dict()
    max_TS_val = []
    for sim in sim_extension:
        # Open simulated data file extension
        map_data = fits.open(simloc_file)[sim].data
        header = fits.open(simloc_file)[sim].header
        map_model = header.get('CBD10001')
        if map_model is None:
            map_model = f'{simloc_file}+{sim}'

        rate_maps = []
        for det in detectors:
            rate_maps.append(map_data[det])        

        # Create dict for easy handling
        rates = {}
        for det, rate in zip(detectors, rate_maps):
            rates[det] = rate

        # ----STEP 4 STOP------#

        # ----STEP 5 START------#
        # Calculate null log likelihood from src and bkg counts
        null_log_like = np.sum(
            src_cnts * np.log(bkg_cnts) - bkg_cnts - gammaln(src_cnts))

        npix = rate_maps[0].size
        nside = hp.npix2nside(npix)

        # Location for all pixel centers
        theta, phi = hp.pix2ang(nside, range(npix))

        coords = PhysicsSphericalRepresentation(theta=theta * u.rad,
                                                phi=phi * u.rad,
                                                r=1)
        coords = SkyCoord(coords, frame='icrs')
        

        # Get expectation rates for each detector based on attitude and duration
        expectation_arr = np.zeros((len(rates), npix))
        for duration, attitude in zip(durations, quat[1:]):
            exprates = get_rates(rates, coords, attitude)
            for i, (det, expect) in enumerate(exprates.items()):
                expectation_arr[i] += expect*duration

        # Interpolate values of pixel centers and get expectation maps
        map = HEALPix(nside=nside, order='ring', frame='icrs')
        coords = map.healpix_to_skycoord(range(npix))

        coords = coords.transform_to('icrs')
       
        coords = coords.represent_as(PhysicsSphericalRepresentation)

        theta = coords.theta.rad
        phi = coords.phi.rad

        expectation = hp.get_interp_val(
            expectation_arr, coords.theta.rad, coords.phi.rad)
        
        # ----STEP 5 STOP------#

        # ----STEP 6 START------#
        # Do normalized fitting to get probability map
        log_like = get_loglike(
            src_cnts, bkg_cnts, expectation, coords, null_log_like)

        # ----STEP 6 STOP------#
        # ----STEP  START------#
        
        TS_map = 2*(log_like - null_log_like)

        # ----STEP  STOP------#        

        # ----STEP 7 START------#
        # Set up prior based on effective area of second best detector
        # Use the effective area (proportional to the expectation) of the
        # second best detector. This simulated the trigger condition of at least 2
        # detectors above threshold.
        aeff = np.sort(np.array(expectation),
                       axis=0)[2, :]

        prior = np.power(aeff, 3/2)

        # ----STEP 7 STOP------#

        # ----STEP 8 START------#
        # Make sure maps are the same size

        occ_map = hp.ud_grade(occ_map, nside, dtype='int')

        # Apply occultation
        prior *= occ_map
        TS_map *= occ_map

        # Prevent error message for log(0)
        prior[prior == 0] = np.finfo(prior.dtype).tiny

        # Sum log of prior
        log_like += np.log(prior)

        prob_map = np.exp(log_like - logsumexp(log_like))
        
        TS_map = np.exp(TS_map - logsumexp(TS_map))
        
 

        loc_maps[sim] = {'prob_map': prob_map, 'max_prob': np.amax(prob_map),
                         'ts_map': TS_map, 'max_ts':np.amax(TS_map),
                         'model': map_model}

        max_TS_val.append(np.amax(TS_map))


    # ----STEP 8 STOP------#
    # End creating localization maps for each simulated data

    # ----STEP 9 START------#
    # Loop through each map used and save each to fits file

    # Create empty HDUlist to hold non-maxlikelihood maps
    temp_hdul = fits.HDUList()
    
    # Create primary header
    primhdr = HealpixPrimHeader(refdata)
    
    # Scrub input Primary header for any keyword values
    primhdr = copyheader(phaii.headers['PRIMARY'], primhdr)
    primhdr['INSTRUME'] = 'CSA'
    # Set up primary hdu
    primary = fits.PrimaryHDU(header=primhdr)
    
    temp_hdul.append(primary)

    # Get sorted list from max TS_map values
    sort = np.flip(np.argsort(max_TS_val))

    sim_extension = sim_extension[sort]
    for sim in sim_extension:

        # Get map and model
        prob_map = loc_maps[sim]['prob_map']
        ts_map = loc_maps[sim]['ts_map']
        model = loc_maps[sim]['model']
        
        # Create array containing both maps
        maps = [prob_map,ts_map]
        logger.info(f"Creating FITS file for {model}")
        for i, map in enumerate(maps):
            # Get RA, DEC of max for this map likelihood
            max_pix = np.argmax(map)
            npix = map.size
            nside = hp.npix2nside(npix)
    
            # Theta and phi in radians
            theta, phi = hp.pix2ang(nside, max_pix)
    
            # Convert theta and phi to SkyCoord object in 'ICRS' frame
            sky_coord = SkyCoord(phi, 0.5 * np.pi - theta,
                                 frame='icrs', unit='rad')
    
            # Extract the RA and DEC in degrees
            ra = sky_coord.ra.deg
            dec = sky_coord.dec.deg
            # Extract the GLAT and GLON in degress
            glat = sky_coord.galactic.b.degree
            glon = sky_coord.galactic.l.degree
    
            # Save map as healpix FITS file
            # Create IMG HDU based on occultation map
            logger.info("Creating image extension")
            img_hdu = pix_img(map, nside, False, refdata)
    
            # Scrub input GTI header for any keyword values
            img_hdu.header = copyheader(phaii.headers['ARRAY_PHA'], img_hdu.header)
    
            # Edit header keywords
            exposure = src_times[0][1] - src_times[0][0]
            update_maphdr(img_hdu.header, version, map, ra, dec, np.rad2deg(theta),
                          np.rad2deg(phi), exposure, emin, emax, model, glat, glon)
    
            # Create Healpix map data extension
            logger.info("Creating Healpix map extension")
            hpixmap = pix_data(map, nside, 'ring', 'icrs', refdata)
    
            # Scrub input GTI header for any keyword values
            hpixmap.header = copyheader(phaii.headers['ARRAY_PHA'], hpixmap.header)
    
            # Edit header keywords
            update_maphdr(hpixmap.header, version, map, ra, dec, np.rad2deg(theta),
                          np.rad2deg(phi), exposure, emin, emax, model, None, None, max_pix)
    
            if i == 0:
                img_hdu.header['EXTNAME'] = 'SKY_PROB'
                hpixmap.header['EXTNAME'] = 'HEALPIX_PROB'
                img_hdu.header['MAPTYPE'] = 'PROB'
                hpixmap.header['MAPTYPE'] = 'PROB'
            else:
                img_hdu.header['EXTNAME'] = 'SKY_TS'
                hpixmap.header['EXTNAME'] = 'HEALPIX_TS'
                img_hdu.header['MAPTYPE'] = 'LIKE'
                hpixmap.header['MAPTYPE'] = 'LIKE'
    
            temp_hdul.append(img_hdu)
            temp_hdul.append(hpixmap)

    # Write file containing all maps if any available
    if len(temp_hdul) != 0:
        map_fits_name = f'{outprefix}_locmap.fits'
        logger.info(f'Writing file {map_fits_name}')
        status = clob_check(clobber, map_fits_name, logger)
        if status == 1:
            return 1
        else:
            temp_hdul.writeto(map_fits_name, overwrite=clobber, checksum=True)

    # ----STEP 9 STOP------#

    # ----STEP 10 START------#
    # Save figure
    logger.info('Creating localization plot')
    eqplot = get_locplot(loc_maps[sim_extension[0]]['prob_map'])
    plot_name = f'{outprefix}_locmap.png'
    logger.info(f'Writing file {plot_name}')
    status = clob_check(clobber, plot_name, logger)
    if status == 1:
        return 1
    else:
        plt.savefig(plot_name)

    # ----STEP 10 STOP------#

    return 0

# Check that header keywords for time are present and
# Return epoch as astropy Time object given a header


def return_epoch(header):
    # Get header keywords
    header_keys = np.array(list(header.keys()))
    # Define what keyword values are needed
    keys_needed = np.array(['MJDREFI', 'MJDREFF', 'TIMESYS'])

    # Check that all needed keywords are present
    not_present = np.setdiff1d(keys_needed, header_keys)

    # If not all keywords found, return error
    if not_present.size != 0:
        print("Keywords not found:", not_present)
        return 1

    # Retrieve keyword values and set-up epoch
    mjd_refi = header['MJDREFI']
    mjd_reff = header['MJDREFF']
    timesys = header['TIMESYS']

    epoch = Time(mjd_refi + mjd_reff, format='mjd', scale=timesys.lower())

    return epoch


# Edit header keywords that pertain to best location for burst
def update_maphdr(header, version, locmap, ra, dec, theta, phi, exposure, emin, emax,
                  model, glat, glon, maxpix=None):
    header['INSTRUME'] = 'CSA'
    header['DATA_MIN'] = np.amin(locmap)
    header['DATA_MAX'] = np.amax(locmap)
    header['MODELSIM'] = model
    header['RA_OBJ'] = round(ra, 1)
    header['DEC_OBJ'] = round(dec, 1)
    header['THET_OBJ'] = round(theta, 1)
    header['PHI_OBJ'] = round(phi, 1)
    header['EXPOSURE'] = exposure
    header['E_MIN'] = emin
    header['E_MAX'] = emax
    header['SOFTVER'] = version
    if glat is not None and glon is not None:
        header['GLAT_OBJ'] = round(glat, 1)
        header['GLON_OBJ'] = round(glon, 1)
    if maxpix is not None:
        header['PIXELOBJ'] = maxpix


# Copy basic info from one header to another


def copyheader(fromheader, toheader):
    # Keywords to skip
    exclude = ['TFORM', 'TTYPE', 'TUNIT', 'HDUCL']
    # Iterate through keywords
    for key, val in fromheader.items():
        if any(keyword in key[:5] for keyword in exclude):
            continue
        if key in toheader.keys() and key != 'EXTNAME':
            toheader[key] = val

    return toheader

# Update DATE keyword in headers


def updateDATE(hdu_list):
    current_datetime = Time(datetime.datetime.now(), scale='utc')
    for hdu in hdu_list:
        header = hdu.header
        if 'DATE' in header:
            header['DATE'] = current_datetime.isot[:19]

# Get expected rates given dictionary holding Healpixmap arrays, coords, and attitude


def get_rates(hmap_dict, coords, attitude):

    # Coords transformation
    coords = coords.transform_to(SpacecraftFrame(quaternion=attitude))

    # Nearest pixels and weights
    coords = coords.represent_as('physicsspherical')

    # Create expectation maps for each detector
    expectation_map = {det: hp.get_interp_val(hmap_dict[det], coords.theta.rad, coords.phi.rad)
                       for det, rate in hmap_dict.items()}

    return expectation_map

# Return log likelihood probability map


def get_loglike(src_cnts, bkg_cnts, expectation, coords, null_log_like):
    nf = FastNormFit()

    ts = np.empty(coords.shape)

    for i in np.ndindex(coords.shape):

        ts_i, norm, norm_err, status = nf.solve(src_cnts,
                                                bkg_cnts,
                                                expectation[(slice(None),)+i])

        if status:
            print(f"Fit failed at coordinate {coords[i]}")

        ts[i] = ts_i

    return ts/2 + null_log_like

# Return localization plot figure


def get_locplot(local_map, gradient=False):
    loc_map = HealPixLocalization.from_data(local_map)
    nside = hp.npix2nside(local_map.size)
    approx_res = np.sqrt(hp.nside2pixarea(nside, degrees=True))

    numpts_az = int(np.floor(360.0/approx_res))
    numpts_zen = int(np.floor(180.0/approx_res))

    prob_grid, ra_grid, dec_grid = loc_map.prob_array(numpts_ra=numpts_az,
                                                      numpts_dec=numpts_zen)

    eqplot = EquatorialPlot()

    eqplot._posterior = eqplot.plot_heatmap(prob_grid, ra_grid, dec_grid)

    clevels = [0.997, 0.955, 0.687]
    for clevel in clevels:
        paths = loc_map.confidence_region_path(clevel, numpts_ra=numpts_az,
                                               numpts_dec=numpts_zen)

        lons = []
        lats = []
        for path in paths:
            coord = SkyCoord(path[:, 0], path[:, 1], frame='gcrs', unit='deg')
            lon, lat = get_lonlat(coord.transform_to(eqplot._astropy_frame))
            lons.append(lon)
            lats.append(lat)

        # if we plotted the gradient, then only plot the unfilled contours,
        # otherwise, plot the stacked, filled contours
        numpaths = len(paths)
        if gradient:
            for i in range(numpaths):
                contour = SkyLine(lons[i].value, lats[i].value, eqplot.ax,
                                  frame=eqplot._frame, color='black',
                                  alpha=0.7, linewidth=2,
                                  flipped=eqplot._flipped)
                eqplot._clevels.include(contour, str(clevel) + '_' + str(i))
        else:
            for i in range(numpaths):
                contour = SkyPolygon(lons[i].value, lats[i].value, eqplot.ax,
                                     frame=eqplot._frame, color='blue',
                                     alpha=None, face_alpha=0.3,
                                     flipped=eqplot._flipped)
                eqplot._clevels.include(contour, str(clevel) + '_' + str(i))

    return eqplot


def main():
    desc = """
    Program to create localization maps
    """

    parser = argparse.ArgumentParser(
        description=desc, formatter_class=RawTextHelpFormatter)
    parser.add_argument('--infile', action="store", type=str,
                        help="List of input FITS files or @file containing list")
    parser.add_argument('--outprefix', type=str,
                        help='Prefix of output files')
    parser.add_argument('--burstgti', type=str,
                        help='GTI FITS file containing time ranges for source and background ')
    parser.add_argument('--attitude', type=str,
                        help='FITS file containing attitude information')
    parser.add_argument('--occultation', type=str,
                        help='Occultation Healpix map in FITS format')
    parser.add_argument('--emin', type=float,
                        help='Minimum energy bound (keV)',
                        default=50.)
    parser.add_argument('--emax', type=float,
                        help='Maximum energy bound (keV)',
                        default=300.)
    parser.add_argument('--simloc', type=str,
                        help='FITS file containing simulated data for each detector or CALDB',
                        default='CALDB')
    parser.add_argument('--extension', type=str,
                        help="Extension(s) of simloc file to use ('1,2,3...') or 'all'",
                        default='all')
    parser.add_argument('--poly', type=int,
                        help='Order of polynomial for background fitting',
                        default=3)
    parser.add_argument('--refdata', type=str,
                        help='Reference data directory',
                        default='REFDATA')
    parser.add_argument('--chatter', type=int,
                        default=1, help='Level of verboseness.')
    parser.add_argument('--log', type=str,
                        default='no', help='Create log file?')
    parser.add_argument('--clobber', type=str,
                        default='no', help='Overwrite existing output file(s)?')

    args = parser.parse_args()
    args_dict = vars(args)
    # Loop through input parameters and prompt user if a required value is not provided
    for param, val in args_dict.items():
        if val is None:
            # Print help information for the parameter
            for action in parser._actions:
                if action.option_strings and action.option_strings[0] == f'--{param}':
                    # Prompt user for input
                    arg_type = action.type if hasattr(action, 'type') else str
                    # Make sure input is valid type
                    valid = False
                    while not valid:
                        new_value = input(f"{action.help}: ")
                        try:
                            new_value = arg_type(new_value)
                            setattr(args, param, new_value)
                            valid = True
                        except:
                            type_str = str(arg_type).split("'")[1]
                            print(
                                f"{param} must be a valid {type_str}. Try again.")

    # Assign values to be passed to main routine
    infile = args.infile
    outprefix = args.outprefix
    burstgti = args.burstgti
    attitude = args.attitude
    occultation = args.occultation
    emin = args.emin
    emax = args.emax
    simloc = args.simloc
    extension = args.extension
    poly = args.poly
    refdata = args.refdata
    chatter = args.chatter
    log = args.log in ['yes', 'y', True]  # Simple check
    clob = args.clobber in ['yes', 'y', True]  # Simple check

    # Set up logger
    logger = logging.getLogger('bcrebevt')
    # Set level based upon chatter level
    # Check that chatter is reasonable
    if chatter < 0 or chatter > 5:
        print(f'Chatter value {chatter} is not valid. Must be 0-5.')
        sys.exit()
    # Create dictionary to set level according to chatter value
    chat_level = {0: 40, 1: 20, 5: 10}
    # Chatter values 1 through 4 are all the same
    if 1 <= chatter <= 4:
        chatter = 1

    # If log file is wanted
    if log:
        # Create a file handler and set the logging level
        # Get current date and time
        now = datetime.datetime.now()
        date_time = now.strftime("%d%m%Y_%H%M%S")
        log_filename = f'bcfindloc_{date_time}.log'
        file_handler = logging.FileHandler(log_filename)
        # Create a formatter and add it to the handlers
        formatter = logging.Formatter(
            '%(asctime)s - %(levelname)s - %(message)s')
        file_handler.setFormatter(formatter)
        file_handler.setLevel(chat_level[chatter])
        logger.addHandler(file_handler)

    bcfindloc(infile, outprefix, burstgti, attitude, occultation,
              emin, emax, simloc, extension, poly, refdata, chatter, clob)


if __name__ == '__main__':
    main()
