-
Notifications
You must be signed in to change notification settings - Fork 0
/
websock_proxy_rt.go
98 lines (91 loc) · 2.14 KB
/
websock_proxy_rt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
package main
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"io"
"log"
"net"
"net/http"
)
type WSProxyRoundTripper struct{}
func (svc *wsProxyService) DialTo(clientId uint32) (net.Conn, error) {
if svc.listener == nil {
return nil, net.ErrClosed
}
reqConn, conn := net.Pipe()
svc.lock.Lock()
client := newWSSink(nil, 0)
svc.clients[clientId] = client
svc.lock.Unlock()
svc.listener.buf <- map[string]interface{}{"connected": clientId}
go func(s *wsSink) {
defer s.Close()
defer conn.Close()
defer reqConn.Close()
defer svc.closeClient(clientId)
for {
data := make([]byte, 16*1024)
n, err := conn.Read(data)
if err != nil {
if err != io.EOF {
log.Printf("Could not read[%d] from http request pipe: %s", n, err)
}
break
}
if n == 0 {
break
}
msg := make([]byte, 4+n)
binary.LittleEndian.PutUint32(msg, clientId)
copy(msg[4:], data[:n])
svc.listener.buf <- msg
}
}(client)
go func(s *wsSink) {
defer s.Close()
defer conn.Close()
defer reqConn.Close()
defer svc.closeClient(clientId)
for {
var data interface{}
select {
case <-s.done:
case data = <-s.buf:
}
if data == nil {
break
}
dataOut, ok := data.([]byte)
if !ok {
var err error
dataOut, err = json.Marshal(data)
if err != nil {
log.Printf("ERROR: could not marshal %#v: %s", data, err)
continue
}
}
if n, err := io.Copy(conn, bytes.NewReader(dataOut)); err != nil {
if err.Error() != "io: read/write on closed pipe" {
log.Printf("could not send %d/%d bytes to http req: %s", n, len(dataOut), err)
}
break
}
}
}(client)
return reqConn, nil
}
func (*WSProxyRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
request.URL.Scheme = "http"
wsprxSvc := getWsPrxSvc(request.URL.Host)
return (&http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
reqNum := ctx.Value(requestNumberContext).(int)
return wsprxSvc.DialTo(uint32(reqNum))
},
}).RoundTrip(request)
}
func init() {
customHttpSchemas["wsprx"] = func() http.RoundTripper { return &WSProxyRoundTripper{} }
}