diff options
Diffstat (limited to 'pkg/server/server.go')
| -rw-r--r-- | pkg/server/server.go | 117 | 
1 files changed, 72 insertions, 45 deletions
| diff --git a/pkg/server/server.go b/pkg/server/server.go index 2d10c4b..06eb258 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -3,104 +3,131 @@ package server  import (  	"context"  	"fmt" +	"io"  	"net"  	"net/http" -	"strings"  	"time" + +	"rgdd.se/silent-ct/internal/x509util"  )  const ( -	EndpointAddChain  = "/add-chain" -	EndpointGetStatus = "/get-status" +	EndpointAddChain  = "add-chain" +	EndpointGetStatus = "get-status" +	DefaultNetwork    = "tcp"  	DefaultAddress    = "localhost:2009"  	DefaultConfigFile = "/home/rgdd/.config/silent-ct/config.json" // FIXME  ) -type Nodes []Node - -type Node struct { -	Name    string   `json:"name"` -	Secret  string   `json:"secret"` -	Domains []string `json:"issues"` -} -  type Config struct { -	Address string // hostname[:port] or unix:///path/to/file.sock +	Network string // tcp or unix +	Address string // hostname[:port] or path to a unix socket  	Nodes   Nodes  // Which nodes are trusted to issue what certificates  }  type Server struct {  	Config -	http.Server -	unixSocket bool // true if listening with a unix socket -	eventCh    chan MessageNodeSubmission -	errorCh    chan error +	eventCh chan MessageNodeSubmission +	errorCh chan error  }  func New(cfg Config) (Server, error) { -	mux := http.NewServeMux() -	srv := Server{Config: cfg, Server: http.Server{Handler: mux}} -	mux.HandleFunc(EndpointAddChain, func(w http.ResponseWriter, r *http.Request) { srv.addChain(w, r) }) -	mux.HandleFunc(EndpointGetStatus, func(w http.ResponseWriter, r *http.Request) { srv.getStatus(w, r) }) -	if len(srv.Address) == 0 { -		srv.Config.Address = DefaultAddress +	if cfg.Network == "" { +		cfg.Network = DefaultNetwork  	} -	if strings.HasPrefix(srv.Config.Address, "unix://") { -		srv.Config.Address = srv.Config.Address[7:] -		srv.unixSocket = true +	if cfg.Address == "" { +		cfg.Network = DefaultAddress  	} -	return srv, nil +	return Server{Config: cfg}, nil  }  func (srv *Server) Run(ctx context.Context, submitCh chan MessageNodeSubmission, errorCh chan error) error {  	srv.eventCh = submitCh  	srv.errorCh = errorCh -	network := "unix" -	if !srv.unixSocket { -		network = "tcp" +	mux := http.NewServeMux() +	for _, handler := range srv.handlers() { +		handler.register(mux)  	} -	listener, err := net.Listen(network, srv.Address) +	listener, err := net.Listen(srv.Network, srv.Address)  	if err != nil {  		return fmt.Errorf("listen: %v", err)  	}  	defer listener.Close() -	exitErr := make(chan error, 1) -	defer close(exitErr) +	s := http.Server{Handler: mux} +	exitCh := make(chan error, 1) +	defer close(exitCh)  	go func() { -		exitErr <- srv.Serve(listener) +		exitCh <- s.Serve(listener)  	}()  	select { -	case err := <-exitErr: +	case err := <-exitCh:  		if err != nil && err != http.ErrServerClosed {  			return fmt.Errorf("serve: %v", err)  		}  	case <-ctx.Done():  		tctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)  		defer cancel() -		if err := srv.Shutdown(tctx); err != nil { +		if err := s.Shutdown(tctx); err != nil {  			return fmt.Errorf("shutdown: %v", err)  		}  	}  	return nil  } -func (srv *Server) getStatus(w http.ResponseWriter, r *http.Request) { -	if r.Method != http.MethodGet { -		http.Error(w, "Only HTTP GET method is allowed", http.StatusMethodNotAllowed) -		return +func (srv *Server) handlers() []handler { +	return []handler{ +		handler{ +			method:   http.MethodGet, +			endpoint: EndpointGetStatus, +			callback: func(w http.ResponseWriter, r *http.Request) (int, string) { return srv.getStatus(w, r) }, +		}, +		handler{ +			method:   http.MethodPost, +			endpoint: EndpointAddChain, +			callback: func(w http.ResponseWriter, r *http.Request) (int, string) { return srv.addChain(w, r) }, +		},  	} -	fmt.Fprintf(w, "OK\n")  } -func (srv *Server) addChain(w http.ResponseWriter, r *http.Request) { -	if r.Method != http.MethodGet { -		http.Error(w, "Only HTTP POST method is allowed", http.StatusMethodNotAllowed) -		return +func (srv *Server) getStatus(w http.ResponseWriter, r *http.Request) (int, string) { +	return http.StatusOK, "OK" +} + +func (srv *Server) addChain(w http.ResponseWriter, r *http.Request) (int, string) { +	node, err := srv.Nodes.authenticate(r) +	if err != nil { +		return http.StatusForbidden, "Invalid HTTP Basic Auth credentials"  	} -	fmt.Fprintf(w, "TODO: HTTP POST /add-chain\n") + +	b, err := io.ReadAll(r.Body) +	if err != nil { +		return http.StatusBadRequest, "Read HTTP POST body failed" +	} +	defer r.Body.Close() + +	chain, err := x509util.ParseChain(b) +	if err != nil { +		return http.StatusBadRequest, "Malformed HTTP POST body" +	} +	if err := node.check(chain[0]); err != nil { +		srv.errorCh <- ErrorUnauthorizedDomainName{ +			PEMChain: b, +			Node:     node, +			Err:      err, +		} +	} else { +		srv.eventCh <- MessageNodeSubmission{ +			SerialNumber: chain[0].SerialNumber.String(), +			NotBefore:    chain[0].NotBefore, +			DomainNames:  chain[0].DNSNames, +			PEMChain:     b, +		} +	} + +	return http.StatusOK, "OK"  } | 
