Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,19 @@ import (
"golang.org/x/crypto/hkdf"
)

type CloseWriteConn interface {
net.Conn
type canCloseWrite interface {
CloseWrite() error
}

func tryCloseWrite(c net.Conn, forceCloseOnFail bool) error {
if conn, ok := c.(canCloseWrite); ok {
return conn.CloseWrite()
} else if forceCloseOnFail {
return c.Close()
}
return nil
}

type MirrorConn struct {
*sync.Mutex
net.Conn
Expand Down Expand Up @@ -183,7 +191,7 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
if pc, ok := conn.(*proxyproto.Conn); ok {
raw = pc.Raw() // for TCP splicing in io.Copy()
}
underlying := raw.(CloseWriteConn) // *net.TCPConn or *net.UnixConn
underlying := raw // *net.TCPConn or *net.UnixConn or sth strange (from tcp mask)

mutex := new(sync.Mutex)

Expand Down Expand Up @@ -275,10 +283,7 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
_, err := io.Copy(target, NewRatelimitedConn(underlying, &config.LimitFallbackUpload))
// close target writer when received FIN (err==nil)
if err == nil {
targetWriterCloser, ok := target.(CloseWriteConn)
if ok {
targetWriterCloser.CloseWrite()
}
tryCloseWrite(target, false)
} else {
// Close target when encountering RST (or any other errors)
target.Close()
Expand Down Expand Up @@ -443,7 +448,7 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
// Here is bidirectional direct forwarding:
// client ---underlying--- server ---target--- dest
// Call `underlying.CloseWrite()` once `io.Copy()` returned
underlying.CloseWrite()
tryCloseWrite(underlying, true)
}
waitGroup.Done()
}()
Expand Down