#
# Copyright (C) 2004 Mekensleep
#
# Mekensleep
# 24 rue vieille du temple
# 75004 Paris
#       licensing@mekensleep.com
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
#
# Authors:
#  Loic Dachary <loic@gnu.org>
#  Henry Precheur <henry@precheur.org>
#
# 
from struct import pack, unpack
from twisted.internet import reactor, protocol
import MySQLdb

from underware.packets import *
from underware.protocol import UGAMEProtocol
from underware.user import User

AUTH_TIMEOUT = 180

class UGAMEServer(UGAMEProtocol):
    """UGAMEServerProtocol"""

    def __init__(self):
        self.user = User()
        self.__login_timer = None
        UGAMEProtocol.__init__(self)

    def isAuthorized(self, type):
        return self.user.hasPrivilege(self.factory.authGetLevel(type))

    def askAuth(self, packet):
        self.sendPacketVerbose(PacketAuthRequest())
        context = {
            "packets": [ packet ],
            "handler": self._expected_handler,
            "args": self._expected_args,
            "kwargs": self._expected_kwargs
            }
        self.__login_timer = reactor.callLater(AUTH_TIMEOUT, self.__auth_expires, context)
        self.expect(self.auth, context)

    def __auth_expires(self, context):
        self.sendPacketVerbose(PacketAuthExpires())
        self.__login_timer = None
        if self.factory.verbose > 1:
            print "context = %s" % context
        self.expect(context["handler"], *context["args"], **context["kwargs"])

    def login(self, info):
        (serial, name, privilege) = info
        self.user.serial = serial
        self.user.name = name
        self.user.privilege = privilege
        self.sendPacketVerbose(PacketSerial(serial = self.user.serial))
        if self.factory.verbose:
            print "user %s/%d logged in" % ( self.user.name, self.user.serial )

    def logout(self):
        self.user.logout()
        
    def auth(self, packet, context = None):
        if ( packet.type != PACKET_LOGIN and
             packet.type != PACKET_AUTH_CANCEL ):
            if context and context.has_key("packets"):
                if self.factory.verbose:
                    print "packet prepended to backlog"
                context["packets"].insert(0, packet)
            else:
                if self.factory.verbose:
                    print  "expected PACKET_LOGIN or PACKET_AUTH_CANCEL, ignored"
            self.expect(self.auth, context)
            return

        if self.__login_timer and self.__login_timer.active():
            self.__login_timer.cancel()
        self.__login_timer = None

        if packet.type == PACKET_AUTH_CANCEL:
            self.expect(context["handler"], *context["args"], **context["kwargs"])
            return
            
        if self.user.checkNameAndPassword(packet.name, packet.password):
            info = self.factory.auth(packet.name, packet.password)
        else:
            info = False
        if info:
            self.sendPacketVerbose(PacketAuthOk())
            self.login(info)
        else:
            self.sendPacketVerbose(PacketAuthRefused())

        if context:
            packets = context.has_key("packets") and context["packets"]
            for packet in packets:
                print "PACKET %s " % packet
                if packet and self.isAuthorized(packet.type):
                    if hasattr(packet, "serial"):
                        packet.serial = self.getSerial()
                    self._handleConnection(packet)
            self.expect(context["handler"], *context["args"], **context["kwargs"])

    def userRemove(self, packet):
        if self.getSerial() == packet.serial:
            self.factory.userRemove(self.user)
            self.transport.loseConnection()

    def getSerial(self):
        return self.user.serial

    def getName(self):
        return self.user.name

    def sendPacket(self, packet):
        self.transport.write(packet.pack())

    def sendPacketVerbose(self, packet):
        if self.factory.verbose > 1:
            print "sendPacket: %s" % str(packet)
        self.sendPacket(packet)
        
class UGAMEServerFactory(protocol.ServerFactory):
    """Factory"""

    def __init__(self, *args, **kwargs):
        self.type2auth = {}
        self.database = kwargs["database"]
        self.client_count = 0
        database = self.database
        self.db = MySQLdb.connect(host = database["host"],
                                  user = database["user"],
                                  passwd = database["password"],
                                  db = database["name"])
        #
        # Database will be close when the object is destroyed
        #
        print "Database connection to %s/%s open" % ( database["host"], database["name"] )        
        self.verbose = 0

    def authSetLevel(self, type, level):
        self.type2auth[type] = level

    def authGetLevel(self, type):
        return self.type2auth.has_key(type) and self.type2auth[type]
    
    def auth(self, name, password):
        cursor = self.db.cursor()
        cursor.execute("select serial, password, privilege from users "
                       "where name = '%s'" % name)
        numrows = int(cursor.rowcount)
        serial = 0
        privilege = User.REGULAR
        if numrows <= 0:
            if self.verbose > 1:
                print "user %s does not exist, create it" % name
            serial = self.userCreate(name, password)
        elif numrows > 1:
            print "more than one row for %s" % name
            return False
        else: 
            (serial, password_sql, privilege) = cursor.fetchone()
            cursor.close()
            if password_sql != password:
                print "password mismatch for %s" % name
                return False

        return (serial, name, privilege)

    def userCreate(self, name, password):
        if self.verbose:
            print "creating user %s" % name,
        cursor = self.db.cursor()
        cursor.execute("insert into users (name, password) values ('%s', '%s')" %
                       (name, password))
        #
        # Accomodate for MySQLdb versions < 1.1
        #
        if hasattr(cursor, "lastrowid"):
            serial = cursor.lastrowid
        else:
            serial = cursor.insert_id()
        if self.verbose:
            print "create user with serial %s" % serial
        cursor.close()
        return int(serial)

    def userRemove(self, user):
        if self.verbose:
            print "removing %s" % user
        cursor = self.db.cursor()
        cursor.execute("delete from users where serial = %d" % user.serial)
        cursor.close()
