a6d4cad7cd87148958375423a6c54598086b0e44
[mech_eap.git] / tests / hwsim / hostapd.py
1 # Python class for controlling hostapd
2 # Copyright (c) 2013-2014, Jouni Malinen <j@w1.fi>
3 #
4 # This software may be distributed under the terms of the BSD license.
5 # See README for more details.
6
7 import os
8 import time
9 import logging
10 import binascii
11 import struct
12 import wpaspy
13 import remotehost
14 import utils
15
16 logger = logging.getLogger()
17 hapd_ctrl = '/var/run/hostapd'
18 hapd_global = '/var/run/hostapd-global'
19
20 def mac2tuple(mac):
21     return struct.unpack('6B', binascii.unhexlify(mac.replace(':','')))
22
23 class HostapdGlobal:
24     def __init__(self, apdev=None):
25         try:
26             hostname = apdev['hostname']
27             port = apdev['port']
28         except:
29             hostname = None
30             port = 8878
31         self.host = remotehost.Host(hostname)
32         self.hostname = hostname
33         self.port = port
34         if hostname is None:
35             self.ctrl = wpaspy.Ctrl(hapd_global)
36             self.mon = wpaspy.Ctrl(hapd_global)
37             self.dbg = ""
38         else:
39             self.ctrl = wpaspy.Ctrl(hostname, port)
40             self.mon = wpaspy.Ctrl(hostname, port)
41             self.dbg = hostname + "/" + str(port)
42         self.mon.attach()
43
44     def request(self, cmd, timeout=10):
45         logger.debug(self.dbg + ": CTRL(global): " + cmd)
46         return self.ctrl.request(cmd, timeout)
47
48     def wait_event(self, events, timeout):
49         start = os.times()[4]
50         while True:
51             while self.mon.pending():
52                 ev = self.mon.recv()
53                 logger.debug(self.dbg + "(global): " + ev)
54                 for event in events:
55                     if event in ev:
56                         return ev
57             now = os.times()[4]
58             remaining = start + timeout - now
59             if remaining <= 0:
60                 break
61             if not self.mon.pending(timeout=remaining):
62                 break
63         return None
64
65     def add(self, ifname, driver=None):
66         cmd = "ADD " + ifname + " " + hapd_ctrl
67         if driver:
68             cmd += " " + driver
69         res = self.request(cmd)
70         if not "OK" in res:
71             raise Exception("Could not add hostapd interface " + ifname)
72
73     def add_iface(self, ifname, confname):
74         res = self.request("ADD " + ifname + " config=" + confname)
75         if not "OK" in res:
76             raise Exception("Could not add hostapd interface")
77
78     def add_bss(self, phy, confname, ignore_error=False):
79         res = self.request("ADD bss_config=" + phy + ":" + confname)
80         if not "OK" in res:
81             if not ignore_error:
82                 raise Exception("Could not add hostapd BSS")
83
84     def remove(self, ifname):
85         self.request("REMOVE " + ifname, timeout=30)
86
87     def relog(self):
88         self.request("RELOG")
89
90     def flush(self):
91         self.request("FLUSH")
92
93     def get_ctrl_iface_port(self, ifname):
94         if self.hostname is None:
95             return None
96
97         res = self.request("INTERFACES ctrl")
98         lines = res.splitlines()
99         found = False
100         for line in lines:
101             words = line.split()
102             if words[0] == ifname:
103                 found = True
104                 break
105         if not found:
106             raise Exception("Could not find UDP port for " + ifname)
107         res = line.find("ctrl_iface=udp:")
108         if res == -1:
109             raise Exception("Wrong ctrl_interface format")
110         words = line.split(":")
111         return int(words[1])
112
113     def terminate(self):
114         self.mon.detach()
115         self.mon.close()
116         self.mon = None
117         self.ctrl.terminate()
118         self.ctrl = None
119
120 class Hostapd:
121     def __init__(self, ifname, bssidx=0, hostname=None, port=8877):
122         self.host = remotehost.Host(hostname, ifname)
123         self.ifname = ifname
124         if hostname is None:
125             self.ctrl = wpaspy.Ctrl(os.path.join(hapd_ctrl, ifname))
126             self.mon = wpaspy.Ctrl(os.path.join(hapd_ctrl, ifname))
127             self.dbg = ifname
128         else:
129             self.ctrl = wpaspy.Ctrl(hostname, port)
130             self.mon = wpaspy.Ctrl(hostname, port)
131             self.dbg = hostname + "/" + ifname
132         self.mon.attach()
133         self.bssid = None
134         self.bssidx = bssidx
135
136     def close_ctrl(self):
137         if self.mon is not None:
138             self.mon.detach()
139             self.mon.close()
140             self.mon = None
141             self.ctrl.close()
142             self.ctrl = None
143
144     def own_addr(self):
145         if self.bssid is None:
146             self.bssid = self.get_status_field('bssid[%d]' % self.bssidx)
147         return self.bssid
148
149     def request(self, cmd):
150         logger.debug(self.dbg + ": CTRL: " + cmd)
151         return self.ctrl.request(cmd)
152
153     def ping(self):
154         return "PONG" in self.request("PING")
155
156     def set(self, field, value):
157         if not "OK" in self.request("SET " + field + " " + value):
158             raise Exception("Failed to set hostapd parameter " + field)
159
160     def set_defaults(self):
161         self.set("driver", "nl80211")
162         self.set("hw_mode", "g")
163         self.set("channel", "1")
164         self.set("ieee80211n", "1")
165         self.set("logger_stdout", "-1")
166         self.set("logger_stdout_level", "0")
167
168     def set_open(self, ssid):
169         self.set_defaults()
170         self.set("ssid", ssid)
171
172     def set_wpa2_psk(self, ssid, passphrase):
173         self.set_defaults()
174         self.set("ssid", ssid)
175         self.set("wpa_passphrase", passphrase)
176         self.set("wpa", "2")
177         self.set("wpa_key_mgmt", "WPA-PSK")
178         self.set("rsn_pairwise", "CCMP")
179
180     def set_wpa_psk(self, ssid, passphrase):
181         self.set_defaults()
182         self.set("ssid", ssid)
183         self.set("wpa_passphrase", passphrase)
184         self.set("wpa", "1")
185         self.set("wpa_key_mgmt", "WPA-PSK")
186         self.set("wpa_pairwise", "TKIP")
187
188     def set_wpa_psk_mixed(self, ssid, passphrase):
189         self.set_defaults()
190         self.set("ssid", ssid)
191         self.set("wpa_passphrase", passphrase)
192         self.set("wpa", "3")
193         self.set("wpa_key_mgmt", "WPA-PSK")
194         self.set("wpa_pairwise", "TKIP")
195         self.set("rsn_pairwise", "CCMP")
196
197     def set_wep(self, ssid, key):
198         self.set_defaults()
199         self.set("ssid", ssid)
200         self.set("wep_key0", key)
201
202     def enable(self):
203         if not "OK" in self.request("ENABLE"):
204             raise Exception("Failed to enable hostapd interface " + self.ifname)
205
206     def disable(self):
207         if not "OK" in self.request("DISABLE"):
208             raise Exception("Failed to disable hostapd interface " + self.ifname)
209
210     def dump_monitor(self):
211         while self.mon.pending():
212             ev = self.mon.recv()
213             logger.debug(self.dbg + ": " + ev)
214
215     def wait_event(self, events, timeout):
216         start = os.times()[4]
217         while True:
218             while self.mon.pending():
219                 ev = self.mon.recv()
220                 logger.debug(self.dbg + ": " + ev)
221                 for event in events:
222                     if event in ev:
223                         return ev
224             now = os.times()[4]
225             remaining = start + timeout - now
226             if remaining <= 0:
227                 break
228             if not self.mon.pending(timeout=remaining):
229                 break
230         return None
231
232     def get_status(self):
233         res = self.request("STATUS")
234         lines = res.splitlines()
235         vals = dict()
236         for l in lines:
237             [name,value] = l.split('=', 1)
238             vals[name] = value
239         return vals
240
241     def get_status_field(self, field):
242         vals = self.get_status()
243         if field in vals:
244             return vals[field]
245         return None
246
247     def get_driver_status(self):
248         res = self.request("STATUS-DRIVER")
249         lines = res.splitlines()
250         vals = dict()
251         for l in lines:
252             [name,value] = l.split('=', 1)
253             vals[name] = value
254         return vals
255
256     def get_driver_status_field(self, field):
257         vals = self.get_driver_status()
258         if field in vals:
259             return vals[field]
260         return None
261
262     def get_config(self):
263         res = self.request("GET_CONFIG")
264         lines = res.splitlines()
265         vals = dict()
266         for l in lines:
267             [name,value] = l.split('=', 1)
268             vals[name] = value
269         return vals
270
271     def mgmt_rx(self, timeout=5):
272         ev = self.wait_event(["MGMT-RX"], timeout=timeout)
273         if ev is None:
274             return None
275         msg = {}
276         frame = binascii.unhexlify(ev.split(' ')[1])
277         msg['frame'] = frame
278
279         hdr = struct.unpack('<HH6B6B6BH', frame[0:24])
280         msg['fc'] = hdr[0]
281         msg['subtype'] = (hdr[0] >> 4) & 0xf
282         hdr = hdr[1:]
283         msg['duration'] = hdr[0]
284         hdr = hdr[1:]
285         msg['da'] = "%02x:%02x:%02x:%02x:%02x:%02x" % hdr[0:6]
286         hdr = hdr[6:]
287         msg['sa'] = "%02x:%02x:%02x:%02x:%02x:%02x" % hdr[0:6]
288         hdr = hdr[6:]
289         msg['bssid'] = "%02x:%02x:%02x:%02x:%02x:%02x" % hdr[0:6]
290         hdr = hdr[6:]
291         msg['seq_ctrl'] = hdr[0]
292         msg['payload'] = frame[24:]
293
294         return msg
295
296     def mgmt_tx(self, msg):
297         t = (msg['fc'], 0) + mac2tuple(msg['da']) + mac2tuple(msg['sa']) + mac2tuple(msg['bssid']) + (0,)
298         hdr = struct.pack('<HH6B6B6BH', *t)
299         self.request("MGMT_TX " + binascii.hexlify(hdr + msg['payload']))
300
301     def get_sta(self, addr, info=None, next=False):
302         cmd = "STA-NEXT " if next else "STA "
303         if addr is None:
304             res = self.request("STA-FIRST")
305         elif info:
306             res = self.request(cmd + addr + " " + info)
307         else:
308             res = self.request(cmd + addr)
309         lines = res.splitlines()
310         vals = dict()
311         first = True
312         for l in lines:
313             if first and '=' not in l:
314                 vals['addr'] = l
315                 first = False
316             else:
317                 [name,value] = l.split('=', 1)
318                 vals[name] = value
319         return vals
320
321     def get_mib(self, param=None):
322         if param:
323             res = self.request("MIB " + param)
324         else:
325             res = self.request("MIB")
326         lines = res.splitlines()
327         vals = dict()
328         for l in lines:
329             name_val = l.split('=', 1)
330             if len(name_val) > 1:
331                 vals[name_val[0]] = name_val[1]
332         return vals
333
334     def get_pmksa(self, addr):
335         res = self.request("PMKSA")
336         lines = res.splitlines()
337         for l in lines:
338             if addr not in l:
339                 continue
340             vals = dict()
341             [index,aa,pmkid,expiration,opportunistic] = l.split(' ')
342             vals['index'] = index
343             vals['pmkid'] = pmkid
344             vals['expiration'] = expiration
345             vals['opportunistic'] = opportunistic
346             return vals
347         return None
348
349 def add_ap(apdev, params, wait_enabled=True, no_enable=False, timeout=30):
350         if isinstance(apdev, dict):
351             ifname = apdev['ifname']
352             try:
353                 hostname = apdev['hostname']
354                 port = apdev['port']
355                 logger.info("Starting AP " + hostname + "/" + port + " " + ifname)
356             except:
357                 logger.info("Starting AP " + ifname)
358                 hostname = None
359                 port = 8878
360         else:
361             ifname = apdev
362             logger.info("Starting AP " + ifname + " (old add_ap argument type)")
363             hostname = None
364             port = 8878
365         hapd_global = HostapdGlobal(apdev)
366         hapd_global.remove(ifname)
367         hapd_global.add(ifname)
368         port = hapd_global.get_ctrl_iface_port(ifname)
369         hapd = Hostapd(ifname, hostname=hostname, port=port)
370         if not hapd.ping():
371             raise Exception("Could not ping hostapd")
372         hapd.set_defaults()
373         fields = [ "ssid", "wpa_passphrase", "nas_identifier", "wpa_key_mgmt",
374                    "wpa",
375                    "wpa_pairwise", "rsn_pairwise", "auth_server_addr",
376                    "acct_server_addr", "osu_server_uri" ]
377         for field in fields:
378             if field in params:
379                 hapd.set(field, params[field])
380         for f,v in params.items():
381             if f in fields:
382                 continue
383             if isinstance(v, list):
384                 for val in v:
385                     hapd.set(f, val)
386             else:
387                 hapd.set(f, v)
388         if no_enable:
389             return hapd
390         hapd.enable()
391         if wait_enabled:
392             ev = hapd.wait_event(["AP-ENABLED", "AP-DISABLED"], timeout=timeout)
393             if ev is None:
394                 raise Exception("AP startup timed out")
395             if "AP-ENABLED" not in ev:
396                 raise Exception("AP startup failed")
397         return hapd
398
399 def add_bss(apdev, ifname, confname, ignore_error=False):
400     phy = utils.get_phy(apdev)
401     try:
402         hostname = apdev['hostname']
403         port = apdev['port']
404         logger.info("Starting BSS " + hostname + "/" + port + " phy=" + phy + " ifname=" + ifname)
405     except:
406         logger.info("Starting BSS phy=" + phy + " ifname=" + ifname)
407         hostname = None
408         port = 8878
409     hapd_global = HostapdGlobal(apdev)
410     hapd_global.add_bss(phy, confname, ignore_error)
411     port = hapd_global.get_ctrl_iface_port(ifname)
412     hapd = Hostapd(ifname, hostname=hostname, port=port)
413     if not hapd.ping():
414         raise Exception("Could not ping hostapd")
415     return hapd
416
417 def add_iface(apdev, confname):
418     ifname = apdev['ifname']
419     try:
420         hostname = apdev['hostname']
421         port = apdev['port']
422         logger.info("Starting interface " + hostname + "/" + port + " " + ifname)
423     except:
424         logger.info("Starting interface " + ifname)
425         hostname = None
426         port = 8878
427     hapd_global = HostapdGlobal(apdev)
428     hapd_global.add_iface(ifname, confname)
429     port = hapd_global.get_ctrl_iface_port(ifname)
430     hapd = Hostapd(ifname, hostname=hostname, port=port)
431     if not hapd.ping():
432         raise Exception("Could not ping hostapd")
433     return hapd
434
435 def remove_bss(apdev, ifname=None):
436     if ifname == None:
437         ifname = apdev['ifname']
438     try:
439         hostname = apdev['hostname']
440         port = apdev['port']
441         logger.info("Removing BSS " + hostname + "/" + port + " " + ifname)
442     except:
443         logger.info("Removing BSS " + ifname)
444     hapd_global = HostapdGlobal(apdev)
445     hapd_global.remove(ifname)
446
447 def terminate(apdev):
448     try:
449         hostname = apdev['hostname']
450         port = apdev['port']
451         logger.info("Terminating hostapd " + hostname + "/" + port)
452     except:
453         logger.info("Terminating hostapd")
454     hapd_global = HostapdGlobal(apdev)
455     hapd_global.terminate()
456
457 def wpa2_params(ssid=None, passphrase=None):
458     params = { "wpa": "2",
459                "wpa_key_mgmt": "WPA-PSK",
460                "rsn_pairwise": "CCMP" }
461     if ssid:
462         params["ssid"] = ssid
463     if passphrase:
464         params["wpa_passphrase"] = passphrase
465     return params
466
467 def wpa_params(ssid=None, passphrase=None):
468     params = { "wpa": "1",
469                "wpa_key_mgmt": "WPA-PSK",
470                "wpa_pairwise": "TKIP" }
471     if ssid:
472         params["ssid"] = ssid
473     if passphrase:
474         params["wpa_passphrase"] = passphrase
475     return params
476
477 def wpa_mixed_params(ssid=None, passphrase=None):
478     params = { "wpa": "3",
479                "wpa_key_mgmt": "WPA-PSK",
480                "wpa_pairwise": "TKIP",
481                "rsn_pairwise": "CCMP" }
482     if ssid:
483         params["ssid"] = ssid
484     if passphrase:
485         params["wpa_passphrase"] = passphrase
486     return params
487
488 def radius_params():
489     params = { "auth_server_addr": "127.0.0.1",
490                "auth_server_port": "1812",
491                "auth_server_shared_secret": "radius",
492                "nas_identifier": "nas.w1.fi" }
493     return params
494
495 def wpa_eap_params(ssid=None):
496     params = radius_params()
497     params["wpa"] = "1"
498     params["wpa_key_mgmt"] = "WPA-EAP"
499     params["wpa_pairwise"] = "TKIP"
500     params["ieee8021x"] = "1"
501     if ssid:
502         params["ssid"] = ssid
503     return params
504
505 def wpa2_eap_params(ssid=None):
506     params = radius_params()
507     params["wpa"] = "2"
508     params["wpa_key_mgmt"] = "WPA-EAP"
509     params["rsn_pairwise"] = "CCMP"
510     params["ieee8021x"] = "1"
511     if ssid:
512         params["ssid"] = ssid
513     return params
514
515 def b_only_params(channel="1", ssid=None, country=None):
516     params = { "hw_mode" : "b",
517                "channel" : channel }
518     if ssid:
519         params["ssid"] = ssid
520     if country:
521         params["country_code"] = country
522     return params
523
524 def g_only_params(channel="1", ssid=None, country=None):
525     params = { "hw_mode" : "g",
526                "channel" : channel }
527     if ssid:
528         params["ssid"] = ssid
529     if country:
530         params["country_code"] = country
531     return params
532
533 def a_only_params(channel="36", ssid=None, country=None):
534     params = { "hw_mode" : "a",
535                "channel" : channel }
536     if ssid:
537         params["ssid"] = ssid
538     if country:
539         params["country_code"] = country
540     return params
541
542 def ht20_params(channel="1", ssid=None, country=None):
543     params = { "ieee80211n" : "1",
544                "channel" : channel,
545                "hw_mode" : "g" }
546     if int(channel) > 14:
547         params["hw_mode"] = "a"
548     if ssid:
549         params["ssid"] = ssid
550     if country:
551         params["country_code"] = country
552     return params
553
554 def ht40_plus_params(channel="1", ssid=None, country=None):
555     params = ht20_params(channel, ssid, country)
556     params['ht_capab'] = "[HT40+]"
557     return params
558
559 def ht40_minus_params(channel="1", ssid=None, country=None):
560     params = ht20_params(channel, ssid, country)
561     params['ht_capab'] = "[HT40-]"
562     return params