Skip to content

Commit 5cbb72d

Browse files
fix: harden accesspoint connection lifecycle
Prevent panics and races during AP close/reconnect by: - Initializing stop channels in the constructor instead of on each connect - Adding a permanent closed state that rejects Send/Receive/Connect after Close - Using non-blocking signal sends to avoid goroutine leaks - Protecting conn access with proper locking (closeConn, isStopped helpers) - Resetting pong-ack deadline on reconnect to avoid false timeouts
1 parent b030f61 commit 5cbb72d

2 files changed

Lines changed: 266 additions & 29 deletions

File tree

ap/ap.go

Lines changed: 92 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"crypto/sha1"
1010
"encoding/base64"
1111
"encoding/binary"
12+
"errors"
1213
"fmt"
1314
"io"
1415
"net"
@@ -27,6 +28,8 @@ import (
2728

2829
const pongAckInterval = 120 * time.Second
2930

31+
var ErrAccesspointClosed = errors.New("accesspoint closed")
32+
3033
type AccesspointLoginError struct {
3134
Message *pb.APLoginFailed
3235
}
@@ -48,6 +51,7 @@ type Accesspoint struct {
4851
conn net.Conn
4952
encConn *shannonConn
5053

54+
closed bool
5155
stop bool
5256
pongAckTickerStop chan struct{}
5357
recvLoopStop chan struct{}
@@ -65,7 +69,14 @@ type Accesspoint struct {
6569
}
6670

6771
func NewAccesspoint(log librespot.Logger, addr librespot.GetAddressFunc, deviceId string) *Accesspoint {
68-
return &Accesspoint{log: log, addr: addr, deviceId: deviceId, recvChans: make(map[PacketType][]chan Packet)}
72+
return &Accesspoint{
73+
log: log,
74+
addr: addr,
75+
deviceId: deviceId,
76+
pongAckTickerStop: make(chan struct{}, 1),
77+
recvLoopStop: make(chan struct{}, 1),
78+
recvChans: make(map[PacketType][]chan Packet),
79+
}
6980
}
7081

7182
func (ap *Accesspoint) init(ctx context.Context) (err error) {
@@ -188,6 +199,10 @@ func (ap *Accesspoint) Connect(ctx context.Context, creds *pb.LoginCredentials)
188199
ap.connMu.Lock()
189200
defer ap.connMu.Unlock()
190201

202+
if ap.closed {
203+
return ErrAccesspointClosed
204+
}
205+
191206
return backoff.Retry(func() error {
192207
err := ap.connect(ctx, creds)
193208
if err != nil {
@@ -199,9 +214,6 @@ func (ap *Accesspoint) Connect(ctx context.Context, creds *pb.LoginCredentials)
199214
}
200215

201216
func (ap *Accesspoint) connect(ctx context.Context, creds *pb.LoginCredentials) error {
202-
ap.recvLoopStop = make(chan struct{}, 1)
203-
ap.pongAckTickerStop = make(chan struct{}, 1)
204-
205217
if err := ap.init(ctx); err != nil {
206218
return err
207219
}
@@ -232,27 +244,33 @@ func (ap *Accesspoint) connect(ctx context.Context, creds *pb.LoginCredentials)
232244

233245
func (ap *Accesspoint) Close() {
234246
ap.connMu.Lock()
235-
defer ap.connMu.Unlock()
236-
247+
ap.closed = true
237248
ap.stop = true
238-
239-
if ap.conn == nil {
240-
return
241-
}
242-
243-
ap.recvLoopStop <- struct{}{}
244-
ap.pongAckTickerStop <- struct{}{}
245-
_ = ap.conn.Close()
249+
ap.signalStop()
250+
ap.closeConnLocked()
251+
ap.connMu.Unlock()
246252
}
247253

248254
func (ap *Accesspoint) Send(ctx context.Context, pktType PacketType, payload []byte) error {
249255
ap.connMu.RLock()
250256
defer ap.connMu.RUnlock()
257+
258+
if ap.closed {
259+
return ErrAccesspointClosed
260+
}
261+
251262
return ap.encConn.sendPacket(ctx, pktType, payload)
252263
}
253264

254265
func (ap *Accesspoint) Receive(types ...PacketType) <-chan Packet {
255266
ch := make(chan Packet)
267+
ap.connMu.RLock()
268+
if ap.closed {
269+
ap.connMu.RUnlock()
270+
close(ch)
271+
return ch
272+
}
273+
256274
ap.recvChansLock.Lock()
257275
for _, type_ := range types {
258276
ll, _ := ap.recvChans[type_]
@@ -263,18 +281,17 @@ func (ap *Accesspoint) Receive(types ...PacketType) <-chan Packet {
263281

264282
// start the recv loop if necessary
265283
ap.startReceiving()
284+
ap.connMu.RUnlock()
266285

267286
return ch
268287
}
269288

270289
func (ap *Accesspoint) startReceiving() {
271290
ap.recvLoopOnce.Do(func() {
272291
ap.log.Tracef("starting accesspoint recv loop")
273-
go ap.recvLoop()
274-
275-
// set last ping in the future
276-
ap.lastPongAck = time.Now().Add(pongAckInterval)
292+
ap.resetPongAckDeadline()
277293
go ap.pongAckTicker()
294+
go ap.recvLoop()
278295
})
279296
}
280297

@@ -288,7 +305,7 @@ loop:
288305
// no need to hold the connMu since reconnection happens in this routine
289306
pkt, payload, err := ap.encConn.receivePacket(context.TODO())
290307
if err != nil {
291-
if !ap.stop {
308+
if !ap.isStopped() {
292309
ap.log.WithError(err).Errorf("failed receiving packet")
293310
}
294311

@@ -298,15 +315,13 @@ loop:
298315
switch pkt {
299316
case PacketTypePing:
300317
ap.log.Tracef("received accesspoint ping")
301-
if err := ap.encConn.sendPacket(context.TODO(), PacketTypePong, payload); err != nil {
318+
if err := ap.Send(context.TODO(), PacketTypePong, payload); err != nil {
302319
ap.log.WithError(err).Errorf("failed sending Pong packet")
303320
break loop
304321
}
305322
case PacketTypePongAck:
306323
ap.log.Tracef("received accesspoint pong ack")
307-
ap.lastPongAckLock.Lock()
308-
ap.lastPongAck = time.Now()
309-
ap.lastPongAckLock.Unlock()
324+
ap.notePongAck()
310325
continue
311326
default:
312327
ap.recvChansLock.RLock()
@@ -327,10 +342,10 @@ loop:
327342
}
328343

329344
// always close as we might end up here because of application errors
330-
_ = ap.conn.Close()
345+
ap.closeConn()
331346

332347
// if we shouldn't stop, try to reconnect
333-
if !ap.stop {
348+
if !ap.isStopped() {
334349
ap.connMu.Lock()
335350
if err := backoff.Retry(ap.reconnect, backoff.NewExponentialBackOff()); err != nil {
336351
ap.log.WithError(err).Errorf("failed reconnecting accesspoint")
@@ -369,15 +384,13 @@ loop:
369384
case <-ap.pongAckTickerStop:
370385
break loop
371386
case <-ticker.C:
372-
ap.lastPongAckLock.Lock()
373-
timePassed := time.Since(ap.lastPongAck)
374-
ap.lastPongAckLock.Unlock()
387+
timePassed := ap.timeSinceLastPongAck()
375388
if timePassed > pongAckInterval {
376389
ap.log.Errorf("did not receive last pong ack from accesspoint, %.0fs passed", timePassed.Seconds())
377390

378391
// closing the connection should make the read on the "recvLoop" fail,
379392
// continue hoping for a new connection
380-
_ = ap.conn.Close()
393+
ap.closeConn()
381394
continue
382395
}
383396
}
@@ -399,13 +412,63 @@ func (ap *Accesspoint) reconnect() (err error) {
399412
return err
400413
}
401414

415+
ap.resetPongAckDeadline()
416+
402417
// if we are here the "recvLoop" has already died, restart it
403418
go ap.recvLoop()
404419

405420
ap.log.Debugf("re-established accesspoint connection")
406421
return nil
407422
}
408423

424+
func (ap *Accesspoint) resetPongAckDeadline() {
425+
ap.lastPongAckLock.Lock()
426+
ap.lastPongAck = time.Now().Add(pongAckInterval)
427+
ap.lastPongAckLock.Unlock()
428+
}
429+
430+
func (ap *Accesspoint) notePongAck() {
431+
ap.lastPongAckLock.Lock()
432+
ap.lastPongAck = time.Now()
433+
ap.lastPongAckLock.Unlock()
434+
}
435+
436+
func (ap *Accesspoint) timeSinceLastPongAck() time.Duration {
437+
ap.lastPongAckLock.Lock()
438+
defer ap.lastPongAckLock.Unlock()
439+
return time.Since(ap.lastPongAck)
440+
}
441+
442+
func (ap *Accesspoint) closeConn() {
443+
ap.connMu.Lock()
444+
ap.closeConnLocked()
445+
ap.connMu.Unlock()
446+
}
447+
448+
func (ap *Accesspoint) closeConnLocked() {
449+
if ap.conn != nil {
450+
_ = ap.conn.Close()
451+
}
452+
}
453+
454+
func (ap *Accesspoint) signalStop() {
455+
select {
456+
case ap.recvLoopStop <- struct{}{}:
457+
default:
458+
}
459+
460+
select {
461+
case ap.pongAckTickerStop <- struct{}{}:
462+
default:
463+
}
464+
}
465+
466+
func (ap *Accesspoint) isStopped() bool {
467+
ap.connMu.RLock()
468+
defer ap.connMu.RUnlock()
469+
return ap.stop
470+
}
471+
409472
func (ap *Accesspoint) performKeyExchange() ([]byte, error) {
410473
// accumulate transferred data for challenge
411474
cc := &connAccumulator{Conn: ap.conn}

0 commit comments

Comments
 (0)