From 5cbb72dc2ad2e96ecf8c92171363e9314ef9ee54 Mon Sep 17 00:00:00 2001 From: Gjermund Garaba Date: Mon, 6 Apr 2026 12:56:15 +0200 Subject: [PATCH] fix: harden accesspoint connection lifecycle Prevent panics and races during AP close/reconnect by: - Initializing stop channels in the constructor instead of on each connect - Adding a permanent closed state that rejects Send/Receive/Connect after Close - Using non-blocking signal sends to avoid goroutine leaks - Protecting conn access with proper locking (closeConn, isStopped helpers) - Resetting pong-ack deadline on reconnect to avoid false timeouts --- ap/ap.go | 121 ++++++++++++++++++++++++++--------- ap/ap_test.go | 174 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 266 insertions(+), 29 deletions(-) create mode 100644 ap/ap_test.go diff --git a/ap/ap.go b/ap/ap.go index 5c9bb80b..5a77e091 100644 --- a/ap/ap.go +++ b/ap/ap.go @@ -9,6 +9,7 @@ import ( "crypto/sha1" "encoding/base64" "encoding/binary" + "errors" "fmt" "io" "net" @@ -27,6 +28,8 @@ import ( const pongAckInterval = 120 * time.Second +var ErrAccesspointClosed = errors.New("accesspoint closed") + type AccesspointLoginError struct { Message *pb.APLoginFailed } @@ -48,6 +51,7 @@ type Accesspoint struct { conn net.Conn encConn *shannonConn + closed bool stop bool pongAckTickerStop chan struct{} recvLoopStop chan struct{} @@ -65,7 +69,14 @@ type Accesspoint struct { } func NewAccesspoint(log librespot.Logger, addr librespot.GetAddressFunc, deviceId string) *Accesspoint { - return &Accesspoint{log: log, addr: addr, deviceId: deviceId, recvChans: make(map[PacketType][]chan Packet)} + return &Accesspoint{ + log: log, + addr: addr, + deviceId: deviceId, + pongAckTickerStop: make(chan struct{}, 1), + recvLoopStop: make(chan struct{}, 1), + recvChans: make(map[PacketType][]chan Packet), + } } func (ap *Accesspoint) init(ctx context.Context) (err error) { @@ -188,6 +199,10 @@ func (ap *Accesspoint) Connect(ctx context.Context, creds *pb.LoginCredentials) ap.connMu.Lock() defer ap.connMu.Unlock() + if ap.closed { + return ErrAccesspointClosed + } + return backoff.Retry(func() error { err := ap.connect(ctx, creds) if err != nil { @@ -199,9 +214,6 @@ func (ap *Accesspoint) Connect(ctx context.Context, creds *pb.LoginCredentials) } func (ap *Accesspoint) connect(ctx context.Context, creds *pb.LoginCredentials) error { - ap.recvLoopStop = make(chan struct{}, 1) - ap.pongAckTickerStop = make(chan struct{}, 1) - if err := ap.init(ctx); err != nil { return err } @@ -232,27 +244,33 @@ func (ap *Accesspoint) connect(ctx context.Context, creds *pb.LoginCredentials) func (ap *Accesspoint) Close() { ap.connMu.Lock() - defer ap.connMu.Unlock() - + ap.closed = true ap.stop = true - - if ap.conn == nil { - return - } - - ap.recvLoopStop <- struct{}{} - ap.pongAckTickerStop <- struct{}{} - _ = ap.conn.Close() + ap.signalStop() + ap.closeConnLocked() + ap.connMu.Unlock() } func (ap *Accesspoint) Send(ctx context.Context, pktType PacketType, payload []byte) error { ap.connMu.RLock() defer ap.connMu.RUnlock() + + if ap.closed { + return ErrAccesspointClosed + } + return ap.encConn.sendPacket(ctx, pktType, payload) } func (ap *Accesspoint) Receive(types ...PacketType) <-chan Packet { ch := make(chan Packet) + ap.connMu.RLock() + if ap.closed { + ap.connMu.RUnlock() + close(ch) + return ch + } + ap.recvChansLock.Lock() for _, type_ := range types { ll, _ := ap.recvChans[type_] @@ -263,6 +281,7 @@ func (ap *Accesspoint) Receive(types ...PacketType) <-chan Packet { // start the recv loop if necessary ap.startReceiving() + ap.connMu.RUnlock() return ch } @@ -270,11 +289,9 @@ func (ap *Accesspoint) Receive(types ...PacketType) <-chan Packet { func (ap *Accesspoint) startReceiving() { ap.recvLoopOnce.Do(func() { ap.log.Tracef("starting accesspoint recv loop") - go ap.recvLoop() - - // set last ping in the future - ap.lastPongAck = time.Now().Add(pongAckInterval) + ap.resetPongAckDeadline() go ap.pongAckTicker() + go ap.recvLoop() }) } @@ -288,7 +305,7 @@ loop: // no need to hold the connMu since reconnection happens in this routine pkt, payload, err := ap.encConn.receivePacket(context.TODO()) if err != nil { - if !ap.stop { + if !ap.isStopped() { ap.log.WithError(err).Errorf("failed receiving packet") } @@ -298,15 +315,13 @@ loop: switch pkt { case PacketTypePing: ap.log.Tracef("received accesspoint ping") - if err := ap.encConn.sendPacket(context.TODO(), PacketTypePong, payload); err != nil { + if err := ap.Send(context.TODO(), PacketTypePong, payload); err != nil { ap.log.WithError(err).Errorf("failed sending Pong packet") break loop } case PacketTypePongAck: ap.log.Tracef("received accesspoint pong ack") - ap.lastPongAckLock.Lock() - ap.lastPongAck = time.Now() - ap.lastPongAckLock.Unlock() + ap.notePongAck() continue default: ap.recvChansLock.RLock() @@ -327,10 +342,10 @@ loop: } // always close as we might end up here because of application errors - _ = ap.conn.Close() + ap.closeConn() // if we shouldn't stop, try to reconnect - if !ap.stop { + if !ap.isStopped() { ap.connMu.Lock() if err := backoff.Retry(ap.reconnect, backoff.NewExponentialBackOff()); err != nil { ap.log.WithError(err).Errorf("failed reconnecting accesspoint") @@ -369,15 +384,13 @@ loop: case <-ap.pongAckTickerStop: break loop case <-ticker.C: - ap.lastPongAckLock.Lock() - timePassed := time.Since(ap.lastPongAck) - ap.lastPongAckLock.Unlock() + timePassed := ap.timeSinceLastPongAck() if timePassed > pongAckInterval { ap.log.Errorf("did not receive last pong ack from accesspoint, %.0fs passed", timePassed.Seconds()) // closing the connection should make the read on the "recvLoop" fail, // continue hoping for a new connection - _ = ap.conn.Close() + ap.closeConn() continue } } @@ -399,6 +412,8 @@ func (ap *Accesspoint) reconnect() (err error) { return err } + ap.resetPongAckDeadline() + // if we are here the "recvLoop" has already died, restart it go ap.recvLoop() @@ -406,6 +421,54 @@ func (ap *Accesspoint) reconnect() (err error) { return nil } +func (ap *Accesspoint) resetPongAckDeadline() { + ap.lastPongAckLock.Lock() + ap.lastPongAck = time.Now().Add(pongAckInterval) + ap.lastPongAckLock.Unlock() +} + +func (ap *Accesspoint) notePongAck() { + ap.lastPongAckLock.Lock() + ap.lastPongAck = time.Now() + ap.lastPongAckLock.Unlock() +} + +func (ap *Accesspoint) timeSinceLastPongAck() time.Duration { + ap.lastPongAckLock.Lock() + defer ap.lastPongAckLock.Unlock() + return time.Since(ap.lastPongAck) +} + +func (ap *Accesspoint) closeConn() { + ap.connMu.Lock() + ap.closeConnLocked() + ap.connMu.Unlock() +} + +func (ap *Accesspoint) closeConnLocked() { + if ap.conn != nil { + _ = ap.conn.Close() + } +} + +func (ap *Accesspoint) signalStop() { + select { + case ap.recvLoopStop <- struct{}{}: + default: + } + + select { + case ap.pongAckTickerStop <- struct{}{}: + default: + } +} + +func (ap *Accesspoint) isStopped() bool { + ap.connMu.RLock() + defer ap.connMu.RUnlock() + return ap.stop +} + func (ap *Accesspoint) performKeyExchange() ([]byte, error) { // accumulate transferred data for challenge cc := &connAccumulator{Conn: ap.conn} diff --git a/ap/ap_test.go b/ap/ap_test.go new file mode 100644 index 00000000..ee1708b6 --- /dev/null +++ b/ap/ap_test.go @@ -0,0 +1,174 @@ +package ap + +import ( + "context" + "io" + "net" + "sync" + "testing" + "testing/synctest" + "time" + + librespot "github.com/devgianlu/go-librespot" +) + +type stubAddr string + +func (a stubAddr) Network() string { return string(a) } +func (a stubAddr) String() string { return string(a) } + +type countingConn struct { + mu sync.Mutex + closes int +} + +func (c *countingConn) Read([]byte) (int, error) { return 0, io.EOF } +func (c *countingConn) Write(b []byte) (int, error) { return len(b), nil } +func (c *countingConn) Close() error { c.mu.Lock(); c.closes++; c.mu.Unlock(); return nil } +func (c *countingConn) LocalAddr() net.Addr { return stubAddr("local") } +func (c *countingConn) RemoteAddr() net.Addr { return stubAddr("remote") } +func (c *countingConn) SetDeadline(time.Time) error { return nil } +func (c *countingConn) SetReadDeadline(time.Time) error { return nil } +func (c *countingConn) SetWriteDeadline(time.Time) error { return nil } + +func (c *countingConn) CloseCount() int { + c.mu.Lock() + defer c.mu.Unlock() + return c.closes +} + +type blockingConn struct { + started sync.Once + startCh chan struct{} + blockCh chan struct{} +} + +func newBlockingConn() *blockingConn { + return &blockingConn{ + startCh: make(chan struct{}), + blockCh: make(chan struct{}), + } +} + +func (c *blockingConn) Read([]byte) (int, error) { return 0, io.EOF } +func (c *blockingConn) LocalAddr() net.Addr { return stubAddr("local") } +func (c *blockingConn) RemoteAddr() net.Addr { return stubAddr("remote") } +func (c *blockingConn) SetDeadline(time.Time) error { return nil } +func (c *blockingConn) SetReadDeadline(time.Time) error { return nil } +func (c *blockingConn) SetWriteDeadline(time.Time) error { return nil } + +func (c *blockingConn) Write(b []byte) (int, error) { + c.started.Do(func() { close(c.startCh) }) + <-c.blockCh + return len(b), nil +} + +func (c *blockingConn) Close() error { + return nil +} + +func TestPongAckTickerDoesNotPanicWhenConnNil(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + ap := NewAccesspoint(&librespot.NullLogger{}, nil, "") + + panicCh := make(chan any, 1) + go func() { + defer func() { + panicCh <- recover() + }() + ap.pongAckTicker() + }() + + time.Sleep(pongAckInterval + time.Nanosecond) + synctest.Wait() + + select { + case p := <-panicCh: + if p != nil { + t.Fatalf("pongAckTicker panicked when conn was nil: %v", p) + } + default: + } + + ap.pongAckTickerStop <- struct{}{} + synctest.Wait() + + select { + case p := <-panicCh: + if p != nil { + t.Fatalf("pongAckTicker panicked when conn was nil: %v", p) + } + default: + t.Fatal("pongAckTicker did not stop") + } + }) +} + +func TestCloseStopsPongAckTickerWhenConnNil(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + ap := NewAccesspoint(&librespot.NullLogger{}, nil, "") + done := make(chan struct{}) + + go func() { + defer close(done) + ap.pongAckTicker() + }() + + synctest.Wait() + ap.Close() + synctest.Wait() + + select { + case <-done: + default: + t.Fatal("pongAckTicker did not stop when closing with nil conn") + } + }) +} + +func TestCloseWaitsForInFlightSend(t *testing.T) { + conn := newBlockingConn() + ap := NewAccesspoint(&librespot.NullLogger{}, nil, "") + ap.conn = conn + ap.encConn = newShannonConn(conn, make([]byte, 32), make([]byte, 32)) + + sendDone := make(chan error, 1) + go func() { + sendDone <- ap.Send(context.Background(), PacketTypePing, []byte("payload")) + }() + + select { + case <-conn.startCh: + case <-time.After(time.Second): + t.Fatal("timed out waiting for send to start") + } + + closeDone := make(chan struct{}) + go func() { + ap.Close() + close(closeDone) + }() + + select { + case <-closeDone: + t.Fatal("close returned before in-flight send finished") + case <-time.After(50 * time.Millisecond): + } + + close(conn.blockCh) + + select { + case err := <-sendDone: + if err != nil { + t.Fatalf("send error = %v", err) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for send to finish") + } + + select { + case <-closeDone: + case <-time.After(time.Second): + t.Fatal("timed out waiting for close to finish") + } +}