Collapse entity/role lookup in metadata API.
[shibboleth/sp.git] / shibsp / AbstractSPRequest.cpp
index 0229b2c..dffe00c 100644 (file)
 #include "Application.h"
 #include "ServiceProvider.h"
 #include "SessionCache.h"
-#include "util/CGIParser.h"
-
-#include <log4cpp/Category.hh>
 
 using namespace shibsp;
+using namespace opensaml;
 using namespace xmltooling;
-using namespace log4cpp;
 using namespace std;
 
 AbstractSPRequest::AbstractSPRequest()
-    : m_sp(NULL), m_mapper(NULL), m_app(NULL), m_session(NULL), m_log(&Category::getInstance(SHIBSP_LOGCAT)), m_parser(NULL)
+    : m_sp(NULL), m_mapper(NULL), m_app(NULL), m_sessionTried(false), m_session(NULL),
+        m_log(&Category::getInstance(SHIBSP_LOGCAT".SPRequest")), m_parser(NULL)
 {
     m_sp=SPConfig::getConfig().getServiceProvider();
     m_sp->lock();
@@ -75,7 +73,95 @@ const Application& AbstractSPRequest::getApplication() const
     return *m_app;
 }
 
-const char* AbstractSPRequest::getRequestURL() const {
+Session* AbstractSPRequest::getSession(bool checkTimeout, bool ignoreAddress, bool cache) const
+{
+    // Only attempt this once.
+    if (cache && m_sessionTried)
+        return m_session;
+    else if (cache)
+        m_sessionTried = true;
+
+    // Get session ID from cookie.
+    const Application& app = getApplication();
+    pair<string,const char*> shib_cookie = app.getCookieNameProps("_shibsession_");
+    const char* session_id = getCookie(shib_cookie.first.c_str());
+    if (!session_id || !*session_id)
+        return NULL;
+
+    // Need address checking and timeout settings.
+    time_t timeout=0;
+    if (checkTimeout || !ignoreAddress) {
+        const PropertySet* props=app.getPropertySet("Sessions");
+        if (props) {
+            if (checkTimeout) {
+                pair<bool,unsigned int> p=props->getUnsignedInt("timeout");
+                if (p.first)
+                    timeout = p.second;
+            }
+            pair<bool,bool> pcheck=props->getBool("consistentAddress");
+            if (pcheck.first)
+                ignoreAddress = !pcheck.second;
+        }
+    }
+
+    // The cache will either silently pass a session or NULL back, or throw an exception out.
+    Session* session = getServiceProvider().getSessionCache()->find(
+        session_id, app, ignoreAddress ? NULL : getRemoteAddr().c_str(), checkTimeout ? &timeout : NULL
+        );
+    if (cache)
+        m_session = session;
+    return session;
+}
+
+static char _x2c(const char *what)\r
+{\r
+    register char digit;\r
+\r
+    digit = (what[0] >= 'A' ? ((what[0] & 0xdf) - 'A')+10 : (what[0] - '0'));\r
+    digit *= 16;\r
+    digit += (what[1] >= 'A' ? ((what[1] & 0xdf) - 'A')+10 : (what[1] - '0'));\r
+    return(digit);\r
+}\r
+
+void AbstractSPRequest::setRequestURI(const char* uri)
+{
+    // Fix for bug 574, secadv 20061002\r
+    // Unescape URI up to query string delimiter by looking for %XX escapes.\r
+    // Adapted from Apache's util.c, ap_unescape_url function.\r
+    if (uri) {\r
+        while (*uri) {\r
+            if (*uri == '?') {\r
+                m_uri += uri;\r
+                break;\r
+            }\r
+            else if (*uri == ';') {\r
+                // If this is Java being stupid, skip everything up to the query string, if any.\r
+                if (!strncmp(uri, ";jsessionid=", 12)) {\r
+                    if (uri = strchr(uri, '?'))\r
+                        m_uri += uri;\r
+                    break;\r
+                }\r
+                else {\r
+                    m_uri += *uri;\r
+                }\r
+            }\r
+            else if (*uri != '%') {\r
+                m_uri += *uri;\r
+            }\r
+            else {\r
+                ++uri;\r
+                if (!isxdigit(*uri) || !isxdigit(*(uri+1)))\r
+                    throw ConfigurationException("Bad request, contained unsupported encoded characters.");\r
+                m_uri += _x2c(uri);\r
+                ++uri;\r
+            }\r
+            ++uri;\r
+        }\r
+    }\r
+}
+
+const char* AbstractSPRequest::getRequestURL() const
+{
     if (m_url.empty()) {
         // Compute the full target URL
         int port = getPort();
@@ -86,9 +172,7 @@ const char* AbstractSPRequest::getRequestURL() const {
             portstr << port;
             m_url += ":" + portstr.str();
         }
-        scheme = getRequestURI();
-        if (scheme)
-            m_url += scheme;
+        m_url += m_uri;
     }
     return m_url.c_str();
 }
@@ -148,6 +232,9 @@ const char* AbstractSPRequest::getCookie(const char* name) const
 
 const char* AbstractSPRequest::getHandlerURL(const char* resource) const
 {
+    if (!resource)
+        resource = getRequestURL();
+
     if (!m_handlerURL.empty() && resource && !strcmp(getRequestURL(),resource))
         return m_handlerURL.c_str();
         
@@ -233,10 +320,10 @@ const char* AbstractSPRequest::getHandlerURL(const char* resource) const
 void AbstractSPRequest::log(SPLogLevel level, const std::string& msg) const
 {
     reinterpret_cast<Category*>(m_log)->log(
-        (level == SPDebug ? log4cpp::Priority::DEBUG :
-        (level == SPInfo ? log4cpp::Priority::INFO :
-        (level == SPWarn ? log4cpp::Priority::WARN :
-        (level == SPError ? log4cpp::Priority::ERROR : log4cpp::Priority::CRIT)))),
+        (level == SPDebug ? Priority::DEBUG :
+        (level == SPInfo ? Priority::INFO :
+        (level == SPWarn ? Priority::WARN :
+        (level == SPError ? Priority::ERROR : Priority::CRIT)))),
         msg
         );
 }
@@ -244,9 +331,9 @@ void AbstractSPRequest::log(SPLogLevel level, const std::string& msg) const
 bool AbstractSPRequest::isPriorityEnabled(SPLogLevel level) const
 {
     return reinterpret_cast<Category*>(m_log)->isPriorityEnabled(
-        (level == SPDebug ? log4cpp::Priority::DEBUG :
-        (level == SPInfo ? log4cpp::Priority::INFO :
-        (level == SPWarn ? log4cpp::Priority::WARN :
-        (level == SPError ? log4cpp::Priority::ERROR : log4cpp::Priority::CRIT))))
+        (level == SPDebug ? Priority::DEBUG :
+        (level == SPInfo ? Priority::INFO :
+        (level == SPWarn ? Priority::WARN :
+        (level == SPError ? Priority::ERROR : Priority::CRIT))))
         );
 }