ttls: return channel bindings on half round trip success
[freeradius.git] / scripts / radtee
1 #!/usr/bin/env python
2 from __future__ import with_statement 
3
4 # RADIUS comparison tee v1.0
5 # Sniffs local RADIUS traffic, replays incoming requests to a test
6 # server, and compares the sniffed responses with the responses
7 # generated by the test server.
8 #
9 # Copyright (c) 2009, Frontier Communications
10 # Copyright (c) 2010, John Morrissey <jwm@horde.net>
11 #
12 # This program is free software; you can redistribute it and/or modify it
13 # under the terms of the GNU General Public License as published by the Free
14 # Software Foundation; either version 2 of the License, or (at your option)
15 # any later version.
16 #
17 # This program is distributed in the hope that it will be useful, but
18 # WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
19 # or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
20 # for more details.
21 #
22 # You should have received a copy of the GNU General Public License along
23 # with this program; if not, write to the Free Software Foundation, Inc., 59
24 # Temple Place, Suite 330, Boston, MA 02111-1307, USA.
25
26 # Requires
27 # ========
28 # - python 2.4 or newer
29 # - impacket
30 # - pcapy
31 # - pyrad, ideally 1.2 or newer
32
33 # Output
34 # ======
35 # - .: 50 successful, matching responses processed.
36 # - C=x.x.x.x: Ignored packet sniffed from unknown client.
37 # - D: Dropped sniffed packet due to processing bottleneck. Consider
38 #      increasing THREADS.
39 # - I: Invalid/unparseable packet sniffed.
40 # - Mreq: Response was sniffed without sniffing a corresponding request.
41 # - Mresp: Request was sniffed without sniffing a corresponding response.
42 # - T: Request to test server timed out.
43
44 import fcntl
45 from getopt import gnu_getopt, GetoptError
46 import os
47 import Queue
48 import re
49 import signal
50 import socket
51 import struct
52 import sys
53 import thread
54 from threading import Thread
55 from time import sleep, time
56
57 from impacket.ImpactDecoder import EthDecoder
58 import pcapy
59 from pyrad.client import Client
60 from pyrad.dictionary import Dictionary
61 from pyrad import packet
62
63
64 TEST_DEST = 'server.example.com'
65 TEST_SECRET = 'examplesecret'
66
67 # Dictionary to use when decoding RADIUS packets. pyrad earlier than
68 # v1.2 can't parse $INCLUDE directives, so you must combine FreeRADIUS'
69 # dictionary manually, with something like this:
70 #
71 # import re
72 # import sys
73
74 # def combine(file):
75 #     for line in open(file):
76 #         matches = re.search(r'^\$INCLUDE\s+(.*)$', line)
77 #         if not matches:
78 #             sys.stdout.write(line)
79 #             continue
80
81 #         combine(matches.group(1))
82
83 # combine('/etc/freeradius/dictionary')
84 DICTIONARY = '/etc/freeradius/dictionary'
85
86 # Number of worker threads to run.
87 THREADS = 32
88
89 # Mapping of RADIUS request source addresses to shared secrets,
90 # so we can decode incoming RADIUS requests.
91 #
92 # For example:
93 #     '127.0.0.1': 'test',
94 CLIENTS = {
95 }
96
97 # Ignore any sniffed requests from these IP addresses.
98 IGNORE_CLIENTS = [
99 ]
100
101 # Expected mismatches to ignore and consider the packet matching.
102 # Only the differences are compared to these items, so only the
103 # differing attrs need be listed in the attrs array.
104 #
105 # Examples:
106 # - Ignore mismatched AccessRejects whose sole difference is a
107 #   Reply-Message attribute with the values given.
108 #   {
109 #       'sniffed': {
110 #           'code': packet.AccessReject,
111 #           'attrs': [
112 #               'Reply-Message=Request Denied',
113 #           ],
114 #       },
115 #       'test': {
116 #           'code': packet.AccessReject,
117 #           'attrs': [
118 #               'Reply-Message=Account is disabled.',
119 #           ],
120 #       }
121 #   },
122 #
123 # - Ignore mismatched AccessRejects with Reply-Message=Request Denied
124 #   and arbitrary Cisco dns-servers in the sniffed packet, and
125 #   no Reply-Message and Cisco-AVPair attrs in the response from the
126 #   test RADIUS server.
127 #    {
128 #        'sniffed': {
129 #            'code': packet.AccessReject,
130 #            'attrs': [
131 #                'Reply-Message=Request Denied',
132 #                'regex:^Cisco-AVPair=ip:dns-servers=.*$',
133 #            ],
134 #        },
135 #        'test': {
136 #            'code': packet.AccessReject,
137 #            'attrs': [
138 #            ],
139 #        }
140 #    },
141 #
142 # - Only apply this stanza to sniffed requests with
143 #   'User-Name= user@example.com' (note the leading whitespace).
144 #    {
145 #        'check': [
146 #            'User-Name= user@example.com',
147 #        ],
148 #        'sniffed': {
149 #            'code': packet.AccessReject,
150 #            'attrs': [
151 #                'Reply-Message=Request Denied',
152 #            ],
153 #        },
154 #        'test': {
155 #            'code': packet.AccessAccept,
156 #            'attrs': [
157 #                'Service-Type=Framed-User',
158 #                'Framed-Protocol=PPP',
159 #                'Framed-IP-Address=255.255.255.255',
160 #                'Framed-MTU=1500',
161 #                'Framed-Compression=Van-Jacobson-TCP-IP',
162 #            ],
163 #        }
164 #    },
165 IGNORE = [
166 ]
167
168
169 QUEUE = Queue.Queue(maxsize=25000)
170 DICT = Dictionary(DICTIONARY)
171
172 def code2str(code):
173         if code == packet.AccessRequest:
174                 return "Access-Request"
175         elif code == packet.AccessAccept:
176                 return "Access-Accept"
177         elif code == packet.AccessReject:
178                 return "Access-Reject"
179         elif code == packet.AccountingRequest:
180                 return "Accounting-Request"
181         elif code == packet.AccountingResponse:
182                 return "Accounting-Response"
183         elif code == packet.AccessChallenge:
184                 return "Access-Challenge"
185         elif code == packet.StatusServer:
186                 return "Status-Server"
187         elif code == packet.StatusClient:
188                 return "Status-Client"
189         elif code == packet.DisconnectRequest:
190                 return "Disconnect-Request"
191         elif code == packet.DisconnectACK:
192                 return "Disconnect-ACK"
193         elif code == packet.DisconnectNAK:
194                 return "Disconnect-NAK"
195         elif code == packet.CoARequest:
196                 return "CoA-Request"
197         elif code == packet.CoAACK:
198                 return "CoA-ACK"
199         elif code == packet.CoANAK:
200                 return "CoA-NAK"
201
202 def handlePacket(header, data):
203         """Place captured packets in the queue to be picked up
204         by worker threads."""
205
206         global QUEUE
207
208         try:
209                 QUEUE.put_nowait(data)
210         except Queue.Full:
211                 sys.stdout.write('D')
212                 sys.stdout.flush()
213
214 def ignore_applies(pkt, ignore):
215         """Determine whether an ignore stanza (based on its check
216         items) applies to a packet."""
217
218         # All check items must match for this ignore stanza to apply.
219         stanza_applies = True
220         for pair in ignore.get('check', []):
221                 attr, value = pair.split('=')
222
223                 if attr not in pkt:
224                         return False
225                 if value.startswith('regex:'):
226                         if not re.search(value.replace('regex:', '', 1), value):
227                                 return False
228                 elif pkt[attr] != value:
229                         return False
230
231         return True
232
233 def ignores_match(pkt, mismatched, ignore):
234         """Determine whether mismatched AV pairs remain after accounting
235         for ignored differences."""
236
237         non_regex_ignore = [
238                 q
239                 for q
240                  in ignore['attrs']
241                  if not q.startswith('regex:')
242         ]
243         regex_ignore = [
244                 q
245                 for q
246                  in ignore['attrs']
247                  if q.startswith('regex:')
248         ]
249
250         unmatched_av = mismatched[:]
251         unmatched_rules = ignore['attrs'][:]
252         for av in mismatched:
253                 if av in non_regex_ignore:
254                         unmatched_av.remove(av)
255                         unmatched_rules.remove(av)
256                         continue
257                 for regex in regex_ignore:
258                         if re.search(regex.replace('regex:', '', 1), av):
259                                 unmatched_av.remove(av)
260                                 if regex in unmatched_rules:
261                                         unmatched_rules.remove(regex)
262                                 break
263
264         if unmatched_av or unmatched_rules:
265                 return False
266         return True
267
268 def matches(req, sniffed_pkt, test_pkt):
269         """Determine whether a response from the test server matches
270         the response sniffed from the wire, accounting for ignored
271         differences."""
272
273         global IGNORE
274
275         mis_attrs_sniffed = []
276         for k in sniffed_pkt.keys():
277                 if sorted(sniffed_pkt[k]) == sorted(test_pkt.get(k, [])):
278                         continue
279                 mis_attrs_sniffed.append('%s=%s' % (
280                         k, ', '.join([str(v) for v in sorted(sniffed_pkt[k])])))
281
282         mis_attrs_test = []
283         for k in test_pkt.keys():
284                 if sorted(test_pkt[k]) == sorted(sniffed_pkt.get(k, [])):
285                         continue
286                 mis_attrs_test.append('%s=%s' % (
287                         k, ', '.join([str(v) for v in sorted(test_pkt[k])])))
288
289         # The packets match without having to consider any ignores.
290         if sniffed_pkt.code == test_pkt.code and \
291            not mis_attrs_sniffed and not mis_attrs_test:
292                 return True
293
294         for ignore in IGNORE:
295                 if not ignore_applies(req, ignore):
296                         continue
297
298                 if ignore['sniffed']['code'] != sniffed_pkt.code or \
299                    ignore['test']['code'] != test_pkt.code:
300                         continue
301
302                 if ignores_match(sniffed_pkt, mis_attrs_sniffed, i['sniffed']):
303                         return True
304                 if ignores_match(test_pkt, mis_attrs_test, i['test']):
305                         return True
306
307         return False
308
309 def log_mismatch(nas, req, passwd, expected, got):
310         """Emit notification that the test server has returned a response
311         that differs from the sniffed response."""
312
313         print 'Mismatch: %s' % nas
314
315         print 'Request: %s' % code2str(req.code)
316         for key in req.keys():
317                 if key == 'User-Password':
318                         print '\t%s: %s' % (key, passwd)
319                         continue
320                 print '\t%s: %s' % (
321                         key, ', '.join([str(v) for v in req[key]]))
322
323         print 'Expected: %s' % code2str(expected.code)
324         for key in expected.keys():
325                 print '\t%s: %s' % (
326                         key, ', '.join([str(v) for v in expected[key]]))
327
328         print 'Got: %s' % code2str(got.code)
329         for key in got.keys():
330                 print '\t%s: %s' % (
331                         key, ', '.join([str(v) for v in got[key]]))
332
333         print
334
335 REQUESTS = {}
336 REQUESTS_LOCK = thread.allocate_lock()
337 NUM_SUCCESSFUL = 0
338 def check_for_match(key, req_resp):
339         """Send a copy of the original request to the test server and
340         determine whether the response matches the response sniffed from
341         the wire."""
342
343         global DICT, NUM_SUCCESSFUL, TEST_DEST, TEST_SECRET
344         global REQUESTS, REQUESTS_LOCK
345
346         client = Client(server=TEST_DEST,
347                 secret=TEST_SECRET, dict=DICT)
348         fwd_req = client.CreateAuthPacket(code=packet.AccessRequest)
349         fwd_req.authenticator = req_resp['req']['pkt'].authenticator
350
351         keys = req_resp['req']['pkt'].keys()
352         for k in keys:
353                 for value in req_resp['req']['pkt'][k]:
354                         fwd_req.AddAttribute(k, value)
355         if 'User-Password' in keys:
356                 fwd_req['User-Password'] = fwd_req.PwCrypt(req_resp['req']['passwd'])
357         if 'NAS-IP-Address' in fwd_req:
358                 del fwd_req['NAS-IP-Address']
359         fwd_req.AddAttribute('NAS-IP-Address', req_resp['req']['ip'])
360
361         try:
362                 test_reply = client.SendPacket(fwd_req)
363         except:
364                 # Request to test server timed out.
365                 sys.stdout.write('T')
366                 sys.stdout.flush()
367                 with REQUESTS_LOCK:
368                         del REQUESTS[key]
369                 return
370
371         if not matches(req_resp['req']['pkt'],
372                 req_resp['response']['pkt'], test_reply):
373
374                 print
375                 log_mismatch(req_resp['req']['ip'],
376                         req_resp['req']['pkt'],
377                         req_resp['req']['passwd'],
378                         req_resp['response']['pkt'], test_reply)
379
380         with REQUESTS_LOCK:
381                 # Occasionally, this key isn't present. Maybe retransmissions
382                 # due to a short timeout on the remote RADIUS client's end
383                 # and a subsequent race?
384                 if key in REQUESTS:
385                         del REQUESTS[key]
386
387         NUM_SUCCESSFUL += 1
388         if NUM_SUCCESSFUL % 50 == 0:
389                 sys.stdout.write('.')
390                 sys.stdout.flush()
391
392 class RadiusComparer(Thread):
393         def run(self):
394                 global DICT, IGNORE_CLIENTS, QUEUE, REQUESTS, REQUESTS_LOCK
395
396                 while True:
397                         data = QUEUE.get()
398                         if not data:
399                                 return
400
401                         frame = EthDecoder().decode(data)
402                         ip = frame.child()
403                         udp = ip.child()
404                         rad_raw = udp.child().get_buffer_as_string()
405
406                         try:
407                                 pkt = packet.Packet(dict=DICT, packet=rad_raw)
408                         except packet.PacketError:
409                                 sys.stdout.write('I')
410                                 sys.stdout.flush()
411                                 continue
412
413                         if ip.get_ip_src() in IGNORE_CLIENTS:
414                                 continue
415
416                         if pkt.code == packet.AccessRequest:
417                                 auth = packet.AuthPacket(data[42:])
418                                 auth.authenticator = pkt.authenticator
419                                 auth.secret = clients.CLIENTS.get(ip.get_ip_src(), None)
420                                 if not auth.secret:
421                                         # No configuration for this client.
422                                         sys.stdout.write('C=%s' % ip.get_ip_src())
423                                         sys.stdout.flush()
424                                         continue
425                                 passwd = None
426                                 if 'User-Password' in pkt.keys():
427                                         passwd = auth.PwDecrypt(pkt['User-Password'][0])
428
429                                 key = '%s:%d:%d' % (ip.get_ip_src(),
430                                         udp.get_uh_sport(), pkt.id)
431                                 do_compare = None
432                                 with REQUESTS_LOCK:
433                                         if key not in REQUESTS:
434                                                 REQUESTS[key] = {}
435                                         REQUESTS[key]['req'] = {
436                                                 'ip': ip.get_ip_src(),
437                                                 'port': udp.get_uh_sport(),
438                                                 'pkt': pkt,
439                                                 'passwd': passwd,
440                                         }
441                                         REQUESTS[key]['tstamp'] = time()
442                                         if 'response' in REQUESTS[key]:
443                                                 do_compare = REQUESTS[key]
444
445                                 if do_compare:
446                                         check_for_match(key, do_compare)
447                         elif pkt.code in [packet.AccessAccept, packet.AccessReject]:
448                                 key = '%s:%d:%d' % (ip.get_ip_dst(),
449                                         udp.get_uh_dport(), pkt.id)
450                                 do_compare = None
451                                 with REQUESTS_LOCK:
452                                         if key not in REQUESTS:
453                                                 REQUESTS[key] = {}
454                                         REQUESTS[key]['response'] = {
455                                                 'ip': ip.get_ip_src(),
456                                                 'port': udp.get_uh_sport(),
457                                                 'pkt': pkt,
458                                         }
459                                         REQUESTS[key]['tstamp'] = time()
460                                         if 'req' in REQUESTS[key]:
461                                                 do_compare = REQUESTS[key]
462
463                                 if do_compare:
464                                         check_for_match(key, do_compare)
465                         else:
466                                 print >>sys.stderr, \
467                                         'Unsupported packet type received: %d' % pkt.code
468
469 class RequestsPruner(Thread):
470         """Prune stale request state periodically."""
471
472         def run(self):
473                 global REQUESTS, REQUESTS_LOCK
474
475                 while True:
476                         sleep(30)
477
478                         now = time()
479                         with REQUESTS_LOCK:
480                                 keys = REQUESTS.keys()
481                                 for key in keys:
482                                         if REQUESTS[key]['tstamp'] + 60 >= now:
483                                                 continue
484
485                                         if 'req' not in REQUESTS[key]:
486                                                 sys.stdout.write('Mreq')
487                                                 sys.stdout.flush()
488                                         if 'response' not in REQUESTS[key]:
489                                                 sys.stdout.write('Mresp')
490                                                 sys.stdout.flush()
491
492                                         del REQUESTS[key]
493
494 def usage():
495         print 'Usage: %s INTERFACE' % os.path.basename(sys.argv[0])
496         print ''
497         print '    -h, --help  display this help and exit'
498
499 if __name__ == '__main__':
500         global PID_FILE
501
502         progname = os.path.basename(sys.argv[0])
503
504         try:
505                 options, iface = gnu_getopt(sys.argv[1:], 'h', ['help'])
506         except GetoptError, e:
507                 print '%s: %s' % (progname, str(e))
508                 usage()
509                 sys.exit(1)
510
511         for option in options:
512                 if option[0] == '-h' or option[0] == '--help':
513                         usage()
514                         sys.exit(0)
515
516         if len(iface) != 1:
517                 usage()
518                 sys.exit(1)
519         iface = iface[0]
520
521         if os.geteuid() != 0:
522                 print >>sys.stderr, '%s: must be run as root.' % progname
523                 sys.exit(1)
524
525         for i in range(0, THREADS):
526                 RadiusComparer().start()
527         RequestsPruner().start()
528
529         s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
530
531         # This is Linux-specific, and there's no tenable way to make
532         # it portable.
533         # 
534         # Unfortunately, we need the interface's IP address to filter out
535         # only RADIUS traffic destined for this host (avoiding traffic sent
536         # *by* this host, such as proxied requests or our own traffic) to
537         # avoid replaying requests not directed to the local radiusd.
538         #
539         # Furthermore, this only obtains the interface's *first* IP address,
540         # so we won't notice traffic sent to additional IP addresses on
541         # the given interface.
542         #
543         # This is Good Enough For Me given the effort I care to invest.
544         # Of course, patches enhancing this are welcome.
545         if os.uname()[0] == 'Linux':
546                 local_ipaddr = socket.inet_ntoa(fcntl.ioctl(
547                         s.fileno(),
548                         0x8915,  # SIOCGIFADDR
549                         struct.pack('256s', iface[:15])
550                 )[20:24])
551         else:
552                 raise Exception('Only the Linux operating system is currently supported.')
553
554         p = pcapy.open_live(iface, 1600, 0, 100)
555         p.setfilter('''
556                 (dst host %s and udp and dst port 1812) or
557                 (src host %s and udp and src port 1812)''' % \
558                 (local_ipaddr, local_ipaddr))
559         while True:
560                 try:
561                         p.dispatch(1, handlePacket)
562                 except KeyboardInterrupt:
563                         sys.exit(0)