package main

import (
	"bytes"
	"context"
	"encoding/base64"
	"encoding/hex"
	"flag"
	"fmt"
	"io"
	"io/ioutil"
	"log"
	"net"
	"net/http"
	"net/http/httputil"
	"os"
	"strings"
	"time"

	"github.com/blinsay/homer/version"
	"golang.org/x/net/dns/dnsmessage"
	"golang.org/x/net/http2"
)

var (
	// dns-over-https options
	post              = flag.Bool("post", false, "use a POST request to make a query. slightly smaller, but less cache friendly")
	resolver          = flag.String("resolver", "", "the url of a dns-over-https resolver to use")
	bootstrapResolver = flag.String("bootstrap-resolver", "", "the ip address of a dns resolver to use to bootstrap the address of the dns-over-https resolver")

	// query options
	qtypeArg  = flag.String("type", "A", "the `type` of record to query for.")
	qclassArg = flag.String("class", "IN", "the `class` of record to query for.")

	// output options
	printVersion = flag.Bool("version", false, "print the version and exit")
	short        = flag.Bool("short", false, "provide a terse answer, like dig's +short")
	dumpHTTP     = flag.Bool("dump-http", false, "dumps http request/response headers")
)

var (
	stringToClass = map[string]dnsmessage.Class{
		"IN":  dnsmessage.ClassINET,
		"CS":  dnsmessage.ClassCSNET,
		"CH":  dnsmessage.ClassCHAOS,
		"HS":  dnsmessage.ClassHESIOD,
		"ANY": dnsmessage.ClassANY,
	}

	stringToType = map[string]dnsmessage.Type{
		"A":     dnsmessage.TypeA,
		"NS":    dnsmessage.TypeNS,
		"CNAME": dnsmessage.TypeCNAME,
		"SOA":   dnsmessage.TypeSOA,
		"PTR":   dnsmessage.TypePTR,
		"MX":    dnsmessage.TypeMX,
		"TXT":   dnsmessage.TypeTXT,
		"AAAA":  dnsmessage.TypeAAAA,
		"SRV":   dnsmessage.TypeSRV,
		"OPT":   dnsmessage.TypeOPT,
	}
)

const (
	schemeHTTPS           = "https"
	headerAccept          = "Accept"
	headerUserAgent       = "User-Agent"
	headerContentType     = "Content-Type"
	contentTypeDNSMessage = "application/dns-message"
	queryParameterDNS     = "dns"
)

var (
	userAgent = fmt.Sprintf("homer/%s", version.VERSION)
)

func init() {
	log.SetFlags(0)

	flag.Usage = func() {
		fmt.Fprintf(flag.CommandLine.Output(), "usage: %s [query flags] -resolver [url] [names...]\n\n", os.Args[0])
		fmt.Fprintf(flag.CommandLine.Output(), "%s makes a dns-over-https query.\n\n", os.Args[0])
		fmt.Fprintf(flag.CommandLine.Output(), "available options:\n")
		flag.PrintDefaults()
	}
}

