#
# Copyright (C) 2004 Mekensleep
#
# Mekensleep
# 24 rue vieille du temple
# 75004 Paris
#       licensing@mekensleep.com
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
#
# Authors:
#  Loic Dachary <loic@gnu.org>
#
import sys
from string import replace
from xml.dom.pulldom import PullDOM
#
# libxml2 would do a faster/better job but can't be used until
# http://bugzilla.gnome.org/show_bug.cgi?id=142775 is resolved
#
from xml.sax import make_parser
import xml.sax.handler

from twisted.internet import reactor, protocol
from twisted.python import dispatch

AUTH_SUCCESS = "//event/poker3d/jabberproxy/auth_success"
AUTH_FAILURE = "//event/poker3d/jabberproxy/auth_failure"
SESSION_START = "//event/poker3d/jabberproxy/session_start"

class _StreamHandler(PullDOM):

    def __init__(self, caller):
        """ caller: the class handling stream events """
        PullDOM.__init__(self)
        self._caller=caller
        
    def startElementNS(self, name, tagName , attrs):        
        PullDOM.startElementNS(self, name, tagName, attrs)
        if len(self.elementStack)==2:
            if name==('http://etherx.jabber.org/streams','stream'):
                self._caller.streamOpened(self.lastEvent[0][1])
            else:
                self._caller.closeStream()
                
        elif len(self.elementStack)>3:
            curNode = self.elementStack[-1]
            parentNode = self.elementStack[-2]
            parentNode.appendChild(curNode)
    
    def endElementNS(self, name, tagName):
        PullDOM.endElementNS(self, name, tagName)
        if len(self.elementStack)==2:
            self._caller.elementReceived(self.lastEvent[0][1])
        elif len(self.elementStack)==1:
            self._caller.closeStream()
        self.firstEvent=[None,None]
        self.lastEvent=self.firstEvent
            
    def processingInstruction(self, target, data):
        PullDOM.processingInstruction(self, target, data)
        node = self.lastEvent[0][1]
        parentNode = self.elementStack[-1]
        parentNode.appendChild(node)

    def ignorableWhitespace(self, chars):
        PullDOM.ignorableWhitespace(self, chars)
        if len(self.elementStack)>2:
            node = self.lastEvent[0][1]
            parentNode = self.elementStack[-1]
            parentNode.appendChild(node)

    def characters(self, chars):
        PullDOM.characters(self, chars)
        if len(self.elementStack)>2:
            node = self.lastEvent[0][1]
            parentNode = self.elementStack[-1]
            parentNode.appendChild(node)

def _escapedata(data):
    data = replace(data, "&", "&amp;")
    data = replace(data, "<", "&lt;")
    data = replace(data, "\"", "&quot;")
    data = replace(data, ">", "&gt;")
    return data

class JabberProxy(protocol.Protocol):

    def __init__(self):
        self.side = "Unknown"
        self.peer = None
        self.buf = ''
        self.resetParser()

    def resetParser(self):
        self._parser=make_parser()
        handler=_StreamHandler(self)
        self._parser.setContentHandler(handler)
        self._parser.setFeature(xml.sax.handler.feature_namespaces, 1)

    def setPeer(self, peer):
        self.peer = peer
#        print "peer/forwarding: %s/%s" % ( self.peer, self.buf )
        self.peer.transport.write(self.buf)
        self.buf = ''

    def connectionLost(self, reason):
        if self.peer:
            self.peer.transport.loseConnection()
        del self.peer

    def elementReceived(self, tag):
#        print "XMLStream:elementReceived: %s"%(tag.toxml())

        self.dataForward(tag.toxml().encode('UTF-8'))

    def streamOpened(self, tag):
        """ Received the opened stream event 
        Clients must override this method if they want to start the pre 1.0
        authentication process
        """
        self.version=tag.getAttribute('version')
        self._from=tag.getAttribute('from')
#        print "XMLStream::streamOpened: %s"%(tag.toxml())
        #
        # Do not rely on tag.toxml. It will end the opening tag with
        # /> instead of > because the closing tag has not been seen
        # yet.
        #
        self._id=tag.getAttribute('id')
        buf = "<"
        buf += tag.tagName
        attrs = tag._get_attributes()
        a_names = attrs.keys()
        a_names.sort()

        for a_name in a_names:
            buf += " %s=\"" % a_name
            buf += _escapedata(attrs[a_name].value)
            buf += "\""
        buf += ">"
        self.dataForward("<?xml version='1.0' encoding='UTF-8'?>" + buf.encode('UTF-8'))

    def closeStream(self):
        self.transport.write("</stream:stream>")
        self.transport.loseConnection()

    def dataForward(self, data):
        if self.peer:
