Skip to content

Commit 1456d81

Browse files
Merge branch 'devgianlu:master' into master
2 parents 4292571 + aaed97c commit 1456d81

9 files changed

Lines changed: 561 additions & 82 deletions

File tree

ap/ap.go

Lines changed: 92 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"crypto/sha1"
1010
"encoding/base64"
1111
"encoding/binary"
12+
"errors"
1213
"fmt"
1314
"io"
1415
"net"
@@ -27,6 +28,8 @@ import (
2728

2829
const pongAckInterval = 120 * time.Second
2930

31+
var ErrAccesspointClosed = errors.New("accesspoint closed")
32+
3033
type AccesspointLoginError struct {
3134
Message *pb.APLoginFailed
3235
}
@@ -48,6 +51,7 @@ type Accesspoint struct {
4851
conn net.Conn
4952
encConn *shannonConn
5053

54+
closed bool
5155
stop bool
5256
pongAckTickerStop chan struct{}
5357
recvLoopStop chan struct{}
@@ -65,7 +69,14 @@ type Accesspoint struct {
6569
}
6670

6771
func NewAccesspoint(log librespot.Logger, addr librespot.GetAddressFunc, deviceId string) *Accesspoint {
68-
return &Accesspoint{log: log, addr: addr, deviceId: deviceId, recvChans: make(map[PacketType][]chan Packet)}
72+
return &Accesspoint{
73+
log: log,
74+
addr: addr,
75+
deviceId: deviceId,
76+
pongAckTickerStop: make(chan struct{}, 1),
77+
recvLoopStop: make(chan struct{}, 1),
78+
recvChans: make(map[PacketType][]chan Packet),
79+
}
6980
}
7081

7182
func (ap *Accesspoint) init(ctx context.Context) (err error) {
@@ -186,6 +197,10 @@ func (ap *Accesspoint) Connect(ctx context.Context, creds *pb.LoginCredentials)
186197
ap.connMu.Lock()
187198
defer ap.connMu.Unlock()
188199

200+
if ap.closed {
201+
return ErrAccesspointClosed
202+
}
203+
189204
return backoff.Retry(func() error {
190205
err := ap.connect(ctx, creds)
191206
if err != nil {
@@ -197,9 +212,6 @@ func (ap *Accesspoint) Connect(ctx context.Context, creds *pb.LoginCredentials)
197212
}
198213

199214
func (ap *Accesspoint) connect(ctx context.Context, creds *pb.LoginCredentials) error {
200-
ap.recvLoopStop = make(chan struct{}, 1)
201-
ap.pongAckTickerStop = make(chan struct{}, 1)
202-
203215
if err := ap.init(ctx); err != nil {
204216
return err
205217
}
@@ -230,27 +242,33 @@ func (ap *Accesspoint) connect(ctx context.Context, creds *pb.LoginCredentials)
230242

231243
func (ap *Accesspoint) Close() {
232244
ap.connMu.Lock()
233-
defer ap.connMu.Unlock()
234-
245+
ap.closed = true
235246
ap.stop = true
236-
237-
if ap.conn == nil {
238-
return
239-
}
240-
241-
ap.recvLoopStop <- struct{}{}
242-
ap.pongAckTickerStop <- struct{}{}
243-
_ = ap.conn.Close()
247+
ap.signalStop()
248+
ap.closeConnLocked()
249+
ap.connMu.Unlock()
244250
}
245251

246252
func (ap *Accesspoint) Send(ctx context.Context, pktType PacketType, payload []byte) error {
247253
ap.connMu.RLock()
248254
defer ap.connMu.RUnlock()
255+
256+
if ap.closed {
257+
return ErrAccesspointClosed
258+
}
259+
249260
return ap.encConn.sendPacket(ctx, pktType, payload)
250261
}
251262

252263
func (ap *Accesspoint) Receive(types ...PacketType) <-chan Packet {
253264
ch := make(chan Packet)
265+
ap.connMu.RLock()
266+
if ap.closed {
267+
ap.connMu.RUnlock()
268+
close(ch)
269+
return ch
270+
}
271+
254272
ap.recvChansLock.Lock()
255273
for _, type_ := range types {
256274
ll, _ := ap.recvChans[type_]
@@ -261,18 +279,17 @@ func (ap *Accesspoint) Receive(types ...PacketType) <-chan Packet {
261279

262280
// start the recv loop if necessary
263281
ap.startReceiving()
282+
ap.connMu.RUnlock()
264283

265284
return ch
266285
}
267286

268287
func (ap *Accesspoint) startReceiving() {
269288
ap.recvLoopOnce.Do(func() {
270289
ap.log.Tracef("starting accesspoint recv loop")
271-
go ap.recvLoop()
272-
273-
// set last ping in the future
274-
ap.lastPongAck = time.Now().Add(pongAckInterval)
290+
ap.resetPongAckDeadline()
275291
go ap.pongAckTicker()
292+
go ap.recvLoop()
276293
})
277294
}
278295

@@ -286,7 +303,7 @@ loop:
286303
// no need to hold the connMu since reconnection happens in this routine
287304
pkt, payload, err := ap.encConn.receivePacket(context.TODO())
288305
if err != nil {
289-
if !ap.stop {
306+
if !ap.isStopped() {
290307
ap.log.WithError(err).Errorf("failed receiving packet")
291308
}
292309

@@ -296,15 +313,13 @@ loop:
296313
switch pkt {
297314
case PacketTypePing:
298315
ap.log.Tracef("received accesspoint ping")
299-
if err := ap.encConn.sendPacket(context.TODO(), PacketTypePong, payload); err != nil {
316+
if err := ap.Send(context.TODO(), PacketTypePong, payload); err != nil {
300317
ap.log.WithError(err).Errorf("failed sending Pong packet")
301318
break loop
302319
}
303320
case PacketTypePongAck:
304321
ap.log.Tracef("received accesspoint pong ack")
305-
ap.lastPongAckLock.Lock()
306-
ap.lastPongAck = time.Now()
307-
ap.lastPongAckLock.Unlock()
322+
ap.notePongAck()
308323
continue
309324
default:
310325
ap.recvChansLock.RLock()
@@ -325,10 +340,10 @@ loop:
325340
}
326341

327342
// always close as we might end up here because of application errors
328-
_ = ap.conn.Close()
343+
ap.closeConn()
329344

330345
// if we shouldn't stop, try to reconnect
331-
if !ap.stop {
346+
if !ap.isStopped() {
332347
ap.connMu.Lock()
333348
if err := backoff.Retry(ap.reconnect, backoff.NewExponentialBackOff()); err != nil {
334349
ap.log.WithError(err).Errorf("failed reconnecting accesspoint")
@@ -367,15 +382,13 @@ loop:
367382
case <-ap.pongAckTickerStop:
368383
break loop
369384
case <-ticker.C:
370-
ap.lastPongAckLock.Lock()
371-
timePassed := time.Since(ap.lastPongAck)
372-
ap.lastPongAckLock.Unlock()
385+
timePassed := ap.timeSinceLastPongAck()
373386
if timePassed > pongAckInterval {
374387
ap.log.Errorf("did not receive last pong ack from accesspoint, %.0fs passed", timePassed.Seconds())
375388

376389
// closing the connection should make the read on the "recvLoop" fail,
377390
// continue hoping for a new connection
378-
_ = ap.conn.Close()
391+
ap.closeConn()
379392
continue
380393
}
381394
}
@@ -397,13 +410,63 @@ func (ap *Accesspoint) reconnect() (err error) {
397410
return err
398411
}
399412

413+
ap.resetPongAckDeadline()
414+
400415
// if we are here the "recvLoop" has already died, restart it
401416
go ap.recvLoop()
402417

403418
ap.log.Debugf("re-established accesspoint connection")
404419
return nil
405420
}
406421

422+
func (ap *Accesspoint) resetPongAckDeadline() {
423+
ap.lastPongAckLock.Lock()
424+
ap.lastPongAck = time.Now().Add(pongAckInterval)
425+
ap.lastPongAckLock.Unlock()
426+
}
427+
428+
func (ap *Accesspoint) notePongAck() {
429+
ap.lastPongAckLock.Lock()
430+
ap.lastPongAck = time.Now()
431+
ap.lastPongAckLock.Unlock()
432+
}
433+
434+
func (ap *Accesspoint) timeSinceLastPongAck() time.Duration {
435+
ap.lastPongAckLock.Lock()
436+
defer ap.lastPongAckLock.Unlock()
437+
return time.Since(ap.lastPongAck)
438+
}
439+
440+
func (ap *Accesspoint) closeConn() {
441+
ap.connMu.Lock()
442+
ap.closeConnLocked()
443+
ap.connMu.Unlock()
444+
}
445+
446+
func (ap *Accesspoint) closeConnLocked() {
447+
if ap.conn != nil {
448+
_ = ap.conn.Close()
449+
}
450+
}
451+
452+
func (ap *Accesspoint) signalStop() {
453+
select {
454+
case ap.recvLoopStop <- struct{}{}:
455+
default:
456+
}
457+
458+
select {
459+
case ap.pongAckTickerStop <- struct{}{}:
460+
default:
461+
}
462+
}
463+
464+
func (ap *Accesspoint) isStopped() bool {
465+
ap.connMu.RLock()
466+
defer ap.connMu.RUnlock()
467+
return ap.stop
468+
}
469+
407470
func (ap *Accesspoint) performKeyExchange() ([]byte, error) {
408471
// accumulate transferred data for challenge
409472
cc := &connAccumulator{Conn: ap.conn}

0 commit comments

Comments
 (0)