https://issues.shibboleth.net/jira/browse/SSPCPP-179
[shibboleth/cpp-sp.git] / nsapi_shib / nsapi_shib.cpp
index fe82299..1925021 100644 (file)
@@ -1,6 +1,6 @@
 /*
  *  Copyright 2001-2005 Internet2
- * 
+ *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * You may obtain a copy of the License at
    12/13/04
 */
 
-#include "config_win32.h"
+#if defined (_MSC_VER) || defined(__BORLANDC__)
+# include "config_win32.h"
+#else
+# include "config.h"
+#endif
+
 
 // SAML Runtime
 #include <saml/saml.h>
@@ -60,7 +65,9 @@ using namespace shibtarget;
 namespace {
     ShibTargetConfig* g_Config=NULL;
     string g_ServerName;
-    string g_ServerScheme;
+    string g_unsetHeaderValue;
+    bool g_checkSpoofing = false;
+    bool g_catchAll = true;
 }
 
 PlugManager::Factory SunRequestMapFactory;
@@ -94,15 +101,10 @@ extern "C" NSAPI_PUBLIC int nsapi_shib_init(pblock* pb, Session* sn, Request* rq
             }
         }
     }
-    name=pblock_findval("server-scheme",pb);
-    if (name)
-        g_ServerScheme=name;
 
     log_error(LOG_INFORM,"nsapi_shib_init",sn,rq,"nsapi_shib loaded for host (%s)",g_ServerName.c_str());
 
-#ifndef _DEBUG
     try {
-#endif
         const char* schemadir=pblock_findval("shib-schemas",pb);
         if (!schemadir)
             schemadir=getenv("SHIBSCHEMAS");
@@ -139,14 +141,25 @@ extern "C" NSAPI_PUBLIC int nsapi_shib_init(pblock* pb, Session* sn, Request* rq
         }
 
         daemon_atrestart(nsapi_shib_exit,NULL);
-#ifndef _DEBUG
+
+        IConfig* conf=g_Config->getINI();
+        Locker locker(conf);
+        const IPropertySet* props=conf->getPropertySet("Local");
+        if (props) {
+            pair<bool,const char*> unsetValue=props->getString("unsetHeaderValue");
+            if (unsetValue.first)
+                g_unsetHeaderValue = unsetValue.second;
+            pair<bool,bool> flag=props->getBool("checkSpoofing");
+            g_checkSpoofing = !flag.first || flag.second;
+            flag=props->getBool("catchAll");
+            g_catchAll = !flag.first || flag.second;
+        }
     }
-    catch (...) {
+    catch (exception&) {
         g_Config=NULL;
         pblock_nvinsert("error","caught exception, unable to initialize Shibboleth libraries",pb);
         return REQ_ABORTED;
     }
-#endif
     return REQ_PROCEED;
 }
 
