Remove libradius.h from the top of the standard header list.
[freeradius.git] / src / modules / rlm_sqlcounter / rlm_sqlcounter.c
1 /*
2  * rlm_sqlcounter.c
3  *
4  * Version:  $Id$
5  *
6  *   This program is free software; you can redistribute it and/or modify
7  *   it under the terms of the GNU General Public License as published by
8  *   the Free Software Foundation; either version 2 of the License, or
9  *   (at your option) any later version.
10  *
11  *   This program is distributed in the hope that it will be useful,
12  *   but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *   GNU General Public License for more details.
15  *
16  *   You should have received a copy of the GNU General Public License
17  *   along with this program; if not, write to the Free Software
18  *   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
19  *
20  * Copyright 2001  The FreeRADIUS server project
21  * Copyright 2001  Alan DeKok <aland@ox.org>
22  */
23
24 /* This module is based directly on the rlm_counter module */
25
26
27 #include "autoconf.h"
28
29 #include <stdio.h>
30 #include <stdlib.h>
31 #include <string.h>
32 #include <ctype.h>
33
34 #include "radiusd.h"
35 #include "modules.h"
36 #include "conffile.h"
37
38 #define MAX_QUERY_LEN 1024
39
40 #include <time.h>
41
42
43 /*      Note: When your counter spans more than 1 period (ie 3 months or 2 weeks), this module
44  *      probably does NOT do what you want!  It calculates the range of dates to count across
45  *      by first calculating the End of the Current period and then subtracting the number of
46  *      periods you specify from that to determine the beginning of the range.
47  *
48  *      For example, if you specify a 3 month counter and today is June 15th, the end of the current
49  *      period is June 30. Subtracting 3 months from that gives April 1st.  So, the counter will
50  *      sum radacct entries from April 1st to June 30. Then, next month, it will sum entries
51  *      from May 1st to July 31st.
52  *
53  *      To fix this behavior, we need to add some way of storing the Next Reset Time
54  */
55
56
57 static const char rcsid[] = "$Id$";
58
59 /*
60  *      Define a structure for our module configuration.
61  *
62  *      These variables do not need to be in a structure, but it's
63  *      a lot cleaner to do so, and a pointer to the structure can
64  *      be used as the instance handle.
65  */
66 typedef struct rlm_sqlcounter_t {
67         char *counter_name;     /* Daily-Session-Time */
68         char *check_name;       /* Max-Daily-Session */
69         char *key_name;         /* User-Name */
70         char *sqlmod_inst;      /* instance of SQL module to use, usually just 'sql' */
71         char *query;            /* SQL query to retrieve current session time */
72         char *reset;            /* daily, weekly, monthly, never or user defined */
73         time_t reset_time;
74         time_t last_reset;
75         int  key_attr;          /* attribute number for key field */
76         int  dict_attr;         /* attribute number for the counter. */
77 } rlm_sqlcounter_t;
78
79 /*
80  *      A mapping of configuration file names to internal variables.
81  *
82  *      Note that the string is dynamically allocated, so it MUST
83  *      be freed.  When the configuration file parse re-reads the string,
84  *      it free's the old one, and strdup's the new one, placing the pointer
85  *      to the strdup'd string into 'config.string'.  This gets around
86  *      buffer over-flows.
87  */
88 static const CONF_PARSER module_config[] = {
89   { "counter-name", PW_TYPE_STRING_PTR, offsetof(rlm_sqlcounter_t,counter_name), NULL,  NULL },
90   { "check-name", PW_TYPE_STRING_PTR, offsetof(rlm_sqlcounter_t,check_name), NULL, NULL },
91   { "key", PW_TYPE_STRING_PTR, offsetof(rlm_sqlcounter_t,key_name), NULL, NULL },
92   { "sqlmod-inst", PW_TYPE_STRING_PTR, offsetof(rlm_sqlcounter_t,sqlmod_inst), NULL, NULL },
93   { "query", PW_TYPE_STRING_PTR, offsetof(rlm_sqlcounter_t,query), NULL, NULL },
94   { "reset", PW_TYPE_STRING_PTR, offsetof(rlm_sqlcounter_t,reset), NULL,  NULL },
95   { NULL, -1, 0, NULL, NULL }
96 };
97
98
99 static int find_next_reset(rlm_sqlcounter_t *data, time_t timeval)
100 {
101         int ret=0;
102         unsigned int num=1;
103         char last = 0;
104         struct tm *tm, s_tm;
105         char sCurrentTime[40], sNextTime[40];
106
107         tm = localtime_r(&timeval, &s_tm);
108         strftime(sCurrentTime, sizeof(sCurrentTime),"%Y-%m-%d %H:%M:%S",tm);
109         tm->tm_sec = tm->tm_min = 0;
110
111         if (data->reset == NULL)
112                 return -1;
113         if (isdigit((int) data->reset[0])){
114                 unsigned int len=0;
115
116                 len = strlen(data->reset);
117                 if (len == 0)
118                         return -1;
119                 last = data->reset[len - 1];
120                 if (!isalpha((int) last))
121                         last = 'd';
122 /*              num = atoi(data->reset); */
123                 DEBUG("rlm_sqlcounter: num=%d, last=%c",num,last);
124         }
125         if (strcmp(data->reset, "hourly") == 0 || last == 'h') {
126                 /*
127                  *  Round up to the next nearest hour.
128                  */
129                 tm->tm_hour += num;
130                 data->reset_time = mktime(tm);
131         } else if (strcmp(data->reset, "daily") == 0 || last == 'd') {
132                 /*
133                  *  Round up to the next nearest day.
134                  */
135                 tm->tm_hour = 0;
136                 tm->tm_mday += num;
137                 data->reset_time = mktime(tm);
138         } else if (strcmp(data->reset, "weekly") == 0 || last == 'w') {
139                 /*
140                  *  Round up to the next nearest week.
141                  */
142                 tm->tm_hour = 0;
143                 tm->tm_mday += (7 - tm->tm_wday) +(7*(num-1));
144                 data->reset_time = mktime(tm);
145         } else if (strcmp(data->reset, "monthly") == 0 || last == 'm') {
146                 tm->tm_hour = 0;
147                 tm->tm_mday = 1;
148                 tm->tm_mon += num;
149                 data->reset_time = mktime(tm);
150         } else if (strcmp(data->reset, "never") == 0) {
151                 data->reset_time = 0;
152         } else {
153                 radlog(L_ERR, "rlm_sqlcounter: Unknown reset timer \"%s\"",
154                         data->reset);
155                 return -1;
156         }
157         strftime(sNextTime, sizeof(sNextTime),"%Y-%m-%d %H:%M:%S",tm);
158         DEBUG2("rlm_sqlcounter: Current Time: %d [%s], Next reset %d [%s]",
159                 (int)timeval,sCurrentTime,(int)data->reset_time, sNextTime);
160
161         return ret;
162 }
163
164
165 /*  I don't believe that this routine handles Daylight Saving Time adjustments
166     properly.  Any suggestions?
167 */
168
169 static int find_prev_reset(rlm_sqlcounter_t *data, time_t timeval)
170 {
171         int ret=0;
172         unsigned int num=1;
173         char last = 0;
174         struct tm *tm, s_tm;
175         char sCurrentTime[40], sPrevTime[40];
176
177         tm = localtime_r(&timeval, &s_tm);
178         strftime(sCurrentTime, sizeof(sCurrentTime),"%Y-%m-%d %H:%M:%S",tm);
179         tm->tm_sec = tm->tm_min = 0;
180
181         if (data->reset == NULL)
182                 return -1;
183         if (isdigit((int) data->reset[0])){
184                 unsigned int len=0;
185
186                 len = strlen(data->reset);
187                 if (len == 0)
188                         return -1;
189                 last = data->reset[len - 1];
190                 if (!isalpha((int) last))
191                         last = 'd';
192                 num = atoi(data->reset);
193                 DEBUG("rlm_sqlcounter: num=%d, last=%c",num,last);
194         }
195         if (strcmp(data->reset, "hourly") == 0 || last == 'h') {
196                 /*
197                  *  Round down to the prev nearest hour.
198                  */
199                 tm->tm_hour -= num - 1;
200                 data->last_reset = mktime(tm);
201         } else if (strcmp(data->reset, "daily") == 0 || last == 'd') {
202                 /*
203                  *  Round down to the prev nearest day.
204                  */
205                 tm->tm_hour = 0;
206                 tm->tm_mday -= num - 1;
207                 data->last_reset = mktime(tm);
208         } else if (strcmp(data->reset, "weekly") == 0 || last == 'w') {
209                 /*
210                  *  Round down to the prev nearest week.
211                  */
212                 tm->tm_hour = 0;
213                 tm->tm_mday -= (7 - tm->tm_wday) +(7*(num-1));
214                 data->last_reset = mktime(tm);
215         } else if (strcmp(data->reset, "monthly") == 0 || last == 'm') {
216                 tm->tm_hour = 0;
217                 tm->tm_mday = 1;
218                 tm->tm_mon -= num - 1;
219                 data->last_reset = mktime(tm);
220         } else if (strcmp(data->reset, "never") == 0) {
221                 data->reset_time = 0;
222         } else {
223                 radlog(L_ERR, "rlm_sqlcounter: Unknown reset timer \"%s\"",
224                         data->reset);
225                 return -1;
226         }
227         strftime(sPrevTime, sizeof(sPrevTime),"%Y-%m-%d %H:%M:%S",tm);
228         DEBUG2("rlm_sqlcounter: Current Time: %d [%s], Prev reset %d [%s]",
229                 (int)timeval,sCurrentTime,(int)data->last_reset, sPrevTime);
230
231         return ret;
232 }
233
234
235 /*
236  *      Replace %<whatever> in a string.
237  *
238  *      %b      last_reset
239  *      %e      reset_time
240  *      %k      key_name
241  *      %S      sqlmod_inst
242  *
243  */
244
245 static int sqlcounter_expand(char *out, int outlen, const char *fmt, void *instance)
246 {
247         rlm_sqlcounter_t *data = (rlm_sqlcounter_t *) instance;
248         int c,freespace;
249         const char *p;
250         char *q;
251         char tmpdt[40]; /* For temporary storing of dates */
252         int openbraces=0;
253
254         q = out;
255         for (p = fmt; *p ; p++) {
256         /* Calculate freespace in output */
257         freespace = outlen - (q - out);
258                 if (freespace <= 1)
259                         break;
260                 c = *p;
261                 if ((c != '%') && (c != '$') && (c != '\\')) {
262                         /*
263                          * We check if we're inside an open brace.  If we are
264                          * then we assume this brace is NOT literal, but is
265                          * a closing brace and apply it
266                          */
267                         if((c == '}') && openbraces) {
268                                 openbraces--;
269                                 continue;
270                         }
271                         *q++ = *p;
272                         continue;
273                 }
274                 if (*++p == '\0') break;
275                 if (c == '\\') switch(*p) {
276                         case '\\':
277                                 *q++ = *p;
278                                 break;
279                         case 't':
280                                 *q++ = '\t';
281                                 break;
282                         case 'n':
283                                 *q++ = '\n';
284                                 break;
285                         default:
286                                 *q++ = c;
287                                 *q++ = *p;
288                                 break;
289
290                 } else if (c == '%') switch(*p) {
291
292                         case '%':
293                                 *q++ = *p;
294                         case 'b': /* last_reset */
295                                 sprintf(tmpdt, "%lu", data->last_reset);
296                                 strNcpy(q, tmpdt, freespace);
297                                 q += strlen(q);
298                                 break;
299                         case 'e': /* reset_time */
300                                 sprintf(tmpdt, "%lu", data->reset_time);
301                                 strNcpy(q, tmpdt, freespace);
302                                 q += strlen(q);
303                                 break;
304                         case 'k': /* Key Name */
305                                 strNcpy(q, data->key_name, freespace);
306                                 q += strlen(q);
307                                 break;
308                         case 'S': /* SQL module instance */
309                                 strNcpy(q, data->sqlmod_inst, freespace);
310                                 q += strlen(q);
311                                 break;
312                         default:
313                                 *q++ = '%';
314                                 *q++ = *p;
315                                 break;
316                 }
317         }
318         *q = '\0';
319
320         DEBUG2("sqlcounter_expand:  '%s'", out);
321
322         return strlen(out);
323 }
324
325
326 /*
327  *      See if the counter matches.
328  */
329 static int sqlcounter_cmp(void *instance, REQUEST *req, VALUE_PAIR *request, VALUE_PAIR *check,
330                 VALUE_PAIR *check_pairs, VALUE_PAIR **reply_pairs)
331 {
332         rlm_sqlcounter_t *data = (rlm_sqlcounter_t *) instance;
333         int counter;
334         char querystr[MAX_QUERY_LEN];
335         char responsestr[MAX_QUERY_LEN];
336
337         check_pairs = check_pairs; /* shut the compiler up */
338         reply_pairs = reply_pairs;
339
340         /* first, expand %k, %b and %e in query */
341         sqlcounter_expand(querystr, MAX_QUERY_LEN, data->query, instance);
342
343         /* second, xlat any request attribs in query */
344         radius_xlat(responsestr, MAX_QUERY_LEN, querystr, req, NULL);
345
346         /* third, wrap query with sql module call & expand */
347         sprintf(querystr, "%%{%%S:%s}", responsestr);
348         sqlcounter_expand(responsestr, MAX_QUERY_LEN, querystr, instance);
349
350         /* Finally, xlat resulting SQL query */
351         radius_xlat(querystr, MAX_QUERY_LEN, responsestr, req, NULL);
352
353         counter = atoi(querystr);
354
355         return counter - check->lvalue;
356 }
357
358
359 /*
360  *      Do any per-module initialization that is separate to each
361  *      configured instance of the module.  e.g. set up connections
362  *      to external databases, read configuration files, set up
363  *      dictionary entries, etc.
364  *
365  *      If configuration information is given in the config section
366  *      that must be referenced in later calls, store a handle to it
367  *      in *instance otherwise put a null pointer there.
368  */
369 static int sqlcounter_instantiate(CONF_SECTION *conf, void **instance)
370 {
371         rlm_sqlcounter_t *data;
372         DICT_ATTR *dattr;
373         ATTR_FLAGS flags;
374         time_t now;
375
376         /*
377          *      Set up a storage area for instance data
378          */
379         data = rad_malloc(sizeof(*data));
380         if (!data) {
381                 return -1;
382         }
383         memset(data, 0, sizeof(*data));
384
385         /*
386          *      If the configuration parameters can't be parsed, then
387          *      fail.
388          */
389         if (cf_section_parse(conf, data, module_config) < 0) {
390                 free(data);
391                 return -1;
392         }
393
394         /*
395          *      Discover the attribute number of the key.
396          */
397         if (data->key_name == NULL) {
398                 radlog(L_ERR, "rlm_sqlcounter: 'key' must be set.");
399                 return -1;
400         }
401         dattr = dict_attrbyname(data->key_name);
402         if (dattr == NULL) {
403                 radlog(L_ERR, "rlm_sqlcounter: No such attribute %s",
404                                 data->key_name);
405                 return -1;
406         }
407         data->key_attr = dattr->attr;
408
409
410         /*
411          *  Create a new attribute for the counter.
412          */
413         if (data->counter_name == NULL) {
414                 radlog(L_ERR, "rlm_sqlcounter: 'counter-name' must be set.");
415                 return -1;
416         }
417
418         memset(&flags, 0, sizeof(flags));
419         dict_addattr(data->counter_name, 0, PW_TYPE_INTEGER, -1, flags);
420         dattr = dict_attrbyname(data->counter_name);
421         if (dattr == NULL) {
422                 radlog(L_ERR, "rlm_sqlcounter: Failed to create counter attribute %s",
423                                 data->counter_name);
424                 return -1;
425         }
426         data->dict_attr = dattr->attr;
427         DEBUG2("rlm_sqlcounter: Counter attribute %s is number %d",
428                         data->counter_name, data->dict_attr);
429
430         /*
431          * Create a new attribute for the check item.
432          */
433         if (data->check_name == NULL) {
434                 radlog(L_ERR, "rlm_sqlcounter: 'check-name' must be set.");
435                 return -1;
436         }
437         dict_addattr(data->check_name, 0, PW_TYPE_INTEGER, -1, flags);
438         dattr = dict_attrbyname(data->check_name);
439         if (dattr == NULL) {
440                 radlog(L_ERR, "rlm_sqlcounter: Failed to create check attribute %s",
441                                 data->counter_name);
442                 return -1;
443         }
444         DEBUG2("rlm_sqlcounter: Check attribute %s is number %d",
445                         data->check_name, dattr->attr);
446
447         /*
448          *  Discover the end of the current time period.
449          */
450         if (data->reset == NULL) {
451                 radlog(L_ERR, "rlm_sqlcounter: 'reset' must be set.");
452                 return -1;
453         }
454         now = time(NULL);
455         data->reset_time = 0;
456
457         if (find_next_reset(data,now) == -1)
458                 return -1;
459
460         /*
461          *  Discover the beginning of the current time period.
462          */
463         data->last_reset = 0;
464
465         if (find_prev_reset(data,now) == -1)
466                 return -1;
467
468
469         /*
470          *      Register the counter comparison operation.
471          */
472         paircompare_register(data->dict_attr, 0, sqlcounter_cmp, data);
473
474         *instance = data;
475
476         return 0;
477 }
478
479 /*
480  *      Find the named user in this modules database.  Create the set
481  *      of attribute-value pairs to check and reply with for this user
482  *      from the database. The authentication code only needs to check
483  *      the password, the rest is done here.
484  */
485 static int sqlcounter_authorize(void *instance, REQUEST *request)
486 {
487         rlm_sqlcounter_t *data = (rlm_sqlcounter_t *) instance;
488         int ret=RLM_MODULE_NOOP;
489         int counter=0;
490         int res=0;
491         DICT_ATTR *dattr;
492         VALUE_PAIR *key_vp, *check_vp;
493         VALUE_PAIR *reply_item;
494         char msg[128];
495         char querystr[MAX_QUERY_LEN];
496         char responsestr[MAX_QUERY_LEN];
497
498         /* quiet the compiler */
499         instance = instance;
500         request = request;
501
502         /*
503          *      Before doing anything else, see if we have to reset
504          *      the counters.
505          */
506         if (data->reset_time && (data->reset_time <= request->timestamp)) {
507
508                 /*
509                  *      Re-set the next time and prev_time for this counters range
510                  */
511                 data->last_reset = data->reset_time;
512                 find_next_reset(data,request->timestamp);
513         }
514
515
516         /*
517          *      Look for the key.  User-Name is special.  It means
518          *      The REAL username, after stripping.
519          */
520         DEBUG2("rlm_sqlcounter: Entering module authorize code");
521         key_vp = (data->key_attr == PW_USER_NAME) ? request->username : pairfind(request->packet->vps, data->key_attr);
522         if (key_vp == NULL) {
523                 DEBUG2("rlm_sqlcounter: Could not find Key value pair");
524                 return ret;
525         }
526
527         /*
528          *      Look for the check item
529          */
530         if ((dattr = dict_attrbyname(data->check_name)) == NULL) {
531                 return ret;
532         }
533         /* DEBUG2("rlm_sqlcounter: Found Check item attribute %d", dattr->attr); */
534         if ((check_vp= pairfind(request->config_items, dattr->attr)) == NULL) {
535                 DEBUG2("rlm_sqlcounter: Could not find Check item value pair");
536                 return ret;
537         }
538
539         /* first, expand %k, %b and %e in query */
540         sqlcounter_expand(querystr, MAX_QUERY_LEN, data->query, instance);
541
542         /* second, xlat any request attribs in query */
543         radius_xlat(responsestr, MAX_QUERY_LEN, querystr, request, NULL);
544
545         /* third, wrap query with sql module & expand */
546         sprintf(querystr, "%%{%%S:%s}", responsestr);
547         sqlcounter_expand(responsestr, MAX_QUERY_LEN, querystr, instance);
548
549         /* Finally, xlat resulting SQL query */
550         radius_xlat(querystr, MAX_QUERY_LEN, responsestr, request, NULL);
551
552         counter = atoi(querystr);
553
554
555         /*
556          * Check if check item > counter
557          */
558         res=check_vp->lvalue - counter;
559         if (res > 0) {
560                 DEBUG2("rlm_sqlcounter: (Check item - counter) is greater than zero");
561                 /*
562                  *      We are assuming that simultaneous-use=1. But
563                  *      even if that does not happen then our user
564                  *      could login at max for 2*max-usage-time Is
565                  *      that acceptable?
566                  */
567
568                 /*
569                  *      User is allowed, but set Session-Timeout.
570                  *      Stolen from main/auth.c
571                  */
572
573                 /*
574                  *      If we are near a reset then add the next
575                  *      limit, so that the user will not need to
576                  *      login again
577                  */
578                 if (data->reset_time && (
579                         res >= (data->reset_time - request->timestamp))) {
580                         res = data->reset_time - request->timestamp;
581                         res += check_vp->lvalue;
582                 }
583
584                 if ((reply_item = pairfind(request->reply->vps, PW_SESSION_TIMEOUT)) != NULL) {
585                         if (reply_item->lvalue > res)
586                                 reply_item->lvalue = res;
587                 } else {
588                         if ((reply_item = paircreate(PW_SESSION_TIMEOUT, PW_TYPE_INTEGER)) == NULL) {
589                                 radlog(L_ERR|L_CONS, "no memory");
590                                 return RLM_MODULE_NOOP;
591                         }
592                         reply_item->lvalue = res;
593                         pairadd(&request->reply->vps, reply_item);
594                 }
595
596                 ret=RLM_MODULE_OK;
597
598                 DEBUG2("rlm_sqlcounter: Authorized user %s, check_item=%d, counter=%d",
599                                 key_vp->strvalue,check_vp->lvalue,counter);
600                 DEBUG2("rlm_sqlcounter: Sent Reply-Item for user %s, Type=Session-Timeout, value=%d",
601                                 key_vp->strvalue,reply_item->lvalue);
602         }
603         else{
604                 char module_fmsg[MAX_STRING_LEN];
605                 VALUE_PAIR *module_fmsg_vp;
606
607                 DEBUG2("rlm_sqlcounter: (Check item - counter) is less than zero");
608
609                 /*
610                  * User is denied access, send back a reply message
611                 */
612                 sprintf(msg, "Your maximum %s usage time has been reached", data->reset);
613                 reply_item=pairmake("Reply-Message", msg, T_OP_EQ);
614                 pairadd(&request->reply->vps, reply_item);
615
616                 snprintf(module_fmsg, sizeof(module_fmsg), "rlm_sqlcounter: Maximum %s usage time reached", data->reset);
617                 module_fmsg_vp = pairmake("Module-Failure-Message", module_fmsg, T_OP_EQ);
618                 pairadd(&request->packet->vps, module_fmsg_vp);
619
620                 ret=RLM_MODULE_REJECT;
621
622                 DEBUG2("rlm_sqlcounter: Rejected user %s, check_item=%d, counter=%d",
623                                 key_vp->strvalue,check_vp->lvalue,counter);
624         }
625
626         return ret;
627 }
628
629 static int sqlcounter_detach(void *instance)
630 {
631         rlm_sqlcounter_t *data = (rlm_sqlcounter_t *) instance;
632
633         paircompare_unregister(data->dict_attr, sqlcounter_cmp);
634         free(data->reset);
635         free(data->query);
636         free(data->check_name);
637         free(data->sqlmod_inst);
638         free(data->counter_name);
639
640         free(instance);
641         return 0;
642 }
643
644 /*
645  *      The module name should be the only globally exported symbol.
646  *      That is, everything else should be 'static'.
647  *
648  *      If the module needs to temporarily modify it's instantiation
649  *      data, the type should be changed to RLM_TYPE_THREAD_UNSAFE.
650  *      The server will then take care of ensuring that the module
651  *      is single-threaded.
652  */
653 module_t rlm_sqlcounter = {
654         "SQL Counter",
655         RLM_TYPE_THREAD_SAFE,           /* type */
656         NULL,                           /* initialization */
657         sqlcounter_instantiate,         /* instantiation */
658         {
659                 NULL,                   /* authentication */
660                 sqlcounter_authorize,   /* authorization */
661                 NULL,                   /* preaccounting */
662                 NULL,                   /* accounting */
663                 NULL,                   /* checksimul */
664                 NULL,                   /* pre-proxy */
665                 NULL,                   /* post-proxy */
666                 NULL                    /* post-auth */
667         },
668         sqlcounter_detach,              /* detach */
669         NULL,                           /* destroy */
670 };
671