// Copyright (c) 2021, Google Inc.
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

package main

import (
	"errors"
	"flag"
	"fmt"
	"log"
	"net"
	"os"
	"path"
	"strings"

	"golang.org/x/crypto/cryptobyte"
	"golang.org/x/net/dns/dnsmessage"
)

const (
	httpsType = 65 // RRTYPE for HTTPS records.

	// SvcParamKey codepoints defined in draft-ietf-dnsop-svcb-https-06.
	httpsKeyMandatory     = 0
	httpsKeyALPN          = 1
	httpsKeyNoDefaultALPN = 2
	httpsKeyPort          = 3
	httpsKeyIPV4Hint      = 4
	httpsKeyECH           = 5
	httpsKeyIPV6Hint      = 6
)

var (
	name   = flag.String("name", "", "The name to look up in DNS. Required.")
	server = flag.String("server", "8.8.8.8:53", "Comma-separated host and UDP port that defines the DNS server to query.")
	outDir = flag.String("out-dir", "", "The directory where ECHConfigList values will be written. If unspecified, bytes are hexdumped to stdout.")
)

type httpsRecord struct {
	priority   uint16
	targetName string

	// SvcParams:
	mandatory     []uint16
	alpn          []string
	noDefaultALPN bool
	hasPort       bool
	port          uint16
	ipv4hint      []net.IP
	ech           []byte
	ipv6hint      []net.IP
	unknownParams map[uint16][]byte
}

// String pretty-prints |h| as a multi-line string with bullet points.
func (h httpsRecord) String() string {
	var b strings.Builder
	fmt.Fprintf(&b, "HTTPS SvcPriority:%d TargetName:%q", h.priority, h.targetName)

	if len(h.mandatory) != 0 {
		fmt.Fprintf(&b, "\n  * mandatory: %v", h.mandatory)
	}
	if len(h.alpn) != 0 {
		fmt.Fprintf(&b, "\n  * alpn: %q", h.alpn)
	}
	if h.noDefaultALPN {
		fmt.Fprint(&b, "\n  * no-default-alpn")
	}
	if h.hasPort {
		fmt.Fprintf(&b, "\n  * port: %d", h.port)
	}
	if len(h.ipv4hint) != 0 {
		fmt.Fprintf(&b, "\n  * ipv4hint:")
		for _, address := range h.ipv4hint {
			fmt.Fprintf(&b, "\n    - %s", address)
		}
	}
	if len(h.ech) != 0 {
		fmt.Fprintf(&b, "\n  * ech: %x", h.ech)
	}
	if len(h.ipv6hint) != 0 {
		fmt.Fprintf(&b, "\n  * ipv6hint:")
		for _, address := range h.ipv6hint {
			fmt.Fprintf(&b, "\n    - %s", address)
		}
	}
	if len(h.unknownParams) != 0 {
		fmt.Fprint(&b, "\n  * unknown SvcParams:")
		for key, value := range h.unknownParams {
			fmt.Fprintf(&b, "\n    - %d: %x", key, value)
		}
	}
	return b.String()
}

