Removed pthread_mutex stuff, as the sem_trywait(), etc. semaphores
[freeradius.git] / src / modules / rlm_sql / sql.c
1 /*
2  *  sql.c               rlm_sql - FreeRADIUS SQL Module
3  *              Main code directly taken from ICRADIUS
4  *
5  * Version:     $Id$
6  *
7  *   This program is free software; you can redistribute it and/or modify
8  *   it under the terms of the GNU General Public License as published by
9  *   the Free Software Foundation; either version 2 of the License, or
10  *   (at your option) any later version.
11  *
12  *   This program is distributed in the hope that it will be useful,
13  *   but WITHOUT ANY WARRANTY; without even the implied warranty of
14  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  *   GNU General Public License for more details.
16  *
17  *   You should have received a copy of the GNU General Public License
18  *   along with this program; if not, write to the Free Software
19  *   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
20  *
21  * Copyright 2001  The FreeRADIUS server project
22  * Copyright 2000  Mike Machado <mike@innercite.com>
23  * Copyright 2000  Alan DeKok <aland@ox.org>
24  * Copyright 2001  Chad Miller <cmiller@surfsouth.com>
25  */
26
27
28 #include        <sys/types.h>
29 #include        <sys/socket.h>
30 #include        <sys/time.h>
31 #include        <sys/file.h>
32 #include        <string.h>
33 #include        <sys/stat.h>
34 #include        <netinet/in.h>
35
36 #include        <stdio.h>
37 #include        <stdlib.h>
38 #include        <netdb.h>
39 #include        <pwd.h>
40 #include        <time.h>
41 #include        <ctype.h>
42 #include        <unistd.h>
43 #include        <signal.h>
44 #include        <errno.h>
45 #include        <sys/wait.h>
46
47 #if HAVE_PTHREAD_H
48 #include        <pthread.h>
49 #endif
50
51 #include        "radiusd.h"
52 #include        "conffile.h"
53 #include        "rlm_sql.h"
54
55
56 /*
57  * Connect to a server.  If error, set this socket's state to be "sockunconnected"
58  * and set a grace period, during which we won't try connecting again (to prevent unduly
59  * lagging the server and being impolite to a DB server that may be having other 
60  * issues).  If successful in connecting, set state to sockconnected.   - chad
61  */
62 static int connect_single_socket(SQLSOCK *sqlsocket, SQL_INST *inst) {
63         if ((inst->module->sql_init_socket)(sqlsocket, inst->config) < 0) {
64                 radlog(L_CONS | L_ERR, "rlm_sql:  Failed to connect DB handle #%d", sqlsocket->id);
65                 inst->connect_after = time(NULL) + inst->config->connect_failure_retry_delay;
66                 sqlsocket->state = sockunconnected;
67                 return(-1);
68         } else {
69                 radlog(L_DBG, "rlm_sql:  Connected new DB handle, #%d", sqlsocket->id);
70                 sqlsocket->state = sockconnected;
71                 return(0);
72         }
73 }
74
75
76 /*************************************************************************
77  *
78  *      Function: sql_init_socketpool
79  *
80  *      Purpose: Connect to the sql server, if possible
81  *
82  *************************************************************************/
83 int sql_init_socketpool(SQL_INST * inst) {
84
85         SQLSOCK *sqlsocket;
86         int     i;
87
88         inst->connect_after = 0;
89         inst->used = 0;
90         inst->sqlpool = NULL;
91
92         for (i = 0; i < inst->config->num_sql_socks; i++) {
93
94                 sqlsocket = rad_malloc(sizeof(SQLSOCK));
95                 if (sqlsocket == NULL) {
96                         return -1;
97                 }
98                 sqlsocket->conn = NULL;
99                 sqlsocket->id = i;
100                 sqlsocket->state = sockunconnected;
101
102 #if HAVE_PTHREAD_H
103                 /*
104                  *  FIXME! Check return codes!
105                  */
106                 sqlsocket->semaphore = (sem_t *) rad_malloc(sizeof(sem_t));
107                 sem_init(sqlsocket->semaphore, 0, SQLSOCK_UNLOCKED);
108 #else
109                 sqlsocket->in_use = SQLSOCK_UNLOCKED;
110 #endif
111
112                 if (time(NULL) > inst->connect_after) {
113                         /* this sets the sqlsocket->state, and possibly sets inst->connect_after */
114                         connect_single_socket(sqlsocket, inst);
115                 }
116
117                 /* Add this socket to the list of sockets */
118                 sqlsocket->next = inst->sqlpool;
119                 inst->sqlpool = sqlsocket;
120         }
121
122         return 1;
123 }
124
125 /*************************************************************************
126  *
127  *     Function: sql_poolfree
128  *
129  *     Purpose: Clean up and free sql pool
130  *
131  *************************************************************************/
132 void sql_poolfree(SQL_INST * inst) {
133
134         SQLSOCK *cur;
135
136         for (cur = inst->sqlpool; cur; cur = cur->next) {
137                 sql_close_socket(inst, cur);
138         }
139 }
140
141
142 /*************************************************************************
143  *
144  *      Function: sql_close_socket
145  *
146  *      Purpose: Close and free a sql sqlsocket
147  *
148  *************************************************************************/
149 int sql_close_socket(SQL_INST *inst, SQLSOCK * sqlsocket) {
150
151         radlog(L_DBG, "rlm_sql: Closing sqlsocket %d", sqlsocket->id);
152         (inst->module->sql_close)(sqlsocket, inst->config);
153 #if HAVE_PTHREAD_H
154         sem_destroy(sqlsocket->semaphore);
155 #endif
156         free(sqlsocket);
157         return 1;
158 }
159
160
161 /*************************************************************************
162  *
163  *      Function: sql_get_socket
164  *
165  *      Purpose: Return a SQL sqlsocket from the connection pool           
166  *
167  *************************************************************************/
168 SQLSOCK * sql_get_socket(SQL_INST * inst) {
169         SQLSOCK *cur;
170         struct timeval timeout;
171         int tried_to_connect = 0;
172
173         while (inst->used == inst->config->num_sql_socks) {
174                 radlog(L_ERR, "rlm_sql: All sockets are being used! Please increase the maximum number of sockets!");
175           return NULL;
176         }
177
178         for (cur = inst->sqlpool; cur; cur = cur->next) {
179
180                 /* if we happen upon an unconnected socket, and this instance's grace 
181                  * period on (re)connecting has expired, then try to connect it.  This 
182                  * should be really rare.  - chad
183                  */
184                 if ((cur->state == sockunconnected) && (time(NULL) > inst->connect_after)) {
185                         tried_to_connect = 1;
186                         radlog(L_INFO, "rlm_sql: Trying to (re)connect an unconnected handle...");
187                         connect_single_socket(cur, inst);
188                 }
189
190                 /* if we still aren't connected, ignore this handle */
191                 if (cur->state == sockunconnected) {
192                         radlog(L_DBG, "rlm_sql: Ignoring unconnected handle");
193                         continue;
194                 }
195
196 #if HAVE_PTHREAD_H
197                 if (sem_trywait(cur->semaphore) == 0) {
198 #else
199                 if (cur->in_use == SQLSOCK_UNLOCKED) {
200 #endif
201                         (inst->used)++;
202 #ifne HAVE_PTHREAD_H
203                         cur->in_use = SQLSOCK_LOCKED;
204 #endif
205                         radlog(L_DBG, "rlm_sql: Reserving sql socket id: %d", cur->id);
206                         return cur;
207                 }
208         }
209
210         /* We get here if every DB handle is unconnected and unconnectABLE */
211         radlog((tried_to_connect = 0) ? (L_DBG) : (L_CONS | L_ERR), "rlm_sql:  There are no DB handles to use!");
212         return NULL;
213 }
214
215 /*************************************************************************
216  *
217  *      Function: sql_release_socket
218  *
219  *      Purpose: Frees a SQL sqlsocket back to the connection pool           
220  *
221  *************************************************************************/
222 int sql_release_socket(SQL_INST * inst, SQLSOCK * sqlsocket) {
223
224         (inst->used)--;
225 #if HAVE_PTHREAD_H
226         sem_post(sqlsocket->semaphore);
227 #else
228         sqlsocket->in_use = SQLSOCK_UNLOCKED;
229 #endif
230
231         radlog(L_DBG, "rlm_sql: Released sql socket id: %d", sqlsocket->id);
232
233         return 1;
234 }
235
236
237 /*************************************************************************
238  *
239  *      Function: sql_userparse
240  *
241  *      Purpose: Read entries from the database and fill VALUE_PAIR structures
242  *
243  *************************************************************************/
244 int sql_userparse(VALUE_PAIR ** first_pair, SQL_ROW row, int mode) {
245
246         DICT_ATTR *attr;
247         VALUE_PAIR *pair, *check;
248
249         if ((attr = dict_attrbyname(row[2])) == (DICT_ATTR *) NULL) {
250                 radlog(L_ERR | L_CONS, "rlm_sql: unknown attribute %s", row[2]);
251                 return (-1);
252         }
253
254         /*
255          * If attribute is already there, skip it because we checked usercheck first 
256          * and we want user settings to over ride group settings 
257          */
258         if ((check = pairfind(*first_pair, attr->attr)) != NULL &&
259 #if defined( BINARY_FILTERS )
260                         attr->type != PW_TYPE_ABINARY &&
261 #endif
262                         mode == PW_VP_GROUPDATA)
263                 return 0;
264
265         pair = pairmake(row[2], row[3], T_OP_CMP_EQ);
266         pairadd(first_pair, pair);
267
268         vp_printlist(stderr, *first_pair);
269
270         return 0;
271 }
272
273
274 /*************************************************************************
275  *
276  *      Function: sql_getvpdata
277  *
278  *      Purpose: Get any group check or reply pairs
279  *
280  *************************************************************************/
281 int sql_getvpdata(SQL_INST * inst, SQLSOCK * sqlsocket, VALUE_PAIR **pair, char *query, int mode) {
282
283         SQL_ROW row;
284         int     rows = 0;
285
286         if ((inst->module->sql_select_query)(sqlsocket, inst->config, query) < 0) {
287                 radlog(L_ERR, "rlm_sql_getvpdata: database query error");
288                 return -1;
289         }
290         while ((row = (inst->module->sql_fetch_row)(sqlsocket, inst->config))) {
291                 if (sql_userparse(pair, row, mode) != 0) {
292                         radlog(L_ERR | L_CONS, "rlm_sql:  Error getting data from database");
293                         (inst->module->sql_finish_select_query)(sqlsocket, inst->config);
294                         return -1;
295                 }
296                 rows++;
297         }
298         (inst->module->sql_finish_select_query)(sqlsocket, inst->config);
299
300         return rows;
301 }
302
303
304 static int got_alrm;
305 static void
306 alrm_handler() {
307         got_alrm = 1;
308 }
309
310 /*************************************************************************
311  *
312  *      Function: sql_check_ts
313  *
314  *      Purpose: Checks the terminal server for a spacific login entry
315  *
316  *************************************************************************/
317 static int sql_check_ts(SQL_ROW row) {
318
319         int     pid, st, e;
320         int     n;
321         NAS    *nas;
322         char    session_id[12];
323         char   *s;
324         void    (*handler) (int);
325
326         /*
327          *      Find NAS type.
328          */
329         if ((nas = nas_find(ip_addr(row[4]))) == NULL) {
330                 radlog(L_ERR, "rlm_sql:  unknown NAS [%s]", row[4]);
331                 return -1;
332         }
333
334         /*
335          *      Fork.
336          */
337         handler = signal(SIGCHLD, SIG_DFL);
338         if ((pid = fork()) < 0) {
339                 radlog(L_ERR, "rlm_sql: fork: %s", strerror(errno));
340                 signal(SIGCHLD, handler);
341                 return -1;
342         }
343
344         if (pid > 0) {
345                 /*
346                  *      Parent - Wait for checkrad to terminate.
347                  *      We timeout in 10 seconds.
348                  */
349                 got_alrm = 0;
350                 signal(SIGALRM, alrm_handler);
351                 alarm(10);
352                 while ((e = waitpid(pid, &st, 0)) != pid)
353                         if (e < 0 && (errno != EINTR || got_alrm))
354                                 break;
355                 alarm(0);
356                 signal(SIGCHLD, handler);
357                 if (got_alrm) {
358                         kill(pid, SIGTERM);
359                         sleep(1);
360                         kill(pid, SIGKILL);
361                         radlog(L_ERR, "rlm_sql:  Check-TS: timeout waiting for checkrad");
362                         return 2;
363                 }
364                 if (e < 0) {
365                         radlog(L_ERR, "rlm_sql:  Check-TS: unknown error in waitpid()");
366                         return 2;
367                 }
368                 return WEXITSTATUS(st);
369         }
370
371         /*
372          *      Child - exec checklogin with the right parameters.
373          */
374         for (n = 32; n >= 3; n--)
375                 close(n);
376
377         sprintf(session_id, "%.8s", row[1]);
378
379         s = CHECKRAD2;
380         execl(CHECKRAD2, "checkrad", nas->nastype, row[4], row[5],
381                                 row[2], session_id, NULL);
382         if (errno == ENOENT) {
383                 s = CHECKRAD1;
384                 execl(CHECKRAD1, "checklogin", nas->nastype, row[4], row[5],
385                                         row[2], session_id, NULL);
386         }
387         radlog(L_ERR, "rlm_sql:  Check-TS: exec %s: %s", s, strerror(errno));
388
389         /*
390          *      Exit - 2 means "some error occured".
391          */
392         exit(2);
393
394 }
395
396
397 /*************************************************************************
398  *
399  *      Function: sql_check_multi
400  *
401  *      Purpose: Check radius accounting for duplicate logins
402  *
403  *************************************************************************/
404 int sql_check_multi(SQL_INST * inst, SQLSOCK * sqlsocket, char *name, VALUE_PAIR * request, int maxsimul) {
405
406         char    querystr[MAX_QUERY_LEN];
407         char    authstr[256];
408         VALUE_PAIR *fra;
409         SQL_ROW row;
410         int     count = 0;
411         uint32_t ipno = 0;
412         int     mpp = 1;
413
414         sprintf(authstr, "UserName = '%s'", name);
415         sprintf(querystr, "SELECT COUNT(*) FROM %s WHERE %s AND AcctStopTime = 0", inst->config->sql_acct_table, authstr);
416         if ((inst->module->sql_select_query)(sqlsocket, inst->config, querystr) < 0) {
417                 radlog(L_ERR, "sql_check_multi: database query error");
418                 return -1;
419         }
420
421         row = (inst->module->sql_fetch_row)(sqlsocket, inst->config);
422         count = atoi(row[0]);
423         (inst->module->sql_finish_select_query)(sqlsocket, inst->config);
424
425         if (count < maxsimul)
426                 return 0;
427
428         /*
429          * *      Setup some stuff, like for MPP detection.
430          */
431         if ((fra = pairfind(request, PW_FRAMED_IP_ADDRESS)) != NULL)
432                 ipno = htonl(fra->lvalue);
433
434         count = 0;
435         sprintf(querystr, "SELECT * FROM %s WHERE %s AND AcctStopTime = 0", inst->config->sql_acct_table, authstr);
436         if ((inst->module->sql_select_query)(sqlsocket, inst->config, querystr) < 0) {
437                 radlog(L_ERR, "sql_check_multi: database query error");
438                 return -1;
439         }
440         while ((row = (inst->module->sql_fetch_row)(sqlsocket, inst->config))) {
441                 int     check = sql_check_ts(row);
442
443                 if (check == 1) {
444                         count++;
445
446                         if (ipno && atoi(row[19]) == ipno)
447                                 mpp = 2;
448
449                 } else if (check == 2)
450                         radlog(L_ERR, "rlm_sql:  Problem with checkrad [%s] (from nas %s)", name, row[4]);
451                 else {
452                         /*
453                          *      False record - zap it
454                          */
455
456                         if (inst->config->deletestalesessions) {
457                                 SQLSOCK *sqlsocket1;
458
459                                 radlog(L_ERR, "rlm_sql:  Deleteing stale session [%s] (from nas %s/%s)", row[2], row[4], row[5]);
460                                 sqlsocket1 = sql_get_socket(inst);
461                                 sprintf(querystr, "DELETE FROM %s WHERE RadAcctId = '%s'", inst->config->sql_acct_table, row[0]);
462                                 (inst->module->sql_query)(sqlsocket1, inst->config, querystr);
463                                 (inst->module->sql_finish_query)(sqlsocket1, inst->config);
464                                 sql_release_socket(inst, sqlsocket1);
465                         }
466                 }
467         }
468         (inst->module->sql_finish_select_query)(sqlsocket, inst->config);
469
470         return (count < maxsimul) ? 0 : mpp;
471 }
472
473 void query_log(SQL_INST * inst, char *querystr) {
474         FILE   *sqlfile = 0;
475
476         if (inst->config->sqltrace) {
477                 if ((sqlfile = fopen(inst->config->tracefile, "a")) == (FILE *) NULL) {
478                         radlog(L_ERR, "rlm_sql: Couldn't open file %s",
479                                                  inst->config->tracefile);
480                 } else {
481 #if defined(F_LOCK) && !defined(BSD)
482                         (void) lockf((int) sqlfile, (int) F_LOCK, (off_t) MAX_QUERY_LEN);
483 #else
484                         (void) flock(sqlfile, SQL_LOCK_EX);
485 #endif
486                         fputs(querystr, sqlfile);
487                         fputs(";\n", sqlfile);
488                         fclose(sqlfile);
489                 }
490         }
491 }
492
493 int sql_set_user(SQL_INST *inst, REQUEST *request, char *sqlusername, char *username) {
494         VALUE_PAIR *vp=NULL;
495         char    tmpuser[MAX_STRING_LEN];
496
497         tmpuser[0]=0;
498         sqlusername[0]=0;
499
500         /* Remove any user attr we added previously */
501         pairdelete(&request->packet->vps, PW_SQL_USER_NAME);
502
503         if(username) {
504                 strNcpy(tmpuser, username, MAX_STRING_LEN);
505         } else if(strlen(inst->config->query_user)) {
506                 radius_xlat(tmpuser, MAX_STRING_LEN, inst->config->query_user, request, NULL);
507         } else {
508                 return 0;
509         }
510
511         if(strlen(tmpuser)) {
512                 sql_escape_string(sqlusername, tmpuser, MAX_STRING_LEN);
513                 DEBUG2("sql_set_user:  escaped user --> '%s'", sqlusername);
514                 vp = pairmake("SQL-User-Name", sqlusername, 0);
515                 if (!vp) {
516                         radlog(L_ERR, "%s", librad_errstr);
517                         return -1;
518                 }
519
520                 pairadd(&request->packet->vps, vp);
521                 return 0;
522         }
523         return -1;
524 }
525
526 /*
527  *      Purpose: Esacpe "'" and any other wierd charactors
528  */
529 int sql_escape_string(char *to, char *from, int length) {
530         int x, y;
531
532         DEBUG2("sql_escape in:  '%s'", from);
533
534         for(x=0, y=0; (x < length) && (from[x]!='\0'); x++) {
535     switch (from[x]) {
536     case 0:                             
537       to[y++]= '\\';
538       to[y++]= '0';
539       break;
540     case '\n':                          
541       to[y++]= '\\';
542       to[y++]= 'n';
543       break;
544     case '\r':
545       to[y++]= '\\';
546       to[y++]= 'r';
547       break;
548     case '\\':
549       to[y++]= '\\';
550       to[y++]= '\\';
551       break;
552     case '\'':
553       to[y++]= '\\';
554       to[y++]= '\'';
555       break;
556     case '"':                           
557       to[y++]= '\\';
558       to[y++]= '"';
559       break;
560     case ';':                           
561       to[y++]= '\\';
562       to[y++]= ';';
563       break;
564                 /* Ascii file separator */
565     case '\032':                        
566       to[y++]= '\\';
567       to[y++]= 'Z';
568       break;
569     default:
570       to[y++]= from[x];
571     }
572   }
573         to[y]=0;
574
575         DEBUG2("sql_escape out:  '%s'", to);
576         return 1;
577 }