6060from saml2 .samlp import SessionIndex
6161from saml2 .samlp import artifact_resolve_from_string
6262from saml2 .samlp import response_from_string
63- from saml2 .sigver import SignatureError
63+ from saml2 .sigver import SignatureError , XMLSEC_SESSION_KEY_URI_TO_ALG , RSA_OAEP
6464from saml2 .sigver import SigverError
6565from saml2 .sigver import get_pem_wrapped_unwrapped
6666from saml2 .sigver import make_temp
7878from saml2 .xmldsig import SIG_ALLOWED_ALG
7979from saml2 .xmldsig import DefaultSignature
8080
81-
8281logger = logging .getLogger (__name__ )
8382
8483__author__ = "rolandh"
@@ -181,6 +180,10 @@ def __init__(self, entity_type, config=None, config_file="", virtual_organizatio
181180
182181 self .sec = security_context (self .config )
183182
183+ self .encrypt_assertion_session_key_algs = self .config .encrypt_assertion_session_key_algs
184+ self .encrypt_assertion_cert_key_algs = self .config .encrypt_assertion_cert_key_algs
185+ self .default_rsa_oaep_mgf_alg = self .config .default_rsa_oaep_mgf_alg
186+
184187 if virtual_organization :
185188 if isinstance (virtual_organization , str ):
186189 self .vorg = self .config .vorg [virtual_organization ]
@@ -194,7 +197,6 @@ def __init__(self, entity_type, config=None, config_file="", virtual_organizatio
194197 self .sourceid = self .metadata .construct_source_id ()
195198 else :
196199 self .sourceid = {}
197-
198200 self .msg_cb = msg_cb
199201
200202 def reload_metadata (self , metadata_conf ):
@@ -644,34 +646,105 @@ def has_encrypt_cert_in_metadata(self, sp_entity_id):
644646 return True
645647 return False
646648
647- def _encrypt_assertion (self , encrypt_cert , sp_entity_id , response , node_xpath = None ):
649+ def _get_first_matching_alg (self , priority_list , metadata_list ):
650+ for alg in priority_list :
651+ for cert_method in metadata_list :
652+ if cert_method .get ("algorithm" ) == alg :
653+ return cert_method
654+ return None
655+
656+ def _encrypt_assertion (
657+ self ,
658+ encrypt_cert ,
659+ sp_entity_id ,
660+ response ,
661+ node_xpath = None ,
662+ encrypt_cert_session_key_alg = None ,
663+ encrypt_cert_cert_key_alg = None ,
664+ ):
648665 """Encryption of assertions.
649666
650667 :param encrypt_cert: Certificate to be used for encryption.
651668 :param sp_entity_id: Entity ID for the calling service provider.
652669 :param response: A samlp.Response
670+ :param encrypt_cert_cert_key_alg: algorithm used for encrypting session key
671+ :param encrypt_cert_session_key_alg: algorithm used for encrypting assertion
672+ :param encrypt_cert_cert_key_alg:
653673 :param node_xpath: Unquie path to the element to be encrypted.
654674 :return: A new samlp.Resonse with the designated assertion encrypted.
655675 """
656676 _certs = []
657677
658678 if encrypt_cert :
659- _certs .append ((None , encrypt_cert ))
679+ _certs .append ((None , encrypt_cert , None , None ))
660680 elif sp_entity_id is not None :
661- _certs = self .metadata .certs (sp_entity_id , "any" , "encryption" )
681+ _certs = self .metadata .certs (sp_entity_id , "any" , "encryption" , get_with_usage_and_encryption_methods = True )
662682 exception = None
663- for _cert_name , _cert in _certs :
683+
684+ # take certs with encryption and encryption_methods first (priority 1)
685+ sorted_certs = []
686+ for _unpacked_cert in _certs :
687+ _cert_name , _cert , _cert_use , _cert_encryption_methods = _unpacked_cert
688+ if _cert_use == "encryption" and _cert_encryption_methods :
689+ sorted_certs .append (_unpacked_cert )
690+
691+ # take certs with encryption or encryption_methods (priority 2)
692+ for _unpacked_cert in _certs :
693+ _cert_name , _cert , _cert_use , _cert_encryption_methods = _unpacked_cert
694+ if _cert_use == "encryption" and _unpacked_cert not in sorted_certs :
695+ sorted_certs .append (_unpacked_cert )
696+
697+ for _unpacked_cert in _certs :
698+ if _unpacked_cert not in sorted_certs :
699+ sorted_certs .append (_unpacked_cert )
700+
701+ for _cert_name , _cert , _cert_use , _cert_encryption_methods in sorted_certs :
664702 wrapped_cert , unwrapped_cert = get_pem_wrapped_unwrapped (_cert )
665703 try :
666704 tmp = make_temp (
667705 wrapped_cert .encode ("ascii" ),
668706 decode = False ,
669707 delete_tmpfiles = self .config .delete_tmpfiles ,
670708 )
709+
710+ msg_enc = (
711+ encrypt_cert_session_key_alg
712+ if encrypt_cert_session_key_alg
713+ else self .encrypt_assertion_session_key_algs [0 ]
714+ )
715+ key_enc = (
716+ encrypt_cert_cert_key_alg if encrypt_cert_cert_key_alg else self .encrypt_assertion_cert_key_algs [0 ]
717+ )
718+
719+ rsa_oaep_mgf_alg = self .default_rsa_oaep_mgf_alg if key_enc == RSA_OAEP else None
720+ if encrypt_cert != _cert and _cert_encryption_methods :
721+ viable_session_key_alg = self ._get_first_matching_alg (
722+ self .encrypt_assertion_session_key_algs , _cert_encryption_methods
723+ )
724+ if viable_session_key_alg :
725+ msg_enc = viable_session_key_alg .get ("algorithm" )
726+
727+ viable_cert_alg = self ._get_first_matching_alg (
728+ self .encrypt_assertion_cert_key_algs , _cert_encryption_methods
729+ )
730+ if viable_cert_alg :
731+ key_enc = viable_cert_alg .get ("algorithm" )
732+ mgf = viable_cert_alg .get ("mgf" )
733+ rsa_oaep_mgf_alg = mgf .get ("algorithm" ) if mgf else None
734+
735+ key_type = XMLSEC_SESSION_KEY_URI_TO_ALG .get (msg_enc )
736+
671737 response = self .sec .encrypt_assertion (
672738 response ,
673739 tmp .name ,
674- pre_encryption_part (key_name = _cert_name , encrypt_cert = unwrapped_cert ),
740+ pre_encryption_part (
741+ key_name = _cert_name ,
742+ encrypt_cert = unwrapped_cert ,
743+ msg_enc = msg_enc ,
744+ key_enc = key_enc ,
745+ rsa_oaep_mgf_alg = rsa_oaep_mgf_alg ,
746+ ),
747+ key_type = key_type ,
675748 node_xpath = node_xpath ,
676749 )
677750 return response
@@ -697,7 +770,11 @@ def _response(
697770 encrypt_assertion_self_contained = False ,
698771 encrypted_advice_attributes = False ,
699772 encrypt_cert_advice = None ,
773+ encrypt_cert_advice_cert_key_alg = None ,
774+ encrypt_cert_advice_session_key_alg = None ,
700775 encrypt_cert_assertion = None ,
776+ encrypt_cert_assertion_cert_key_alg = None ,
777+ encrypt_cert_assertion_session_key_alg = None ,
701778 sign_assertion = None ,
702779 pefim = False ,
703780 sign_alg = None ,
@@ -731,8 +808,16 @@ def _response(
731808 element should be encrypted.
732809 :param encrypt_cert_advice: Certificate to be used for encryption of
733810 assertions in the advice element.
811+ :param encrypt_cert_advice_cert_key_alg: algorithm used for encrypting session key
812+ by encrypt_cert_advice
813+ :param encrypt_cert_advice_session_key_alg: algorithm used for encrypting assertion
814+ when using encrypt_cert_advice
734815 :param encrypt_cert_assertion: Certificate to be used for encryption
735816 of assertions.
817+ :param encrypt_cert_assertion_cert_key_alg: algorithm used for encrypting session key
818+ by encrypt_cert_assertion
819+ :param encrypt_cert_assertion_session_key_alg: algorithm used for encrypting assertion when
820+ using encrypt_cert_assertion
736821 :param sign_assertion: True if assertions should be signed.
737822 :param pefim: True if a response according to the PEFIM profile
738823 should be created.
@@ -856,6 +941,8 @@ def _response(
856941 sp_entity_id ,
857942 response ,
858943 node_xpath = node_xpath ,
944+ encrypt_cert_session_key_alg = encrypt_cert_advice_session_key_alg ,
945+ encrypt_cert_cert_key_alg = encrypt_cert_advice_cert_key_alg ,
859946 )
860947 response = response_from_string (response )
861948
@@ -900,7 +987,13 @@ def _response(
900987 response = signed_instance_factory (response , self .sec , to_sign_assertion )
901988
902989 # XXX encrypt assertion
903- response = self ._encrypt_assertion (encrypt_cert_assertion , sp_entity_id , response )
990+ response = self ._encrypt_assertion (
991+ encrypt_cert_assertion ,
992+ sp_entity_id ,
993+ response ,
994+ encrypt_cert_session_key_alg = encrypt_cert_assertion_session_key_alg ,
995+ encrypt_cert_cert_key_alg = encrypt_cert_assertion_cert_key_alg ,
996+ )
904997 else :
905998 # XXX sign other parts! (defiend by to_sign)
906999 if to_sign :
@@ -1357,7 +1450,6 @@ def create_manage_name_id_response(
13571450 digest_alg = None ,
13581451 ** kwargs ,
13591452 ):
1360-
13611453 rinfo = self .response_args (request , bindings )
13621454
13631455 response = self ._status_response (
0 commit comments