aboutsummaryrefslogtreecommitdiff
path: root/shamir.py
diff options
context:
space:
mode:
Diffstat (limited to 'shamir.py')
-rw-r--r--shamir.py80
1 files changed, 80 insertions, 0 deletions
diff --git a/shamir.py b/shamir.py
new file mode 100644
index 0000000..cb7312c
--- /dev/null
+++ b/shamir.py
@@ -0,0 +1,80 @@
+# Code modified from https://github.com/kurtbrose/shamir/tree/master (under CC0
+# license, so we took liberty to refactor it here)
+
+import secrets
+
+# use the prime to be compatible with the Shamir tool developed by Ava Labs:
+# https://github.com/ava-labs/mnemonic-shamir-secret-sharing-cli/tree/main
+# This is a 257-bit prime, so the returned points could be either 256-bit or
+# (infrequently) 257-bit.
+_PRIME = 187110422339161656731757292403725394067928975545356095774785896842956550853219
+# Other good choices:
+# 12th Mersenne Prime is 2**127 - 1
+# 13th Mersenne Prime is 2**521 - 1
+
+
+def split(secret, minimum, shares, prime=_PRIME):
+ def y(poly, x, prime):
+ # evaluate polynomial (coefficient tuple) at x
+ accum = 0
+ for coeff in reversed(poly):
+ accum *= x
+ accum += coeff
+ accum %= prime
+ return accum
+
+ if minimum > shares:
+ raise ValueError("pool secret would be irrecoverable")
+ poly = [secrets.randbelow(prime) for i in range(minimum - 1)]
+ poly.insert(0, secret)
+ return [y(poly, i, prime) for i in range(1, shares + 1)]
+
+
+def _divmod(num, a, p):
+ # extended gcd will find ax + py = gcd(a, p)
+ # so if p is a big prime and a < p, then ax + py = gcd(a, p) = 1,
+ # then y = 0, so ax = 1, x will be the multiplicative inverse for a modulo p
+ x = 0
+ y = 1
+ last_x = 1
+ last_y = 0
+ while p != 0:
+ quot = a // p
+ a, p = p, a % p
+ x, last_x = last_x - quot * x, x
+ y, last_y = last_y - quot * y, y
+ return num * last_x
+
+def _lagrange_interpolate(x, x_s, y_s, p):
+ k = len(x_s)
+ assert k == len(set(x_s)), "points must be distinct"
+
+ def prod(vals): # product of inputs
+ r = 1
+ for v in vals:
+ r = (r * v) % p
+ return r
+ l_s = []
+ n_all = prod(x - x_j for x_j in x_s)
+ for i in range(k):
+ others = list(x_s)
+ x_i = others.pop(i)
+ # \Prod_{j \neq i}{(x - x_j)} / \Prod_{j \neq i}{(x_i - x_j)}
+ l_s.append(_divmod(
+ n_all,
+ (prod(x_i - x_j for x_j in others) * (x - x_i)) % p, p))
+ sum = 0
+ for (y, l) in zip(y_s, l_s):
+ sum = (sum + (y * l) % p) % p
+ return (sum + p) % p
+
+
+def combine(shares, prime=_PRIME):
+ '''
+ Recover the secret from share points
+ (shares contain (x, y) as points on the polynomial)
+ '''
+ if len(shares) < 2:
+ raise ValueError("need at least two shares")
+ x_s, y_s = zip(*shares)
+ return _lagrange_interpolate(0, x_s, y_s, prime)