aboutsummaryrefslogtreecommitdiff
path: root/rpc/http.go
diff options
context:
space:
mode:
Diffstat (limited to 'rpc/http.go')
-rw-r--r--rpc/http.go163
1 files changed, 43 insertions, 120 deletions
diff --git a/rpc/http.go b/rpc/http.go
index 2dffc5d..87a96e4 100644
--- a/rpc/http.go
+++ b/rpc/http.go
@@ -25,14 +25,10 @@ import (
"io"
"io/ioutil"
"mime"
- "net"
"net/http"
- "strings"
+ "net/url"
"sync"
"time"
-
- "github.com/ava-labs/go-ethereum/log"
- "github.com/rs/cors"
)
const (
@@ -45,31 +41,33 @@ var acceptedContentTypes = []string{contentType, "application/json-rpc", "applic
type httpConn struct {
client *http.Client
- req *http.Request
+ url string
closeOnce sync.Once
- closed chan interface{}
+ closeCh chan interface{}
+ mu sync.Mutex // protects headers
+ headers http.Header
}
// httpConn is treated specially by Client.
-func (hc *httpConn) Write(context.Context, interface{}) error {
- panic("Write called on httpConn")
+func (hc *httpConn) writeJSON(context.Context, interface{}) error {
+ panic("writeJSON called on httpConn")
}
-func (hc *httpConn) RemoteAddr() string {
- return hc.req.URL.String()
+func (hc *httpConn) remoteAddr() string {
+ return hc.url
}
-func (hc *httpConn) Read() ([]*jsonrpcMessage, bool, error) {
- <-hc.closed
+func (hc *httpConn) readBatch() ([]*jsonrpcMessage, bool, error) {
+ <-hc.closeCh
return nil, false, io.EOF
}
-func (hc *httpConn) Close() {
- hc.closeOnce.Do(func() { close(hc.closed) })
+func (hc *httpConn) close() {
+ hc.closeOnce.Do(func() { close(hc.closeCh) })
}
-func (hc *httpConn) Closed() <-chan interface{} {
- return hc.closed
+func (hc *httpConn) closed() <-chan interface{} {
+ return hc.closeCh
}
// HTTPTimeouts represents the configuration params for the HTTP RPC server.
@@ -107,16 +105,24 @@ var DefaultHTTPTimeouts = HTTPTimeouts{
// DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP
// using the provided HTTP Client.
func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
- req, err := http.NewRequest(http.MethodPost, endpoint, nil)
+ // Sanity check URL so we don't end up with a client that will fail every request.
+ _, err := url.Parse(endpoint)
if err != nil {
return nil, err
}
- req.Header.Set("Content-Type", contentType)
- req.Header.Set("Accept", contentType)
initctx := context.Background()
+ headers := make(http.Header, 2)
+ headers.Set("accept", contentType)
+ headers.Set("content-type", contentType)
return newClient(initctx, func(context.Context) (ServerCodec, error) {
- return &httpConn{client: client, req: req, closed: make(chan interface{})}, nil
+ hc := &httpConn{
+ client: client,
+ headers: headers,
+ url: endpoint,
+ closeCh: make(chan interface{}),
+ }
+ return hc, nil
})
}
@@ -136,7 +142,7 @@ func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) e
if respBody != nil {
buf := new(bytes.Buffer)
if _, err2 := buf.ReadFrom(respBody); err2 == nil {
- return fmt.Errorf("%v %v", err, buf.String())
+ return fmt.Errorf("%v: %v", err, buf.String())
}
}
return err
@@ -171,10 +177,18 @@ func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadClos
if err != nil {
return nil, err
}
- req := hc.req.WithContext(ctx)
- req.Body = ioutil.NopCloser(bytes.NewReader(body))
+ req, err := http.NewRequestWithContext(ctx, "POST", hc.url, ioutil.NopCloser(bytes.NewReader(body)))
+ if err != nil {
+ return nil, err
+ }
req.ContentLength = int64(len(body))
+ // set headers
+ hc.mu.Lock()
+ req.Header = hc.headers.Clone()
+ hc.mu.Unlock()
+
+ // do request
resp, err := hc.client.Do(req)
if err != nil {
return nil, err
@@ -195,7 +209,7 @@ type httpServerConn struct {
func newHTTPServerConn(r *http.Request, w http.ResponseWriter) ServerCodec {
body := io.LimitReader(r.Body, maxRequestContentLength)
conn := &httpServerConn{Reader: body, Writer: w, r: r}
- return NewJSONCodec(conn)
+ return NewCodec(conn)
}
// Close does nothing and always returns nil.
@@ -209,49 +223,19 @@ func (t *httpServerConn) RemoteAddr() string {
// SetWriteDeadline does nothing and always returns nil.
func (t *httpServerConn) SetWriteDeadline(time.Time) error { return nil }
-// NewHTTPServer creates a new HTTP RPC server around an API provider.
-//
-// Deprecated: Server implements http.Handler
-func NewHTTPServer(cors []string, vhosts []string, timeouts HTTPTimeouts, srv http.Handler) *http.Server {
- // Wrap the CORS-handler within a host-handler
- handler := newCorsHandler(srv, cors)
- handler = newVHostHandler(vhosts, handler)
- handler = newGzipHandler(handler)
-
- // Make sure timeout values are meaningful
- if timeouts.ReadTimeout < time.Second {
- log.Warn("Sanitizing invalid HTTP read timeout", "provided", timeouts.ReadTimeout, "updated", DefaultHTTPTimeouts.ReadTimeout)
- timeouts.ReadTimeout = DefaultHTTPTimeouts.ReadTimeout
- }
- if timeouts.WriteTimeout < time.Second {
- log.Warn("Sanitizing invalid HTTP write timeout", "provided", timeouts.WriteTimeout, "updated", DefaultHTTPTimeouts.WriteTimeout)
- timeouts.WriteTimeout = DefaultHTTPTimeouts.WriteTimeout
- }
- if timeouts.IdleTimeout < time.Second {
- log.Warn("Sanitizing invalid HTTP idle timeout", "provided", timeouts.IdleTimeout, "updated", DefaultHTTPTimeouts.IdleTimeout)
- timeouts.IdleTimeout = DefaultHTTPTimeouts.IdleTimeout
- }
- // Bundle and start the HTTP server
- return &http.Server{
- Handler: handler,
- ReadTimeout: timeouts.ReadTimeout,
- WriteTimeout: timeouts.WriteTimeout,
- IdleTimeout: timeouts.IdleTimeout,
- }
-}
-
// ServeHTTP serves JSON-RPC requests over HTTP.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Permit dumb empty requests for remote health-checks (AWS)
if r.Method == http.MethodGet && r.ContentLength == 0 && r.URL.RawQuery == "" {
+ w.WriteHeader(http.StatusOK)
return
}
if code, err := validateRequest(r); err != nil {
http.Error(w, err.Error(), code)
return
}
- // All checks passed, create a codec that reads direct from the request body
- // untilEOF and writes the response to w and order the server to process a
+ // All checks passed, create a codec that reads directly from the request body
+ // until EOF, writes the response to w, and orders the server to process a
// single request.
ctx := r.Context()
ctx = context.WithValue(ctx, "remote", r.RemoteAddr)
@@ -266,7 +250,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-type", contentType)
codec := newHTTPServerConn(r, w)
- defer codec.Close()
+ defer codec.close()
s.serveSingleRequest(ctx, codec)
}
@@ -296,64 +280,3 @@ func validateRequest(r *http.Request) (int, error) {
err := fmt.Errorf("invalid content type, only %s is supported", contentType)
return http.StatusUnsupportedMediaType, err
}
-
-func newCorsHandler(srv http.Handler, allowedOrigins []string) http.Handler {
- // disable CORS support if user has not specified a custom CORS configuration
- if len(allowedOrigins) == 0 {
- return srv
- }
- c := cors.New(cors.Options{
- AllowedOrigins: allowedOrigins,
- AllowedMethods: []string{http.MethodPost, http.MethodGet},
- MaxAge: 600,
- AllowedHeaders: []string{"*"},
- })
- return c.Handler(srv)
-}
-
-// virtualHostHandler is a handler which validates the Host-header of incoming requests.
-// The virtualHostHandler can prevent DNS rebinding attacks, which do not utilize CORS-headers,
-// since they do in-domain requests against the RPC api. Instead, we can see on the Host-header
-// which domain was used, and validate that against a whitelist.
-type virtualHostHandler struct {
- vhosts map[string]struct{}
- next http.Handler
-}
-
-// ServeHTTP serves JSON-RPC requests over HTTP, implements http.Handler
-func (h *virtualHostHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- // if r.Host is not set, we can continue serving since a browser would set the Host header
- if r.Host == "" {
- h.next.ServeHTTP(w, r)
- return
- }
- host, _, err := net.SplitHostPort(r.Host)
- if err != nil {
- // Either invalid (too many colons) or no port specified
- host = r.Host
- }
- if ipAddr := net.ParseIP(host); ipAddr != nil {
- // It's an IP address, we can serve that
- h.next.ServeHTTP(w, r)
- return
-
- }
- // Not an ip address, but a hostname. Need to validate
- if _, exist := h.vhosts["*"]; exist {
- h.next.ServeHTTP(w, r)
- return
- }
- if _, exist := h.vhosts[host]; exist {
- h.next.ServeHTTP(w, r)
- return
- }
- http.Error(w, "invalid host specified", http.StatusForbidden)
-}
-
-func newVHostHandler(vhosts []string, next http.Handler) http.Handler {
- vhostMap := make(map[string]struct{})
- for _, allowedHost := range vhosts {
- vhostMap[strings.ToLower(allowedHost)] = struct{}{}
- }
- return &virtualHostHandler{vhostMap, next}
-}