vendor attribute removal
authorvenaas <venaas>
Wed, 5 Mar 2008 15:00:00 +0000 (15:00 +0000)
committervenaas <venaas@e88ac4ed-0b26-0410-9574-a7f39faa03bf>
Wed, 5 Mar 2008 15:00:00 +0000 (15:00 +0000)
git-svn-id: https://svn.testnett.uninett.no/radsecproxy/trunk@222 e88ac4ed-0b26-0410-9574-a7f39faa03bf

radsecproxy.c
radsecproxy.h

index d7bf838..8e319f3 100644 (file)
@@ -1377,11 +1377,66 @@ int msmppe(unsigned char *attrs, int length, uint8_t type, char *attrtxt, struct
     return 1;
 }
 
-void removeattrs(uint8_t *buf, uint8_t *rmattrs) {
+int findvendorsubattr(uint32_t *attrs, uint32_t vendor, uint8_t subattr) {
+    if (!attrs)
+       return 0;
+    
+    for (; attrs[0]; attrs += 2)
+       if (attrs[0] == vendor && attrs[1] == subattr)
+           return 1;
+    return 0;
+}
+
+int dovendorrewrite(uint8_t *attrs, uint16_t length, uint32_t *removevendorattrs) {
+    uint8_t alen, sublen, rmlen = 0;
+    uint32_t vendor = *(uint32_t *)ATTRVAL(attrs);
+    uint8_t *subattrs;
+    
+    if (!removevendorattrs)
+       return 0;
+
+    while (*removevendorattrs && *removevendorattrs != vendor)
+       removevendorattrs += 2;
+    if (!*removevendorattrs)
+       return 0;
+    
+    alen = ATTRLEN(attrs);
+
+    if (findvendorsubattr(removevendorattrs, vendor, -1)) {
+       /* remove entire vendor attribute */
+       memmove(attrs, attrs + alen, length - alen);
+       return alen;
+    }
+
+    sublen = alen - 4;
+    subattrs = ATTRVAL(attrs) + 4;
+    
+    if (!attrvalidate(subattrs, sublen)) {
+       debug(DBG_WARN, "dovendorrewrite: vendor attribute validation failed, no rewrite");
+       return 0;
+    }
+
+    length -= 6;
+    while (sublen > 1) {
+       alen = ATTRLEN(subattrs);
+       sublen -= alen;
+       length -= alen;
+       if (findvendorsubattr(removevendorattrs, vendor, ATTRTYPE(subattrs))) {
+           memmove(subattrs, subattrs + alen, length);
+           rmlen += alen;
+       } else
+           subattrs += alen;
+    }
+
+    ATTRLEN(attrs) -= rmlen;
+    return rmlen;
+}
+
+void dorewrite(uint8_t *buf, struct rewrite *rewrite) {
     uint8_t *attrs, alen;
     uint16_t len, rmlen = 0;
-
-    if (!rmattrs)
+    
+    if (!rewrite || (!rewrite->removeattrs && !rewrite->removevendorattrs))
        return;
 
     len = RADLEN(buf) - 20;
@@ -1389,10 +1444,12 @@ void removeattrs(uint8_t *buf, uint8_t *rmattrs) {
     while (len > 1) {
        alen = ATTRLEN(attrs);
        len -= alen;
-       if (strchr((char *)rmattrs, ATTRTYPE(attrs))) {
+       if (rewrite->removeattrs && strchr((char *)rewrite->removeattrs, ATTRTYPE(attrs))) {
            memmove(attrs, attrs + alen, len);
            rmlen += alen;
-       } else
+       } else if (ATTRTYPE(attrs) == RAD_Attr_Vendor_Specific && rewrite->removevendorattrs)
+           rmlen += dovendorrewrite(attrs, len, rewrite->removevendorattrs);
+       else
            attrs += alen;
     }
     if (rmlen)
@@ -1632,11 +1689,11 @@ void radsrv(struct request *rq) {
     }
 
     /* code == RAD_Access_Request */
-    if (rq->from->conf->removeattrs) {
-       removeattrs(rq->buf, rq->from->conf->removeattrs);
+    if (rq->from->conf->rewrite) {
+       dorewrite(rq->buf, rq->from->conf->rewrite);
        len = RADLEN(rq->buf) - 20;
     }
-
+    
     attr = attrget(attrs, len, RAD_Attr_User_Name);
     if (!attr) {
        debug(DBG_WARN, "radsrv: ignoring request, no username attribute");
@@ -1804,11 +1861,11 @@ int replyh(struct server *server, unsigned char *buf) {
        return 0;
     }
 
-    if (server->conf->removeattrs) {
-       removeattrs(buf, server->conf->removeattrs);
+    if (server->conf->rewrite) {
+       dorewrite(buf, server->conf->rewrite);
        len = RADLEN(buf) - 20;
     }
-
+    
     /* MS MPPE */
     for (attr = attrs; (attr = attrget(attr, len - (attr - attrs), RAD_Attr_Vendor_Specific)); attr += ATTRLEN(attr)) {
        if (ATTRVALLEN(attr) <= 4)
@@ -2562,32 +2619,46 @@ int addrewriteattr(struct clsrvconf *conf, char *rewriteattr) {
 /* should accept both names and numeric values, only numeric right now */
 uint8_t attrname2val(char *attrname) {
     int val = 0;
-
+    
     val = atoi(attrname);
     return val > 0 && val < 256 ? val : 0;
 }
 
+/* should accept both names and numeric values, only numeric right now */
+int vattrname2val(char *attrname, uint32_t *vendor, uint32_t *type) {
+    char *s;
+    
+    *vendor = atoi(attrname);
+    s = strchr(attrname, ':');
+    if (!s) {
+       *type = -1;
+       return 1;
+    }
+    *type = atoi(s + 1);
+    return *type >= 0 && *type < 256;
+}
+
 void rewritefree() {
     struct list_node *entry;
-    struct rewrite *r;
-
+    struct rewriteconf *r;
+    
     for (entry = list_first(rewriteconfs); entry; entry = list_next(entry)) {
-       r = (struct rewrite *)entry->data;
+       r = (struct rewriteconf *)entry->data;
        if (r->name)
            free(r->name);
        if (!r->count)
-           free(r->removeattrs);
+           free(r->rewrite);
     }
     list_destroy(rewriteconfs);
     rewriteconfs = NULL;
 }
 
-uint8_t *getrewrite(char *alt1, char *alt2) {
+struct rewrite *getrewrite(char *alt1, char *alt2) {
     struct list_node *entry;
-    struct rewrite *r, *r1 = NULL, *r2 = NULL;
-
+    struct rewriteconf *r, *r1 = NULL, *r2 = NULL;
+    
     for (entry = list_first(rewriteconfs); entry; entry = list_next(entry)) {
-       r = (struct rewrite *)entry->data;
+       r = (struct rewriteconf *)entry->data;
        if (!strcasecmp(r->name, alt1)) {
            r1 = r;
            break;
@@ -2600,35 +2671,60 @@ uint8_t *getrewrite(char *alt1, char *alt2) {
     if (!r)
        return NULL;
     r->count++;
-    return r->removeattrs;
+    return r->rewrite;
 }
 
-void addrewrite(char *value, char **attrs) {
-    struct rewrite *new;
+void addrewrite(char *value, char **attrs, char **vattrs) {
+    struct rewriteconf *new;
+    struct rewrite *rewrite = NULL;
     int i, n;
-    uint8_t *a;
+    uint8_t *a = NULL;
+    uint32_t *p, *va = NULL;
 
-    n = 0;
-    if (attrs)
+    if (attrs) {
+       n = 0;
        for (; attrs[n]; n++);
-    a = malloc((n + 1) * sizeof(uint8_t));
-    if (!a)
-       debugx(1, DBG_ERR, "malloc failed");
-
-    for (i = 0; i < n; i++)
-       if (!(a[i] = attrname2val(attrs[i])))
-           debugx(1, DBG_ERR, "addrewrite: invalid attribute %s", attrs[i]);
-    a[i] = 0;
-
-    new = malloc(sizeof(struct rewrite));
+       a = malloc((n + 1) * sizeof(uint8_t));
+       if (!a)
+           debugx(1, DBG_ERR, "malloc failed");
+    
+       for (i = 0; i < n; i++)
+           if (!(a[i] = attrname2val(attrs[i])))
+               debugx(1, DBG_ERR, "addrewrite: invalid attribute %s", attrs[i]);
+       a[i] = 0;
+    }
+    
+    if (vattrs) {
+       n = 0;
+       for (; vattrs[n]; n++);
+       va = malloc((2 * n + 1) * sizeof(uint32_t));
+       if (!va)
+           debugx(1, DBG_ERR, "malloc failed");
+    
+       for (p = va, i = 0; i < n; i++, p += 2)
+           if (!vattrname2val(vattrs[i], p, p + 1))
+               debugx(1, DBG_ERR, "addrewrite: invalid vendor attribute %s", vattrs[i]);
+       *p = 0;
+    }
+    
+    if (a || va) {
+       rewrite = malloc(sizeof(struct rewrite));
+       if (!rewrite)
+           debugx(1, DBG_ERR, "malloc failed");
+       rewrite->removeattrs = a;
+       rewrite->removevendorattrs = va;
+    }
+    
+    new = malloc(sizeof(struct rewriteconf));
     if (!new || !list_push(rewriteconfs, new))
        debugx(1, DBG_ERR, "malloc failed");
 
-    memset(new, 0, sizeof(struct rewrite));
+    memset(new, 0, sizeof(struct rewriteconf));
     new->name = stringcopy(value, 0);
     if (!new->name)
        debugx(1, DBG_ERR, "malloc failed");
-    new->removeattrs = a;
+       
+    new->rewrite = rewrite;
     debug(DBG_DBG, "addrewrite: added rewrite block %s", value);
 }
 
@@ -2676,8 +2772,8 @@ void confclient_cb(struct gconffile **cf, char *block, char *opt, char *val) {
        free(tls);
     if (matchcertattr)
        free(matchcertattr);
-
-    conf->removeattrs = rewrite ? getrewrite(rewrite, NULL) : getrewrite("defaultclient", "default");
+    
+    conf->rewrite = rewrite ? getrewrite(rewrite, NULL) : getrewrite("defaultclient", "default");
     
     if (rewriteattr) {
        if (!addrewriteattr(conf, rewriteattr))
@@ -2745,7 +2841,7 @@ void confserver_cb(struct gconffile **cf, char *block, char *opt, char *val) {
     if (matchcertattr)
        free(matchcertattr);
     
-    conf->removeattrs = rewrite ? getrewrite(rewrite, NULL) : getrewrite("defaultserver", "default");
+    conf->rewrite = rewrite ? getrewrite(rewrite, NULL) : getrewrite("defaultserver", "default");
     
     if (!resolvepeer(conf, 0))
        debugx(1, DBG_ERR, "failed to resolve host %s port %s, exiting", conf->host ? conf->host : "(null)", conf->port ? conf->port : "(null)");
@@ -2803,16 +2899,18 @@ void conftls_cb(struct gconffile **cf, char *block, char *opt, char *val) {
 }
 
 void confrewrite_cb(struct gconffile **cf, char *block, char *opt, char *val) {
-    char **attrs = NULL;
-
+    char **attrs = NULL, **vattrs = NULL;
+    
     debug(DBG_DBG, "confrewrite_cb called for %s", block);
-
+    
     getgenericconfig(cf, block,
                     "removeAttribute", CONF_MSTR, &attrs,
+                    "removeVendorAttribute", CONF_MSTR, &vattrs,
                     NULL
                     );
-    addrewrite(val, attrs);
+    addrewrite(val, attrs, vattrs);
     free(attrs);
+    free(vattrs);
 }
 
 void getmainconfig(const char *configfile) {
@@ -2836,12 +2934,12 @@ void getmainconfig(const char *configfile) {
  
     tlsconfs = list_create();
     if (!tlsconfs)
-       debugx(1, DBG_ERR, "malloc failed");    
-
-    rewriteconfs = list_create();
-    if (!rewriteconfs)
        debugx(1, DBG_ERR, "malloc failed");
     
+    rewriteconfs = list_create();
+    if (!rewriteconfs)
+       debugx(1, DBG_ERR, "malloc failed");    
     getgenericconfig(&cfs, NULL,
                     "ListenUDP", CONF_STR, &options.listenudp,
                     "ListenTCP", CONF_STR, &options.listentcp,
@@ -2860,7 +2958,7 @@ void getmainconfig(const char *configfile) {
     popgconffile(&cfs);
     tlsfree();
     rewritefree();
-
+    
     if (loglevel) {
        if (strlen(loglevel) != 1 || *loglevel < '1' || *loglevel > '4')
            debugx(1, DBG_ERR, "error in %s, value of option LogLevel is %s, must be 1, 2, 3 or 4", configfile, loglevel);
index 9400a6f..666514b 100644 (file)
@@ -84,7 +84,7 @@ struct clsrvconf {
     char *rewriteattrreplacement;
     uint8_t statusserver;
     SSL_CTX *ssl_ctx;
-    uint8_t *removeattrs;
+    struct rewrite *rewrite;
     struct addrinfo *addrinfo;
     uint8_t prefixlen;
     struct list *clients;
@@ -128,8 +128,13 @@ struct tls {
 };
 
 struct rewrite {
-    char *name;
     uint8_t *removeattrs;
+    uint32_t *removevendorattrs;
+};
+
+struct rewriteconf {
+    char *name;
+    struct rewrite *rewrite;
     int count;
 };