aboutsummaryrefslogblamecommitdiff
path: root/shamir.py
blob: cb7312cc246042e405932f1864492480278fce9c (plain) (tree)
1
2
3
4
5
6
7
8
9
10

                                                                               
 
              
 




                                                                                       

                                   

                                   

                                                 








                                                      

                                                              
                                                                 
                          
                                                            

 



                                                                                
         
         
              
              


                        

                                        
                       
 
                                          

                                                        

                                        
             
                      

                           

                                        

                          








                                                                     




                                        
                                                       




                                                    
# Code modified from https://github.com/kurtbrose/shamir/tree/master (under CC0
# license, so we took liberty to refactor it here)

import secrets

# use the prime to be compatible with the Shamir tool developed by Ava Labs:
# https://github.com/ava-labs/mnemonic-shamir-secret-sharing-cli/tree/main
# This is a 257-bit prime, so the returned points could be either 256-bit or
# (infrequently) 257-bit.
_PRIME = 187110422339161656731757292403725394067928975545356095774785896842956550853219
# Other good choices:
# 12th Mersenne Prime is 2**127 - 1
# 13th Mersenne Prime is 2**521 - 1


def split(secret, minimum, shares, prime=_PRIME):
    def y(poly, x, prime):
        # evaluate polynomial (coefficient tuple) at x
        accum = 0
        for coeff in reversed(poly):
            accum *= x
            accum += coeff
            accum %= prime
        return accum

    if minimum > shares:
        raise ValueError("pool secret would be irrecoverable")
    poly = [secrets.randbelow(prime) for i in range(minimum - 1)]
    poly.insert(0, secret)
    return [y(poly, i, prime) for i in range(1, shares + 1)]


def _divmod(num, a, p):
    # extended gcd will find ax + py = gcd(a, p)
    # so if p is a big prime and a < p, then ax + py = gcd(a, p) = 1,
    # then y = 0, so ax = 1, x will be the multiplicative inverse for a modulo p
    x = 0
    y = 1
    last_x = 1
    last_y = 0
    while p != 0:
        quot = a // p
        a, p = p,  a % p
        x, last_x = last_x - quot * x, x
        y, last_y = last_y - quot * y, y
    return num * last_x

def _lagrange_interpolate(x, x_s, y_s, p):
    k = len(x_s)
    assert k == len(set(x_s)), "points must be distinct"

    def prod(vals):  # product of inputs
        r = 1
        for v in vals:
            r = (r * v) % p
        return r
    l_s = []
    n_all = prod(x - x_j for x_j in x_s)
    for i in range(k):
        others = list(x_s)
        x_i = others.pop(i)
        # \Prod_{j \neq i}{(x - x_j)} / \Prod_{j \neq i}{(x_i - x_j)}
        l_s.append(_divmod(
            n_all,
            (prod(x_i - x_j for x_j in others) * (x - x_i)) % p, p))
    sum = 0
    for (y, l) in zip(y_s, l_s):
        sum = (sum + (y * l) % p) % p
    return (sum + p) % p


def combine(shares, prime=_PRIME):
    '''
    Recover the secret from share points
    (shares contain (x, y) as points on the polynomial)
    '''
    if len(shares) < 2:
        raise ValueError("need at least two shares")
    x_s, y_s = zip(*shares)
    return _lagrange_interpolate(0, x_s, y_s, prime)