@@ -733,7 +733,6 @@ class Connection(object):
733733 _socket = None
734734
735735 _socket_impl = socket
736- _ssl_impl = ssl
737736
738737 _check_hostname = False
739738 _product_type = None
@@ -757,7 +756,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
757756 self .endpoint = host if isinstance (host , EndPoint ) else DefaultEndPoint (host , port )
758757
759758 self .authenticator = authenticator
760- self .ssl_options = ssl_options .copy () if ssl_options else None
759+ self .ssl_options = ssl_options .copy () if ssl_options else {}
761760 self .ssl_context = ssl_context
762761 self .sockopts = sockopts
763762 self .compression = compression
@@ -777,15 +776,20 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
777776 self ._on_orphaned_stream_released = on_orphaned_stream_released
778777
779778 if ssl_options :
780- self ._check_hostname = bool (self .ssl_options .pop ('check_hostname' , False ))
781- if self ._check_hostname :
782- if not getattr (ssl , 'match_hostname' , None ):
783- raise RuntimeError ("ssl_options specify 'check_hostname', but ssl.match_hostname is not provided. "
784- "Patch or upgrade Python to use this option." )
785779 self .ssl_options .update (self .endpoint .ssl_options or {})
786780 elif self .endpoint .ssl_options :
787781 self .ssl_options = self .endpoint .ssl_options
788782
783+ # PYTHON-1331
784+ #
785+ # We always use SSLContext.wrap_socket() now but legacy configs may have other params that were passed to ssl.wrap_socket()...
786+ # and either could have 'check_hostname'. Remove these params into a separate map and use them to build an SSLContext if
787+ # we need to do so.
788+ #
789+ # Note the use of pop() here; we are very deliberately removing these params from ssl_options if they're present. After this
790+ # operation ssl_options should contain only args needed for the ssl_context.wrap_socket() call.
791+ if not self .ssl_context and self .ssl_options :
792+ self .ssl_context = self ._build_ssl_context_from_options ()
789793
790794 if protocol_version >= 3 :
791795 self .max_request_id = min (self .max_in_flight - 1 , (2 ** 15 ) - 1 )
@@ -852,21 +856,57 @@ def factory(cls, endpoint, timeout, *args, **kwargs):
852856 else :
853857 return conn
854858
859+ def _build_ssl_context_from_options (self ):
860+
861+ # Extract a subset of names from self.ssl_options which apply to SSLContext creation
862+ ssl_context_opt_names = ['ssl_version' , 'cert_reqs' , 'check_hostname' , 'keyfile' , 'certfile' , 'ca_certs' , 'ciphers' ]
863+ opts = {k :self .ssl_options .get (k , None ) for k in ssl_context_opt_names if k in self .ssl_options }
864+
865+ # Python >= 3.10 requires either PROTOCOL_TLS_CLIENT or PROTOCOL_TLS_SERVER so we'll get ahead of things by always
866+ # being explicit
867+ ssl_version = opts .get ('ssl_version' , None ) or ssl .PROTOCOL_TLS_CLIENT
868+ cert_reqs = opts .get ('cert_reqs' , None ) or ssl .CERT_REQUIRED
869+ rv = ssl .SSLContext (protocol = int (ssl_version ))
870+ rv .check_hostname = bool (opts .get ('check_hostname' , False ))
871+ rv .options = int (cert_reqs )
872+
873+ certfile = opts .get ('certfile' , None )
874+ keyfile = opts .get ('keyfile' , None )
875+ if certfile :
876+ rv .load_cert_chain (certfile , keyfile )
877+ ca_certs = opts .get ('ca_certs' , None )
878+ if ca_certs :
879+ rv .load_verify_locations (ca_certs )
880+ ciphers = opts .get ('ciphers' , None )
881+ if ciphers :
882+ rv .set_ciphers (ciphers )
883+
884+ return rv
885+
855886 def _wrap_socket_from_context (self ):
856- ssl_options = self .ssl_options or {}
887+
888+ # Extract a subset of names from self.ssl_options which apply to SSLContext.wrap_socket (or at least the parts
889+ # of it that don't involve building an SSLContext under the covers)
890+ wrap_socket_opt_names = ['server_side' , 'do_handshake_on_connect' , 'suppress_ragged_eofs' , 'server_hostname' ]
891+ opts = {k :self .ssl_options .get (k , None ) for k in wrap_socket_opt_names if k in self .ssl_options }
892+
857893 # PYTHON-1186: set the server_hostname only if the SSLContext has
858894 # check_hostname enabled and it is not already provided by the EndPoint ssl options
859- if (self .ssl_context .check_hostname and
860- 'server_hostname' not in ssl_options ):
861- ssl_options = ssl_options .copy ()
862- ssl_options ['server_hostname' ] = self .endpoint .address
863- self ._socket = self .ssl_context .wrap_socket (self ._socket , ** ssl_options )
895+ #opts['server_hostname'] = self.endpoint.address
896+ if (self .ssl_context .check_hostname and 'server_hostname' not in opts ):
897+ server_hostname = self .endpoint .address
898+ opts ['server_hostname' ] = server_hostname
899+
900+ return self .ssl_context .wrap_socket (self ._socket , ** opts )
864901
865902 def _initiate_connection (self , sockaddr ):
866903 self ._socket .connect (sockaddr )
867904
868- def _match_hostname (self ):
869- ssl .match_hostname (self ._socket .getpeercert (), self .endpoint .address )
905+ # PYTHON-1331
906+ #
907+ # Allow implementations specific to an event loop to add additional behaviours
908+ def _validate_hostname (self ):
909+ pass
870910
871911 def _get_socket_addresses (self ):
872912 address , port = self .endpoint .resolve ()
@@ -887,16 +927,18 @@ def _connect_socket(self):
887927 try :
888928 self ._socket = self ._socket_impl .socket (af , socktype , proto )
889929 if self .ssl_context :
890- self ._wrap_socket_from_context ()
891- elif self .ssl_options :
892- if not self ._ssl_impl :
893- raise RuntimeError ("This version of Python was not compiled with SSL support" )
894- self ._socket = self ._ssl_impl .wrap_socket (self ._socket , ** self .ssl_options )
930+ self ._socket = self ._wrap_socket_from_context ()
895931 self ._socket .settimeout (self .connect_timeout )
896932 self ._initiate_connection (sockaddr )
897933 self ._socket .settimeout (None )
934+
935+ # PYTHON-1331
936+ #
937+ # Most checking is done via the check_hostname param on the SSLContext.
938+ # Subclasses can add additional behaviours via _validate_hostname() so
939+ # run that here.
898940 if self ._check_hostname :
899- self ._match_hostname ()
941+ self ._validate_hostname ()
900942 sockerr = None
901943 break
902944 except socket .error as err :
0 commit comments