From: Luke Howard Date: Wed, 13 Oct 2010 04:07:52 +0000 (+1100) Subject: add set/delete attribute to SAML provider X-Git-Tag: vm/20110310~117 X-Git-Url: http://www.project-moonshot.org/gitweb/?p=mech_eap.git;a=commitdiff_plain;h=1b6f36f4407ebb4cd09cb916118a19c1d850581c add set/delete attribute to SAML provider --- diff --git a/util_saml.cpp b/util_saml.cpp index 2017f3e..b4afd21 100644 --- a/util_saml.cpp +++ b/util_saml.cpp @@ -329,14 +329,25 @@ gss_eap_saml_assertion_provider::createAttrContext(void) return new gss_eap_saml_assertion_provider; } +saml2::Assertion * +gss_eap_saml_assertion_provider::initAssertion(void) +{ + delete m_assertion; + m_assertion = saml2::AssertionBuilder::buildAssertion(); + m_authenticated = false; + + return m_assertion; +} + /* * gss_eap_saml_attr_provider is for retrieving the underlying attributes. */ bool gss_eap_saml_attr_provider::getAssertion(int *authenticated, - const saml2::Assertion **pAssertion) const + saml2::Assertion **pAssertion, + bool createIfAbsent) const { - const gss_eap_saml_assertion_provider *saml; + gss_eap_saml_assertion_provider *saml; if (authenticated != NULL) *authenticated = false; @@ -353,14 +364,24 @@ gss_eap_saml_attr_provider::getAssertion(int *authenticated, if (pAssertion != NULL) *pAssertion = saml->getAssertion(); - return (saml->getAssertion() != NULL); + if (saml->getAssertion() == NULL) { + if (createIfAbsent) { + if (authenticated != NULL) + *authenticated = false; + if (pAssertion != NULL) + *pAssertion = saml->initAssertion(); + } else + return false; + } + + return true; } bool gss_eap_saml_attr_provider::getAttributeTypes(gss_eap_attr_enumeration_cb addAttribute, void *data) const { - const saml2::Assertion *assertion; + saml2::Assertion *assertion; int authenticated; if (!getAssertion(&authenticated, &assertion)) @@ -381,13 +402,13 @@ gss_eap_saml_attr_provider::getAttributeTypes(gss_eap_attr_enumeration_cb addAtt * using the same name syntax. */ /* For each attribute statement, look for an attribute match */ - const vector &statements = - assertion->getAttributeStatements(); + const vector &statements = + const_cast(assertion)->getAttributeStatements(); for (vector::const_iterator s = statements.begin(); s != statements.end(); ++s) { - const vector& attrs = + const vector &attrs = const_cast(*s)->getAttributes(); for (vector::const_iterator a = attrs.begin(); a != attrs.end(); ++a) { @@ -419,10 +440,22 @@ gss_eap_saml_attr_provider::getAttributeTypes(gss_eap_attr_enumeration_cb addAtt return true; } -ssize_t -gss_eap_saml_attr_provider::getAttributeIndex(const gss_buffer_t attr) const +static BaseRefVectorOf * +decomposeAttributeName(const gss_buffer_t attr) { - return -1; + XMLCh *qualifiedAttr = new XMLCh[attr->length + 1]; + XMLString::transcode((const char *)attr->value, qualifiedAttr, attr->length); + + BaseRefVectorOf *components = XMLString::tokenizeString(qualifiedAttr); + + delete qualifiedAttr; + + if (components->size() != 2) { + delete components; + components = NULL; + } + + return components; } bool @@ -430,26 +463,93 @@ gss_eap_saml_attr_provider::setAttribute(int complete, const gss_buffer_t attr, const gss_buffer_t value) { - return false; + saml2::Assertion *assertion; + saml2::Attribute *attribute; + saml2::AttributeValue *attributeValue; + saml2::AttributeStatement *attributeStatement; + + if (!getAssertion(NULL, &assertion, true)) + return false; + + if (assertion->getAttributeStatements().size() != 0) { + attributeStatement = assertion->getAttributeStatements().front(); + } else { + attributeStatement = saml2::AttributeStatementBuilder::buildAttributeStatement(); + assertion->getAttributeStatements().push_back(attributeStatement); + } + + /* Check the attribute name consists of name format | whsp | name */ + BaseRefVectorOf *components = decomposeAttributeName(attr); + if (components == NULL) + return false; + + attribute = saml2::AttributeBuilder::buildAttribute(); + attribute->setNameFormat(components->elementAt(0)); + attribute->setName(components->elementAt(1)); + + XMLCh *xmlValue = new XMLCh[value->length + 1]; + XMLString::transcode((const char *)value->value, xmlValue, attr->length); + + attributeValue = saml2::AttributeValueBuilder::buildAttributeValue(); + attributeValue->setTextContent(xmlValue); + + attribute->getAttributeValues().push_back(attributeValue); + + assert(attributeStatement != NULL); + attributeStatement->getAttributes().push_back(attribute); + + delete components; + delete xmlValue; + + return true; } bool -gss_eap_saml_attr_provider::deleteAttribute(const gss_buffer_t value) +gss_eap_saml_attr_provider::deleteAttribute(const gss_buffer_t attr) { - return false; -} + saml2::Assertion *assertion; + bool ret = false; -static BaseRefVectorOf * -decomposeAttributeName(const gss_buffer_t attr) -{ - XMLCh *qualifiedAttr = new XMLCh[attr->length + 1]; - XMLString::transcode((const char *)attr->value, qualifiedAttr, attr->length); + if (!getAssertion(NULL, &assertion) || + assertion->getAttributeStatements().size() == 0) + return false; - BaseRefVectorOf *components = XMLString::tokenizeString(qualifiedAttr); + /* Check the attribute name consists of name format | whsp | name */ + BaseRefVectorOf *components = decomposeAttributeName(attr); + if (components == NULL) + return false; - delete qualifiedAttr; + /* For each attribute statement, look for an attribute match */ + const vector &statements = + const_cast(assertion)->getAttributeStatements(); - return components; + for (vector::const_iterator s = statements.begin(); + s != statements.end(); + ++s) { + const vector &attrs = + const_cast(*s)->getAttributes(); + ssize_t index = -1, i = 0; + + /* There's got to be an easier way to do this */ + for (vector::const_iterator a = attrs.begin(); + a != attrs.end(); + ++a) { + if (XMLString::equals((*a)->getNameFormat(), components->elementAt(0)) && + XMLString::equals((*a)->getName(), components->elementAt(1))) { + index = i; + break; + } + ++i; + } + if (index != -1) { + (*s)->getAttributes().erase((*s)->getAttributes().begin() + index); + ret = true; + } + } + + delete components; + + return ret; } bool @@ -458,7 +558,7 @@ gss_eap_saml_attr_provider::getAttribute(const gss_buffer_t attr, int *complete, const saml2::Attribute **pAttribute) const { - const saml2::Assertion *assertion; + saml2::Assertion *assertion; if (authenticated != NULL) *authenticated = false; @@ -472,23 +572,21 @@ gss_eap_saml_attr_provider::getAttribute(const gss_buffer_t attr, /* Check the attribute name consists of name format | whsp | name */ BaseRefVectorOf *components = decomposeAttributeName(attr); - if (components == NULL || components->size() != 2) { - delete components; + if (components == NULL) return false; - } /* For each attribute statement, look for an attribute match */ - const vector &statements = - assertion->getAttributeStatements(); + const vector &statements = + const_cast(assertion)->getAttributeStatements(); const saml2::Attribute *ret = NULL; for (vector::const_iterator s = statements.begin(); s != statements.end(); ++s) { - const vector& attrs = + const vector &attrs = const_cast(*s)->getAttributes(); - for (vector::const_iterator a = attrs.begin(); a != attrs.end(); ++a) { + for (vector::const_iterator a = attrs.begin(); a != attrs.end(); ++a) { if (XMLString::equals((*a)->getNameFormat(), components->elementAt(0)) && XMLString::equals((*a)->getName(), components->elementAt(1))) { ret = *a; diff --git a/util_saml.h b/util_saml.h index 8beb821..25647db 100644 --- a/util_saml.h +++ b/util_saml.h @@ -74,7 +74,9 @@ public: bool initFromBuffer(const gss_eap_attr_ctx *ctx, const gss_buffer_t buffer); - const opensaml::saml2::Assertion *getAssertion(void) const { + opensaml::saml2::Assertion *initAssertion(void); + + opensaml::saml2::Assertion *getAssertion(void) const { return m_assertion; } bool authenticated(void) const { @@ -131,8 +133,8 @@ public: int *complete, const opensaml::saml2::Attribute **pAttribute) const; bool getAssertion(int *authenticated, - const opensaml::saml2::Assertion **pAssertion) const; - ssize_t getAttributeIndex(const gss_buffer_t attr) const; + opensaml::saml2::Assertion **pAssertion, + bool createIfAbsent = false) const; static bool init(void); static void finalize(void);