diff options
Diffstat (limited to 'server')
-rw-r--r-- | server/piztor/import.py | 6 | ||||
-rw-r--r-- | server/piztor/model.py | 4 | ||||
-rw-r--r-- | server/piztor/prob.py | 5 | ||||
-rw-r--r-- | server/piztor/server.py | 129 |
4 files changed, 98 insertions, 46 deletions
diff --git a/server/piztor/import.py b/server/piztor/import.py index 84c990f..c91aae9 100644 --- a/server/piztor/import.py +++ b/server/piztor/import.py @@ -2,7 +2,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from model import * -path = "piztor.sqlite" +path = "root:helloworld@localhost/piztor" class UserData: def __init__(self, username, password, gid, sex): @@ -12,12 +12,12 @@ class UserData: self.sex = sex def create_database(): - engine = create_engine('sqlite:///' + path, echo = True) + engine = create_engine('mysql://' + path, echo = True) Base.metadata.drop_all(engine) Base.metadata.create_all(engine) def import_user_data(data): - engine = create_engine('sqlite:///' + path, echo = True) + engine = create_engine('mysql://' + path, echo = True) Session = sessionmaker(bind = engine) session = Session() for user in data: diff --git a/server/piztor/model.py b/server/piztor/model.py index 4621bbe..8916e3a 100644 --- a/server/piztor/model.py +++ b/server/piztor/model.py @@ -6,6 +6,8 @@ Base = declarative_base() _SALT_LEN = 16 _TOKEN_LEN = 16 +MAX_USERNAME_SIZE = 20 +MAX_PASSWORD_SIZE = 20 class _TableName: # avoid typoes UserModel = 'users' @@ -17,7 +19,7 @@ class UserModel(Base): id = Column(Integer, primary_key = True) gid = Column(Integer) - username = Column(String) + username = Column(String(MAX_USERNAME_SIZE)) sex = Column(Boolean) location = None auth = None diff --git a/server/piztor/prob.py b/server/piztor/prob.py index 20b6779..9798c18 100644 --- a/server/piztor/prob.py +++ b/server/piztor/prob.py @@ -56,7 +56,8 @@ def send(data): from sys import argv -username = "hello" +#username = "hello" +username = "12345678901234567890" password = "world" gid = 1 @@ -111,4 +112,4 @@ for i in xrange(10): idx += 1 print (info_key, info_value) from time import sleep - sleep(10) +# sleep(10) diff --git a/server/piztor/server.py b/server/piztor/server.py index d0cfc34..eb321fc 100644 --- a/server/piztor/server.py +++ b/server/piztor/server.py @@ -1,7 +1,6 @@ from twisted.internet.protocol import Protocol from twisted.internet.protocol import Factory from twisted.internet.endpoints import TCP4ServerEndpoint -from twisted.internet import reactor from twisted.protocols.policies import TimeoutMixin from sqlalchemy import create_engine @@ -18,7 +17,8 @@ from model import * def get_hex(data): return "".join([hex(ord(c))[2:].zfill(2) for c in data]) -db_path = "piztor.sqlite" +db_path = "root:helloworld@localhost/piztor" +#db_path = "piztor.sqlite" FORMAT = "%(asctime)-15s %(message)s" logging.basicConfig(format = FORMAT) logger = logging.getLogger('piztor_server') @@ -38,6 +38,12 @@ class _SectionSize: LOCATION_ENTRY = USER_ID + LATITUDE + LONGITUDE PADDING = 1 +_MAX_AUTH_HEAD_SIZE = _SectionSize.USER_TOKEN + \ + MAX_USERNAME_SIZE + \ + _SectionSize.PADDING +_HEADER_SIZE = _SectionSize.LENGTH + \ + _SectionSize.OPT_ID + class _OptCode: user_auth = 0x00 location_update = 0x01 @@ -50,8 +56,17 @@ class _StatusCode: class RequestHandler(object): def __init__(self): - self.engine = create_engine('sqlite:///' + db_path, echo = False) - self.Session = sessionmaker(bind = self.engine) + self.engine = create_engine('mysql://' + db_path, echo = False) + Session = sessionmaker(bind = self.engine) + self.session = Session() + + def __del__(self): + self.session.close() + self.engine.dispose() + + def check_size(self, tr_data): + if len(tr_data) > self._max_tr_data_size: + raise BadReqError("Authentication: Request size is too large") @classmethod def get_uauth(cls, token, username, session): @@ -78,7 +93,6 @@ class RequestHandler(object): for i in xrange(len(data)): ch = data[i] if ch == '\x00': - print get_hex(leading), get_hex(data[i + 1:]) return (leading, data[i + 1:]) else: leading += ch @@ -87,6 +101,11 @@ class RequestHandler(object): class UserAuthHandler(RequestHandler): + _max_tr_data_size = MAX_USERNAME_SIZE + \ + _SectionSize.PADDING + \ + MAX_PASSWORD_SIZE + \ + _SectionSize.PADDING + _response_size = \ _SectionSize.LENGTH + \ _SectionSize.OPT_ID + \ @@ -103,6 +122,7 @@ class UserAuthHandler(RequestHandler): def handle(self, tr_data): + self.check_size(tr_data) logger.info("Reading auth data...") pos = -1 for i in xrange(0, len(tr_data)): @@ -118,9 +138,8 @@ class UserAuthHandler(RequestHandler): "(username = {0}, password = {1})" \ .format(username, password)) - session = self.Session() try: - user = session.query(UserModel) \ + user = self.session.query(UserModel) \ .filter(UserModel.username == username).one() except NoResultFound: logger.info("No such user: {0}".format(username)) @@ -138,8 +157,8 @@ class UserAuthHandler(RequestHandler): else: logger.info("Logged in sucessfully: {0}".format(username)) uauth.regen_token() - session.commit() logger.info("New token generated: " + get_hex(uauth.token)) + self.session.commit() return struct.pack("!LBBL32s", UserAuthHandler._response_size, _OptCode.user_auth, _StatusCode.sucess, @@ -149,14 +168,18 @@ class UserAuthHandler(RequestHandler): class LocationUpdateHandler(RequestHandler): + _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \ + _SectionSize.LATITUDE + \ + _SectionSize.LONGITUDE + _response_size = \ _SectionSize.LENGTH + \ _SectionSize.OPT_ID + \ _SectionSize.STATUS def handle(self, tr_data): + self.check_size(tr_data) logger.info("Reading location update data...") - try: token, = struct.unpack("!32s", tr_data[:32]) username, tail = RequestHandler.trunc_padding(tr_data[32:]) @@ -170,8 +193,7 @@ class LocationUpdateHandler(RequestHandler): "(token = {0}, username = {1}, lat = {2}, lng = {3})"\ .format(get_hex(token), username, lat, lng)) - session = self.Session() - uauth = RequestHandler.get_uauth(token, username, session) + uauth = RequestHandler.get_uauth(token, username, self.session) # Authentication failure if uauth is None: logger.warning("Authentication failure") @@ -182,14 +204,17 @@ class LocationUpdateHandler(RequestHandler): ulocation = uauth.user.location ulocation.lat = lat ulocation.lng = lng - session.commit() logger.info("Location is updated sucessfully") + self.session.commit() return struct.pack("!LBB", LocationUpdateHandler._response_size, _OptCode.location_update, _StatusCode.sucess) -class LocationRequestHandler(RequestHandler): +class LocationInfoHandler(RequestHandler): + + _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \ + _SectionSize.GROUP_ID @classmethod def _response_size(cls, item_num): @@ -199,8 +224,8 @@ class LocationRequestHandler(RequestHandler): _SectionSize.LOCATION_ENTRY * item_num def handle(self, tr_data): + self.check_size(tr_data) logger.info("Reading location request data..") - try: token, = struct.unpack("!32s", tr_data[:32]) username, tail = RequestHandler.trunc_padding(tr_data[32:]) @@ -214,19 +239,18 @@ class LocationRequestHandler(RequestHandler): "(token = {0}, gid = {1})" \ .format(get_hex(token), gid)) - session = self.Session() - uauth = RequestHandler.get_uauth(token, username, session) + uauth = RequestHandler.get_uauth(token, username, self.session) # Auth failure if uauth is None: logger.warning("Authentication failure") - return struct.pack("!LBB", LocationRequestHandler._response_size(0), + return struct.pack("!LBB", LocationInfoHandler._response_size(0), _OptCode.location_request, _StatusCode.failure) - ulist = session.query(UserModel).filter(UserModel.gid == gid).all() + ulist = self.session.query(UserModel).filter(UserModel.gid == gid).all() reply = struct.pack( "!LBB", - LocationRequestHandler._response_size(len(ulist)), + LocationInfoHandler._response_size(len(ulist)), _OptCode.location_request, _StatusCode.sucess) @@ -243,7 +267,10 @@ def pack_bool(val): return struct.pack("!B", 0x01 if val else 0x00) -class UserInfoRequestHandler(RequestHandler): +class UserInfoHandler(RequestHandler): + + _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \ + _SectionSize.USER_ID _failed_response_size = \ _SectionSize.LENGTH + \ @@ -266,8 +293,8 @@ class UserInfoRequestHandler(RequestHandler): return struct.pack("!B", info_key) + pack_method(info_value) def handle(self, tr_data): + self.check_size(tr_data) logger.info("Reading user info request data...") - try: token, = struct.unpack("!32s", tr_data[:32]) username, tail = RequestHandler.trunc_padding(tr_data[32:]) @@ -281,43 +308,49 @@ class UserInfoRequestHandler(RequestHandler): "(token = {0}, uid = {1})" \ .format(get_hex(token), uid)) - session = self.Session() - uauth = RequestHandler.get_uauth(token, username, session) + uauth = RequestHandler.get_uauth(token, username, self.session) # Auth failure if uauth is None: logger.warning("Authentication failure") - return UserInfoRequestHandler._fail_response + return UserInfoHandler._fail_response # TODO: check the relationship between user and quser user = uauth.user reply = struct.pack("!BB", _OptCode.user_info_request, _StatusCode.sucess) try: - quser = session.query(UserModel) \ + quser = self.session.query(UserModel) \ .filter(UserModel.id == uid).one() except NoResultFound: logger.info("No such user: {0}".format(username)) - return UserInfoRequestHandler._fail_response + return UserInfoHandler._fail_response except MultipleResultsFound: raise DBCorruptedError() - for code in UserInfoRequestHandler._code_map: - reply += UserInfoRequestHandler.pack_entry(quser, code) + for code in UserInfoHandler._code_map: + reply += UserInfoHandler.pack_entry(quser, code) reply = struct.pack("!L", len(reply) + _SectionSize.LENGTH) + reply return reply -handlers = [UserAuthHandler, - LocationUpdateHandler, - LocationRequestHandler, - UserInfoRequestHandler] - -def check_header(header): - return 0 <= header < len(handlers) class PTP(Protocol, TimeoutMixin): + handlers = [UserAuthHandler, + LocationUpdateHandler, + LocationInfoHandler, + UserInfoHandler] + + handler_num = len(handlers) + + _MAX_REQUEST_SIZE = _HEADER_SIZE + \ + max([h._max_tr_data_size for h in handlers]) + + @classmethod + def check_header(cls, header): + return 0 <= header < cls.handler_num + def __init__(self, factory): self.buff = bytes() self.length = -1 @@ -337,14 +370,19 @@ class PTP(Protocol, TimeoutMixin): if len(self.buff) > 4: try: self.length, self.optcode = struct.unpack("!LB", self.buff[:5]) - if not check_header(self.optcode): # invalid header + if not PTP.check_header(self.optcode): # invalid header raise struct.error except struct.error: - logger.warning("Invalid request header") raise BadReqError("Malformed request header") + if self.length > PTP._MAX_REQUEST_SIZE: + print self.length, PTP._MAX_REQUEST_SIZE + raise BadReqError("The size of remaining part is too big") + + # Incomplete length info if self.length == -1: return + if len(self.buff) == self.length: - h = handlers[self.optcode]() + h = PTP.handlers[self.optcode]() reply = h.handle(self.buff[5:]) logger.info("Wrote: %s", get_hex(reply)) self.transport.write(reply) @@ -364,6 +402,17 @@ class PTPFactory(Factory): def buildProtocol(self, addr): return PTP(self) -endpoint = TCP4ServerEndpoint(reactor, 2222) -endpoint.listen(PTPFactory()) +#if os.name!='nt': +# from twisted.internet import epollreactor +# epollreactor.install() +#else: +# from twisted.internet import iocpreactor +# iocpreactor.install() + +from twisted.internet import reactor + +f = PTPFactory() +f.protocol = PTP +reactor.listenTCP(2222, f) +logger.warning("The server is on") reactor.run() |