package main

import (
	"bufio"
	"bytes"
	"encoding/json"
	"errors"
	"fmt"
	logger "log"
	"os"
	"os/exec"
	"strconv"
	"time"

	ct "github.com/google/certificate-transparency-go"
	"gitlab.torproject.org/rgdd/ct/pkg/metadata"
)

func assemble(opts options) error {
	now := time.Now()
	metadataBytes, err := os.ReadFile(fmt.Sprintf("%s/%s", opts.Directory, opts.metadataFile))
	if err != nil {
		return err
	}
	var md metadata.Metadata
	if err := json.Unmarshal(metadataBytes, &md); err != nil {
		return err
	}
	var sanFiles []string
	var noticeFiles []string
	var sths []ct.SignedTreeHead
	for _, log := range logs(md) {
		id, _ := log.Key.ID()
		th, err := readState(opts, id[:])
		if err != nil {
			return err
		}
		sth, err := readSnapshot(opts, id[:])
		if err != nil {
			return err
		}
		if uint64(th.TreeSize) != sth.TreeSize {
			return fmt.Errorf("%s: at tree size %d, want %d", *log.Description, th.TreeSize, sth.TreeSize)
		}
		if th.RootHash != sth.SHA256RootHash {
			return fmt.Errorf("%s: root hash mismatch")
		}

		sanFiles = append(sanFiles, fmt.Sprintf("%s/%x/%s", opts.logDirectory, id[:], opts.sansFile))
		noticeFiles = append(noticeFiles, fmt.Sprintf("%s/%x/%s", opts.logDirectory, id[:], opts.noticeFile))
		sths = append(sths, sth)
	}

	logger.Printf("INFO: merging and de-duplicating %d input files with GNU sort", len(sanFiles))
	archiveDir := fmt.Sprintf("%s/%s-ct-sans", opts.archiveDirectory, now.Format("2006-01-02"))
	if err := os.MkdirAll(archiveDir, os.ModePerm); err != nil {
		return err
	}
	sansFile := fmt.Sprintf("%s/%s", archiveDir, opts.sansFile)
	if err := dedup(opts, sansFile, sanFiles); err != nil {
		return err
	}
	size, err := fileSize(sansFile)
	if err != nil {
		return err
	}
	logger.Printf("INFO: created %s (%s)", sansFile, size)

	logger.Printf("INFO: adding notice file")
	var notes []byte
	for _, noticeFile := range noticeFiles {
		b, err := os.ReadFile(noticeFile)
		if errors.Is(err, os.ErrNotExist) {
			continue // no notes, great
		} else if err != nil {
			return err
		}

		notes = append(notes, b...)
	}
	if err := os.WriteFile(fmt.Sprintf("%s/%s", archiveDir, opts.noticeFile), notes, 0644); err != nil {
		return err
	}
	numNotes := len(bytes.Split(notes, []byte("\n"))) - 1

	logger.Printf("INFO: adding README")
	readme, err := makeREADME(opts, sths, numNotes, now)
	if err != nil {
		return err
	}
	if err := os.WriteFile(fmt.Sprintf("%s/README.md", archiveDir), []byte(readme), 0644); err != nil {
		return err
	}

	logger.Printf("INFO: adding signed metadata file")
	sigBytes, err := os.ReadFile(fmt.Sprintf("%s/%s", opts.Directory, opts.metadataSignatureFile))
	if err != nil {
		return err
	}
	if err := os.WriteFile(fmt.Sprintf("%s/%s", archiveDir, opts.metadataFile), metadataBytes, 0644); err != nil {
		return err
	}
	if err := os.WriteFile(fmt.Sprintf("%s/%s", archiveDir, opts.metadataSignatureFile), sigBytes, 0644); err != nil {
		return err
	}

	logger.Printf("INFO: adding signed tree heads")
	sthsBytes, err := json.MarshalIndent(sths, "", "\t")
	if err != nil {
		return err
	}
	if err := os.WriteFile(fmt.Sprintf("%s/%s", archiveDir, opts.sthsFile), sthsBytes, 0644); err != nil {
		return err
	}

	logger.Printf("INFO: uncompressed dataset available in %s", archiveDir)
	return nil
}

func dedup(opts options, outputFile string, inputFiles []string) error {
	cmd := exec.Command("sort", append([]string{
		"-Vuo", outputFile,
		"--buffer-size", fmt.Sprintf("%dG", opts.BufferSize),
		"--temporary-directory", fmt.Sprintf("%s", opts.TempDir),
		"--parallel", fmt.Sprintf("%d", opts.Parallel),
	}, inputFiles...)...)
	if errors.Is(cmd.Err, exec.ErrDot) {
		cmd.Err = nil
	}
	stderr := bytes.NewBuffer(nil)
	cmd.Stderr = stderr
	if _, err := cmd.Output(); err != nil {
		return fmt.Errorf("%s", string(stderr.Bytes()))
	}
	return nil
}

func makeREADME(opts options, sths []ct.SignedTreeHead, numNotes int, now time.Time) (string, error) {
	snapshotTime, err := readSnapshotTime(opts)
	if err != nil {
		return "", err
	}
	return fmt.Sprintf(`# ct-sans dataset

Dataset assembled at %s.  Contents:

  - README.md
  - %s
  - %s
  - %s
  - %s
  - %s

The signed [metadata file][] and tree heads were downloaded at
%s.

[metadata file]: https://groups.google.com/a/chromium.org/g/ct-policy/c/IdbrdAcDQto

In total, %d certificates were downloaded from %d CT logs;
%d certificates contained SANs that could not be parsed.
For more information about these errors, see %s.

The SANs data set is sorted and de-duplicated, one SAN per line.
`, now.Format(time.UnixDate), opts.metadataFile, opts.metadataSignatureFile, opts.sthsFile, opts.noticeFile, opts.sansFile,
		snapshotTime.Format(time.UnixDate), numCertificates(sths), len(sths), numNotes, opts.noticeFile), nil
}

func fileSize(name string) (string, error) {
	fi, err := os.Stat(name)
	if err != nil {
		return "", err
	}
	size := fmt.Sprintf("%.1f GiB", float64(fi.Size())/float64((1024*1024*1024)))
	if fi.Size() < 1024*1024*1024 {
		size = fmt.Sprintf("%.1f MiB", float64(fi.Size())/float64((1024*1024)))
	}
	return size, nil
}

func noticeReport(path string) (string, error) {
	fp, err := os.OpenFile(path, os.O_RDONLY, 0644)
	if err != nil {
		return "", err
	}
	defer fp.Close()

	scanner := bufio.NewScanner(fp)
	num := 0
	for scanner.Scan() {
		_ = scanner.Text()
		num += 1
	}
	return fmt.Sprintf("%d", num), nil
}

func numCertificates(sths []ct.SignedTreeHead) (sum uint64) {
	for _, sth := range sths {
		sum += sth.TreeSize
	}
	return
}

func readSnapshotTime(opts options) (time.Time, error) {
	b, err := os.ReadFile(fmt.Sprintf("%s/%s", opts.Directory, opts.metadataTimestampFile))
	if err != nil {
		return time.Time{}, err
	}
	num, err := strconv.ParseInt(string(b), 10, 64)
	if err != nil {
		return time.Time{}, err
	}
	return time.Unix(num, 0), nil
}