#
#  Requirement 
#   The task inputs are the orbit file and gti and creates an output file
#   with a map in the  primary header in AIT projection and the first extension
#   with HEAPIX representation.
#   The values in the map represent the part of the sky occulted by the earth
#   (pixel set to 0) or not occulted pixel set =1. The size of map is internally
#   calculated from the parameter nside that is 12 x (nside)**2. 
#
#    The task needs to have set the location of the REFDATA directory that contains
#    the headers of the output file     
#
# Step 1 Check and load all input parameters
# Step 2 Create occultation HEALPix array (1=visible,0=occulted) based on
#        input GTI time and orbital data
# Step 3 Write occultation map to a FITS file with the following extensions:
#        Primary: Image of occultation map using Aitoff projection
#        HEALPIX: HEALPix data array
#
#  The gdt library is used to map the orbit into a software structure that set the
#  occultation flags
#
#     James Runge and Lorella Angelini Feb 2024    
#

import sys
import os
import datetime
import argparse
from argparse import RawTextHelpFormatter
import logging

import astropy.io.fits as fits
import astropy.units as u
from astropy.time import Time
from astropy.coordinates.representation import CartesianRepresentation
from astropy.coordinates import SkyCoord
from astropy.utils.exceptions import AstropyWarning

from gdt.core.coords import SpacecraftFrame

from pathlib import Path

import healpy as hp
import numpy as np
from heasoftpy.burstcube.lib.pixmaps import pix_img, pix_data
from heasoftpy.burstcube.lib.io import clob_check, file_extn
from heasoftpy.burstcube.bcversion_task import bcversion

import warnings

# Suppress warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.simplefilter('ignore', category=AstropyWarning)


