# compatibility with Python 2.6, for that we need unittest2 package,
# which is not available on 3.3 or 3.4
import warnings
from binascii import hexlify
try:
import unittest2 as unittest
except ImportError:
import unittest
from six import b
import hypothesis.strategies as st
from hypothesis import given, example
import pytest
from ._compat import str_idx_as_int
from .curves import NIST256p, NIST224p
from .der import remove_integer, UnexpectedDER, read_length, encode_bitstring,\
remove_bitstring, remove_object, encode_oid
class TestRemoveInteger(unittest.TestCase):
# DER requires the integers to be 0-padded only if they would be
# interpreted as negative, check if those errors are detected
def test_non_minimal_encoding(self):
with self.assertRaises(UnexpectedDER):
remove_integer(b('\x02\x02\x00\x01'))
def test_negative_with_high_bit_set(self):
with self.assertRaises(UnexpectedDER):
remove_integer(b('\x02\x01\x80'))
def test_minimal_with_high_bit_set(self):
val, rem = remove_integer(b('\x02\x02\x00\x80'))
self.assertEqual(val, 0x80)
self.assertFalse(rem)
def test_two_zero_bytes_with_high_bit_set(self):
with self.assertRaises(UnexpectedDER):
remove_integer(b('\x02\x03\x00\x00\xff'))
def test_zero_length_integer(self):
with self.assertRaises(UnexpectedDER):
remove_integer(b('\x02\x00'))
def test_empty_string(self):
with self.assertRaises(UnexpectedDER):
remove_integer(b(''))
def test_encoding_of_zero(self):
val, rem = remove_integer(b('\x02\x01\x00'))
self.assertEqual(val, 0)
self.assertFalse(rem)
def test_encoding_of_127(self):
val, rem = remove_integer(b('\x02\x01\x7f'))
self.assertEqual(val, 127)
self.assertFalse(rem)
def test_encoding_of_128(self):
val, rem = remove_integer(b('\x02\x02\x00\x80'))
self.assertEqual(val, 128)
self.assertFalse(rem)
class TestReadLength(unittest.TestCase):
# DER requires the lengths between 0 and 127 to be encoded using the short
# form and lengths above that encoded with minimal number of bytes
# necessary
def test_zero_length(self):
self.assertEqual((0, 1), read_length(b('\x00')))
def test_two_byte_zero_length(self):
with self.assertRaises(UnexpectedDER):
read_length(b('\x81\x00'))
def test_two_byte_small_length(self):
with self.assertRaises(UnexpectedDER):
read_length(b('\x81\x7f'))
def test_long_form_with_zero_length(self):
with self.assertRaises(UnexpectedDER):
read_length(b('\x80'))
def test_smallest_two_byte_length(self):
self.assertEqual((128, 2), read_length(b('\x81\x80')))
def test_zero_padded_length(self):
with self.assertRaises(UnexpectedDER):
read_length(b('\x82\x00\x80'))
def test_two_three_byte_length(self):
self.assertEqual((256, 3), read_length(b'\x82\x01\x00'))
def test_empty_string(self):
with self.assertRaises(UnexpectedDER):
read_length(b(''))
def test_length_overflow(self):
with self.assertRaises(UnexpectedDER):
read_length(b('\x83\x01\x00'))
class TestEncodeBitstring(unittest.TestCase):
# DER requires BIT STRINGS to include a number of padding bits in the
# encoded byte string, that padding must be between 0 and 7
def test_old_call_convention(self):
"""This is the old way to use the function."""
warnings.simplefilter('always')
with pytest.warns(DeprecationWarning) as warns:
der = encode_bitstring(b'\x00\xff')
self.assertEqual(len(warns), 1)
self.assertIn("unused= needs to be specified",
warns[0].message.args[0])
self.assertEqual(der, b'\x03\x02\x00\xff')
def test_new_call_convention(self):
"""This is how it should be called now."""
warnings.simplefilter('always')
with pytest.warns(None) as warns:
der = encode_bitstring(b'\xff', 0)
# verify that new call convention doesn't raise Warnings
self.assertEqual(len(warns), 0)
self.assertEqual(der, b'\x03\x02\x00\xff')
def test_implicit_unused_bits(self):
"""
Writing bit string with already included the number of unused bits.
"""
warnings.simplefilter('always')
with pytest.warns(None) as warns:
der = encode_bitstring(b'\x00\xff', None)
# verify that new call convention doesn't raise Warnings
self.assertEqual(len(warns), 0)
self.assertEqual(der, b'\x03\x02\x00\xff')
def test_explicit_unused_bits(self):
der = encode_bitstring(b'\xff\xf0', 4)
self.assertEqual(der, b'\x03\x03\x04\xff\xf0')
def test_empty_string(self):
self.assertEqual(encode_bitstring(b'', 0), b'\x03\x01\x00')
def test_invalid_unused_count(self):
with self.assertRaises(ValueError):
encode_bitstring(b'\xff\x00', 8)
def test_invalid_unused_with_empty_string(self):
with self.assertRaises(ValueError):
encode_bitstring(b'', 1)
def test_non_zero_padding_bits(self):
with self.assertRaises(ValueError):
encode_bitstring(b'\xff', 2)
class TestRemoveBitstring(unittest.TestCase):
def test_old_call_convention(self):
"""This is the old way to call the function."""
warnings.simplefilter('always')
with pytest.warns(DeprecationWarning) as warns:
bits, rest = remove_bitstring(b'\x03\x02\x00\xff')
self.assertEqual(len(warns), 1)
self.assertIn("expect_unused= needs to be specified",
warns[0].message.args[0])
self.assertEqual(bits, b'\x00\xff')
self.assertEqual(rest, b'')
def test_new_call_convention(self):
warnings.simplefilter('always')
with pytest.warns(None) as warns:
bits, rest = remove_bitstring(b'\x03\x02\x00\xff', 0)
self.assertEqual(len(warns), 0)
self.assertEqual(bits, b'\xff')
self.assertEqual(rest, b'')
def test_implicit_unexpected_unused(self):
warnings.simplefilter('always')
with pytest.warns(None) as warns:
bits, rest = remove_bitstring(b'\x03\x02\x00\xff', None)
self.assertEqual(len(warns), 0)
self.assertEqual(bits, (b'\xff', 0))
self.assertEqual(rest, b'')
def test_with_padding(self):
ret, rest = remove_bitstring(b'\x03\x02\x04\xf0', None)
self.assertEqual(ret, (b'\xf0', 4))
self.assertEqual(rest, b'')
def test_not_a_bitstring(self):
with self.assertRaises(UnexpectedDER):
remove_bitstring(b'\x02\x02\x00\xff', None)
def test_empty_encoding(self):
with self.assertRaises(UnexpectedDER):
remove_bitstring(b'\x03\x00', None)
def test_empty_string(self):
with self.assertRaises(UnexpectedDER):
remove_bitstring(b'', None)
def test_no_length(self):
with self.assertRaises(UnexpectedDER):
remove_bitstring(b'\x03', None)
def test_unexpected_number_of_unused_bits(self):
with self.assertRaises(UnexpectedDER):
remove_bitstring(b'\x03\x02\x00\xff', 1)
def test_invalid_encoding_of_unused_bits(self):
with self.assertRaises(UnexpectedDER):
remove_bitstring(b'\x03\x03\x08\xff\x00', None)
def test_invalid_encoding_of_empty_string(self):
with self.assertRaises(UnexpectedDER):
remove_bitstring(b'\x03\x01\x01', None)
def test_invalid_padding_bits(self):
with self.assertRaises(UnexpectedDER):
remove_bitstring(b'\x03\x02\x01\xff', None)
class TestStrIdxAsInt(unittest.TestCase):
def test_str(self):
self.assertEqual(115, str_idx_as_int('str', 0))
def test_bytes(self):
self.assertEqual(115, str_idx_as_int(b'str', 0))
def test_bytearray(self):
self.assertEqual(115, str_idx_as_int(bytearray(b'str'), 0))
class TestEncodeOid(unittest.TestCase):
def test_pub_key_oid(self):
oid_ecPublicKey = encode_oid(1, 2, 840, 10045, 2, 1)
self.assertEqual(hexlify(oid_ecPublicKey), b("06072a8648ce3d0201"))
def test_nist224p_oid(self):
self.assertEqual(hexlify(NIST224p.encoded_oid), b("06052b81040021"))
def test_nist256p_oid(self):
self.assertEqual(hexlify(NIST256p.encoded_oid),
b"06082a8648ce3d030107")
def test_large_second_subid(self):
# from X.690, section 8.19.5
oid = encode_oid(2, 999, 3)
self.assertEqual(oid, b'\x06\x03\x88\x37\x03')
def test_with_two_subids(self):
oid = encode_oid(2, 999)
self.assertEqual(oid, b'\x06\x02\x88\x37')
def test_zero_zero(self):
oid = encode_oid(0, 0)
self.assertEqual(oid, b'\x06\x01\x00')
def test_with_wrong_types(self):
with self.assertRaises((TypeError, AssertionError)):
encode_oid(0, None)
def test_with_small_first_large_second(self):
with self.assertRaises(AssertionError):
encode_oid(1, 40)
def test_small_first_max_second(self):
oid = encode_oid(1, 39)
self.assertEqual(oid, b'\x06\x01\x4f')
def test_with_invalid_first(self):
with self.assertRaises(AssertionError):
encode_oid(3, 39)
class TestRemoveObject(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.oid_ecPublicKey = encode_oid(1, 2, 840, 10045, 2, 1)
def test_pub_key_oid(self):
oid, rest = remove_object(self.oid_ecPublicKey)
self.assertEqual(rest, b'')
self.assertEqual(oid, (1, 2, 840, 10045, 2, 1))
def test_with_extra_bytes(self):
oid, rest = remove_object(self.oid_ecPublicKey + b'more')
self.assertEqual(rest, b'more')
self.assertEqual(oid, (1, 2, 840, 10045, 2, 1))
def test_with_large_second_subid(self):
# from X.690, section 8.19.5
oid, rest = remove_object(b'\x06\x03\x88\x37\x03')
self.assertEqual(rest, b'')
self.assertEqual(oid, (2, 999, 3))
def test_with_padded_first_subid(self):
with self.assertRaises(UnexpectedDER):
remove_object(b'\x06\x02\x80\x00')
def test_with_padded_second_subid(self):
with self.assertRaises(UnexpectedDER):
remove_object(b'\x06\x04\x88\x37\x80\x01')
def test_with_missing_last_byte_of_multi_byte(self):
with self.assertRaises(UnexpectedDER):
remove_object(b'\x06\x03\x88\x37\x83')
def test_with_two_subids(self):
oid, rest = remove_object(b'\x06\x02\x88\x37')
self.assertEqual(rest, b'')
self.assertEqual(oid, (2, 999))
def test_zero_zero(self):
oid, rest = remove_object(b'\x06\x01\x00')
self.assertEqual(rest, b'')
self.assertEqual(oid, (0, 0))
def test_empty_string(self):
with self.assertRaises(UnexpectedDER):
remove_object(b'')
def test_missing_length(self):
with self.assertRaises(UnexpectedDER):
remove_object(b'\x06')
def test_empty_oid(self):
with self.assertRaises(UnexpectedDER):
remove_object(b'\x06\x00')
def test_empty_oid_overflow(self):
with self.assertRaises(UnexpectedDER):
remove_object(b'\x06\x01')
def test_with_wrong_type(self):
with self.assertRaises(UnexpectedDER):
remove_object(b'\x04\x02\x88\x37')
def test_with_too_long_length(self):
with self.assertRaises(UnexpectedDER):
remove_object(b'\x06\x03\x88\x37')
@st.composite
def st_oid(draw, max_value=2**512, max_size=50):
"""
Hypothesis strategy that returns valid OBJECT IDENTIFIERs as tuples
:param max_value: maximum value of any single sub-identifier
:param max_size: maximum length of the generated OID
"""
first = draw(st.integers(min_value=0, max_value=2))
if first < 2:
second = draw(st.integers(min_value=0, max_value=39))
else:
second = draw(st.integers(min_value=0, max_value=max_value))
rest = draw(st.lists(st.integers(min_value=0, max_value=max_value),
max_size=max_size))
return (first, second) + tuple(rest)
@given(st_oid())
def test_oids(ids):
encoded_oid = encode_oid(*ids)
decoded_oid, rest = remove_object(encoded_oid)
assert rest == b''
assert decoded_oid == ids