@@ -156,16 +169,35 @@ extern "C" NSAPI_PUBLIC int nsapi_shib_init(pblock* pb, Session* sn, Request* rq
 class ShibTargetNSAPI : public ShibTarget
 {
 public:
-  ShibTargetNSAPI(pblock* pb, Session* sn, Request* rq) {
-    m_pb = pb;
-    m_sn = sn;
-    m_rq = rq;
-
-    // Get everything but hostname...
+  ShibTargetNSAPI(pblock* pb, Session* sn, Request* rq) : m_pb(pb), m_sn(sn), m_rq(rq), m_firsttime(true) {
+
+      // To determine whether SSL is active or not, we're supposed to rely
+      // on the security_active macro. For iPlanet 4.x, this works.
+      // For Sun 7.x, it's useless and appears to be on or off based
+      // on whether ANY SSL support is enabled for a vhost. Sun 6.x is unknown.
+      // As a fix, there's a conf variable called $security that can be mapped
+      // into a function parameter: security_active="$security"
+      // We check for this parameter, and rely on the macro if it isn't set.
+      // This doubles as a scheme virtualizer for load balanced scenarios
+      // since you can set the parameter to 1 or 0 as needed.
+      const char* scheme;
+      const char* sa = pblock_findval("security_active", pb);
+      if (sa)
+          scheme = (*sa == '1') ? "https" : "http";
+      else if (security_active)
+          scheme = "https";
+      else
+          scheme = "http";
+
+      // A similar issue exists for the port. server_portnum is no longer
+      // working on at least Sun 7.x, and returns the first listener's port
+      // rather than whatever port is actually used for the request. Nice job, Sun.
+      sa = pblock_findval("server_portnum", pb);
+      int port = (sa && *sa) ? atoi(sa) : server_portnum;
+
+    // Get everything else but hostname...
     const char* uri=pblock_findval("uri", rq->reqpb);
     const char* qstr=pblock_findval("query", rq->reqpb);
-    int port=server_portnum;
-    const char* scheme=security_active ? "https" : "http";
     const char* host=NULL;
 
     string url;
@@ -173,10 +205,10 @@ public:
         url=uri;
     if (qstr)
         url=url + '?' + qstr;
-    
+
 #ifdef vs_is_default_vs
     // This is 6.0 or later, so we can distinguish requests to name-based vhosts.
-    if (!vs_is_default_vs)
+    if (!vs_is_default_vs(request_get_vs(m_rq)))
         // The beauty here is, a non-default vhost can *only* be accessed if the client
         // specified the exact name in the Host header. So we can trust the Host header.
         host=pblock_findval("host", rq->headers);
@@ -187,13 +219,21 @@ public:
 
     char* content_type = "";
     request_header("content-type", &content_type, sn, rq);
-      
-    const char *remote_ip = pblock_findval("ip", sn->client);
-    const char *method = pblock_findval("method", rq->reqpb);
+
+    const charremote_ip = pblock_findval("ip", sn->client);
+    const charmethod = pblock_findval("method", rq->reqpb);
 
     init(scheme, host, port, url.c_str(), content_type, remote_ip, method);
+
+    // See if this is the first time we've run.
+    method = pblock_findval("auth-type", rq->vars);
+    if (method && !strcmp(method, "shibboleth"))
+        m_firsttime = false;
+    if (!m_firsttime || rq->orig_rq)
+        log(LogLevelDebug, "nsapi_shib function running more than once");
+  }
+  ~ShibTargetNSAPI() {
   }
-  ~ShibTargetNSAPI() {}
 
   virtual void log(ShibLogLevel level, const string &msg) {
     ShibTarget::log(level,msg);
@@ -210,7 +250,7 @@ public:
     string cookie = name + '=' + value;
     pblock_nvinsert("Set-Cookie", cookie.c_str(), m_rq->srvhdrs);
   }
-  virtual string getArgs(void) { 
+  virtual string getArgs(void) {
     const char *q = pblock_findval("query", m_rq->reqpb);
     return string(q ? q : "");
   }
@@ -225,7 +265,7 @@ public:
       string cgistr;
       while (cl && ch != IO_EOF) {
         ch=netbuf_getc(m_sn->inbuf);
-      
+
         // Check for error.
         if(ch==IO_ERROR)
           break;
@@ -238,23 +278,70 @@ public:
     }
   }
   virtual void clearHeader(const string &name) {
-    param_free(pblock_remove(name.c_str(), m_rq->headers));
+    if (g_checkSpoofing && m_firsttime && !m_rq->orig_rq && m_allhttp.empty()) {
+      // Populate the set of client-supplied headers for spoof checking.
+      const pb_entry* entry;
+      for (int i=0; i<m_rq->headers->hsize; ++i) {
+          entry = m_rq->headers->ht[i];
+          while (entry) {
+              string cgiversion("HTTP_");
+              const char* pch = entry->param->name;
+              while (*pch) {
+                  cgiversion += (isalnum(*pch) ? toupper(*pch) : '_');
+                  pch++;
+              }
+              m_allhttp.insert(cgiversion);
+              entry = entry->next;
+          }
+      }
+    }
+    if (name=="REMOTE_USER") {
+        if (g_checkSpoofing && m_firsttime && !m_rq->orig_rq && m_allhttp.count("HTTP_REMOTE_USER") > 0)
+            throw SAMLException("Attempt to spoof header ($1) was detected.", params(1, name.c_str()));
+        param_free(pblock_remove("auth-user",m_rq->vars));
+        param_free(pblock_remove("remote-user",m_rq->headers));
+        pblock_nvinsert("remote-user", g_unsetHeaderValue.c_str(), m_rq->headers);
+    }
+    else {
+        if (g_checkSpoofing && m_firsttime && !m_rq->orig_rq) {
+            // Map to the expected CGI variable name.
+            string transformed("HTTP_");
+            const char* pch = name.c_str();
+            while (*pch) {
+                transformed += (isalnum(*pch) ? toupper(*pch) : '_');
+                pch++;
+            }
+            if (m_allhttp.count(transformed) > 0)
+                throw SAMLException("Attempt to spoof header ($1) was detected.", params(1, name.c_str()));
+        }
+        param_free(pblock_remove(name.c_str(), m_rq->headers));
+        pblock_nvinsert(name.c_str(), g_unsetHeaderValue.c_str(), m_rq->headers);
+    }
   }
   virtual void setHeader(const string &name, const string &value) {
+    param_free(pblock_remove(name.c_str(), m_rq->headers));
     pblock_nvinsert(name.c_str(), value.c_str() ,m_rq->headers);
   }
   virtual string getHeader(const string &name) {
     char *hdr = NULL;
-    if (request_header(const_cast<char*>(name.c_str()), &hdr, m_sn, m_rq) != REQ_PROCEED)
-      hdr = NULL;
+    if (request_header(const_cast<char*>(name.c_str()), &hdr, m_sn, m_rq) != REQ_PROCEED) {
+      string n;
+      const char* pch = name.c_str();
+      while (*pch)
+          n += tolower(*(pch++));
+      if (request_header(const_cast<char*>(n.c_str()), &hdr, m_sn, m_rq) != REQ_PROCEED)
+          return "";
+    }
     return string(hdr ? hdr : "");
   }
   virtual void setRemoteUser(const string &user) {
+    param_free(pblock_remove("remote-user",m_rq->headers));
     pblock_nvinsert("remote-user", user.c_str(), m_rq->headers);
     pblock_nvinsert("auth-user", user.c_str(), m_rq->vars);
   }
   virtual string getRemoteUser(void) {
-    return getHeader("remote-user");
+    const char* ru = pblock_findval("auth-user", m_rq->vars);
+    return ru ? ru : "";
   }
 
   virtual void* sendPage(
@@ -282,6 +369,7 @@ public:
     pblock_nvinsert("expires", "01-Jan-1997 12:00:00 GMT", m_rq->srvhdrs);
     pblock_nvinsert("cache-control", "private,no-store,no-cache", m_rq->srvhdrs);
     pblock_nvinsert("location", url.c_str(), m_rq->srvhdrs);
+    pblock_nvinsert("connection","close",m_rq->srvhdrs);
     protocol_status(m_sn, m_rq, PROTOCOL_REDIRECT, NULL);
     protocol_start_response(m_sn, m_rq);
     return (void*)REQ_ABORTED;
@@ -292,6 +380,8 @@ public:
   pblock* m_pb;
   Session* m_sn;
   Request* m_rq;
+  set<string> m_allhttp;
+  bool m_firsttime;
 };
 
 /********************************************************************************/
@@ -307,41 +397,41 @@ int WriteClientError(Session* sn, Request* rq, char* func, char* msg)
 #define FUNC "shibboleth"
 extern "C" NSAPI_PUBLIC int nsapi_shib(pblock* pb, Session* sn, Request* rq)
 {
-  ostringstream threadid;
-  threadid << "[" << getpid() << "] nsapi_shib" << '\0';
-  saml::NDC ndc(threadid.str().c_str());
-
-  try {
-    ShibTargetNSAPI stn(pb, sn, rq);
-
-    // Check user authentication
-    pair<bool,void*> res = stn.doCheckAuthN();
-    if (res.first) return (int)res.second;
-
-    // user authN was okay -- export the assertions now
-    param_free(pblock_remove("auth-user",rq->vars));
-    // This seems to be required in order to eventually set
-    // the auth-user var.
-    pblock_nvinsert("auth-type","shibboleth",rq->vars);
-    res = stn.doExportAssertions();
-    if (res.first) return (int)res.second;
-
-    // Check the Authorization
-    res = stn.doCheckAuthZ();
-    if (res.first) return (int)res.second;
-
-    // this user is ok.
-    return REQ_PROCEED;
-  }
-  catch (SAMLException& e) {
-    log_error(LOG_FAILURE,FUNC,sn,rq,const_cast<char*>(e.what()));
-    return WriteClientError(sn, rq, FUNC, "Shibboleth filter threw an exception, see web server log for error.");
-  }
-#ifndef _DEBUG
-  catch (...) {
-    return WriteClientError(sn, rq, FUNC, "Shibboleth filter threw an uncaught exception.");
-  }
-#endif
+    ostringstream threadid;
+    threadid << "[" << getpid() << "] nsapi_shib" << '\0';
+    saml::NDC ndc(threadid.str().c_str());
+
+    try {
+        ShibTargetNSAPI stn(pb, sn, rq);
+
+        // Check user authentication
+        pair<bool,void*> res = stn.doCheckAuthN();
+        if (res.first) return (int)res.second;
+
+        // user authN was okay -- export the assertions now
+        param_free(pblock_remove("auth-user",rq->vars));
+        // This seems to be required in order to eventually set
+        // the auth-user var.
+        pblock_nvinsert("auth-type","shibboleth",rq->vars);
+        res = stn.doExportAssertions();
+        if (res.first) return (int)res.second;
+
+        // Check the Authorization
+        res = stn.doCheckAuthZ();
+        if (res.first) return (int)res.second;
+
+        // this user is ok.
+        return REQ_PROCEED;
+    }
+    catch (exception& e) {
+        log_error(LOG_FAILURE,FUNC,sn,rq,const_cast<char*>(e.what()));
+        return WriteClientError(sn, rq, FUNC, "Shibboleth filter threw an exception, see web server log for error.");
+    }
+    catch (...) {
+        if (g_catchAll)
+            return WriteClientError(sn, rq, FUNC, "Shibboleth filter threw an uncaught exception.");
+        throw;
+    }
 }
 
 
@@ -349,27 +439,27 @@ extern "C" NSAPI_PUBLIC int nsapi_shib(pblock* pb, Session* sn, Request* rq)
 #define FUNC "shib_handler"
 extern "C" NSAPI_PUBLIC int shib_handler(pblock* pb, Session* sn, Request* rq)
 {
-  ostringstream threadid;
-  threadid << "[" << getpid() << "] shib_handler" << '\0';
-  saml::NDC ndc(threadid.str().c_str());
+    ostringstream threadid;
+    threadid << "[" << getpid() << "] shib_handler" << '\0';
+    saml::NDC ndc(threadid.str().c_str());
 
-  try {
-    ShibTargetNSAPI stn(pb, sn, rq);
+    try {
+        ShibTargetNSAPI stn(pb, sn, rq);
 
-    pair<bool,void*> res = stn.doHandler();
-    if (res.first) return (int)res.second;
+        pair<bool,void*> res = stn.doHandler();
+        if (res.first) return (int)res.second;
 
-    return WriteClientError(sn, rq, FUNC, "Shibboleth handler did not do anything.");
-  }
-  catch (SAMLException& e) {
-    log_error(LOG_FAILURE,FUNC,sn,rq,const_cast<char*>(e.what()));
-    return WriteClientError(sn, rq, FUNC, "Shibboleth handler threw an exception, see web server log for error.");
-  }
-#ifndef _DEBUG
-  catch (...) {
-    return WriteClientError(sn, rq, FUNC, "Shibboleth handler threw an unknown exception.");
-  }
-#endif
+        return WriteClientError(sn, rq, FUNC, "Shibboleth handler did not do anything.");
+    }
+    catch (exception& e) {
+        log_error(LOG_FAILURE,FUNC,sn,rq,const_cast<char*>(e.what()));
+        return WriteClientError(sn, rq, FUNC, "Shibboleth handler threw an exception, see web server log for error.");
+    }
+    catch (...) {
+        if (g_catchAll)
+            return WriteClientError(sn, rq, FUNC, "Shibboleth handler threw an unknown exception.");
+        throw;
+    }
 }
 
 
@@ -381,7 +471,7 @@ public:
     void lock() { m_mapper->lock(); }
     void unlock() { m_stKey->setData(NULL); m_propsKey->setData(NULL); m_mapper->unlock(); }
     Settings getSettings(ShibTarget* st) const;
-    
+
     pair<bool,bool> getBool(const char* name, const char* ns=NULL) const;
     pair<bool,const char*> getString(const char* name, const char* ns=NULL) const;
     pair<bool,const XMLCh*> getXMLString(const char* name, const char* ns=NULL) const;
@@ -459,13 +549,27 @@ pair<bool,const XMLCh*> SunRequestMapper::getXMLString(const char* name, const c
 
 pair<bool,unsigned int> SunRequestMapper::getUnsignedInt(const char* name, const char* ns) const
 {
+    ShibTargetNSAPI* stn=reinterpret_cast<ShibTargetNSAPI*>(m_stKey->getData());
     const IPropertySet* s=reinterpret_cast<const IPropertySet*>(m_propsKey->getData());
+    if (stn && !ns && name) {
+        // Override int properties.
+        const char* param=pblock_findval(name,stn->m_pb);
+        if (param)
+            return pair<bool,unsigned int>(true,strtol(param,NULL,10));
+    }
     return s ? s->getUnsignedInt(name,ns) : pair<bool,unsigned int>(false,0);
 }
 
 pair<bool,int> SunRequestMapper::getInt(const char* name, const char* ns) const
 {
+    ShibTargetNSAPI* stn=reinterpret_cast<ShibTargetNSAPI*>(m_stKey->getData());
     const IPropertySet* s=reinterpret_cast<const IPropertySet*>(m_propsKey->getData());
+    if (stn && !ns && name) {
+        // Override int properties.
+        const char* param=pblock_findval(name,stn->m_pb);
+        if (param)
+            return pair<bool,int>(true,atoi(param));
+    }
     return s ? s->getInt(name,ns) : pair<bool,int>(false,0);
 }