From 349a6e68d1b9bc27ecac16e32cc75d3d338d3afc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20Benkovsk=C3=BD?= Date: Mon, 10 Jul 2023 21:34:08 +0200 Subject: [PATCH] support read and write timeouts --- doq/client.go | 114 +++++++++++++++++++++++++++++++++------------ doq/client_test.go | 40 +++++++++++++++- 2 files changed, 124 insertions(+), 30 deletions(-) diff --git a/doq/client.go b/doq/client.go index d223396..d79aad3 100644 --- a/doq/client.go +++ b/doq/client.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "io" "sync" + "time" "github.com/miekg/dns" "github.com/quic-go/quic-go" @@ -14,15 +15,20 @@ import ( // Client encapsulates and provides logic for querying DNS servers over QUIC. // The client should be thread-safe. The client reuses single QUIC connection to the server, while creating multiple parallel QUIC streams. type Client struct { - lock sync.Mutex - addr string - tlsconfig *tls.Config - conn quic.Connection + lock sync.Mutex + addr string + tlsconfig *tls.Config + conn quic.Connection + writeTimeout time.Duration + readTimeout time.Duration } // Options encapsulates configuration options for doq.Client. +// By default, WriteTimeout and ReadTimeout is zero, meaning there is no timeout. type Options struct { - TLSConfig *tls.Config + TLSConfig *tls.Config + WriteTimeout time.Duration + ReadTimeout time.Duration } // NewClient creates a new doq.Client used for sending DoQ queries. @@ -39,6 +45,8 @@ func NewClient(addr string, options Options) (*Client, error) { // override protocol negotiation to DoQ, all the other stuff (like certificates, cert pools, insecure skip) is up to the user of library client.tlsconfig.NextProtos = []string{"doq"} + client.readTimeout = options.ReadTimeout + client.writeTimeout = options.WriteTimeout return &client, nil } @@ -94,43 +102,91 @@ func (c *Client) Send(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { } } - pack, err := msg.Pack() + stream, err := c.conn.OpenStreamSync(ctx) if err != nil { return nil, err } - packWithPrefix := make([]byte, 2+len(pack)) - binary.BigEndian.PutUint16(packWithPrefix, uint16(len(pack))) - copy(packWithPrefix[2:], pack) - streamSync, err := c.conn.OpenStreamSync(ctx) - if err != nil { + writeCtx := ctx + if c.writeTimeout != 0 { + var cancel context.CancelFunc + writeCtx, cancel = context.WithTimeout(writeCtx, c.writeTimeout) + defer cancel() + } + if err := writeMsg(writeCtx, stream, msg); err != nil { return nil, err } - _, err = streamSync.Write(packWithPrefix) - // close the stream to indicate we are done sending or the server might wait till we close the stream or timeout is hit - _ = streamSync.Close() - if err != nil { - return nil, err + readCtx := ctx + if c.readTimeout != 0 { + var cancel context.CancelFunc + readCtx, cancel = context.WithTimeout(readCtx, c.readTimeout) + defer cancel() } + return readMsg(readCtx, stream) +} - // read 2-octet length field to know how long the DNS message is - sizeBuf := make([]byte, 2) - _, err = io.ReadFull(streamSync, sizeBuf) +func writeMsg(ctx context.Context, stream quic.Stream, msg *dns.Msg) error { + pack, err := msg.Pack() if err != nil { - return nil, err + return err } + packWithPrefix := make([]byte, 2+len(pack)) + binary.BigEndian.PutUint16(packWithPrefix, uint16(len(pack))) + copy(packWithPrefix[2:], pack) - size := binary.BigEndian.Uint16(sizeBuf) - buf := make([]byte, size) - _, err = io.ReadFull(streamSync, buf) - if err != nil { - return nil, err + done := make(chan error) + go func() { + _, err = stream.Write(packWithPrefix) + // close the stream to indicate we are done sending or the server might wait till we close the stream or timeout is hit + _ = stream.Close() + done <- err + }() + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-done: + return err } +} - resp := dns.Msg{} - if err := resp.Unpack(buf); err != nil { - return nil, err +func readMsg(ctx context.Context, stream quic.Stream) (*dns.Msg, error) { + done := make(chan interface{}) + go func() { + // read 2-octet length field to know how long the DNS message is + sizeBuf := make([]byte, 2) + _, err := io.ReadFull(stream, sizeBuf) + if err != nil { + done <- err + return + } + + size := binary.BigEndian.Uint16(sizeBuf) + buf := make([]byte, size) + _, err = io.ReadFull(stream, buf) + if err != nil { + done <- err + return + } + + resp := dns.Msg{} + if err := resp.Unpack(buf); err != nil { + done <- err + return + } + done <- &resp + }() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case res := <-done: + switch r := res.(type) { + case error: + return nil, r + case *dns.Msg: + return r, nil + default: + panic("unknown response") + } } - return &resp, nil } diff --git a/doq/client_test.go b/doq/client_test.go index f1bbe6f..abd715a 100644 --- a/doq/client_test.go +++ b/doq/client_test.go @@ -9,6 +9,7 @@ import ( "os" "sync/atomic" "testing" + "time" "github.com/miekg/dns" "github.com/quic-go/quic-go" @@ -38,10 +39,15 @@ func (d *doqServer) start() { } return } + stream, err := conn.AcceptStream(context.Background()) if err != nil { panic(err) } + + // to reliably test read timeout + time.Sleep(time.Second) + resp := dns.Msg{ MsgHdr: dns.MsgHdr{Rcode: dns.RcodeSuccess}, Question: []dns.Question{{Name: "example.org.", Qtype: dns.TypeA}}, @@ -79,7 +85,7 @@ func Test(t *testing.T) { server.start() defer server.stop() - client, err := NewClient(server.addr, Options{generateTLSConfig()}) + client, err := NewClient(server.addr, Options{TLSConfig: generateTLSConfig()}) require.NoError(t, err) msg := dns.Msg{} @@ -92,6 +98,38 @@ func Test(t *testing.T) { assert.Equal(t, net.ParseIP("127.0.0.1").To4(), resp.Answer[0].(*dns.A).A) } +func TestWriteTimeout(t *testing.T) { + server := doqServer{} + server.start() + defer server.stop() + + client, err := NewClient(server.addr, Options{TLSConfig: generateTLSConfig(), WriteTimeout: 1 * time.Nanosecond}) + require.NoError(t, err) + + msg := dns.Msg{} + msg.SetQuestion("example.org.", dns.TypeA) + resp, err := client.Send(context.Background(), &msg) + + require.ErrorIs(t, err, context.DeadlineExceeded) + require.Nil(t, resp) +} + +func TestReadTimeout(t *testing.T) { + server := doqServer{} + server.start() + defer server.stop() + + client, err := NewClient(server.addr, Options{TLSConfig: generateTLSConfig(), ReadTimeout: 1 * time.Nanosecond}) + require.NoError(t, err) + + msg := dns.Msg{} + msg.SetQuestion("example.org.", dns.TypeA) + resp, err := client.Send(context.Background(), &msg) + + require.ErrorIs(t, err, context.DeadlineExceeded) + require.Nil(t, resp) +} + func generateTLSConfig() *tls.Config { cert, err := tls.LoadX509KeyPair("test.crt", "test.key") if err != nil {