summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTeddy <ted.sybil@gmail.com>2013-08-27 09:54:58 +0800
committerTeddy <ted.sybil@gmail.com>2013-08-27 09:54:58 +0800
commit2c3891cf77538c57b172404c5c35b846ec6b0de6 (patch)
treeaac1caddbcae1fdf505965d022ef1a3cf3f5b5e0
parent3849201b28bc38bb8b42574bdf847dd7a5ce8aa7 (diff)
server: packet size check
-rw-r--r--server/piztor/model.py4
-rw-r--r--server/piztor/prob.py3
-rw-r--r--server/piztor/server.py120
3 files changed, 77 insertions, 50 deletions
diff --git a/server/piztor/model.py b/server/piztor/model.py
index 9e3b132..8916e3a 100644
--- a/server/piztor/model.py
+++ b/server/piztor/model.py
@@ -6,6 +6,8 @@ Base = declarative_base()
_SALT_LEN = 16
_TOKEN_LEN = 16
+MAX_USERNAME_SIZE = 20
+MAX_PASSWORD_SIZE = 20
class _TableName: # avoid typoes
UserModel = 'users'
@@ -17,7 +19,7 @@ class UserModel(Base):
id = Column(Integer, primary_key = True)
gid = Column(Integer)
- username = Column(String(20))
+ username = Column(String(MAX_USERNAME_SIZE))
sex = Column(Boolean)
location = None
auth = None
diff --git a/server/piztor/prob.py b/server/piztor/prob.py
index 6f7ef14..9798c18 100644
--- a/server/piztor/prob.py
+++ b/server/piztor/prob.py
@@ -56,7 +56,8 @@ def send(data):
from sys import argv
-username = "hello"
+#username = "hello"
+username = "12345678901234567890"
password = "world"
gid = 1
diff --git a/server/piztor/server.py b/server/piztor/server.py
index 204496c..eb321fc 100644
--- a/server/piztor/server.py
+++ b/server/piztor/server.py
@@ -22,7 +22,7 @@ db_path = "root:helloworld@localhost/piztor"
FORMAT = "%(asctime)-15s %(message)s"
logging.basicConfig(format = FORMAT)
logger = logging.getLogger('piztor_server')
-logger.setLevel(logging.WARN)
+logger.setLevel(logging.INFO)
class _SectionSize:
@@ -38,6 +38,12 @@ class _SectionSize:
LOCATION_ENTRY = USER_ID + LATITUDE + LONGITUDE
PADDING = 1
+_MAX_AUTH_HEAD_SIZE = _SectionSize.USER_TOKEN + \
+ MAX_USERNAME_SIZE + \
+ _SectionSize.PADDING
+_HEADER_SIZE = _SectionSize.LENGTH + \
+ _SectionSize.OPT_ID
+
class _OptCode:
user_auth = 0x00
location_update = 0x01
@@ -51,11 +57,17 @@ class _StatusCode:
class RequestHandler(object):
def __init__(self):
self.engine = create_engine('mysql://' + db_path, echo = False)
- self.Session = sessionmaker(bind = self.engine)
+ Session = sessionmaker(bind = self.engine)
+ self.session = Session()
def __del__(self):
+ self.session.close()
self.engine.dispose()
+ def check_size(self, tr_data):
+ if len(tr_data) > self._max_tr_data_size:
+ raise BadReqError("Authentication: Request size is too large")
+
@classmethod
def get_uauth(cls, token, username, session):
try:
@@ -73,7 +85,6 @@ class RequestHandler(object):
return None
except MultipleResultsFound:
- session.close()
raise DBCorruptedError()
@classmethod
@@ -82,7 +93,6 @@ class RequestHandler(object):
for i in xrange(len(data)):
ch = data[i]
if ch == '\x00':
- print get_hex(leading), get_hex(data[i + 1:])
return (leading, data[i + 1:])
else:
leading += ch
@@ -91,6 +101,11 @@ class RequestHandler(object):
class UserAuthHandler(RequestHandler):
+ _max_tr_data_size = MAX_USERNAME_SIZE + \
+ _SectionSize.PADDING + \
+ MAX_PASSWORD_SIZE + \
+ _SectionSize.PADDING
+
_response_size = \
_SectionSize.LENGTH + \
_SectionSize.OPT_ID + \
@@ -107,6 +122,7 @@ class UserAuthHandler(RequestHandler):
def handle(self, tr_data):
+ self.check_size(tr_data)
logger.info("Reading auth data...")
pos = -1
for i in xrange(0, len(tr_data)):
@@ -122,32 +138,27 @@ class UserAuthHandler(RequestHandler):
"(username = {0}, password = {1})" \
.format(username, password))
- session = self.Session()
try:
- user = session.query(UserModel) \
+ user = self.session.query(UserModel) \
.filter(UserModel.username == username).one()
except NoResultFound:
logger.info("No such user: {0}".format(username))
- session.commit()
return UserAuthHandler._failed_response
except MultipleResultsFound:
- session.close()
raise DBCorruptedError()
uauth = user.auth
if uauth is None:
- session.close()
raise DBCorruptedError()
if not uauth.check_password(password):
logger.info("Incorrect password: {0}".format(password))
- session.commit()
return UserAuthHandler._failed_response
else:
logger.info("Logged in sucessfully: {0}".format(username))
uauth.regen_token()
logger.info("New token generated: " + get_hex(uauth.token))
- session.commit()
+ self.session.commit()
return struct.pack("!LBBL32s", UserAuthHandler._response_size,
_OptCode.user_auth,
_StatusCode.sucess,
@@ -157,14 +168,18 @@ class UserAuthHandler(RequestHandler):
class LocationUpdateHandler(RequestHandler):
+ _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \
+ _SectionSize.LATITUDE + \
+ _SectionSize.LONGITUDE
+
_response_size = \
_SectionSize.LENGTH + \
_SectionSize.OPT_ID + \
_SectionSize.STATUS
def handle(self, tr_data):
+ self.check_size(tr_data)
logger.info("Reading location update data...")
-
try:
token, = struct.unpack("!32s", tr_data[:32])
username, tail = RequestHandler.trunc_padding(tr_data[32:])
@@ -178,12 +193,10 @@ class LocationUpdateHandler(RequestHandler):
"(token = {0}, username = {1}, lat = {2}, lng = {3})"\
.format(get_hex(token), username, lat, lng))
- session = self.Session()
- uauth = RequestHandler.get_uauth(token, username, session)
+ uauth = RequestHandler.get_uauth(token, username, self.session)
# Authentication failure
if uauth is None:
logger.warning("Authentication failure")
- session.commit()
return struct.pack("!LBB", LocationUpdateHandler._response_size,
_OptCode.location_update,
_StatusCode.failure)
@@ -193,12 +206,15 @@ class LocationUpdateHandler(RequestHandler):
ulocation.lng = lng
logger.info("Location is updated sucessfully")
- session.commit()
+ self.session.commit()
return struct.pack("!LBB", LocationUpdateHandler._response_size,
_OptCode.location_update,
_StatusCode.sucess)
-class LocationRequestHandler(RequestHandler):
+class LocationInfoHandler(RequestHandler):
+
+ _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \
+ _SectionSize.GROUP_ID
@classmethod
def _response_size(cls, item_num):
@@ -208,8 +224,8 @@ class LocationRequestHandler(RequestHandler):
_SectionSize.LOCATION_ENTRY * item_num
def handle(self, tr_data):
+ self.check_size(tr_data)
logger.info("Reading location request data..")
-
try:
token, = struct.unpack("!32s", tr_data[:32])
username, tail = RequestHandler.trunc_padding(tr_data[32:])
@@ -223,20 +239,18 @@ class LocationRequestHandler(RequestHandler):
"(token = {0}, gid = {1})" \
.format(get_hex(token), gid))
- session = self.Session()
- uauth = RequestHandler.get_uauth(token, username, session)
+ uauth = RequestHandler.get_uauth(token, username, self.session)
# Auth failure
if uauth is None:
logger.warning("Authentication failure")
- session.commit()
- return struct.pack("!LBB", LocationRequestHandler._response_size(0),
+ return struct.pack("!LBB", LocationInfoHandler._response_size(0),
_OptCode.location_request,
_StatusCode.failure)
- ulist = session.query(UserModel).filter(UserModel.gid == gid).all()
+ ulist = self.session.query(UserModel).filter(UserModel.gid == gid).all()
reply = struct.pack(
"!LBB",
- LocationRequestHandler._response_size(len(ulist)),
+ LocationInfoHandler._response_size(len(ulist)),
_OptCode.location_request,
_StatusCode.sucess)
@@ -244,7 +258,6 @@ class LocationRequestHandler(RequestHandler):
loc = user.location
reply += struct.pack("!Ldd", user.id, loc.lat, loc.lng)
- session.commit()
return reply
def pack_int(val):
@@ -254,7 +267,10 @@ def pack_bool(val):
return struct.pack("!B", 0x01 if val else 0x00)
-class UserInfoRequestHandler(RequestHandler):
+class UserInfoHandler(RequestHandler):
+
+ _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \
+ _SectionSize.USER_ID
_failed_response_size = \
_SectionSize.LENGTH + \
@@ -277,8 +293,8 @@ class UserInfoRequestHandler(RequestHandler):
return struct.pack("!B", info_key) + pack_method(info_value)
def handle(self, tr_data):
+ self.check_size(tr_data)
logger.info("Reading user info request data...")
-
try:
token, = struct.unpack("!32s", tr_data[:32])
username, tail = RequestHandler.trunc_padding(tr_data[32:])
@@ -292,47 +308,49 @@ class UserInfoRequestHandler(RequestHandler):
"(token = {0}, uid = {1})" \
.format(get_hex(token), uid))
- session = self.Session()
- uauth = RequestHandler.get_uauth(token, username, session)
+ uauth = RequestHandler.get_uauth(token, username, self.session)
# Auth failure
if uauth is None:
logger.warning("Authentication failure")
- session.commit()
- return UserInfoRequestHandler._fail_response
+ return UserInfoHandler._fail_response
# TODO: check the relationship between user and quser
user = uauth.user
reply = struct.pack("!BB", _OptCode.user_info_request,
_StatusCode.sucess)
try:
- quser = session.query(UserModel) \
+ quser = self.session.query(UserModel) \
.filter(UserModel.id == uid).one()
except NoResultFound:
logger.info("No such user: {0}".format(username))
- session.commit()
- return UserInfoRequestHandler._fail_response
+ return UserInfoHandler._fail_response
except MultipleResultsFound:
- session.close()
raise DBCorruptedError()
- for code in UserInfoRequestHandler._code_map:
- reply += UserInfoRequestHandler.pack_entry(quser, code)
+ for code in UserInfoHandler._code_map:
+ reply += UserInfoHandler.pack_entry(quser, code)
reply = struct.pack("!L", len(reply) + _SectionSize.LENGTH) + reply
- session.commit()
return reply
-handlers = [UserAuthHandler,
- LocationUpdateHandler,
- LocationRequestHandler,
- UserInfoRequestHandler]
-
-def check_header(header):
- return 0 <= header < len(handlers)
class PTP(Protocol, TimeoutMixin):
+ handlers = [UserAuthHandler,
+ LocationUpdateHandler,
+ LocationInfoHandler,
+ UserInfoHandler]
+
+ handler_num = len(handlers)
+
+ _MAX_REQUEST_SIZE = _HEADER_SIZE + \
+ max([h._max_tr_data_size for h in handlers])
+
+ @classmethod
+ def check_header(cls, header):
+ return 0 <= header < cls.handler_num
+
def __init__(self, factory):
self.buff = bytes()
self.length = -1
@@ -352,14 +370,19 @@ class PTP(Protocol, TimeoutMixin):
if len(self.buff) > 4:
try:
self.length, self.optcode = struct.unpack("!LB", self.buff[:5])
- if not check_header(self.optcode): # invalid header
+ if not PTP.check_header(self.optcode): # invalid header
raise struct.error
except struct.error:
- logger.warning("Invalid request header")
raise BadReqError("Malformed request header")
+ if self.length > PTP._MAX_REQUEST_SIZE:
+ print self.length, PTP._MAX_REQUEST_SIZE
+ raise BadReqError("The size of remaining part is too big")
+
+ # Incomplete length info
if self.length == -1: return
+
if len(self.buff) == self.length:
- h = handlers[self.optcode]()
+ h = PTP.handlers[self.optcode]()
reply = h.handle(self.buff[5:])
logger.info("Wrote: %s", get_hex(reply))
self.transport.write(reply)
@@ -391,4 +414,5 @@ from twisted.internet import reactor
f = PTPFactory()
f.protocol = PTP
reactor.listenTCP(2222, f)
+logger.warning("The server is on")
reactor.run()