summaryrefslogtreecommitdiff
path: root/server
diff options
context:
space:
mode:
Diffstat (limited to 'server')
-rw-r--r--server/piztor/import.py6
-rw-r--r--server/piztor/model.py4
-rw-r--r--server/piztor/prob.py5
-rw-r--r--server/piztor/server.py129
4 files changed, 98 insertions, 46 deletions
diff --git a/server/piztor/import.py b/server/piztor/import.py
index 84c990f..c91aae9 100644
--- a/server/piztor/import.py
+++ b/server/piztor/import.py
@@ -2,7 +2,7 @@ from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from model import *
-path = "piztor.sqlite"
+path = "root:helloworld@localhost/piztor"
class UserData:
def __init__(self, username, password, gid, sex):
@@ -12,12 +12,12 @@ class UserData:
self.sex = sex
def create_database():
- engine = create_engine('sqlite:///' + path, echo = True)
+ engine = create_engine('mysql://' + path, echo = True)
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
def import_user_data(data):
- engine = create_engine('sqlite:///' + path, echo = True)
+ engine = create_engine('mysql://' + path, echo = True)
Session = sessionmaker(bind = engine)
session = Session()
for user in data:
diff --git a/server/piztor/model.py b/server/piztor/model.py
index 4621bbe..8916e3a 100644
--- a/server/piztor/model.py
+++ b/server/piztor/model.py
@@ -6,6 +6,8 @@ Base = declarative_base()
_SALT_LEN = 16
_TOKEN_LEN = 16
+MAX_USERNAME_SIZE = 20
+MAX_PASSWORD_SIZE = 20
class _TableName: # avoid typoes
UserModel = 'users'
@@ -17,7 +19,7 @@ class UserModel(Base):
id = Column(Integer, primary_key = True)
gid = Column(Integer)
- username = Column(String)
+ username = Column(String(MAX_USERNAME_SIZE))
sex = Column(Boolean)
location = None
auth = None
diff --git a/server/piztor/prob.py b/server/piztor/prob.py
index 20b6779..9798c18 100644
--- a/server/piztor/prob.py
+++ b/server/piztor/prob.py
@@ -56,7 +56,8 @@ def send(data):
from sys import argv
-username = "hello"
+#username = "hello"
+username = "12345678901234567890"
password = "world"
gid = 1
@@ -111,4 +112,4 @@ for i in xrange(10):
idx += 1
print (info_key, info_value)
from time import sleep
- sleep(10)
+# sleep(10)
diff --git a/server/piztor/server.py b/server/piztor/server.py
index d0cfc34..eb321fc 100644
--- a/server/piztor/server.py
+++ b/server/piztor/server.py
@@ -1,7 +1,6 @@
from twisted.internet.protocol import Protocol
from twisted.internet.protocol import Factory
from twisted.internet.endpoints import TCP4ServerEndpoint
-from twisted.internet import reactor
from twisted.protocols.policies import TimeoutMixin
from sqlalchemy import create_engine
@@ -18,7 +17,8 @@ from model import *
def get_hex(data):
return "".join([hex(ord(c))[2:].zfill(2) for c in data])
-db_path = "piztor.sqlite"
+db_path = "root:helloworld@localhost/piztor"
+#db_path = "piztor.sqlite"
FORMAT = "%(asctime)-15s %(message)s"
logging.basicConfig(format = FORMAT)
logger = logging.getLogger('piztor_server')
@@ -38,6 +38,12 @@ class _SectionSize:
LOCATION_ENTRY = USER_ID + LATITUDE + LONGITUDE
PADDING = 1
+_MAX_AUTH_HEAD_SIZE = _SectionSize.USER_TOKEN + \
+ MAX_USERNAME_SIZE + \
+ _SectionSize.PADDING
+_HEADER_SIZE = _SectionSize.LENGTH + \
+ _SectionSize.OPT_ID
+
class _OptCode:
user_auth = 0x00
location_update = 0x01
@@ -50,8 +56,17 @@ class _StatusCode:
class RequestHandler(object):
def __init__(self):
- self.engine = create_engine('sqlite:///' + db_path, echo = False)
- self.Session = sessionmaker(bind = self.engine)
+ self.engine = create_engine('mysql://' + db_path, echo = False)
+ Session = sessionmaker(bind = self.engine)
+ self.session = Session()
+
+ def __del__(self):
+ self.session.close()
+ self.engine.dispose()
+
+ def check_size(self, tr_data):
+ if len(tr_data) > self._max_tr_data_size:
+ raise BadReqError("Authentication: Request size is too large")
@classmethod
def get_uauth(cls, token, username, session):
@@ -78,7 +93,6 @@ class RequestHandler(object):
for i in xrange(len(data)):
ch = data[i]
if ch == '\x00':
- print get_hex(leading), get_hex(data[i + 1:])
return (leading, data[i + 1:])
else:
leading += ch
@@ -87,6 +101,11 @@ class RequestHandler(object):
class UserAuthHandler(RequestHandler):
+ _max_tr_data_size = MAX_USERNAME_SIZE + \
+ _SectionSize.PADDING + \
+ MAX_PASSWORD_SIZE + \
+ _SectionSize.PADDING
+
_response_size = \
_SectionSize.LENGTH + \
_SectionSize.OPT_ID + \
@@ -103,6 +122,7 @@ class UserAuthHandler(RequestHandler):
def handle(self, tr_data):
+ self.check_size(tr_data)
logger.info("Reading auth data...")
pos = -1
for i in xrange(0, len(tr_data)):
@@ -118,9 +138,8 @@ class UserAuthHandler(RequestHandler):
"(username = {0}, password = {1})" \
.format(username, password))
- session = self.Session()
try:
- user = session.query(UserModel) \
+ user = self.session.query(UserModel) \
.filter(UserModel.username == username).one()
except NoResultFound:
logger.info("No such user: {0}".format(username))
@@ -138,8 +157,8 @@ class UserAuthHandler(RequestHandler):
else:
logger.info("Logged in sucessfully: {0}".format(username))
uauth.regen_token()
- session.commit()
logger.info("New token generated: " + get_hex(uauth.token))
+ self.session.commit()
return struct.pack("!LBBL32s", UserAuthHandler._response_size,
_OptCode.user_auth,
_StatusCode.sucess,
@@ -149,14 +168,18 @@ class UserAuthHandler(RequestHandler):
class LocationUpdateHandler(RequestHandler):
+ _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \
+ _SectionSize.LATITUDE + \
+ _SectionSize.LONGITUDE
+
_response_size = \
_SectionSize.LENGTH + \
_SectionSize.OPT_ID + \
_SectionSize.STATUS
def handle(self, tr_data):
+ self.check_size(tr_data)
logger.info("Reading location update data...")
-
try:
token, = struct.unpack("!32s", tr_data[:32])
username, tail = RequestHandler.trunc_padding(tr_data[32:])
@@ -170,8 +193,7 @@ class LocationUpdateHandler(RequestHandler):
"(token = {0}, username = {1}, lat = {2}, lng = {3})"\
.format(get_hex(token), username, lat, lng))
- session = self.Session()
- uauth = RequestHandler.get_uauth(token, username, session)
+ uauth = RequestHandler.get_uauth(token, username, self.session)
# Authentication failure
if uauth is None:
logger.warning("Authentication failure")
@@ -182,14 +204,17 @@ class LocationUpdateHandler(RequestHandler):
ulocation = uauth.user.location
ulocation.lat = lat
ulocation.lng = lng
- session.commit()
logger.info("Location is updated sucessfully")
+ self.session.commit()
return struct.pack("!LBB", LocationUpdateHandler._response_size,
_OptCode.location_update,
_StatusCode.sucess)
-class LocationRequestHandler(RequestHandler):
+class LocationInfoHandler(RequestHandler):
+
+ _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \
+ _SectionSize.GROUP_ID
@classmethod
def _response_size(cls, item_num):
@@ -199,8 +224,8 @@ class LocationRequestHandler(RequestHandler):
_SectionSize.LOCATION_ENTRY * item_num
def handle(self, tr_data):
+ self.check_size(tr_data)
logger.info("Reading location request data..")
-
try:
token, = struct.unpack("!32s", tr_data[:32])
username, tail = RequestHandler.trunc_padding(tr_data[32:])
@@ -214,19 +239,18 @@ class LocationRequestHandler(RequestHandler):
"(token = {0}, gid = {1})" \
.format(get_hex(token), gid))
- session = self.Session()
- uauth = RequestHandler.get_uauth(token, username, session)
+ uauth = RequestHandler.get_uauth(token, username, self.session)
# Auth failure
if uauth is None:
logger.warning("Authentication failure")
- return struct.pack("!LBB", LocationRequestHandler._response_size(0),
+ return struct.pack("!LBB", LocationInfoHandler._response_size(0),
_OptCode.location_request,
_StatusCode.failure)
- ulist = session.query(UserModel).filter(UserModel.gid == gid).all()
+ ulist = self.session.query(UserModel).filter(UserModel.gid == gid).all()
reply = struct.pack(
"!LBB",
- LocationRequestHandler._response_size(len(ulist)),
+ LocationInfoHandler._response_size(len(ulist)),
_OptCode.location_request,
_StatusCode.sucess)
@@ -243,7 +267,10 @@ def pack_bool(val):
return struct.pack("!B", 0x01 if val else 0x00)
-class UserInfoRequestHandler(RequestHandler):
+class UserInfoHandler(RequestHandler):
+
+ _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \
+ _SectionSize.USER_ID
_failed_response_size = \
_SectionSize.LENGTH + \
@@ -266,8 +293,8 @@ class UserInfoRequestHandler(RequestHandler):
return struct.pack("!B", info_key) + pack_method(info_value)
def handle(self, tr_data):
+ self.check_size(tr_data)
logger.info("Reading user info request data...")
-
try:
token, = struct.unpack("!32s", tr_data[:32])
username, tail = RequestHandler.trunc_padding(tr_data[32:])
@@ -281,43 +308,49 @@ class UserInfoRequestHandler(RequestHandler):
"(token = {0}, uid = {1})" \
.format(get_hex(token), uid))
- session = self.Session()
- uauth = RequestHandler.get_uauth(token, username, session)
+ uauth = RequestHandler.get_uauth(token, username, self.session)
# Auth failure
if uauth is None:
logger.warning("Authentication failure")
- return UserInfoRequestHandler._fail_response
+ return UserInfoHandler._fail_response
# TODO: check the relationship between user and quser
user = uauth.user
reply = struct.pack("!BB", _OptCode.user_info_request,
_StatusCode.sucess)
try:
- quser = session.query(UserModel) \
+ quser = self.session.query(UserModel) \
.filter(UserModel.id == uid).one()
except NoResultFound:
logger.info("No such user: {0}".format(username))
- return UserInfoRequestHandler._fail_response
+ return UserInfoHandler._fail_response
except MultipleResultsFound:
raise DBCorruptedError()
- for code in UserInfoRequestHandler._code_map:
- reply += UserInfoRequestHandler.pack_entry(quser, code)
+ for code in UserInfoHandler._code_map:
+ reply += UserInfoHandler.pack_entry(quser, code)
reply = struct.pack("!L", len(reply) + _SectionSize.LENGTH) + reply
return reply
-handlers = [UserAuthHandler,
- LocationUpdateHandler,
- LocationRequestHandler,
- UserInfoRequestHandler]
-
-def check_header(header):
- return 0 <= header < len(handlers)
class PTP(Protocol, TimeoutMixin):
+ handlers = [UserAuthHandler,
+ LocationUpdateHandler,
+ LocationInfoHandler,
+ UserInfoHandler]
+
+ handler_num = len(handlers)
+
+ _MAX_REQUEST_SIZE = _HEADER_SIZE + \
+ max([h._max_tr_data_size for h in handlers])
+
+ @classmethod
+ def check_header(cls, header):
+ return 0 <= header < cls.handler_num
+
def __init__(self, factory):
self.buff = bytes()
self.length = -1
@@ -337,14 +370,19 @@ class PTP(Protocol, TimeoutMixin):
if len(self.buff) > 4:
try:
self.length, self.optcode = struct.unpack("!LB", self.buff[:5])
- if not check_header(self.optcode): # invalid header
+ if not PTP.check_header(self.optcode): # invalid header
raise struct.error
except struct.error:
- logger.warning("Invalid request header")
raise BadReqError("Malformed request header")
+ if self.length > PTP._MAX_REQUEST_SIZE:
+ print self.length, PTP._MAX_REQUEST_SIZE
+ raise BadReqError("The size of remaining part is too big")
+
+ # Incomplete length info
if self.length == -1: return
+
if len(self.buff) == self.length:
- h = handlers[self.optcode]()
+ h = PTP.handlers[self.optcode]()
reply = h.handle(self.buff[5:])
logger.info("Wrote: %s", get_hex(reply))
self.transport.write(reply)
@@ -364,6 +402,17 @@ class PTPFactory(Factory):
def buildProtocol(self, addr):
return PTP(self)
-endpoint = TCP4ServerEndpoint(reactor, 2222)
-endpoint.listen(PTPFactory())
+#if os.name!='nt':
+# from twisted.internet import epollreactor
+# epollreactor.install()
+#else:
+# from twisted.internet import iocpreactor
+# iocpreactor.install()
+
+from twisted.internet import reactor
+
+f = PTPFactory()
+f.protocol = PTP
+reactor.listenTCP(2222, f)
+logger.warning("The server is on")
reactor.run()