import cyrus-sasl-2.1.23
[cyrus-sasl.git] / sample / server.c
1 /* $Id: server.c,v 1.9 2004/03/29 14:56:40 rjs3 Exp $ */
2 /* 
3  * Copyright (c) 1998-2003 Carnegie Mellon University.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer. 
11  *
12  * 2. Redistributions in binary form must reproduce the above copyright
13  *    notice, this list of conditions and the following disclaimer in
14  *    the documentation and/or other materials provided with the
15  *    distribution.
16  *
17  * 3. The name "Carnegie Mellon University" must not be used to
18  *    endorse or promote products derived from this software without
19  *    prior written permission. For permission or any other legal
20  *    details, please contact  
21  *      Office of Technology Transfer
22  *      Carnegie Mellon University
23  *      5000 Forbes Avenue
24  *      Pittsburgh, PA  15213-3890
25  *      (412) 268-4387, fax: (412) 268-7395
26  *      tech-transfer@andrew.cmu.edu
27  *
28  * 4. Redistributions of any form whatsoever must retain the following
29  *    acknowledgment:
30  *    "This product includes software developed by Computing Services
31  *     at Carnegie Mellon University (http://www.cmu.edu/computing/)."
32  *
33  * CARNEGIE MELLON UNIVERSITY DISCLAIMS ALL WARRANTIES WITH REGARD TO
34  * THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
35  * AND FITNESS, IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY BE LIABLE
36  * FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
37  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN
38  * AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
39  * OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
40  */
41
42 #include <config.h>
43
44 #include <stdio.h>
45 #include <stdlib.h>
46 #include <stdarg.h>
47 #include <ctype.h>
48 #include <errno.h>
49 #include <string.h>
50
51 #ifdef HAVE_UNISTD_H
52 #include <unistd.h>
53 #endif
54
55 #include <sys/socket.h>
56 #include <netinet/in.h>
57 #include <arpa/inet.h>
58 #include <netdb.h>
59
60 #include <sasl.h>
61
62 #include "common.h"
63
64 #if !defined(IPV6_BINDV6ONLY) && defined(IN6P_IPV6_V6ONLY)
65 #define IPV6_BINDV6ONLY IN6P_BINDV6ONLY
66 #endif
67 #if !defined(IPV6_V6ONLY) && defined(IPV6_BINDV6ONLY)
68 #define IPV6_V6ONLY     IPV6_BINDV6ONLY
69 #endif
70 #ifndef IPV6_BINDV6ONLY
71 #undef      IPV6_V6ONLY
72 #endif
73
74 /* create a socket listening on port 'port' */
75 /* if af is PF_UNSPEC more than one socket may be returned */
76 /* the returned list is dynamically allocated, so caller needs to free it */
77 int *listensock(const char *port, const int af)
78 {
79     struct addrinfo hints, *ai, *r;
80     int err, maxs, *sock, *socks;
81     const int on = 1;
82
83     memset(&hints, 0, sizeof(hints));
84     hints.ai_flags = AI_PASSIVE;
85     hints.ai_family = af;
86     hints.ai_socktype = SOCK_STREAM;
87     err = getaddrinfo(NULL, port, &hints, &ai);
88     if (err) {
89         fprintf(stderr, "%s\n", gai_strerror(err));
90         exit(EX_USAGE);
91     }
92
93     /* Count max number of sockets we may open */
94     for (maxs = 0, r = ai; r; r = r->ai_next, maxs++)
95         ;
96     socks = malloc((maxs + 1) * sizeof(int));
97     if (!socks) {
98         fprintf(stderr, "couldn't allocate memory for sockets\n");
99         freeaddrinfo(ai);
100         exit(EX_OSERR);
101     }
102
103     socks[0] = 0;       /* num of sockets counter at start of array */
104     sock = socks + 1;
105     for (r = ai; r; r = r->ai_next) {
106         fprintf(stderr, "trying %d, %d, %d\n",r->ai_family, r->ai_socktype, r->ai_protocol);
107         *sock = socket(r->ai_family, r->ai_socktype, r->ai_protocol);
108         if (*sock < 0) {
109             perror("socket");
110             continue;
111         }
112         if (setsockopt(*sock, SOL_SOCKET, SO_REUSEADDR, 
113                        (void *) &on, sizeof(on)) < 0) {
114             perror("setsockopt(SO_REUSEADDR)");
115             close(*sock);
116             continue;
117         }
118 #if defined(IPV6_V6ONLY) && !(defined(__FreeBSD__) && __FreeBSD__ < 3)
119         if (r->ai_family == AF_INET6) {
120             if (setsockopt(*sock, IPPROTO_IPV6, IPV6_BINDV6ONLY,
121                            (void *) &on, sizeof(on)) < 0) {
122                 perror("setsockopt (IPV6_BINDV6ONLY)");
123                 close(*sock);
124                 continue;
125             }
126         }
127 #endif
128         if (bind(*sock, r->ai_addr, r->ai_addrlen) < 0) {
129             perror("bind");
130             close(*sock);
131             continue;
132         }
133
134         if (listen(*sock, 5) < 0) {
135             perror("listen");
136             close(*sock);
137             continue;
138         }
139
140         socks[0]++;
141         sock++;
142     }
143
144     freeaddrinfo(ai);
145
146     if (socks[0] == 0) {
147         fprintf(stderr, "Couldn't bind to any socket\n");
148         free(socks);
149         exit(EX_OSERR);
150     }
151
152     return socks;
153 }
154
155 void usage(void)
156 {
157     fprintf(stderr, "usage: server [-p port] [-s service] [-m mech]\n");
158     exit(EX_USAGE);
159 }
160
161 /* globals because i'm lazy */
162 char *mech;
163
164 /* do the sasl negotiation; return -1 if it fails */
165 int mysasl_negotiate(FILE *in, FILE *out, sasl_conn_t *conn)
166 {
167     char buf[8192];
168     char chosenmech[128];
169     const char *data;
170     int len;
171     int r = SASL_FAIL;
172     const char *userid;
173     
174     /* generate the capability list */
175     if (mech) {
176         dprintf(2, "forcing use of mechanism %s\n", mech);
177         data = strdup(mech);
178         len = strlen(data);
179     } else {
180         int count;
181
182         dprintf(1, "generating client mechanism list... ");
183         r = sasl_listmech(conn, NULL, NULL, " ", NULL,
184                           &data, &len, &count);
185         if (r != SASL_OK) saslfail(r, "generating mechanism list");
186         dprintf(1, "%d mechanisms\n", count);
187     }
188
189     /* send capability list to client */
190     send_string(out, data, len);
191
192     dprintf(1, "waiting for client mechanism...\n");
193     len = recv_string(in, chosenmech, sizeof chosenmech);
194     if (len <= 0) {
195         printf("client didn't choose mechanism\n");
196         fputc('N', out); /* send NO to client */
197         fflush(out);
198         return -1;
199     }
200
201     if (mech && strcasecmp(mech, chosenmech)) {
202         printf("client didn't choose mandatory mechanism\n");
203         fputc('N', out); /* send NO to client */
204         fflush(out);
205         return -1;
206     }
207
208     len = recv_string(in, buf, sizeof(buf));
209     if(len != 1) {
210         saslerr(r, "didn't receive first-send parameter correctly");
211         fputc('N', out);
212         fflush(out);
213         return -1;
214     }
215
216     if(buf[0] == 'Y') {
217         /* receive initial response (if any) */
218         len = recv_string(in, buf, sizeof(buf));
219
220         /* start libsasl negotiation */
221         r = sasl_server_start(conn, chosenmech, buf, len,
222                               &data, &len);
223     } else {
224         r = sasl_server_start(conn, chosenmech, NULL, 0,
225                               &data, &len);
226     }
227     
228     if (r != SASL_OK && r != SASL_CONTINUE) {
229         saslerr(r, "starting SASL negotiation");
230         fputc('N', out); /* send NO to client */
231         fflush(out);
232         return -1;
233     }
234
235     while (r == SASL_CONTINUE) {
236         if (data) {
237             dprintf(2, "sending response length %d...\n", len);
238             fputc('C', out); /* send CONTINUE to client */
239             send_string(out, data, len);
240         } else {
241             dprintf(2, "sending null response...\n");
242             fputc('C', out); /* send CONTINUE to client */
243             send_string(out, "", 0);
244         }
245
246         dprintf(1, "waiting for client reply...\n");
247         len = recv_string(in, buf, sizeof buf);
248         if (len < 0) {
249             printf("client disconnected\n");
250             return -1;
251         }
252
253         r = sasl_server_step(conn, buf, len, &data, &len);
254         if (r != SASL_OK && r != SASL_CONTINUE) {
255             saslerr(r, "performing SASL negotiation");
256             fputc('N', out); /* send NO to client */
257             fflush(out);
258             return -1;
259         }
260     }
261
262     if (r != SASL_OK) {
263         saslerr(r, "incorrect authentication");
264         fputc('N', out); /* send NO to client */
265         fflush(out);
266         return -1;
267     }
268
269     fputc('O', out); /* send OK to client */
270     fflush(out);
271     dprintf(1, "negotiation complete\n");
272
273     r = sasl_getprop(conn, SASL_USERNAME, (const void **) &userid);
274     printf("successful authentication '%s'\n", userid);
275
276     return 0;
277 }
278
279 int main(int argc, char *argv[])
280 {
281     int c;
282     char *port = "12345";
283     char *service = "rcmd";
284     int *l, maxfd=0;
285     int r, i;
286     sasl_conn_t *conn;
287
288     while ((c = getopt(argc, argv, "p:s:m:")) != EOF) {
289         switch(c) {
290         case 'p':
291             port = optarg;
292             break;
293
294         case 's':
295             service = optarg;
296             break;
297
298         case 'm':
299             mech = optarg;
300             break;
301
302         default:
303             usage();
304             break;
305         }
306     }
307
308     /* initialize the sasl library */
309     r = sasl_server_init(NULL, "sample");
310     if (r != SASL_OK) saslfail(r, "initializing libsasl");
311
312     /* get a listening socket */
313     if ((l = listensock(port, PF_UNSPEC)) == NULL) {
314         saslfail(SASL_FAIL, "allocating listensock");
315     }
316
317     for (i = 1; i <= l[0]; i++) {
318        if (l[i] > maxfd)
319            maxfd = l[i];
320     }
321
322     for (;;) {
323         char localaddr[NI_MAXHOST | NI_MAXSERV],
324              remoteaddr[NI_MAXHOST | NI_MAXSERV];
325         char myhostname[1024+1];
326         char hbuf[NI_MAXHOST], pbuf[NI_MAXSERV];
327         struct sockaddr_storage local_ip, remote_ip;
328         int niflags, error;
329         int salen;
330         int nfds, fd = -1;
331         FILE *in, *out;
332         fd_set readfds;
333
334         FD_ZERO(&readfds);
335         for (i = 1; i <= l[0]; i++)
336             FD_SET(l[i], &readfds);
337
338         nfds = select(maxfd + 1, &readfds, 0, 0, 0);
339         if (nfds <= 0) {
340             if (nfds < 0 && errno != EINTR)
341                 perror("select");
342             continue;
343         }
344
345        for (i = 1; i <= l[0]; i++) 
346            if (FD_ISSET(l[i], &readfds)) {
347                fd = accept(l[i], NULL, NULL);
348                break;
349            }
350
351         if (fd < 0) {
352             if (errno != EINTR)
353                 perror("accept");
354             continue;
355         }
356
357         printf("accepted new connection\n");
358
359         /* set ip addresses */
360         salen = sizeof(local_ip);
361         if (getsockname(fd, (struct sockaddr *)&local_ip, &salen) < 0) {
362             perror("getsockname");
363         }
364         niflags = (NI_NUMERICHOST | NI_NUMERICSERV);
365 #ifdef NI_WITHSCOPEID
366         if (((struct sockaddr *)&local_ip)->sa_family == AF_INET6)
367             niflags |= NI_WITHSCOPEID;
368 #endif
369         error = getnameinfo((struct sockaddr *)&local_ip, salen, hbuf,
370                             sizeof(hbuf), pbuf, sizeof(pbuf), niflags);
371         if (error != 0) {
372             fprintf(stderr, "getnameinfo: %s\n", gai_strerror(error));
373             strcpy(hbuf, "unknown");
374             strcpy(pbuf, "unknown");
375         }
376         snprintf(localaddr, sizeof(localaddr), "%s;%s", hbuf, pbuf);
377
378         salen = sizeof(remote_ip);
379         if (getpeername(fd, (struct sockaddr *)&remote_ip, &salen) < 0) {
380             perror("getpeername");
381         }
382
383         niflags = (NI_NUMERICHOST | NI_NUMERICSERV);
384 #ifdef NI_WITHSCOPEID
385         if (((struct sockaddr *)&remote_ip)->sa_family == AF_INET6)
386             niflags |= NI_WITHSCOPEID;
387 #endif
388         error = getnameinfo((struct sockaddr *)&remote_ip, salen, hbuf,
389                             sizeof(hbuf), pbuf, sizeof(pbuf), niflags);
390         if (error != 0) {
391             fprintf(stderr, "getnameinfo: %s\n", gai_strerror(error));
392             strcpy(hbuf, "unknown");
393             strcpy(pbuf, "unknown");
394         }
395         snprintf(remoteaddr, sizeof(remoteaddr), "%s;%s", hbuf, pbuf);
396
397         r = gethostname(myhostname, sizeof(myhostname)-1);
398         if(r == -1) saslfail(r, "getting hostname");
399
400         r = sasl_server_new(service, myhostname, NULL, localaddr, remoteaddr,
401                             NULL, 0, &conn);
402         if (r != SASL_OK) saslfail(r, "allocating connection state");
403
404         /* set external properties here
405            sasl_setprop(conn, SASL_SSF_EXTERNAL, &extprops); */
406
407         /* set required security properties here
408            sasl_setprop(conn, SASL_SEC_PROPS, &secprops); */
409
410         in = fdopen(fd, "r");
411         out = fdopen(fd, "w");
412
413         r = mysasl_negotiate(in, out, conn);
414         if (r == SASL_OK) {
415             /* send/receive data */
416
417
418         }
419
420         printf("closing connection\n");
421         fclose(in);
422         fclose(out);
423         close(fd);
424         sasl_dispose(&conn);
425     }
426
427     sasl_done();
428 }