Skip to content

Commit

Permalink
support read and write timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
Tantalor93 committed Jul 10, 2023
1 parent f9b834a commit 349a6e6
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 30 deletions.
114 changes: 85 additions & 29 deletions doq/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/binary"
"io"
"sync"
"time"

"github.com/miekg/dns"
"github.com/quic-go/quic-go"
Expand All @@ -14,15 +15,20 @@ import (
// Client encapsulates and provides logic for querying DNS servers over QUIC.
// The client should be thread-safe. The client reuses single QUIC connection to the server, while creating multiple parallel QUIC streams.
type Client struct {
lock sync.Mutex
addr string
tlsconfig *tls.Config
conn quic.Connection
lock sync.Mutex
addr string
tlsconfig *tls.Config
conn quic.Connection
writeTimeout time.Duration
readTimeout time.Duration
}

// Options encapsulates configuration options for doq.Client.
// By default, WriteTimeout and ReadTimeout is zero, meaning there is no timeout.
type Options struct {
TLSConfig *tls.Config
TLSConfig *tls.Config
WriteTimeout time.Duration
ReadTimeout time.Duration
}

// NewClient creates a new doq.Client used for sending DoQ queries.
Expand All @@ -39,6 +45,8 @@ func NewClient(addr string, options Options) (*Client, error) {

// override protocol negotiation to DoQ, all the other stuff (like certificates, cert pools, insecure skip) is up to the user of library
client.tlsconfig.NextProtos = []string{"doq"}
client.readTimeout = options.ReadTimeout
client.writeTimeout = options.WriteTimeout

return &client, nil
}
Expand Down Expand Up @@ -94,43 +102,91 @@ func (c *Client) Send(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
}
}

pack, err := msg.Pack()
stream, err := c.conn.OpenStreamSync(ctx)
if err != nil {
return nil, err
}
packWithPrefix := make([]byte, 2+len(pack))
binary.BigEndian.PutUint16(packWithPrefix, uint16(len(pack)))
copy(packWithPrefix[2:], pack)

streamSync, err := c.conn.OpenStreamSync(ctx)
if err != nil {
writeCtx := ctx
if c.writeTimeout != 0 {
var cancel context.CancelFunc
writeCtx, cancel = context.WithTimeout(writeCtx, c.writeTimeout)
defer cancel()
}
if err := writeMsg(writeCtx, stream, msg); err != nil {
return nil, err
}

_, err = streamSync.Write(packWithPrefix)
// close the stream to indicate we are done sending or the server might wait till we close the stream or timeout is hit
_ = streamSync.Close()
if err != nil {
return nil, err
readCtx := ctx
if c.readTimeout != 0 {
var cancel context.CancelFunc
readCtx, cancel = context.WithTimeout(readCtx, c.readTimeout)
defer cancel()
}
return readMsg(readCtx, stream)
}

// read 2-octet length field to know how long the DNS message is
sizeBuf := make([]byte, 2)
_, err = io.ReadFull(streamSync, sizeBuf)
func writeMsg(ctx context.Context, stream quic.Stream, msg *dns.Msg) error {
pack, err := msg.Pack()
if err != nil {
return nil, err
return err
}
packWithPrefix := make([]byte, 2+len(pack))
binary.BigEndian.PutUint16(packWithPrefix, uint16(len(pack)))
copy(packWithPrefix[2:], pack)

size := binary.BigEndian.Uint16(sizeBuf)
buf := make([]byte, size)
_, err = io.ReadFull(streamSync, buf)
if err != nil {
return nil, err
done := make(chan error)
go func() {
_, err = stream.Write(packWithPrefix)
// close the stream to indicate we are done sending or the server might wait till we close the stream or timeout is hit
_ = stream.Close()
done <- err
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
return err
}
}

resp := dns.Msg{}
if err := resp.Unpack(buf); err != nil {
return nil, err
func readMsg(ctx context.Context, stream quic.Stream) (*dns.Msg, error) {
done := make(chan interface{})
go func() {
// read 2-octet length field to know how long the DNS message is
sizeBuf := make([]byte, 2)
_, err := io.ReadFull(stream, sizeBuf)
if err != nil {
done <- err
return
}

size := binary.BigEndian.Uint16(sizeBuf)
buf := make([]byte, size)
_, err = io.ReadFull(stream, buf)
if err != nil {
done <- err
return
}

resp := dns.Msg{}
if err := resp.Unpack(buf); err != nil {
done <- err
return
}
done <- &resp
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case res := <-done:
switch r := res.(type) {
case error:
return nil, r
case *dns.Msg:
return r, nil
default:
panic("unknown response")
}
}
return &resp, nil
}
40 changes: 39 additions & 1 deletion doq/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"sync/atomic"
"testing"
"time"

"github.com/miekg/dns"
"github.com/quic-go/quic-go"
Expand Down Expand Up @@ -38,10 +39,15 @@ func (d *doqServer) start() {
}
return
}

stream, err := conn.AcceptStream(context.Background())
if err != nil {
panic(err)
}

// to reliably test read timeout
time.Sleep(time.Second)

resp := dns.Msg{
MsgHdr: dns.MsgHdr{Rcode: dns.RcodeSuccess},
Question: []dns.Question{{Name: "example.org.", Qtype: dns.TypeA}},
Expand Down Expand Up @@ -79,7 +85,7 @@ func Test(t *testing.T) {
server.start()
defer server.stop()

client, err := NewClient(server.addr, Options{generateTLSConfig()})
client, err := NewClient(server.addr, Options{TLSConfig: generateTLSConfig()})
require.NoError(t, err)

msg := dns.Msg{}
Expand All @@ -92,6 +98,38 @@ func Test(t *testing.T) {
assert.Equal(t, net.ParseIP("127.0.0.1").To4(), resp.Answer[0].(*dns.A).A)
}

func TestWriteTimeout(t *testing.T) {
server := doqServer{}
server.start()
defer server.stop()

client, err := NewClient(server.addr, Options{TLSConfig: generateTLSConfig(), WriteTimeout: 1 * time.Nanosecond})
require.NoError(t, err)

msg := dns.Msg{}
msg.SetQuestion("example.org.", dns.TypeA)
resp, err := client.Send(context.Background(), &msg)

require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, resp)
}

func TestReadTimeout(t *testing.T) {
server := doqServer{}
server.start()
defer server.stop()

client, err := NewClient(server.addr, Options{TLSConfig: generateTLSConfig(), ReadTimeout: 1 * time.Nanosecond})
require.NoError(t, err)

msg := dns.Msg{}
msg.SetQuestion("example.org.", dns.TypeA)
resp, err := client.Send(context.Background(), &msg)

require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, resp)
}

func generateTLSConfig() *tls.Config {
cert, err := tls.LoadX509KeyPair("test.crt", "test.key")
if err != nil {
Expand Down

0 comments on commit 349a6e6

Please sign in to comment.