aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xkeytree.py35
-rw-r--r--shamir.py23
2 files changed, 38 insertions, 20 deletions
diff --git a/keytree.py b/keytree.py
index fb6551d..58f5cf5 100755
--- a/keytree.py
+++ b/keytree.py
@@ -47,7 +47,7 @@ import hmac
import unicodedata
import json
from getpass import getpass as _getpass
-from itertools import zip_longest
+from itertools import zip_longest, combinations
import bech32
import mnemonic
@@ -376,8 +376,8 @@ def shamir256_combine(shares):
shares = [(i, int.from_bytes(bytearray(p), 'big')) for (i, p) in shares]
try:
secret = shamir.combine(shares)
- except ValueError:
- raise KeytreeError("invalid Shamir recovery input")
+ except ValueError as e:
+ raise KeytreeError("invalid Shamir recovery input", e)
if secret.bit_length() > 256:
raise KeytreeError("Shamir result is too long")
result.extend(secret.to_bytes(32, 'big'))
@@ -496,13 +496,32 @@ if __name__ == '__main__':
entropy = None
if entropy:
shares = shamir256_split(mgen.to_entropy(words), args.shamir_threshold, args.shamir_num)
- for idx, share in enumerate(shares):
- print("KEEP THIS PRIVATE (share) #{} {}".format(idx + 1, mgen.to_mnemonic(share)))
+ shares = [mgen.to_mnemonic(share) for share in shares]
+
+ # checking
+ for case in combinations(range(args.shamir_num), args.shamir_threshold):
+ verify = [(i + 1, mgen.to_entropy(shares[i])) for i in case]
+ recovered = mgen.to_mnemonic(shamir256_combine(verify))
+ if words != recovered:
+ raise KeytreeError('Shamir sanity check failed: {} = {}'.format(case, recovered))
+ print("checked {}".format(case))
else:
shares = shamir256_split(seed, args.shamir_threshold, args.shamir_num)
- for idx, share in enumerate(shares):
- words = mgen.to_mnemonic(share[:32]) + ' ' + mgen.to_mnemonic(share[32:])
- print("KEEP THIS PRIVATE (share) #{} {}".format(idx + 1, words))
+ shares = [mgen.to_mnemonic(share[:32]) + ' ' + mgen.to_mnemonic(share[32:]) for share in shares]
+
+ # checking
+ for case in combinations(range(args.shamir_num), args.shamir_threshold):
+ verify = []
+ for i in case:
+ swords = shares[i].split()
+ verify.append((i + 1, mgen.to_entropy(' '.join(swords[:24])) + mgen.to_entropy(' '.join(swords[24:]))))
+ recovered = shamir256_combine(verify)
+ if seed != recovered:
+ raise KeytreeError('Shamir sanity check failed: {} = {}'.format(case, recovered.hex()))
+ print("checked {}".format(case))
+
+ for idx, share in enumerate(shares):
+ print("KEEP THIS PRIVATE (share) #{} {}".format(idx + 1, share))
# derive the keys at the requested paths
diff --git a/shamir.py b/shamir.py
index 60a4d78..cb7312c 100644
--- a/shamir.py
+++ b/shamir.py
@@ -45,7 +45,6 @@ def _divmod(num, a, p):
y, last_y = last_y - quot * y, y
return num * last_x
-
def _lagrange_interpolate(x, x_s, y_s, p):
k = len(x_s)
assert k == len(set(x_s)), "points must be distinct"
@@ -55,19 +54,19 @@ def _lagrange_interpolate(x, x_s, y_s, p):
for v in vals:
r = (r * v) % p
return r
- nums = [] # avoid inexact division
- dens = []
+ l_s = []
+ n_all = prod(x - x_j for x_j in x_s)
for i in range(k):
others = list(x_s)
- cur = others.pop(i)
- # cur is the current i-th term: x_i
- # others is the list of terms excluding the current i-th term: x_j such that j != i
- nums.append(prod(x - o for o in others)) # \Prod_{j \neq i}{(x - x_j)}
- dens.append(prod(cur - o for o in others)) # \Prod_{j \neq i}{(x_i - x_j)}
- den = prod(dens)
- num = sum([_divmod((nums[i] * den * y_s[i]) % p, dens[i], p)
- for i in range(k)])
- return (_divmod(num, den, p) + p) % p
+ x_i = others.pop(i)
+ # \Prod_{j \neq i}{(x - x_j)} / \Prod_{j \neq i}{(x_i - x_j)}
+ l_s.append(_divmod(
+ n_all,
+ (prod(x_i - x_j for x_j in others) * (x - x_i)) % p, p))
+ sum = 0
+ for (y, l) in zip(y_s, l_s):
+ sum = (sum + (y * l) % p) % p
+ return (sum + p) % p
def combine(shares, prime=_PRIME):