diff options
Diffstat (limited to 'server/piztor/server.py')
-rw-r--r-- | server/piztor/server.py | 191 |
1 files changed, 181 insertions, 10 deletions
diff --git a/server/piztor/server.py b/server/piztor/server.py index 3f1f2cb..70cc13a 100644 --- a/server/piztor/server.py +++ b/server/piztor/server.py @@ -7,6 +7,8 @@ from sqlalchemy import create_engine, and_ from sqlalchemy.orm import sessionmaker from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound +from collections import deque + import struct import os import logging @@ -45,18 +47,76 @@ _MAX_AUTH_HEAD_SIZE = _SectionSize.USER_TOKEN + \ _HEADER_SIZE = _SectionSize.LENGTH + \ _SectionSize.OPT_ID +_MAX_TEXT_MESG_SIZE = 1024 + class _OptCode: user_auth = 0x00 location_update = 0x01 location_info= 0x02 user_info = 0x03 user_logout = 0x04 + open_push_tunnel = 0x05 + send_text_mesg = 0x06 class _StatusCode: sucess = 0x00 failure = 0x01 +class PushData(object): + from hashlib import sha256 + def __init__(self, data): + self.data = data + self.finger_print = sha256(data).digest() + +class PushTextMesgData(PushData): + def __init__(self, mesg): + self.finger_print = sha256(mesg).digest() + buff = struct.pack("!B32s", 0x00, self.finger_print) + buff += mesg + buff += chr(0) + buff = struct.pack("!L", _SectionSize.LENGTH + len(buff)) + buff + self.data = buff + + +class PushTunnel(object): + def __init__(self): + self.pending = deque() + self.conn = None + + def __del__(self): + if self.conn: + self.conn.loseConnection() + + def add(self, pdata): + logger.info("-- Push data enqued --") + logger.info("Data: %s", get_hex(pdata.data)) + self.pending.append(pdata) + + def on_receive(self, data): + front = self.pending.popleft() + length, optcode, fingerprint = struct.unpack("!LB32s", data) + if front.finger_print != fingerprint: + raise PiztorError + logger.info("-- Push data confirmed --") + self.push() + + def push(self): + if (self.conn is None) or len(self.pending) == 0: + return + print "Pushing" + front = self.pending.popleft() + self.pending.appendleft(front) + self.conn.transport.write(front.data) + + def connect(self, conn): + conn.tunnel = self + self.conn = conn + + def on_connection_lost(self): + self.conn = None + class RequestHandler(object): + push_tunnels = dict() def __init__(self): Session = sessionmaker(bind = engine) self.session = Session() @@ -78,7 +138,9 @@ class RequestHandler(object): if uauth.user.username != username: logger.warning("Toke and username mismatch") return None - + uid = uauth.uid + if not cls.push_tunnels.has_key(uid): + cls.push_tunnels[uid] = PushTunnel() return uauth except NoResultFound: @@ -122,7 +184,7 @@ class UserAuthHandler(RequestHandler): bytes('\x00' * 32)) - def handle(self, tr_data): + def handle(self, tr_data, conn): self.check_size(tr_data) logger.info("Reading auth data...") pos = -1 @@ -178,7 +240,7 @@ class LocationUpdateHandler(RequestHandler): _SectionSize.OPT_ID + \ _SectionSize.STATUS - def handle(self, tr_data): + def handle(self, tr_data, conn): self.check_size(tr_data) logger.info("Reading location update data...") try: @@ -224,7 +286,7 @@ class LocationInfoHandler(RequestHandler): _SectionSize.STATUS + \ _SectionSize.LOCATION_ENTRY * item_num - def handle(self, tr_data): + def handle(self, tr_data, conn): self.check_size(tr_data) logger.info("Reading location request data..") try: @@ -298,7 +360,7 @@ class UserInfoHandler(RequestHandler): info_key = entry_code return struct.pack("!B", info_key) + pack_method(user) - def handle(self, tr_data): + def handle(self, tr_data, conn): self.check_size(tr_data) logger.info("Reading user info request data...") try: @@ -348,7 +410,7 @@ class UserLogoutHandler(RequestHandler): _SectionSize.OPT_ID + \ _SectionSize.STATUS - def handle(self, tr_data): + def handle(self, tr_data, conn): self.check_size(tr_data) logger.info("Reading user logout data...") try: @@ -370,13 +432,105 @@ class UserLogoutHandler(RequestHandler): return struct.pack("!LBB", self._response_size, _OptCode.location_update, _StatusCode.failure) + del RequestHandler.push_tunnels[uauth.uid] uauth.regen_token() logger.info("User Logged out successfully!") self.session.commit() return struct.pack("!LBB", self._response_size, _OptCode.user_logout, _StatusCode.sucess) - + +class OpenPushTunnelHandler(RequestHandler): + + _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + + _response_size = \ + _SectionSize.LENGTH + \ + _SectionSize.OPT_ID + \ + _SectionSize.STATUS + + def handle(self, tr_data, conn): + self.check_size(tr_data) + logger.info("Reading open push tunnel data...") + try: + token, = struct.unpack("!32s", tr_data[:32]) + username, tail = RequestHandler.trunc_padding(tr_data[32:]) + if username is None: + raise struct.error + except struct.error: + raise BadReqError("Open push tunnel: Malformed request body") + + logger.info("Trying to open push tunnel with " + "(token = {0}, username = {1})"\ + .format(get_hex(token), username)) + + uauth = RequestHandler.get_uauth(token, username, self.session) + # Authentication failure + if uauth is None: + logger.warning("Authentication failure") + return struct.pack("!LBB", self._response_size, + _OptCode.location_update, + _StatusCode.failure) + + tunnel = RequestHandler.push_tunnels[uauth.uid] + pt = RequestHandler.push_tunnels + uid = uauth.uid + if pt.has_key(uid): + tunnel = pt[uid] + tunnel.connect(conn) + tunnel.push() + + logger.info("Push tunnel opened successfully!") + return struct.pack("!LBB", self._response_size, + _OptCode.user_logout, + _StatusCode.sucess) + +class SendTextMessageHandler(RequestHandler): + + _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \ + _MAX_TEXT_MESG_SIZE + \ + _SectionSize.PADDING + + _response_size = \ + _SectionSize.LENGTH + \ + _SectionSize.OPT_ID + \ + _SectionSize.STATUS + + def handle(self, tr_data, conn): + self.check_size(tr_data) + logger.info("Reading send text mesg data...") + try: + token, = struct.unpack("!32s", tr_data[:32]) + username, tail = RequestHandler.trunc_padding(tr_data[32:]) + mesg = tail[:-1] + if username is None: + raise struct.error + except struct.error: + raise BadReqError("Send text mesg: Malformed request body") + + logger.info("Trying to send text mesg with " + "(token = {0}, username = {1})"\ + .format(get_hex(token), username)) + + uauth = RequestHandler.get_uauth(token, username, self.session) + # Authentication failure + if uauth is None: + logger.warning("Authentication failure") + return struct.pack("!LBB", self._response_size, + _OptCode.location_update, + _StatusCode.failure) + + pt = RequestHandler.push_tunnels + uid = uauth.uid + if pt.has_key(uid): + tunnel = pt[uid] + tunnel.add(PushTextMesgData(mesg)) + tunnel.push() + logger.info("Sent text mesg successfully!") + return struct.pack("!LBB", self._response_size, + _OptCode.user_logout, + _StatusCode.sucess) + class PTP(Protocol, TimeoutMixin): @@ -384,7 +538,9 @@ class PTP(Protocol, TimeoutMixin): LocationUpdateHandler, LocationInfoHandler, UserInfoHandler, - UserLogoutHandler] + UserLogoutHandler, + OpenPushTunnelHandler, + SendTextMessageHandler] handler_num = len(handlers) @@ -399,6 +555,7 @@ class PTP(Protocol, TimeoutMixin): self.buff = bytes() self.length = -1 self.factory = factory + self.tunnel = None def timeoutConnection(self): logger.info("The connection times out") @@ -425,8 +582,18 @@ class PTP(Protocol, TimeoutMixin): print self.length, PTP._MAX_REQUEST_SIZE raise BadReqError("The size of remaining part is too big") if len(self.buff) == self.length: + if self.tunnel: # received push response + self.tunnel.on_receive(self.buff) + self.buff = bytes() + self.length = -1 + return h = PTP.handlers[self.optcode]() - reply = h.handle(self.buff[5:]) + reply = h.handle(self.buff[5:], self) + if self.tunnel: + logger.info("Blocking the client...") + self.buff = bytes() + self.length = -1 + return logger.info("Wrote: %s", get_hex(reply)) self.transport.write(reply) self.transport.loseConnection() @@ -434,12 +601,16 @@ class PTP(Protocol, TimeoutMixin): raise BadReqError("The actual length is larger than promised") except BadReqError as e: logger.warn("Rejected a bad request: %s", str(e)) + self.transport.loseConnection() except DBCorruptionError: logger.error("*** Database corruption ***") - finally: + self.transport.loseConnection() + if self.tunnel is None: self.transport.loseConnection() def connectionLost(self, reason): + if self.tunnel: + self.tunnel.on_connection_lost() logger.info("The connection is lost") self.setTimeout(None) |