perl -i -npe "s/[ \t]+$//g" `find src -name "*.[ch]" -print`
[freeradius.git] / src / modules / rlm_sqlcounter / rlm_sqlcounter.c
1 /*
2  * rlm_sqlcounter.c
3  *
4  * Version:  $Id$
5  *
6  *   This program is free software; you can redistribute it and/or modify
7  *   it under the terms of the GNU General Public License as published by
8  *   the Free Software Foundation; either version 2 of the License, or
9  *   (at your option) any later version.
10  *
11  *   This program is distributed in the hope that it will be useful,
12  *   but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *   GNU General Public License for more details.
15  *
16  *   You should have received a copy of the GNU General Public License
17  *   along with this program; if not, write to the Free Software
18  *   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
19  *
20  * Copyright 2001  The FreeRADIUS server project
21  * Copyright 2001  Alan DeKok <aland@ox.org>
22  */
23
24 /* This module is based directly on the rlm_counter module */
25
26
27 #include "config.h"
28 #include "autoconf.h"
29 #include "libradius.h"
30
31 #include <stdio.h>
32 #include <stdlib.h>
33 #include <string.h>
34 #include <ctype.h>
35
36 #include "radiusd.h"
37 #include "modules.h"
38 #include "conffile.h"
39
40 #define MAX_QUERY_LEN 1024
41
42 #include <time.h>
43
44
45 /*      Note: When your counter spans more than 1 period (ie 3 months or 2 weeks), this module
46  *      probably does NOT do what you want!  It calculates the range of dates to count across
47  *      by first calculating the End of the Current period and then subtracting the number of
48  *      periods you specify from that to determine the beginning of the range.
49  *
50  *      For example, if you specify a 3 month counter and today is June 15th, the end of the current
51  *      period is June 30. Subtracting 3 months from that gives April 1st.  So, the counter will
52  *      sum radacct entries from April 1st to June 30. Then, next month, it will sum entries
53  *      from May 1st to July 31st.
54  *
55  *      To fix this behavior, we need to add some way of storing the Next Reset Time
56  */
57
58
59 static const char rcsid[] = "$Id$";
60
61 /*
62  *      Define a structure for our module configuration.
63  *
64  *      These variables do not need to be in a structure, but it's
65  *      a lot cleaner to do so, and a pointer to the structure can
66  *      be used as the instance handle.
67  */
68 typedef struct rlm_sqlcounter_t {
69         char *counter_name;     /* Daily-Session-Time */
70         char *check_name;       /* Max-Daily-Session */
71         char *key_name;         /* User-Name */
72         char *sqlmod_inst;      /* instance of SQL module to use, usually just 'sql' */
73         char *query;            /* SQL query to retrieve current session time */
74         char *reset;            /* daily, weekly, monthly, never or user defined */
75         time_t reset_time;
76         time_t last_reset;
77         int  key_attr;          /* attribute number for key field */
78         int  dict_attr;         /* attribute number for the counter. */
79 } rlm_sqlcounter_t;
80
81 /*
82  *      A mapping of configuration file names to internal variables.
83  *
84  *      Note that the string is dynamically allocated, so it MUST
85  *      be freed.  When the configuration file parse re-reads the string,
86  *      it free's the old one, and strdup's the new one, placing the pointer
87  *      to the strdup'd string into 'config.string'.  This gets around
88  *      buffer over-flows.
89  */
90 static CONF_PARSER module_config[] = {
91   { "counter-name", PW_TYPE_STRING_PTR, offsetof(rlm_sqlcounter_t,counter_name), NULL,  NULL },
92   { "check-name", PW_TYPE_STRING_PTR, offsetof(rlm_sqlcounter_t,check_name), NULL, NULL },
93   { "key", PW_TYPE_STRING_PTR, offsetof(rlm_sqlcounter_t,key_name), NULL, NULL },
94   { "sqlmod-inst", PW_TYPE_STRING_PTR, offsetof(rlm_sqlcounter_t,sqlmod_inst), NULL, NULL },
95   { "query", PW_TYPE_STRING_PTR, offsetof(rlm_sqlcounter_t,query), NULL, NULL },
96   { "reset", PW_TYPE_STRING_PTR, offsetof(rlm_sqlcounter_t,reset), NULL,  NULL },
97   { NULL, -1, 0, NULL, NULL }
98 };
99
100
101 static int find_next_reset(rlm_sqlcounter_t *data, time_t timeval)
102 {
103         int ret=0;
104         unsigned int num=1;
105         char last = 0;
106         struct tm *tm, s_tm;
107         char sCurrentTime[40], sNextTime[40];
108
109         tm = localtime_r(&timeval, &s_tm);
110         strftime(sCurrentTime, sizeof(sCurrentTime),"%Y-%m-%d %H:%M:%S",tm);
111         tm->tm_sec = tm->tm_min = 0;
112
113         if (data->reset == NULL)
114                 return -1;
115         if (isdigit((int) data->reset[0])){
116                 unsigned int len=0;
117
118                 len = strlen(data->reset);
119                 if (len == 0)
120                         return -1;
121                 last = data->reset[len - 1];
122                 if (!isalpha((int) last))
123                         last = 'd';
124 /*              num = atoi(data->reset); */
125                 DEBUG("rlm_sqlcounter: num=%d, last=%c",num,last);
126         }
127         if (strcmp(data->reset, "hourly") == 0 || last == 'h') {
128                 /*
129                  *  Round up to the next nearest hour.
130                  */
131                 tm->tm_hour += num;
132                 data->reset_time = mktime(tm);
133         } else if (strcmp(data->reset, "daily") == 0 || last == 'd') {
134                 /*
135                  *  Round up to the next nearest day.
136                  */
137                 tm->tm_hour = 0;
138                 tm->tm_mday += num;
139                 data->reset_time = mktime(tm);
140         } else if (strcmp(data->reset, "weekly") == 0 || last == 'w') {
141                 /*
142                  *  Round up to the next nearest week.
143                  */
144                 tm->tm_hour = 0;
145                 tm->tm_mday += (7 - tm->tm_wday) +(7*(num-1));
146                 data->reset_time = mktime(tm);
147         } else if (strcmp(data->reset, "monthly") == 0 || last == 'm') {
148                 tm->tm_hour = 0;
149                 tm->tm_mday = 1;
150                 tm->tm_mon += num;
151                 data->reset_time = mktime(tm);
152         } else if (strcmp(data->reset, "never") == 0) {
153                 data->reset_time = 0;
154         } else {
155                 radlog(L_ERR, "rlm_sqlcounter: Unknown reset timer \"%s\"",
156                         data->reset);
157                 return -1;
158         }
159         strftime(sNextTime, sizeof(sNextTime),"%Y-%m-%d %H:%M:%S",tm);
160         DEBUG2("rlm_sqlcounter: Current Time: %d [%s], Next reset %d [%s]",
161                 (int)timeval,sCurrentTime,(int)data->reset_time, sNextTime);
162
163         return ret;
164 }
165
166
167 /*  I don't believe that this routine handles Daylight Saving Time adjustments
168     properly.  Any suggestions?
169 */
170
171 static int find_prev_reset(rlm_sqlcounter_t *data, time_t timeval)
172 {
173         int ret=0;
174         unsigned int num=1;
175         char last = 0;
176         struct tm *tm, s_tm;
177         char sCurrentTime[40], sPrevTime[40];
178
179         tm = localtime_r(&timeval, &s_tm);
180         strftime(sCurrentTime, sizeof(sCurrentTime),"%Y-%m-%d %H:%M:%S",tm);
181         tm->tm_sec = tm->tm_min = 0;
182
183         if (data->reset == NULL)
184                 return -1;
185         if (isdigit((int) data->reset[0])){
186                 unsigned int len=0;
187
188                 len = strlen(data->reset);
189                 if (len == 0)
190                         return -1;
191                 last = data->reset[len - 1];
192                 if (!isalpha((int) last))
193                         last = 'd';
194                 num = atoi(data->reset);
195                 DEBUG("rlm_sqlcounter: num=%d, last=%c",num,last);
196         }
197         if (strcmp(data->reset, "hourly") == 0 || last == 'h') {
198                 /*
199                  *  Round down to the prev nearest hour.
200                  */
201                 tm->tm_hour -= num - 1;
202                 data->last_reset = mktime(tm);
203         } else if (strcmp(data->reset, "daily") == 0 || last == 'd') {
204                 /*
205                  *  Round down to the prev nearest day.
206                  */
207                 tm->tm_hour = 0;
208                 tm->tm_mday -= num - 1;
209                 data->last_reset = mktime(tm);
210         } else if (strcmp(data->reset, "weekly") == 0 || last == 'w') {
211                 /*
212                  *  Round down to the prev nearest week.
213                  */
214                 tm->tm_hour = 0;
215                 tm->tm_mday -= (7 - tm->tm_wday) +(7*(num-1));
216                 data->last_reset = mktime(tm);
217         } else if (strcmp(data->reset, "monthly") == 0 || last == 'm') {
218                 tm->tm_hour = 0;
219                 tm->tm_mday = 1;
220                 tm->tm_mon -= num - 1;
221                 data->last_reset = mktime(tm);
222         } else if (strcmp(data->reset, "never") == 0) {
223                 data->reset_time = 0;
224         } else {
225                 radlog(L_ERR, "rlm_sqlcounter: Unknown reset timer \"%s\"",
226                         data->reset);
227                 return -1;
228         }
229         strftime(sPrevTime, sizeof(sPrevTime),"%Y-%m-%d %H:%M:%S",tm);
230         DEBUG2("rlm_sqlcounter: Current Time: %d [%s], Prev reset %d [%s]",
231                 (int)timeval,sCurrentTime,(int)data->last_reset, sPrevTime);
232
233         return ret;
234 }
235
236
237 /*
238  *      Replace %<whatever> in a string.
239  *
240  *      %b      last_reset
241  *      %e      reset_time
242  *      %k      key_name
243  *      %S      sqlmod_inst
244  *
245  */
246
247 static int sqlcounter_expand(char *out, int outlen, const char *fmt, void *instance)
248 {
249         rlm_sqlcounter_t *data = (rlm_sqlcounter_t *) instance;
250         int c,freespace;
251         const char *p;
252         char *q;
253         char tmpdt[40]; /* For temporary storing of dates */
254         int openbraces=0;
255
256         q = out;
257         for (p = fmt; *p ; p++) {
258         /* Calculate freespace in output */
259         freespace = outlen - (q - out);
260                 if (freespace <= 1)
261                         break;
262                 c = *p;
263                 if ((c != '%') && (c != '$') && (c != '\\')) {
264                         /*
265                          * We check if we're inside an open brace.  If we are
266                          * then we assume this brace is NOT literal, but is
267                          * a closing brace and apply it
268                          */
269                         if((c == '}') && openbraces) {
270                                 openbraces--;
271                                 continue;
272                         }
273                         *q++ = *p;
274                         continue;
275                 }
276                 if (*++p == '\0') break;
277                 if (c == '\\') switch(*p) {
278                         case '\\':
279                                 *q++ = *p;
280                                 break;
281                         case 't':
282                                 *q++ = '\t';
283                                 break;
284                         case 'n':
285                                 *q++ = '\n';
286                                 break;
287                         default:
288                                 *q++ = c;
289                                 *q++ = *p;
290                                 break;
291
292                 } else if (c == '%') switch(*p) {
293
294                         case '%':
295                                 *q++ = *p;
296                         case 'b': /* last_reset */
297                                 sprintf(tmpdt, "%lu", data->last_reset);
298                                 strNcpy(q, tmpdt, freespace);
299                                 q += strlen(q);
300                                 break;
301                         case 'e': /* reset_time */
302                                 sprintf(tmpdt, "%lu", data->reset_time);
303                                 strNcpy(q, tmpdt, freespace);
304                                 q += strlen(q);
305                                 break;
306                         case 'k': /* Key Name */
307                                 strNcpy(q, data->key_name, freespace);
308                                 q += strlen(q);
309                                 break;
310                         case 'S': /* SQL module instance */
311                                 strNcpy(q, data->sqlmod_inst, freespace);
312                                 q += strlen(q);
313                                 break;
314                         default:
315                                 *q++ = '%';
316                                 *q++ = *p;
317                                 break;
318                 }
319         }
320         *q = '\0';
321
322         DEBUG2("sqlcounter_expand:  '%s'", out);
323
324         return strlen(out);
325 }
326
327
328 /*
329  *      See if the counter matches.
330  */
331 static int sqlcounter_cmp(void *instance, REQUEST *req, VALUE_PAIR *request, VALUE_PAIR *check,
332                 VALUE_PAIR *check_pairs, VALUE_PAIR **reply_pairs)
333 {
334         rlm_sqlcounter_t *data = (rlm_sqlcounter_t *) instance;
335         int counter;
336         char querystr[MAX_QUERY_LEN];
337         char responsestr[MAX_QUERY_LEN];
338
339         check_pairs = check_pairs; /* shut the compiler up */
340         reply_pairs = reply_pairs;
341
342         /* first, expand %k, %b and %e in query */
343         sqlcounter_expand(querystr, MAX_QUERY_LEN, data->query, instance);
344
345         /* second, xlat any request attribs in query */
346         radius_xlat(responsestr, MAX_QUERY_LEN, querystr, req, NULL);
347
348         /* third, wrap query with sql module call & expand */
349         sprintf(querystr, "%%{%%S:%s}", responsestr);
350         sqlcounter_expand(responsestr, MAX_QUERY_LEN, querystr, instance);
351
352         /* Finally, xlat resulting SQL query */
353         radius_xlat(querystr, MAX_QUERY_LEN, responsestr, req, NULL);
354
355         counter = atoi(querystr);
356
357         return counter - check->lvalue;
358 }
359
360
361 /*
362  *      Do any per-module initialization that is separate to each
363  *      configured instance of the module.  e.g. set up connections
364  *      to external databases, read configuration files, set up
365  *      dictionary entries, etc.
366  *
367  *      If configuration information is given in the config section
368  *      that must be referenced in later calls, store a handle to it
369  *      in *instance otherwise put a null pointer there.
370  */
371 static int sqlcounter_instantiate(CONF_SECTION *conf, void **instance)
372 {
373         rlm_sqlcounter_t *data;
374         DICT_ATTR *dattr;
375         ATTR_FLAGS flags;
376         time_t now;
377
378         /*
379          *      Set up a storage area for instance data
380          */
381         data = rad_malloc(sizeof(*data));
382         if (!data) {
383                 return -1;
384         }
385         memset(data, 0, sizeof(*data));
386
387         /*
388          *      If the configuration parameters can't be parsed, then
389          *      fail.
390          */
391         if (cf_section_parse(conf, data, module_config) < 0) {
392                 free(data);
393                 return -1;
394         }
395
396         /*
397          *      Discover the attribute number of the key.
398          */
399         if (data->key_name == NULL) {
400                 radlog(L_ERR, "rlm_sqlcounter: 'key' must be set.");
401                 exit(0);
402         }
403         dattr = dict_attrbyname(data->key_name);
404         if (dattr == NULL) {
405                 radlog(L_ERR, "rlm_sqlcounter: No such attribute %s",
406                                 data->key_name);
407                 return -1;
408         }
409         data->key_attr = dattr->attr;
410
411
412         /*
413          *  Create a new attribute for the counter.
414          */
415         if (data->counter_name == NULL) {
416                 radlog(L_ERR, "rlm_sqlcounter: 'counter-name' must be set.");
417                 exit(0);
418         }
419
420         memset(&flags, 0, sizeof(flags));
421         dict_addattr(data->counter_name, 0, PW_TYPE_INTEGER, -1, flags);
422         dattr = dict_attrbyname(data->counter_name);
423         if (dattr == NULL) {
424                 radlog(L_ERR, "rlm_sqlcounter: Failed to create counter attribute %s",
425                                 data->counter_name);
426                 return -1;
427         }
428         data->dict_attr = dattr->attr;
429         DEBUG2("rlm_sqlcounter: Counter attribute %s is number %d",
430                         data->counter_name, data->dict_attr);
431
432         /*
433          * Create a new attribute for the check item.
434          */
435         if (data->check_name == NULL) {
436                 radlog(L_ERR, "rlm_sqlcounter: 'check-name' must be set.");
437                 exit(0);
438         }
439         dict_addattr(data->check_name, 0, PW_TYPE_INTEGER, -1, flags);
440         dattr = dict_attrbyname(data->check_name);
441         if (dattr == NULL) {
442                 radlog(L_ERR, "rlm_sqlcounter: Failed to create check attribute %s",
443                                 data->counter_name);
444                 return -1;
445         }
446         DEBUG2("rlm_sqlcounter: Check attribute %s is number %d",
447                         data->check_name, dattr->attr);
448
449         /*
450          *  Discover the end of the current time period.
451          */
452         if (data->reset == NULL) {
453                 radlog(L_ERR, "rlm_sqlcounter: 'reset' must be set.");
454                 exit(0);
455         }
456         now = time(NULL);
457         data->reset_time = 0;
458
459         if (find_next_reset(data,now) == -1)
460                 return -1;
461
462         /*
463          *  Discover the beginning of the current time period.
464          */
465         data->last_reset = 0;
466
467         if (find_prev_reset(data,now) == -1)
468                 return -1;
469
470
471         /*
472          *      Register the counter comparison operation.
473          */
474         paircompare_register(data->dict_attr, 0, sqlcounter_cmp, data);
475
476         *instance = data;
477
478         return 0;
479 }
480
481 /*
482  *      Find the named user in this modules database.  Create the set
483  *      of attribute-value pairs to check and reply with for this user
484  *      from the database. The authentication code only needs to check
485  *      the password, the rest is done here.
486  */
487 static int sqlcounter_authorize(void *instance, REQUEST *request)
488 {
489         rlm_sqlcounter_t *data = (rlm_sqlcounter_t *) instance;
490         int ret=RLM_MODULE_NOOP;
491         int counter=0;
492         int res=0;
493         DICT_ATTR *dattr;
494         VALUE_PAIR *key_vp, *check_vp;
495         VALUE_PAIR *reply_item;
496         char msg[128];
497         char querystr[MAX_QUERY_LEN];
498         char responsestr[MAX_QUERY_LEN];
499
500         /* quiet the compiler */
501         instance = instance;
502         request = request;
503
504         /*
505          *      Before doing anything else, see if we have to reset
506          *      the counters.
507          */
508         if (data->reset_time && (data->reset_time <= request->timestamp)) {
509
510                 /*
511                  *      Re-set the next time and prev_time for this counters range
512                  */
513                 data->last_reset = data->reset_time;
514                 find_next_reset(data,request->timestamp);
515         }
516
517
518         /*
519          *      Look for the key.  User-Name is special.  It means
520          *      The REAL username, after stripping.
521          */
522         DEBUG2("rlm_sqlcounter: Entering module authorize code");
523         key_vp = (data->key_attr == PW_USER_NAME) ? request->username : pairfind(request->packet->vps, data->key_attr);
524         if (key_vp == NULL) {
525                 DEBUG2("rlm_sqlcounter: Could not find Key value pair");
526                 return ret;
527         }
528
529         /*
530          *      Look for the check item
531          */
532         if ((dattr = dict_attrbyname(data->check_name)) == NULL) {
533                 return ret;
534         }
535         /* DEBUG2("rlm_sqlcounter: Found Check item attribute %d", dattr->attr); */
536         if ((check_vp= pairfind(request->config_items, dattr->attr)) == NULL) {
537                 DEBUG2("rlm_sqlcounter: Could not find Check item value pair");
538                 return ret;
539         }
540
541         /* first, expand %k, %b and %e in query */
542         sqlcounter_expand(querystr, MAX_QUERY_LEN, data->query, instance);
543
544         /* second, xlat any request attribs in query */
545         radius_xlat(responsestr, MAX_QUERY_LEN, querystr, request, NULL);
546
547         /* third, wrap query with sql module & expand */
548         sprintf(querystr, "%%{%%S:%s}", responsestr);
549         sqlcounter_expand(responsestr, MAX_QUERY_LEN, querystr, instance);
550
551         /* Finally, xlat resulting SQL query */
552         radius_xlat(querystr, MAX_QUERY_LEN, responsestr, request, NULL);
553
554         counter = atoi(querystr);
555
556
557         /*
558          * Check if check item > counter
559          */
560         res=check_vp->lvalue - counter;
561         if (res > 0) {
562                 DEBUG2("rlm_sqlcounter: (Check item - counter) is greater than zero");
563                 /*
564                  *      We are assuming that simultaneous-use=1. But
565                  *      even if that does not happen then our user
566                  *      could login at max for 2*max-usage-time Is
567                  *      that acceptable?
568                  */
569
570                 /*
571                  *      User is allowed, but set Session-Timeout.
572                  *      Stolen from main/auth.c
573                  */
574
575                 /*
576                  *      If we are near a reset then add the next
577                  *      limit, so that the user will not need to
578                  *      login again
579                  */
580                 if (data->reset_time && (
581                         res >= (data->reset_time - request->timestamp))) {
582                         res += check_vp->lvalue;
583                 }
584
585                 if ((reply_item = pairfind(request->reply->vps, PW_SESSION_TIMEOUT)) != NULL) {
586                         if (reply_item->lvalue > res)
587                                 reply_item->lvalue = res;
588                 } else {
589                         if ((reply_item = paircreate(PW_SESSION_TIMEOUT, PW_TYPE_INTEGER)) == NULL) {
590                                 radlog(L_ERR|L_CONS, "no memory");
591                                 return RLM_MODULE_NOOP;
592                         }
593                         reply_item->lvalue = res;
594                         pairadd(&request->reply->vps, reply_item);
595                 }
596
597                 ret=RLM_MODULE_OK;
598
599                 DEBUG2("rlm_sqlcounter: Authorized user %s, check_item=%d, counter=%d",
600                                 key_vp->strvalue,check_vp->lvalue,counter);
601                 DEBUG2("rlm_sqlcounter: Sent Reply-Item for user %s, Type=Session-Timeout, value=%d",
602                                 key_vp->strvalue,reply_item->lvalue);
603         }
604         else{
605                 char module_fmsg[MAX_STRING_LEN];
606                 VALUE_PAIR *module_fmsg_vp;
607
608                 DEBUG2("rlm_sqlcounter: (Check item - counter) is less than zero");
609
610                 /*
611                  * User is denied access, send back a reply message
612                 */
613                 sprintf(msg, "Your maximum %s usage time has been reached", data->reset);
614                 reply_item=pairmake("Reply-Message", msg, T_OP_EQ);
615                 pairadd(&request->reply->vps, reply_item);
616
617                 snprintf(module_fmsg, sizeof(module_fmsg), "rlm_sqlcounter: Maximum %s usage time reached", data->reset);
618                 module_fmsg_vp = pairmake("Module-Failure-Message", module_fmsg, T_OP_EQ);
619                 pairadd(&request->packet->vps, module_fmsg_vp);
620
621                 ret=RLM_MODULE_REJECT;
622
623                 DEBUG2("rlm_sqlcounter: Rejected user %s, check_item=%d, counter=%d",
624                                 key_vp->strvalue,check_vp->lvalue,counter);
625         }
626
627         return ret;
628 }
629
630 static int sqlcounter_detach(void *instance)
631 {
632         rlm_sqlcounter_t *data = (rlm_sqlcounter_t *) instance;
633
634         paircompare_unregister(data->dict_attr, sqlcounter_cmp);
635         free(data->reset);
636         free(data->query);
637         free(data->check_name);
638         free(data->sqlmod_inst);
639         free(data->counter_name);
640
641         free(instance);
642         return 0;
643 }
644
645 /*
646  *      The module name should be the only globally exported symbol.
647  *      That is, everything else should be 'static'.
648  *
649  *      If the module needs to temporarily modify it's instantiation
650  *      data, the type should be changed to RLM_TYPE_THREAD_UNSAFE.
651  *      The server will then take care of ensuring that the module
652  *      is single-threaded.
653  */
654 module_t rlm_sqlcounter = {
655         "SQL Counter",
656         RLM_TYPE_THREAD_SAFE,           /* type */
657         NULL,                           /* initialization */
658         sqlcounter_instantiate,         /* instantiation */
659         {
660                 NULL,                   /* authentication */
661                 sqlcounter_authorize,   /* authorization */
662                 NULL,                   /* preaccounting */
663                 NULL,                   /* accounting */
664                 NULL,                   /* checksimul */
665                 NULL,                   /* pre-proxy */
666                 NULL,                   /* post-proxy */
667                 NULL                    /* post-auth */
668         },
669         sqlcounter_detach,              /* detach */
670         NULL,                           /* destroy */
671 };
672