from twisted.internet.protocol import Protocol from twisted.internet.protocol import Factory from twisted.internet.endpoints import TCP4ServerEndpoint from twisted.protocols.policies import TimeoutMixin from sqlalchemy import create_engine, and_ from sqlalchemy.orm import sessionmaker from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound from collections import deque import struct import os import logging from exc import * from model import * def get_hex(data): return "".join([hex(ord(c))[2:].zfill(2) for c in data]) db_path = "root:helloworld@localhost/piztor" #db_path = "piztor.sqlite" FORMAT = "%(asctime)-15s %(message)s" logging.basicConfig(format = FORMAT) logger = logging.getLogger('piztor_server') logger.setLevel(logging.INFO) engine = create_engine('mysql://' + db_path, echo = False, pool_size = 1024) class _SectionSize: LENGTH = 4 OPT_ID = 1 STATUS = 1 USER_ID = 4 USER_TOKEN = 32 GROUP_ID = 2 ENTRY_CNT = 4 LATITUDE = 8 LONGITUDE = 8 LOCATION_ENTRY = USER_ID + LATITUDE + LONGITUDE PADDING = 1 _MAX_AUTH_HEAD_SIZE = _SectionSize.USER_TOKEN + \ MAX_USERNAME_SIZE + \ _SectionSize.PADDING _HEADER_SIZE = _SectionSize.LENGTH + \ _SectionSize.OPT_ID _MAX_TEXT_MESG_SIZE = 1024 class _OptCode: user_auth = 0x00 location_update = 0x01 location_info= 0x02 user_info = 0x03 user_logout = 0x04 open_push_tunnel = 0x05 send_text_mesg = 0x06 class _StatusCode: sucess = 0x00 failure = 0x01 class PushData(object): from hashlib import sha256 def __init__(self, data): self.data = data self.finger_print = sha256(data).digest() class PushTextMesgData(PushData): def __init__(self, mesg): self.finger_print = sha256(mesg).digest() buff = struct.pack("!B32s", 0x00, self.finger_print) buff += mesg buff += chr(0) buff = struct.pack("!L", _SectionSize.LENGTH + len(buff)) + buff self.data = buff class PushTunnel(object): def __init__(self): self.pending = deque() self.conn = None def __del__(self): if self.conn: self.conn.loseConnection() def add(self, pdata): logger.info("-- Push data enqued --") logger.info("Data: %s", get_hex(pdata.data)) self.pending.append(pdata) def on_receive(self, data): front = self.pending.popleft() length, optcode, fingerprint = struct.unpack("!LB32s", data) if front.finger_print != fingerprint: raise PiztorError logger.info("-- Push data confirmed --") self.push() def push(self): if (self.conn is None) or len(self.pending) == 0: return print "Pushing" front = self.pending.popleft() self.pending.appendleft(front) self.conn.transport.write(front.data) def connect(self, conn): conn.tunnel = self self.conn = conn def on_connection_lost(self): self.conn = None class RequestHandler(object): push_tunnels = dict() def __init__(self): Session = sessionmaker(bind = engine) self.session = Session() def __del__(self): self.session.close() # self.engine.dispose() def check_size(self, tr_data): if len(tr_data) > self._max_tr_data_size: raise BadReqError("Authentication: Request size is too large") @classmethod def get_uauth(cls, token, username, session): try: uauth = session.query(UserAuth) \ .filter(UserAuth.token == token).one() if uauth.user.username != username: logger.warning("Toke and username mismatch") return None uid = uauth.uid if not cls.push_tunnels.has_key(uid): cls.push_tunnels[uid] = PushTunnel() return uauth except NoResultFound: logger.warning("Incorrect token") return None except MultipleResultsFound: raise DBCorruptionError() @classmethod def trunc_padding(cls, data): leading = bytes() for i in xrange(len(data)): ch = data[i] if ch == '\x00': return (leading, data[i + 1:]) else: leading += ch # padding not found return (None, data) class UserAuthHandler(RequestHandler): _max_tr_data_size = MAX_USERNAME_SIZE + \ _SectionSize.PADDING + \ MAX_PASSWORD_SIZE + \ _SectionSize.PADDING _response_size = \ _SectionSize.LENGTH + \ _SectionSize.OPT_ID + \ _SectionSize.STATUS + \ _SectionSize.USER_ID + \ _SectionSize.USER_TOKEN _failed_response = \ struct.pack("!LBBL32s", _response_size, _OptCode.user_auth, _StatusCode.failure, 0, bytes('\x00' * 32)) def handle(self, tr_data, conn): self.check_size(tr_data) logger.info("Reading auth data...") pos = -1 for i in xrange(0, len(tr_data)): if tr_data[i] == '\x00': pos = i break if pos == -1: raise BadReqError("Authentication: Malformed request body") username = tr_data[0:pos] password = tr_data[pos + 1:-1] logger.info("Trying to login with " \ "(username = {0}, password = {1})" \ .format(username, password)) try: user = self.session.query(UserModel) \ .filter(UserModel.username == username).one() except NoResultFound: logger.info("No such user: {0}".format(username)) return self._failed_response except MultipleResultsFound: raise DBCorruptionError() uauth = user.auth if uauth is None: raise DBCorruptionError() if not uauth.check_password(password): logger.info("Incorrect password: {0}".format(password)) return self._failed_response else: logger.info("Logged in sucessfully: {0}".format(username)) uauth.regen_token() #logger.info("New token generated: " + get_hex(uauth.token)) self.session.commit() return struct.pack("!LBBL32s", self._response_size, _OptCode.user_auth, _StatusCode.sucess, user.id, uauth.token) class LocationUpdateHandler(RequestHandler): _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \ _SectionSize.LATITUDE + \ _SectionSize.LONGITUDE _response_size = \ _SectionSize.LENGTH + \ _SectionSize.OPT_ID + \ _SectionSize.STATUS def handle(self, tr_data, conn): self.check_size(tr_data) logger.info("Reading location update data...") try: token, = struct.unpack("!32s", tr_data[:32]) username, tail = RequestHandler.trunc_padding(tr_data[32:]) if username is None: raise struct.error lat, lng = struct.unpack("!dd", tail) except struct.error: raise BadReqError("Location update: Malformed request body") logger.info("Trying to update location with " "(token = {0}, username = {1}, lat = {2}, lng = {3})"\ .format(get_hex(token), username, lat, lng)) uauth = RequestHandler.get_uauth(token, username, self.session) # Authentication failure if uauth is None: logger.warning("Authentication failure") return struct.pack("!LBB", self._response_size, _OptCode.location_update, _StatusCode.failure) ulocation = uauth.user.location ulocation.lat = lat ulocation.lng = lng logger.info("Location is updated sucessfully") self.session.commit() return struct.pack("!LBB", self._response_size, _OptCode.location_update, _StatusCode.sucess) class LocationInfoHandler(RequestHandler): _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \ _SectionSize.GROUP_ID @classmethod def _response_size(cls, item_num): return _SectionSize.LENGTH + \ _SectionSize.OPT_ID + \ _SectionSize.STATUS + \ _SectionSize.LOCATION_ENTRY * item_num def handle(self, tr_data, conn): self.check_size(tr_data) logger.info("Reading location request data..") try: token, = struct.unpack("!32s", tr_data[:32]) username, tail = RequestHandler.trunc_padding(tr_data[32:]) if username is None: raise struct.error comp_id, sec_id = struct.unpack("!BB", tail) except struct.error: raise BadReqError("Location request: Malformed request body") logger.info("Trying to request locatin with " \ "(token = {0}, comp_id = {1}, sec_id = {2})" \ .format(get_hex(token), comp_id, sec_id)) uauth = RequestHandler.get_uauth(token, username, self.session) # Auth failure if uauth is None: logger.warning("Authentication failure") return struct.pack("!LBB", self._response_size(0), _OptCode.location_info, _StatusCode.failure) if sec_id == 0xff: # All members in the company ulist = self.session.query(UserModel) \ .filter(UserModel.comp_id == comp_id).all() else: ulist = self.session.query(UserModel) \ .filter(and_(UserModel.comp_id == comp_id, UserModel.sec_id == sec_id)).all() reply = struct.pack( "!LBB", self._response_size(len(ulist)), _OptCode.location_info, _StatusCode.sucess) for user in ulist: loc = user.location reply += struct.pack("!Ldd", user.id, loc.lat, loc.lng) return reply def pack_gid(user): return struct.pack("!BB", user.comp_id, user.sec_id) def pack_sex(user): return struct.pack("!B", 0x01 if user.sex else 0x00) class UserInfoHandler(RequestHandler): _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \ _SectionSize.USER_ID _failed_response_size = \ _SectionSize.LENGTH + \ _SectionSize.OPT_ID + \ _SectionSize.STATUS _fail_response = \ struct.pack("!LBB", _failed_response_size, _OptCode.user_info, _StatusCode.failure) _code_map = {0x00 : pack_gid, 0x01 : pack_sex} @classmethod def pack_entry(cls, user, entry_code): pack_method = cls._code_map[entry_code] info_key = entry_code return struct.pack("!B", info_key) + pack_method(user) def handle(self, tr_data, conn): self.check_size(tr_data) logger.info("Reading user info request data...") try: token, = struct.unpack("!32s", tr_data[:32]) username, tail = RequestHandler.trunc_padding(tr_data[32:]) if username is None: raise struct.error uid, = struct.unpack("!L", tail) except struct.error: raise BadReqError("User info request: Malformed request body") logger.info("Trying to user info with " \ "(token = {0}, uid = {1})" \ .format(get_hex(token), uid)) uauth = RequestHandler.get_uauth(token, username, self.session) # Auth failure if uauth is None: logger.warning("Authentication failure") return self._fail_response # TODO: check the relationship between user and quser user = uauth.user reply = struct.pack("!BB", _OptCode.user_info, _StatusCode.sucess) try: quser = self.session.query(UserModel) \ .filter(UserModel.id == uid).one() except NoResultFound: logger.info("No such user: {0}".format(username)) return self._fail_response except MultipleResultsFound: raise DBCorruptionError() for code in self._code_map: reply += UserInfoHandler.pack_entry(quser, code) reply = struct.pack("!L", len(reply) + _SectionSize.LENGTH) + reply return reply class UserLogoutHandler(RequestHandler): _max_tr_data_size = _MAX_AUTH_HEAD_SIZE _response_size = \ _SectionSize.LENGTH + \ _SectionSize.OPT_ID + \ _SectionSize.STATUS def handle(self, tr_data, conn): self.check_size(tr_data) logger.info("Reading user logout data...") try: token, = struct.unpack("!32s", tr_data[:32]) username, tail = RequestHandler.trunc_padding(tr_data[32:]) if username is None: raise struct.error except struct.error: raise BadReqError("User logout: Malformed request body") logger.info("Trying to logout with " "(token = {0}, username = {1})"\ .format(get_hex(token), username)) uauth = RequestHandler.get_uauth(token, username, self.session) # Authentication failure if uauth is None: logger.warning("Authentication failure") return struct.pack("!LBB", self._response_size, _OptCode.user_logout, _StatusCode.failure) del RequestHandler.push_tunnels[uauth.uid] uauth.regen_token() logger.info("User Logged out successfully!") self.session.commit() return struct.pack("!LBB", self._response_size, _OptCode.user_logout, _StatusCode.sucess) class OpenPushTunnelHandler(RequestHandler): _max_tr_data_size = _MAX_AUTH_HEAD_SIZE _response_size = \ _SectionSize.LENGTH + \ _SectionSize.OPT_ID + \ _SectionSize.STATUS def handle(self, tr_data, conn): self.check_size(tr_data) logger.info("Reading open push tunnel data...") try: token, = struct.unpack("!32s", tr_data[:32]) username, tail = RequestHandler.trunc_padding(tr_data[32:]) if username is None: raise struct.error except struct.error: raise BadReqError("Open push tunnel: Malformed request body") logger.info("Trying to open push tunnel with " "(token = {0}, username = {1})"\ .format(get_hex(token), username)) uauth = RequestHandler.get_uauth(token, username, self.session) # Authentication failure if uauth is None: logger.warning("Authentication failure") return struct.pack("!LBB", self._response_size, _OptCode.open_push_tunnel, _StatusCode.failure) tunnel = RequestHandler.push_tunnels[uauth.uid] pt = RequestHandler.push_tunnels uid = uauth.uid if pt.has_key(uid): tunnel = pt[uid] tunnel.connect(conn) tunnel.push() logger.info("Push tunnel opened successfully!") return struct.pack("!LBB", self._response_size, _OptCode.open_push_tunnel, _StatusCode.sucess) class SendTextMessageHandler(RequestHandler): _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \ _MAX_TEXT_MESG_SIZE + \ _SectionSize.PADDING _response_size = \ _SectionSize.LENGTH + \ _SectionSize.OPT_ID + \ _SectionSize.STATUS def handle(self, tr_data, conn): self.check_size(tr_data) logger.info("Reading send text mesg data...") try: token, = struct.unpack("!32s", tr_data[:32]) username, tail = RequestHandler.trunc_padding(tr_data[32:]) mesg = tail[:-1] if username is None: raise struct.error except struct.error: raise BadReqError("Send text mesg: Malformed request body") logger.info("Trying to send text mesg with " "(token = {0}, username = {1})"\ .format(get_hex(token), username)) uauth = RequestHandler.get_uauth(token, username, self.session) # Authentication failure if uauth is None: logger.warning("Authentication failure") return struct.pack("!LBB", self._response_size, _OptCode.send_text_mesg, _StatusCode.failure) pt = RequestHandler.push_tunnels u = uauth.user ulist = self.session.query(UserModel) \ .filter(and_(UserModel.comp_id == u.comp_id, UserModel.sec_id == u.sec_id)).all() for user in ulist: uid = user.id if pt.has_key(uid): tunnel = pt[uid] tunnel.add(PushTextMesgData(mesg)) tunnel.push() logger.info("Sent text mesg successfully!") return struct.pack("!LBB", self._response_size, _OptCode.send_text_mesg, _StatusCode.sucess) class PTP(Protocol, TimeoutMixin): handlers = [UserAuthHandler, LocationUpdateHandler, LocationInfoHandler, UserInfoHandler, UserLogoutHandler, OpenPushTunnelHandler, SendTextMessageHandler] handler_num = len(handlers) _MAX_REQUEST_SIZE = _HEADER_SIZE + \ max([h._max_tr_data_size for h in handlers]) @classmethod def check_header(cls, header): return 0 <= header < cls.handler_num def __init__(self, factory): self.buff = bytes() self.length = -1 self.factory = factory self.tunnel = None def timeoutConnection(self): logger.info("The connection times out") self.transport.loseConnection() def connectionMade(self): logger.info("A new connection is made") self.setTimeout(self.factory.timeout) def dataReceived(self, data): self.buff += data self.resetTimeout() logger.info("Buffer length is now: %d", len(self.buff)) if len(self.buff) <= 4: return try: if self.length == -1: try: self.length, self.optcode = struct.unpack("!LB", self.buff[:5]) if not PTP.check_header(self.optcode): # invalid header raise struct.error except struct.error: raise BadReqError("Malformed request header") if self.length > PTP._MAX_REQUEST_SIZE: print self.length, PTP._MAX_REQUEST_SIZE raise BadReqError("The size of remaining part is too big") if len(self.buff) == self.length: if self.tunnel: # received push response self.tunnel.on_receive(self.buff) self.buff = bytes() self.length = -1 return h = PTP.handlers[self.optcode]() reply = h.handle(self.buff[5:], self) if self.tunnel: logger.info("Blocking the client...") self.buff = bytes() self.length = -1 self.setTimeout(None) return logger.info("Wrote: %s", get_hex(reply)) self.transport.write(reply) self.transport.loseConnection() elif len(self.buff) > self.length: raise BadReqError("The actual length is larger than promised") except BadReqError as e: logger.warn("Rejected a bad request: %s", str(e)) self.transport.loseConnection() except DBCorruptionError: logger.error("*** Database corruption ***") self.transport.loseConnection() if self.tunnel is None: self.transport.loseConnection() def connectionLost(self, reason): if self.tunnel: self.tunnel.on_connection_lost() logger.info("The connection is lost") self.setTimeout(None) class PTPFactory(Factory): def __init__(self, timeout = 10): self.timeout = timeout def buildProtocol(self, addr): return PTP(self) if os.name!='nt': from twisted.internet import epollreactor epollreactor.install() else: from twisted.internet import iocpreactor iocpreactor.install() from twisted.internet import reactor f = PTPFactory() f.protocol = PTP reactor.listenTCP(2222, f) logger.warning("The server is lanuched") reactor.run()