@@ -10,6 +10,7 @@ import (
1010 "reflect"
1111 "runtime"
1212 "sync"
13+ "sync/atomic"
1314 "syscall/js"
1415
1516 "nhooyr.io/websocket/internal/wsjs"
@@ -19,9 +20,10 @@ import (
1920type Conn struct {
2021 ws wsjs.WebSocket
2122
22- closeOnce sync.Once
23- closed chan struct {}
24- closeErr error
23+ readClosed int64
24+ closeOnce sync.Once
25+ closed chan struct {}
26+ closeErr error
2527
2628 releaseOnClose func ()
2729 releaseOnMessage func ()
@@ -67,6 +69,10 @@ func (c *Conn) init() {
6769// Read attempts to read a message from the connection.
6870// The maximum time spent waiting is bounded by the context.
6971func (c * Conn ) Read (ctx context.Context ) (MessageType , []byte , error ) {
72+ if atomic .LoadInt64 (& c .readClosed ) == 1 {
73+ return 0 , nil , fmt .Errorf ("websocket connection read closed" )
74+ }
75+
7076 typ , p , err := c .read (ctx )
7177 if err != nil {
7278 return 0 , nil , fmt .Errorf ("failed to read: %w" , err )
@@ -78,6 +84,7 @@ func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) {
7884 var me wsjs.MessageEvent
7985 select {
8086 case <- ctx .Done ():
87+ c .Close (StatusPolicyViolation , "read timed out" )
8188 return 0 , nil , ctx .Err ()
8289 case me = <- c .readch :
8390 case <- c .closed :
@@ -198,6 +205,7 @@ func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Resp
198205
199206 select {
200207 case <- ctx .Done ():
208+ c .Close (StatusPolicyViolation , "dial timed out" )
201209 return nil , nil , ctx .Err ()
202210 case <- opench :
203211 case <- c .closed :
@@ -215,3 +223,8 @@ func (c *netConn) netConnReader(ctx context.Context) (MessageType, io.Reader, er
215223 }
216224 return typ , bytes .NewReader (p ), nil
217225}
226+
227+ // Only implemented for use by *Conn.CloseRead in netconn.go
228+ func (c * Conn ) reader (ctx context.Context ) {
229+ c .read (ctx )
230+ }
0 commit comments