package main

import (
	"container/heap"
	"context"
	"crypto/sha256"
	"encoding/json"
	"fmt"
	logger "log"
	"net/http"
	"os"
	"strings"
	"sync"
	"time"

	"git.cs.kau.se/rasmoste/ct-sans/internal/chunk"
	"git.cs.kau.se/rasmoste/ct-sans/internal/merkle"
	"git.cs.kau.se/rasmoste/ct-sans/internal/utils"
	ct "github.com/google/certificate-transparency-go"
	"github.com/google/certificate-transparency-go/client"
	"github.com/google/certificate-transparency-go/jsonclient"
	"github.com/google/certificate-transparency-go/scanner"
	"gitlab.torproject.org/rgdd/ct/pkg/metadata"
)

func collect(opts options) error {
	b, err := os.ReadFile(fmt.Sprintf("%s/%s", opts.Directory, opts.metadataFile))
	if err != nil {
		return err
	}
	var md metadata.Metadata
	if err := json.Unmarshal(b, &md); err != nil {
		return err
	}

	var await sync.WaitGroup
	defer await.Wait()
	ctx, cancel := context.WithCancel(context.Background())
	go func() {
		await.Add(1)
		defer await.Done()
		handleSignals(ctx, cancel)

		//
		// Sometimes some worker in scanner.Fetcher isn't shutdown
		// properly despite the parent context (including getRanges)
		// being done.  The below is an ugly hack to avoid hanging.
		//
		wait := time.Second * 5 // TODO: 15s
		logger.Printf("INFO: about to exit, please wait %v...\n", wait)
		select {
		case <-time.After(wait):
			os.Exit(0)
		}
	}()

	metricsCh := make(chan metrics)
	defer close(metricsCh)
	go func() {
		await.Add(1)
		defer await.Done()
		handleMetrics(ctx, metricsCh, utils.Logs(md))
	}()

	defer cancel()
	var wg sync.WaitGroup
	defer wg.Wait()
	for _, log := range utils.Logs(md) {
		//if *log.Description != "Trust Asia Log2024-2" {
		//	continue
		//}
		go func(log metadata.Log) {
			wg.Add(1)
			defer wg.Done()

			chunks := make(chan *chunk.Chunk)
			defer close(chunks)

			id, _ := log.Key.ID()
			th, err := readState(opts, id[:])
			if err != nil {
				logger.Printf("ERROR: %s: %v\n", *log.Description, err)
				cancel()
				return
			}
			sth, err := readSnapshot(opts, id[:])
			if err != nil {
				logger.Printf("ERROR: %s: %v\n", *log.Description, err)
				cancel()
				return
			}
			cli, err := client.New(string(log.URL),
				&http.Client{Transport: &http.Transport{IdleConnTimeout: 120 * time.Second}},
				jsonclient.Options{UserAgent: opts.HTTPAgent},
			)
			if err != nil {
				logger.Printf("ERROR: %s: %v\n", *log.Description, err)
				cancel()
				return
			}
			fetcher := scanner.NewFetcher(cli, &scanner.FetcherOptions{
				BatchSize:     int(opts.BatchSize),
				StartIndex:    th.TreeSize,
				EndIndex:      int64(sth.TreeSize),
				ParallelFetch: int(opts.WorkersPerLog),
			})
			if uint64(th.TreeSize) == sth.TreeSize {
				logger.Printf("INFO: %s: up-to-date with tree size %d", *log.Description, th.TreeSize)
				metricsCh <- metrics{Description: *log.Description, End: th.TreeSize, Done: true}
				return
			}

			//
			// Callback that puts downloaded certificates into a
			// chunk that a single sequencer can verify and persist
			//
			callback := func(eb scanner.EntryBatch) {
				leafHashes := [][sha256.Size]byte{}
				for i := 0; i < len(eb.Entries); i++ {
					leafHashes = append(leafHashes, merkle.HashLeafNode(eb.Entries[i].LeafInput))
				}
				sans, errs := utils.SANsFromLeafEntries(eb.Start, eb.Entries)
				for _, err := range errs {
					logger.Printf("NOTICE: %s: %v", *log.Description, err)
				}
				chunks <- &chunk.Chunk{eb.Start, leafHashes, sans}
			}

			//
			// Sequencer that waits for sufficiently large chunks
			// before verifying inclusion proofs and persisting an
			// intermediate tree head (size and root hash) as well
			// as the SANs that were observed up until that point.
			//
			cctx, fetchDone := context.WithCancel(ctx)
			defer fetchDone()
			go func() {
				wg.Add(1)
				defer wg.Done()

				h := &chunk.ChunkHeap{}
				heap.Init(h)
				curr := th.TreeSize
				for {
					select {
					case <-cctx.Done():
						if h.Sequence(curr) {
							c := h.TPop()
							if _, err := persistChunk(metricsCh, cli, opts, id[:], *log.Description, 0, c); err != nil {
								logger.Printf("ERROR: %s: %v\n", *log.Description, err)
							}
						}
						return
					case c, ok := <-chunks:
						if ok {
							h.TPush(c)
						}
						if !h.Sequence(curr) {
							continue
						}

						c = h.TPop()
						putBack, err := persistChunk(metricsCh, cli, opts, id[:], *log.Description, int64(opts.PersistSize), c)
						if err != nil {
							cancel()
							logger.Printf("ERROR: %s: %v\n", *log.Description, err)
							return
						}
						if putBack {
							h.TPush(c)
							continue
						}

						curr += int64(len(c.LeafHashes))
					}
				}
			}()

			logger.Printf("INFO: %s: working from tree size %d to %d", *log.Description, th.TreeSize, sth.TreeSize)
			if err := fetcher.Run(ctx, callback); err != nil {
				logger.Printf("ERROR: %s: %v\n", *log.Description, err)
				cancel()
				return
			}
			if ctx.Err() == nil {
				logger.Printf("INFO: %s: completed fetch at tree size %d", *log.Description, sth.TreeSize)
			}

			for len(chunks) > 0 {
				select {
				case <-ctx.Done():
					return // some Go routine cancelled due to an error
				case <-time.After(1 * time.Second):
					logger.Printf("DEBUG: %s: waiting for chunks to be consumed\n", *log.Description)
				}
			}
		}(log)
		//break
	}

	logger.Printf("INFO: collect is up-and-running, ctrl+C to exit\n")
	time.Sleep(3 * time.Second) // ensure that Go routines had time to spawn
	return nil
}

