From 2c3891cf77538c57b172404c5c35b846ec6b0de6 Mon Sep 17 00:00:00 2001
From: Teddy <ted.sybil@gmail.com>
Date: Tue, 27 Aug 2013 09:54:58 +0800
Subject: server: packet size check

---
 server/piztor/model.py  |   4 +-
 server/piztor/prob.py   |   3 +-
 server/piztor/server.py | 120 +++++++++++++++++++++++++++++-------------------
 3 files changed, 77 insertions(+), 50 deletions(-)

(limited to 'server')

diff --git a/server/piztor/model.py b/server/piztor/model.py
index 9e3b132..8916e3a 100644
--- a/server/piztor/model.py
+++ b/server/piztor/model.py
@@ -6,6 +6,8 @@ Base = declarative_base()
 
 _SALT_LEN = 16
 _TOKEN_LEN = 16
+MAX_USERNAME_SIZE = 20
+MAX_PASSWORD_SIZE = 20
 
 class _TableName:   # avoid typoes
     UserModel = 'users'
@@ -17,7 +19,7 @@ class UserModel(Base):
 
     id = Column(Integer, primary_key = True)
     gid = Column(Integer)
-    username = Column(String(20))
+    username = Column(String(MAX_USERNAME_SIZE))
     sex = Column(Boolean)
     location = None
     auth = None
diff --git a/server/piztor/prob.py b/server/piztor/prob.py
index 6f7ef14..9798c18 100644
--- a/server/piztor/prob.py
+++ b/server/piztor/prob.py
@@ -56,7 +56,8 @@ def send(data):
 
 from sys import argv
 
-username = "hello"
+#username = "hello"
+username = "12345678901234567890"
 password = "world"
 gid = 1
 
diff --git a/server/piztor/server.py b/server/piztor/server.py
index 204496c..eb321fc 100644
--- a/server/piztor/server.py
+++ b/server/piztor/server.py
@@ -22,7 +22,7 @@ db_path = "root:helloworld@localhost/piztor"
 FORMAT = "%(asctime)-15s %(message)s"
 logging.basicConfig(format = FORMAT)
 logger = logging.getLogger('piztor_server')
-logger.setLevel(logging.WARN)
+logger.setLevel(logging.INFO)
 
 
 class _SectionSize:
@@ -38,6 +38,12 @@ class _SectionSize:
     LOCATION_ENTRY = USER_ID + LATITUDE + LONGITUDE
     PADDING = 1
 
+_MAX_AUTH_HEAD_SIZE = _SectionSize.USER_TOKEN + \
+                      MAX_USERNAME_SIZE + \
+                      _SectionSize.PADDING
+_HEADER_SIZE = _SectionSize.LENGTH + \
+                _SectionSize.OPT_ID
+
 class _OptCode:
     user_auth = 0x00
     location_update = 0x01
@@ -51,11 +57,17 @@ class _StatusCode:
 class RequestHandler(object):
     def __init__(self):
         self.engine = create_engine('mysql://' + db_path, echo = False)
-        self.Session = sessionmaker(bind = self.engine)
+        Session = sessionmaker(bind = self.engine)
+        self.session = Session()
 
     def __del__(self):
+        self.session.close()
         self.engine.dispose()
 
+    def check_size(self, tr_data):
+        if len(tr_data) > self._max_tr_data_size:
+            raise BadReqError("Authentication: Request size is too large")
+
     @classmethod
     def get_uauth(cls, token, username, session):
         try:
@@ -73,7 +85,6 @@ class RequestHandler(object):
             return None
 
         except MultipleResultsFound:
-            session.close()
             raise DBCorruptedError()
 
     @classmethod
@@ -82,7 +93,6 @@ class RequestHandler(object):
         for i in xrange(len(data)):
             ch = data[i]
             if ch == '\x00':
-                print get_hex(leading), get_hex(data[i + 1:])
                 return (leading, data[i + 1:])
             else:
                 leading += ch
@@ -91,6 +101,11 @@ class RequestHandler(object):
 
 class UserAuthHandler(RequestHandler):
 
+    _max_tr_data_size = MAX_USERNAME_SIZE + \
+                        _SectionSize.PADDING + \
+                        MAX_PASSWORD_SIZE + \
+                        _SectionSize.PADDING
+
     _response_size = \
             _SectionSize.LENGTH + \
             _SectionSize.OPT_ID + \
@@ -107,6 +122,7 @@ class UserAuthHandler(RequestHandler):
 
 
     def handle(self, tr_data):
+        self.check_size(tr_data)
         logger.info("Reading auth data...")
         pos = -1
         for i in xrange(0, len(tr_data)):
