gh-143756: Avoid borrowed reference in SSL code (gh-143816)

GET_SOCKET() returned a borrowed reference, which was potentially
unsafe. Also, refactor out some common code.
This commit is contained in:
Sam Gross 2026-01-22 14:02:48 -05:00 committed by GitHub
parent bcf9cb0217
commit ee4e14aa4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -423,26 +423,6 @@ typedef enum {
#define ERRSTR1(x,y,z) (x ":" y ": " z)
#define ERRSTR(x) ERRSTR1("_ssl.c", Py_STRINGIFY(__LINE__), x)
// Get the socket from a PySSLSocket, if it has one.
// Return a borrowed reference.
static inline PySocketSockObject* GET_SOCKET(PySSLSocket *obj) {
if (obj->Socket) {
PyObject *sock;
if (PyWeakref_GetRef(obj->Socket, &sock)) {
// GET_SOCKET() returns a borrowed reference
Py_DECREF(sock);
}
else {
// dead weak reference
sock = Py_None;
}
return (PySocketSockObject *)sock; // borrowed reference
}
else {
return NULL;
}
}
/* If sock is NULL, use a timeout of 0 second */
#define GET_SOCKET_TIMEOUT(sock) \
((sock != NULL) ? (sock)->sock_timeout : 0)
@ -794,6 +774,35 @@ _ssl_deprecated(const char* msg, int stacklevel) {
#define PY_SSL_DEPRECATED(name, stacklevel, ret) \
if (_ssl_deprecated((name), (stacklevel)) == -1) return (ret)
// Get the socket from a PySSLSocket, if it has one.
// Stores a strong reference in out_sock.
static int
get_socket(PySSLSocket *obj, PySocketSockObject **out_sock,
const char *filename, int lineno)
{
if (!obj->Socket) {
*out_sock = NULL;
return 0;
}
PySocketSockObject *sock;
int res = PyWeakref_GetRef(obj->Socket, (PyObject **)&sock);
if (res == 0 || sock->sock_fd == INVALID_SOCKET) {
_setSSLError(get_state_sock(obj),
"Underlying socket connection gone",
PY_SSL_ERROR_NO_SOCKET, filename, lineno);
*out_sock = NULL;
return -1;
}
if (sock != NULL) {
/* just in case the blocking state of the socket has been changed */
int nonblocking = (sock->sock_timeout >= 0);
BIO_set_nbio(SSL_get_rbio(obj->ssl), nonblocking);
BIO_set_nbio(SSL_get_wbio(obj->ssl), nonblocking);
}
*out_sock = sock;
return res;
}
/*
* SSL objects
*/
@ -1021,24 +1030,13 @@ _ssl__SSLSocket_do_handshake_impl(PySSLSocket *self)
int ret;
_PySSLError err;
PyObject *exc = NULL;
int sockstate, nonblocking;
PySocketSockObject *sock = GET_SOCKET(self);
int sockstate;
PyTime_t timeout, deadline = 0;
int has_timeout;
if (sock) {
if (((PyObject*)sock) == Py_None) {
_setSSLError(get_state_sock(self),
"Underlying socket connection gone",
PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
return NULL;
}
Py_INCREF(sock);
/* just in case the blocking state of the socket has been changed */
nonblocking = (sock->sock_timeout >= 0);
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
PySocketSockObject *sock = NULL;
if (get_socket(self, &sock, __FILE__, __LINE__) < 0) {
return NULL;
}
timeout = GET_SOCKET_TIMEOUT(sock);
@ -2610,22 +2608,12 @@ _ssl__SSLSocket_sendfile_impl(PySSLSocket *self, int fd, Py_off_t offset,
int sockstate;
_PySSLError err;
PyObject *exc = NULL;
PySocketSockObject *sock = GET_SOCKET(self);
PyTime_t timeout, deadline = 0;
int has_timeout;
if (sock != NULL) {
if ((PyObject *)sock == Py_None) {
_setSSLError(get_state_sock(self),
"Underlying socket connection gone",
PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
return NULL;
}
Py_INCREF(sock);
/* just in case the blocking state of the socket has been changed */
int nonblocking = (sock->sock_timeout >= 0);
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
PySocketSockObject *sock = NULL;
if (get_socket(self, &sock, __FILE__, __LINE__) < 0) {
return NULL;
}
timeout = GET_SOCKET_TIMEOUT(sock);
@ -2747,26 +2735,12 @@ _ssl__SSLSocket_write_impl(PySSLSocket *self, Py_buffer *b)
int sockstate;
_PySSLError err;
PyObject *exc = NULL;
int nonblocking;
PySocketSockObject *sock = GET_SOCKET(self);
PyTime_t timeout, deadline = 0;
int has_timeout;
if (sock != NULL) {
if (((PyObject*)sock) == Py_None) {
_setSSLError(get_state_sock(self),
"Underlying socket connection gone",
PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
return NULL;
}
Py_INCREF(sock);
}
if (sock != NULL) {
/* just in case the blocking state of the socket has been changed */
nonblocking = (sock->sock_timeout >= 0);
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
PySocketSockObject *sock = NULL;
if (get_socket(self, &sock, __FILE__, __LINE__) < 0) {
return NULL;
}
timeout = GET_SOCKET_TIMEOUT(sock);
@ -2896,8 +2870,6 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len,
int sockstate;
_PySSLError err;
PyObject *exc = NULL;
int nonblocking;
PySocketSockObject *sock = GET_SOCKET(self);
PyTime_t timeout, deadline = 0;
int has_timeout;
@ -2906,14 +2878,9 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len,
return NULL;
}
if (sock != NULL) {
if (((PyObject*)sock) == Py_None) {
_setSSLError(get_state_sock(self),
"Underlying socket connection gone",
PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
return NULL;
}
Py_INCREF(sock);
PySocketSockObject *sock = NULL;
if (get_socket(self, &sock, __FILE__, __LINE__) < 0) {
return NULL;
}
if (!group_right_1) {
@ -2944,13 +2911,6 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len,
}
}
if (sock != NULL) {
/* just in case the blocking state of the socket has been changed */
nonblocking = (sock->sock_timeout >= 0);
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
}
timeout = GET_SOCKET_TIMEOUT(sock);
has_timeout = (timeout > 0);
if (has_timeout)
@ -3041,26 +3001,14 @@ _ssl__SSLSocket_shutdown_impl(PySSLSocket *self)
{
_PySSLError err;
PyObject *exc = NULL;
int sockstate, nonblocking, ret;
int sockstate, ret;
int zeros = 0;
PySocketSockObject *sock = GET_SOCKET(self);
PyTime_t timeout, deadline = 0;
int has_timeout;
if (sock != NULL) {
/* Guard against closed socket */
if ((((PyObject*)sock) == Py_None) || (sock->sock_fd == INVALID_SOCKET)) {
_setSSLError(get_state_sock(self),
"Underlying socket connection gone",
PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
return NULL;
}
Py_INCREF(sock);
/* Just in case the blocking state of the socket has been changed */
nonblocking = (sock->sock_timeout >= 0);
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
PySocketSockObject *sock = NULL;
if (get_socket(self, &sock, __FILE__, __LINE__) < 0) {
return NULL;
}
timeout = GET_SOCKET_TIMEOUT(sock);