I have been exploring different ways we could proxy WebSocket data. I came across this implementation.
Option 1:
func (proxy *ProxyHttpServer) serveWebsocket(ctx *ProxyCtx, w http.ResponseWriter, req *http.Request) {
targetURL := url.URL{Scheme: "ws", Host: req.URL.Host, Path: req.URL.Path}
targetConn, err := proxy.connectDial(ctx, "tcp", targetURL.Host)
if err != nil {
ctx.Warnf("Error dialing target site: %v", err)
return
}
defer targetConn.Close()
// Connect to Client
hj, ok := w.(http.Hijacker)
if !ok {
panic("httpserver does not support hijacking")
}
clientConn, _, err := hj.Hijack()
if err != nil {
ctx.Warnf("Hijack error: %v", err)
return
}
// Perform handshake
if err := proxy.websocketHandshake(ctx, req, targetConn, clientConn); err != nil {
ctx.Warnf("Websocket handshake error: %v", err)
return
}
// Proxy ws connection
proxy.proxyWebsocket(ctx, targetConn, clientConn)
}
func (proxy *ProxyHttpServer) websocketHandshake(ctx *ProxyCtx, req *http.Request, targetSiteConn io.ReadWriter, clientConn io.ReadWriter) error {
// write handshake request to target
err := req.Write(targetSiteConn)
if err != nil {
ctx.Warnf("Error writing upgrade request: %v", err)
return err
}
targetTLSReader := bufio.NewReader(targetSiteConn)
// Read handshake response from target
resp, err := http.ReadResponse(targetTLSReader, req)
if err != nil {
ctx.Warnf("Error reading handhsake response %v", err)
return err
}
// Run response through handlers
resp = proxy.filterResponse(resp, ctx)
// Proxy handshake back to client
err = resp.Write(clientConn)
if err != nil {
ctx.Warnf("Error writing handshake response: %v", err)
return err
}
return nil
}
func (proxy *ProxyHttpServer) proxyWebsocket(ctx *ProxyCtx, dest io.ReadWriter, source io.ReadWriter) {
errChan := make(chan error, 2)
cp := func(dst io.Writer, src io.Reader) {
_, err := io.Copy(dst, src)
ctx.Warnf("Websocket error: %v", err)
errChan <- err
}
// Start proxying websocket data
go cp(dest, source)
go cp(source, dest)
<-errChan
}
This is a little bit different from the implementation that I came up with:
Option 2: My implementaion
...
import "nhooyr.io/websocket"
...
func handleWebSocketConnection(w http.ResponseWriter, r *http.Request) {
connA, err := websocket.Accept(w, r, nil)
if err != nil {
log.Printf("Failed to accept WebSocket connection: %v", err)
return
}
defer connA.Close(200, "Closing connectionA")
// Connect to dest server
caCert, _ := os.ReadFile(caCertPath)
caPool := x509.NewCertPool()
if !caPool.AppendCertsFromPEM(caCert) {
log.Fatalf("Failed to append CA certificate")
}
tlsConfig := &tls.Config{
RootCAs: caPool,
}
httpClient := &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
}
connB, _, err := websocket.Dial(r.Context(), AggWsAddr, &websocket.DialOptions{
HTTPClient: httpClient,
})
if err != nil {
log.Printf("Failed to connect: %v", err)
return
}
defer connB.Close(200, "Closing connection")
// Start goroutines to forward messages between client and agg
go forwardMessages(connA, connB, "Conn A -> Conn B")
go forwardMessages(connB, connA, "Conn B -> Conn A")
// Keep the connection open indefinitely
select {}
}
I was just curious to understand why one would choose one implementation over the other.
A few differences I see are:
- The first implementation is library-agnostic, but my implementation (the second one) specifically uses https://pkg.go.dev/nhooyr.io/websocket. Also, in the second implementation, the library takes care of the WebSocket handshake, whereas in the first implementation, it is handled by the proxy itself.
- More importantly, my implementation will make two WebSocket connections: client <-> proxy and proxy <-> server. So, I need to take care of copying headers from the proxy that are meant for the destination server. On the other hand, in the first implementation, there is a single WebSocket handshake between the client and the server. The proxy is forwarding those messages as well. So, I guess we don't have to worry about copying headers to the destination server.
Thanks!