#! /usr/bin/env python
"""
Protein-Ligand Interaction Profiler - Analyze and visualize protein-ligand interactions in PDB files.
plipcmd - Main script for PLIP command line execution.
Copyright 2014-2015 Sebastian Salentin

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

# Compatibility
from __future__ import print_function

# Own modules
try:
    from plip.modules.preparation import *
    from plip.modules.visualize import visualize_in_pymol
    from plip.modules.plipremote import VisualizerData
    from plip.modules.report import TextOutput
    from plip.modules import config
    from plip.modules.mp import parallel_fn
except ImportError:
    from modules.preparation import *
    from modules.visualize import visualize_in_pymol, PyMOLComplex
    from modules.plipremote import VisualizerData
    from modules.report import TextOutput
    from modules import config
    from modules.mp import parallel_fn

# Python standard library
import sys
import argparse
from argparse import ArgumentParser
import urllib2
import time
import multiprocessing
import json

# External libraries
import lxml.etree as et

__version__ = '1.3.1'
descript = "Protein-Ligand Interaction Profiler (PLIP) v%s " \
           "is a command-line based tool to analyze interactions in a protein-ligand complex. " \
           "If you are using PLIP in your work, please cite: " \
           "Salentin,S. et al. PLIP: fully automated protein-ligand interaction profiler. " \
           "Nucl. Acids Res. (1 July 2015) 43 (W1): W443-W447. doi: 10.1093/nar/gkv315" % __version__


def threshold_limiter(aparser, arg):
    arg = float(arg)
    if arg <= 0:
        aparser.error("All thresholds have to be values larger than zero.")
    return arg


def check_pdb_status(pdbid):
    """Returns the status and up-to-date entry in the PDB for a given PDB ID"""
    url = 'http://www.rcsb.org/pdb/rest/idStatus?structureId=%s' % pdbid
    xmlf = urllib2.urlopen(url)
    xml = et.parse(xmlf)
    xmlf.close()
    status = None
    current_pdbid = pdbid
    for df in xml.xpath('//record'):
        status = df.attrib['status']  # Status of an entry can be either 'UNKWOWN', 'OBSOLETE', or 'CURRENT'
        if status == 'OBSOLETE':
            current_pdbid = df.attrib['replacedBy']  # Contains the up-to-date PDB ID for obsolete entries
    return [status, current_pdbid.lower()]


def fetch_pdb(pdbid):
    """Get the newest entry from the RCSB server for the given PDB ID. Exits with '1' if PDB ID is invalid."""
    pdbid = pdbid.lower()
    write_message('\nChecking status of PDB ID %s ... ' % pdbid)
    state, current_entry = check_pdb_status(pdbid)  # Get state and current PDB ID

    if state == 'OBSOLETE':
        write_message('entry is obsolete, getting %s instead.\n' % current_entry)
    elif state == 'CURRENT':
        write_message('entry is up to date.\n')
    elif state == 'UNKNOWN':
        sysexit(3, 'Invalid PDB ID (Entry does not exist on PDB server)\n')
    write_message('Downloading file from PDB ... ')
    pdburl = 'http://www.rcsb.org/pdb/files/%s.pdb' % current_entry  # Get URL for current entry
    pdbfile = urllib2.urlopen(pdburl).read()
    # If no PDB file is available, a text is now shown with "We're sorry, but ..."
    # Could previously be distinguished by an HTTP error
    if 'sorry' in pdbfile:
        sysexit(5, "No file in PDB format available from wwPDB for the given PDB ID.\n")
    return [pdbfile, current_entry]


def process_pdb(pdbfile, outpath):
    """Analysis of a single PDB file. Can generate textual reports XML, PyMOL session files and images as output."""
    startmessage = '\nStarting analysis of %s\n' % pdbfile.split('/')[-1]
    write_message(startmessage)
    write_message('='*len(startmessage)+'\n')
    mol = PDBComplex()
    mol.output_path = outpath
    mol.load_pdb(pdbfile)
    # #@todo Offers possibility for filter function from command line (by ligand chain, position, hetid)
    for ligand in mol.ligands:
        mol.characterize_complex(ligand)
    excluded = mol.excluded
    create_folder_if_not_exists(outpath)

    # Begin constructing the XML tree
    report = et.Element('report')
    plipversion = et.SubElement(report, 'plipversion')
    plipversion.text = __version__
    pdbid = et.SubElement(report, 'pdbid')
    pdbid.text = mol.pymol_name.upper()
    filetype = et.SubElement(report, 'filetype')
    filetype.text = mol.filetype.upper()
    pdbfile = et.SubElement(report, 'pdbfile')
    pdbfile.text = mol.sourcefiles['pdbcomplex']
    pdbfixes = et.SubElement(report, 'pdbfixes')
    pdbfixes.text = str(mol.information['pdbfixes'])
    filename = et.SubElement(report, 'filename')
    filename.text = str(mol.sourcefiles['filename'])
    exligs = et.SubElement(report, 'excluded_ligands')
    for i, exlig in enumerate(excluded):
        e = et.SubElement(exligs, 'excluded_ligand', id=str(i + 1))
        e.text = exlig

    # Write header of rST file
    textlines = ['Prediction of noncovalent interactions for PDB structure %s' % mol.pymol_name.upper(), ]
    textlines.append("=" * len(textlines[0]))
    textlines.append('Created on %s using PLIP v%s\n' % (time.strftime("%Y/%m/%d"), __version__))
    if len(excluded) != 0:
        textlines.append('Excluded molecules as ligands: %s\n' % ','.join([lig for lig in excluded]))

    config.MAXTHREADS = min(config.MAXTHREADS, len(mol.interaction_sets))

    ####
    # Create JSON dump of PyMOL helper class
    ###

    if config.VISJSON:
        complexes = [VisualizerData(mol, site) for site in sorted(mol.interaction_sets)
        if not len(mol.interaction_sets[site].interacting_res) == 0]

        with open('%s/plipvis.json' % outpath, 'w') as f:
            cplx_dict = {}
            for cplx in complexes:
                cplx_dict[cplx.uid] = cplx.to_json()
            json.dump(cplx_dict, f)

    ######################################
    # PyMOL Visualization (parallelized) #
    ######################################

    if config.PYMOL or config.PICS:
        complexes = [VisualizerData(mol, site) for site in sorted(mol.interaction_sets)
                     if not len(mol.interaction_sets[site].interacting_res) == 0]
        if config.MAXTHREADS > 1:
            write_message('\nGenerating visualizations in parallel on %i cores ...' % config.MAXTHREADS)
            parfn = parallel_fn(visualize_in_pymol)
            parfn(complexes, processes=config.MAXTHREADS)
        else:
            [visualize_in_pymol(plcomplex) for plcomplex in complexes]

    ##################################################################
    # Generate XML- and rST-formatted reports for each binding site. #
    ##################################################################

    for i, site in enumerate(sorted(mol.interaction_sets)):
        s = mol.interaction_sets[site]
        bindingsite = TextOutput(s).generate_xml()
        bindingsite.set('id', str(i + 1))
        bindingsite.set('has_interactions', 'False')
        report.insert(i + 1, bindingsite)
        for itype in TextOutput(s).generate_txt():
            textlines.append(itype)
        if not s.no_interactions:
            bindingsite.set('has_interactions', 'True')
        else:
            textlines.append('No interactions detected.')
        sys.stdout = sys.__stdout__  # Change back to original stdout, gets changed when PyMOL has been used before

    tree = et.ElementTree(report)
    if config.XML:  # Generate report in xml format
        tree.write('%s/report.xml' % outpath, pretty_print=True, xml_declaration=True)

    if config.TXT:  # Generate report in txt (rst) format
        with open('%s/report.txt' % outpath, 'w') as f:
            [f.write(textline + '\n') for textline in textlines]

def download_structure(inputpdbid):
    """Given a PDB ID, downloads the corresponding PDB structure.
    Checks for validity of ID and handles error while downloading.
    Returns the path of the downloaded file."""
    try:
        if len(inputpdbid) != 4 or extract_pdbid(inputpdbid.lower()) == 'UnknownProtein':
            sysexit(3, 'Invalid PDB ID (Wrong format)\n')
        pdbfile, pdbid = fetch_pdb(inputpdbid.lower())
        pdbpath = tilde_expansion('%s/%s.pdb' % (config.BASEPATH.rstrip('/'), pdbid))
        create_folder_if_not_exists(config.BASEPATH)
        with open(pdbpath, 'w') as g:
            g.write(pdbfile)
        write_message('file downloaded as %s\n\n' % pdbpath)
        return pdbpath, pdbid

    except ValueError:  # Invalid PDB ID, cannot fetch from RCBS server
        sysexit(3, 'Invalid PDB ID (Entry does not exist)\n')

def remove_duplicates(slist):
    """Checks input lists for duplicates and returns
    a list with unique entries"""
    unique = list(set(slist))
    difference = len(slist) - len(unique)
    if difference == 1:
        write_message("Removed one duplicate entry from input list.\n")
    if difference > 1:
        write_message("Removed %i duplicate entries from input list.\n" % difference)
    return unique


def main(inputstructs, inputpdbids):
    """Main function. Calls functions for processing, report generation and visualization."""
    pdbid, pdbpath = None, None
    # #@todo For multiprocessing, implement better stacktracing for errors
    # Print title and version
    title = "* Protein-Ligand Interaction Profiler v%s *" % __version__
    write_message('\n' + '*' * len(title) + '\n')
    write_message(title)
    write_message('\n' + '*' * len(title) + '\n\n')

    if inputstructs is not None:  # Process PDB file(s)
        num_structures = len(inputstructs)
        inputstructs = remove_duplicates(inputstructs)
        for inputstruct in inputstructs:
            if os.path.getsize(inputstruct) == 0:
                sysexit(2, 'Empty PDB file\n')  # Exit if input file is empty
            if num_structures > 1:
                basename = inputstruct.split('.')[0].split('/')[-1]
                config.OUTPATH = '/'.join([config.BASEPATH, basename])
            process_pdb(inputstruct, config.OUTPATH)
    else:  # Try to fetch the current PDB structure(s) directly from the RCBS server
        num_pdbids = len(inputpdbids)
        inputpdbids =remove_duplicates(inputpdbids)
        for inputpdbid in inputpdbids:
            pdbpath, pdbid = download_structure(inputpdbid)
            if num_pdbids > 1:
                config.OUTPATH = '/'.join([config.BASEPATH, pdbid[1:3].upper(), pdbid.upper()])
            process_pdb(pdbpath, config.OUTPATH)

    if (pdbid is not None or inputstructs is not None) and config.BASEPATH is not None:
        if config.BASEPATH in ['.', './']:
            write_message('\nFinished analysis. Find the result files in the working directory.\n\n')
        else:
            write_message('\nFinished analysis. Find the result files in %s\n\n' % config.BASEPATH)

if __name__ == '__main__':

    ##############################
    # Parse command line arguments
    ##############################

    parser = ArgumentParser(prog="PLIP", description=descript)
    pdbstructure = parser.add_mutually_exclusive_group(required=True)  # Needs either PDB ID or file
    pdbstructure.add_argument("-f", "--file", dest="input", nargs="+")
    pdbstructure.add_argument("-i", "--input", dest="pdbid", nargs="+")
    parser.add_argument("-o", "--out", dest="outpath", default="./")
    parser.add_argument("-v", "--verbose", dest="verbose", default=False, help="Set verbose mode", action="store_true")
    parser.add_argument("-p", "--pics", dest="pics", default=False, help="Additional pictures", action="store_true")
    parser.add_argument("-x", "--xml", dest="xml", default=False, help="Generate report file in XML format",
                        action="store_true")
    parser.add_argument("-t", "--txt", dest="txt", default=False, help="Generate report file in TXT (RST) format",
                        action="store_true")
    parser.add_argument("-y", "--pymol", dest="pymol", default=False, help="Additional PyMOL session files",
                        action="store_true")
    parser.add_argument("--maxthreads", dest="maxthreads", default=multiprocessing.cpu_count(),
                        help="Set maximum number of main threads (number of binding sites processed simultaneously)."
                             "If not set, PLIP uses all available CPUs if possible.",
                        type=int)
    parser.add_argument("--breakcomposite", dest="breakcomposite", default=False,
                        help="Don't combine ligand fragments into with covalent bonds but treat them as single ligands"
                             "fot the analysis.",
                        action="store_true")
    parser.add_argument("--altlocation", dest="altlocation", default=False,
                        help="Also consider alternate locations for atoms (e.g. alternate conformations).",
                        action="store_true")
    parser.add_argument("--debug", dest="debug", default=False,
                        help="Turn on DEBUG mode with extended log.",
                        action="store_true")
    parser.add_argument("--nofix", dest="nofix", default=False,
                        help="Turns off fixing of PDB files.",
                        action="store_true")
    # Optional threshold arguments, not shown in help
    thr = namedtuple('threshold', 'name type')
    thresholds = [thr(name='aromatic_planarity', type='angle'),
                  thr(name='hydroph_dist_max', type='distance'), thr(name='hbond_dist_max', type='distance'),
                  thr(name='hbond_don_angle_min', type='angle'), thr(name='pistack_dist_max', type='distance'),
                  thr(name='pistack_ang_dev', type='other'), thr(name='pistack_offset_max', type='distance'),
                  thr(name='pication_dist_max', type='distance'), thr(name='saltbridge_dist_max', type='distance'),
                  thr(name='halogen_dist_max', type='distance'), thr(name='halogen_acc_angle', type='angle'),
                  thr(name='halogen_don_angle', type='angle'), thr(name='halogen_angle_dev', type='other'),
                  thr(name='water_bridge_mindist', type='distance'), thr(name='water_bridge_maxdist', type='distance'),
                  thr(name='water_bridge_omega_min', type='angle'), thr(name='water_bridge_omega_max', type='angle'),
                  thr(name='water_bridge_theta_min', type='angle')]
    for t in thresholds:
        parser.add_argument('--%s' % t.name, dest=t.name, type=lambda val: threshold_limiter(parser, val),
                            help=argparse.SUPPRESS)

    arguments = parser.parse_args()
    config.VERBOSE = True if (arguments.verbose or arguments.debug) else False
    config.DEBUG = True if arguments.debug else False
    config.MAXTHREADS = arguments.maxthreads
    config.XML = arguments.xml
    config.TXT = arguments.txt
    config.PICS = arguments.pics
    config.PYMOL = arguments.pymol
    config.OUTPATH = arguments.outpath
    config.OUTPATH = tilde_expansion("".join([config.OUTPATH, '/'])
                                     if not config.OUTPATH.endswith('/') else config.OUTPATH)
    config.BASEPATH = config.OUTPATH  # Used for batch processing
    config.BREAKCOMPOSITE = arguments.breakcomposite
    config.ALTLOC = arguments.altlocation
    config.NOFIX = arguments.nofix
    # Assign values to global thresholds
    for t in thresholds:
        tvalue = getattr(arguments, t.name)
        if tvalue is not None:
            if t.type == 'angle' and not 0 < tvalue < 180:  # Check value for angle thresholds
                parser.error("Threshold for angles need to have values within 0 and 180.")
            if t.type == 'distance':
                if tvalue > 10:  # Check value for angle thresholds
                    parser.error("Threshold for distances must not be larger than 10 Angstrom.")
                elif tvalue > config.BS_DIST + 1:  # Dynamically adapt the search space for binding site residues
                    config.BS_DIST = tvalue + 1
            setattr(config, t.name.upper(), tvalue)
    # Check additional conditions for interdependent thresholds
    if not config.HALOGEN_ACC_ANGLE > config.HALOGEN_ANGLE_DEV:
        parser.error("The halogen acceptor angle has to be larger than the halogen angle deviation.")
    if not config.HALOGEN_DON_ANGLE > config.HALOGEN_ANGLE_DEV:
        parser.error("The halogen donor angle has to be larger than the halogen angle deviation.")
    if not config.WATER_BRIDGE_MINDIST < config.WATER_BRIDGE_MAXDIST:
        parser.error("The water bridge minimum distance has to be smaller than the water bridge maximum distance.")
    if not config.WATER_BRIDGE_OMEGA_MIN < config.WATER_BRIDGE_OMEGA_MAX:
        parser.error("The water bridge omega minimum angle has to be smaller than the water bridge omega maximum angle")
    expanded_path = tilde_expansion(arguments.input) if arguments.input is not None else None
    main(expanded_path, arguments.pdbid)  # Start main script
