diff options
Diffstat (limited to 'shamir.py')
-rw-r--r-- | shamir.py | 80 |
1 files changed, 80 insertions, 0 deletions
diff --git a/shamir.py b/shamir.py new file mode 100644 index 0000000..cb7312c --- /dev/null +++ b/shamir.py @@ -0,0 +1,80 @@ +# Code modified from https://github.com/kurtbrose/shamir/tree/master (under CC0 +# license, so we took liberty to refactor it here) + +import secrets + +# use the prime to be compatible with the Shamir tool developed by Ava Labs: +# https://github.com/ava-labs/mnemonic-shamir-secret-sharing-cli/tree/main +# This is a 257-bit prime, so the returned points could be either 256-bit or +# (infrequently) 257-bit. +_PRIME = 187110422339161656731757292403725394067928975545356095774785896842956550853219 +# Other good choices: +# 12th Mersenne Prime is 2**127 - 1 +# 13th Mersenne Prime is 2**521 - 1 + + +def split(secret, minimum, shares, prime=_PRIME): + def y(poly, x, prime): + # evaluate polynomial (coefficient tuple) at x + accum = 0 + for coeff in reversed(poly): + accum *= x + accum += coeff + accum %= prime + return accum + + if minimum > shares: + raise ValueError("pool secret would be irrecoverable") + poly = [secrets.randbelow(prime) for i in range(minimum - 1)] + poly.insert(0, secret) + return [y(poly, i, prime) for i in range(1, shares + 1)] + + +def _divmod(num, a, p): + # extended gcd will find ax + py = gcd(a, p) + # so if p is a big prime and a < p, then ax + py = gcd(a, p) = 1, + # then y = 0, so ax = 1, x will be the multiplicative inverse for a modulo p + x = 0 + y = 1 + last_x = 1 + last_y = 0 + while p != 0: + quot = a // p + a, p = p, a % p + x, last_x = last_x - quot * x, x + y, last_y = last_y - quot * y, y + return num * last_x + +def _lagrange_interpolate(x, x_s, y_s, p): + k = len(x_s) + assert k == len(set(x_s)), "points must be distinct" + + def prod(vals): # product of inputs + r = 1 + for v in vals: + r = (r * v) % p + return r + l_s = [] + n_all = prod(x - x_j for x_j in x_s) + for i in range(k): + others = list(x_s) + x_i = others.pop(i) + # \Prod_{j \neq i}{(x - x_j)} / \Prod_{j \neq i}{(x_i - x_j)} + l_s.append(_divmod( + n_all, + (prod(x_i - x_j for x_j in others) * (x - x_i)) % p, p)) + sum = 0 + for (y, l) in zip(y_s, l_s): + sum = (sum + (y * l) % p) % p + return (sum + p) % p + + +def combine(shares, prime=_PRIME): + ''' + Recover the secret from share points + (shares contain (x, y) as points on the polynomial) + ''' + if len(shares) < 2: + raise ValueError("need at least two shares") + x_s, y_s = zip(*shares) + return _lagrange_interpolate(0, x_s, y_s, prime) |