// dnsQueryForHTTPS queries the DNS server over UDP for any HTTPS records
// associated with |domain|. It scans the response's answers and returns all the
// HTTPS records it finds. It returns an error if any connection steps fail.
func dnsQueryForHTTPS(domain string) ([][]byte, error) {
	udpAddr, err := net.ResolveUDPAddr("udp", *server)
	if err != nil {
		return nil, err
	}
	conn, err := net.DialUDP("udp", nil, udpAddr)
	if err != nil {
		return nil, fmt.Errorf("failed to dial: %s", err)
	}
	defer conn.Close()

	// Domain name must be canonical or message packing will fail.
	if domain[len(domain)-1] != '.' {
		domain += "."
	}
	dnsName, err := dnsmessage.NewName(domain)
	if err != nil {
		return nil, fmt.Errorf("failed to create DNS name from %q: %s", domain, err)
	}
	question := dnsmessage.Question{
		Name:  dnsName,
		Type:  httpsType,
		Class: dnsmessage.ClassINET,
	}
	msg := dnsmessage.Message{
		Header: dnsmessage.Header{
			RecursionDesired: true,
		},
		Questions: []dnsmessage.Question{question},
	}
	packedMsg, err := msg.Pack()
	if err != nil {
		return nil, fmt.Errorf("failed to pack msg: %s", err)
	}

	if _, err = conn.Write(packedMsg); err != nil {
		return nil, fmt.Errorf("failed to send the DNS query: %s", err)
	}

	for {
		response := make([]byte, 512)
		n, err := conn.Read(response)
		if err != nil {
			return nil, fmt.Errorf("failed to read the DNS response: %s", err)
		}
		response = response[:n]

		var p dnsmessage.Parser
		header, err := p.Start(response)
		if err != nil {
			return nil, err
		}
		if !header.Response {
			return nil, errors.New("received DNS message is not a response")
		}
		if header.RCode != dnsmessage.RCodeSuccess {
			return nil, fmt.Errorf("response from DNS has non-success RCode: %s", header.RCode.String())
		}
		if header.ID != 0 {
			return nil, errors.New("received a DNS response with the wrong ID")
		}
		if !header.RecursionAvailable {
			return nil, errors.New("server does not support recursion")
		}
		// Verify that this response answers the question that we asked in the
		// query. If the resolver encountered any CNAMEs, it's not guaranteed
		// that the response will contain a question with the same QNAME as our
		// query. However, RFC 8499 Section 4 indicates that in general use, the
		// response's QNAME should match the query, so we will make that
		// assumption.
		q, err := p.Question()
		if err != nil {
			return nil, err
		}
		if q != question {
			return nil, fmt.Errorf("response answers the wrong question: %v", q)
		}
		if q, err = p.Question(); err != dnsmessage.ErrSectionDone {
			return nil, fmt.Errorf("response contains an unexpected question: %v", q)
		}

		var httpsRecords [][]byte
		for {
			h, err := p.AnswerHeader()
			if err == dnsmessage.ErrSectionDone {
				break
			}
			if err != nil {
				return nil, err
			}

			switch h.Type {
			case httpsType:
				// This should continue to work when golang.org/x/net/dns/dnsmessage
				// adds support for HTTPS records.
				r, err := p.UnknownResource()
				if err != nil {
					return nil, err
				}
				httpsRecords = append(httpsRecords, r.Data)
			default:
				if _, err := p.UnknownResource(); err != nil {
					return nil, err
				}
			}
		}
		return httpsRecords, nil
	}
}

