Fix backslashes in SHIBSP_PREFIX variable by manually creating it during the script...
[shibboleth/sp.git] / nsapi_shib / nsapi_shib.cpp
index 55a53da..d1efdc4 100644 (file)
@@ -75,6 +75,8 @@ namespace {
     string g_ServerName;
     string g_ServerScheme;
     string g_unsetHeaderValue;
+    bool g_checkSpoofing = true;
+    bool g_catchAll = false;
 
     static const XMLCh path[] =     UNICODE_LITERAL_4(p,a,t,h);
     static const XMLCh validate[] = UNICODE_LITERAL_8(v,a,l,i,d,a,t,e);
@@ -118,24 +120,18 @@ extern "C" NSAPI_PUBLIC int nsapi_shib_init(pblock* pb, ::Session* sn, Request*
     log_error(LOG_INFORM,"nsapi_shib_init",sn,rq,"nsapi_shib loaded for host (%s)",g_ServerName.c_str());
 
     const char* schemadir=pblock_findval("shib-schemas",pb);
-    if (!schemadir)
-        schemadir=getenv("SHIBSP_SCHEMAS");
-    if (!schemadir)
-        schemadir=SHIBSP_SCHEMAS;
-    const char* config=pblock_findval("shib-config",pb);
-    if (!config)
-        config=getenv("SHIBSP_CONFIG");
-    if (!config)
-        config=SHIBSP_CONFIG;
+    const char* prefix=pblock_findval("shib-prefix",pb);
+
     g_Config=&SPConfig::getConfig();
     g_Config->setFeatures(
         SPConfig::Listener |
         SPConfig::Caching |
         SPConfig::RequestMapping |
         SPConfig::InProcess |
-        SPConfig::Logging
+        SPConfig::Logging |
+        SPConfig::Handlers
         );
-    if (!g_Config->init(schemadir)) {
+    if (!g_Config->init(schemadir,prefix)) {
         g_Config=NULL;
         pblock_nvinsert("error","unable to initialize Shibboleth libraries",pb);
         return REQ_ABORTED;
@@ -143,6 +139,12 @@ extern "C" NSAPI_PUBLIC int nsapi_shib_init(pblock* pb, ::Session* sn, Request*
 
     g_Config->RequestMapperManager.registerFactory(XML_REQUEST_MAPPER,&SunRequestMapFactory);
 
+    const char* config=pblock_findval("shib-config",pb);
+    if (!config)
+        config=getenv("SHIBSP_CONFIG");
+    if (!config)
+        config=SHIBSP_CONFIG;
+
     try {
         xercesc::DOMDocument* dummydoc=XMLToolingConfig::getConfig().getParser().newDocument();
         XercesJanitor<xercesc::DOMDocument> docjanitor(dummydoc);
@@ -170,6 +172,10 @@ extern "C" NSAPI_PUBLIC int nsapi_shib_init(pblock* pb, ::Session* sn, Request*
         pair<bool,const char*> unsetValue=props->getString("unsetHeaderValue");
         if (unsetValue.first)
             g_unsetHeaderValue = unsetValue.second;
+        pair<bool,bool> flag=props->getBool("checkSpoofing");
+        g_checkSpoofing = !flag.first || flag.second;
+        flag=props->getBool("catchAll");
+        g_catchAll = flag.first && flag.second;
     }
     return REQ_PROCEED;
 }
@@ -179,42 +185,38 @@ extern "C" NSAPI_PUBLIC int nsapi_shib_init(pblock* pb, ::Session* sn, Request*
 
 class ShibTargetNSAPI : public AbstractSPRequest
 {
-  string m_uri;
   mutable string m_body;
-  mutable bool m_gotBody;
-  vector<string> m_certs;
+  mutable bool m_gotBody,m_firsttime;
+  mutable vector<string> m_certs;
+  set<string> m_allhttp;
 
 public:
-  ShibTargetNSAPI(pblock* pb, ::Session* sn, Request* rq) : m_gotBody(false) {
-    m_pb = pb;
-    m_sn = sn;
-    m_rq = rq;
+  pblock* m_pb;
+  ::Session* m_sn;
+  Request* m_rq;
+
+  ShibTargetNSAPI(pblock* pb, ::Session* sn, Request* rq)
+      : AbstractSPRequest(SHIBSP_LOGCAT".NSAPI"), m_gotBody(false), m_firsttime(true), m_pb(pb), m_sn(sn), m_rq(rq) {
 
-    // Get everything but hostname...
     const char* uri=pblock_findval("uri", rq->reqpb);
     const char* qstr=pblock_findval("query", rq->reqpb);
 
-    string url;
-    if (uri) {
-        url = uri;
-        m_uri = uri;
+    if (qstr) {
+        string temp = string(uri) + '?' + qstr;
+        setRequestURI(temp.c_str());
     }
-    if (qstr)
-        url=url + '?' + qstr;
-    
-    const char* host=NULL;
-#ifdef vs_is_default_vs
-    // This is 6.0 or later, so we can distinguish requests to name-based vhosts.
-    if (!vs_is_default_vs)
-        // The beauty here is, a non-default vhost can *only* be accessed if the client
-        // specified the exact name in the Host header. So we can trust the Host header.
-        host=pblock_findval("host", rq->headers);
-    else
-#endif
-    // In other cases, we're going to rely on the initialization process...
-    host=g_ServerName.c_str();
+    else {
+        setRequestURI(uri);
+    }
+
+    // See if this is the first time we've run.
+    qstr = pblock_findval("auth-type", rq->vars);
+    if (qstr && !strcmp(qstr, "shibboleth"))
+        m_firsttime = false;
+    if (!m_firsttime || rq->orig_rq)
+        log(SPDebug, "nsapi_shib function running more than once");
   }
-  ~ShibTargetNSAPI() {}
+  ~ShibTargetNSAPI() { }
 
   const char* getScheme() const {
     return security_active ? "https" : "http";
@@ -222,7 +224,7 @@ public:
   const char* getHostname() const {
 #ifdef vs_is_default_vs
     // This is 6.0 or later, so we can distinguish requests to name-based vhosts.
-    if (!vs_is_default_vs)
+    if (!vs_is_default_vs(request_get_vs(m_rq)))
         // The beauty here is, a non-default vhost can *only* be accessed if the client
         // specified the exact name in the Host header. So we can trust the Host header.
         return pblock_findval("host", m_rq->headers);
@@ -234,9 +236,6 @@ public:
   int getPort() const {
     return server_portnum;
   }
-  const char* getRequestURI() const {
-    return m_uri.c_str();
-  }
   const char* getMethod() const {
     return pblock_findval("method", m_rq->reqpb);
   }
@@ -255,7 +254,7 @@ public:
   string getRemoteAddr() const {
     return pblock_findval("ip", m_sn->client);
   }
-  void log(SPLogLevel level, const string& msg) {
+  void log(SPLogLevel level, const string& msg) const {
     AbstractSPRequest::log(level,msg);
     if (level>=SPError)
         log_error(LOG_FAILURE, "nsapi_shib", m_sn, m_rq, const_cast<char*>(msg.c_str()));
@@ -286,31 +285,59 @@ public:
       return m_body.c_str();
     }
   }
-  void clearHeader(const char* name) {
-    if (!strcmp(name,"REMOTE_USER")) {
-        param_free(pblock_remove("auth-user",m_rq->vars));
-        param_free(pblock_remove("remote-user",m_rq->headers));
-    }
-    else {
-        param_free(pblock_remove(name, m_rq->headers));
-        pblock_nvinsert(name, g_unsetHeaderValue.c_str() ,m_rq->headers);
+  void clearHeader(const char* rawname, const char* cginame) {
+    if (g_checkSpoofing && m_firsttime && !m_rq->orig_rq) {
+        if (m_allhttp.empty()) {
+            // Populate the set of client-supplied headers for spoof checking.
+            const pb_entry* entry;
+            for (int i=0; i<m_rq->headers->hsize; ++i) {
+                entry = m_rq->headers->ht[i];
+                while (entry) {
+                    string cgiversion("HTTP_");
+                    const char* pch = entry->param->name;
+                    while (*pch) {
+                        cgiversion += (isalnum(*pch) ? toupper(*pch) : '_');
+                        pch++;
+                    }
+                    m_allhttp.insert(cgiversion);
+                    entry = entry->next;
+                }
+            }
+        }
+        if (m_allhttp.count(cginame) > 0)
+            throw opensaml::SecurityPolicyException("Attempt to spoof header ($1) was detected.", params(1, rawname));
     }
+    param_free(pblock_remove(rawname, m_rq->headers));
+    pblock_nvinsert(rawname, g_unsetHeaderValue.c_str(), m_rq->headers);
   }
   void setHeader(const char* name, const char* value) {
+    param_free(pblock_remove(name, m_rq->headers));
     pblock_nvinsert(name, value, m_rq->headers);
   }
   string getHeader(const char* name) const {
+    // NSAPI headers tend to be lower case. We'll special case "cookie" since it's used a lot.
     char* hdr = NULL;
-    if (request_header(const_cast<char*>(name), &hdr, m_sn, m_rq) != REQ_PROCEED)
-      hdr = NULL;
+    int cookie = strcmp(name, "Cookie");
+    if (cookie == 0)
+        name = "cookie";
+    if (request_header(const_cast<char*>(name), &hdr, m_sn, m_rq) != REQ_PROCEED) {
+      // We didn't get a hit, so we'll try a lower-casing operation, unless we already did...
+      if (cookie == 0)
+          return "";
+      string n;
+      while (*name)
+          n += tolower(*(name++));
+      if (request_header(const_cast<char*>(n.c_str()), &hdr, m_sn, m_rq) != REQ_PROCEED)
+          return "";
+    }
     return string(hdr ? hdr : "");
   }
   void setRemoteUser(const char* user) {
-    pblock_nvinsert("remote-user", user, m_rq->headers);
     pblock_nvinsert("auth-user", user, m_rq->vars);
   }
   string getRemoteUser() const {
-    return getHeader("remote-user");
+    const char* ru = pblock_findval("auth-user", m_rq->vars);
+    return ru ? ru : "";
   }
   void setResponseHeader(const char* name, const char* value) {
     pblock_nvinsert(name, value, m_rq->srvhdrs);
@@ -344,12 +371,13 @@ public:
   long returnDecline() { return REQ_NOACTION; }
   long returnOK() { return REQ_PROCEED; }
   const vector<string>& getClientCertificates() const {
+      if (m_certs.empty()) {
+          const char* cert = pblock_findval("auth-cert", m_rq->vars);
+          if (cert)
+              m_certs.push_back(cert);
+      }
       return m_certs;
   }
-
-  pblock* m_pb;
-  ::Session* m_sn;
-  Request* m_rq;
 };
 
 /********************************************************************************/
@@ -395,11 +423,12 @@ extern "C" NSAPI_PUBLIC int nsapi_shib(pblock* pb, ::Session* sn, Request* rq)
     log_error(LOG_FAILURE,FUNC,sn,rq,const_cast<char*>(e.what()));
     return WriteClientError(sn, rq, FUNC, "Shibboleth module threw an exception, see web server log for error.");
   }
-#ifndef _DEBUG
   catch (...) {
-    return WriteClientError(sn, rq, FUNC, "Shibboleth module threw an uncaught exception.");
+    log_error(LOG_FAILURE,FUNC,sn,rq,const_cast<char*>("Shibboleth module threw an unknown exception."));
+    if (g_catchAll)
+        return WriteClientError(sn, rq, FUNC, "Shibboleth module threw an unknown exception.");
+    throw;
   }
-#endif
 }
 
 
@@ -423,11 +452,11 @@ extern "C" NSAPI_PUBLIC int shib_handler(pblock* pb, ::Session* sn, Request* rq)
     log_error(LOG_FAILURE,FUNC,sn,rq,const_cast<char*>(e.what()));
     return WriteClientError(sn, rq, FUNC, "Shibboleth handler threw an exception, see web server log for error.");
   }
-#ifndef _DEBUG
   catch (...) {
-    return WriteClientError(sn, rq, FUNC, "Shibboleth handler threw an unknown exception.");
+    if (g_catchAll)
+        return WriteClientError(sn, rq, FUNC, "Shibboleth handler threw an unknown exception.");
+    throw;
   }
-#endif
 }
 
 
@@ -438,15 +467,17 @@ public:
     ~SunRequestMapper() { delete m_mapper; delete m_stKey; delete m_propsKey; }
     Lockable* lock() { return m_mapper->lock(); }
     void unlock() { m_stKey->setData(NULL); m_propsKey->setData(NULL); m_mapper->unlock(); }
-    Settings getSettings(const SPRequest& request) const;
+    Settings getSettings(const HTTPRequest& request) const;
     
+    const PropertySet* getParent() const { return NULL; }
     void setParent(const PropertySet*) {}
     pair<bool,bool> getBool(const char* name, const char* ns=NULL) const;
     pair<bool,const char*> getString(const char* name, const char* ns=NULL) const;
     pair<bool,const XMLCh*> getXMLString(const char* name, const char* ns=NULL) const;
     pair<bool,unsigned int> getUnsignedInt(const char* name, const char* ns=NULL) const;
     pair<bool,int> getInt(const char* name, const char* ns=NULL) const;
-    const PropertySet* getPropertySet(const char* name, const char* ns="urn:mace:shibboleth:2.0:native:sp:config") const;
+    void getAll(map<string,const char*>& properties) const;
+    const PropertySet* getPropertySet(const char* name, const char* ns=shibspconstants::ASCII_SHIB2SPCONFIG_NS) const;
     const xercesc::DOMElement* getElement() const;
 
 private:
@@ -467,7 +498,7 @@ SunRequestMapper::SunRequestMapper(const xercesc::DOMElement* e) : m_mapper(NULL
     m_propsKey=ThreadKey::create(NULL);
 }
 
-RequestMapper::Settings SunRequestMapper::getSettings(const SPRequest& request) const
+RequestMapper::Settings SunRequestMapper::getSettings(const HTTPRequest& request) const
 {
     Settings s=m_mapper->getSettings(request);
     m_stKey->setData((void*)dynamic_cast<const ShibTargetNSAPI*>(&request));
@@ -537,6 +568,25 @@ pair<bool,int> SunRequestMapper::getInt(const char* name, const char* ns) const
     return s ? s->getInt(name,ns) : pair<bool,int>(false,0);
 }
 
+void SunRequestMapper::getAll(map<string,const char*>& properties) const
+{
+    const ShibTargetNSAPI* stn=reinterpret_cast<const ShibTargetNSAPI*>(m_stKey->getData());
+    const PropertySet* s=reinterpret_cast<const PropertySet*>(m_propsKey->getData());
+    if (s)
+        s->getAll(properties);
+    if (!stn)
+        return;
+    properties["authType"] = "shibboleth";
+    const pb_entry* entry;
+    for (int i=0; i<stn->m_pb->hsize; ++i) {
+        entry = stn->m_pb->ht[i];
+        while (entry) {
+            properties[entry->param->name] = entry->param->value;
+            entry = entry->next;
+        }
+    }
+}
+
 const PropertySet* SunRequestMapper::getPropertySet(const char* name, const char* ns) const
 {
     const PropertySet* s=reinterpret_cast<const PropertySet*>(m_propsKey->getData());