From 97a5c7f9f0149192ebec2fc84b262f3addfb992c Mon Sep 17 00:00:00 2001 From: Determinant Date: Fri, 23 Aug 2024 00:12:18 -0700 Subject: clean up interpolation code; add sanity check --- keytree.py | 35 +++++++++++++++++++++++++++-------- shamir.py | 23 +++++++++++------------ 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/keytree.py b/keytree.py index fb6551d..58f5cf5 100755 --- a/keytree.py +++ b/keytree.py @@ -47,7 +47,7 @@ import hmac import unicodedata import json from getpass import getpass as _getpass -from itertools import zip_longest +from itertools import zip_longest, combinations import bech32 import mnemonic @@ -376,8 +376,8 @@ def shamir256_combine(shares): shares = [(i, int.from_bytes(bytearray(p), 'big')) for (i, p) in shares] try: secret = shamir.combine(shares) - except ValueError: - raise KeytreeError("invalid Shamir recovery input") + except ValueError as e: + raise KeytreeError("invalid Shamir recovery input", e) if secret.bit_length() > 256: raise KeytreeError("Shamir result is too long") result.extend(secret.to_bytes(32, 'big')) @@ -496,13 +496,32 @@ if __name__ == '__main__': entropy = None if entropy: shares = shamir256_split(mgen.to_entropy(words), args.shamir_threshold, args.shamir_num) - for idx, share in enumerate(shares): - print("KEEP THIS PRIVATE (share) #{} {}".format(idx + 1, mgen.to_mnemonic(share))) + shares = [mgen.to_mnemonic(share) for share in shares] + + # checking + for case in combinations(range(args.shamir_num), args.shamir_threshold): + verify = [(i + 1, mgen.to_entropy(shares[i])) for i in case] + recovered = mgen.to_mnemonic(shamir256_combine(verify)) + if words != recovered: + raise KeytreeError('Shamir sanity check failed: {} = {}'.format(case, recovered)) + print("checked {}".format(case)) else: shares = shamir256_split(seed, args.shamir_threshold, args.shamir_num) - for idx, share in enumerate(shares): - words = mgen.to_mnemonic(share[:32]) + ' ' + mgen.to_mnemonic(share[32:]) - print("KEEP THIS PRIVATE (share) #{} {}".format(idx + 1, words)) + shares = [mgen.to_mnemonic(share[:32]) + ' ' + mgen.to_mnemonic(share[32:]) for share in shares] + + # checking + for case in combinations(range(args.shamir_num), args.shamir_threshold): + verify = [] + for i in case: + swords = shares[i].split() + verify.append((i + 1, mgen.to_entropy(' '.join(swords[:24])) + mgen.to_entropy(' '.join(swords[24:])))) + recovered = shamir256_combine(verify) + if seed != recovered: + raise KeytreeError('Shamir sanity check failed: {} = {}'.format(case, recovered.hex())) + print("checked {}".format(case)) + + for idx, share in enumerate(shares): + print("KEEP THIS PRIVATE (share) #{} {}".format(idx + 1, share)) # derive the keys at the requested paths diff --git a/shamir.py b/shamir.py index 60a4d78..cb7312c 100644 --- a/shamir.py +++ b/shamir.py @@ -45,7 +45,6 @@ def _divmod(num, a, p): y, last_y = last_y - quot * y, y return num * last_x - def _lagrange_interpolate(x, x_s, y_s, p): k = len(x_s) assert k == len(set(x_s)), "points must be distinct" @@ -55,19 +54,19 @@ def _lagrange_interpolate(x, x_s, y_s, p): for v in vals: r = (r * v) % p return r - nums = [] # avoid inexact division - dens = [] + l_s = [] + n_all = prod(x - x_j for x_j in x_s) for i in range(k): others = list(x_s) - cur = others.pop(i) - # cur is the current i-th term: x_i - # others is the list of terms excluding the current i-th term: x_j such that j != i - nums.append(prod(x - o for o in others)) # \Prod_{j \neq i}{(x - x_j)} - dens.append(prod(cur - o for o in others)) # \Prod_{j \neq i}{(x_i - x_j)} - den = prod(dens) - num = sum([_divmod((nums[i] * den * y_s[i]) % p, dens[i], p) - for i in range(k)]) - return (_divmod(num, den, p) + p) % p + x_i = others.pop(i) + # \Prod_{j \neq i}{(x - x_j)} / \Prod_{j \neq i}{(x_i - x_j)} + l_s.append(_divmod( + n_all, + (prod(x_i - x_j for x_j in others) * (x - x_i)) % p, p)) + sum = 0 + for (y, l) in zip(y_s, l_s): + sum = (sum + (y * l) % p) % p + return (sum + p) % p def combine(shares, prime=_PRIME): -- cgit v1.2.3-70-g09d2