type treeHead struct {
	TreeSize int64             `json:"tree_size"`
	RootHash [sha256.Size]byte `json:root_hash"`
}

func readState(opts options, logID []byte) (treeHead, error) {
	if _, err := os.Stat(fmt.Sprintf("%s/%x/%s", opts.logDirectory, logID, opts.stateFile)); err != nil {
		return treeHead{0, sha256.Sum256(nil)}, nil
	}
	b, err := os.ReadFile(fmt.Sprintf("%s/%x/%s", opts.logDirectory, logID, opts.stateFile))
	if err != nil {
		return treeHead{}, err
	}
	var th treeHead
	if err := json.Unmarshal(b, &th); err != nil {
		return treeHead{}, err
	}
	return th, nil
}

func readSnapshot(opts options, logID []byte) (ct.SignedTreeHead, error) {
	b, err := os.ReadFile(fmt.Sprintf("%s/%x/%s", opts.logDirectory, logID, opts.sthFile))
	if err != nil {
		return ct.SignedTreeHead{}, err
	}
	var sth ct.SignedTreeHead
	if err := json.Unmarshal(b, &sth); err != nil {
		return ct.SignedTreeHead{}, err
	}
	return sth, nil
}

func persistChunk(metricsCh chan metrics, cli *client.LogClient, opts options, logID []byte, logDesc string, minSequence int64, c *chunk.Chunk) (bool, error) {
	chunkSize := int64(len(c.LeafHashes))
	if chunkSize == 0 {
		return false, nil // nothing to persist
	}
	if chunkSize < minSequence {
		return true, nil // wait for more leaves
	}

	// Read persisted tree state from disk
	oldTH, err := readState(opts, logID)
	if err != nil {
		return false, err
	}
	if oldTH.TreeSize != c.Start {
		return false, fmt.Errorf("disk state says next index is %d, in-memory says %d", oldTH.TreeSize, c.Start)
	}
	// Read signed tree head from disk
	sth, err := readSnapshot(opts, logID)
	if err != nil {
		return false, err
	}
	// Derive next intermediate tree state from a compact range
	//
	// Santity checks: expected indces/sizes and consistent root hashes.
	// This is redundant, but could, e.g., catch bugs with our storage.
	//
	// Independent context because we need to run inclusion and consistency
	// queries after the parent context is cancelled to persist on shutdown
	//
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	newTH := treeHead{TreeSize: c.Start + chunkSize}
	p, err := cli.GetProofByHash(ctx, c.LeafHashes[0][:], uint64(newTH.TreeSize))
	if err != nil {
		return true, nil // try again later
	}
	if p.LeafIndex != c.Start {
		return false, fmt.Errorf("log says proof for entry %d is at index %d", c.Start, p.LeafIndex)
	}
	if newTH.RootHash, err = merkle.TreeHeadFromRangeProof(c.LeafHashes, uint64(c.Start), utils.Proof(p.AuditPath)); err != nil {
		return false, err
	}
	var hashes [][]byte
	if oldTH.TreeSize > 0 {
		if hashes, err = cli.GetSTHConsistency(ctx, uint64(oldTH.TreeSize), uint64(newTH.TreeSize)); err != nil {
			return true, nil // try again later
		}
	}
	if err := merkle.VerifyConsistency(uint64(oldTH.TreeSize), uint64(newTH.TreeSize), oldTH.RootHash, newTH.RootHash, utils.Proof(hashes)); err != nil {
		return false, fmt.Errorf("%d %x is inconsistent with on-disk state: %v", newTH.TreeSize, newTH.RootHash, err)
	}

	// Check that new tree state is consistent with the signed tree head
	if hashes, err = cli.GetSTHConsistency(ctx, uint64(newTH.TreeSize), sth.TreeSize); err != nil {
		return true, nil // try again later
	}
	if err := merkle.VerifyConsistency(uint64(newTH.TreeSize), sth.TreeSize, newTH.RootHash, sth.SHA256RootHash, utils.Proof(hashes)); err != nil {
		return false, fmt.Errorf("%d %x is inconsistent with signed tree head: %v", newTH.TreeSize, newTH.RootHash, err)
	}

	// Persist SANs to disk
	fp, err := os.OpenFile(fmt.Sprintf("%s/%x/%s", opts.logDirectory, logID, opts.sansFile), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
	if err != nil {
		return false, err
	}
	defer fp.Close()
	if _, err := fp.WriteString(strings.Join(c.SANs, "\n") + "\n"); err != nil {
		return false, err
	}
	if err := fp.Sync(); err != nil {
		return false, err
	}

	// Persist new tree state to disk
	b, err := json.Marshal(&newTH)
	if err != nil {
		return false, err
	}
	if err := os.WriteFile(fmt.Sprintf("%s/%x/%s", opts.logDirectory, logID, opts.stateFile), b, 0644); err != nil {
		return false, err
	}

	// Output metrics
	metricsCh <- metrics{
		Description: logDesc,
		NumEntries:  newTH.TreeSize - oldTH.TreeSize,
		Timestamp:   time.Now().Unix(),
		Start:       newTH.TreeSize,
		End:         int64(sth.TreeSize),
		Done:        uint64(newTH.TreeSize) == sth.TreeSize,
	}

	logger.Printf("DEBUG: %s: persisted [%d, %d]\n", logDesc, oldTH.TreeSize, newTH.TreeSize)
	return false, nil
}