func main() {
	flag.Parse()

	if *printVersion {
		log.Printf("%s (%s)", version.VERSION, version.GITCOMMIT)
		return
	}

	if *resolver == "" {
		log.Fatalf("--resolver is required")
	}

	qclass, ok := stringToClass[strings.ToUpper(*qclassArg)]
	if !ok {
		// FIXME: print supported classes
		log.Fatalf("unrecognized query class: %s", *qclassArg)
	}
	qtype, ok := stringToType[strings.ToUpper(*qtypeArg)]
	if !ok {
		// FIXME: print supported types
		log.Fatalf("unrecognized query type: %s", *qtypeArg)
	}

	// configure an http client to use for dns-over-https.
	//
	// to avoid using DNS to build the resolver we're using because we don't trust
	// DNS, the transport is configured with a net.Resolver that will cowardly
	// refuse to do anything.
	//
	// if a user specifies a custom bootstrap dns server by IP address that they
	// trust to locate their DoH resolver, replace the cowardly net.Resolver with
	// one that always connects to the bootstrap IP.
	dialer := net.Dialer{
		Timeout:   30 * time.Second,
		KeepAlive: 30 * time.Second,
		DualStack: true,
		Resolver: &net.Resolver{
			PreferGo: true,
			Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
				log.Fatal("bootstrapping DNS is disabled unless --bootstrap-resolver is specified", address)
				panic("never dial")
			},
		},
	}

	if *bootstrapResolver != "" {
		if parsed := net.ParseIP(*bootstrapResolver); parsed == nil {
			log.Fatalf("boostrap-resolver must be an IP address")
		}

		bootstrapResolverAddress := net.JoinHostPort(*bootstrapResolver, "53")
		dialer.Resolver = &net.Resolver{
			PreferGo: true,
			Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
				var d net.Dialer
				return d.DialContext(ctx, network, bootstrapResolverAddress)
			},
		}
	}

	transport := &http.Transport{
		Proxy:                 http.ProxyFromEnvironment,
		DialContext:           dialer.DialContext,
		MaxIdleConns:          100,
		IdleConnTimeout:       90 * time.Second,
		TLSHandshakeTimeout:   10 * time.Second,
		ExpectContinueTimeout: 1 * time.Second,
	}
	if err := http2.ConfigureTransport(transport); err != nil {
		panic(fmt.Sprintf("unable to setup http2: %s", err))
	}
	client := http.Client{Transport: transport}

	// per the RFC, application/dns-message requests should use a message id of
	// zero for cache friendliness.
	//
	// TODO(benl): a cli flag to disable recursive queries
	question := dnsmessage.Message{
		Header: dnsmessage.Header{
			ID:               0,
			RecursionDesired: true,
		},
	}

	for _, name := range flag.Args() {
		canonicalName := name
		if !strings.HasSuffix(canonicalName, ".") {
			canonicalName = canonicalName + "."
		}
		question.Questions = append(question.Questions, dnsmessage.Question{
			Class: qclass,
			Type:  qtype,
			Name:  dnsmessage.MustNewName(canonicalName),
		})
	}

	var err error
	var request *http.Request
	if *post {
		request, err = postRequest(*resolver, question)
	} else {
		request, err = getRequest(*resolver, question)
	}
	if err != nil {
		log.Fatalln("error building request:", err)
	}
	request.Header.Set(headerUserAgent, userAgent)

	if *dumpHTTP {
		reqDump, err := dumpRequest(request)
		if err != nil {
			panic(err)
		}
		log.Println(reqDump)
	}

	response, err := client.Do(request)
	if err != nil {
		log.Fatal(err)
	}

	if *dumpHTTP {
		respDump, err := dumpResponse(response)
		if err != nil {
			panic(err)
		}
		log.Println(respDump)
	}

	if response.StatusCode == 200 {
		msg, err := dnsResponse(response)
		if err != nil {
			log.Fatalf("error parsing response from server: %s", err)
		}

		for _, answer := range msg.Answers {
			if *short {
				log.Println(formatBody(answer.Body))
			} else {
				log.Println(formatHeader(answer.Header), formatBody(answer.Body))
			}
		}
	}
}

// pack a dns Message into a POST request. according to the RFC, POST requests
// should include Accept: application/dns-message, and should set their
// content-type appropriately
func postRequest(resolverURL string, msg dnsmessage.Message) (*http.Request, error) {
	bs, err := msg.Pack()
	if err != nil {
		return nil, err
	}

	req, err := http.NewRequest(http.MethodPost, resolverURL, bytes.NewBuffer(bs))
	if err != nil {
		return nil, err
	}
	if req.URL.Scheme != schemeHTTPS {
		return nil, fmt.Errorf("dns-over-https requires https: url scheme is %q", req.URL.Scheme)
	}

	req.Header.Set(headerContentType, contentTypeDNSMessage)
	req.Header.Set(headerAccept, contentTypeDNSMessage)

	return req, nil
}

// pack a dns Message into a GET request. according to the RFC, GET requests
// should include Accept: application/dns-message, and must base64 their request
// into a "dns" query parameter
func getRequest(resolverURL string, msg dnsmessage.Message) (*http.Request, error) {
	bs, err := msg.Pack()
	if err != nil {
		return nil, err
	}
	encodedMessage := base64.RawURLEncoding.EncodeToString(bs)

	req, err := http.NewRequest(http.MethodGet, resolverURL, nil)
	if err != nil {
		return nil, err
	}
	if req.URL.Scheme != schemeHTTPS {
		return nil, fmt.Errorf("dns-over-https requires https: url scheme is %q", req.URL.Scheme)
	}

	req.Header.Set(headerAccept, contentTypeDNSMessage)

	query := req.URL.Query()
	query.Set(queryParameterDNS, encodedMessage)
	req.URL.RawQuery = query.Encode()

	return req, nil
}

