From f2b63c34ec6046b1fff8ebbaa5c101b59863ec45 Mon Sep 17 00:00:00 2001 From: Al S-M Date: Thu, 23 Jul 2020 14:52:14 +0100 Subject: [PATCH] Wiring up Persistence and looking at conn loss I've removed the OnDisconnect callback and moved to OnConnectionLost instead and used a specific Error type to indicate when it was a server initiated disconnect (see test change) #25 #26 --- paho/client.go | 200 ++++++++++++++++++++++----------------- paho/client_test.go | 94 +++++++++++++++++- paho/cmd/chat/main.go | 2 +- paho/errors.go | 12 +++ paho/noop_persistence.go | 14 ++- paho/persistence.go | 42 ++++---- paho/persistence_test.go | 84 ++++++++++++++++ 7 files changed, 336 insertions(+), 112 deletions(-) create mode 100644 paho/errors.go create mode 100644 paho/persistence_test.go diff --git a/paho/client.go b/paho/client.go index 27ab572..a2cd311 100644 --- a/paho/client.go +++ b/paho/client.go @@ -19,21 +19,41 @@ const ( MQTTv5 MQTTVersion = 5 ) +var ( + defaultServerProperties = CommsProperties{ + ReceiveMaximum: 65535, + MaximumQoS: 2, + MaximumPacketSize: 0, + TopicAliasMaximum: 0, + RetainAvailable: true, + WildcardSubAvailable: true, + SubIDAvailable: true, + SharedSubAvailable: true, + } + + defaultClientProperties = CommsProperties{ + ReceiveMaximum: 65535, + MaximumQoS: 2, + MaximumPacketSize: 0, + TopicAliasMaximum: 0, + } +) + type ( // ClientConfig are the user configurable options for the client, an // instance of this struct is passed into NewClient(), not all options // are required to be set, defaults are provided for Persistence, MIDs, // PingHandler, PacketTimeout and Router. ClientConfig struct { - ClientID string - Conn net.Conn - MIDs MIDService - AuthHandler Auther - PingHandler Pinger - Router Router - Persistence Persistence - PacketTimeout time.Duration - OnDisconnect func(*Disconnect) + ClientID string + Conn net.Conn + MIDs MIDService + AuthHandler Auther + PingHandler Pinger + Router Router + Persistence Persistence + PacketTimeout time.Duration + OnConnectionLost func(error) } // Client is the struct representing an MQTT client Client struct { @@ -50,8 +70,8 @@ type ( clientProps CommsProperties serverInflight *semaphore.Weighted clientInflight *semaphore.Weighted - debug Logger - errors Logger + Errors Logger + Debug Logger } // CommsProperties is a struct of the communication properties that may @@ -83,25 +103,11 @@ type ( // Connect() is called. func NewClient(conf ClientConfig) *Client { c := &Client{ - serverProps: CommsProperties{ - ReceiveMaximum: 65535, - MaximumQoS: 2, - MaximumPacketSize: 0, - TopicAliasMaximum: 0, - RetainAvailable: true, - WildcardSubAvailable: true, - SubIDAvailable: true, - SharedSubAvailable: true, - }, - clientProps: CommsProperties{ - ReceiveMaximum: 65535, - MaximumQoS: 2, - MaximumPacketSize: 0, - TopicAliasMaximum: 0, - }, + serverProps: defaultServerProperties, + clientProps: defaultClientProperties, ClientConfig: conf, - errors: NOOPLogger{}, - debug: NOOPLogger{}, + Errors: NOOPLogger{}, + Debug: NOOPLogger{}, } if c.Persistence == nil { @@ -145,10 +151,16 @@ func (c *Client) Connect(ctx context.Context, cp *Connect) (*Connack, error) { if c.Conn == nil { return nil, fmt.Errorf("client connection is nil") } + if err := c.ClientConfig.Persistence.Open(); err != nil { + return nil, fmt.Errorf("failed to open persistence: %w", err) + } + if cp.CleanStart { + c.ClientConfig.Persistence.Reset() + } c.stop = make(chan struct{}) - c.debug.Println("connecting") + c.Debug.Println("connecting") c.mu.Lock() defer c.mu.Unlock() @@ -169,7 +181,7 @@ func (c *Client) Connect(ctx context.Context, cp *Connect) (*Connack, error) { } } - c.debug.Println("starting Incoming") + c.Debug.Println("starting Incoming") c.workers.Add(1) go func() { defer c.workers.Done() @@ -188,18 +200,18 @@ func (c *Client) Connect(ctx context.Context, cp *Connect) (*Connack, error) { ccp.ProtocolName = "MQTT" ccp.ProtocolVersion = 5 - c.debug.Println("sending CONNECT") + c.Debug.Println("sending CONNECT") if _, err := ccp.WriteTo(c.Conn); err != nil { cleanup() return nil, err } - c.debug.Println("waiting for CONNACK") + c.Debug.Println("waiting for CONNACK") var cap *packets.Connack select { case <-connCtx.Done(): if e := connCtx.Err(); e == context.DeadlineExceeded { - c.debug.Println("timeout waiting for CONNACK") + c.Debug.Println("timeout waiting for CONNACK") cleanup() return nil, e } @@ -210,7 +222,7 @@ func (c *Client) Connect(ctx context.Context, cp *Connect) (*Connack, error) { if ca.ReasonCode >= 0x80 { var reason string - c.debug.Println("received an error code in Connack:", ca.ReasonCode) + c.Debug.Println("received an error code in Connack:", ca.ReasonCode) if ca.Properties != nil { reason = ca.Properties.ReasonString } @@ -246,7 +258,7 @@ func (c *Client) Connect(ctx context.Context, cp *Connect) (*Connack, error) { c.serverInflight = semaphore.NewWeighted(int64(c.serverProps.ReceiveMaximum)) c.clientInflight = semaphore.NewWeighted(int64(c.clientProps.ReceiveMaximum)) - c.debug.Println("received CONNACK, starting PingHandler") + c.Debug.Println("received CONNACK, starting PingHandler") c.workers.Add(1) go func() { defer c.workers.Done() @@ -265,7 +277,7 @@ func (c *Client) Incoming() { for { select { case <-c.stop: - c.debug.Println("client stopping, Incoming stopping") + c.Debug.Println("client stopping, Incoming stopping") return default: recv, err := packets.ReadPacket(c.Conn) @@ -308,7 +320,7 @@ func (c *Client) Incoming() { } _, err := pa.WriteTo(c.Conn) if err != nil { - c.errors.Printf("failed to send PUBACK for %d: %s", pa.PacketID, err) + c.Errors.Printf("failed to send PUBACK for %d: %s", pa.PacketID, err) } case 2: pr := packets.Pubrec{ @@ -317,40 +329,44 @@ func (c *Client) Incoming() { } _, err := pr.WriteTo(c.Conn) if err != nil { - c.errors.Printf("failed to send PUBREC for %d: %s", pr.PacketID, err) + c.Errors.Printf("failed to send PUBREC for %d: %s", pr.PacketID, err) } } case packets.PUBACK, packets.PUBCOMP, packets.SUBACK, packets.UNSUBACK: - c.debug.Println("received packet with id", recv.PacketID()) + c.Debug.Println("received packet with id", recv.PacketID()) if cpCtx := c.MIDs.Get(recv.PacketID()); cpCtx != nil { + c.ClientConfig.Persistence.Delete(recv.PacketID()) cpCtx.Return <- *recv } else { - c.debug.Println("received a response for a message ID we don't know:", recv.PacketID()) + c.Debug.Println("received a response for a message ID we don't know:", recv.PacketID()) } case packets.PUBREC: - c.debug.Println("received pubrec") + c.Debug.Println("received pubrec") if cpCtx := c.MIDs.Get(recv.PacketID()); cpCtx == nil { - c.debug.Println("received a PUBREC for a message ID we don't know:", recv.PacketID()) + c.Debug.Println("received a PUBREC for a message ID we don't know:", recv.PacketID()) pl := packets.Pubrel{ PacketID: recv.Content.(*packets.Pubrec).PacketID, ReasonCode: 0x92, } _, err := pl.WriteTo(c.Conn) if err != nil { - c.errors.Printf("failed to send PUBREL for %d: %s", pl.PacketID, err) + c.Errors.Printf("failed to send PUBREL for %d: %s", pl.PacketID, err) } } else { pr := recv.Content.(*packets.Pubrec) if pr.ReasonCode >= 0x80 { //Received a failure code, shortcut and return + c.ClientConfig.Persistence.Delete(pr.PacketID) cpCtx.Return <- *recv } else { pl := packets.Pubrel{ PacketID: pr.PacketID, } + //TODO: what to do about failing to persist a pubrel? + c.ClientConfig.Persistence.Put(pl.PacketID, &pl) _, err := pl.WriteTo(c.Conn) if err != nil { - c.errors.Printf("failed to send PUBREL for %d: %s", pl.PacketID, err) + c.Errors.Printf("failed to send PUBREL for %d: %s", pl.PacketID, err) } } } @@ -366,19 +382,17 @@ func (c *Client) Incoming() { } _, err := pc.WriteTo(c.Conn) if err != nil { - c.errors.Printf("failed to send PUBCOMP for %d: %s", pc.PacketID, err) + c.Errors.Printf("failed to send PUBCOMP for %d: %s", pc.PacketID, err) } } case packets.DISCONNECT: if c.raCtx != nil { c.raCtx.Return <- *recv } - c.Error(fmt.Errorf("received server initiated disconnect")) - c.debug.Println(c.OnDisconnect) - if c.OnDisconnect != nil { - c.debug.Println("calling OnDisconnect") - go c.OnDisconnect(DisconnectFromPacketDisconnect(recv.Content.(*packets.Disconnect))) - } + c.Error(&DisconnectError{ + Disconnect: DisconnectFromPacketDisconnect(recv.Content.(*packets.Disconnect)), + Err: fmt.Errorf("received server initiated disconnect"), + }) } } } @@ -389,7 +403,7 @@ func (c *Client) Incoming() { // which results in the other client goroutines terminating. // It also closes the client network connection. func (c *Client) Error(e error) { - c.debug.Println("error called:", e) + c.Debug.Println("error called:", e) c.mu.Lock() select { case <-c.stop: @@ -397,12 +411,16 @@ func (c *Client) Error(e error) { default: close(c.stop) } - c.debug.Println("client stopped") + c.Debug.Println("client stopped") c.PingHandler.Stop() - c.debug.Println("ping stopped") + c.Debug.Println("ping stopped") c.Conn.Close() - c.debug.Println("conn closed") + c.Debug.Println("conn closed") c.mu.Unlock() + if c.ClientConfig.OnConnectionLost != nil { + c.Debug.Println("calling OnConnectionLost") + go c.ClientConfig.OnConnectionLost(e) + } } // Authenticate is used to initiate a reauthentication of credentials with the @@ -411,7 +429,7 @@ func (c *Client) Error(e error) { // server until either a successful Auth packet is passed back, or a Disconnect // is received. func (c *Client) Authenticate(ctx context.Context, a *Auth) (*AuthResponse, error) { - c.debug.Println("client initiated reauthentication") + c.Debug.Println("client initiated reauthentication") c.mu.Lock() defer c.mu.Unlock() @@ -420,7 +438,7 @@ func (c *Client) Authenticate(ctx context.Context, a *Auth) (*AuthResponse, erro c.raCtx = nil }() - c.debug.Println("sending AUTH") + c.Debug.Println("sending AUTH") if _, err := a.Packet().WriteTo(c.Conn); err != nil { return nil, err } @@ -429,7 +447,7 @@ func (c *Client) Authenticate(ctx context.Context, a *Auth) (*AuthResponse, erro select { case <-ctx.Done(): if e := ctx.Err(); e == context.DeadlineExceeded { - c.debug.Println("timeout waiting for Auth to complete") + c.Debug.Println("timeout waiting for Auth to complete") return nil, e } case rp = <-c.raCtx.Return: @@ -471,7 +489,7 @@ func (c *Client) Subscribe(ctx context.Context, s *Subscribe) (*Suback, error) { } } - c.debug.Printf("subscribing to %+v", s.Subscriptions) + c.Debug.Printf("subscribing to %+v", s.Subscriptions) subCtx, cf := context.WithTimeout(ctx, c.PacketTimeout) defer cf() @@ -480,17 +498,21 @@ func (c *Client) Subscribe(ctx context.Context, s *Subscribe) (*Suback, error) { sp := s.Packet() sp.PacketID = c.MIDs.Request(cpCtx) - c.debug.Println("sending SUBSCRIBE") + if err := c.ClientConfig.Persistence.Put(sp.PacketID, sp); err != nil { + c.MIDs.Free(sp.PacketID) + return nil, fmt.Errorf("failed to persist SUBSCRIBE: %w", err) + } + c.Debug.Println("sending SUBSCRIBE") if _, err := sp.WriteTo(c.Conn); err != nil { return nil, err } - c.debug.Println("waiting for SUBACK") + c.Debug.Println("waiting for SUBACK") var sap packets.ControlPacket select { case <-subCtx.Done(): if e := subCtx.Err(); e == context.DeadlineExceeded { - c.debug.Println("timeout waiting for SUBACK") + c.Debug.Println("timeout waiting for SUBACK") return nil, e } case sap = <-cpCtx.Return: @@ -499,14 +521,14 @@ func (c *Client) Subscribe(ctx context.Context, s *Subscribe) (*Suback, error) { if sap.Type != packets.SUBACK { return nil, fmt.Errorf("received %d instead of Suback", sap.Type) } - c.debug.Println("received SUBACK") + c.Debug.Println("received SUBACK") sa := SubackFromPacketSuback(sap.Content.(*packets.Suback)) switch { case len(sa.Reasons) == 1: if sa.Reasons[0] >= 0x80 { var reason string - c.debug.Println("received an error code in Suback:", sa.Reasons[0]) + c.Debug.Println("received an error code in Suback:", sa.Reasons[0]) if sa.Properties != nil { reason = sa.Properties.ReasonString } @@ -515,7 +537,7 @@ func (c *Client) Subscribe(ctx context.Context, s *Subscribe) (*Suback, error) { default: for _, code := range sa.Reasons { if code >= 0x80 { - c.debug.Println("received an error code in Suback:", code) + c.Debug.Println("received an error code in Suback:", code) return sa, fmt.Errorf("at least one requested subscription failed") } } @@ -529,7 +551,7 @@ func (c *Client) Subscribe(ctx context.Context, s *Subscribe) (*Suback, error) { // a response Unsuback, or for the timeout to fire. Any response Unsuback // is returned from the function, along with any errors. func (c *Client) Unsubscribe(ctx context.Context, u *Unsubscribe) (*Unsuback, error) { - c.debug.Printf("unsubscribing from %+v", u.Topics) + c.Debug.Printf("unsubscribing from %+v", u.Topics) unsubCtx, cf := context.WithTimeout(ctx, c.PacketTimeout) defer cf() cpCtx := &CPContext{unsubCtx, make(chan packets.ControlPacket, 1)} @@ -537,17 +559,21 @@ func (c *Client) Unsubscribe(ctx context.Context, u *Unsubscribe) (*Unsuback, er up := u.Packet() up.PacketID = c.MIDs.Request(cpCtx) - c.debug.Println("sending UNSUBSCRIBE") + if err := c.ClientConfig.Persistence.Put(up.PacketID, up); err != nil { + c.MIDs.Free(up.PacketID) + return nil, fmt.Errorf("failed to persist UNSUBSCRIBE: %w", err) + } + c.Debug.Println("sending UNSUBSCRIBE") if _, err := up.WriteTo(c.Conn); err != nil { return nil, err } - c.debug.Println("waiting for UNSUBACK") + c.Debug.Println("waiting for UNSUBACK") var uap packets.ControlPacket select { case <-unsubCtx.Done(): if e := unsubCtx.Err(); e == context.DeadlineExceeded { - c.debug.Println("timeout waiting for UNSUBACK") + c.Debug.Println("timeout waiting for UNSUBACK") return nil, e } case uap = <-cpCtx.Return: @@ -556,14 +582,14 @@ func (c *Client) Unsubscribe(ctx context.Context, u *Unsubscribe) (*Unsuback, er if uap.Type != packets.UNSUBACK { return nil, fmt.Errorf("received %d instead of Unsuback", uap.Type) } - c.debug.Println("received SUBACK") + c.Debug.Println("received SUBACK") ua := UnsubackFromPacketUnsuback(uap.Content.(*packets.Unsuback)) switch { case len(ua.Reasons) == 1: if ua.Reasons[0] >= 0x80 { var reason string - c.debug.Println("received an error code in Unsuback:", ua.Reasons[0]) + c.Debug.Println("received an error code in Unsuback:", ua.Reasons[0]) if ua.Properties != nil { reason = ua.Properties.ReasonString } @@ -572,7 +598,7 @@ func (c *Client) Unsubscribe(ctx context.Context, u *Unsubscribe) (*Unsuback, er default: for _, code := range ua.Reasons { if code >= 0x80 { - c.debug.Println("received an error code in Suback:", code) + c.Debug.Println("received an error code in Suback:", code) return ua, fmt.Errorf("at least one requested unsubscribe failed") } } @@ -598,13 +624,13 @@ func (c *Client) Publish(ctx context.Context, p *Publish) (*PublishResponse, err return nil, fmt.Errorf("cannot send Publish with retain flag set, server does not support retained messages") } - c.debug.Printf("sending message to %s", p.Topic) + c.Debug.Printf("sending message to %s", p.Topic) pb := p.Packet() switch p.QoS { case 0: - c.debug.Println("sending QoS0 message") + c.Debug.Println("sending QoS0 message") if _, err := pb.WriteTo(c.Conn); err != nil { return nil, err } @@ -617,7 +643,7 @@ func (c *Client) Publish(ctx context.Context, p *Publish) (*PublishResponse, err } func (c *Client) publishQoS12(ctx context.Context, pb *packets.Publish) (*PublishResponse, error) { - c.debug.Println("sending QoS12 message") + c.Debug.Println("sending QoS12 message") pubCtx, cf := context.WithTimeout(ctx, c.PacketTimeout) defer cf() if err := c.serverInflight.Acquire(pubCtx, 1); err != nil { @@ -626,6 +652,10 @@ func (c *Client) publishQoS12(ctx context.Context, pb *packets.Publish) (*Publis cpCtx := &CPContext{pubCtx, make(chan packets.ControlPacket, 1)} pb.PacketID = c.MIDs.Request(cpCtx) + if err := c.ClientConfig.Persistence.Put(pb.PacketID, pb); err != nil { + c.MIDs.Free(pb.PacketID) + return nil, fmt.Errorf("failed to persist PUBLISH: %w", err) + } if _, err := pb.WriteTo(c.Conn); err != nil { return nil, err } @@ -634,7 +664,7 @@ func (c *Client) publishQoS12(ctx context.Context, pb *packets.Publish) (*Publis select { case <-pubCtx.Done(): if e := pubCtx.Err(); e == context.DeadlineExceeded { - c.debug.Println("timeout waiting for Publish response") + c.Debug.Println("timeout waiting for Publish response") return nil, e } case resp = <-cpCtx.Return: @@ -645,24 +675,24 @@ func (c *Client) publishQoS12(ctx context.Context, pb *packets.Publish) (*Publis if resp.Type != packets.PUBACK { return nil, fmt.Errorf("received %d instead of PUBACK", resp.Type) } - c.debug.Println("received PUBACK for", pb.PacketID) + c.Debug.Println("received PUBACK for", pb.PacketID) c.serverInflight.Release(1) pr := PublishResponseFromPuback(resp.Content.(*packets.Puback)) if pr.ReasonCode >= 0x80 { - c.debug.Println("received an error code in Puback:", pr.ReasonCode) + c.Debug.Println("received an error code in Puback:", pr.ReasonCode) return pr, fmt.Errorf("error publishing: %s", resp.Content.(*packets.Puback).Reason()) } return pr, nil case 2: switch resp.Type { case packets.PUBCOMP: - c.debug.Println("received PUBCOMP for", pb.PacketID) + c.Debug.Println("received PUBCOMP for", pb.PacketID) c.serverInflight.Release(1) pr := PublishResponseFromPubcomp(resp.Content.(*packets.Pubcomp)) return pr, nil case packets.PUBREC: - c.debug.Printf("received PUBREC for %s (must have errored)", pb.PacketID) + c.Debug.Printf("received PUBREC for %s (must have errored)", pb.PacketID) c.serverInflight.Release(1) pr := PublishResponseFromPubrec(resp.Content.(*packets.Pubrec)) return pr, nil @@ -671,7 +701,7 @@ func (c *Client) publishQoS12(ctx context.Context, pb *packets.Publish) (*Publis } } - c.debug.Println("ended up with a non QoS1/2 message:", pb.QoS) + c.Debug.Println("ended up with a non QoS1/2 message:", pb.QoS) return nil, fmt.Errorf("ended up with a non QoS1/2 message: %d", pb.QoS) } @@ -680,7 +710,7 @@ func (c *Client) publishQoS12(ctx context.Context, pb *packets.Publish) (*Publis // (and if it does this function returns any error) the network connection // is closed. func (c *Client) Disconnect(d *Disconnect) error { - c.debug.Println("disconnecting") + c.Debug.Println("disconnecting") c.mu.Lock() defer c.mu.Unlock() defer c.Conn.Close() @@ -693,11 +723,11 @@ func (c *Client) Disconnect(d *Disconnect) error { // SetDebugLogger takes an instance of the paho Logger interface // and sets it to be used by the debug log endpoint func (c *Client) SetDebugLogger(l Logger) { - c.debug = l + c.Debug = l } // SetErrorLogger takes an instance of the paho Logger interface // and sets it to be used by the error log endpoint func (c *Client) SetErrorLogger(l Logger) { - c.errors = l + c.Errors = l } diff --git a/paho/client_test.go b/paho/client_test.go index b5f9f42..dabb175 100644 --- a/paho/client_test.go +++ b/paho/client_test.go @@ -2,6 +2,7 @@ package paho import ( "context" + "errors" "log" "os" "testing" @@ -361,9 +362,12 @@ func TestReceiveServerDisconnect(t *testing.T) { c := NewClient(ClientConfig{ Conn: ts.ClientConn(), - OnDisconnect: func(d *Disconnect) { - assert.Equal(t, byte(packets.DisconnectServerShuttingDown), d.ReasonCode) - assert.Equal(t, d.Properties.ReasonString, "GONE!") + OnConnectionLost: func(e error) { + var d *DisconnectError + require.True(t, errors.As(e, &d)) + d = e.(*DisconnectError) + assert.Equal(t, byte(packets.DisconnectServerShuttingDown), d.Disconnect.ReasonCode) + assert.Equal(t, d.Disconnect.Properties.ReasonString, "GONE!") close(rChan) }, }) @@ -422,3 +426,87 @@ func TestAuthenticate(t *testing.T) { time.Sleep(10 * time.Millisecond) } + +func TestClientPublishQoS1Persistence(t *testing.T) { + ts := newTestServer() + ts.SetResponse(packets.PUBACK, &packets.Puback{ + ReasonCode: packets.PubackSuccess, + Properties: &packets.Properties{}, + }) + go ts.Run() + defer ts.Stop() + + ps := newTestPersistence() + + c := NewClient(ClientConfig{ + Conn: ts.ClientConn(), + Persistence: ps, + }) + require.NotNil(t, c) + c.SetDebugLogger(log.New(os.Stderr, "PUBLISHQOS1: ", log.LstdFlags)) + + c.serverInflight = semaphore.NewWeighted(10000) + c.clientInflight = semaphore.NewWeighted(10000) + c.stop = make(chan struct{}) + go c.Incoming() + go c.PingHandler.Start(c.Conn, 30*time.Second) + + p := &Publish{ + Topic: "test/1", + QoS: 1, + Payload: []byte("test payload"), + } + + pa, err := c.Publish(context.Background(), p) + require.Nil(t, err) + assert.Equal(t, uint8(0), pa.ReasonCode) + assert.Equal(t, 1, ps.putCount, "putCount") + assert.Equal(t, 1, ps.deleteCount, "deleteCount") + assert.Equal(t, 0, ps.Len()) + + time.Sleep(10 * time.Millisecond) +} + +func TestClientPublishQoS2Persistence(t *testing.T) { + ts := newTestServer() + ts.SetResponse(packets.PUBREC, &packets.Pubrec{ + ReasonCode: packets.PubrecSuccess, + Properties: &packets.Properties{}, + }) + ts.SetResponse(packets.PUBCOMP, &packets.Pubcomp{ + ReasonCode: packets.PubcompSuccess, + Properties: &packets.Properties{}, + }) + go ts.Run() + defer ts.Stop() + + ps := newTestPersistence() + + c := NewClient(ClientConfig{ + Conn: ts.ClientConn(), + Persistence: ps, + }) + require.NotNil(t, c) + c.SetDebugLogger(log.New(os.Stderr, "PUBLISHQOS2: ", log.LstdFlags)) + + c.serverInflight = semaphore.NewWeighted(10000) + c.clientInflight = semaphore.NewWeighted(10000) + c.stop = make(chan struct{}) + go c.Incoming() + go c.PingHandler.Start(c.Conn, 30*time.Second) + + p := &Publish{ + Topic: "test/2", + QoS: 2, + Payload: []byte("test payload"), + } + + pr, err := c.Publish(context.Background(), p) + require.Nil(t, err) + assert.Equal(t, uint8(0), pr.ReasonCode) + assert.Equal(t, 2, ps.putCount, "putCount") + assert.Equal(t, 1, ps.deleteCount, "deleteCount") + assert.Equal(t, 0, ps.Len()) + + time.Sleep(10 * time.Millisecond) +} diff --git a/paho/cmd/chat/main.go b/paho/cmd/chat/main.go index bbccbef..57247aa 100644 --- a/paho/cmd/chat/main.go +++ b/paho/cmd/chat/main.go @@ -82,7 +82,7 @@ func main() { if _, err := c.Subscribe(context.Background(), &paho.Subscribe{ Subscriptions: map[string]paho.SubscribeOptions{ - *topic: paho.SubscribeOptions{QoS: byte(*qos), NoLocal: true}, + *topic: {QoS: byte(*qos), NoLocal: true}, }, }); err != nil { log.Fatalln(err) diff --git a/paho/errors.go b/paho/errors.go new file mode 100644 index 0000000..307c2dd --- /dev/null +++ b/paho/errors.go @@ -0,0 +1,12 @@ +package paho + +import "fmt" + +type DisconnectError struct { + Disconnect *Disconnect + Err error +} + +func (d *DisconnectError) Error() string { + return fmt.Sprintf("%s - %d: %s", d.Err, d.Disconnect.ReasonCode, d.Disconnect.Properties.ReasonString) +} diff --git a/paho/noop_persistence.go b/paho/noop_persistence.go index d2d1570..f9275d2 100644 --- a/paho/noop_persistence.go +++ b/paho/noop_persistence.go @@ -4,15 +4,19 @@ import "github.com/eclipse/paho.golang/packets" type noopPersistence struct{} -func (n *noopPersistence) Open() {} +func (n *noopPersistence) Open() error { + return nil +} -func (n *noopPersistence) Put(id uint16, cp packets.ControlPacket) {} +func (n *noopPersistence) Put(id uint16, cp packets.Packet) error { + return nil +} -func (n *noopPersistence) Get(id uint16) packets.ControlPacket { - return packets.ControlPacket{} +func (n *noopPersistence) Get(id uint16) packets.Packet { + return nil } -func (n *noopPersistence) All() []packets.ControlPacket { +func (n *noopPersistence) All() []packets.Packet { return nil } diff --git a/paho/persistence.go b/paho/persistence.go index f02b846..7f0447a 100644 --- a/paho/persistence.go +++ b/paho/persistence.go @@ -7,53 +7,59 @@ import ( ) // Persistence is an interface of the functions for a struct -// that is used to persist ControlPackets. +// that is used to persist Packets. // Open() is an initialiser to prepare the Persistence for use -// Put() takes a uint16 which is a messageid and a ControlPacket +// Put() takes a uint16 which is a messageid and a Packet // to persist against that messageid // Get() takes a uint16 which is a messageid and returns the -// persisted ControlPacket from the Persistence for that messageid -// All() returns a slice of all ControlPackets persisted +// persisted Packet from the Persistence for that messageid +// All() returns a slice of all Packets persisted // Delete() takes a uint16 which is a messageid and deletes the -// associated stored ControlPacket from the Persistence +// associated stored Packet from the Persistence // Close() closes the Persistence // Reset() clears the Persistence and prepares it to be reused type Persistence interface { - Open() - Put(uint16, packets.ControlPacket) - Get(uint16) packets.ControlPacket - All() []packets.ControlPacket + Open() error + Put(uint16, packets.Packet) error + Get(uint16) packets.Packet + All() []packets.Packet Delete(uint16) Close() Reset() } // MemoryPersistence is an implementation of a Persistence -// that stores the ControlPackets in memory using a map +// that stores the Packets in memory using a map type MemoryPersistence struct { sync.RWMutex - packets map[uint16]packets.ControlPacket + packets map[uint16]packets.Packet } // Open is the library provided MemoryPersistence's implementation of // the required interface function() -func (m *MemoryPersistence) Open() { +func (m *MemoryPersistence) Open() error { m.Lock() - m.packets = make(map[uint16]packets.ControlPacket) + if m.packets == nil { + m.packets = make(map[uint16]packets.Packet) + } m.Unlock() + + return nil } // Put is the library provided MemoryPersistence's implementation of // the required interface function() -func (m *MemoryPersistence) Put(id uint16, cp packets.ControlPacket) { +func (m *MemoryPersistence) Put(id uint16, cp packets.Packet) error { m.Lock() m.packets[id] = cp m.Unlock() + + return nil } // Get is the library provided MemoryPersistence's implementation of // the required interface function() -func (m *MemoryPersistence) Get(id uint16) packets.ControlPacket { +func (m *MemoryPersistence) Get(id uint16) packets.Packet { m.RLock() defer m.RUnlock() return m.packets[id] @@ -61,10 +67,10 @@ func (m *MemoryPersistence) Get(id uint16) packets.ControlPacket { // All is the library provided MemoryPersistence's implementation of // the required interface function() -func (m *MemoryPersistence) All() []packets.ControlPacket { +func (m *MemoryPersistence) All() []packets.Packet { m.Lock() defer m.RUnlock() - ret := make([]packets.ControlPacket, len(m.packets)) + ret := make([]packets.Packet, len(m.packets)) for _, cp := range m.packets { ret = append(ret, cp) @@ -93,6 +99,6 @@ func (m *MemoryPersistence) Close() { // the required interface function() func (m *MemoryPersistence) Reset() { m.Lock() - m.packets = make(map[uint16]packets.ControlPacket) + m.packets = make(map[uint16]packets.Packet) m.Unlock() } diff --git a/paho/persistence_test.go b/paho/persistence_test.go new file mode 100644 index 0000000..e2b4eac --- /dev/null +++ b/paho/persistence_test.go @@ -0,0 +1,84 @@ +package paho + +import ( + "sync" + + "github.com/eclipse/paho.golang/packets" +) + +type testPersistence struct { + putCount int + getCount int + deleteCount int + sync.RWMutex + packets map[uint16]packets.Packet +} + +func newTestPersistence() *testPersistence { + return &testPersistence{ + packets: make(map[uint16]packets.Packet), + } +} + +func (t *testPersistence) Open() error { + t.Lock() + if t.packets == nil { + t.packets = make(map[uint16]packets.Packet) + } + t.Unlock() + + return nil +} + +func (t *testPersistence) Put(id uint16, cp packets.Packet) error { + t.Lock() + t.putCount++ + t.packets[id] = cp + t.Unlock() + + return nil +} + +func (t *testPersistence) Get(id uint16) packets.Packet { + t.RLock() + defer t.RUnlock() + t.getCount++ + return t.packets[id] +} + +func (t *testPersistence) All() []packets.Packet { + t.Lock() + defer t.RUnlock() + ret := make([]packets.Packet, len(t.packets)) + + for _, cp := range t.packets { + ret = append(ret, cp) + } + + return ret +} + +func (t *testPersistence) Delete(id uint16) { + t.Lock() + delete(t.packets, id) + t.deleteCount++ + t.Unlock() +} + +func (t *testPersistence) Close() { + t.Lock() + t.packets = nil + t.Unlock() +} + +func (t *testPersistence) Reset() { + t.Lock() + t.packets = make(map[uint16]packets.Packet) + t.Unlock() +} + +func (t *testPersistence) Len() int { + t.Lock() + defer t.Unlock() + return len(t.packets) +}