| // Copyright (c) 2021, Google Inc. |
| // |
| // Permission to use, copy, modify, and/or distribute this software for any |
| // purpose with or without fee is hereby granted, provided that the above |
| // copyright notice and this permission notice appear in all copies. |
| // |
| // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES |
| // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF |
| // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY |
| // SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES |
| // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION |
| // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN |
| // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. |
| |
| package main |
| |
| import ( |
| "errors" |
| "flag" |
| "fmt" |
| "log" |
| "net" |
| "os" |
| "path" |
| "strings" |
| |
| "golang.org/x/crypto/cryptobyte" |
| "golang.org/x/net/dns/dnsmessage" |
| ) |
| |
| const ( |
| httpsType = 65 // RRTYPE for HTTPS records. |
| |
| // SvcParamKey codepoints defined in draft-ietf-dnsop-svcb-https-06. |
| httpsKeyMandatory = 0 |
| httpsKeyALPN = 1 |
| httpsKeyNoDefaultALPN = 2 |
| httpsKeyPort = 3 |
| httpsKeyIPV4Hint = 4 |
| httpsKeyECH = 5 |
| httpsKeyIPV6Hint = 6 |
| ) |
| |
| var ( |
| name = flag.String("name", "", "The name to look up in DNS. Required.") |
| server = flag.String("server", "8.8.8.8:53", "Comma-separated host and UDP port that defines the DNS server to query.") |
| outDir = flag.String("out-dir", "", "The directory where ECHConfigList values will be written. If unspecified, bytes are hexdumped to stdout.") |
| ) |
| |
| type httpsRecord struct { |
| priority uint16 |
| targetName string |
| |
| // SvcParams: |
| mandatory []uint16 |
| alpn []string |
| noDefaultALPN bool |
| hasPort bool |
| port uint16 |
| ipv4hint []net.IP |
| ech []byte |
| ipv6hint []net.IP |
| unknownParams map[uint16][]byte |
| } |
| |
| // String pretty-prints |h| as a multi-line string with bullet points. |
| func (h httpsRecord) String() string { |
| var b strings.Builder |
| fmt.Fprintf(&b, "HTTPS SvcPriority:%d TargetName:%q", h.priority, h.targetName) |
| |
| if len(h.mandatory) != 0 { |
| fmt.Fprintf(&b, "\n * mandatory: %v", h.mandatory) |
| } |
| if len(h.alpn) != 0 { |
| fmt.Fprintf(&b, "\n * alpn: %q", h.alpn) |
| } |
| if h.noDefaultALPN { |
| fmt.Fprint(&b, "\n * no-default-alpn") |
| } |
| if h.hasPort { |
| fmt.Fprintf(&b, "\n * port: %d", h.port) |
| } |
| if len(h.ipv4hint) != 0 { |
| fmt.Fprintf(&b, "\n * ipv4hint:") |
| for _, address := range h.ipv4hint { |
| fmt.Fprintf(&b, "\n - %s", address) |
| } |
| } |
| if len(h.ech) != 0 { |
| fmt.Fprintf(&b, "\n * ech: %x", h.ech) |
| } |
| if len(h.ipv6hint) != 0 { |
| fmt.Fprintf(&b, "\n * ipv6hint:") |
| for _, address := range h.ipv6hint { |
| fmt.Fprintf(&b, "\n - %s", address) |
| } |
| } |
| if len(h.unknownParams) != 0 { |
| fmt.Fprint(&b, "\n * unknown SvcParams:") |
| for key, value := range h.unknownParams { |
| fmt.Fprintf(&b, "\n - %d: %x", key, value) |
| } |
| } |
| return b.String() |
| } |
| |
| // dnsQueryForHTTPS queries the DNS server over UDP for any HTTPS records |
| // associated with |domain|. It scans the response's answers and returns all the |
| // HTTPS records it finds. It returns an error if any connection steps fail. |
| func dnsQueryForHTTPS(domain string) ([][]byte, error) { |
| udpAddr, err := net.ResolveUDPAddr("udp", *server) |
| if err != nil { |
| return nil, err |
| } |
| conn, err := net.DialUDP("udp", nil, udpAddr) |
| if err != nil { |
| return nil, fmt.Errorf("failed to dial: %s", err) |
| } |
| defer conn.Close() |
| |
| // Domain name must be canonical or message packing will fail. |
| if domain[len(domain)-1] != '.' { |
| domain += "." |
| } |
| dnsName, err := dnsmessage.NewName(domain) |
| if err != nil { |
| return nil, fmt.Errorf("failed to create DNS name from %q: %s", domain, err) |
| } |
| question := dnsmessage.Question{ |
| Name: dnsName, |
| Type: httpsType, |
| Class: dnsmessage.ClassINET, |
| } |
| msg := dnsmessage.Message{ |
| Header: dnsmessage.Header{ |
| RecursionDesired: true, |
| }, |
| Questions: []dnsmessage.Question{question}, |
| } |
| packedMsg, err := msg.Pack() |
| if err != nil { |
| return nil, fmt.Errorf("failed to pack msg: %s", err) |
| } |
| |
| if _, err = conn.Write(packedMsg); err != nil { |
| return nil, fmt.Errorf("failed to send the DNS query: %s", err) |
| } |
| |
| for { |
| response := make([]byte, 512) |
| n, err := conn.Read(response) |
| if err != nil { |
| return nil, fmt.Errorf("failed to read the DNS response: %s", err) |
| } |
| response = response[:n] |
| |
| var p dnsmessage.Parser |
| header, err := p.Start(response) |
| if err != nil { |
| return nil, err |
| } |
| if !header.Response { |
| return nil, errors.New("received DNS message is not a response") |
| } |
| if header.RCode != dnsmessage.RCodeSuccess { |
| return nil, fmt.Errorf("response from DNS has non-success RCode: %s", header.RCode.String()) |
| } |
| if header.ID != 0 { |
| return nil, errors.New("received a DNS response with the wrong ID") |
| } |
| if !header.RecursionAvailable { |
| return nil, errors.New("server does not support recursion") |
| } |
| // Verify that this response answers the question that we asked in the |
| // query. If the resolver encountered any CNAMEs, it's not guaranteed |
| // that the response will contain a question with the same QNAME as our |
| // query. However, RFC 8499 Section 4 indicates that in general use, the |
| // response's QNAME should match the query, so we will make that |
| // assumption. |
| q, err := p.Question() |
| if err != nil { |
| return nil, err |
| } |
| if q != question { |
| return nil, fmt.Errorf("response answers the wrong question: %v", q) |
| } |
| if q, err = p.Question(); err != dnsmessage.ErrSectionDone { |
| return nil, fmt.Errorf("response contains an unexpected question: %v", q) |
| } |
| |
| var httpsRecords [][]byte |
| for { |
| h, err := p.AnswerHeader() |
| if err == dnsmessage.ErrSectionDone { |
| break |
| } |
| if err != nil { |
| return nil, err |
| } |
| |
| switch h.Type { |
| case httpsType: |
| // This should continue to work when golang.org/x/net/dns/dnsmessage |
| // adds support for HTTPS records. |
| r, err := p.UnknownResource() |
| if err != nil { |
| return nil, err |
| } |
| httpsRecords = append(httpsRecords, r.Data) |
| default: |
| if _, err := p.UnknownResource(); err != nil { |
| return nil, err |
| } |
| } |
| } |
| return httpsRecords, nil |
| } |
| } |
| |
| // parseHTTPSRecord parses an HTTPS record (draft-ietf-dnsop-svcb-https-06, |
| // Section 2.2) from |raw|. If there are syntax errors, it returns an error. |
| func parseHTTPSRecord(raw []byte) (httpsRecord, error) { |
| reader := cryptobyte.String(raw) |
| |
| var priority uint16 |
| if !reader.ReadUint16(&priority) { |
| return httpsRecord{}, errors.New("failed to parse HTTPS record priority") |
| } |
| |
| // Read the TargetName. |
| var dottedDomain string |
| for { |
| var label cryptobyte.String |
| if !reader.ReadUint8LengthPrefixed(&label) { |
| return httpsRecord{}, errors.New("failed to parse HTTPS record TargetName") |
| } |
| if label.Empty() { |
| break |
| } |
| dottedDomain += string(label) + "." |
| } |
| |
| if priority == 0 { |
| // TODO(dmcardle) Recursively follow AliasForm records. |
| return httpsRecord{}, fmt.Errorf("received an AliasForm HTTPS record with TargetName=%q", dottedDomain) |
| } |
| |
| record := httpsRecord{ |
| priority: priority, |
| targetName: dottedDomain, |
| unknownParams: make(map[uint16][]byte), |
| } |
| |
| // Read the SvcParams. |
| var lastSvcParamKey uint16 |
| for svcParamCount := 0; !reader.Empty(); svcParamCount++ { |
| var svcParamKey uint16 |
| var svcParamValue cryptobyte.String |
| if !reader.ReadUint16(&svcParamKey) || |
| !reader.ReadUint16LengthPrefixed(&svcParamValue) { |
| return httpsRecord{}, errors.New("failed to parse HTTPS record SvcParam") |
| } |
| if svcParamCount > 0 && svcParamKey <= lastSvcParamKey { |
| return httpsRecord{}, errors.New("malformed HTTPS record contains out-of-order SvcParamKey") |
| } |
| lastSvcParamKey = svcParamKey |
| |
| switch svcParamKey { |
| case httpsKeyMandatory: |
| if svcParamValue.Empty() { |
| return httpsRecord{}, errors.New("malformed mandatory SvcParamValue") |
| } |
| var lastKey uint16 |
| for !svcParamValue.Empty() { |
| // |httpsKeyMandatory| may not appear in the mandatory list. |
| // |httpsKeyMandatory| is zero, so checking against the initial |
| // value of |lastKey| handles ordering and the invalid code point. |
| var key uint16 |
| if !svcParamValue.ReadUint16(&key) || |
| key <= lastKey { |
| return httpsRecord{}, errors.New("malformed mandatory SvcParamValue") |
| } |
| lastKey = key |
| record.mandatory = append(record.mandatory, key) |
| } |
| case httpsKeyALPN: |
| if svcParamValue.Empty() { |
| return httpsRecord{}, errors.New("malformed alpn SvcParamValue") |
| } |
| for !svcParamValue.Empty() { |
| var alpn cryptobyte.String |
| if !svcParamValue.ReadUint8LengthPrefixed(&alpn) || alpn.Empty() { |
| return httpsRecord{}, errors.New("malformed alpn SvcParamValue") |
| } |
| record.alpn = append(record.alpn, string(alpn)) |
| } |
| case httpsKeyNoDefaultALPN: |
| if !svcParamValue.Empty() { |
| return httpsRecord{}, errors.New("malformed no-default-alpn SvcParamValue") |
| } |
| record.noDefaultALPN = true |
| case httpsKeyPort: |
| if !svcParamValue.ReadUint16(&record.port) || |
| !svcParamValue.Empty() { |
| return httpsRecord{}, errors.New("malformed port SvcParamValue") |
| } |
| record.hasPort = true |
| case httpsKeyIPV4Hint: |
| if svcParamValue.Empty() { |
| return httpsRecord{}, errors.New("malformed ipv4hint SvcParamValue") |
| } |
| for !svcParamValue.Empty() { |
| var address []byte |
| if !svcParamValue.ReadBytes(&address, 4) { |
| return httpsRecord{}, errors.New("malformed ipv4hint SvcParamValue") |
| } |
| record.ipv4hint = append(record.ipv4hint, address) |
| } |
| case httpsKeyECH: |
| if svcParamValue.Empty() { |
| return httpsRecord{}, errors.New("malformed ech SvcParamValue") |
| } |
| record.ech = svcParamValue |
| case httpsKeyIPV6Hint: |
| if svcParamValue.Empty() { |
| return httpsRecord{}, errors.New("malformed ipv6hint SvcParamValue") |
| } |
| for !svcParamValue.Empty() { |
| var address []byte |
| if !svcParamValue.ReadBytes(&address, 16) { |
| return httpsRecord{}, errors.New("malformed ipv6hint SvcParamValue") |
| } |
| record.ipv6hint = append(record.ipv6hint, address) |
| } |
| default: |
| record.unknownParams[svcParamKey] = svcParamValue |
| } |
| } |
| return record, nil |
| } |
| |
| func main() { |
| flag.Parse() |
| log.SetFlags(log.Lshortfile | log.LstdFlags) |
| |
| if len(*name) == 0 { |
| flag.Usage() |
| os.Exit(1) |
| } |
| |
| httpsRecords, err := dnsQueryForHTTPS(*name) |
| if err != nil { |
| log.Printf("Error querying %q: %s\n", *name, err) |
| os.Exit(1) |
| } |
| if len(httpsRecords) == 0 { |
| log.Println("No HTTPS records found in DNS response.") |
| os.Exit(1) |
| } |
| |
| if len(*outDir) > 0 { |
| if err = os.Mkdir(*outDir, 0755); err != nil && !os.IsExist(err) { |
| log.Printf("Failed to create out directory %q: %s\n", *outDir, err) |
| os.Exit(1) |
| } |
| } |
| |
| var echConfigListCount int |
| for _, httpsRecord := range httpsRecords { |
| record, err := parseHTTPSRecord(httpsRecord) |
| if err != nil { |
| log.Printf("Failed to parse HTTPS record: %s", err) |
| os.Exit(1) |
| } |
| fmt.Printf("%s\n", record) |
| if len(*outDir) == 0 { |
| continue |
| } |
| |
| outFile := path.Join(*outDir, fmt.Sprintf("ech-config-list-%d", echConfigListCount)) |
| if err = os.WriteFile(outFile, record.ech, 0644); err != nil { |
| log.Printf("Failed to write file: %s\n", err) |
| os.Exit(1) |
| } |
| fmt.Printf("Wrote ECHConfigList to %q\n", outFile) |
| echConfigListCount++ |
| } |
| } |