@@ -122,32 +138,27 @@ class UserAuthHandler(RequestHandler):
                     "(username = {0}, password = {1})" \
                 .format(username, password))
 
-        session = self.Session()
         try:
-            user = session.query(UserModel) \
+            user = self.session.query(UserModel) \
                 .filter(UserModel.username == username).one()
         except NoResultFound:
             logger.info("No such user: {0}".format(username))
-            session.commit()
             return UserAuthHandler._failed_response
 
         except MultipleResultsFound:
-            session.close()
             raise DBCorruptedError()
 
         uauth = user.auth
         if uauth is None:
-            session.close()
             raise DBCorruptedError()
         if not uauth.check_password(password):
             logger.info("Incorrect password: {0}".format(password))
-            session.commit()
             return UserAuthHandler._failed_response
         else:
             logger.info("Logged in sucessfully: {0}".format(username))
             uauth.regen_token()
             logger.info("New token generated: " + get_hex(uauth.token))
-            session.commit()
+            self.session.commit()
             return struct.pack("!LBBL32s", UserAuthHandler._response_size,
                                            _OptCode.user_auth,
                                            _StatusCode.sucess,
@@ -157,14 +168,18 @@ class UserAuthHandler(RequestHandler):
 
 class LocationUpdateHandler(RequestHandler):
 
+    _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \
+                        _SectionSize.LATITUDE + \
+                        _SectionSize.LONGITUDE
+
     _response_size = \
             _SectionSize.LENGTH + \
             _SectionSize.OPT_ID + \
             _SectionSize.STATUS
 
     def handle(self, tr_data):
+        self.check_size(tr_data)
         logger.info("Reading location update data...")
-
         try:
             token, = struct.unpack("!32s", tr_data[:32])
             username, tail = RequestHandler.trunc_padding(tr_data[32:])
@@ -178,12 +193,10 @@ class LocationUpdateHandler(RequestHandler):
                     "(token = {0}, username = {1}, lat = {2}, lng = {3})"\
                 .format(get_hex(token), username, lat, lng))
 
-        session = self.Session()
-        uauth = RequestHandler.get_uauth(token, username, session)
+        uauth = RequestHandler.get_uauth(token, username, self.session)
         # Authentication failure
         if uauth is None:
             logger.warning("Authentication failure")
-            session.commit()
             return struct.pack("!LBB",  LocationUpdateHandler._response_size,
                                         _OptCode.location_update,
                                         _StatusCode.failure)
@@ -193,12 +206,15 @@ class LocationUpdateHandler(RequestHandler):
         ulocation.lng = lng
 
         logger.info("Location is updated sucessfully")
-        session.commit()
+        self.session.commit()
         return struct.pack("!LBB",  LocationUpdateHandler._response_size,
                                     _OptCode.location_update,
                                     _StatusCode.sucess)
 
-class LocationRequestHandler(RequestHandler):
+class LocationInfoHandler(RequestHandler):
+
+    _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \
+                        _SectionSize.GROUP_ID
 
     @classmethod
     def _response_size(cls, item_num):
@@ -208,8 +224,8 @@ class LocationRequestHandler(RequestHandler):
                 _SectionSize.LOCATION_ENTRY * item_num
 
     def handle(self, tr_data):
+        self.check_size(tr_data)
         logger.info("Reading location request data..")
-
         try:
             token, = struct.unpack("!32s", tr_data[:32])
             username, tail = RequestHandler.trunc_padding(tr_data[32:])
@@ -223,20 +239,18 @@ class LocationRequestHandler(RequestHandler):
                     "(token = {0}, gid = {1})" \
             .format(get_hex(token), gid))
 
-        session = self.Session()
-        uauth = RequestHandler.get_uauth(token, username, session)
+        uauth = RequestHandler.get_uauth(token, username, self.session)
         # Auth failure
         if uauth is None:
             logger.warning("Authentication failure")
