Move credential usage enum to Credential class.
[shibboleth/xmltooling.git] / xmltooling / security / impl / CredentialCriteria.cpp
1 /*
2  *  Copyright 2001-2007 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * CredentialCriteria.cpp
19  * 
20  * Class for specifying criteria by which a CredentialResolver should resolve credentials.
21  */
22
23 #include "internal.h"
24 #include "logging.h"
25 #include "security/Credential.h"
26 #include "security/CredentialCriteria.h"
27 #include "security/KeyInfoResolver.h"
28
29 #include <openssl/dsa.h>
30 #include <openssl/rsa.h>
31 #include <xsec/enc/OpenSSL/OpenSSLCryptoKeyDSA.hpp>
32 #include <xsec/enc/OpenSSL/OpenSSLCryptoKeyRSA.hpp>
33
34 using namespace xmltooling;
35 using namespace std;
36
37 bool CredentialCriteria::matches(const Credential& credential) const
38 {
39     // Usage check, if specified and we have one.
40     if (getUsage() != Credential::UNSPECIFIED_CREDENTIAL) {
41         if (credential.getUsage() != Credential::UNSPECIFIED_CREDENTIAL)
42             if (getUsage() != credential.getUsage())
43                 return false;
44     }
45
46     // Algorithm check, if specified and we have one.
47     const char* alg = getKeyAlgorithm();
48     if (alg && *alg) {
49         const char* alg2 = credential.getAlgorithm();
50         if (alg2 && *alg2)
51             if (strcmp(alg,alg2))
52                 return false;
53     }
54
55     // KeySize check, if specified and we have one.
56     if (credential.getKeySize()>0 && getKeySize()>0 && credential.getKeySize() != getKeySize())
57         return false;
58
59     // See if we can test key names.
60     const set<string>& critnames = getKeyNames();
61     const set<string>& crednames = credential.getKeyNames();
62     if (!critnames.empty() && !crednames.empty()) {
63         bool found = false;
64         for (set<string>::const_iterator n = critnames.begin(); n!=critnames.end(); ++n) {
65             if (crednames.count(*n)>0) {
66                 found = true;
67                 break;
68             }
69         }
70         if (!found)
71             return false;
72     }
73
74     // See if we have to match a specific key.
75     const XSECCryptoKey* key1 = getPublicKey();
76     if (!key1)
77         return true;    // no key to compare against, so we're done
78
79     const XSECCryptoKey* key2 = credential.getPublicKey();
80     if (!key2)
81         return true;   // no key here, so we can't test it
82
83     if (key1->getProviderName()!=DSIGConstants::s_unicodeStrPROVOpenSSL ||
84         key2->getProviderName()!=DSIGConstants::s_unicodeStrPROVOpenSSL) {
85         logging::Category::getInstance(XMLTOOLING_LOGCAT".Credential").warn("comparison of non-OpenSSL credentials are not supported");
86         return false;
87     }
88
89     if (key1->getKeyType()==XSECCryptoKey::KEY_RSA_PUBLIC || key1->getKeyType()==XSECCryptoKey::KEY_RSA_PAIR) {
90         if (key2->getKeyType()!=XSECCryptoKey::KEY_RSA_PUBLIC && key2->getKeyType()!=XSECCryptoKey::KEY_RSA_PAIR)
91             return false;
92         const RSA* rsa1 = static_cast<const OpenSSLCryptoKeyRSA*>(key1)->getOpenSSLRSA();
93         const RSA* rsa2 = static_cast<const OpenSSLCryptoKeyRSA*>(key2)->getOpenSSLRSA();
94         return (BN_cmp(rsa1->n,rsa2->n) == 0 && BN_cmp(rsa1->e,rsa2->e) == 0);
95     }
96
97     if (key1->getKeyType()==XSECCryptoKey::KEY_DSA_PUBLIC || key1->getKeyType()==XSECCryptoKey::KEY_DSA_PAIR) {
98         if (key2->getKeyType()!=XSECCryptoKey::KEY_DSA_PUBLIC && key2->getKeyType()!=XSECCryptoKey::KEY_DSA_PAIR)
99             return false;
100         const DSA* dsa1 = static_cast<const OpenSSLCryptoKeyDSA*>(key1)->getOpenSSLDSA();
101         const DSA* dsa2 = static_cast<const OpenSSLCryptoKeyDSA*>(key2)->getOpenSSLDSA();
102         return (BN_cmp(dsa1->pub_key,dsa2->pub_key) == 0);
103     }
104     
105     logging::Category::getInstance(XMLTOOLING_LOGCAT".CredentialCriteria").warn("unsupported key type for comparison");
106     return false;
107 }