aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xkeytree.py90
-rw-r--r--shamir.py113
2 files changed, 200 insertions, 3 deletions
diff --git a/keytree.py b/keytree.py
index 78cdf4d..71d5998 100755
--- a/keytree.py
+++ b/keytree.py
@@ -1,4 +1,4 @@
-#!/usr/bin/env python3
+#!/usr/bin/env python3.10
# MIT License
#
# Copyright (c) 2020 Ted Yin <[email protected]>
@@ -47,6 +47,7 @@ import hmac
import unicodedata
import json
from getpass import getpass as _getpass
+from itertools import zip_longest
import bech32
import mnemonic
@@ -58,6 +59,7 @@ from sha3 import keccak_256
from uuid import uuid4
from Cryptodome.Cipher import AES
from Cryptodome.Util import Counter
+import shamir
def getpass(prompt):
@@ -124,7 +126,7 @@ def is_infinity(P):
# parse256(p): interprets a 32-byte sequence as a 256-bit number, most
# significant byte first.
def parse256(p):
- assert(len(p) == 32)
+ assert len(p) == 32
return int.from_bytes(p, byteorder='big')
@@ -346,6 +348,42 @@ def save_to_mew(priv_keys, n=1 << 18, p=1, r=8, dklen=32):
raise KeytreeError("failed while saving")
+def to_chunks(n, iterable):
+ return zip_longest(*[iter(iterable)]*n, fillvalue=0)
+
+def shamir256_split(secret, t, n):
+ shares = [bytearray() for i in range(n)]
+ for chunk in to_chunks(32, secret):
+ secret = int.from_bytes(chunk, 'big')
+ while True:
+ points = shamir.split(secret, t, n)
+ good = True
+ for p in points:
+ if p.bit_length() > 256:
+ good = False
+ break
+ if good:
+ # all shares are within 256 bits
+ for (p, s) in zip(points, shares):
+ s.extend(p.to_bytes(32, 'big'))
+ break
+ return shares
+
+
+def shamir256_combine(shares):
+ result = bytearray()
+ for shares in zip(*[[(i, g) for g in to_chunks(32, s)] for (i, s) in shares]):
+ shares = [(i, int.from_bytes(bytearray(p), 'big')) for (i, p) in shares]
+ try:
+ secret = shamir.combine(shares)
+ except ValueError:
+ raise KeytreeError("invalid Shamir recovery input")
+ if secret.bit_length() > 256:
+ raise KeytreeError("Shamir result is too long")
+ result.extend(secret.to_bytes(32, 'big'))
+ return result
+
+
metamask_path = r"44'/60'/0'/0"
avax_path = r"44'/9000'/0'/0"
@@ -355,6 +393,10 @@ if __name__ == '__main__':
parser.add_argument('--save', type=str, default=None, help='save mnemonic to a file (AVAX Wallet compatible)')
parser.add_argument('--export-mew', action='store_true', default=False, help='export keys to MEW keystore files (mnemonic is NOT saved, only keys are saved)')
parser.add_argument('--show-private', action='store_true', default=False, help='also show private keys and the mnemonic')
+ parser.add_argument('--gen-shamir', action='store_true', default=False, help='generate Shamir\'s secret shares')
+ parser.add_argument('--shamir-threshold', type=int, default=2, help='Shamir\'s secret sharing threshold (number of shares to decode)')
+ parser.add_argument('--shamir-num', type=int, default=3, help='Shamir\'s secret sharing number (total number of shares)')
+ parser.add_argument('--recover-shamir', type=str, default=None, help='recover the secret from Shamir shares')
parser.add_argument('--custom', action='store_true', default=False, help='use an arbitrary word combination as mnemonic')
parser.add_argument('--seed', action='store_true', default=False, help='load mnemonic from seed')
parser.add_argument('--path', default=avax_path, help="path prefix for key deriving (e.g. \"{}\" for Metamask)".format(metamask_path))
@@ -371,13 +413,40 @@ if __name__ == '__main__':
for arg in unknown:
if len(arg) > 0:
raise KeytreeError("invalid argument: `{}`".format(arg))
+ shares = []
try:
+ mgen = mnemonic.Mnemonic(args.lang)
if args.gen_mnemonic:
- mgen = mnemonic.Mnemonic(args.lang)
words = mgen.generate(256)
else:
if args.load:
words = load_from_keystore(args.load)
+ elif args.recover_shamir:
+ try:
+ idxes = [int(i) for i in args.recover_shamir.split(',')]
+ except ValueError:
+ raise KeytreeError("invalid Shamir share spec, should be something like \"1,2\"")
+ custom_mnemonic = None
+ for idx in idxes:
+ swords = getpass('Enter the mnemonic for Shamir share #{}: '.format(idx))
+ if len(swords) == 48:
+ if not custom_mnemonic:
+ raise KeytreeError("invalid Shamir share format")
+ custom_mnemonic = True
+ share = mgen.to_entropy(swords[:24]) + mgen.to_entropy(swords[24:])
+ else:
+ if custom_mnemonic:
+ raise KeytreeError("invalid Shamir share format")
+ custom_mnemonic = False
+ try:
+ share = mgen.to_entropy(swords)
+ except ValueError:
+ raise KeytreeError('invalid mnemonic')
+ shares.append((idx, share))
+ if custom_mnemonic:
+ seed = shamir256_combine(shares)
+ else:
+ words = mgen.to_mnemonic(shamir256_combine(shares))
elif not args.seed:
words = getpass('Enter the mnemonic: ').strip()
if not args.custom:
@@ -402,6 +471,21 @@ if __name__ == '__main__':
if not args.seed:
print("KEEP THIS PRIVATE (mnemonic): {}".format(words))
print("KEEP THIS PRIVATE (seed): {}".format(seed.hex()))
+ if args.shamir_threshold:
+ if args.shamir_num > 20:
+ raise KeytreeError('Shamir threshold should be <= 20')
+ if args.shamir_threshold < 2 or args.shamir_threshold > args.shamir_num:
+ raise KeytreeError('Shamir threshold should be (2, N]')
+ if args.gen_shamir:
+ if args.seed:
+ shares = shamir256_split(seed, args.shamir_threshold, args.shamir_num)
+ for idx, share in enumerate(shares):
+ words = mgen.to_mnemonic(share[:32]) + mgen.to_mnemonic(share[32:])
+ print("KEEP THIS PRIVATE: secret_{}: {}".format(idx + 1, words))
+ else:
+ shares = shamir256_split(mgen.to_entropy(words), args.shamir_threshold, args.shamir_num)
+ for idx, share in enumerate(shares):
+ print("KEEP THIS PRIVATE: secret_{}: {}".format(idx + 1, mgen.to_mnemonic(share)))
gen = BIP32(seed)
if args.start_idx < 0 or args.end_idx < 0:
raise KeytreeError("invalid start/end index")
diff --git a/shamir.py b/shamir.py
new file mode 100644
index 0000000..c7d606a
--- /dev/null
+++ b/shamir.py
@@ -0,0 +1,113 @@
+# Code modified from https://github.com/kurtbrose/shamir/tree/master
+
+'''
+An implementation of shamir secret sharing algorithm.
+
+To the extent possible under law,
+all copyright and related or neighboring rights
+are hereby waived. (CC0, see LICENSE file.)
+
+All possible patents arising from this code
+are renounced under the terms of
+the Open Web Foundation CLA 1.0
+(http://www.openwebfoundation.org/legal/the-owf-1-0-agreements/owfa-1-0)
+'''
+
+from __future__ import division
+import random
+import functools
+
+# 12th Mersenne Prime is 2**127 - 1
+# use the prime to be compatible with the Shamir tool developed by Ava Labs:
+# https://github.com/ava-labs/mnemonic-shamir-secret-sharing-cli/tree/main
+# This is a 257-bit prime, so the returned points could be either 256-bit or
+# (infrequently) 257-bit.
+_PRIME = 187110422339161656731757292403725394067928975545356095774785896842956550853219
+# 13th Mersenne Prime is 2**521 - 1
+
+def _eval_at(poly, x, prime):
+ 'evaluate polynomial (coefficient tuple) at x'
+ accum = 0
+ for coeff in reversed(poly):
+ accum *= x
+ accum += coeff
+ accum %= prime
+ return accum
+
+
+def split(secret, minimum, shares, prime=_PRIME):
+ '''
+ Generates a random shamir pool, returns
+ the secret and the share points.
+ '''
+ randint = functools.partial(random.SystemRandom().randint, 0)
+ if minimum > shares:
+ raise ValueError("pool secret would be irrecoverable")
+ poly = [randint(prime) for i in range(minimum - 1)]
+ poly.insert(0, secret)
+ return [_eval_at(poly, i, prime) for i in range(1, shares + 1)]
+
+
+# division in integers modulus p means finding the inverse of the denominator
+# modulo p and then multiplying the numerator by this inverse
+# (Note: inverse of A is B such that A*B % p == 1)
+# this can be computed via extended euclidean algorithm
+# http://en.wikipedia.org/wiki/Modular_multiplicative_inverse#Computation
+def _extended_gcd(a, b):
+ x = 0
+ last_x = 1
+ y = 1
+ last_y = 0
+ while b != 0:
+ quot = a // b
+ a, b = b, a%b
+ x, last_x = last_x - quot * x, x
+ y, last_y = last_y - quot * y, y
+ return last_x, last_y
+
+
+def _divmod(num, den, p):
+ '''
+ compute num / den modulo prime p
+ To explain what this means, the return
+ value will be such that the following is true:
+ den * _divmod(num, den, p) % p == num
+ '''
+ inv, _ = _extended_gcd(den, p)
+ return num * inv
+
+
+def _lagrange_interpolate(x, x_s, y_s, p):
+ '''
+ Find the y-value for the given x, given n (x, y) points;
+ k points will define a polynomial of up to kth order
+ '''
+ k = len(x_s)
+ assert k == len(set(x_s)), "points must be distinct"
+ def PI(vals): # upper-case PI -- product of inputs
+ accum = 1
+ for v in vals:
+ accum *= v
+ return accum
+ nums = [] # avoid inexact division
+ dens = []
+ for i in range(k):
+ others = list(x_s)
+ cur = others.pop(i)
+ nums.append(PI(x - o for o in others))
+ dens.append(PI(cur - o for o in others))
+ den = PI(dens)
+ num = sum([_divmod(nums[i] * den * y_s[i] % p, dens[i], p)
+ for i in range(k)])
+ return (_divmod(num, den, p) + p) % p
+
+
+def combine(shares, prime=_PRIME):
+ '''
+ Recover the secret from share points
+ (x,y points on the polynomial)
+ '''
+ if len(shares) < 2:
+ raise ValueError("need at least two shares")
+ x_s, y_s = zip(*shares)
+ return _lagrange_interpolate(0, x_s, y_s, prime)