-            session.commit()
-            return struct.pack("!LBB", LocationRequestHandler._response_size(0),
+            return struct.pack("!LBB", LocationInfoHandler._response_size(0),
                                         _OptCode.location_request,
                                         _StatusCode.failure)
 
-        ulist = session.query(UserModel).filter(UserModel.gid == gid).all()
+        ulist = self.session.query(UserModel).filter(UserModel.gid == gid).all()
         reply = struct.pack(
                 "!LBB", 
-                LocationRequestHandler._response_size(len(ulist)),
+                LocationInfoHandler._response_size(len(ulist)),
                 _OptCode.location_request, 
                 _StatusCode.sucess)
 
@@ -244,7 +258,6 @@ class LocationRequestHandler(RequestHandler):
             loc = user.location
             reply += struct.pack("!Ldd", user.id, loc.lat, loc.lng)
 
-        session.commit()
         return reply
 
 def pack_int(val):
@@ -254,7 +267,10 @@ def pack_bool(val):
     return struct.pack("!B", 0x01 if val else 0x00)
 
 
-class UserInfoRequestHandler(RequestHandler):
+class UserInfoHandler(RequestHandler):
+
+    _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \
+                        _SectionSize.USER_ID
 
     _failed_response_size = \
             _SectionSize.LENGTH + \
@@ -277,8 +293,8 @@ class UserInfoRequestHandler(RequestHandler):
         return struct.pack("!B", info_key) + pack_method(info_value)
 
     def handle(self, tr_data):
+        self.check_size(tr_data)
         logger.info("Reading user info request data...")
-
         try:
             token, = struct.unpack("!32s", tr_data[:32])
             username, tail = RequestHandler.trunc_padding(tr_data[32:])
@@ -292,47 +308,49 @@ class UserInfoRequestHandler(RequestHandler):
                     "(token = {0}, uid = {1})" \
             .format(get_hex(token), uid))
 
-        session = self.Session()
-        uauth = RequestHandler.get_uauth(token, username, session)
+        uauth = RequestHandler.get_uauth(token, username, self.session)
         # Auth failure
         if uauth is None:
             logger.warning("Authentication failure")
-            session.commit()
-            return UserInfoRequestHandler._fail_response
+            return UserInfoHandler._fail_response
         # TODO: check the relationship between user and quser
         user = uauth.user 
 
         reply = struct.pack("!BB", _OptCode.user_info_request,
                                     _StatusCode.sucess)
         try:
-            quser = session.query(UserModel) \
+            quser = self.session.query(UserModel) \
                     .filter(UserModel.id == uid).one()
         except NoResultFound:
             logger.info("No such user: {0}".format(username))
-            session.commit()
-            return UserInfoRequestHandler._fail_response
+            return UserInfoHandler._fail_response
 
         except MultipleResultsFound:
-            session.close()
             raise DBCorruptedError()
 
-        for code in UserInfoRequestHandler._code_map:
-            reply += UserInfoRequestHandler.pack_entry(quser, code)
+        for code in UserInfoHandler._code_map:
+            reply += UserInfoHandler.pack_entry(quser, code)
         reply = struct.pack("!L", len(reply) + _SectionSize.LENGTH) + reply
-        session.commit()
         return reply
 
         
-handlers = [UserAuthHandler,
-            LocationUpdateHandler,
-            LocationRequestHandler,
-            UserInfoRequestHandler]
-
-def check_header(header):
-    return 0 <= header < len(handlers)
 
 class PTP(Protocol, TimeoutMixin):
 
+    handlers = [UserAuthHandler,
+                LocationUpdateHandler,
+                LocationInfoHandler,
+                UserInfoHandler]
+
+    handler_num = len(handlers)
+
+    _MAX_REQUEST_SIZE = _HEADER_SIZE + \
+                        max([h._max_tr_data_size for h in handlers])
+
+    @classmethod
+    def check_header(cls, header):
+        return 0 <= header < cls.handler_num
+
     def __init__(self, factory):
         self.buff = bytes()
         self.length = -1
@@ -352,14 +370,19 @@ class PTP(Protocol, TimeoutMixin):
         if len(self.buff) > 4:
             try:
                 self.length, self.optcode = struct.unpack("!LB", self.buff[:5])
-                if not check_header(self.optcode):    # invalid header
+                if not PTP.check_header(self.optcode):    # invalid header
                     raise struct.error
             except struct.error:
-                logger.warning("Invalid request header")
                 raise BadReqError("Malformed request header")
+            if self.length > PTP._MAX_REQUEST_SIZE:
+                print self.length, PTP._MAX_REQUEST_SIZE
+                raise BadReqError("The size of remaining part is too big")
+
+        # Incomplete length info
         if self.length == -1: return
+
         if len(self.buff) == self.length:
-            h = handlers[self.optcode]()
+            h = PTP.handlers[self.optcode]()
             reply = h.handle(self.buff[5:])
             logger.info("Wrote: %s", get_hex(reply))
             self.transport.write(reply)
@@ -391,4 +414,5 @@ from twisted.internet import reactor
 f = PTPFactory()
 f.protocol = PTP
 reactor.listenTCP(2222, f)
+logger.warning("The server is on")
 reactor.run()
-- 
cgit v1.2.3-70-g09d2