summaryrefslogtreecommitdiff
path: root/server/piztor/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'server/piztor/server.py')
-rw-r--r--server/piztor/server.py191
1 files changed, 181 insertions, 10 deletions
diff --git a/server/piztor/server.py b/server/piztor/server.py
index 3f1f2cb..70cc13a 100644
--- a/server/piztor/server.py
+++ b/server/piztor/server.py
@@ -7,6 +7,8 @@ from sqlalchemy import create_engine, and_
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound
+from collections import deque
+
import struct
import os
import logging
@@ -45,18 +47,76 @@ _MAX_AUTH_HEAD_SIZE = _SectionSize.USER_TOKEN + \
_HEADER_SIZE = _SectionSize.LENGTH + \
_SectionSize.OPT_ID
+_MAX_TEXT_MESG_SIZE = 1024
+
class _OptCode:
user_auth = 0x00
location_update = 0x01
location_info= 0x02
user_info = 0x03
user_logout = 0x04
+ open_push_tunnel = 0x05
+ send_text_mesg = 0x06
class _StatusCode:
sucess = 0x00
failure = 0x01
+class PushData(object):
+ from hashlib import sha256
+ def __init__(self, data):
+ self.data = data
+ self.finger_print = sha256(data).digest()
+
+class PushTextMesgData(PushData):
+ def __init__(self, mesg):
+ self.finger_print = sha256(mesg).digest()
+ buff = struct.pack("!B32s", 0x00, self.finger_print)
+ buff += mesg
+ buff += chr(0)
+ buff = struct.pack("!L", _SectionSize.LENGTH + len(buff)) + buff
+ self.data = buff
+
+
+class PushTunnel(object):
+ def __init__(self):
+ self.pending = deque()
+ self.conn = None
+
+ def __del__(self):
+ if self.conn:
+ self.conn.loseConnection()
+
+ def add(self, pdata):
+ logger.info("-- Push data enqued --")
+ logger.info("Data: %s", get_hex(pdata.data))
+ self.pending.append(pdata)
+
+ def on_receive(self, data):
+ front = self.pending.popleft()
+ length, optcode, fingerprint = struct.unpack("!LB32s", data)
+ if front.finger_print != fingerprint:
+ raise PiztorError
+ logger.info("-- Push data confirmed --")
+ self.push()
+
+ def push(self):
+ if (self.conn is None) or len(self.pending) == 0:
+ return
+ print "Pushing"
+ front = self.pending.popleft()
+ self.pending.appendleft(front)
+ self.conn.transport.write(front.data)
+
+ def connect(self, conn):
+ conn.tunnel = self
+ self.conn = conn
+
+ def on_connection_lost(self):
+ self.conn = None
+
class RequestHandler(object):
+ push_tunnels = dict()
def __init__(self):
Session = sessionmaker(bind = engine)
self.session = Session()
@@ -78,7 +138,9 @@ class RequestHandler(object):
if uauth.user.username != username:
logger.warning("Toke and username mismatch")
return None
-
+ uid = uauth.uid
+ if not cls.push_tunnels.has_key(uid):
+ cls.push_tunnels[uid] = PushTunnel()
return uauth
except NoResultFound:
@@ -122,7 +184,7 @@ class UserAuthHandler(RequestHandler):
bytes('\x00' * 32))
- def handle(self, tr_data):
+ def handle(self, tr_data, conn):
self.check_size(tr_data)
logger.info("Reading auth data...")
pos = -1
@@ -178,7 +240,7 @@ class LocationUpdateHandler(RequestHandler):
_SectionSize.OPT_ID + \
_SectionSize.STATUS
- def handle(self, tr_data):
+ def handle(self, tr_data, conn):
self.check_size(tr_data)
logger.info("Reading location update data...")
try:
@@ -224,7 +286,7 @@ class LocationInfoHandler(RequestHandler):
_SectionSize.STATUS + \
_SectionSize.LOCATION_ENTRY * item_num
- def handle(self, tr_data):
+ def handle(self, tr_data, conn):
self.check_size(tr_data)
logger.info("Reading location request data..")
try:
@@ -298,7 +360,7 @@ class UserInfoHandler(RequestHandler):
info_key = entry_code
return struct.pack("!B", info_key) + pack_method(user)
- def handle(self, tr_data):
+ def handle(self, tr_data, conn):
self.check_size(tr_data)
logger.info("Reading user info request data...")
try:
@@ -348,7 +410,7 @@ class UserLogoutHandler(RequestHandler):
_SectionSize.OPT_ID + \
_SectionSize.STATUS
- def handle(self, tr_data):
+ def handle(self, tr_data, conn):
self.check_size(tr_data)
logger.info("Reading user logout data...")
try:
@@ -370,13 +432,105 @@ class UserLogoutHandler(RequestHandler):
return struct.pack("!LBB", self._response_size,
_OptCode.location_update,
_StatusCode.failure)
+ del RequestHandler.push_tunnels[uauth.uid]
uauth.regen_token()
logger.info("User Logged out successfully!")
self.session.commit()
return struct.pack("!LBB", self._response_size,
_OptCode.user_logout,
_StatusCode.sucess)
-
+
+class OpenPushTunnelHandler(RequestHandler):
+
+ _max_tr_data_size = _MAX_AUTH_HEAD_SIZE
+
+ _response_size = \
+ _SectionSize.LENGTH + \
+ _SectionSize.OPT_ID + \
+ _SectionSize.STATUS
+
+ def handle(self, tr_data, conn):
+ self.check_size(tr_data)
+ logger.info("Reading open push tunnel data...")
+ try:
+ token, = struct.unpack("!32s", tr_data[:32])
+ username, tail = RequestHandler.trunc_padding(tr_data[32:])
+ if username is None:
+ raise struct.error
+ except struct.error:
+ raise BadReqError("Open push tunnel: Malformed request body")
+
+ logger.info("Trying to open push tunnel with "
+ "(token = {0}, username = {1})"\
+ .format(get_hex(token), username))
+
+ uauth = RequestHandler.get_uauth(token, username, self.session)
+ # Authentication failure
+ if uauth is None:
+ logger.warning("Authentication failure")
+ return struct.pack("!LBB", self._response_size,
+ _OptCode.location_update,
+ _StatusCode.failure)
+
+ tunnel = RequestHandler.push_tunnels[uauth.uid]
+ pt = RequestHandler.push_tunnels
+ uid = uauth.uid
+ if pt.has_key(uid):
+ tunnel = pt[uid]
+ tunnel.connect(conn)
+ tunnel.push()
+
+ logger.info("Push tunnel opened successfully!")
+ return struct.pack("!LBB", self._response_size,
+ _OptCode.user_logout,
+ _StatusCode.sucess)
+
+class SendTextMessageHandler(RequestHandler):
+
+ _max_tr_data_size = _MAX_AUTH_HEAD_SIZE + \
+ _MAX_TEXT_MESG_SIZE + \
+ _SectionSize.PADDING
+
+ _response_size = \
+ _SectionSize.LENGTH + \
+ _SectionSize.OPT_ID + \
+ _SectionSize.STATUS
+
+ def handle(self, tr_data, conn):
+ self.check_size(tr_data)
+ logger.info("Reading send text mesg data...")
+ try:
+ token, = struct.unpack("!32s", tr_data[:32])
+ username, tail = RequestHandler.trunc_padding(tr_data[32:])
+ mesg = tail[:-1]
+ if username is None:
+ raise struct.error
+ except struct.error:
+ raise BadReqError("Send text mesg: Malformed request body")
+
+ logger.info("Trying to send text mesg with "
+ "(token = {0}, username = {1})"\
+ .format(get_hex(token), username))
+
+ uauth = RequestHandler.get_uauth(token, username, self.session)
+ # Authentication failure
+ if uauth is None:
+ logger.warning("Authentication failure")
+ return struct.pack("!LBB", self._response_size,
+ _OptCode.location_update,
+ _StatusCode.failure)
+
+ pt = RequestHandler.push_tunnels
+ uid = uauth.uid
+ if pt.has_key(uid):
+ tunnel = pt[uid]
+ tunnel.add(PushTextMesgData(mesg))
+ tunnel.push()
+ logger.info("Sent text mesg successfully!")
+ return struct.pack("!LBB", self._response_size,
+ _OptCode.user_logout,
+ _StatusCode.sucess)
+
class PTP(Protocol, TimeoutMixin):
@@ -384,7 +538,9 @@ class PTP(Protocol, TimeoutMixin):
LocationUpdateHandler,
LocationInfoHandler,
UserInfoHandler,
- UserLogoutHandler]
+ UserLogoutHandler,
+ OpenPushTunnelHandler,
+ SendTextMessageHandler]
handler_num = len(handlers)
@@ -399,6 +555,7 @@ class PTP(Protocol, TimeoutMixin):
self.buff = bytes()
self.length = -1
self.factory = factory
+ self.tunnel = None
def timeoutConnection(self):
logger.info("The connection times out")
@@ -425,8 +582,18 @@ class PTP(Protocol, TimeoutMixin):
print self.length, PTP._MAX_REQUEST_SIZE
raise BadReqError("The size of remaining part is too big")
if len(self.buff) == self.length:
+ if self.tunnel: # received push response
+ self.tunnel.on_receive(self.buff)
+ self.buff = bytes()
+ self.length = -1
+ return
h = PTP.handlers[self.optcode]()
- reply = h.handle(self.buff[5:])
+ reply = h.handle(self.buff[5:], self)
+ if self.tunnel:
+ logger.info("Blocking the client...")
+ self.buff = bytes()
+ self.length = -1
+ return
logger.info("Wrote: %s", get_hex(reply))
self.transport.write(reply)
self.transport.loseConnection()
@@ -434,12 +601,16 @@ class PTP(Protocol, TimeoutMixin):
raise BadReqError("The actual length is larger than promised")
except BadReqError as e:
logger.warn("Rejected a bad request: %s", str(e))
+ self.transport.loseConnection()
except DBCorruptionError:
logger.error("*** Database corruption ***")
- finally:
+ self.transport.loseConnection()
+ if self.tunnel is None:
self.transport.loseConnection()
def connectionLost(self, reason):
+ if self.tunnel:
+ self.tunnel.on_connection_lost()
logger.info("The connection is lost")
self.setTimeout(None)