def bcoccult(orbit_file, gti_file, outprefix, nside, scheme='ring',
             refdata='REFDATA', chatter=1, clobber=False):
    # orbit_file - FITS file containing orbital information
    #              Retrieve TIME, X, Y, and Z from this file
    #
    # gti_file   - GTI file (+extension) that includes time for occultation
    #
    # nside      - Size of healpix map
    #
    # scheme     - Scheme of healpix map: either 'ring' (default) or 'nest'

    # ----STEP 1 START------#
    # Set logger parameters
    logger = logging.getLogger('bcoccult')
    # Check that chatter is reasonable
    if chatter < 0 or chatter > 5:
        print(f'Chatter value {chatter} is not valid. Must be 0-5.')
        return 1
    # Create dictionary to set level according to chatter value
    chat_level = {0: 40, 1: 20, 5: 10}
    # Chatter values 1 through 4 are all the same for now
    if 1 <= chatter <= 4:
        chatter = 1
    # Set logger level
    version = bcversion()
    logger.info(f'VERSION: {version}')

    # Begin checking all input parameters
    logger.info("Verifying input parameters")

    # Begin checking and loading orbit file
    # Check if extension given
    file, extn = file_extn(orbit_file)
    # Check that file exists
    if not Path(file).exists():
        logger.error(f"File {file} not found.")
        return 1

    # Open file
    orb_fits = fits.open(Path(file))[extn]
    # Load column info
    orb_cols = orb_fits.columns
    # Check that the columns TIME, X, Y, and Z exist
    cols_needed = np.array(['TIME', 'X', 'Y', 'Z'])

    not_present = np.setdiff1d(cols_needed, orb_cols.names)

    # If column missing, return error
    if not_present.size != 0:
        logger.error("Needed columns not found:", not_present)
        return 1

    # A double check that units are given for columns
    if orb_cols['X'].unit == '':
        logger.error("No units found for X, Y, Z columns.")
        return 1

    # Check that all relevant info is in header
    # Load header
    orb_header = orb_fits.header

    # Check that header keywords are present and get epoch
    epoch = return_epoch(orb_header)

    # If header keywords were not found or error occured
    # return error
    if not isinstance(epoch, Time):
        logger.error("Error in retrieving header keywords.")
        return 1

    # Begin loading column data
    orb_data = orb_fits.data
    t_col = orb_data['TIME']
    x_col = orb_data['X']
    y_col = orb_data['Y']
    z_col = orb_data['Z']

    # Create a Time object for the specified times for given epoch
    times = epoch + t_col * u.second

    # End loading of data from orbit file

    # Check REFDATA
    if refdata == 'REFDATA':
        # Check environmental setting
        if os.getenv('LHEA_DATA'):
            refdata = os.environ['LHEA_DATA'] + '/'
        else:
            logger.error("LHEA_DATA is not defined in current environment.")
            return 1
    else:
        # Check that directory exists
        if Path(refdata).exists() and Path(refdata).is_dir():
            # Catch if path does not end with /
            if refdata[-1] != '/':
                refdata = refdata + '/'
        else:
            logger.error(f"Directory {refdata} does not exist.")
            return 1

    # Begin loading GTI file
    # Check if extension given
    file, extn = file_extn(gti_file)
    # Check that file exists
    if not Path(file).exists():
        logger.error(f"File {file} not found.")
        return 1

    # Open file
    gti_fits = fits.open(Path(file))[extn]

    # Check that all relevant info is in header
    # Load header
    gti_header = gti_fits.header

    # Check that header keywords are present and get epoch
    gti_epoch = return_epoch(gti_header)

    # If header keywords were not found or error occured
    # return error
    if not isinstance(gti_epoch, Time):
        logger.error("Error in retrieving header keywords.")
        return 1

    # Begin loading column data
    gti_data = gti_fits.data
    start_col = gti_data['START']
    stop_col = gti_data['STOP']
    # End loading GTI data

    # Set boolean based upon scheme
    if 'ring' in scheme.lower():
        scheme = 'ring'
        nest = False
    elif 'nest' in scheme.lower():
        scheme = 'nested'
        nest = True
    else:
        logger.error(f"Scheme {scheme} is not 'ring' or 'nested'.")
        return 1

    # Check that nside is a valid value
    if not hp.isnsideok(nside, nest=nest):
        logger.error(f"{nside} is not a valid value.")
        return 1

    # End of checking and loading inputs
    # ----STEP 1 STOP------#

    # ----STEP 2 START------#
    # Begin Occultation map creation
    logger.info("Creating occultation map")
    # Expecting peak time interval only but will calculate midpoint
    # of first START and last STOP value
    midpoint = (start_col[0] + stop_col[-1])/2.

    # Get TIME object based on midpoint
    time = gti_epoch + midpoint * u.second

    # Check that given time falls within time range
    if time < times[0] or time > times[-1]:
        logger.error(
            f"Given time {time.isot} falls outside range of input time column.")
        return 1

    # Create position based on Cartesian representation
    pos = CartesianRepresentation(x_col, y_col, z_col,
                                  unit=u.Unit(orb_cols['X'].unit))

    # Create spacecraft frame from times and position
    sc_frame = SpacecraftFrame(obstime=times, obsgeoloc=pos)

    # Get specific frame for given time
    sc_frame_wanted = sc_frame.at(time)
    # Get the pixel indices
    pixel_indices = np.arange(hp.nside2npix(nside))

    # Get the pixel centers in spherical coordinates
    theta, phi = hp.pix2ang(nside, pixel_indices, nest=nest)

    # Convert spherical coordinates to degrees
    theta_deg = np.degrees(theta)
    phi_deg = np.degrees(phi)

    # Create a SkyCoord object with the spherical coordinates
    sky_coords = SkyCoord(phi_deg, 90-theta_deg, unit='deg', frame='icrs')

    # Create boolean array of pixels that are visible based upon spacecraft frame at given time
    occultation = sc_frame_wanted.location_visible(sky_coords).astype(float)

    # ----STEP 2 STOP------#

    # ----STEP 3 START------#
    # Create IMG HDU based on occultation map
    logger.info("Creating primary image extension")
    img_hdu = pix_img(occultation, nside, nest, refdata, occultation=True)

    # Edit header keywords
    # Since occultation map, min = 0 and max = 1
    img_hdu.header['DATA_MIN'] = 0.0
    img_hdu.header['DATA_MAX'] = 1.0

    # Scrub input GTI header for any keyword values
    get_keyvals(gti_header, img_hdu.header)
    primary = fits.PrimaryHDU(data=img_hdu.data, header=img_hdu.header)

    # Create Healpix map data extension
    logger.info("Creating Healpix map extension")
    hpixmap = pix_data(occultation, nside, scheme, 'icrs', refdata)

    # Scrub input GTI header for any keyword values
    get_keyvals(gti_header, hpixmap.header)

    # Write occultation map
    hdul = fits.HDUList([primary, hpixmap])

    # Existence check
    logger.info("Writing output FITS file")
    occ_fits_name = f'{outprefix}_occult.fits'
    status = clob_check(clobber, occ_fits_name, logger)
    if status == 1:
        return 1
    else:
        hdul.writeto(occ_fits_name, overwrite=clobber, checksum=True)

    return 0
    # ----STEP 3 STOP------#

# Check that header keywords for time are present and
# Return epoch as astropy Time object given a header


