@@ -14,7 +14,7 @@ import (
14
14
"github.com/wencaiwulue/kubevpn/v2/pkg/util"
15
15
)
16
16
17
- func (h * tunHandler ) HandleClient (ctx context.Context , tun net.Conn , remoteAddr * net. UDPAddr ) {
17
+ func (h * tunHandler ) HandleClient (ctx context.Context , tun net.Conn ) {
18
18
device := & ClientDevice {
19
19
tun : tun ,
20
20
tunInbound : make (chan * Packet , MaxSize ),
@@ -23,7 +23,7 @@ func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn, remoteAddr
23
23
}
24
24
25
25
defer device .Close ()
26
- go device .handlePacket (ctx , remoteAddr , h .forward )
26
+ go device .handlePacket (ctx , h .forward )
27
27
go device .readFromTun (ctx )
28
28
go device .writeToTun (ctx )
29
29
go heartbeats (ctx , device .tun )
@@ -43,56 +43,40 @@ type ClientDevice struct {
43
43
forward * Forwarder
44
44
}
45
45
46
- func (d * ClientDevice ) handlePacket (ctx context.Context , remoteAddr * net. UDPAddr , forward * Forwarder ) {
46
+ func (d * ClientDevice ) handlePacket (ctx context.Context , forward * Forwarder ) {
47
47
for ctx .Err () == nil {
48
- packetConn , err := getRemotePacketConn (ctx , forward )
48
+ conn , err := forwardConn (ctx , forward )
49
49
if err != nil {
50
- plog .G (ctx ).Errorf ("Failed to get remote conn from %s -> %s: %s" , d .tun .LocalAddr (), remoteAddr , err )
50
+ plog .G (ctx ).Errorf ("Failed to get remote conn from %s -> %s: %s" , d .tun .LocalAddr (), forward . node . Remote , err )
51
51
time .Sleep (time .Second * 1 )
52
52
continue
53
53
}
54
- err = handlePacketClient (ctx , d .tunInbound , d .tunOutbound , packetConn , remoteAddr )
54
+ err = handlePacketClient (ctx , d .tunInbound , d .tunOutbound , conn )
55
55
if err != nil {
56
- plog .G (ctx ).Errorf ("Failed to transport data to remote %s: %v" , remoteAddr , err )
56
+ plog .G (ctx ).Errorf ("Failed to transport data to remote %s: %v" , conn . RemoteAddr () , err )
57
57
}
58
58
}
59
59
}
60
60
61
- func getRemotePacketConn (ctx context.Context , forwarder * Forwarder ) (net.PacketConn , error ) {
61
+ func forwardConn (ctx context.Context , forwarder * Forwarder ) (net.Conn , error ) {
62
62
conn , err := forwarder .DialContext (ctx )
63
63
if err != nil {
64
64
return nil , errors .Wrap (err , "failed to dial forwarder" )
65
65
}
66
-
67
- if packetConn , ok := conn .(net.PacketConn ); ! ok {
68
- return nil , errors .Errorf ("failed to cast packet conn to PacketConn" )
69
- } else {
70
- return packetConn , nil
71
- }
66
+ return conn , nil
72
67
}
73
68
74
- func handlePacketClient (ctx context.Context , tunInbound <- chan * Packet , tunOutbound chan <- * Packet , packetConn net.PacketConn , remoteAddr net. Addr ) error {
69
+ func handlePacketClient (ctx context.Context , tunInbound <- chan * Packet , tunOutbound chan <- * Packet , conn net.Conn ) error {
75
70
errChan := make (chan error , 2 )
76
- defer packetConn .Close ()
71
+ defer conn .Close ()
77
72
78
73
go func () {
79
74
defer util .HandleCrash ()
80
75
for packet := range tunInbound {
81
- if packet .src .Equal (packet .dst ) {
82
- util .SafeWrite (tunOutbound , packet , func (v * Packet ) {
83
- var p = "unknown"
84
- if _ , _ , protocol , err := util .ParseIP (v .data [:v .length ]); err == nil {
85
- p = layers .IPProtocol (protocol ).String ()
86
- }
87
- config .LPool .Put (v .data [:])
88
- plog .G (context .Background ()).Errorf ("Drop packet, SRC: %s, DST: %s, Protocol: %s, Length: %d" , v .src , v .dst , p , v .length )
89
- })
90
- continue
91
- }
92
- _ , err := packetConn .WriteTo (packet .data [:packet .length ], remoteAddr )
76
+ _ , err := conn .Write (packet .data [:packet .length ])
93
77
config .LPool .Put (packet .data [:])
94
78
if err != nil {
95
- util .SafeWrite (errChan , errors .Wrap (err , fmt . Sprintf ( "failed to write packet to remote %s" , remoteAddr ) ))
79
+ util .SafeWrite (errChan , errors .Wrap (err , "failed to write packet to remote" ))
96
80
return
97
81
}
98
82
}
@@ -102,10 +86,10 @@ func handlePacketClient(ctx context.Context, tunInbound <-chan *Packet, tunOutbo
102
86
defer util .HandleCrash ()
103
87
for {
104
88
buf := config .LPool .Get ().([]byte )[:]
105
- n , _ , err := packetConn . ReadFrom (buf [:])
89
+ n , err := conn . Read (buf [:])
106
90
if err != nil {
107
91
config .LPool .Put (buf [:])
108
- util .SafeWrite (errChan , errors .Wrap (err , fmt .Sprintf ("failed to read packet from remote %s" , remoteAddr )))
92
+ util .SafeWrite (errChan , errors .Wrap (err , fmt .Sprintf ("failed to read packet from remote %s" , conn . RemoteAddr () )))
109
93
return
110
94
}
111
95
if n == 0 {
@@ -115,7 +99,7 @@ func handlePacketClient(ctx context.Context, tunInbound <-chan *Packet, tunOutbo
115
99
}
116
100
util .SafeWrite (tunOutbound , NewPacket (buf [:], n , nil , nil ), func (v * Packet ) {
117
101
config .LPool .Put (v .data [:])
118
- plog .G (context .Background ()).Errorf ("Drop packet, LocalAddr: %s, Remote: %s, Length: %d" , packetConn .LocalAddr (), remoteAddr , v .length )
102
+ plog .G (context .Background ()).Errorf ("Drop packet, LocalAddr: %s, Remote: %s, Length: %d" , conn .LocalAddr (), conn . RemoteAddr () , v .length )
119
103
})
120
104
}
121
105
}()
@@ -150,10 +134,16 @@ func (d *ClientDevice) readFromTun(ctx context.Context) {
150
134
continue
151
135
}
152
136
plog .G (context .Background ()).Debugf ("SRC: %s, DST: %s, Protocol: %s, Length: %d" , src , dst , layers .IPProtocol (protocol ).String (), n )
153
- util .SafeWrite (d .tunInbound , NewPacket (buf [:], n , src , dst ), func (v * Packet ) {
137
+ packet := NewPacket (buf [:], n , src , dst )
138
+ f := func (v * Packet ) {
154
139
config .LPool .Put (v .data [:])
155
140
plog .G (context .Background ()).Errorf ("Drop packet, SRC: %s, DST: %s, Protocol: %s, Length: %d" , v .src , v .dst , layers .IPProtocol (protocol ).String (), v .length )
156
- })
141
+ }
142
+ if packet .src .Equal (packet .dst ) {
143
+ util .SafeWrite (d .tunOutbound , packet , f )
144
+ continue
145
+ }
146
+ util .SafeWrite (d .tunInbound , packet , f )
157
147
}
158
148
}
159
149
@@ -188,7 +178,7 @@ func heartbeats(ctx context.Context, tun net.Conn) {
188
178
return
189
179
}
190
180
191
- ticker := time .NewTicker (time . Second * 60 )
181
+ ticker := time .NewTicker (config . KeepAliveTime )
192
182
defer ticker .Stop ()
193
183
194
184
for ; ctx .Err () == nil ; <- ticker .C {
0 commit comments