// parse a DNS response out of an http response body. returns an error if the
// content-type isn't application/dns-message or if there is no body
// to the response
func dnsResponse(response *http.Response) (*dnsmessage.Message, error) {
	if response.Body == nil {
		return nil, fmt.Errorf("nil body")
	}

	if contentType := response.Header.Get(headerContentType); strings.ToLower(contentType) != contentTypeDNSMessage {
		return nil, fmt.Errorf("unrecognized content type: %q", contentType)
	}

	defer response.Body.Close()
	bs, err := ioutil.ReadAll(response.Body)
	if err != nil {
		return nil, err
	}

	if len(bs) == 0 {
		return nil, fmt.Errorf("body was empty")
	}

	var msg dnsmessage.Message
	if err := msg.Unpack(bs); err != nil {
		return nil, err
	}
	return &msg, nil
}

// format a dns message header for output. looks a little bit like dig, with
// less whitepsace
func formatHeader(h dnsmessage.ResourceHeader) string {
	typ := strings.TrimPrefix(h.Type.String(), "Type")
	return fmt.Sprintf("%s %d %s", h.Name, h.TTL, typ)
}

// format a dns resource body for output. mimics dig's output where appropriate
func formatBody(b dnsmessage.ResourceBody) string {
	switch rr := b.(type) {
	case *dnsmessage.AResource:
		return net.IP(rr.A[:]).String()
	case *dnsmessage.NSResource:
		return rr.NS.String()
	case *dnsmessage.CNAMEResource:
		return rr.CNAME.String()
	case *dnsmessage.SOAResource:
		return fmt.Sprintf("%s %s %d %d %d %d %d", rr.NS, rr.MBox, rr.Serial, rr.Refresh, rr.Retry, rr.Expire, rr.MinTTL)
	case *dnsmessage.PTRResource:
		return rr.PTR.String()
	case *dnsmessage.MXResource:
		return fmt.Sprintf("%d %s", rr.Pref, rr.MX)
	case *dnsmessage.TXTResource:
		return fmt.Sprintf("%q", strings.Join(rr.TXT, " "))
	case *dnsmessage.AAAAResource:
		return net.IP(rr.AAAA[:]).String()
	}

	return "(unknown)"
}

// dump an http request as a string, including the body if present. buffers
// the entire body into memory.handles replacing the body with a copy.
func dumpRequest(request *http.Request) (string, error) {
	requestBytes, err := httputil.DumpRequestOut(request, false)
	if err != nil {
		return "", err
	}
	headers := string(requestBytes)

	if request.Body == http.NoBody || request.Body == nil {
		return headers, nil
	}

	oldBody := request.Body
	defer oldBody.Close()

	bodyBytes, err := ioutil.ReadAll(request.Body)
	if err != nil {
		return "", err
	}

	request.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
	request.ContentLength = int64(len(bodyBytes))
	request.GetBody = func() (io.ReadCloser, error) {
		return ioutil.NopCloser(bytes.NewReader(bodyBytes)), nil
	}

	if len(bodyBytes) > 0 {
		return strings.Join([]string{headers, hex.Dump(bodyBytes)}, "\n"), nil
	}
	return headers, nil
}

// dump an http response as a string, including the body if present. buffers
// the entire body into memory. handles replacing the body with a copy.
func dumpResponse(response *http.Response) (string, error) {
	bs, err := httputil.DumpResponse(response, false)
	if err != nil {
		return "", err
	}
	headers := string(bs)

	oldBody := response.Body
	defer oldBody.Close()

	bodyBytes, err := ioutil.ReadAll(response.Body)
	if err != nil {
		return "", err
	}
	response.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))

	if len(bodyBytes) > 0 {
		return strings.Join([]string{headers, hex.Dump(bodyBytes)}, "\n"), nil
	}
	return headers, nil
}