aboutsummaryrefslogtreecommitdiff
path: root/plugin/evm/user.go
blob: cfa302bfff7ce262f598505fed6efcf07f312c6b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
// (c) 2019-2020, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package evm

import (
	"errors"
	"fmt"

	"github.com/ava-labs/avalanchego/database"
	"github.com/ava-labs/avalanchego/ids"
	"github.com/ava-labs/avalanchego/utils/crypto"
	"github.com/ava-labs/go-ethereum/common"
)

// Key in the database whose corresponding value is the list of
// addresses this user controls
var addressesKey = ids.Empty.Bytes()

var (
	errDBNil        = errors.New("db uninitialized")
	errKeyNil       = errors.New("key uninitialized")
	errEmptyAddress = errors.New("address is empty")
)

type user struct {
	// This user's database, acquired from the keystore
	db database.Database
}

// Get the addresses controlled by this user
func (u *user) getAddresses() ([]common.Address, error) {
	if u.db == nil {
		return nil, errDBNil
	}

	// If user has no addresses, return empty list
	hasAddresses, err := u.db.Has(addressesKey)
	if err != nil {
		return nil, err
	}
	if !hasAddresses {
		return nil, nil
	}

	// User has addresses. Get them.
	bytes, err := u.db.Get(addressesKey)
	if err != nil {
		return nil, err
	}
	addresses := []common.Address{}
	if err := Codec.Unmarshal(bytes, &addresses); err != nil {
		return nil, err
	}
	return addresses, nil
}

// controlsAddress returns true iff this user controls the given address
func (u *user) controlsAddress(address common.Address) (bool, error) {
	if u.db == nil {
		return false, errDBNil
		//} else if address.IsZero() {
		//	return false, errEmptyAddress
	}
	return u.db.Has(address.Bytes())
}

// putAddress persists that this user controls address controlled by [privKey]
func (u *user) putAddress(privKey *crypto.PrivateKeySECP256K1R) error {
	if privKey == nil {
		return errKeyNil
	}

	address := GetEthAddress(privKey) // address the privKey controls
	controlsAddress, err := u.controlsAddress(address)
	if err != nil {
		return err
	}
	if controlsAddress { // user already controls this address. Do nothing.
		return nil
	}

	if err := u.db.Put(address.Bytes(), privKey.Bytes()); err != nil { // Address --> private key
		return err
	}

	addresses := make([]common.Address, 0) // Add address to list of addresses user controls
	userHasAddresses, err := u.db.Has(addressesKey)
	if err != nil {
		return err
	}
	if userHasAddresses { // Get addresses this user already controls, if they exist
		if addresses, err = u.getAddresses(); err != nil {
			return err
		}
	}
	addresses = append(addresses, address)
	bytes, err := Codec.Marshal(addresses)
	if err != nil {
		return err
	}
	if err := u.db.Put(addressesKey, bytes); err != nil {
		return err
	}
	return nil
}

// Key returns the private key that controls the given address
func (u *user) getKey(address common.Address) (*crypto.PrivateKeySECP256K1R, error) {
	if u.db == nil {
		return nil, errDBNil
		//} else if address.IsZero() {
		//	return nil, errEmptyAddress
	}

	factory := crypto.FactorySECP256K1R{}
	bytes, err := u.db.Get(address.Bytes())
	if err != nil {
		return nil, err
	}
	sk, err := factory.ToPrivateKey(bytes)
	if err != nil {
		return nil, err
	}
	if sk, ok := sk.(*crypto.PrivateKeySECP256K1R); ok {
		return sk, nil
	}
	return nil, fmt.Errorf("expected private key to be type *crypto.PrivateKeySECP256K1R but is type %T", sk)
}

// Return all private keys controlled by this user
func (u *user) getKeys() ([]*crypto.PrivateKeySECP256K1R, error) {
	addrs, err := u.getAddresses()
	if err != nil {
		return nil, err
	}
	keys := make([]*crypto.PrivateKeySECP256K1R, len(addrs))
	for i, addr := range addrs {
		key, err := u.getKey(addr)
		if err != nil {
			return nil, err
		}
		keys[i] = key
	}
	return keys, nil
}