package qna

import (
	"crypto/tls"
	"fmt"
	"net"
	"net/url"
	"os"
	"strings"
)

type Question struct {
	Domain string // domain name to visit via HTTPS
}

type Answer struct {
	Domain string // domain name of the visited HTTPS site
	HTTP   string // value set in the Onion-Location HTTP header (if any)
	HTML   string // value set in the Onion-Location HTML attribute (if any)

	CtxErr bool  // true if context deadline was exceeded
	ReqErr error // nil if HTTP GET request could be constructed
	DoErr  error // nil if HTTP GET request could be executed
}

func (a Answer) String() string {
	if a.CtxErr {
		return fmt.Sprintf("%s: context deadline exteeded")
	}
	if a.ReqErr != nil {
		return fmt.Sprintf("%s: %v", a.ReqErr)
	}
	if a.DoErr != nil {
		return fmt.Sprintf("%s: %v", a.DoErr)
	}
	if a.HTTP == "" && a.HTML == "" {
		return fmt.Sprintf("%s: connected but found no Onion-Location")
	}
	return fmt.Sprintf("%s header=%s attribute=%s", a.Domain, a.HTTP, a.HTML)
}

func (a *Answer) OnionLocation() bool {
	return a.HTTP != "" || a.HTML != ""
}

type Progress struct {
	NumOK    int
	NumOnion int

	NumMakeReqErr     int
	NumDNSNotFoundErr int
	NumDNSTimeoutErr  int
	NumDNSOtherErr    int
	NumConnTimeoutErr int
	NumConnSyscallErr int
	NumTLSCertErr     int
	NumTLSOtherErr    int
	Num3xxErr         int
	NumEOFErr         int
	NumDeadlineErr    int
	NumOtherErr       int
}

func (p Progress) String() string {
	str := fmt.Sprintf("  Processed: %d\n", p.NumProcess())
	str += fmt.Sprintf("    Success: %d (Onion-Location:%d)\n", p.NumOK, p.NumOnion)
	str += fmt.Sprintf("    Failure: %d (See breakdown below)\n", p.NumError())
	str += fmt.Sprintf("        Req: %d (Before sending request)\n", p.NumMakeReqErr)
	str += fmt.Sprintf("        DNS: %d (NotFound:%d Timeout:%d Other:%d)\n",
		p.NumDNSErr(), p.NumDNSNotFoundErr, p.NumDNSTimeoutErr, p.NumDNSOtherErr)
	str += fmt.Sprintf("        TCP: %d (Timeout:%d Syscall:%d)\n",
		p.NumConnErr(), p.NumConnTimeoutErr, p.NumConnSyscallErr)
	str += fmt.Sprintf("        TLS: %d (Cert:%d Other:%d)\n", p.NumTLSErr(), p.NumTLSCertErr, p.NumTLSOtherErr)
	str += fmt.Sprintf("        3xx: %d (Too many redirects)\n", p.Num3xxErr)
	str += fmt.Sprintf("        EOF: %d (Unclear meaning)\n", p.NumEOFErr)
	str += fmt.Sprintf("        CTX: %d (Deadline exceeded)\n", p.NumDeadlineErr)
	str += fmt.Sprintf("        ???: %d (Other errors)", p.NumOtherErr)
	return str
}

func (p *Progress) NumDNSErr() int {
	return p.NumDNSNotFoundErr + p.NumDNSTimeoutErr + p.NumDNSOtherErr
}

func (p *Progress) NumConnErr() int {
	return p.NumConnTimeoutErr + p.NumConnSyscallErr
}

func (p *Progress) NumTLSErr() int {
	return p.NumTLSCertErr + p.NumTLSOtherErr
}

func (p *Progress) NumError() int {
	return p.NumMakeReqErr + p.NumDNSErr() + p.NumConnErr() + p.NumTLSErr() + p.Num3xxErr + p.NumEOFErr + p.NumDeadlineErr + p.NumOtherErr
}

func (p *Progress) NumProcess() int {
	return p.NumOK + p.NumError()
}

func (p *Progress) AddAnswer(a Answer) {
	if a.CtxErr {
		p.NumDeadlineErr++
		return
	}
	if err := a.ReqErr; err != nil {
		p.NumMakeReqErr++
		return
	}
	if err := a.DoErr; err != nil {
		if e := dnsError(err); e != nil {
			if e.IsTimeout {
				p.NumDNSTimeoutErr++
			} else if e.IsNotFound {
				p.NumDNSNotFoundErr++
			} else {
				p.NumDNSOtherErr++
			}
		} else if isConnTimeoutError(err) {
			p.NumConnTimeoutErr++
		} else if isConnSyscallError(err) {
			p.NumConnSyscallErr++
		} else if isTLSError(err) {
			if isTLSCertError(err) {
				p.NumTLSCertErr++
			} else {
				p.NumTLSOtherErr++
			}
		} else if is3xxErr(err) {
			p.Num3xxErr++
		} else if isEOFError(err) {
			p.NumEOFErr++
		} else {
			p.NumOtherErr++
		}
		return
	}
	p.NumOK++
	if a.OnionLocation() {
		p.NumOnion++
	}
}

func dnsError(err error) *net.DNSError {
	urlErr, ok := err.(*url.Error)
	if !ok {
		return nil
	}
	opErr, ok := urlErr.Err.(*net.OpError)
	if !ok {
		return nil
	}
	dnsErr, ok := opErr.Err.(*net.DNSError)
	if !ok {
		return nil
	}
	return dnsErr
}

func isConnTimeoutError(err error) bool {
	urlErr, ok := err.(*url.Error)
	if !ok {
		return false
	}
	opErr, ok := urlErr.Err.(*net.OpError)
	return ok && opErr.Err.Error() == "i/o timeout"
}

func isConnSyscallError(err error) bool {
	urlErr, ok := err.(*url.Error)
	if !ok {
		return false
	}
	opErr, ok := urlErr.Err.(*net.OpError)
	if !ok {
		return false
	}
	syscallErr, ok := opErr.Err.(*os.SyscallError)
	return ok && (syscallErr.Syscall == "connect" || syscallErr.Syscall == "read")
}

func isTLSError(err error) bool {
	urlErr, ok := err.(*url.Error)
	if !ok {
		return false
	}
	if urlErr.Err.Error() == "http: server gave HTTP response to HTTPS client" {
		return true
	}
	return strings.Contains(urlErr.Err.Error(), "tls: ")
}

func isTLSCertError(err error) bool {
	urlErr, ok := err.(*url.Error)
	if !ok {
		return false
	}
	_, ok = urlErr.Err.(*tls.CertificateVerificationError)
	return ok
}

func is3xxErr(err error) bool {
	urlErr, ok := err.(*url.Error)
	return ok && urlErr.Err.Error() == "stopped after 10 redirects"
}

func isEOFError(err error) bool {
	urlErr, ok := err.(*url.Error)
	return ok && (urlErr.Err.Error() == "EOF" || urlErr.Err.Error() == "unexpected EOF")
}

func TrimWildcard(san string) string {
	if len(san) >= 2 && san[:2] == "*." {
		return san[2:]
	}
	return san
}