Added radtee
authorAlan T. DeKok <aland@freeradius.org>
Fri, 18 Feb 2011 09:37:17 +0000 (10:37 +0100)
committerAlan T. DeKok <aland@freeradius.org>
Fri, 18 Feb 2011 12:48:43 +0000 (13:48 +0100)
Shamelessly taken from http://horde.net/~jwm/software/misc/comparison-tee

scripts/radtee [new file with mode: 0755]

diff --git a/scripts/radtee b/scripts/radtee
new file mode 100755 (executable)
index 0000000..123769d
--- /dev/null
@@ -0,0 +1,563 @@
+#!/usr/bin/env python
+from __future__ import with_statement 
+
+# RADIUS comparison tee v1.0
+# Sniffs local RADIUS traffic, replays incoming requests to a test
+# server, and compares the sniffed responses with the responses
+# generated by the test server.
+#
+# Copyright (c) 2009, Frontier Communications
+# Copyright (c) 2010, John Morrissey <jwm@horde.net>
+#
+# This program is free software; you can redistribute it and/or modify it
+# under the terms of the GNU General Public License as published by the Free
+# Software Foundation; either version 2 of the License, or (at your option)
+# any later version.
+#
+# This program is distributed in the hope that it will be useful, but
+# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+# or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+# for more details.
+#
+# You should have received a copy of the GNU General Public License along
+# with this program; if not, write to the Free Software Foundation, Inc., 59
+# Temple Place, Suite 330, Boston, MA 02111-1307, USA.
+
+# Requires
+# ========
+# - python 2.4 or newer
+# - impacket
+# - pcapy
+# - pyrad, ideally 1.2 or newer
+
+# Output
+# ======
+# - .: 50 successful, matching responses processed.
+# - C=x.x.x.x: Ignored packet sniffed from unknown client.
+# - D: Dropped sniffed packet due to processing bottleneck. Consider
+#      increasing THREADS.
+# - I: Invalid/unparseable packet sniffed.
+# - Mreq: Response was sniffed without sniffing a corresponding request.
+# - Mresp: Request was sniffed without sniffing a corresponding response.
+# - T: Request to test server timed out.
+
+import fcntl
+from getopt import gnu_getopt, GetoptError
+import os
+import Queue
+import re
+import signal
+import socket
+import struct
+import sys
+import thread
+from threading import Thread
+from time import sleep, time
+
+from impacket.ImpactDecoder import EthDecoder
+import pcapy
+from pyrad.client import Client
+from pyrad.dictionary import Dictionary
+from pyrad import packet
+
+
+TEST_DEST = 'server.example.com'
+TEST_SECRET = 'examplesecret'
+
+# Dictionary to use when decoding RADIUS packets. pyrad earlier than
+# v1.2 can't parse $INCLUDE directives, so you must combine FreeRADIUS'
+# dictionary manually, with something like this:
+#
+# import re
+# import sys
+# 
+# def combine(file):
+#     for line in open(file):
+#         matches = re.search(r'^\$INCLUDE\s+(.*)$', line)
+#         if not matches:
+#             sys.stdout.write(line)
+#             continue
+# 
+#         combine(matches.group(1))
+# 
+# combine('/etc/freeradius/dictionary')
+DICTIONARY = '/etc/freeradius/dictionary'
+
+# Number of worker threads to run.
+THREADS = 32
+
+# Mapping of RADIUS request source addresses to shared secrets,
+# so we can decode incoming RADIUS requests.
+#
+# For example:
+#     '127.0.0.1': 'test',
+CLIENTS = {
+}
+
+# Ignore any sniffed requests from these IP addresses.
+IGNORE_CLIENTS = [
+]
+
+# Expected mismatches to ignore and consider the packet matching.
+# Only the differences are compared to these items, so only the
+# differing attrs need be listed in the attrs array.
+#
+# Examples:
+# - Ignore mismatched AccessRejects whose sole difference is a
+#   Reply-Message attribute with the values given.
+#   {
+#       'sniffed': {
+#           'code': packet.AccessReject,
+#           'attrs': [
+#               'Reply-Message=Request Denied',
+#           ],
+#       },
+#       'test': {
+#           'code': packet.AccessReject,
+#           'attrs': [
+#               'Reply-Message=Account is disabled.',
+#           ],
+#       }
+#   },
+#
+# - Ignore mismatched AccessRejects with Reply-Message=Request Denied
+#   and arbitrary Cisco dns-servers in the sniffed packet, and
+#   no Reply-Message and Cisco-AVPair attrs in the response from the
+#   test RADIUS server.
+#    {
+#        'sniffed': {
+#            'code': packet.AccessReject,
+#            'attrs': [
+#                'Reply-Message=Request Denied',
+#                'regex:^Cisco-AVPair=ip:dns-servers=.*$',
+#            ],
+#        },
+#        'test': {
+#            'code': packet.AccessReject,
+#            'attrs': [
+#            ],
+#        }
+#    },
+#
+# - Only apply this stanza to sniffed requests with
+#   'User-Name= user@example.com' (note the leading whitespace).
+#    {
+#        'check': [
+#            'User-Name= user@example.com',
+#        ],
+#        'sniffed': {
+#            'code': packet.AccessReject,
+#            'attrs': [
+#                'Reply-Message=Request Denied',
+#            ],
+#        },
+#        'test': {
+#            'code': packet.AccessAccept,
+#            'attrs': [
+#                'Service-Type=Framed-User',
+#                'Framed-Protocol=PPP',
+#                'Framed-IP-Address=255.255.255.255',
+#                'Framed-MTU=1500',
+#                'Framed-Compression=Van-Jacobson-TCP-IP',
+#            ],
+#        }
+#    },
+IGNORE = [
+]
+
+
+QUEUE = Queue.Queue(maxsize=25000)
+DICT = Dictionary(DICTIONARY)
+
+def code2str(code):
+       if code == packet.AccessRequest:
+               return "Access-Request"
+       elif code == packet.AccessAccept:
+               return "Access-Accept"
+       elif code == packet.AccessReject:
+               return "Access-Reject"
+       elif code == packet.AccountingRequest:
+               return "Accounting-Request"
+       elif code == packet.AccountingResponse:
+               return "Accounting-Response"
+       elif code == packet.AccessChallenge:
+               return "Access-Challenge"
+       elif code == packet.StatusServer:
+               return "Status-Server"
+       elif code == packet.StatusClient:
+               return "Status-Client"
+       elif code == packet.DisconnectRequest:
+               return "Disconnect-Request"
+       elif code == packet.DisconnectACK:
+               return "Disconnect-ACK"
+       elif code == packet.DisconnectNAK:
+               return "Disconnect-NAK"
+       elif code == packet.CoARequest:
+               return "CoA-Request"
+       elif code == packet.CoAACK:
+               return "CoA-ACK"
+       elif code == packet.CoANAK:
+               return "CoA-NAK"
+
+def handlePacket(header, data):
+       """Place captured packets in the queue to be picked up
+       by worker threads."""
+
+       global QUEUE
+
+       try:
+               QUEUE.put_nowait(data)
+       except Queue.Full:
+               sys.stdout.write('D')
+               sys.stdout.flush()
+
+def ignore_applies(pkt, ignore):
+       """Determine whether an ignore stanza (based on its check
+       items) applies to a packet."""
+
+       # All check items must match for this ignore stanza to apply.
+       stanza_applies = True
+       for pair in ignore.get('check', []):
+               attr, value = pair.split('=')
+
+               if attr not in pkt:
+                       return False
+               if value.startswith('regex:'):
+                       if not re.search(value.replace('regex:', '', 1), value):
+                               return False
+               elif pkt[attr] != value:
+                       return False
+
+       return True
+
+def ignores_match(pkt, mismatched, ignore):
+       """Determine whether mismatched AV pairs remain after accounting
+       for ignored differences."""
+
+       non_regex_ignore = [
+               q
+               for q
+                in ignore['attrs']
+                if not q.startswith('regex:')
+       ]
+       regex_ignore = [
+               q
+               for q
+                in ignore['attrs']
+                if q.startswith('regex:')
+       ]
+
+       unmatched_av = mismatched[:]
+       unmatched_rules = ignore['attrs'][:]
+       for av in mismatched:
+               if av in non_regex_ignore:
+                       unmatched_av.remove(av)
+                       unmatched_rules.remove(av)
+                       continue
+               for regex in regex_ignore:
+                       if re.search(regex.replace('regex:', '', 1), av):
+                               unmatched_av.remove(av)
+                               if regex in unmatched_rules:
+                                       unmatched_rules.remove(regex)
+                               break
+
+       if unmatched_av or unmatched_rules:
+               return False
+       return True
+
+def matches(req, sniffed_pkt, test_pkt):
+       """Determine whether a response from the test server matches
+       the response sniffed from the wire, accounting for ignored
+       differences."""
+
+       global IGNORE
+
+       mis_attrs_sniffed = []
+       for k in sniffed_pkt.keys():
+               if sorted(sniffed_pkt[k]) == sorted(test_pkt.get(k, [])):
+                       continue
+               mis_attrs_sniffed.append('%s=%s' % (
+                       k, ', '.join([str(v) for v in sorted(sniffed_pkt[k])])))
+
+       mis_attrs_test = []
+       for k in test_pkt.keys():
+               if sorted(test_pkt[k]) == sorted(sniffed_pkt.get(k, [])):
+                       continue
+               mis_attrs_test.append('%s=%s' % (
+                       k, ', '.join([str(v) for v in sorted(test_pkt[k])])))
+
+       # The packets match without having to consider any ignores.
+       if sniffed_pkt.code == test_pkt.code and \
+          not mis_attrs_sniffed and not mis_attrs_test:
+               return True
+
+       for ignore in IGNORE:
+               if not ignore_applies(req, ignore):
+                       continue
+
+               if ignore['sniffed']['code'] != sniffed_pkt.code or \
+                  ignore['test']['code'] != test_pkt.code:
+                       continue
+
+               if ignores_match(sniffed_pkt, mis_attrs_sniffed, i['sniffed']):
+                       return True
+               if ignores_match(test_pkt, mis_attrs_test, i['test']):
+                       return True
+
+       return False
+
+def log_mismatch(nas, req, passwd, expected, got):
+       """Emit notification that the test server has returned a response
+       that differs from the sniffed response."""
+
+       print 'Mismatch: %s' % nas
+
+       print 'Request: %s' % code2str(req.code)
+       for key in req.keys():
+               if key == 'User-Password':
+                       print '\t%s: %s' % (key, passwd)
+                       continue
+               print '\t%s: %s' % (
+                       key, ', '.join([str(v) for v in req[key]]))
+
+       print 'Expected: %s' % code2str(expected.code)
+       for key in expected.keys():
+               print '\t%s: %s' % (
+                       key, ', '.join([str(v) for v in expected[key]]))
+
+       print 'Got: %s' % code2str(got.code)
+       for key in got.keys():
+               print '\t%s: %s' % (
+                       key, ', '.join([str(v) for v in got[key]]))
+
+       print
+
+REQUESTS = {}
+REQUESTS_LOCK = thread.allocate_lock()
+NUM_SUCCESSFUL = 0
+def check_for_match(key, req_resp):
+       """Send a copy of the original request to the test server and
+       determine whether the response matches the response sniffed from
+       the wire."""
+
+       global DICT, NUM_SUCCESSFUL, TEST_DEST, TEST_SECRET
+       global REQUESTS, REQUESTS_LOCK
+
+       client = Client(server=TEST_DEST,
+               secret=TEST_SECRET, dict=DICT)
+       fwd_req = client.CreateAuthPacket(code=packet.AccessRequest)
+       fwd_req.authenticator = req_resp['req']['pkt'].authenticator
+
+       keys = req_resp['req']['pkt'].keys()
+       for k in keys:
+               for value in req_resp['req']['pkt'][k]:
+                       fwd_req.AddAttribute(k, value)
+       if 'User-Password' in keys:
+               fwd_req['User-Password'] = fwd_req.PwCrypt(req_resp['req']['passwd'])
+       if 'NAS-IP-Address' in fwd_req:
+               del fwd_req['NAS-IP-Address']
+       fwd_req.AddAttribute('NAS-IP-Address', req_resp['req']['ip'])
+
+       try:
+               test_reply = client.SendPacket(fwd_req)
+       except:
+               # Request to test server timed out.
+               sys.stdout.write('T')
+               sys.stdout.flush()
+               with REQUESTS_LOCK:
+                       del REQUESTS[key]
+               return
+
+       if not matches(req_resp['req']['pkt'],
+               req_resp['response']['pkt'], test_reply):
+
+               print
+               log_mismatch(req_resp['req']['ip'],
+                       req_resp['req']['pkt'],
+                       req_resp['req']['passwd'],
+                       req_resp['response']['pkt'], test_reply)
+
+       with REQUESTS_LOCK:
+               # Occasionally, this key isn't present. Maybe retransmissions
+               # due to a short timeout on the remote RADIUS client's end
+               # and a subsequent race?
+               if key in REQUESTS:
+                       del REQUESTS[key]
+
+       NUM_SUCCESSFUL += 1
+       if NUM_SUCCESSFUL % 50 == 0:
+               sys.stdout.write('.')
+               sys.stdout.flush()
+
+class RadiusComparer(Thread):
+       def run(self):
+               global DICT, IGNORE_CLIENTS, QUEUE, REQUESTS, REQUESTS_LOCK
+
+               while True:
+                       data = QUEUE.get()
+                       if not data:
+                               return
+
+                       frame = EthDecoder().decode(data)
+                       ip = frame.child()
+                       udp = ip.child()
+                       rad_raw = udp.child().get_buffer_as_string()
+
+                       try:
+                               pkt = packet.Packet(dict=DICT, packet=rad_raw)
+                       except packet.PacketError:
+                               sys.stdout.write('I')
+                               sys.stdout.flush()
+                               continue
+
+                       if ip.get_ip_src() in IGNORE_CLIENTS:
+                               continue
+
+                       if pkt.code == packet.AccessRequest:
+                               auth = packet.AuthPacket(data[42:])
+                               auth.authenticator = pkt.authenticator
+                               auth.secret = clients.CLIENTS.get(ip.get_ip_src(), None)
+                               if not auth.secret:
+                                       # No configuration for this client.
+                                       sys.stdout.write('C=%s' % ip.get_ip_src())
+                                       sys.stdout.flush()
+                                       continue
+                               passwd = None
+                               if 'User-Password' in pkt.keys():
+                                       passwd = auth.PwDecrypt(pkt['User-Password'][0])
+
+                               key = '%s:%d:%d' % (ip.get_ip_src(),
+                                       udp.get_uh_sport(), pkt.id)
+                               do_compare = None
+                               with REQUESTS_LOCK:
+                                       if key not in REQUESTS:
+                                               REQUESTS[key] = {}
+                                       REQUESTS[key]['req'] = {
+                                               'ip': ip.get_ip_src(),
+                                               'port': udp.get_uh_sport(),
+                                               'pkt': pkt,
+                                               'passwd': passwd,
+                                       }
+                                       REQUESTS[key]['tstamp'] = time()
+                                       if 'response' in REQUESTS[key]:
+                                               do_compare = REQUESTS[key]
+
+                               if do_compare:
+                                       check_for_match(key, do_compare)
+                       elif pkt.code in [packet.AccessAccept, packet.AccessReject]:
+                               key = '%s:%d:%d' % (ip.get_ip_dst(),
+                                       udp.get_uh_dport(), pkt.id)
+                               do_compare = None
+                               with REQUESTS_LOCK:
+                                       if key not in REQUESTS:
+                                               REQUESTS[key] = {}
+                                       REQUESTS[key]['response'] = {
+                                               'ip': ip.get_ip_src(),
+                                               'port': udp.get_uh_sport(),
+                                               'pkt': pkt,
+                                       }
+                                       REQUESTS[key]['tstamp'] = time()
+                                       if 'req' in REQUESTS[key]:
+                                               do_compare = REQUESTS[key]
+
+                               if do_compare:
+                                       check_for_match(key, do_compare)
+                       else:
+                               print >>sys.stderr, \
+                                       'Unsupported packet type received: %d' % pkt.code
+
+class RequestsPruner(Thread):
+       """Prune stale request state periodically."""
+
+       def run(self):
+               global REQUESTS, REQUESTS_LOCK
+
+               while True:
+                       sleep(30)
+
+                       now = time()
+                       with REQUESTS_LOCK:
+                               keys = REQUESTS.keys()
+                               for key in keys:
+                                       if REQUESTS[key]['tstamp'] + 60 >= now:
+                                               continue
+
+                                       if 'req' not in REQUESTS[key]:
+                                               sys.stdout.write('Mreq')
+                                               sys.stdout.flush()
+                                       if 'response' not in REQUESTS[key]:
+                                               sys.stdout.write('Mresp')
+                                               sys.stdout.flush()
+
+                                       del REQUESTS[key]
+
+def usage():
+       print 'Usage: %s INTERFACE' % os.path.basename(sys.argv[0])
+       print ''
+       print '    -h, --help  display this help and exit'
+
+if __name__ == '__main__':
+       global PID_FILE
+
+       progname = os.path.basename(sys.argv[0])
+
+       try:
+               options, iface = gnu_getopt(sys.argv[1:], 'h', ['help'])
+       except GetoptError, e:
+               print '%s: %s' % (progname, str(e))
+               usage()
+               sys.exit(1)
+
+       for option in options:
+               if option[0] == '-h' or option[0] == '--help':
+                       usage()
+                       sys.exit(0)
+
+       if len(iface) != 1:
+               usage()
+               sys.exit(1)
+       iface = iface[0]
+
+       if os.geteuid() != 0:
+               print >>sys.stderr, '%s: must be run as root.' % progname
+               sys.exit(1)
+
+       for i in range(0, THREADS):
+               RadiusComparer().start()
+       RequestsPruner().start()
+
+       s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+
+       # This is Linux-specific, and there's no tenable way to make
+       # it portable.
+       # 
+       # Unfortunately, we need the interface's IP address to filter out
+       # only RADIUS traffic destined for this host (avoiding traffic sent
+       # *by* this host, such as proxied requests or our own traffic) to
+       # avoid replaying requests not directed to the local radiusd.
+       #
+       # Furthermore, this only obtains the interface's *first* IP address,
+       # so we won't notice traffic sent to additional IP addresses on
+       # the given interface.
+       #
+       # This is Good Enough For Me given the effort I care to invest.
+       # Of course, patches enhancing this are welcome.
+       if os.uname()[0] == 'Linux':
+               local_ipaddr = socket.inet_ntoa(fcntl.ioctl(
+                       s.fileno(),
+                       0x8915,  # SIOCGIFADDR
+                       struct.pack('256s', iface[:15])
+               )[20:24])
+       else:
+               raise Exception('Only the Linux operating system is currently supported.')
+
+       p = pcapy.open_live(iface, 1600, 0, 100)
+       p.setfilter('''
+               (dst host %s and udp and dst port 1812) or
+               (src host %s and udp and src port 1812)''' % \
+               (local_ipaddr, local_ipaddr))
+       while True:
+               try:
+                       p.dispatch(1, handlePacket)
+               except KeyboardInterrupt:
+                       sys.exit(0)