diff options
Diffstat (limited to 'server')
-rw-r--r-- | server/piztor/model.py | 4 | ||||
-rw-r--r-- | server/piztor/prob.py | 3 | ||||
-rw-r--r-- | server/piztor/server.py | 120 |
3 files changed, 77 insertions, 50 deletions
diff --git a/server/piztor/model.py b/server/piztor/model.py index 9e3b132..8916e3a 100644 --- a/server/piztor/model.py +++ b/server/piztor/model.py @@ -6,6 +6,8 @@ Base = declarative_base() _SALT_LEN = 16 _TOKEN_LEN = 16 +MAX_USERNAME_SIZE = 20 +MAX_PASSWORD_SIZE = 20 class _TableName: # avoid typoes UserModel = 'users' @@ -17,7 +19,7 @@ class UserModel(Base): id = Column(Integer, primary_key = True) gid = Column(Integer) - username = Column(String(20)) + username = Column(String(MAX_USERNAME_SIZE)) sex = Column(Boolean) location = None auth = None diff --git a/server/piztor/prob.py b/server/piztor/prob.py index 6f7ef14..9798c18 100644 --- a/server/piztor/prob.py +++ b/server/piztor/prob.py @@ -56,7 +56,8 @@ def send(data): from sys import argv -username = "hello" +#username = "hello" +username = "12345678901234567890" password = "world" gid = 1 diff --git a/server/piztor/server.py b/server/piztor/server.py index 204496c..eb321fc 100644 --- a/server/piztor/server.py +++ b/server/piztor/server.py @@ -22,7 +22,7 @@ db_path = "root:helloworld@localhost/piztor" FORMAT = "%(asctime)-15s %(message)s" logging.basicConfig(format = FORMAT) logger = logging.getLogger('piztor_server') -logger.setLevel(logging.WARN) +logger.setLevel(logging.INFO) class _SectionSize: @@ -38,6 +38,12 @@ class _SectionSize: LOCATION_ENTRY = USER_ID + LATITUDE + LONGITUDE PADDING = 1 +_MAX_AUTH_HEAD_SIZE = _SectionSize.USER_TOKEN + \ + MAX_USERNAME_SIZE + \ + _SectionSize.PADDING +_HEADER_SIZE = _SectionSize.LENGTH + \ + _SectionSize.OPT_ID + class _OptCode: user_auth = 0x00 location_update = 0x01 @@ -51,11 +57,17 @@ class _StatusCode: class RequestHandler(object): def __init__(self): self.engine = create_engine('mysql://' + db_path, echo = False) - self.Session = sessionmaker(bind = self.engine) + Session = sessionmaker(bind = self.engine) + self.session = Session() def __del__(self): + self.session.close() self.engine.dispose() + def check_size(self, tr_data): + if len(tr_data) > self._max_tr_data_size: + raise BadReqError("Authentication: Request size is too large") + @classmethod def get_uauth(cls, token, username, session): try: @@ -73,7 +85,6 @@ class RequestHandler(object): return None except MultipleResultsFound: - session.close() raise DBCorruptedError() @classmethod @@ -82,7 +93,6 @@ class RequestHandler(object): for i in xrange(len(data)): ch = data[i] if ch == '\x00': - print get_hex(leading), get_hex(data[i + 1:]) return (leading, data[i + 1:]) else: leading += ch @@ -91,6 +101,11 @@ class RequestHandler(object): class UserAuthHandler(RequestHandler): + _max_tr_data_size = MAX_USERNAME_SIZE + \ + _SectionSize.PADDING + \ + MAX_PASSWORD_SIZE + \ + _SectionSize.PADDING + _response_size = \ _SectionSize.LENGTH + \ _SectionSize.OPT_ID + \ @@ -107,6 +122,7 @@ class UserAuthHandler(RequestHandler): def handle(self, tr_data): + self.check_size(tr_data) logger.info("Reading auth data...") pos = -1 for i in xrange(0, len(tr_data)): @@ -122,32 +138,27 @@ class UserAuthHandler(RequestHandler): "(username = {0}, password = {1})" \ .format(username, password)) - session = self.Session() try: - user = session.query(UserModel) \ + user = self.session.query(UserModel) \ .filter(UserModel.username == username).one() except NoResultFound: logger.info("No such user: {0}".format(username)) - session.commit() return UserAuthHandler._failed_response except MultipleResultsFound: - session.close() raise DBCorruptedError() uauth = user.auth if uauth is None: - session.close() raise DBCorruptedError() if not uauth.check_password(password): logger.info("Incorrect password: {0}".format(password)) - session.commit() return UserAuthHandler._failed_response else: logger.info("Logged in sucessfully: {0}".format(username)) uauth.regen_token() logger.info("New token generated: " + get_hex(uauth.token)) - session.commit() + self.session.commit() return struct.pack("!LBBL32s", UserAuthHandler._response_size, _OptCode.user_auth, _StatusCode.sucess, @@ -157,14 +168,18 @@ class UserAuthHandler(RequestHandler): class LocationUpdateHandler(RequestHandler): + _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \ + _SectionSize.LATITUDE + \ + _SectionSize.LONGITUDE + _response_size = \ _SectionSize.LENGTH + \ _SectionSize.OPT_ID + \ _SectionSize.STATUS def handle(self, tr_data): + self.check_size(tr_data) logger.info("Reading location update data...") - try: token, = struct.unpack("!32s", tr_data[:32]) username, tail = RequestHandler.trunc_padding(tr_data[32:]) @@ -178,12 +193,10 @@ class LocationUpdateHandler(RequestHandler): "(token = {0}, username = {1}, lat = {2}, lng = {3})"\ .format(get_hex(token), username, lat, lng)) - session = self.Session() - uauth = RequestHandler.get_uauth(token, username, session) + uauth = RequestHandler.get_uauth(token, username, self.session) # Authentication failure if uauth is None: logger.warning("Authentication failure") - session.commit() return struct.pack("!LBB", LocationUpdateHandler._response_size, _OptCode.location_update, _StatusCode.failure) @@ -193,12 +206,15 @@ class LocationUpdateHandler(RequestHandler): ulocation.lng = lng logger.info("Location is updated sucessfully") - session.commit() + self.session.commit() return struct.pack("!LBB", LocationUpdateHandler._response_size, _OptCode.location_update, _StatusCode.sucess) -class LocationRequestHandler(RequestHandler): +class LocationInfoHandler(RequestHandler): + + _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \ + _SectionSize.GROUP_ID @classmethod def _response_size(cls, item_num): @@ -208,8 +224,8 @@ class LocationRequestHandler(RequestHandler): _SectionSize.LOCATION_ENTRY * item_num def handle(self, tr_data): + self.check_size(tr_data) logger.info("Reading location request data..") - try: token, = struct.unpack("!32s", tr_data[:32]) username, tail = RequestHandler.trunc_padding(tr_data[32:]) @@ -223,20 +239,18 @@ class LocationRequestHandler(RequestHandler): "(token = {0}, gid = {1})" \ .format(get_hex(token), gid)) - session = self.Session() - uauth = RequestHandler.get_uauth(token, username, session) + uauth = RequestHandler.get_uauth(token, username, self.session) # Auth failure if uauth is None: logger.warning("Authentication failure") - session.commit() - return struct.pack("!LBB", LocationRequestHandler._response_size(0), + return struct.pack("!LBB", LocationInfoHandler._response_size(0), _OptCode.location_request, _StatusCode.failure) - ulist = session.query(UserModel).filter(UserModel.gid == gid).all() + ulist = self.session.query(UserModel).filter(UserModel.gid == gid).all() reply = struct.pack( "!LBB", - LocationRequestHandler._response_size(len(ulist)), + LocationInfoHandler._response_size(len(ulist)), _OptCode.location_request, _StatusCode.sucess) @@ -244,7 +258,6 @@ class LocationRequestHandler(RequestHandler): loc = user.location reply += struct.pack("!Ldd", user.id, loc.lat, loc.lng) - session.commit() return reply def pack_int(val): @@ -254,7 +267,10 @@ def pack_bool(val): return struct.pack("!B", 0x01 if val else 0x00) -class UserInfoRequestHandler(RequestHandler): +class UserInfoHandler(RequestHandler): + + _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \ + _SectionSize.USER_ID _failed_response_size = \ _SectionSize.LENGTH + \ @@ -277,8 +293,8 @@ class UserInfoRequestHandler(RequestHandler): return struct.pack("!B", info_key) + pack_method(info_value) def handle(self, tr_data): + self.check_size(tr_data) logger.info("Reading user info request data...") - try: token, = struct.unpack("!32s", tr_data[:32]) username, tail = RequestHandler.trunc_padding(tr_data[32:]) @@ -292,47 +308,49 @@ class UserInfoRequestHandler(RequestHandler): "(token = {0}, uid = {1})" \ .format(get_hex(token), uid)) - session = self.Session() - uauth = RequestHandler.get_uauth(token, username, session) + uauth = RequestHandler.get_uauth(token, username, self.session) # Auth failure if uauth is None: logger.warning("Authentication failure") - session.commit() - return UserInfoRequestHandler._fail_response + return UserInfoHandler._fail_response # TODO: check the relationship between user and quser user = uauth.user reply = struct.pack("!BB", _OptCode.user_info_request, _StatusCode.sucess) try: - quser = session.query(UserModel) \ + quser = self.session.query(UserModel) \ .filter(UserModel.id == uid).one() except NoResultFound: logger.info("No such user: {0}".format(username)) - session.commit() - return UserInfoRequestHandler._fail_response + return UserInfoHandler._fail_response except MultipleResultsFound: - session.close() raise DBCorruptedError() - for code in UserInfoRequestHandler._code_map: - reply += UserInfoRequestHandler.pack_entry(quser, code) + for code in UserInfoHandler._code_map: + reply += UserInfoHandler.pack_entry(quser, code) reply = struct.pack("!L", len(reply) + _SectionSize.LENGTH) + reply - session.commit() return reply -handlers = [UserAuthHandler, - LocationUpdateHandler, - LocationRequestHandler, - UserInfoRequestHandler] - -def check_header(header): - return 0 <= header < len(handlers) class PTP(Protocol, TimeoutMixin): + handlers = [UserAuthHandler, + LocationUpdateHandler, + LocationInfoHandler, + UserInfoHandler] + + handler_num = len(handlers) + + _MAX_REQUEST_SIZE = _HEADER_SIZE + \ + max([h._max_tr_data_size for h in handlers]) + + @classmethod + def check_header(cls, header): + return 0 <= header < cls.handler_num + def __init__(self, factory): self.buff = bytes() self.length = -1 @@ -352,14 +370,19 @@ class PTP(Protocol, TimeoutMixin): if len(self.buff) > 4: try: self.length, self.optcode = struct.unpack("!LB", self.buff[:5]) - if not check_header(self.optcode): # invalid header + if not PTP.check_header(self.optcode): # invalid header raise struct.error except struct.error: - logger.warning("Invalid request header") raise BadReqError("Malformed request header") + if self.length > PTP._MAX_REQUEST_SIZE: + print self.length, PTP._MAX_REQUEST_SIZE + raise BadReqError("The size of remaining part is too big") + + # Incomplete length info if self.length == -1: return + if len(self.buff) == self.length: - h = handlers[self.optcode]() + h = PTP.handlers[self.optcode]() reply = h.handle(self.buff[5:]) logger.info("Wrote: %s", get_hex(reply)) self.transport.write(reply) @@ -391,4 +414,5 @@ from twisted.internet import reactor f = PTPFactory() f.protocol = PTP reactor.listenTCP(2222, f) +logger.warning("The server is on") reactor.run() |