import pyfits as pf
import numpy as np
import scipy.optimize as opt
from scipy.stats.distributions import chi2
###units 1000 if ksec, 1 if sec!
s2u = 1.
def heaviside(x):
    return 0.5 * (np.sign(x) + 1)

def compute_variability_parameters(rate, rate_err, time, tstart):
  # read the values in the binned light curve
  # (time is measured from tstart)
    
  BestfitPar = BestfitParameters(time, rate, rate_err)

  # extract the chi2 test statistics with its associated P-value
  # for a rate model linear in time
  BestfitPar.FitLinear()
  # for a rate model quadratic in time
  BestfitPar.FitFlare()
  # for a constant rate model with an eclipse
  BestfitPar.FitEclipse()
  return BestfitPar

class BestfitParameters:
  def __init__(self, X=[], Y=[], Y_err=[]):
    self.N_points = len(X)

    self.__X = X
    self.__Y = Y
    self.__Y_err = Y_err
    
    self.Linear = []
    self.Flare = []
    self.Eclipse = []

    self.Hdu = 0
    self.Hdr = []


  def FitLinear(self):
    self.Linear = compute_chi2_linear(self.__Y, self.__Y_err, self.__X)

  def FitFlare(self):
    self.Flare = compute_chi2_flare(self.__Y, self.__Y_err, \
        self.__X)

  def FitEclipse(self):
    self.Eclipse = compute_chi2_eclipse(self.__Y, self.__Y_err, \
        self.__X)


def compute_chi2_linear(rate, rate_err, time):
  n_points = np.size(rate)
  epsilon = 1.e-8
  chi2_TS = -1.; chi2_prob = -1.;
  cons = 0.; cons_err = -1.
  lin = 0.; lin_err = -1.
  chi2_dof = n_points - 2
  if chi2_dof <= 0:
    print 'linear fit notice: not enough bins, skip'
    return [chi2_TS, chi2_prob, chi2_dof, cons, cons_err, lin, lin_err]
  # extract the chi2 test statistics with its associated P-value
  drate = [np.sqrt(epsilon + d ** 2.) for d in rate_err]
  params = opt.leastsq(residual_linear, x0=[0.,0.], \
      args=(rate, drate, time), full_output=True)
  # get parameters and errors only if leastsq converged
  if params[1] is not None and params[4] > 0 and params[4] < 5:
    if params[1][0][0] >= 0. and params[1][1][1] >= 0.:
      cons = params[0][0]
      cons_err = np.sqrt(params[1][0][0])
      lin = params[0][1] * s2u
      lin_err = np.sqrt(params[1][1][1]) * s2u
      offsets = residual_linear(params[0], rate, drate, time)
      chi2_TS = np.inner(offsets, offsets)
      chi2_prob = chi2.sf(chi2_TS, chi2_dof)
    else:
      print 'linear fit error: negative values for errors squared'
  else:
    print 'linear fit warning: leastsq did not converge'
  if chi2_TS >= 0.:
    print 'linear fit: Slope=%.3g+/-%.2g Const=%.3g+/-%.2g chi2(%d)=%.1f' % \
        (lin, lin_err, cons, cons_err, chi2_dof, chi2_TS)
  return [chi2_TS, chi2_prob, chi2_dof, cons, cons_err, lin, lin_err]



