From 33e459095ae82d2a6123012b82173fd77e7d6ba8 Mon Sep 17 00:00:00 2001 From: Determinant Date: Thu, 22 Aug 2024 22:16:25 -0700 Subject: clean up the shamir code --- shamir.py | 63 +++++++++++++++++++++++---------------------------------------- 1 file changed, 23 insertions(+), 40 deletions(-) (limited to 'shamir.py') diff --git a/shamir.py b/shamir.py index c7d606a..8db66fc 100644 --- a/shamir.py +++ b/shamir.py @@ -1,51 +1,33 @@ -# Code modified from https://github.com/kurtbrose/shamir/tree/master +# Code modified from https://github.com/kurtbrose/shamir/tree/master (under CC0 +# license, so we took liberty to refactor it here) -''' -An implementation of shamir secret sharing algorithm. +import secrets -To the extent possible under law, -all copyright and related or neighboring rights -are hereby waived. (CC0, see LICENSE file.) - -All possible patents arising from this code -are renounced under the terms of -the Open Web Foundation CLA 1.0 -(http://www.openwebfoundation.org/legal/the-owf-1-0-agreements/owfa-1-0) -''' - -from __future__ import division -import random -import functools - -# 12th Mersenne Prime is 2**127 - 1 # use the prime to be compatible with the Shamir tool developed by Ava Labs: # https://github.com/ava-labs/mnemonic-shamir-secret-sharing-cli/tree/main # This is a 257-bit prime, so the returned points could be either 256-bit or # (infrequently) 257-bit. _PRIME = 187110422339161656731757292403725394067928975545356095774785896842956550853219 +# Other good choices: +# 12th Mersenne Prime is 2**127 - 1 # 13th Mersenne Prime is 2**521 - 1 -def _eval_at(poly, x, prime): - 'evaluate polynomial (coefficient tuple) at x' - accum = 0 - for coeff in reversed(poly): - accum *= x - accum += coeff - accum %= prime - return accum - def split(secret, minimum, shares, prime=_PRIME): - ''' - Generates a random shamir pool, returns - the secret and the share points. - ''' - randint = functools.partial(random.SystemRandom().randint, 0) + def y(poly, x, prime): + # evaluate polynomial (coefficient tuple) at x + accum = 0 + for coeff in reversed(poly): + accum *= x + accum += coeff + accum %= prime + return accum + if minimum > shares: raise ValueError("pool secret would be irrecoverable") - poly = [randint(prime) for i in range(minimum - 1)] + poly = [secrets.randbelow(prime) for i in range(minimum - 1)] poly.insert(0, secret) - return [_eval_at(poly, i, prime) for i in range(1, shares + 1)] + return [y(poly, i, prime) for i in range(1, shares + 1)] # division in integers modulus p means finding the inverse of the denominator @@ -60,7 +42,7 @@ def _extended_gcd(a, b): last_y = 0 while b != 0: quot = a // b - a, b = b, a%b + a, b = b, a % b x, last_x = last_x - quot * x, x y, last_y = last_y - quot * y, y return last_x, last_y @@ -84,7 +66,8 @@ def _lagrange_interpolate(x, x_s, y_s, p): ''' k = len(x_s) assert k == len(set(x_s)), "points must be distinct" - def PI(vals): # upper-case PI -- product of inputs + + def prod(vals): # product of inputs accum = 1 for v in vals: accum *= v @@ -94,9 +77,9 @@ def _lagrange_interpolate(x, x_s, y_s, p): for i in range(k): others = list(x_s) cur = others.pop(i) - nums.append(PI(x - o for o in others)) - dens.append(PI(cur - o for o in others)) - den = PI(dens) + nums.append(prod(x - o for o in others)) + dens.append(prod(cur - o for o in others)) + den = prod(dens) num = sum([_divmod(nums[i] * den * y_s[i] % p, dens[i], p) for i in range(k)]) return (_divmod(num, den, p) + p) % p @@ -105,7 +88,7 @@ def _lagrange_interpolate(x, x_s, y_s, p): def combine(shares, prime=_PRIME): ''' Recover the secret from share points - (x,y points on the polynomial) + (shares contain (x, y) as points on the polynomial) ''' if len(shares) < 2: raise ValueError("need at least two shares") -- cgit v1.2.3-70-g09d2