Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 92 additions & 29 deletions ap/ap.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
Expand All @@ -27,6 +28,8 @@ import (

const pongAckInterval = 120 * time.Second

var ErrAccesspointClosed = errors.New("accesspoint closed")

type AccesspointLoginError struct {
Message *pb.APLoginFailed
}
Expand All @@ -48,6 +51,7 @@ type Accesspoint struct {
conn net.Conn
encConn *shannonConn

closed bool
stop bool
pongAckTickerStop chan struct{}
recvLoopStop chan struct{}
Expand All @@ -65,7 +69,14 @@ type Accesspoint struct {
}

func NewAccesspoint(log librespot.Logger, addr librespot.GetAddressFunc, deviceId string) *Accesspoint {
return &Accesspoint{log: log, addr: addr, deviceId: deviceId, recvChans: make(map[PacketType][]chan Packet)}
return &Accesspoint{
log: log,
addr: addr,
deviceId: deviceId,
pongAckTickerStop: make(chan struct{}, 1),
recvLoopStop: make(chan struct{}, 1),
recvChans: make(map[PacketType][]chan Packet),
}
}

func (ap *Accesspoint) init(ctx context.Context) (err error) {
Expand Down Expand Up @@ -188,6 +199,10 @@ func (ap *Accesspoint) Connect(ctx context.Context, creds *pb.LoginCredentials)
ap.connMu.Lock()
defer ap.connMu.Unlock()

if ap.closed {
return ErrAccesspointClosed
}

return backoff.Retry(func() error {
err := ap.connect(ctx, creds)
if err != nil {
Expand All @@ -199,9 +214,6 @@ func (ap *Accesspoint) Connect(ctx context.Context, creds *pb.LoginCredentials)
}

func (ap *Accesspoint) connect(ctx context.Context, creds *pb.LoginCredentials) error {
ap.recvLoopStop = make(chan struct{}, 1)
ap.pongAckTickerStop = make(chan struct{}, 1)

if err := ap.init(ctx); err != nil {
return err
}
Expand Down Expand Up @@ -232,27 +244,33 @@ func (ap *Accesspoint) connect(ctx context.Context, creds *pb.LoginCredentials)

func (ap *Accesspoint) Close() {
ap.connMu.Lock()
defer ap.connMu.Unlock()

ap.closed = true
ap.stop = true

if ap.conn == nil {
return
}

ap.recvLoopStop <- struct{}{}
ap.pongAckTickerStop <- struct{}{}
_ = ap.conn.Close()
ap.signalStop()
ap.closeConnLocked()
ap.connMu.Unlock()
}

func (ap *Accesspoint) Send(ctx context.Context, pktType PacketType, payload []byte) error {
ap.connMu.RLock()
defer ap.connMu.RUnlock()

if ap.closed {
return ErrAccesspointClosed
}

return ap.encConn.sendPacket(ctx, pktType, payload)
}

func (ap *Accesspoint) Receive(types ...PacketType) <-chan Packet {
ch := make(chan Packet)
ap.connMu.RLock()
if ap.closed {
ap.connMu.RUnlock()
close(ch)
return ch
}

ap.recvChansLock.Lock()
for _, type_ := range types {
ll, _ := ap.recvChans[type_]
Expand All @@ -263,18 +281,17 @@ func (ap *Accesspoint) Receive(types ...PacketType) <-chan Packet {

// start the recv loop if necessary
ap.startReceiving()
ap.connMu.RUnlock()

return ch
}

func (ap *Accesspoint) startReceiving() {
ap.recvLoopOnce.Do(func() {
ap.log.Tracef("starting accesspoint recv loop")
go ap.recvLoop()

// set last ping in the future
ap.lastPongAck = time.Now().Add(pongAckInterval)
ap.resetPongAckDeadline()
go ap.pongAckTicker()
go ap.recvLoop()
})
}

Expand All @@ -288,7 +305,7 @@ loop:
// no need to hold the connMu since reconnection happens in this routine
pkt, payload, err := ap.encConn.receivePacket(context.TODO())
if err != nil {
if !ap.stop {
if !ap.isStopped() {
ap.log.WithError(err).Errorf("failed receiving packet")
}

Expand All @@ -298,15 +315,13 @@ loop:
switch pkt {
case PacketTypePing:
ap.log.Tracef("received accesspoint ping")
if err := ap.encConn.sendPacket(context.TODO(), PacketTypePong, payload); err != nil {
if err := ap.Send(context.TODO(), PacketTypePong, payload); err != nil {
ap.log.WithError(err).Errorf("failed sending Pong packet")
break loop
}
case PacketTypePongAck:
ap.log.Tracef("received accesspoint pong ack")
ap.lastPongAckLock.Lock()
ap.lastPongAck = time.Now()
ap.lastPongAckLock.Unlock()
ap.notePongAck()
continue
default:
ap.recvChansLock.RLock()
Expand All @@ -327,10 +342,10 @@ loop:
}

// always close as we might end up here because of application errors
_ = ap.conn.Close()
ap.closeConn()

// if we shouldn't stop, try to reconnect
if !ap.stop {
if !ap.isStopped() {
ap.connMu.Lock()
if err := backoff.Retry(ap.reconnect, backoff.NewExponentialBackOff()); err != nil {
ap.log.WithError(err).Errorf("failed reconnecting accesspoint")
Expand Down Expand Up @@ -369,15 +384,13 @@ loop:
case <-ap.pongAckTickerStop:
break loop
case <-ticker.C:
ap.lastPongAckLock.Lock()
timePassed := time.Since(ap.lastPongAck)
ap.lastPongAckLock.Unlock()
timePassed := ap.timeSinceLastPongAck()
if timePassed > pongAckInterval {
ap.log.Errorf("did not receive last pong ack from accesspoint, %.0fs passed", timePassed.Seconds())

// closing the connection should make the read on the "recvLoop" fail,
// continue hoping for a new connection
_ = ap.conn.Close()
ap.closeConn()
continue
}
}
Expand All @@ -399,13 +412,63 @@ func (ap *Accesspoint) reconnect() (err error) {
return err
}

ap.resetPongAckDeadline()

// if we are here the "recvLoop" has already died, restart it
go ap.recvLoop()

ap.log.Debugf("re-established accesspoint connection")
return nil
}

func (ap *Accesspoint) resetPongAckDeadline() {
ap.lastPongAckLock.Lock()
ap.lastPongAck = time.Now().Add(pongAckInterval)
ap.lastPongAckLock.Unlock()
}

func (ap *Accesspoint) notePongAck() {
ap.lastPongAckLock.Lock()
ap.lastPongAck = time.Now()
ap.lastPongAckLock.Unlock()
}

func (ap *Accesspoint) timeSinceLastPongAck() time.Duration {
ap.lastPongAckLock.Lock()
defer ap.lastPongAckLock.Unlock()
return time.Since(ap.lastPongAck)
}

func (ap *Accesspoint) closeConn() {
ap.connMu.Lock()
ap.closeConnLocked()
ap.connMu.Unlock()
}

func (ap *Accesspoint) closeConnLocked() {
if ap.conn != nil {
_ = ap.conn.Close()
}
}

func (ap *Accesspoint) signalStop() {
select {
case ap.recvLoopStop <- struct{}{}:
default:
}

select {
case ap.pongAckTickerStop <- struct{}{}:
default:
}
}

func (ap *Accesspoint) isStopped() bool {
ap.connMu.RLock()
defer ap.connMu.RUnlock()
return ap.stop
}

func (ap *Accesspoint) performKeyExchange() ([]byte, error) {
// accumulate transferred data for challenge
cc := &connAccumulator{Conn: ap.conn}
Expand Down
Loading
Loading