aboutsummaryrefslogtreecommitdiff
path: root/shamir.py
diff options
context:
space:
mode:
Diffstat (limited to 'shamir.py')
-rw-r--r--shamir.py63
1 files changed, 23 insertions, 40 deletions
diff --git a/shamir.py b/shamir.py
index c7d606a..8db66fc 100644
--- a/shamir.py
+++ b/shamir.py
@@ -1,51 +1,33 @@
-# Code modified from https://github.com/kurtbrose/shamir/tree/master
+# Code modified from https://github.com/kurtbrose/shamir/tree/master (under CC0
+# license, so we took liberty to refactor it here)
-'''
-An implementation of shamir secret sharing algorithm.
+import secrets
-To the extent possible under law,
-all copyright and related or neighboring rights
-are hereby waived. (CC0, see LICENSE file.)
-
-All possible patents arising from this code
-are renounced under the terms of
-the Open Web Foundation CLA 1.0
-(http://www.openwebfoundation.org/legal/the-owf-1-0-agreements/owfa-1-0)
-'''
-
-from __future__ import division
-import random
-import functools
-
-# 12th Mersenne Prime is 2**127 - 1
# use the prime to be compatible with the Shamir tool developed by Ava Labs:
# https://github.com/ava-labs/mnemonic-shamir-secret-sharing-cli/tree/main
# This is a 257-bit prime, so the returned points could be either 256-bit or
# (infrequently) 257-bit.
_PRIME = 187110422339161656731757292403725394067928975545356095774785896842956550853219
+# Other good choices:
+# 12th Mersenne Prime is 2**127 - 1
# 13th Mersenne Prime is 2**521 - 1
-def _eval_at(poly, x, prime):
- 'evaluate polynomial (coefficient tuple) at x'
- accum = 0
- for coeff in reversed(poly):
- accum *= x
- accum += coeff
- accum %= prime
- return accum
-
def split(secret, minimum, shares, prime=_PRIME):
- '''
- Generates a random shamir pool, returns
- the secret and the share points.
- '''
- randint = functools.partial(random.SystemRandom().randint, 0)
+ def y(poly, x, prime):
+ # evaluate polynomial (coefficient tuple) at x
+ accum = 0
+ for coeff in reversed(poly):
+ accum *= x
+ accum += coeff
+ accum %= prime
+ return accum
+
if minimum > shares:
raise ValueError("pool secret would be irrecoverable")
- poly = [randint(prime) for i in range(minimum - 1)]
+ poly = [secrets.randbelow(prime) for i in range(minimum - 1)]
poly.insert(0, secret)
- return [_eval_at(poly, i, prime) for i in range(1, shares + 1)]
+ return [y(poly, i, prime) for i in range(1, shares + 1)]
# division in integers modulus p means finding the inverse of the denominator
@@ -60,7 +42,7 @@ def _extended_gcd(a, b):
last_y = 0
while b != 0:
quot = a // b
- a, b = b, a%b
+ a, b = b, a % b
x, last_x = last_x - quot * x, x
y, last_y = last_y - quot * y, y
return last_x, last_y
@@ -84,7 +66,8 @@ def _lagrange_interpolate(x, x_s, y_s, p):
'''
k = len(x_s)
assert k == len(set(x_s)), "points must be distinct"
- def PI(vals): # upper-case PI -- product of inputs
+
+ def prod(vals): # product of inputs
accum = 1
for v in vals:
accum *= v
@@ -94,9 +77,9 @@ def _lagrange_interpolate(x, x_s, y_s, p):
for i in range(k):
others = list(x_s)
cur = others.pop(i)
- nums.append(PI(x - o for o in others))
- dens.append(PI(cur - o for o in others))
- den = PI(dens)
+ nums.append(prod(x - o for o in others))
+ dens.append(prod(cur - o for o in others))
+ den = prod(dens)
num = sum([_divmod(nums[i] * den * y_s[i] % p, dens[i], p)
for i in range(k)])
return (_divmod(num, den, p) + p) % p
@@ -105,7 +88,7 @@ def _lagrange_interpolate(x, x_s, y_s, p):
def combine(shares, prime=_PRIME):
'''
Recover the secret from share points
- (x,y points on the polynomial)
+ (shares contain (x, y) as points on the polynomial)
'''
if len(shares) < 2:
raise ValueError("need at least two shares")