// BSD 2-Clause License
//
// Copyright (c) 2022, the ct authors
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
//  1. Redistributions of source code must retain the above copyright notice, this
//     list of conditions and the following disclaimer.
//
//  2. Redistributions in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
// From:
// https://gitlab.torproject.org/rgdd/ct/-/tree/main/pkg/merkle
package merkle

import (
	"crypto/sha256"
	"fmt"
)

// HashEmptyTree computes the hash of an empty tree.  See RFC 6162, §2.1:
//
//	MTH({}) = SHA-256()
func HashEmptyTree() [sha256.Size]byte {
	return sha256.Sum256(nil)
}

// HashLeafNode computes the hash of a leaf's data.  See RFC 6162, §2.1:
//
//	MTH({d(0)}) = SHA-256(0x00 || d(0))
func HashLeafNode(data []byte) (hash [sha256.Size]byte) {
	h := sha256.New()
	h.Write([]byte{0x00})
	h.Write(data)
	copy(hash[:], h.Sum(nil))
	return
}

// HashInteriorNode computes the hash of an interior node.  See RFC 6962, §2.1:
//
//	MTH(D[n]) = SHA-256(0x01 || MTH(D[0:k]) || MTH(D[k:n])
func HashInteriorNode(left, right [sha256.Size]byte) (hash [sha256.Size]byte) {
	h := sha256.New()
	h.Write([]byte{0x01})
	h.Write(left[:])
	h.Write(right[:])
	copy(hash[:], h.Sum(nil))
	return
}

// inclusion implements the algorithm specified in RFC 9162, Section 2.1.3.2.
// In addition, the caller is allowed to confirm right-node subtree hashes.
func inclusion(leaf [sha256.Size]byte, index, size uint64, proof [][sha256.Size]byte,
	confirmRoot func([sha256.Size]byte) error, confirmHash func([sha256.Size]byte) error) error {
	// Step 1
	if index >= size {
		return fmt.Errorf("leaf index must be in [%d, %d]", 0, size-1)
	}

	// Step 2
	fn := index
	sn := size - 1

	// Step 3
	r := leaf

	// Step 4
	for i, p := range proof {
		// Step 4a
		if sn == 0 {
			return fmt.Errorf("reached tree head with %d remaining proof hash(es)", len(proof[i:]))
		}

		// Step 4b
		if isLSB(fn) || fn == sn {
			// Step 4b, i
			r = HashInteriorNode(p, r)

			// Step 4b, ii
			if !isLSB(fn) {
				for {
					fn = rshift(fn)
					sn = rshift(sn)

					if isLSB(fn) || fn == 0 {
						break
					}
				}
			}
		} else {
			// Step 4b, i
			r = HashInteriorNode(r, p)

			// Extension: allow the caller to confirm right-node subtree hashes
			if err := confirmHash(p); err != nil {
				return fmt.Errorf("subtree index %d: %v", fn, err)
			}
		}

		// Step 4c
		fn = rshift(fn)
		sn = rshift(sn)
	}

	// Step 5
	if sn != 0 {
		return fmt.Errorf("stopped at subtree with index %d due to missing proof hashes", fn)
	}
	return confirmRoot(r)
}

