diff options
Diffstat (limited to 'pkg/server/server.go')
-rw-r--r-- | pkg/server/server.go | 117 |
1 files changed, 72 insertions, 45 deletions
diff --git a/pkg/server/server.go b/pkg/server/server.go index 2d10c4b..06eb258 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -3,104 +3,131 @@ package server import ( "context" "fmt" + "io" "net" "net/http" - "strings" "time" + + "rgdd.se/silent-ct/internal/x509util" ) const ( - EndpointAddChain = "/add-chain" - EndpointGetStatus = "/get-status" + EndpointAddChain = "add-chain" + EndpointGetStatus = "get-status" + DefaultNetwork = "tcp" DefaultAddress = "localhost:2009" DefaultConfigFile = "/home/rgdd/.config/silent-ct/config.json" // FIXME ) -type Nodes []Node - -type Node struct { - Name string `json:"name"` - Secret string `json:"secret"` - Domains []string `json:"issues"` -} - type Config struct { - Address string // hostname[:port] or unix:///path/to/file.sock + Network string // tcp or unix + Address string // hostname[:port] or path to a unix socket Nodes Nodes // Which nodes are trusted to issue what certificates } type Server struct { Config - http.Server - unixSocket bool // true if listening with a unix socket - eventCh chan MessageNodeSubmission - errorCh chan error + eventCh chan MessageNodeSubmission + errorCh chan error } func New(cfg Config) (Server, error) { - mux := http.NewServeMux() - srv := Server{Config: cfg, Server: http.Server{Handler: mux}} - mux.HandleFunc(EndpointAddChain, func(w http.ResponseWriter, r *http.Request) { srv.addChain(w, r) }) - mux.HandleFunc(EndpointGetStatus, func(w http.ResponseWriter, r *http.Request) { srv.getStatus(w, r) }) - if len(srv.Address) == 0 { - srv.Config.Address = DefaultAddress + if cfg.Network == "" { + cfg.Network = DefaultNetwork } - if strings.HasPrefix(srv.Config.Address, "unix://") { - srv.Config.Address = srv.Config.Address[7:] - srv.unixSocket = true + if cfg.Address == "" { + cfg.Network = DefaultAddress } - return srv, nil + return Server{Config: cfg}, nil } func (srv *Server) Run(ctx context.Context, submitCh chan MessageNodeSubmission, errorCh chan error) error { srv.eventCh = submitCh srv.errorCh = errorCh - network := "unix" - if !srv.unixSocket { - network = "tcp" + mux := http.NewServeMux() + for _, handler := range srv.handlers() { + handler.register(mux) } - listener, err := net.Listen(network, srv.Address) + listener, err := net.Listen(srv.Network, srv.Address) if err != nil { return fmt.Errorf("listen: %v", err) } defer listener.Close() - exitErr := make(chan error, 1) - defer close(exitErr) + s := http.Server{Handler: mux} + exitCh := make(chan error, 1) + defer close(exitCh) go func() { - exitErr <- srv.Serve(listener) + exitCh <- s.Serve(listener) }() select { - case err := <-exitErr: + case err := <-exitCh: if err != nil && err != http.ErrServerClosed { return fmt.Errorf("serve: %v", err) } case <-ctx.Done(): tctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - if err := srv.Shutdown(tctx); err != nil { + if err := s.Shutdown(tctx); err != nil { return fmt.Errorf("shutdown: %v", err) } } return nil } -func (srv *Server) getStatus(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "Only HTTP GET method is allowed", http.StatusMethodNotAllowed) - return +func (srv *Server) handlers() []handler { + return []handler{ + handler{ + method: http.MethodGet, + endpoint: EndpointGetStatus, + callback: func(w http.ResponseWriter, r *http.Request) (int, string) { return srv.getStatus(w, r) }, + }, + handler{ + method: http.MethodPost, + endpoint: EndpointAddChain, + callback: func(w http.ResponseWriter, r *http.Request) (int, string) { return srv.addChain(w, r) }, + }, } - fmt.Fprintf(w, "OK\n") } -func (srv *Server) addChain(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "Only HTTP POST method is allowed", http.StatusMethodNotAllowed) - return +func (srv *Server) getStatus(w http.ResponseWriter, r *http.Request) (int, string) { + return http.StatusOK, "OK" +} + +func (srv *Server) addChain(w http.ResponseWriter, r *http.Request) (int, string) { + node, err := srv.Nodes.authenticate(r) + if err != nil { + return http.StatusForbidden, "Invalid HTTP Basic Auth credentials" } - fmt.Fprintf(w, "TODO: HTTP POST /add-chain\n") + + b, err := io.ReadAll(r.Body) + if err != nil { + return http.StatusBadRequest, "Read HTTP POST body failed" + } + defer r.Body.Close() + + chain, err := x509util.ParseChain(b) + if err != nil { + return http.StatusBadRequest, "Malformed HTTP POST body" + } + if err := node.check(chain[0]); err != nil { + srv.errorCh <- ErrorUnauthorizedDomainName{ + PEMChain: b, + Node: node, + Err: err, + } + } else { + srv.eventCh <- MessageNodeSubmission{ + SerialNumber: chain[0].SerialNumber.String(), + NotBefore: chain[0].NotBefore, + DomainNames: chain[0].DNSNames, + PEMChain: b, + } + } + + return http.StatusOK, "OK" } |