diff --git a/src/internal/poll/fd_windows.go b/src/internal/poll/fd_windows.go index edad6563508bc6..7953ad6aac9f90 100644 --- a/src/internal/poll/fd_windows.go +++ b/src/internal/poll/fd_windows.go @@ -149,7 +149,7 @@ var wsaMsgPool = sync.Pool{ // newWSAMsg creates a new WSAMsg with the provided parameters. // Use [freeWSAMsg] to free it. -func newWSAMsg(p []byte, oob []byte, flags int, unconnected bool) *windows.WSAMsg { +func newWSAMsg(p []byte, oob []byte, flags int, rsa *wsaRsa) *windows.WSAMsg { // The returned object can't be allocated in the stack because it is accessed asynchronously // by Windows in between several system calls. If the stack frame is moved while that happens, // then Windows may access invalid memory. @@ -164,34 +164,46 @@ func newWSAMsg(p []byte, oob []byte, flags int, unconnected bool) *windows.WSAMs Buf: unsafe.SliceData(oob), } msg.Flags = uint32(flags) - if unconnected { - msg.Name = wsaRsaPool.Get().(*syscall.RawSockaddrAny) - msg.Namelen = int32(unsafe.Sizeof(syscall.RawSockaddrAny{})) + if rsa != nil { + msg.Name = &rsa.name + msg.Namelen = rsa.namelen } return msg } func freeWSAMsg(msg *windows.WSAMsg) { // Clear pointers to buffers so they can be released by garbage collector. + msg.Name = nil + msg.Namelen = 0 msg.Buffers.Len = 0 msg.Buffers.Buf = nil msg.Control.Len = 0 msg.Control.Buf = nil - if msg.Name != nil { - *msg.Name = syscall.RawSockaddrAny{} - wsaRsaPool.Put(msg.Name) - msg.Name = nil - msg.Namelen = 0 - } wsaMsgPool.Put(msg) } +// wsaRsa bundles a [syscall.RawSockaddrAny] with its length for efficient caching. +// +// When used by WSARecvFrom, wsaRsa must be on the heap. See +// https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsarecvfrom. +type wsaRsa struct { + name syscall.RawSockaddrAny + namelen int32 +} + var wsaRsaPool = sync.Pool{ New: func() any { - return new(syscall.RawSockaddrAny) + return new(wsaRsa) }, } +func newWSARsa() *wsaRsa { + rsa := wsaRsaPool.Get().(*wsaRsa) + rsa.name = syscall.RawSockaddrAny{} + rsa.namelen = int32(unsafe.Sizeof(syscall.RawSockaddrAny{})) + return rsa +} + var operationPool = sync.Pool{ New: func() any { return new(operation) @@ -737,19 +749,18 @@ func (fd *FD) ReadFrom(buf []byte) (int, syscall.Sockaddr, error) { fd.pin('r', &buf[0]) - rsa := wsaRsaPool.Get().(*syscall.RawSockaddrAny) + rsa := newWSARsa() defer wsaRsaPool.Put(rsa) n, err := fd.execIO('r', func(o *operation) (qty uint32, err error) { - rsan := int32(unsafe.Sizeof(*rsa)) var flags uint32 - err = syscall.WSARecvFrom(fd.Sysfd, newWsaBuf(buf), 1, &qty, &flags, rsa, &rsan, &o.o, nil) + err = syscall.WSARecvFrom(fd.Sysfd, newWsaBuf(buf), 1, &qty, &flags, &rsa.name, &rsa.namelen, &o.o, nil) return qty, err }) err = fd.eofError(n, err) if err != nil { return n, nil, err } - sa, _ := rsa.Sockaddr() + sa, _ := rsa.name.Sockaddr() return n, sa, nil } @@ -768,19 +779,18 @@ func (fd *FD) ReadFromInet4(buf []byte, sa4 *syscall.SockaddrInet4) (int, error) fd.pin('r', &buf[0]) - rsa := wsaRsaPool.Get().(*syscall.RawSockaddrAny) + rsa := newWSARsa() defer wsaRsaPool.Put(rsa) n, err := fd.execIO('r', func(o *operation) (qty uint32, err error) { - rsan := int32(unsafe.Sizeof(*rsa)) var flags uint32 - err = syscall.WSARecvFrom(fd.Sysfd, newWsaBuf(buf), 1, &qty, &flags, rsa, &rsan, &o.o, nil) + err = syscall.WSARecvFrom(fd.Sysfd, newWsaBuf(buf), 1, &qty, &flags, &rsa.name, &rsa.namelen, &o.o, nil) return qty, err }) err = fd.eofError(n, err) if err != nil { return n, err } - rawToSockaddrInet4(rsa, sa4) + rawToSockaddrInet4(&rsa.name, sa4) return n, err } @@ -799,19 +809,18 @@ func (fd *FD) ReadFromInet6(buf []byte, sa6 *syscall.SockaddrInet6) (int, error) fd.pin('r', &buf[0]) - rsa := wsaRsaPool.Get().(*syscall.RawSockaddrAny) + rsa := newWSARsa() defer wsaRsaPool.Put(rsa) n, err := fd.execIO('r', func(o *operation) (qty uint32, err error) { - rsan := int32(unsafe.Sizeof(*rsa)) var flags uint32 - err = syscall.WSARecvFrom(fd.Sysfd, newWsaBuf(buf), 1, &qty, &flags, rsa, &rsan, &o.o, nil) + err = syscall.WSARecvFrom(fd.Sysfd, newWsaBuf(buf), 1, &qty, &flags, &rsa.name, &rsa.namelen, &o.o, nil) return qty, err }) err = fd.eofError(n, err) if err != nil { return n, err } - rawToSockaddrInet6(rsa, sa6) + rawToSockaddrInet6(&rsa.name, sa6) return n, err } @@ -1371,7 +1380,9 @@ func (fd *FD) ReadMsg(p []byte, oob []byte, flags int) (int, int, int, syscall.S p = p[:maxRW] } - msg := newWSAMsg(p, oob, flags, true) + rsa := newWSARsa() + defer wsaRsaPool.Put(rsa) + msg := newWSAMsg(p, oob, flags, rsa) defer freeWSAMsg(msg) n, err := fd.execIO('r', func(o *operation) (qty uint32, err error) { err = windows.WSARecvMsg(fd.Sysfd, msg, &qty, &o.o, nil) @@ -1396,7 +1407,9 @@ func (fd *FD) ReadMsgInet4(p []byte, oob []byte, flags int, sa4 *syscall.Sockadd p = p[:maxRW] } - msg := newWSAMsg(p, oob, flags, true) + rsa := newWSARsa() + defer wsaRsaPool.Put(rsa) + msg := newWSAMsg(p, oob, flags, rsa) defer freeWSAMsg(msg) n, err := fd.execIO('r', func(o *operation) (qty uint32, err error) { err = windows.WSARecvMsg(fd.Sysfd, msg, &qty, &o.o, nil) @@ -1420,7 +1433,9 @@ func (fd *FD) ReadMsgInet6(p []byte, oob []byte, flags int, sa6 *syscall.Sockadd p = p[:maxRW] } - msg := newWSAMsg(p, oob, flags, true) + rsa := newWSARsa() + defer wsaRsaPool.Put(rsa) + msg := newWSAMsg(p, oob, flags, rsa) defer freeWSAMsg(msg) n, err := fd.execIO('r', func(o *operation) (qty uint32, err error) { err = windows.WSARecvMsg(fd.Sysfd, msg, &qty, &o.o, nil) @@ -1444,15 +1459,18 @@ func (fd *FD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (int, int, err } defer fd.writeUnlock() - msg := newWSAMsg(p, oob, 0, sa != nil) - defer freeWSAMsg(msg) + var rsa *wsaRsa if sa != nil { + rsa = newWSARsa() + defer wsaRsaPool.Put(rsa) var err error - msg.Namelen, err = sockaddrToRaw(msg.Name, sa) + rsa.namelen, err = sockaddrToRaw(&rsa.name, sa) if err != nil { return 0, 0, err } } + msg := newWSAMsg(p, oob, 0, rsa) + defer freeWSAMsg(msg) n, err := fd.execIO('w', func(o *operation) (qty uint32, err error) { err = windows.WSASendMsg(fd.Sysfd, msg, 0, nil, &o.o, nil) return qty, err @@ -1471,11 +1489,14 @@ func (fd *FD) WriteMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (in } defer fd.writeUnlock() - msg := newWSAMsg(p, oob, 0, sa != nil) - defer freeWSAMsg(msg) + var rsa *wsaRsa if sa != nil { - msg.Namelen = sockaddrInet4ToRaw(msg.Name, sa) + rsa = newWSARsa() + defer wsaRsaPool.Put(rsa) + rsa.namelen = sockaddrInet4ToRaw(&rsa.name, sa) } + msg := newWSAMsg(p, oob, 0, rsa) + defer freeWSAMsg(msg) n, err := fd.execIO('w', func(o *operation) (qty uint32, err error) { err = windows.WSASendMsg(fd.Sysfd, msg, 0, nil, &o.o, nil) return qty, err @@ -1494,11 +1515,14 @@ func (fd *FD) WriteMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (in } defer fd.writeUnlock() - msg := newWSAMsg(p, oob, 0, sa != nil) - defer freeWSAMsg(msg) + var rsa *wsaRsa if sa != nil { - msg.Namelen = sockaddrInet6ToRaw(msg.Name, sa) + rsa = newWSARsa() + defer wsaRsaPool.Put(rsa) + rsa.namelen = sockaddrInet6ToRaw(&rsa.name, sa) } + msg := newWSAMsg(p, oob, 0, rsa) + defer freeWSAMsg(msg) n, err := fd.execIO('w', func(o *operation) (qty uint32, err error) { err = windows.WSASendMsg(fd.Sysfd, msg, 0, nil, &o.o, nil) return qty, err