99 "crypto/sha1"
1010 "encoding/base64"
1111 "encoding/binary"
12+ "errors"
1213 "fmt"
1314 "io"
1415 "net"
@@ -27,6 +28,8 @@ import (
2728
2829const pongAckInterval = 120 * time .Second
2930
31+ var ErrAccesspointClosed = errors .New ("accesspoint closed" )
32+
3033type AccesspointLoginError struct {
3134 Message * pb.APLoginFailed
3235}
@@ -48,6 +51,7 @@ type Accesspoint struct {
4851 conn net.Conn
4952 encConn * shannonConn
5053
54+ closed bool
5155 stop bool
5256 pongAckTickerStop chan struct {}
5357 recvLoopStop chan struct {}
@@ -65,7 +69,14 @@ type Accesspoint struct {
6569}
6670
6771func NewAccesspoint (log librespot.Logger , addr librespot.GetAddressFunc , deviceId string ) * Accesspoint {
68- return & Accesspoint {log : log , addr : addr , deviceId : deviceId , recvChans : make (map [PacketType ][]chan Packet )}
72+ return & Accesspoint {
73+ log : log ,
74+ addr : addr ,
75+ deviceId : deviceId ,
76+ pongAckTickerStop : make (chan struct {}, 1 ),
77+ recvLoopStop : make (chan struct {}, 1 ),
78+ recvChans : make (map [PacketType ][]chan Packet ),
79+ }
6980}
7081
7182func (ap * Accesspoint ) init (ctx context.Context ) (err error ) {
@@ -188,6 +199,10 @@ func (ap *Accesspoint) Connect(ctx context.Context, creds *pb.LoginCredentials)
188199 ap .connMu .Lock ()
189200 defer ap .connMu .Unlock ()
190201
202+ if ap .closed {
203+ return ErrAccesspointClosed
204+ }
205+
191206 return backoff .Retry (func () error {
192207 err := ap .connect (ctx , creds )
193208 if err != nil {
@@ -199,9 +214,6 @@ func (ap *Accesspoint) Connect(ctx context.Context, creds *pb.LoginCredentials)
199214}
200215
201216func (ap * Accesspoint ) connect (ctx context.Context , creds * pb.LoginCredentials ) error {
202- ap .recvLoopStop = make (chan struct {}, 1 )
203- ap .pongAckTickerStop = make (chan struct {}, 1 )
204-
205217 if err := ap .init (ctx ); err != nil {
206218 return err
207219 }
@@ -232,27 +244,33 @@ func (ap *Accesspoint) connect(ctx context.Context, creds *pb.LoginCredentials)
232244
233245func (ap * Accesspoint ) Close () {
234246 ap .connMu .Lock ()
235- defer ap .connMu .Unlock ()
236-
247+ ap .closed = true
237248 ap .stop = true
238-
239- if ap .conn == nil {
240- return
241- }
242-
243- ap .recvLoopStop <- struct {}{}
244- ap .pongAckTickerStop <- struct {}{}
245- _ = ap .conn .Close ()
249+ ap .signalStop ()
250+ ap .closeConnLocked ()
251+ ap .connMu .Unlock ()
246252}
247253
248254func (ap * Accesspoint ) Send (ctx context.Context , pktType PacketType , payload []byte ) error {
249255 ap .connMu .RLock ()
250256 defer ap .connMu .RUnlock ()
257+
258+ if ap .closed {
259+ return ErrAccesspointClosed
260+ }
261+
251262 return ap .encConn .sendPacket (ctx , pktType , payload )
252263}
253264
254265func (ap * Accesspoint ) Receive (types ... PacketType ) <- chan Packet {
255266 ch := make (chan Packet )
267+ ap .connMu .RLock ()
268+ if ap .closed {
269+ ap .connMu .RUnlock ()
270+ close (ch )
271+ return ch
272+ }
273+
256274 ap .recvChansLock .Lock ()
257275 for _ , type_ := range types {
258276 ll , _ := ap .recvChans [type_ ]
@@ -263,18 +281,17 @@ func (ap *Accesspoint) Receive(types ...PacketType) <-chan Packet {
263281
264282 // start the recv loop if necessary
265283 ap .startReceiving ()
284+ ap .connMu .RUnlock ()
266285
267286 return ch
268287}
269288
270289func (ap * Accesspoint ) startReceiving () {
271290 ap .recvLoopOnce .Do (func () {
272291 ap .log .Tracef ("starting accesspoint recv loop" )
273- go ap .recvLoop ()
274-
275- // set last ping in the future
276- ap .lastPongAck = time .Now ().Add (pongAckInterval )
292+ ap .resetPongAckDeadline ()
277293 go ap .pongAckTicker ()
294+ go ap .recvLoop ()
278295 })
279296}
280297
@@ -288,7 +305,7 @@ loop:
288305 // no need to hold the connMu since reconnection happens in this routine
289306 pkt , payload , err := ap .encConn .receivePacket (context .TODO ())
290307 if err != nil {
291- if ! ap .stop {
308+ if ! ap .isStopped () {
292309 ap .log .WithError (err ).Errorf ("failed receiving packet" )
293310 }
294311
@@ -298,15 +315,13 @@ loop:
298315 switch pkt {
299316 case PacketTypePing :
300317 ap .log .Tracef ("received accesspoint ping" )
301- if err := ap .encConn . sendPacket (context .TODO (), PacketTypePong , payload ); err != nil {
318+ if err := ap .Send (context .TODO (), PacketTypePong , payload ); err != nil {
302319 ap .log .WithError (err ).Errorf ("failed sending Pong packet" )
303320 break loop
304321 }
305322 case PacketTypePongAck :
306323 ap .log .Tracef ("received accesspoint pong ack" )
307- ap .lastPongAckLock .Lock ()
308- ap .lastPongAck = time .Now ()
309- ap .lastPongAckLock .Unlock ()
324+ ap .notePongAck ()
310325 continue
311326 default :
312327 ap .recvChansLock .RLock ()
@@ -327,10 +342,10 @@ loop:
327342 }
328343
329344 // always close as we might end up here because of application errors
330- _ = ap .conn . Close ()
345+ ap .closeConn ()
331346
332347 // if we shouldn't stop, try to reconnect
333- if ! ap .stop {
348+ if ! ap .isStopped () {
334349 ap .connMu .Lock ()
335350 if err := backoff .Retry (ap .reconnect , backoff .NewExponentialBackOff ()); err != nil {
336351 ap .log .WithError (err ).Errorf ("failed reconnecting accesspoint" )
@@ -369,15 +384,13 @@ loop:
369384 case <- ap .pongAckTickerStop :
370385 break loop
371386 case <- ticker .C :
372- ap .lastPongAckLock .Lock ()
373- timePassed := time .Since (ap .lastPongAck )
374- ap .lastPongAckLock .Unlock ()
387+ timePassed := ap .timeSinceLastPongAck ()
375388 if timePassed > pongAckInterval {
376389 ap .log .Errorf ("did not receive last pong ack from accesspoint, %.0fs passed" , timePassed .Seconds ())
377390
378391 // closing the connection should make the read on the "recvLoop" fail,
379392 // continue hoping for a new connection
380- _ = ap .conn . Close ()
393+ ap .closeConn ()
381394 continue
382395 }
383396 }
@@ -399,13 +412,63 @@ func (ap *Accesspoint) reconnect() (err error) {
399412 return err
400413 }
401414
415+ ap .resetPongAckDeadline ()
416+
402417 // if we are here the "recvLoop" has already died, restart it
403418 go ap .recvLoop ()
404419
405420 ap .log .Debugf ("re-established accesspoint connection" )
406421 return nil
407422}
408423
424+ func (ap * Accesspoint ) resetPongAckDeadline () {
425+ ap .lastPongAckLock .Lock ()
426+ ap .lastPongAck = time .Now ().Add (pongAckInterval )
427+ ap .lastPongAckLock .Unlock ()
428+ }
429+
430+ func (ap * Accesspoint ) notePongAck () {
431+ ap .lastPongAckLock .Lock ()
432+ ap .lastPongAck = time .Now ()
433+ ap .lastPongAckLock .Unlock ()
434+ }
435+
436+ func (ap * Accesspoint ) timeSinceLastPongAck () time.Duration {
437+ ap .lastPongAckLock .Lock ()
438+ defer ap .lastPongAckLock .Unlock ()
439+ return time .Since (ap .lastPongAck )
440+ }
441+
442+ func (ap * Accesspoint ) closeConn () {
443+ ap .connMu .Lock ()
444+ ap .closeConnLocked ()
445+ ap .connMu .Unlock ()
446+ }
447+
448+ func (ap * Accesspoint ) closeConnLocked () {
449+ if ap .conn != nil {
450+ _ = ap .conn .Close ()
451+ }
452+ }
453+
454+ func (ap * Accesspoint ) signalStop () {
455+ select {
456+ case ap .recvLoopStop <- struct {}{}:
457+ default :
458+ }
459+
460+ select {
461+ case ap .pongAckTickerStop <- struct {}{}:
462+ default :
463+ }
464+ }
465+
466+ func (ap * Accesspoint ) isStopped () bool {
467+ ap .connMu .RLock ()
468+ defer ap .connMu .RUnlock ()
469+ return ap .stop
470+ }
471+
409472func (ap * Accesspoint ) performKeyExchange () ([]byte , error ) {
410473 // accumulate transferred data for challenge
411474 cc := & connAccumulator {Conn : ap .conn }
0 commit comments