def compute_chi2_flare(rate, rate_err, time):
  n_points = np.size(rate)
  epsilon = 1.e-8
  chi2_TS = -1.; chi2_prob = -1.;
  norm = 0.; norm_err = -1.
  t_burst = 0.; t_burst_err = -1.
  t_decay = 0.; t_dec_err = -1.
  cons = 0.; cons_err = -1.
  chi2_dof = n_points - 4
  if chi2_dof <= 0:
    print 'flare fit notice: not enough bins, skip'
    return [chi2_TS, chi2_prob, chi2_dof, \
        cons, cons_err, norm, norm_err, t_burst, t_burst_err, t_decay, t_dec_err]
  # extract the chi2 test statistics with its associated P-value
  drate = [np.sqrt(epsilon + d ** 2.) for d in rate_err]
  constr = get_flare_constraints(time)
  # get a first guess of the global minimum by brute force
  params = locate_flare(rate, drate, time, constr)
  # follow up only if there is a flare
  if params[1] is not None and params[0][0] > 0.:
    params = opt.leastsq(residual_flare, x0=params[0], \
        args=(rate, drate, time, constr), full_output=True)
    # get parameters and errors only if leastsq converged
    if params[1] is not None and params[4] > 0 and params[4] < 5:
      if params[1][0][0] >= 0. and params[1][1][1] >= 0. \
          and params[1][2][2] >= 0. and params[1][3][3] >= 0.:
        norm = params[0][0]
        norm_err = np.sqrt(params[1][0][0])
        t_burst = params[0][1] / s2u
        t_burst_err = np.sqrt(params[1][1][1]) / s2u
        t_decay = params[0][2] / s2u
        t_dec_err = np.sqrt(params[1][2][2]) / s2u
        cons = params[0][3]
        cons_err = np.sqrt(params[1][3][3])
        chi2_TS = res_square_flare(params[0], rate, drate, time, constr)
        chi2_prob = chi2.sf(chi2_TS, chi2_dof)
      else:
        print 'flare fit error: negative values for errors squared'
    else:
      print 'flare fit warning: leastsq did not converge'
  else:
    print 'flare fit warning: brute did not converge'
  if chi2_TS >= 0.:
    print 'flare fit: Norm=%.3g+/-%.2g Time_Burst=%.3g+/-%.2g Time_Decay=%.3g+/-%.2g Const=%.3g+/-%.2g chi2(%d)=%.1f' % \
        (norm, norm_err, t_burst, t_burst_err, t_decay, t_dec_err, cons, cons_err, chi2_dof, chi2_TS)
  return [chi2_TS, chi2_prob, chi2_dof, \
      cons, cons_err, norm, norm_err, t_burst, t_burst_err, t_decay, t_dec_err]

def compute_chi2_eclipse(rate, rate_err, time):
  n_points = np.size(rate)
  epsilon = 1.e-8
  chi2_TS = -1.; chi2_prob = -1.;
  t_init = 0.; t_init_err = -1.
  t_exit = 0.; t_exit_err = -1.
  drop = 0.; drop_err = -1.
  cons = 0.; cons_err = -1.
  chi2_dof = n_points - 4
  if chi2_dof <= 0:
    print 'eclipse fit notice: not enough bins, skip'
    return [chi2_TS, chi2_prob, chi2_dof, \
        cons, cons_err, drop, drop_err, t_init, t_init_err, t_exit, t_exit_err]
  # extract the chi2 test statistics with its associated P-value
  drate = [np.sqrt(epsilon + d ** 2.) for d in rate_err]
  constr = get_eclipse_constraints(time)
  # get a first guess of the global minimum by brute force
  params = locate_eclipse(rate, drate, time, constr)
  # follow up only if there is a flare
  if params[1] is not None and params[0][0] > 0.:
    params = opt.leastsq(residual_eclipse, x0=params[0], \
        args=(rate, drate, time, constr), full_output=True)
    # get parameters and errors only if leastsq converged
    if params[1] is not None and params[4] > 0 and params[4] < 5:
      if params[1][0][0] >= 0. and params[1][1][1] >= 0. \
          and params[1][2][2] >= 0. and params[1][3][3] >= 0.:
        drop = params[0][0]
        drop_err = np.sqrt(params[1][0][0])
        t_init = params[0][1] / s2u
        t_init_err = np.sqrt(params[1][1][1]) / s2u
        t_exit = params[0][2] / s2u
        t_exit_err = np.sqrt(params[1][2][2]) / s2u
        cons = params[0][3]
        cons_err = np.sqrt(params[1][3][3])
        chi2_TS = res_square_eclipse(params[0], rate, drate, time, constr)
        chi2_prob = chi2.sf(chi2_TS, chi2_dof)
      else:
        print 'eclipse fit error: negative values for errors squared'
    else:
      print 'eclipse fit warning: leastsq did not converge'
  else:
    print 'eclipse fit warning: brute did not converge'
  if chi2_TS >= 0.:
    print 'eclipse fit: Drop=%.3g+/-%.2g Time_Init=%.3g+/-%.2g Time_Exit=%.3g+/-%.2g Const=%.3g+/-%.2g chi2(%d)=%.1f' % \
        (drop, drop_err, t_init, t_init_err, t_exit, t_exit_err, cons, cons_err, chi2_dof, chi2_TS)
  return [chi2_TS, chi2_prob, chi2_dof, \
      cons, cons_err, drop, drop_err, t_init, t_init_err, t_exit, t_exit_err]
