# Code modified from https://github.com/kurtbrose/shamir/tree/master (under CC0
# license, so we took liberty to refactor it here)
import secrets
# use the prime to be compatible with the Shamir tool developed by Ava Labs:
# https://github.com/ava-labs/mnemonic-shamir-secret-sharing-cli/tree/main
# This is a 257-bit prime, so the returned points could be either 256-bit or
# (infrequently) 257-bit.
_PRIME = 187110422339161656731757292403725394067928975545356095774785896842956550853219
# Other good choices:
# 12th Mersenne Prime is 2**127 - 1
# 13th Mersenne Prime is 2**521 - 1
def split(secret, minimum, shares, prime=_PRIME):
def y(poly, x, prime):
# evaluate polynomial (coefficient tuple) at x
accum = 0
for coeff in reversed(poly):
accum *= x
accum += coeff
accum %= prime
return accum
if minimum > shares:
raise ValueError("pool secret would be irrecoverable")
poly = [secrets.randbelow(prime) for i in range(minimum - 1)]
poly.insert(0, secret)
return [y(poly, i, prime) for i in range(1, shares + 1)]
# division in integers modulus p means finding the inverse of the denominator
# modulo p and then multiplying the numerator by this inverse
# (Note: inverse of A is B such that A*B % p == 1)
# this can be computed via extended euclidean algorithm
# http://en.wikipedia.org/wiki/Modular_multiplicative_inverse#Computation
def _extended_gcd(a, b):
x = 0
last_x = 1
y = 1
last_y = 0
while b != 0:
quot = a // b
a, b = b, a % b
x, last_x = last_x - quot * x, x
y, last_y = last_y - quot * y, y
return last_x, last_y
def _divmod(num, den, p):
'''
compute num / den modulo prime p
To explain what this means, the return
value will be such that the following is true:
den * _divmod(num, den, p) % p == num
'''
inv, _ = _extended_gcd(den, p)
return num * inv
def _lagrange_interpolate(x, x_s, y_s, p):
'''
Find the y-value for the given x, given n (x, y) points;
k points will define a polynomial of up to kth order
'''
k = len(x_s)
assert k == len(set(x_s)), "points must be distinct"
def prod(vals): # product of inputs
accum = 1
for v in vals:
accum *= v
return accum
nums = [] # avoid inexact division
dens = []
for i in range(k):
others = list(x_s)
cur = others.pop(i)
nums.append(prod(x - o for o in others))
dens.append(prod(cur - o for o in others))
den = prod(dens)
num = sum([_divmod(nums[i] * den * y_s[i] % p, dens[i], p)
for i in range(k)])
return (_divmod(num, den, p) + p) % p
def combine(shares, prime=_PRIME):
'''
Recover the secret from share points
(shares contain (x, y) as points on the polynomial)
'''
if len(shares) < 2:
raise ValueError("need at least two shares")
x_s, y_s = zip(*shares)
return _lagrange_interpolate(0, x_s, y_s, prime)