// parseHTTPSRecord parses an HTTPS record (draft-ietf-dnsop-svcb-https-06,
// Section 2.2) from |raw|. If there are syntax errors, it returns an error.
func parseHTTPSRecord(raw []byte) (httpsRecord, error) {
	reader := cryptobyte.String(raw)

	var priority uint16
	if !reader.ReadUint16(&priority) {
		return httpsRecord{}, errors.New("failed to parse HTTPS record priority")
	}

	// Read the TargetName.
	var dottedDomain string
	for {
		var label cryptobyte.String
		if !reader.ReadUint8LengthPrefixed(&label) {
			return httpsRecord{}, errors.New("failed to parse HTTPS record TargetName")
		}
		if label.Empty() {
			break
		}
		dottedDomain += string(label) + "."
	}

	if priority == 0 {
		// TODO(dmcardle) Recursively follow AliasForm records.
		return httpsRecord{}, fmt.Errorf("received an AliasForm HTTPS record with TargetName=%q", dottedDomain)
	}

	record := httpsRecord{
		priority:      priority,
		targetName:    dottedDomain,
		unknownParams: make(map[uint16][]byte),
	}

	// Read the SvcParams.
	var lastSvcParamKey uint16
	for svcParamCount := 0; !reader.Empty(); svcParamCount++ {
		var svcParamKey uint16
		var svcParamValue cryptobyte.String
		if !reader.ReadUint16(&svcParamKey) ||
			!reader.ReadUint16LengthPrefixed(&svcParamValue) {
			return httpsRecord{}, errors.New("failed to parse HTTPS record SvcParam")
		}
		if svcParamCount > 0 && svcParamKey <= lastSvcParamKey {
			return httpsRecord{}, errors.New("malformed HTTPS record contains out-of-order SvcParamKey")
		}
		lastSvcParamKey = svcParamKey

		switch svcParamKey {
		case httpsKeyMandatory:
			if svcParamValue.Empty() {
				return httpsRecord{}, errors.New("malformed mandatory SvcParamValue")
			}
			var lastKey uint16
			for !svcParamValue.Empty() {
				// |httpsKeyMandatory| may not appear in the mandatory list.
				// |httpsKeyMandatory| is zero, so checking against the initial
				// value of |lastKey| handles ordering and the invalid code point.
				var key uint16
				if !svcParamValue.ReadUint16(&key) ||
					key <= lastKey {
					return httpsRecord{}, errors.New("malformed mandatory SvcParamValue")
				}
				lastKey = key
				record.mandatory = append(record.mandatory, key)
			}
		case httpsKeyALPN:
			if svcParamValue.Empty() {
				return httpsRecord{}, errors.New("malformed alpn SvcParamValue")
			}
			for !svcParamValue.Empty() {
				var alpn cryptobyte.String
				if !svcParamValue.ReadUint8LengthPrefixed(&alpn) || alpn.Empty() {
					return httpsRecord{}, errors.New("malformed alpn SvcParamValue")
				}
				record.alpn = append(record.alpn, string(alpn))
			}
		case httpsKeyNoDefaultALPN:
			if !svcParamValue.Empty() {
				return httpsRecord{}, errors.New("malformed no-default-alpn SvcParamValue")
			}
			record.noDefaultALPN = true
		case httpsKeyPort:
			if !svcParamValue.ReadUint16(&record.port) ||
				!svcParamValue.Empty() {
				return httpsRecord{}, errors.New("malformed port SvcParamValue")
			}
			record.hasPort = true
		case httpsKeyIPV4Hint:
			if svcParamValue.Empty() {
				return httpsRecord{}, errors.New("malformed ipv4hint SvcParamValue")
			}
			for !svcParamValue.Empty() {
				var address []byte
				if !svcParamValue.ReadBytes(&address, 4) {
					return httpsRecord{}, errors.New("malformed ipv4hint SvcParamValue")
				}
				record.ipv4hint = append(record.ipv4hint, address)
			}
		case httpsKeyECH:
			if svcParamValue.Empty() {
				return httpsRecord{}, errors.New("malformed ech SvcParamValue")
			}
			record.ech = svcParamValue
		case httpsKeyIPV6Hint:
			if svcParamValue.Empty() {
				return httpsRecord{}, errors.New("malformed ipv6hint SvcParamValue")
			}
			for !svcParamValue.Empty() {
				var address []byte
				if !svcParamValue.ReadBytes(&address, 16) {
					return httpsRecord{}, errors.New("malformed ipv6hint SvcParamValue")
				}
				record.ipv6hint = append(record.ipv6hint, address)
			}
		default:
			record.unknownParams[svcParamKey] = svcParamValue
		}
	}
	return record, nil
}

func main() {
	flag.Parse()
	log.SetFlags(log.Lshortfile | log.LstdFlags)

	if len(*name) == 0 {
		flag.Usage()
		os.Exit(1)
	}

	httpsRecords, err := dnsQueryForHTTPS(*name)
	if err != nil {
		log.Printf("Error querying %q: %s\n", *name, err)
		os.Exit(1)
	}
	if len(httpsRecords) == 0 {
		log.Println("No HTTPS records found in DNS response.")
		os.Exit(1)
	}

	if len(*outDir) > 0 {
		if err = os.Mkdir(*outDir, 0755); err != nil && !os.IsExist(err) {
			log.Printf("Failed to create out directory %q: %s\n", *outDir, err)
			os.Exit(1)
		}
	}

	var echConfigListCount int
	for _, httpsRecord := range httpsRecords {
		record, err := parseHTTPSRecord(httpsRecord)
		if err != nil {
			log.Printf("Failed to parse HTTPS record: %s", err)
			os.Exit(1)
		}
		fmt.Printf("%s\n", record)
		if len(*outDir) == 0 {
			continue
		}

		outFile := path.Join(*outDir, fmt.Sprintf("ech-config-list-%d", echConfigListCount))
		if err = os.WriteFile(outFile, record.ech, 0644); err != nil {
			log.Printf("Failed to write file: %s\n", err)
			os.Exit(1)
		}
		fmt.Printf("Wrote ECHConfigList to %q\n", outFile)
		echConfigListCount++
	}
}
