# created by Dmitrey
from numpy import copy, isnan, array, argmax, abs, zeros, any, isfinite, where, asscalar
__docformat__ = "restructuredtext en"
empty_arr = array(())

class Point:
    """
    the class is used to prevent calling non-linear constraints more than once
    f, c, h are funcs for obtaining objFunc, non-lin ineq and eq constraints.
    df, dc, dh are funcs for obtaining 1st derivatives.
    """
    __expectedArgs__ = ['x', 'f', 'mr']
    def __init__(self, p, x, *args, **kwargs):
        self.p = p
        self.x = copy(x)
        for i, arg in enumerate(args):
            setattr(self, '_' + self.__expectedArgs__[i], args[i])
        for name, val in kwargs.iteritems():
            setattr(self, '_' + name, val)
        #assert self.x is not None

    def f(self):
        if not hasattr(self, '_f'): self._f = self.p.f(self.x)
        return copy(self._f)

    def df(self):
        if not hasattr(self, '_df'): self._df = self.p.df(self.x)
        return copy(self._df)

    def c(self, ind=None):
        if not self.p.userProvided.c: return empty_arr.copy()
        if ind is None:
            if not hasattr(self, '_c'): self._c = self.p.c(self.x)
            return copy(self._c)
        else:
            if hasattr(self, '_c'): return copy(self._c[ind])
            else: return copy(self.p.c(self.x, ind))


    def dc(self, ind=None):
        if not self.p.userProvided.c: return empty_arr.copy().reshape(0, self.p.n)
        if ind is None:
            if not hasattr(self, '_dc'): self._dc = self.p.dc(self.x)
            return copy(self._dc)
        else:
            if hasattr(self, '_dc'): return copy(self._dc[ind])
            else: return copy(self.p.dc(self.x, ind))


    def h(self, ind=None):
        if not self.p.userProvided.h: return empty_arr.copy()
        if ind is None:
            if not hasattr(self, '_h'): self._h = self.p.h(self.x)
            return copy(self._h)
        else:
            if hasattr(self, '_h'): return copy(self._h[ind])
            else: return copy(self.p.h(self.x, ind))

    def dh(self, ind=None):
        if not self.p.userProvided.h: return empty_arr.copy().reshape(0, self.p.n)
        if ind is None:
            if not hasattr(self, '_dh'): self._dh = self.p.dh(self.x)
            return copy(self._dh)
        else:
            if hasattr(self, '_dh'): return copy(self._dh[ind])
            else: return copy(self.p.dh(self.x, ind))

    def d2f(self):
        if not hasattr(self, '_d2f'): self._d2f = self.p.d2f(self.x)
        return copy(self._d2f)

    def lin_ineq(self):
        if not hasattr(self, '_lin_ineq'): self._lin_ineq = self.p.__get_AX_Less_B_Residuals__(self.x)
        return copy(self._lin_ineq)

    def lin_eq(self):
        if not hasattr(self, '_lin_eq'): self._lin_eq = self.p.__get_AeqX_eq_Beq_Residuals__(self.x)
        return copy(self._lin_eq)

    def lb(self):
        if not hasattr(self, '_lb'): self._lb = self.p.lb - self.x
        return copy(self._lb)

    def ub(self):
        if not hasattr(self, '_ub'): self._ub = self.x - self.p.ub
        return copy(self._ub)

    def mr(self, retAll = False):
        # returns max residual
        return self.__mr(retAll)

    def __mr(self, retAll = False):
        if not hasattr(self, '_mr') or (retAll and not hasattr(self, '_mrInd')):
            r, fname, ind = 0, None, None
            for field in ('c',  'lin_ineq', 'lb', 'ub'):
                fv = array(getattr(self, field)()).flatten()
                if fv not in ([], ()) and fv.size>0:
                    ind_max = argmax(fv)
                    val_max = fv[ind_max]
                    if r < val_max:
                        r, ind, fname = val_max, ind_max, field
            for field in ('h', 'lin_eq'):
                fv = array(getattr(self, field)()).flatten()
                if fv not in ([], ()) and fv.size>0:
                    fv = abs(fv)
                    ind_max = argmax(fv)
                    val_max = fv[ind_max]
                    if r < val_max:
                        r, ind, fname = val_max, ind_max, field
            self._mr, self._mrName,  self._mrInd= r, fname, ind
        if retAll:
            return asscalar(copy(self._mr)), self._mrName, asscalar(copy(self._mrInd))
        else: return asscalar(copy(self._mr))

    def dmr(self, retAll = False):
        # returns direction for max residual decrease
        #( gradient for equality < 0 residuals ! )
        return self.__dmr(retAll)

    def __dmr(self, retAll = False):
        if not hasattr(self, '_dmr') or (retAll and not hasattr(self, '_dmrInd')):
            g = zeros(self.p.n)
            maxResidual, resType, ind = self.mr(retAll=True)
            if resType == 'lb':
                g[ind] -= 1 # N * (-1), -1 = dConstr/dx = d(lb-x)/dx
            elif resType == 'ub':
                g[ind] += 1 # N * (+1), +1 = dConstr/dx = d(x-ub)/dx
            elif resType == 'lin_ineq':
                g += self.p.A[ind]
            elif resType == 'lin_eq':
                rr = self.p.matmult(self.p.Aeq[ind], self.x)-self.p.beq[ind]
                if rr < 0:  g -= self.p.Aeq[ind]
                else:  g += self.p.Aeq[ind]
            elif resType == 'c':
                dc = self.dc(ind=ind).flatten()
                g += dc
            elif resType == 'h':
                dh = self.dh(ind=ind).flatten()
                if self.p.h(self.x, ind=ind) < 0:  g -= dh#CHECKME!!
                else: g += dh#CHECKME!!
            self._dmr, self._dmrName,  self._dmrInd = g, resType, ind
        if retAll:
            return copy(self._dmr),  self._dmrName,  copy(self._dmrInd)
        else:
            return copy(self._dmr)

    def betterThan(self, *args, **kwargs):
        """
        usage: result = involvedPoint.better(pointToCompare)

        returns True if the involvedPoint is better than pointToCompare
        and False otherwise
        (if NOT better, mb same fval and same residuals or residuals less than desired contol)
        """
        return self.__betterThan__(*args, **kwargs)

    def __betterThan__(self, oldPoint):
        if self.p.isUC:
            return self.f() < oldPoint.f()
        else:
            oldPointResidual = oldPoint.mr()
            criticalResidualValue = max(self.p.contol, oldPointResidual)

            if hasattr(self, '_mr'):
                if self._mr > criticalResidualValue: return False
            else:
                #TODO: simplify it!
                #for fn in Residuals: (...)
                if any(self.lb() > criticalResidualValue): return False
                if any(self.ub() > criticalResidualValue): return False
                if any(abs(self.lin_eq()) > criticalResidualValue): return False
                if any(self.lin_ineq() > criticalResidualValue): return False
                if any(abs(self.h()) > criticalResidualValue): return False
                if any(self.c() > criticalResidualValue): return False

            mr = self.mr()

            if not self.p.isNaNInConstraintsAllowed:
                if oldPoint.__nNaNs__()  > self.__nNaNs__(): return True
                elif oldPoint.__nNaNs__()  < self.__nNaNs__(): return False
                # TODO: check me
                if mr <= self.p.contol and oldPointResidual <= self.p.contol and self.__nNaNs__() != 0: return mr < oldPointResidual

            if mr < oldPointResidual and self.p.contol < oldPointResidual: return True

            oldPointF_is_NaN = isnan(oldPoint.f())
            selfF_is_NaN = isnan(self.f())

            if not oldPointF_is_NaN: # f(oldPoint) is not NaN
                if not selfF_is_NaN: # f(newPoint) is not NaN
                    return self.f() < oldPoint.f()
                else: # f(newPoint) is NaN
                    return False
            else: # f(oldPoint) is NaN
                if selfF_is_NaN: # f(newPoint) is NaN
                    return mr < oldPointResidual
                else: # f(newPoint) is not NaN
                    return True

    def isFeas(self):
        return self.__isFeas__()

    def __isFeas__(self):
        contol = self.p.contol
        if hasattr(self, '_mr'):
            if self._mr > contol: return False
        else:
            #TODO: simplify it!
            #for fn in Residuals: (...)
            if any(self.lb() > contol): return False
            if any(self.ub() > contol): return False
            if any(abs(self.lin_eq()) > contol): return False
            if any(self.lin_ineq() > contol): return False
            if any(abs(self.h()) > contol): return False
            if any(self.c() > contol): return False
        return True

    def __nNaNs__(self):
        # returns number of nans in constraints
        r = 0
        c, h = self.c(), self.h()
        r += len(where(isnan(c))[0])
        r += len(where(isnan(h))[0])
        return r