// consistency implements the algorithm specified in RFC 9162, §2.1.4.2
func consistency(oldSize, newSize uint64, oldRoot, newRoot [sha256.Size]byte, proof [][sha256.Size]byte) error {
	// Step 1
	if len(proof) == 0 {
		return fmt.Errorf("need at least one proof hash")
	}

	// Step 2
	if isPOW2(oldSize) {
		proof = append([][sha256.Size]byte{oldRoot}, proof...)
	}

	// Step 3
	fn := oldSize - 1
	sn := newSize - 1

	// Step 4
	for isLSB(fn) {
		fn = rshift(fn)
		sn = rshift(sn)
	}

	// Step 5
	fr := proof[0]
	sr := proof[0]

	// Step 6
	for i, c := range proof[1:] {
		// Step 6a
		if sn == 0 {
			return fmt.Errorf("reached tree head with %d remaining proof hash(es)", len(proof[i+1:]))
		}

		// Step 6b
		if isLSB(fn) || fn == sn {
			// Step 6b, i
			fr = HashInteriorNode(c, fr)
			// Step 6b, ii
			sr = HashInteriorNode(c, sr)
			// Step 6b, iii
			if !isLSB(fn) {
				for {
					fn = rshift(fn)
					sn = rshift(sn)

					if isLSB(fn) || fn == 0 {
						break
					}
				}
			}
		} else {
			// Step 6b, i
			sr = HashInteriorNode(sr, c)
		}

		// Step 6c
		fn = rshift(fn)
		sn = rshift(sn)
	}

	// Step 7
	if sn != 0 {
		return fmt.Errorf("stopped at subtree with index %d due to missing proof hashes", fn)
	}
	if fr != oldRoot {
		return fmt.Errorf("recomputed old tree head %x is not equal to reference tree head %x", fr[:], oldRoot[:])
	}
	if sr != newRoot {
		return fmt.Errorf("recomputed new tree head %x is not equal to reference tree head %x", sr[:], newRoot[:])
	}
	return nil
}

// VerifyInclusion verifies that a leaf's data is commited at a given index in a
// reference tree
func VerifyInclusion(data []byte, index, size uint64, root [sha256.Size]byte, proof [][sha256.Size]byte) error {
	if size == 0 {
		return fmt.Errorf("tree size must be larger than zero")
	}

	confirmHash := func(h [sha256.Size]byte) error { return nil } // No compact range extension
	confirmRoot := func(r [sha256.Size]byte) error {
		if r != root {
			return fmt.Errorf("recomputed tree head %x is not equal to reference tree head %x", r[:], root[:])
		}
		return nil
	}
	return inclusion(HashLeafNode(data), index, size, proof, confirmRoot, confirmHash)
}

// VerifyConsistency verifies that an an old tree is consistent with a new tree
func VerifyConsistency(oldSize, newSize uint64, oldRoot, newRoot [sha256.Size]byte, proof [][sha256.Size]byte) error {
	checkTree := func(size uint64, root [sha256.Size]byte) error {
		if size == 0 {
			if root != HashEmptyTree() {
				return fmt.Errorf("non-empty tree head %x for size zero", root[:])
			}
			if len(proof) != 0 {
				return fmt.Errorf("non-empty proof with %d hashes for size zero", len(proof))
			}
		} else if root == HashEmptyTree() {
			return fmt.Errorf("empty tree head %x for tree size %d", root[:], size)
		}
		return nil
	}

	if err := checkTree(oldSize, oldRoot); err != nil {
		return fmt.Errorf("old: %v", err)
	}
	if err := checkTree(newSize, newRoot); err != nil {
		return fmt.Errorf("new: %v", err)
	}
	if oldSize == 0 {
		return nil
	}

	if oldSize == newSize {
		if oldRoot != newRoot {
			return fmt.Errorf("different tree heads %x and %x with equal tree size %d", oldRoot, newRoot, oldSize)
		}
		if len(proof) != 0 {
			return fmt.Errorf("non-empty proof with %d hashes for equal tree size %d", len(proof), oldSize)
		}
		return nil
	}
	if oldSize > newSize {
		return fmt.Errorf("old tree size %d must be smaller than or equal to the new tree size %d", oldSize, newSize)
	}

	return consistency(oldSize, newSize, oldRoot, newRoot, proof)
}

// isLSB returns true if the least significant bit of num is set
func isLSB(num uint64) bool {
	return (num & 1) != 0
}

// isPOW2 returns true if num is a power of two (1, 2, 4, 8, ...)
func isPOW2(num uint64) bool {
	return (num & (num - 1)) == 0
}

func rshift(num uint64) uint64 {
	return num >> 1
}