def return_epoch(header):
    # Get header keywords
    header_keys = np.array(list(header.keys()))
    # Define what keyword values are needed
    keys_needed = np.array(['MJDREFI', 'MJDREFF', 'TIMESYS'])

    # Check that all needed keywords are present
    not_present = np.setdiff1d(keys_needed, header_keys)

    # If not all keywords found, return error
    if not_present.size != 0:
        print("Keywords not found:", not_present)
        return 1

    # Retrieve keyword values and set-up epoch
    mjd_refi = header['MJDREFI']
    mjd_reff = header['MJDREFF']
    timesys = header['TIMESYS']

    epoch = Time(mjd_refi + mjd_reff, format='mjd', scale=timesys.lower())

    return epoch

# Check header for list of keyword values


def get_keyvals(input_hdr, out_hdr):
    # Define what keyword values are wanted
    keys_wanted = np.array(['TELESCOP', 'INSTRUME', 'DATAMODE',
                            'OBJECT', 'RA_OBJ', 'DEC_OBJ', 'POS_FLAG',
                            'OBS_ID', 'TRIGGER', 'TRIGTIME', 'TRIGUTC',
                            'DATE-OBS', 'DATE-END', 'CREATOR', 'PROCVER',
                            'CALDBVER'])

    # Check header for list of keywords and their values
    # If not None, set value
    for key in keys_wanted:
        value = input_hdr.get(key)
        if value is not None:
            out_hdr[key] = value


def main():
    desc = """
    Task to create an occultation HEALPix map based upon orbital data and
    a GTI.
    """

    parser = argparse.ArgumentParser(
        description=desc, formatter_class=RawTextHelpFormatter)
    parser.add_argument('--orbit_file', action="store", type=str,
                        help='FITS file containing orbit information')
    parser.add_argument('--outprefix', type=str,
                        help='Prefix of output files')
    parser.add_argument('--gti_file', action="store", type=str,
                        help='GTI file containing time range')
    parser.add_argument('--nside', action="store", type=int,
                        help='Size of Healpix map (nside)',
                        default=16)
    parser.add_argument('--scheme', action="store", type=str,
                        help='Healpix ordering scheme (RING or NESTED)',
                        default='RING')
    parser.add_argument('--refdata', type=str,
                        help='Reference data directory',
                        default='REFDATA')
    parser.add_argument('--chatter', type=int,
                        default=1, help='Level of verboseness.')
    parser.add_argument('--log', type=str,
                        default='no', help='Create log file?')
    parser.add_argument('--clobber', type=str,
                        default='no', help='Overwrite existing output file(s)?')
    args = parser.parse_args()
    args_dict = vars(args)
    # Loop through input parameters and prompt user if a required value is not provided
    for param, val in args_dict.items():
        if val is None:
            # Print help information for the parameter
            for action in parser._actions:
                if action.option_strings and action.option_strings[0] == f'--{param}':
                    # Prompt user for input
                    arg_type = action.type if hasattr(action, 'type') else str
                    # Make sure input is valid type
                    valid = False
                    while not valid:
                        new_value = input(f"{action.help}: ")
                        try:
                            new_value = arg_type(new_value)
                            setattr(args, param, new_value)
                            valid = True
                        except:
                            type_str = str(arg_type).split("'")[1]
                            print(
                                f"{param} must be a valid {type_str}. Try again.")

    # Assign values to be passed to main routine
    orbit_file = args.orbit_file
    outprefix = args.outprefix
    gti_file = args.gti_file
    nside = args.nside
    scheme = args.scheme
    refdata = args.refdata
    chatter = int(args.chatter)
    log = args.log in ['yes', 'y', True]  # Simple check
    clob = args.clobber in ['yes', 'y', True]  # Simple check

    # Set up logger
    logger = logging.getLogger('bcoccult')
    # Set level based upon chatter level
    # Check that chatter is reasonable
    if chatter < 0 or chatter > 5:
        print(f'Chatter value {chatter} is not valid. Must be 0-5.')
        sys.exit()
    # Create dictionary to set level according to chatter value
    chat_level = {0: 40, 1: 20, 5: 10}
    # Chatter values 1 through 4 are all the same
    if 1 <= chatter <= 4:
        chatter = 1

    # If log file is wanted
    if log:
        # Create a file handler and set the logging level
        # Get current date and time
        now = datetime.datetime.now()
        date_time = now.strftime("%d%m%Y_%H%M%S")
        log_filename = f'bcoccult_{date_time}.log'
        file_handler = logging.FileHandler(log_filename)
        # Create a formatter and add it to the handlers
        formatter = logging.Formatter(
            '%(asctime)s - %(levelname)s - %(message)s')
        file_handler.setFormatter(formatter)
        file_handler.setLevel(chat_level[chatter])
        logger.addHandler(file_handler)

    # Pass parameters to main routine
    bcoccult(orbit_file, gti_file, outprefix,
             nside, scheme, refdata, chatter, clob)

if __name__ == '__main__':
    main()