22package websocketproxy
33
44import (
5+ "context"
56 "fmt"
67 "io"
78 "log"
4748 // If nil, DefaultDialer is used.
4849 Dialer * websocket.Dialer
4950
50- // Done specifies a channel for which all proxied websocket connections
51+ // done specifies a channel for which all proxied websocket connections
5152 // can be closed on demand by closing the channel.
52- Done chan struct {}
53+ done chan struct {}
5354 }
5455
5556 websocketMsg struct {
@@ -186,6 +187,9 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
186187
187188 errClient := make (chan error , 1 )
188189 errBackend := make (chan error , 1 )
190+ if w .done == nil {
191+ w .done = make (chan struct {})
192+ }
189193
190194 replicateWebsocketConn := func (dst , src * websocket.Conn , errc chan error ) {
191195 websocketMsgRcverC := make (chan websocketMsg , 1 )
@@ -214,9 +218,7 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
214218 errc <- err
215219 break
216220 }
217- case <- w .Done :
218- m := websocket .FormatCloseMessage (websocket .CloseGoingAway , "websocketproxy: closing connection" )
219- dst .WriteMessage (websocket .CloseMessage , m )
221+ case <- w .done :
220222 break
221223 }
222224 }
@@ -234,8 +236,21 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
234236 if e , ok := err .(* websocket.CloseError ); ! ok || e .Code == websocket .CloseAbnormalClosure {
235237 log .Printf ("websocketproxy: Error when copying from client to backend: %v" , err )
236238 }
237- case <- w .Done :
239+ case <- w .done :
240+ m := websocket .FormatCloseMessage (websocket .CloseGoingAway , "websocketproxy: closing connection" )
241+ connPub .WriteMessage (websocket .CloseMessage , m )
242+ connBackend .WriteMessage (websocket .CloseMessage , m )
243+ }
244+ }
245+
246+ // Shutdown closes ws connections by closing the done channel they are subscribed to.
247+ func (w * WebsocketProxy ) Shutdown (ctx context.Context ) error {
248+ // TODO: support using context for control and return error when applicable
249+ // Currently implemented such that the method signature matches http.Server.Shutdown()
250+ if w .done != nil {
251+ close (w .done )
238252 }
253+ return nil
239254}
240255
241256func copyHeader (dst , src http.Header ) {
0 commit comments