from twisted.internet.protocol import Protocol from twisted.internet.protocol import Factory from twisted.internet.endpoints import TCP4ServerEndpoint from twisted.protocols.policies import TimeoutMixin from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound import struct import os import logging from exc import * from model import * def get_hex(data): return "".join([hex(ord(c))[2:].zfill(2) for c in data]) db_path = "root:helloworld@localhost/piztor" #db_path = "piztor.sqlite" FORMAT = "%(asctime)-15s %(message)s" logging.basicConfig(format = FORMAT) logger = logging.getLogger('piztor_server') logger.setLevel(logging.INFO) class _SectionSize: LENGTH = 4 OPT_ID = 1 STATUS = 1 USER_ID = 4 USER_TOKEN = 32 GROUP_ID = 4 ENTRY_CNT = 4 LATITUDE = 8 LONGITUDE = 8 LOCATION_ENTRY = USER_ID + LATITUDE + LONGITUDE PADDING = 1 _MAX_AUTH_HEAD_SIZE = _SectionSize.USER_TOKEN + \ MAX_USERNAME_SIZE + \ _SectionSize.PADDING _HEADER_SIZE = _SectionSize.LENGTH + \ _SectionSize.OPT_ID class _OptCode: user_auth = 0x00 location_update = 0x01 location_request= 0x02 user_info_request = 0x03 class _StatusCode: sucess = 0x00 failure = 0x01 class RequestHandler(object): def __init__(self): self.engine = create_engine('mysql://' + db_path, echo = False) Session = sessionmaker(bind = self.engine) self.session = Session() def __del__(self): self.session.close() self.engine.dispose() def check_size(self, tr_data): if len(tr_data) > self._max_tr_data_size: raise BadReqError("Authentication: Request size is too large") @classmethod def get_uauth(cls, token, username, session): try: uauth = session.query(UserAuth) \ .filter(UserAuth.token == token).one() if uauth.user.username != username: logger.warning("Toke and username mismatch") return None return uauth except NoResultFound: logger.warning("Incorrect token") return None except MultipleResultsFound: raise DBCorruptedError() @classmethod def trunc_padding(cls, data): leading = bytes() for i in xrange(len(data)): ch = data[i] if ch == '\x00': return (leading, data[i + 1:]) else: leading += ch # padding not found return (None, data) class UserAuthHandler(RequestHandler): _max_tr_data_size = MAX_USERNAME_SIZE + \ _SectionSize.PADDING + \ MAX_PASSWORD_SIZE + \ _SectionSize.PADDING _response_size = \ _SectionSize.LENGTH + \ _SectionSize.OPT_ID + \ _SectionSize.STATUS + \ _SectionSize.USER_ID + \ _SectionSize.USER_TOKEN _failed_response = \ struct.pack("!LBBL32s", _response_size, _OptCode.user_auth, _StatusCode.failure, 0, bytes('\x00' * 32)) def handle(self, tr_data): self.check_size(tr_data) logger.info("Reading auth data...") pos = -1 for i in xrange(0, len(tr_data)): if tr_data[i] == '\x00': pos = i break if pos == -1: raise BadReqError("Authentication: Malformed request body") username = tr_data[0:pos] password = tr_data[pos + 1:-1] logger.info("Trying to login with " \ "(username = {0}, password = {1})" \ .format(username, password)) try: user = self.session.query(UserModel) \ .filter(UserModel.username == username).one() except NoResultFound: logger.info("No such user: {0}".format(username)) return UserAuthHandler._failed_response except MultipleResultsFound: raise DBCorruptedError() uauth = user.auth if uauth is None: raise DBCorruptedError() if not uauth.check_password(password): logger.info("Incorrect password: {0}".format(password)) return UserAuthHandler._failed_response else: logger.info("Logged in sucessfully: {0}".format(username)) uauth.regen_token() logger.info("New token generated: " + get_hex(uauth.token)) self.session.commit() return struct.pack("!LBBL32s", UserAuthHandler._response_size, _OptCode.user_auth, _StatusCode.sucess, user.id, uauth.token) class LocationUpdateHandler(RequestHandler): _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \ _SectionSize.LATITUDE + \ _SectionSize.LONGITUDE _response_size = \ _SectionSize.LENGTH + \ _SectionSize.OPT_ID + \ _SectionSize.STATUS def handle(self, tr_data): self.check_size(tr_data) logger.info("Reading location update data...") try: token, = struct.unpack("!32s", tr_data[:32]) username, tail = RequestHandler.trunc_padding(tr_data[32:]) if username is None: raise struct.error lat, lng = struct.unpack("!dd", tail) except struct.error: raise BadReqError("Location update: Malformed request body") logger.info("Trying to update location with " "(token = {0}, username = {1}, lat = {2}, lng = {3})"\ .format(get_hex(token), username, lat, lng)) uauth = RequestHandler.get_uauth(token, username, self.session) # Authentication failure if uauth is None: logger.warning("Authentication failure") return struct.pack("!LBB", LocationUpdateHandler._response_size, _OptCode.location_update, _StatusCode.failure) ulocation = uauth.user.location ulocation.lat = lat ulocation.lng = lng logger.info("Location is updated sucessfully") self.session.commit() return struct.pack("!LBB", LocationUpdateHandler._response_size, _OptCode.location_update, _StatusCode.sucess) class LocationInfoHandler(RequestHandler): _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \ _SectionSize.GROUP_ID @classmethod def _response_size(cls, item_num): return _SectionSize.LENGTH + \ _SectionSize.OPT_ID + \ _SectionSize.STATUS + \ _SectionSize.LOCATION_ENTRY * item_num def handle(self, tr_data): self.check_size(tr_data) logger.info("Reading location request data..") try: token, = struct.unpack("!32s", tr_data[:32]) username, tail = RequestHandler.trunc_padding(tr_data[32:]) if username is None: raise struct.error gid, = struct.unpack("!L", tail) except struct.error: raise BadReqError("Location request: Malformed request body") logger.info("Trying to request locatin with " \ "(token = {0}, gid = {1})" \ .format(get_hex(token), gid)) uauth = RequestHandler.get_uauth(token, username, self.session) # Auth failure if uauth is None: logger.warning("Authentication failure") return struct.pack("!LBB", LocationInfoHandler._response_size(0), _OptCode.location_request, _StatusCode.failure) ulist = self.session.query(UserModel).filter(UserModel.gid == gid).all() reply = struct.pack( "!LBB", LocationInfoHandler._response_size(len(ulist)), _OptCode.location_request, _StatusCode.sucess) for user in ulist: loc = user.location reply += struct.pack("!Ldd", user.id, loc.lat, loc.lng) return reply def pack_int(val): return struct.pack("!L", val) def pack_bool(val): return struct.pack("!B", 0x01 if val else 0x00) class UserInfoHandler(RequestHandler): _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \ _SectionSize.USER_ID _failed_response_size = \ _SectionSize.LENGTH + \ _SectionSize.OPT_ID + \ _SectionSize.STATUS _fail_response = \ struct.pack("!LBB", _failed_response_size, _OptCode.user_info_request, _StatusCode.failure) _code_map = {0x00 : ('gid', pack_int), 0x01 : ('sex', pack_bool)} @classmethod def pack_entry(cls, user, entry_code): attr, pack_method = cls._code_map[entry_code] info_key = entry_code info_value = getattr(user, attr) return struct.pack("!B", info_key) + pack_method(info_value) def handle(self, tr_data): self.check_size(tr_data) logger.info("Reading user info request data...") try: token, = struct.unpack("!32s", tr_data[:32]) username, tail = RequestHandler.trunc_padding(tr_data[32:]) if username is None: raise struct.error uid, = struct.unpack("!L", tail) except struct.error: raise BadReqError("User info request: Malformed request body") logger.info("Trying to request locatin with " \ "(token = {0}, uid = {1})" \ .format(get_hex(token), uid)) uauth = RequestHandler.get_uauth(token, username, self.session) # Auth failure if uauth is None: logger.warning("Authentication failure") return UserInfoHandler._fail_response # TODO: check the relationship between user and quser user = uauth.user reply = struct.pack("!BB", _OptCode.user_info_request, _StatusCode.sucess) try: quser = self.session.query(UserModel) \ .filter(UserModel.id == uid).one() except NoResultFound: logger.info("No such user: {0}".format(username)) return UserInfoHandler._fail_response except MultipleResultsFound: raise DBCorruptedError() for code in UserInfoHandler._code_map: reply += UserInfoHandler.pack_entry(quser, code) reply = struct.pack("!L", len(reply) + _SectionSize.LENGTH) + reply return reply class PTP(Protocol, TimeoutMixin): handlers = [UserAuthHandler, LocationUpdateHandler, LocationInfoHandler, UserInfoHandler] handler_num = len(handlers) _MAX_REQUEST_SIZE = _HEADER_SIZE + \ max([h._max_tr_data_size for h in handlers]) @classmethod def check_header(cls, header): return 0 <= header < cls.handler_num def __init__(self, factory): self.buff = bytes() self.length = -1 self.factory = factory def timeoutConnection(self): logger.info("The connection times out") def connectionMade(self): logger.info("A new connection is made") self.setTimeout(self.factory.timeout) def dataReceived(self, data): self.buff += data self.resetTimeout() logger.info("Buffer length is now: %d", len(self.buff)) if len(self.buff) <= 4: return try: if self.length == -1: try: self.length, self.optcode = struct.unpack("!LB", self.buff[:5]) if not PTP.check_header(self.optcode): # invalid header raise struct.error except struct.error: raise BadReqError("Malformed request header") if self.length > PTP._MAX_REQUEST_SIZE: print self.length, PTP._MAX_REQUEST_SIZE raise BadReqError("The size of remaining part is too big") if len(self.buff) == self.length: h = PTP.handlers[self.optcode]() reply = h.handle(self.buff[5:]) logger.info("Wrote: %s", get_hex(reply)) self.transport.write(reply) self.transport.loseConnection() elif len(self.buff) > self.length: raise BadReqError("The actual length is larger than promised") except BadReqError as e: logger.warn("Rejected a bad request: %s", str(e)) except DBCorruptedError: logger.error("*** Database corruption ***") finally: self.transport.loseConnection() def connectionLost(self, reason): logger.info("The connection is lost") self.setTimeout(None) class PTPFactory(Factory): def __init__(self, timeout = 10): self.timeout = timeout def buildProtocol(self, addr): return PTP(self) #if os.name!='nt': # from twisted.internet import epollreactor # epollreactor.install() #else: # from twisted.internet import iocpreactor # iocpreactor.install() from twisted.internet import reactor f = PTPFactory() f.protocol = PTP reactor.listenTCP(2222, f) logger.warning("The server is lanuched") reactor.run()