aboutsummaryrefslogtreecommitdiff
path: root/pkg/server/server.go
blob: 06eb258b17f99dbb85839a7bfd4e6aef698d9d76 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package server

import (
	"context"
	"fmt"
	"io"
	"net"
	"net/http"
	"time"

	"rgdd.se/silent-ct/internal/x509util"
)

const (
	EndpointAddChain  = "add-chain"
	EndpointGetStatus = "get-status"

	DefaultNetwork    = "tcp"
	DefaultAddress    = "localhost:2009"
	DefaultConfigFile = "/home/rgdd/.config/silent-ct/config.json" // FIXME
)

type Config struct {
	Network string // tcp or unix
	Address string // hostname[:port] or path to a unix socket
	Nodes   Nodes  // Which nodes are trusted to issue what certificates
}

type Server struct {
	Config

	eventCh chan MessageNodeSubmission
	errorCh chan error
}

func New(cfg Config) (Server, error) {
	if cfg.Network == "" {
		cfg.Network = DefaultNetwork
	}
	if cfg.Address == "" {
		cfg.Network = DefaultAddress
	}
	return Server{Config: cfg}, nil
}

func (srv *Server) Run(ctx context.Context, submitCh chan MessageNodeSubmission, errorCh chan error) error {
	srv.eventCh = submitCh
	srv.errorCh = errorCh
	mux := http.NewServeMux()
	for _, handler := range srv.handlers() {
		handler.register(mux)
	}

	listener, err := net.Listen(srv.Network, srv.Address)
	if err != nil {
		return fmt.Errorf("listen: %v", err)
	}
	defer listener.Close()

	s := http.Server{Handler: mux}
	exitCh := make(chan error, 1)
	defer close(exitCh)
	go func() {
		exitCh <- s.Serve(listener)
	}()

	select {
	case err := <-exitCh:
		if err != nil && err != http.ErrServerClosed {
			return fmt.Errorf("serve: %v", err)
		}
	case <-ctx.Done():
		tctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
		defer cancel()
		if err := s.Shutdown(tctx); err != nil {
			return fmt.Errorf("shutdown: %v", err)
		}
	}
	return nil
}

func (srv *Server) handlers() []handler {
	return []handler{
		handler{
			method:   http.MethodGet,
			endpoint: EndpointGetStatus,
			callback: func(w http.ResponseWriter, r *http.Request) (int, string) { return srv.getStatus(w, r) },
		},
		handler{
			method:   http.MethodPost,
			endpoint: EndpointAddChain,
			callback: func(w http.ResponseWriter, r *http.Request) (int, string) { return srv.addChain(w, r) },
		},
	}
}

func (srv *Server) getStatus(w http.ResponseWriter, r *http.Request) (int, string) {
	return http.StatusOK, "OK"
}

func (srv *Server) addChain(w http.ResponseWriter, r *http.Request) (int, string) {
	node, err := srv.Nodes.authenticate(r)
	if err != nil {
		return http.StatusForbidden, "Invalid HTTP Basic Auth credentials"
	}

	b, err := io.ReadAll(r.Body)
	if err != nil {
		return http.StatusBadRequest, "Read HTTP POST body failed"
	}
	defer r.Body.Close()

	chain, err := x509util.ParseChain(b)
	if err != nil {
		return http.StatusBadRequest, "Malformed HTTP POST body"
	}
	if err := node.check(chain[0]); err != nil {
		srv.errorCh <- ErrorUnauthorizedDomainName{
			PEMChain: b,
			Node:     node,
			Err:      err,
		}
	} else {
		srv.eventCh <- MessageNodeSubmission{
			SerialNumber: chain[0].SerialNumber.String(),
			NotBefore:    chain[0].NotBefore,
			DomainNames:  chain[0].DNSNames,
			PEMChain:     b,
		}
	}

	return http.StatusOK, "OK"
}