from gdt.core.data_primitives import TimeBins, TimeRange
from gdt.core.background.binned import Polynomial
from gdt.core.binning.binned import rebin_by_edge_index

from scipy.signal import peak_prominences, argrelmax

import numpy as np

# Much of this code was copied from the bayesian blocks implementation
# in astropy https://github.com/astropy/astropy/blob/2db2f820eb51c95fcb3e187328c8cbd99ecd24df/astropy/stats/bayesian_blocks.py
# With the following differences:
# 1. The following issue was fixed https://github.com/astropy/astropy/issues/14017
# 2. It can handle bins with 0 events (Scargle's convention is to always have at least 1)
# 3. The code was simplified to handle the specific case of binned data
# 4. In returns change point indices, instead of edges.

class Bayes:
    
    def __init__(self,p0=0.05,poly=3,buffer_blocks=5,max_iter=100):
    
        self._p0  = p0
        self._poly= poly 
        self._buffer_blocks = buffer_blocks
        self._max_iter = max_iter

        self._lc = None
    
        self._signal_range = None
        self._bb_args = None
    
        self._lc_bayes = None
        self._bkg_times = None
        self._bkg_model = None
      
        self._lc_bkg_model = None
        self._lc_bkg_counts = None
        self._lc_bb_index = None
    
        self._peak = None
        
    def load_lightcurve(self, lc):
        self._lc = lc
    
    def get_lightcurve(self):
        return self._lc
    
    def get_peak(self):
        return self._peak
        
    def bayesian_blocks(self,lc, p0=0.05, gamma=None, ncp_prior=None):
        """
        Perform a binned bayesian blocks identification (Scargle, 2013),
        taking the exposure into account.

        Args:
            lc (TimeBins): Lighcurve
            p0 (float): False positive rate
            gamma (float): Alternatively, provide the gamma parameter 
                (slope in the prior). Take precedence over ``p0``
            ncp_prior (float): Specify the prior in the number of bins,
                taking precedence over ``p0`` and ``gamma``.

        Return:
            bb_indices (array): Bin indices of change points. A new lightcurve can be 
                computed as lc.rebin(rebin_by_edge_index, bb_indices)
        """
    
        # Compute prior
        if ncp_prior is None:
            if gamma is not None:
                ncp_prior = -np.log(gamma)
            elif p0 is not None:
                #Eq. 21 in Scargle (2013) (log missing)
                ncp_prior = 4 - np.log(73.53 * p0 * (lc.size**-0.478))
            else:
                raise RuntimeError("Specify either p0, gamma or ncp_prior")
        
        # ----------------------------------------------------------------
        # Start with first data cell; add one cell at each iteration
        # ----------------------------------------------------------------
        exposure_cumsum = np.append(np.cumsum(lc.exposure[::-1])[::-1], 0)
    
        counts_cumsum = np.append(np.cumsum(lc.counts[::-1])[::-1], 0)
        best = np.zeros(lc.size, dtype=float)
        last = np.zeros(lc.size, dtype=int)

        for R in range(lc.size):

            # evaluate fitness function. Eq. 19 from Scargle 2013
            T_k = exposure_cumsum[: R + 1] - exposure_cumsum[R + 1]
        
            N_k = counts_cumsum[: R + 1] - counts_cumsum[R+1] 

            # When N_k = 0, fit_vec is nan, but it should be 0
            fit_vec_log = np.zeros(N_k.size) #Prevent uninitialized values
            np.log(N_k / T_k, out = fit_vec_log, where = N_k != 0)
        
            fit_vec = N_k * fit_vec_log

            A_R = fit_vec - ncp_prior
            A_R[1:] += best[:R]

            i_max = np.argmax(A_R)
            last[R] = i_max
            best[R] = A_R[i_max]

        # ----------------------------------------------------------------
        # Now find changepoints by iteratively peeling off the last block
        # ----------------------------------------------------------------
        change_points = np.zeros(lc.size, dtype=int)
        i_cp = lc.size
        ind = lc.size
        while i_cp > 0:
            i_cp -= 1
            change_points[i_cp] = ind
            if ind == 0:
                break
            ind = last[ind - 1]
        if i_cp == 0:
            change_points[i_cp] = 0
        change_points = change_points[i_cp:]

        return change_points


    def get_bayesian_blocks(self,poly=3,
                            p0 = 0.05, 
                            buffer_blocks = 5,
                            signal_range = None,
                            max_iter = 100):
        """
        Bayesian blocks with iterative background fitting
        
        Args:
            p0 (float): False alarm probability for bayesian blocks algorithm
            buffer_blocks (int): Define the exclusion zone around the signal
                to compute the background in multiple of the size of the first and last
                bayesian block that are part of the signal
            bkg_times (array-like): Optional: provide a first guess for the time range
              that contained the signal. Must have shape (2,).
           max_iter (int): Maximum number of iterations. 

        Return:
           bb_index (array): Identified bayesian blocks, corresponding to indices of the ligh curve bins
           lc_bayes (TimeBins): Lightcurve bins in resulting bayesian blocks 
        """

        args = {'poly':poly,
                'p0':p0,
                'buffer_blocks':buffer_blocks,
                'signal_range':signal_range,
                'max_iter':max_iter}

        if self._lc_bayes is None or args != self._bb_args:

            # Not cached or need to recompute
            signal_range, bkg_times, bkg_model, bkg_counts, bb_index, lc_bayes = self._get_bayesian_blocks(self._lc, **args)

            # Cache
            self._bb_args = args
            self._signal_range = signal_range
            self._bkg_times = bkg_times
            self._lc_bkg_model = bkg_model
            self._lc_bkg_counts = bkg_counts
            self._lc_bb_index = bb_index
            self._lc_bayes = lc_bayes

        return self._lc_bayes, self._lc_bb_index

    def _get_bayesian_blocks(self,lc,poly=3,
                             p0 = 0.05, 
                             buffer_blocks = 5,
                             signal_range = None,
                             max_iter = 100):
        """
        Bayesian blocks with iterative background fitting

        Args:
            lc (TimeBins): Light curve
            poly (int): Order of polynomial to fit background
            p0 (float): False alarm probability for bayesian blocks algorithm
            buffer_blocks (int): Define the exclusion zone around the signal
                to compute the background in multiple of the size of the first and last
                bayesian block that are part of the signal
            bkg_times (array-like): Optional: provide a first guess for the time range
              that contained the signal. Must have shape (2,).
           max_iter (int): Maximum number of iterations. 

        Return:
           signal_range (tuple):  Signal start and stop time
           bkg_times (array): List of tuples containing the time ranges used to compute the backgrund 
           bkg_model (funtion): Polynomial fit to the background
           bb_index (array): Identified bayesian blocks, corresponding to indices of the ligh curve bins
           lc_bayes (TimeBins): Lightcurve bins in resulting bayesian blocks 
        """
        
        # Standarize input
        if signal_range is None:
            bkg_times = [(lc.range[0], lc.range[-1])]
        else:
            if len(signal_range) != 2:
                raise ValueError("Wrong signal_range shape.")

            signal_range = np.sort(signal_range)

            bkg_times = [(lc.range[0], signal_range[0]),
                         (signal_range[-1], lc.range[-1])]

        # Iterative method
        previous_bkg_times = []
        
        for i in range(max_iter):

            # ---- Fit bkg ------
            lc_bkg_times = TimeBins.merge([lc.slice(ti,tf) for ti,tf in bkg_times])

            bkg_model = Polynomial(counts = lc_bkg_times.counts[:,np.newaxis], 
                                   tstart = lc_bkg_times.lo_edges, 
                                   tstop = lc_bkg_times.hi_edges, 
                                   exposure = lc_bkg_times.exposure)

            bkg_model.fit(order= min(poly,i)) # Remove first mean and linear component

            bkg_rate, bkg_rate_err = bkg_model.interpolate(lc.lo_edges, lc.hi_edges)

            bkg_counts = bkg_rate.flatten() * lc.exposure

            # ---- Bayesian blocks --------

            # Make effective time bin ~bkg rate so the effective rate is
            # constant (homogeneous poisson process)
            # Same as Giacomo's "trick" in 3ML
            # https://github.com/threeML/threeML/blob/e31db70daf8777ce12be7aa694b21efc3f15dae0/threeML/utils/bayesian_blocks.py#L171  

            lc_eff = TimeBins(counts = lc.counts,
                              lo_edges = lc.lo_edges,
                              hi_edges = lc.hi_edges,
                              exposure = bkg_counts) 

            bb_index = self.bayesian_blocks(lc_eff, p0 = p0)

            lc_bayes = lc.rebin(rebin_by_edge_index, bb_index)

            # ----- Find peaks ----
            peaks = argrelmax(lc_bayes.rates)[0]

            if len(peaks) == 0:
                # Need least 1 peak
                print("Could not identify peak.")
                return None, None, bkg_model, bkg_counts, bb_index, lc_bayes

            prominence,left_base,right_base = peak_prominences(lc_bayes.rates, peaks)

            # Start and end of signal based on peaks, in histogram bins
            leftmost_base = np.min(left_base)
            rightmost_base = np.max(right_base)

            new_start_signal = bb_index[leftmost_base + 1]
            new_stop_signal = bb_index[rightmost_base]

            signal_tstart = lc.lo_edges[new_start_signal]    
            signal_tstop  = lc.hi_edges[new_stop_signal-1] 

            # Remove an extra chunk the size of the end blocks
            # Do not remove more than half the distance to the ends
            block_widths = lc_bayes.widths

            left_buffer = min(buffer_blocks*block_widths[left_base[0] + 1], (lc.lo_edges[new_start_signal] - lc.range[0])/2)
            right_buffer = min(buffer_blocks*block_widths[right_base[-1] - 1], (lc.range[1] - lc.hi_edges[new_stop_signal])/2)

            new_start,new_stop = np.digitize([lc.lo_edges[new_start_signal] - left_buffer,
                                              lc.hi_edges[new_stop_signal-1] + right_buffer], 
                                             lc.lo_edges) - 1

            # ----- Update background times ------
            # If they didn't change, then it has converge
            bkgex_tstart = lc.lo_edges[new_start]    
            bkgex_tstop  = lc.hi_edges[new_stop-1] 

            # Assume that the signal is fully contained. i.e. there is some background on both sides
            bkgex_tstart = max(bkgex_tstart, lc.lo_edges[bb_index[1]])
            bkgex_tstop = min(bkgex_tstop, lc.hi_edges[bb_index[-2]-1])
            
            bkg_times = [(lc.lo_edges[0], bkgex_tstart), (bkgex_tstop, lc.hi_edges[-1])]

            # Check if we are repeating a pattern.
            # Usually, when the method converges, 2 consecutive iteration have the same bayesian
            # blocks, but sometimes there is a 2-3 pattern that repeats with almost identical
            # block representations
            if bkg_times in previous_bkg_times and i >= 2:
                break

            previous_bkg_times += [bkg_times]
            if i == max_iter - 1:
                print("Maximum number of iterations reached without converging.")

        signal_range = (signal_tstart, signal_tstop)
        
        # Slice lightcurve by signal range
        lc_signal = lc.slice(*signal_range)
        maxpeak_idx = np.argmax(lc_signal.rates)
        self._peak = lc_signal.centroids[maxpeak_idx]  
        return signal_range, bkg_times, bkg_model, bkg_counts, bb_index, lc_bayes 

    def get_signal_range(self):
        """

        .. note::
            If ``get_bayesian_blocks()`` has not been called already, it will be called 
            with default parameters

        Return:
           (signal_start_time, signal_stop_time)
        """

        # Identify signal or used cached values
        if self._lc_bayes is None:
            # No cache, need to identify signal first
            args = {'poly':self._poly,
                    'p0':self._p0,
                    'buffer_blocks':self._buffer_blocks,
                    'signal_range':self._signal_range,
                    'max_iter':self._max_iter}
            self.get_bayesian_blocks(**args)

        return TimeRange(*self._signal_range)
            
    def get_duration(self, quantile):
        """
        Compute the symmetric time range that contains a given quantile of bkg-subtracted counts

        .. note::
            If ``get_bayesian_blocks()`` has not been called already, it will be called 
            with default parameters

        Args:
            quantile (float or array): Quantile(s) e.g. .9 for T90, [1,.9,.5] for [T100, T90, T50]

        Return:
           float or array
        """

        # Identify signal or used cached values
        if self._lc_bayes is None:
            # No cache, need to identify signal first
            args = {'poly':self._poly,
                    'p0':self._p0,
                    'buffer_blocks':self._buffer_blocks,
                    'signal_range':self._signal_range,
                    'max_iter':self._max_iter}
            self.get_bayesian_blocks(**args)
        #Catch if no signal is detected 
        if self._signal_range == None:
            return None
        else:
            return self._get_duration(lc = self._lc,
                                  bkg_counts = self._lc_bkg_counts,
                                  signal_range = self._signal_range,
                                  quantile = quantile)
            
    @staticmethod
    def _get_duration(lc,
                      bkg_counts,
                      signal_range,
                      quantile):
        """
        Compute the symmetric time range that contains a given quantile of bkg-subtracted counts

        Args:
            lc (TimeBins): Light curve
            bkg_model (array): Estimated background counts for light curve bin
            signal_range (tuple): Start and stop of signal
            quantile (float or array): Quantile(s) e.g. .9 for T90, [1,.9,.5] for [T100, T90, T50]

        Return:
           float or array: Same shape as quantile
        """
    
        # Percentile calculation
        start_signal,stop_signal = np.digitize(signal_range, lc.lo_edges)-1
        offset = 2*lc.lo_edges[start_signal] - lc.hi_edges[start_signal]

        cumtime = lc.hi_edges[start_signal:stop_signal] - lc.lo_edges[start_signal]
        # Background subtraction 
        cumcounts = np.cumsum(lc.counts[start_signal:stop_signal] -
                              bkg_counts[start_signal:stop_signal])


        half_inv_quant = (1-np.array(quantile))/2

        #print(np.interp([half_inv_quant, 1-half_inv_quant],
        #                           cumcounts/cumcounts[-1],
        #                           cumtime))
        #tquant = np.diff(np.interp([half_inv_quant, 1-half_inv_quant],
        #                            cumcounts/cumcounts[-1],
        #                            cumtime), axis = 0)
        
        tquant = np.interp([half_inv_quant, 1-half_inv_quant],
                                   cumcounts/cumcounts[-1],
                                   cumtime) + np.array([offset,lc.lo_edges[start_signal]])[:,np.newaxis]

        '''
        if np.isscalar(quantile):
            tquant = tquant.item()
        else:
            tquant = tquant.reshape(np.shape(quantile))
        '''    
        return tquant