#            print "forwarding: %s: %s" % (self.side, data)
            self.peer.transport.write(data)
        else:
#            print "buffering: %s: %s" % ( self.side, data )
            self.buf += data

    def dataReceived(self, data):
        if self.factory.verbose > 3:
            print "%s: %s" % ( self.side, data )
        if data == "	":
            self.nullMessage()
        self._parser.feed(data)

    def nullMessage(self):
        pass


STARTING = 1
AUTHENTICATED = 2
SESSION = 3

class JabberProxyClient(JabberProxy):
    """Jabber proxy client side"""

    def __init__(self, *args, **kwargs):
        JabberProxy.__init__(self, *args, **kwargs)
        self.side = "client <= server"
        self.state = STARTING
        
    def connectionMade(self):
        self.state = STARTING
        self.peer.setPeer(self)

    def elementReceived(self, tag):
        JabberProxy.elementReceived(self, tag)
        if self.state == STARTING:
            if ( tag.tagName == "success" and
                 tag.getAttribute("xmlns") == "urn:ietf:params:xml:ns:xmpp-sasl" ):
                self.resetParser()
                self.peer.resetParser()
                # self.transport.write("<?xml version='1.0' encoding='UTF-8'?>")
                self.state = AUTHENTICATED
                self.peer.factory.publishEvent(AUTH_SUCCESS)
            if ( tag.tagName == "failure" ):
                self.authFailed()
                self.peer.factory.publishEvent(AUTH_FAILURE)
            
        elif ( self.state == AUTHENTICATED and
               hasattr(self.peer, "session_id") and
               tag.getAttribute("id") == self.peer.session_id ):
            self.state == SESSION
            if self.factory.verbose > 3:
                print "Session %s initiated" % self.peer.session_id
            del self.peer.session_id
            self.sessionStart()
            self.peer.factory.publishEvent(SESSION_START, self.peer.factory)

    def authFailed(self):
        pass
    
    def sessionStart(self):
        pass

class JabberProxyClientFactory(protocol.ClientFactory):
    """Jabber proxy client side (factory)"""

    protocol = JabberProxyClient

    def __init__(self, *args, **kwargs):
        self.verbose = 0
        self.server = None
        
    def setServer(self, server):
        self.server = server

    def buildProtocol(self, *args, **kw):
        prot = protocol.ClientFactory.buildProtocol(self, *args, **kw)
        prot.setPeer(self.server)
        return prot

    def clientConnectionLost(self, connector, reason):
        print "JabberProxyServerFactory: %s:%d: %s" % ( self.host, self.port, reason )

    def clientConnectionFailed(self, connector, reason):
        print "JabberProxyServerFactory: unable to connect to %s:%d: %s" % ( self.host, self.port, reason )

class JabberProxyServer(JabberProxy):
    """Jabber proxy server side"""

    clientProtocolFactory = JabberProxyClientFactory

    def __init__(self, *args, **kwargs):
        JabberProxy.__init__(self, *args, **kwargs)
        self.side = "client => server"
        
    def connectionMade(self):
        client = self.clientProtocolFactory()
        client.setServer(self)
        client.host = self.factory.host
        client.port = self.factory.port
        client.verbose = self.factory.verbose
        reactor.connectTCP(client.host, client.port,
                           client)

    def elementReceived(self, tag):
        if self.peer.state == AUTHENTICATED:
            if ( tag.tagName == "iq" and
                 tag.getAttribute("type") == "set" and
                 tag.getElementsByTagName("session") ):
                session = tag.getElementsByTagName("session")[0]
                if session and session.getAttribute("xmlns") == "urn:ietf:params:xml:ns:xmpp-session":
                    self.session_id = tag.getAttribute("id")
        JabberProxy.elementReceived(self, tag)

    def connectionLost(self, reason):
        if self.factory.verbose:
            print "JabberProxyServer: connection from jabber client is lost"
        JabberProxy.connectionLost(self, reason)
        
class JabberProxyServerFactory(protocol.Factory, dispatch.EventDispatcher):
    """Jabber proxy server side (factory)"""

    protocol = JabberProxyServer

    def __init__(self, host, port, verbose = 0):
        self.verbose = verbose
        self.host = host
        self.port = port
        dispatch.EventDispatcher.__init__(self)

    # This fixes a bug in python.dispatch
    def publishEvent(self, name, *args, **kwargs):
        # here the original publishEvent rises an exception if no listerner
        # is registered
        for cb in self.callbacks.get(name,[]):  
            cb(*args, **kwargs)

if __name__ == '__main__':
    port = 5222
    factory = JabberProxyServerFactory('mekensleep.org', port)
    factory.verbose = len(sys.argv) - 1
    reactor.listenTCP(port, factory, 5, '127.0.0.1')
    reactor.run()

