diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/server/.nodes.go.swp | bin | 0 -> 12288 bytes | |||
-rw-r--r-- | pkg/server/errors.go | 1 | ||||
-rw-r--r-- | pkg/server/handler.go | 32 | ||||
-rw-r--r-- | pkg/server/messages.go | 19 | ||||
-rw-r--r-- | pkg/server/nodes.go | 53 | ||||
-rw-r--r-- | pkg/server/server.go | 117 |
6 files changed, 176 insertions, 46 deletions
diff --git a/pkg/server/.nodes.go.swp b/pkg/server/.nodes.go.swp Binary files differnew file mode 100644 index 0000000..b775c96 --- /dev/null +++ b/pkg/server/.nodes.go.swp diff --git a/pkg/server/errors.go b/pkg/server/errors.go deleted file mode 100644 index abb4e43..0000000 --- a/pkg/server/errors.go +++ /dev/null @@ -1 +0,0 @@ -package server diff --git a/pkg/server/handler.go b/pkg/server/handler.go new file mode 100644 index 0000000..6d17af7 --- /dev/null +++ b/pkg/server/handler.go @@ -0,0 +1,32 @@ +package server + +import ( + "fmt" + "net/http" +) + +// handler implements the http.Handler interface +type handler struct { + method string + endpoint string + callback func(w http.ResponseWriter, r *http.Request) (int, string) +} + +func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("Cache-Control", "no-cache") + if r.Method != h.method { + http.Error(w, "Invalid HTTP method, expected "+h.method, http.StatusMethodNotAllowed) + return + } + code, text := h.callback(w, r) + if code != http.StatusOK { + http.Error(w, text, code) + return + } + fmt.Fprintf(w, fmt.Sprintf("%s\n", text)) +} + +func (h handler) register(mux *http.ServeMux) { + mux.Handle("/"+h.endpoint, h) +} diff --git a/pkg/server/messages.go b/pkg/server/messages.go index 50edded..a6ea243 100644 --- a/pkg/server/messages.go +++ b/pkg/server/messages.go @@ -1,4 +1,23 @@ package server +import ( + "fmt" + "time" +) + type MessageNodeSubmission struct { + SerialNumber string + NotBefore time.Time + DomainNames []string + PEMChain []byte +} + +type ErrorUnauthorizedDomainName struct { + PEMChain []byte + Node Node + Err error +} + +func (e ErrorUnauthorizedDomainName) Error() string { + return fmt.Sprintf("%v", e.Err) } diff --git a/pkg/server/nodes.go b/pkg/server/nodes.go new file mode 100644 index 0000000..164c06f --- /dev/null +++ b/pkg/server/nodes.go @@ -0,0 +1,53 @@ +package server + +import ( + "crypto/x509" + "fmt" + "net/http" +) + +// Node is an identified system that can request certificates +type Node struct { + Name string `json:"name"` // Artbirary node name for authentication + Secret string `json:"secret"` // Arbitrary node secret for authentication + Domains []string `json:"issues"` // Exact-match domain names that are allowed +} + +func (node *Node) authenticate(r *http.Request) error { + user, password, ok := r.BasicAuth() + if !ok { + return fmt.Errorf("no http basic auth credentials") + } + if user != node.Name || password != node.Secret { + return fmt.Errorf("invalid http basic auth credentials") + } + return nil +} + +func (node *Node) check(crt x509.Certificate) error { + for _, san := range crt.DNSNames { + ok := false + for _, domain := range node.Domains { + if domain == san { + ok = true + break + } + } + if !ok { + return fmt.Errorf("%s: not authorized to issue certificates for %s", node.Name, san) + } + } + return nil +} + +// Nodes is a list of nodes that can request certificates +type Nodes []Node + +func (nodes *Nodes) authenticate(r *http.Request) (Node, error) { + for _, node := range (*nodes)[:] { + if err := node.authenticate(r); err == nil { + return node, nil + } + } + return Node{}, fmt.Errorf("no valid HTTP basic auth credentials") +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 2d10c4b..06eb258 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -3,104 +3,131 @@ package server import ( "context" "fmt" + "io" "net" "net/http" - "strings" "time" + + "rgdd.se/silent-ct/internal/x509util" ) const ( - EndpointAddChain = "/add-chain" - EndpointGetStatus = "/get-status" + EndpointAddChain = "add-chain" + EndpointGetStatus = "get-status" + DefaultNetwork = "tcp" DefaultAddress = "localhost:2009" DefaultConfigFile = "/home/rgdd/.config/silent-ct/config.json" // FIXME ) -type Nodes []Node - -type Node struct { - Name string `json:"name"` - Secret string `json:"secret"` - Domains []string `json:"issues"` -} - type Config struct { - Address string // hostname[:port] or unix:///path/to/file.sock + Network string // tcp or unix + Address string // hostname[:port] or path to a unix socket Nodes Nodes // Which nodes are trusted to issue what certificates } type Server struct { Config - http.Server - unixSocket bool // true if listening with a unix socket - eventCh chan MessageNodeSubmission - errorCh chan error + eventCh chan MessageNodeSubmission + errorCh chan error } func New(cfg Config) (Server, error) { - mux := http.NewServeMux() - srv := Server{Config: cfg, Server: http.Server{Handler: mux}} - mux.HandleFunc(EndpointAddChain, func(w http.ResponseWriter, r *http.Request) { srv.addChain(w, r) }) - mux.HandleFunc(EndpointGetStatus, func(w http.ResponseWriter, r *http.Request) { srv.getStatus(w, r) }) - if len(srv.Address) == 0 { - srv.Config.Address = DefaultAddress + if cfg.Network == "" { + cfg.Network = DefaultNetwork } - if strings.HasPrefix(srv.Config.Address, "unix://") { - srv.Config.Address = srv.Config.Address[7:] - srv.unixSocket = true + if cfg.Address == "" { + cfg.Network = DefaultAddress } - return srv, nil + return Server{Config: cfg}, nil } func (srv *Server) Run(ctx context.Context, submitCh chan MessageNodeSubmission, errorCh chan error) error { srv.eventCh = submitCh srv.errorCh = errorCh - network := "unix" - if !srv.unixSocket { - network = "tcp" + mux := http.NewServeMux() + for _, handler := range srv.handlers() { + handler.register(mux) } - listener, err := net.Listen(network, srv.Address) + listener, err := net.Listen(srv.Network, srv.Address) if err != nil { return fmt.Errorf("listen: %v", err) } defer listener.Close() - exitErr := make(chan error, 1) - defer close(exitErr) + s := http.Server{Handler: mux} + exitCh := make(chan error, 1) + defer close(exitCh) go func() { - exitErr <- srv.Serve(listener) + exitCh <- s.Serve(listener) }() select { - case err := <-exitErr: + case err := <-exitCh: if err != nil && err != http.ErrServerClosed { return fmt.Errorf("serve: %v", err) } case <-ctx.Done(): tctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - if err := srv.Shutdown(tctx); err != nil { + if err := s.Shutdown(tctx); err != nil { return fmt.Errorf("shutdown: %v", err) } } return nil } -func (srv *Server) getStatus(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "Only HTTP GET method is allowed", http.StatusMethodNotAllowed) - return +func (srv *Server) handlers() []handler { + return []handler{ + handler{ + method: http.MethodGet, + endpoint: EndpointGetStatus, + callback: func(w http.ResponseWriter, r *http.Request) (int, string) { return srv.getStatus(w, r) }, + }, + handler{ + method: http.MethodPost, + endpoint: EndpointAddChain, + callback: func(w http.ResponseWriter, r *http.Request) (int, string) { return srv.addChain(w, r) }, + }, } - fmt.Fprintf(w, "OK\n") } -func (srv *Server) addChain(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "Only HTTP POST method is allowed", http.StatusMethodNotAllowed) - return +func (srv *Server) getStatus(w http.ResponseWriter, r *http.Request) (int, string) { + return http.StatusOK, "OK" +} + +func (srv *Server) addChain(w http.ResponseWriter, r *http.Request) (int, string) { + node, err := srv.Nodes.authenticate(r) + if err != nil { + return http.StatusForbidden, "Invalid HTTP Basic Auth credentials" } - fmt.Fprintf(w, "TODO: HTTP POST /add-chain\n") + + b, err := io.ReadAll(r.Body) + if err != nil { + return http.StatusBadRequest, "Read HTTP POST body failed" + } + defer r.Body.Close() + + chain, err := x509util.ParseChain(b) + if err != nil { + return http.StatusBadRequest, "Malformed HTTP POST body" + } + if err := node.check(chain[0]); err != nil { + srv.errorCh <- ErrorUnauthorizedDomainName{ + PEMChain: b, + Node: node, + Err: err, + } + } else { + srv.eventCh <- MessageNodeSubmission{ + SerialNumber: chain[0].SerialNumber.String(), + NotBefore: chain[0].NotBefore, + DomainNames: chain[0].DNSNames, + PEMChain: b, + } + } + + return http.StatusOK, "OK" } |