@@ -10,6 +10,7 @@ import (
1010 "net/http"
1111 "net/http/httptest"
1212 "strings"
13+ "sync"
1314 "testing"
1415
1516 "nhooyr.io/websocket/internal/test/assert"
@@ -142,6 +143,42 @@ func TestAccept(t *testing.T) {
142143 _ , err := Accept (w , r , nil )
143144 assert .Contains (t , err , `failed to hijack connection` )
144145 })
146+ t .Run ("closeRace" , func (t * testing.T ) {
147+ t .Parallel ()
148+
149+ server , _ := net .Pipe ()
150+
151+ rw := bufio .NewReadWriter (bufio .NewReader (server ), bufio .NewWriter (server ))
152+ newResponseWriter := func () http.ResponseWriter {
153+ return mockHijacker {
154+ ResponseWriter : httptest .NewRecorder (),
155+ hijack : func () (net.Conn , * bufio.ReadWriter , error ) {
156+ return server , rw , nil
157+ },
158+ }
159+ }
160+ w := newResponseWriter ()
161+
162+ r := httptest .NewRequest ("GET" , "/" , nil )
163+ r .Header .Set ("Connection" , "Upgrade" )
164+ r .Header .Set ("Upgrade" , "websocket" )
165+ r .Header .Set ("Sec-WebSocket-Version" , "13" )
166+ r .Header .Set ("Sec-WebSocket-Key" , xrand .Base64 (16 ))
167+
168+ c , err := Accept (w , r , nil )
169+ wg := & sync.WaitGroup {}
170+ wg .Add (2 )
171+ go func () {
172+ c .Close (StatusInternalError , "the sky is falling" )
173+ wg .Done ()
174+ }()
175+ go func () {
176+ c .CloseNow ()
177+ wg .Done ()
178+ }()
179+ wg .Wait ()
180+ assert .Success (t , err )
181+ })
145182}
146183
147184func Test_verifyClientHandshake (t * testing.T ) {
0 commit comments