From 69c966134cdba3c8a38038b9fd28a0c78e88cc73 Mon Sep 17 00:00:00 2001 From: Determinant Date: Thu, 22 Aug 2024 17:39:58 -0700 Subject: support Shamir's secret sharing with minimal code (`shamir.py`) --- keytree.py | 90 ++++++++++++++++++++++++++++++++++++++++++++++-- shamir.py | 113 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 200 insertions(+), 3 deletions(-) create mode 100644 shamir.py diff --git a/keytree.py b/keytree.py index 78cdf4d..71d5998 100755 --- a/keytree.py +++ b/keytree.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3 +#!/usr/bin/env python3.10 # MIT License # # Copyright (c) 2020 Ted Yin @@ -47,6 +47,7 @@ import hmac import unicodedata import json from getpass import getpass as _getpass +from itertools import zip_longest import bech32 import mnemonic @@ -58,6 +59,7 @@ from sha3 import keccak_256 from uuid import uuid4 from Cryptodome.Cipher import AES from Cryptodome.Util import Counter +import shamir def getpass(prompt): @@ -124,7 +126,7 @@ def is_infinity(P): # parse256(p): interprets a 32-byte sequence as a 256-bit number, most # significant byte first. def parse256(p): - assert(len(p) == 32) + assert len(p) == 32 return int.from_bytes(p, byteorder='big') @@ -346,6 +348,42 @@ def save_to_mew(priv_keys, n=1 << 18, p=1, r=8, dklen=32): raise KeytreeError("failed while saving") +def to_chunks(n, iterable): + return zip_longest(*[iter(iterable)]*n, fillvalue=0) + +def shamir256_split(secret, t, n): + shares = [bytearray() for i in range(n)] + for chunk in to_chunks(32, secret): + secret = int.from_bytes(chunk, 'big') + while True: + points = shamir.split(secret, t, n) + good = True + for p in points: + if p.bit_length() > 256: + good = False + break + if good: + # all shares are within 256 bits + for (p, s) in zip(points, shares): + s.extend(p.to_bytes(32, 'big')) + break + return shares + + +def shamir256_combine(shares): + result = bytearray() + for shares in zip(*[[(i, g) for g in to_chunks(32, s)] for (i, s) in shares]): + shares = [(i, int.from_bytes(bytearray(p), 'big')) for (i, p) in shares] + try: + secret = shamir.combine(shares) + except ValueError: + raise KeytreeError("invalid Shamir recovery input") + if secret.bit_length() > 256: + raise KeytreeError("Shamir result is too long") + result.extend(secret.to_bytes(32, 'big')) + return result + + metamask_path = r"44'/60'/0'/0" avax_path = r"44'/9000'/0'/0" @@ -355,6 +393,10 @@ if __name__ == '__main__': parser.add_argument('--save', type=str, default=None, help='save mnemonic to a file (AVAX Wallet compatible)') parser.add_argument('--export-mew', action='store_true', default=False, help='export keys to MEW keystore files (mnemonic is NOT saved, only keys are saved)') parser.add_argument('--show-private', action='store_true', default=False, help='also show private keys and the mnemonic') + parser.add_argument('--gen-shamir', action='store_true', default=False, help='generate Shamir\'s secret shares') + parser.add_argument('--shamir-threshold', type=int, default=2, help='Shamir\'s secret sharing threshold (number of shares to decode)') + parser.add_argument('--shamir-num', type=int, default=3, help='Shamir\'s secret sharing number (total number of shares)') + parser.add_argument('--recover-shamir', type=str, default=None, help='recover the secret from Shamir shares') parser.add_argument('--custom', action='store_true', default=False, help='use an arbitrary word combination as mnemonic') parser.add_argument('--seed', action='store_true', default=False, help='load mnemonic from seed') parser.add_argument('--path', default=avax_path, help="path prefix for key deriving (e.g. \"{}\" for Metamask)".format(metamask_path)) @@ -371,13 +413,40 @@ if __name__ == '__main__': for arg in unknown: if len(arg) > 0: raise KeytreeError("invalid argument: `{}`".format(arg)) + shares = [] try: + mgen = mnemonic.Mnemonic(args.lang) if args.gen_mnemonic: - mgen = mnemonic.Mnemonic(args.lang) words = mgen.generate(256) else: if args.load: words = load_from_keystore(args.load) + elif args.recover_shamir: + try: + idxes = [int(i) for i in args.recover_shamir.split(',')] + except ValueError: + raise KeytreeError("invalid Shamir share spec, should be something like \"1,2\"") + custom_mnemonic = None + for idx in idxes: + swords = getpass('Enter the mnemonic for Shamir share #{}: '.format(idx)) + if len(swords) == 48: + if not custom_mnemonic: + raise KeytreeError("invalid Shamir share format") + custom_mnemonic = True + share = mgen.to_entropy(swords[:24]) + mgen.to_entropy(swords[24:]) + else: + if custom_mnemonic: + raise KeytreeError("invalid Shamir share format") + custom_mnemonic = False + try: + share = mgen.to_entropy(swords) + except ValueError: + raise KeytreeError('invalid mnemonic') + shares.append((idx, share)) + if custom_mnemonic: + seed = shamir256_combine(shares) + else: + words = mgen.to_mnemonic(shamir256_combine(shares)) elif not args.seed: words = getpass('Enter the mnemonic: ').strip() if not args.custom: @@ -402,6 +471,21 @@ if __name__ == '__main__': if not args.seed: print("KEEP THIS PRIVATE (mnemonic): {}".format(words)) print("KEEP THIS PRIVATE (seed): {}".format(seed.hex())) + if args.shamir_threshold: + if args.shamir_num > 20: + raise KeytreeError('Shamir threshold should be <= 20') + if args.shamir_threshold < 2 or args.shamir_threshold > args.shamir_num: + raise KeytreeError('Shamir threshold should be (2, N]') + if args.gen_shamir: + if args.seed: + shares = shamir256_split(seed, args.shamir_threshold, args.shamir_num) + for idx, share in enumerate(shares): + words = mgen.to_mnemonic(share[:32]) + mgen.to_mnemonic(share[32:]) + print("KEEP THIS PRIVATE: secret_{}: {}".format(idx + 1, words)) + else: + shares = shamir256_split(mgen.to_entropy(words), args.shamir_threshold, args.shamir_num) + for idx, share in enumerate(shares): + print("KEEP THIS PRIVATE: secret_{}: {}".format(idx + 1, mgen.to_mnemonic(share))) gen = BIP32(seed) if args.start_idx < 0 or args.end_idx < 0: raise KeytreeError("invalid start/end index") diff --git a/shamir.py b/shamir.py new file mode 100644 index 0000000..c7d606a --- /dev/null +++ b/shamir.py @@ -0,0 +1,113 @@ +# Code modified from https://github.com/kurtbrose/shamir/tree/master + +''' +An implementation of shamir secret sharing algorithm. + +To the extent possible under law, +all copyright and related or neighboring rights +are hereby waived. (CC0, see LICENSE file.) + +All possible patents arising from this code +are renounced under the terms of +the Open Web Foundation CLA 1.0 +(http://www.openwebfoundation.org/legal/the-owf-1-0-agreements/owfa-1-0) +''' + +from __future__ import division +import random +import functools + +# 12th Mersenne Prime is 2**127 - 1 +# use the prime to be compatible with the Shamir tool developed by Ava Labs: +# https://github.com/ava-labs/mnemonic-shamir-secret-sharing-cli/tree/main +# This is a 257-bit prime, so the returned points could be either 256-bit or +# (infrequently) 257-bit. +_PRIME = 187110422339161656731757292403725394067928975545356095774785896842956550853219 +# 13th Mersenne Prime is 2**521 - 1 + +def _eval_at(poly, x, prime): + 'evaluate polynomial (coefficient tuple) at x' + accum = 0 + for coeff in reversed(poly): + accum *= x + accum += coeff + accum %= prime + return accum + + +def split(secret, minimum, shares, prime=_PRIME): + ''' + Generates a random shamir pool, returns + the secret and the share points. + ''' + randint = functools.partial(random.SystemRandom().randint, 0) + if minimum > shares: + raise ValueError("pool secret would be irrecoverable") + poly = [randint(prime) for i in range(minimum - 1)] + poly.insert(0, secret) + return [_eval_at(poly, i, prime) for i in range(1, shares + 1)] + + +# division in integers modulus p means finding the inverse of the denominator +# modulo p and then multiplying the numerator by this inverse +# (Note: inverse of A is B such that A*B % p == 1) +# this can be computed via extended euclidean algorithm +# http://en.wikipedia.org/wiki/Modular_multiplicative_inverse#Computation +def _extended_gcd(a, b): + x = 0 + last_x = 1 + y = 1 + last_y = 0 + while b != 0: + quot = a // b + a, b = b, a%b + x, last_x = last_x - quot * x, x + y, last_y = last_y - quot * y, y + return last_x, last_y + + +def _divmod(num, den, p): + ''' + compute num / den modulo prime p + To explain what this means, the return + value will be such that the following is true: + den * _divmod(num, den, p) % p == num + ''' + inv, _ = _extended_gcd(den, p) + return num * inv + + +def _lagrange_interpolate(x, x_s, y_s, p): + ''' + Find the y-value for the given x, given n (x, y) points; + k points will define a polynomial of up to kth order + ''' + k = len(x_s) + assert k == len(set(x_s)), "points must be distinct" + def PI(vals): # upper-case PI -- product of inputs + accum = 1 + for v in vals: + accum *= v + return accum + nums = [] # avoid inexact division + dens = [] + for i in range(k): + others = list(x_s) + cur = others.pop(i) + nums.append(PI(x - o for o in others)) + dens.append(PI(cur - o for o in others)) + den = PI(dens) + num = sum([_divmod(nums[i] * den * y_s[i] % p, dens[i], p) + for i in range(k)]) + return (_divmod(num, den, p) + p) % p + + +def combine(shares, prime=_PRIME): + ''' + Recover the secret from share points + (x,y points on the polynomial) + ''' + if len(shares) < 2: + raise ValueError("need at least two shares") + x_s, y_s = zip(*shares) + return _lagrange_interpolate(0, x_s, y_s, prime) -- cgit v1.2.3-70-g09d2