add
This commit is contained in:
@@ -0,0 +1,219 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/gofrs/uuid/v5"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type ConnectionInfo struct {
|
||||
ID string
|
||||
Conn net.Conn
|
||||
PacketConn network.PacketConn
|
||||
Inbound string
|
||||
Type string // "tcp" or "udp"
|
||||
}
|
||||
|
||||
type ConnTracker struct {
|
||||
access sync.Mutex
|
||||
connections map[string]*ConnectionInfo
|
||||
}
|
||||
|
||||
func NewConnTracker() *ConnTracker {
|
||||
return &ConnTracker{
|
||||
connections: make(map[string]*ConnectionInfo),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConnTracker) Reset() {
|
||||
c.access.Lock()
|
||||
defer c.access.Unlock()
|
||||
for _, connInfo := range c.connections {
|
||||
if connInfo.Conn != nil {
|
||||
_ = connInfo.Conn.Close()
|
||||
}
|
||||
if connInfo.PacketConn != nil {
|
||||
_ = connInfo.PacketConn.Close()
|
||||
}
|
||||
}
|
||||
c.connections = make(map[string]*ConnectionInfo)
|
||||
}
|
||||
|
||||
func (c *ConnTracker) generateConnectionID() string {
|
||||
return uuid.Must(uuid.NewV4()).String()
|
||||
}
|
||||
|
||||
func (c *ConnTracker) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) net.Conn {
|
||||
connID := c.generateConnectionID()
|
||||
connInfo := &ConnectionInfo{
|
||||
ID: connID,
|
||||
Conn: conn,
|
||||
Inbound: metadata.Inbound,
|
||||
Type: "tcp",
|
||||
}
|
||||
|
||||
c.trackConnection(connID, connInfo)
|
||||
|
||||
return c.createWrappedConn(conn, connID)
|
||||
}
|
||||
|
||||
func (c *ConnTracker) RoutedPacketConnection(ctx context.Context, conn network.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) network.PacketConn {
|
||||
connID := c.generateConnectionID()
|
||||
connInfo := &ConnectionInfo{
|
||||
ID: connID,
|
||||
PacketConn: conn,
|
||||
Inbound: metadata.Inbound,
|
||||
Type: "udp",
|
||||
}
|
||||
|
||||
c.trackConnection(connID, connInfo)
|
||||
|
||||
return c.createWrappedPacketConn(conn, connID)
|
||||
}
|
||||
|
||||
func (c *ConnTracker) CloseConnByInbound(inbound string) int {
|
||||
c.access.Lock()
|
||||
defer c.access.Unlock()
|
||||
|
||||
closedCount := 0
|
||||
for connID, connInfo := range c.connections {
|
||||
if connInfo.Inbound == inbound {
|
||||
if connInfo.Conn != nil {
|
||||
connInfo.Conn.Close()
|
||||
}
|
||||
if connInfo.PacketConn != nil {
|
||||
connInfo.PacketConn.Close()
|
||||
}
|
||||
delete(c.connections, connID)
|
||||
closedCount++
|
||||
}
|
||||
}
|
||||
return closedCount
|
||||
}
|
||||
|
||||
func (c *ConnTracker) trackConnection(connID string, connInfo *ConnectionInfo) {
|
||||
c.access.Lock()
|
||||
defer c.access.Unlock()
|
||||
c.connections[connID] = connInfo
|
||||
}
|
||||
|
||||
func (c *ConnTracker) untrackConnection(connID string) {
|
||||
c.access.Lock()
|
||||
defer c.access.Unlock()
|
||||
delete(c.connections, connID)
|
||||
}
|
||||
|
||||
// shouldUntrackIOErr reports whether err indicates the connection is done (peer closed, reset, etc.).
|
||||
func shouldUntrackIOErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return true
|
||||
}
|
||||
var ne net.Error
|
||||
if errors.As(err, &ne) {
|
||||
return !ne.Temporary()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *ConnTracker) createWrappedConn(conn net.Conn, connID string) *wrappedConn {
|
||||
return &wrappedConn{
|
||||
Conn: conn,
|
||||
tracker: c,
|
||||
connID: connID,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConnTracker) createWrappedPacketConn(conn network.PacketConn, connID string) *wrappedPacketConn {
|
||||
return &wrappedPacketConn{
|
||||
PacketConn: conn,
|
||||
tracker: c,
|
||||
connID: connID,
|
||||
}
|
||||
}
|
||||
|
||||
type wrappedConn struct {
|
||||
net.Conn
|
||||
tracker *ConnTracker
|
||||
connID string
|
||||
untrackOnce sync.Once
|
||||
}
|
||||
|
||||
func (w *wrappedConn) doUntrack() {
|
||||
w.untrackOnce.Do(func() {
|
||||
w.tracker.untrackConnection(w.connID)
|
||||
})
|
||||
}
|
||||
|
||||
func (w *wrappedConn) Read(b []byte) (int, error) {
|
||||
n, err := w.Conn.Read(b)
|
||||
if shouldUntrackIOErr(err) {
|
||||
w.doUntrack()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *wrappedConn) Write(b []byte) (int, error) {
|
||||
n, err := w.Conn.Write(b)
|
||||
if err != nil && shouldUntrackIOErr(err) {
|
||||
w.doUntrack()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *wrappedConn) Close() error {
|
||||
w.doUntrack()
|
||||
return w.Conn.Close()
|
||||
}
|
||||
|
||||
func (w *wrappedConn) Upstream() any {
|
||||
return w.Conn
|
||||
}
|
||||
|
||||
type wrappedPacketConn struct {
|
||||
network.PacketConn
|
||||
tracker *ConnTracker
|
||||
connID string
|
||||
untrackOnce sync.Once
|
||||
}
|
||||
|
||||
func (w *wrappedPacketConn) doUntrack() {
|
||||
w.untrackOnce.Do(func() {
|
||||
w.tracker.untrackConnection(w.connID)
|
||||
})
|
||||
}
|
||||
|
||||
func (w *wrappedPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
dest, err := w.PacketConn.ReadPacket(buffer)
|
||||
if shouldUntrackIOErr(err) {
|
||||
w.doUntrack()
|
||||
}
|
||||
return dest, err
|
||||
}
|
||||
|
||||
func (w *wrappedPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
err := w.PacketConn.WritePacket(buffer, destination)
|
||||
if err != nil && shouldUntrackIOErr(err) {
|
||||
w.doUntrack()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *wrappedPacketConn) Close() error {
|
||||
w.doUntrack()
|
||||
return w.PacketConn.Close()
|
||||
}
|
||||
|
||||
func (w *wrappedPacketConn) Upstream() any {
|
||||
return w.PacketConn
|
||||
}
|
||||
Reference in New Issue
Block a user