import astropy.io.fits as fits
from astropy.wcs import WCS
from matplotlib.colors import LogNorm
from astropy.visualization.wcsaxes.frame import EllipticalFrame
from reproject import reproject_from_healpix

import numpy as np

from pathlib import Path

import healpy as hp
from healpy.newvisufunc import projview, newprojplot
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

def HealpixPrimHeader(REFDATA):
    reffile = REFDATA + 'primary.txt'
    if not Path(reffile).exists():
        print(f"File {reffile} not found")
    else:
        reffile = Path(reffile)
    hdr = fits.Header.fromtextfile(reffile)
    return hdr

def HealpixImgHeader(REFDATA):
    reffile = REFDATA + 'pixmap_img.txt'
    if not Path(reffile).exists():
        print(f"File {reffile} not found")
    else:
        reffile = Path(reffile)
    hdr = fits.Header.fromtextfile(reffile)
    return hdr

def HealpixDataHeader(REFDATA):
    reffile = REFDATA + 'pixmap_heap.txt'
    if not Path(reffile).exists():
        print(f"File {reffile} not found")
    else:
        reffile = Path(reffile)
    hdr = fits.Header.fromtextfile(reffile)
    return hdr
    
    
def pix_img(healpix_arr,nside,nest,refdata,occultation=False):

    # Calculate approximate resolution of image
    approx_res = np.sqrt(hp.nside2pixarea(nside,degrees=True))

    numpts_az = int(np.floor(360.0/approx_res))
    numpts_zen = int(np.floor(180.0/approx_res))
    
    # Load image header and set values according to resolution
    header = HealpixImgHeader(refdata)
    header['NAXIS']  = 2
    header['NAXIS1'] = numpts_az
    header['NAXIS2'] = numpts_zen
    header['CTYPE1'] = 'GLON-AIT'
    header['CRPIX1'] = header['NAXIS1'] / 2 + 0.5
    header['CRVAL1'] = 180.0
    header['CDELT1'] = -approx_res
    header['CUNIT1'] = 'deg'
    header['CTYPE2'] = 'GLAT-AIT'
    header['CRPIX2'] = header['NAXIS2'] / 2 + 0.5
    header['CRVAL2'] = 0.0
    header['CDELT2'] = approx_res
    header['CUNIT2'] = 'deg'
    
    w = WCS(header)
    # Reproject healpix map to header info
    array, footprint = reproject_from_healpix((healpix_arr, 'icrs'), w, 
                                              shape_out=(numpts_zen,numpts_az),
                                              nested=nest)
                                    
    # If occultation map, only want 1 and 0 values. Fractional values comes from interpolation
    # when reprojecting healpix map
    if occultation:
        array[array>0] = 1.0

    hdu = fits.ImageHDU(data=array, header=header)

    return hdu

def pix_data(healpix_arr,nside,scheme,coordsys,refdata):
    
    healpix_pix = np.arange(healpix_arr.size)
    # Create column of pixel values
    pix_col = fits.Column(name='PIXEL', format='1I', array=healpix_pix)
    # Create column from healpix array values
    val_col = fits.Column(name='MAP_VALUE', format='1E', array=healpix_arr)
    # Create HDU Bintable from column
    hdu = fits.BinTableHDU.from_columns([pix_col,val_col])
    
    # Load header from refdata
    hdu.header.update(HealpixDataHeader(refdata))
    
    # Set header keyword values
    hdu.header['ORDERING'] = scheme.upper()
    hdu.header['COORDSYS'] = coordsys.upper()
    hdu.header['NSIDE']    = nside
    hdu.header['FIRSTPIX'] = 0
    hdu.header['LASTPIX']  = healpix_arr.size - 1
    
    return hdu