def locate_eclipse(rate, drate, time, constr):
  n_points = np.size(rate)
  n_steps = 12
  ti_min, te_max, dtime = constr
  tirange = slice(ti_min, te_max-dtime, 2.*(te_max-ti_min)/n_points)
  terange = slice(ti_min+dtime, te_max, 2.*(te_max-ti_min)/n_points)
  r_min = np.min(rate)
  r_max = np.max(rate)
  r_ave = np.mean(rate)
  crange = slice(r_ave, r_max, (r_max-r_ave)/float(n_steps))
  drange = slice(0., r_max-r_min, (r_max-r_min)/float(n_steps))
  return opt.brute(res_square_eclipse, \
      ranges=(drange, tirange, terange, crange), \
      args=(rate, drate, time, constr), full_output=True)

def get_eclipse_constraints(time):
  dtime = (time[1] - time[0]) / 2.
  if time[-1] - time[0] > dtime:
    ti_min = time[0] + dtime
    te_max = time[-1] - dtime
  else:
    tave = (time[-1] + time[0]) / 2.
    ti_min = tave - dtime / 2.
    te_max = tave + dtime / 2.   
  return ti_min, te_max, dtime

def bind_eclipse_params(params, constr):
  ti_min, te_max, dt = constr
  d, ti, te, c = params
  if ti < ti_min:
    ti = ti_min
  elif ti > te_max - dt:
    ti = te_max - dt
  if d < 0.:
    d = 0.
  if te < ti + dt:
    te = ti + dt
  elif te > te_max:
    te = te_max 
  return d, ti, te, c

def residual_eclipse(params, y, dy, x, constr):
  d, ti, te, c = bind_eclipse_params(params, constr)
  dt = 2. * constr[2]
  # center and duration of the eclipse
  tc = (ti + te) / 2.
  dur = te - ti
  # this way, also partially eclipsed bins are considered
  y_mod = c - d * np.clip(((dt+dur)/2. - np.fabs(x-tc)) / dt, 0., 1.)

  return (y - y_mod) / dy

def res_square_eclipse(params, y, dy, x, constr):
  residuals = residual_eclipse(params, y, dy, x, constr)
  return np.inner(residuals, residuals)

def locate_flare(rate, drate, time, constr):
  n_points = np.size(rate)
  n_steps = 12
  tb_min, tb_max, dtime = constr
  tbrange = slice(tb_min, tb_max, 2.*(tb_max-tb_min)/n_points)
  tdrange = slice(tb_min, (tb_min+tb_max)/2., (tb_max-tb_min)/float(n_steps))
  r_min = np.min(rate)
  r_max = np.max(rate)
  r_ave = np.mean(rate)
  crange = slice(r_min, r_ave, (r_ave-r_min)/float(n_steps))
  nrange = slice(0., r_max-r_min, (r_max-r_min)/float(n_steps))
  return opt.brute(res_square_flare, \
      ranges=(nrange, tbrange, tdrange, crange), \
      args=(rate, drate, time, constr), full_output=True)

def get_flare_constraints(time):
  dtime = (time[1] - time[0]) / 2.
  tb_min = (time[0] + time[1]) / 2.
  tb_max = (time[-3] + time[-2]) / 2.
  return tb_min, tb_max, dtime

def bind_flare_params(params, constr):
  tb_min, tb_max, dt = constr
  n, tb, td, c = params
  if tb < tb_min:
    tb = tb_min
  elif tb > tb_max:
    tb = tb_max
  if n < 0.:
    n = 0.
  if td < dt:
    td = dt
  elif td > 1.e6:
    td = 1.e6
  return n, tb, td, c

def residual_flare(params, y, dy, x, constr):
  n, tb, td, c = bind_flare_params(params, constr)
  dt = 2. * constr[2]
  y_mod = n * np.clip((x - tb) / dt + .5, 0., 1.) * \
      np.exp(-1. * np.clip((x - tb) / td, -2., 10.)) + c
  return (y - y_mod) / dy

def res_square_flare(params, y, dy, x, constr):
  residuals = residual_flare(params, y, dy, x, constr)
  return np.inner(residuals, residuals)


def residual_linear(params, y, dy, x):
  c, b = params
  return (y - (b*x + c)) / dy


