6d6d641d4194f49d46ab7fd7dd948a52cbf5c29d
[mech_eap.git] / tests / hwsim / hostapd.py
1 # Python class for controlling hostapd
2 # Copyright (c) 2013-2014, Jouni Malinen <j@w1.fi>
3 #
4 # This software may be distributed under the terms of the BSD license.
5 # See README for more details.
6
7 import os
8 import time
9 import logging
10 import binascii
11 import struct
12 import wpaspy
13 import remotehost
14 import utils
15
16 logger = logging.getLogger()
17 hapd_ctrl = '/var/run/hostapd'
18 hapd_global = '/var/run/hostapd-global'
19
20 def mac2tuple(mac):
21     return struct.unpack('6B', binascii.unhexlify(mac.replace(':','')))
22
23 class HostapdGlobal:
24     def __init__(self, apdev=None):
25         try:
26             hostname = apdev['hostname']
27             port = apdev['port']
28         except:
29             hostname = None
30             port = 8878
31         self.host = remotehost.Host(hostname)
32         self.hostname = hostname
33         self.port = port
34         if hostname is None:
35             self.ctrl = wpaspy.Ctrl(hapd_global)
36             self.mon = wpaspy.Ctrl(hapd_global)
37             self.dbg = ""
38         else:
39             self.ctrl = wpaspy.Ctrl(hostname, port)
40             self.mon = wpaspy.Ctrl(hostname, port)
41             self.dbg = hostname + "/" + str(port)
42         self.mon.attach()
43
44     def request(self, cmd, timeout=10):
45         logger.debug(self.dbg + ": CTRL(global): " + cmd)
46         return self.ctrl.request(cmd, timeout)
47
48     def wait_event(self, events, timeout):
49         start = os.times()[4]
50         while True:
51             while self.mon.pending():
52                 ev = self.mon.recv()
53                 logger.debug(self.dbg + "(global): " + ev)
54                 for event in events:
55                     if event in ev:
56                         return ev
57             now = os.times()[4]
58             remaining = start + timeout - now
59             if remaining <= 0:
60                 break
61             if not self.mon.pending(timeout=remaining):
62                 break
63         return None
64
65     def add(self, ifname, driver=None):
66         cmd = "ADD " + ifname + " " + hapd_ctrl
67         if driver:
68             cmd += " " + driver
69         res = self.request(cmd)
70         if not "OK" in res:
71             raise Exception("Could not add hostapd interface " + ifname)
72
73     def add_iface(self, ifname, confname):
74         res = self.request("ADD " + ifname + " config=" + confname)
75         if not "OK" in res:
76             raise Exception("Could not add hostapd interface")
77
78     def add_bss(self, phy, confname, ignore_error=False):
79         res = self.request("ADD bss_config=" + phy + ":" + confname)
80         if not "OK" in res:
81             if not ignore_error:
82                 raise Exception("Could not add hostapd BSS")
83
84     def remove(self, ifname):
85         self.request("REMOVE " + ifname, timeout=30)
86
87     def relog(self):
88         self.request("RELOG")
89
90     def flush(self):
91         self.request("FLUSH")
92
93     def get_ctrl_iface_port(self, ifname):
94         if self.hostname is None:
95             return None
96
97         res = self.request("INTERFACES ctrl")
98         lines = res.splitlines()
99         found = False
100         for line in lines:
101             words = line.split()
102             if words[0] == ifname:
103                 found = True
104                 break
105         if not found:
106             raise Exception("Could not find UDP port for " + ifname)
107         res = line.find("ctrl_iface=udp:")
108         if res == -1:
109             raise Exception("Wrong ctrl_interface format")
110         words = line.split(":")
111         return int(words[1])
112
113     def terminate(self):
114         self.mon.detach()
115         self.mon.close()
116         self.mon = None
117         self.ctrl.terminate()
118         self.ctrl = None
119
120 class Hostapd:
121     def __init__(self, ifname, bssidx=0, hostname=None, port=8877):
122         self.hostname = hostname
123         self.host = remotehost.Host(hostname, ifname)
124         self.ifname = ifname
125         if hostname is None:
126             self.ctrl = wpaspy.Ctrl(os.path.join(hapd_ctrl, ifname))
127             self.mon = wpaspy.Ctrl(os.path.join(hapd_ctrl, ifname))
128             self.dbg = ifname
129         else:
130             self.ctrl = wpaspy.Ctrl(hostname, port)
131             self.mon = wpaspy.Ctrl(hostname, port)
132             self.dbg = hostname + "/" + ifname
133         self.mon.attach()
134         self.bssid = None
135         self.bssidx = bssidx
136
137     def close_ctrl(self):
138         if self.mon is not None:
139             self.mon.detach()
140             self.mon.close()
141             self.mon = None
142             self.ctrl.close()
143             self.ctrl = None
144
145     def own_addr(self):
146         if self.bssid is None:
147             self.bssid = self.get_status_field('bssid[%d]' % self.bssidx)
148         return self.bssid
149
150     def request(self, cmd):
151         logger.debug(self.dbg + ": CTRL: " + cmd)
152         return self.ctrl.request(cmd)
153
154     def ping(self):
155         return "PONG" in self.request("PING")
156
157     def set(self, field, value):
158         if not "OK" in self.request("SET " + field + " " + value):
159             raise Exception("Failed to set hostapd parameter " + field)
160
161     def set_defaults(self):
162         self.set("driver", "nl80211")
163         self.set("hw_mode", "g")
164         self.set("channel", "1")
165         self.set("ieee80211n", "1")
166         self.set("logger_stdout", "-1")
167         self.set("logger_stdout_level", "0")
168
169     def set_open(self, ssid):
170         self.set_defaults()
171         self.set("ssid", ssid)
172
173     def set_wpa2_psk(self, ssid, passphrase):
174         self.set_defaults()
175         self.set("ssid", ssid)
176         self.set("wpa_passphrase", passphrase)
177         self.set("wpa", "2")
178         self.set("wpa_key_mgmt", "WPA-PSK")
179         self.set("rsn_pairwise", "CCMP")
180
181     def set_wpa_psk(self, ssid, passphrase):
182         self.set_defaults()
183         self.set("ssid", ssid)
184         self.set("wpa_passphrase", passphrase)
185         self.set("wpa", "1")
186         self.set("wpa_key_mgmt", "WPA-PSK")
187         self.set("wpa_pairwise", "TKIP")
188
189     def set_wpa_psk_mixed(self, ssid, passphrase):
190         self.set_defaults()
191         self.set("ssid", ssid)
192         self.set("wpa_passphrase", passphrase)
193         self.set("wpa", "3")
194         self.set("wpa_key_mgmt", "WPA-PSK")
195         self.set("wpa_pairwise", "TKIP")
196         self.set("rsn_pairwise", "CCMP")
197
198     def set_wep(self, ssid, key):
199         self.set_defaults()
200         self.set("ssid", ssid)
201         self.set("wep_key0", key)
202
203     def enable(self):
204         if not "OK" in self.request("ENABLE"):
205             raise Exception("Failed to enable hostapd interface " + self.ifname)
206
207     def disable(self):
208         if not "OK" in self.request("DISABLE"):
209             raise Exception("Failed to disable hostapd interface " + self.ifname)
210
211     def dump_monitor(self):
212         while self.mon.pending():
213             ev = self.mon.recv()
214             logger.debug(self.dbg + ": " + ev)
215
216     def wait_event(self, events, timeout):
217         start = os.times()[4]
218         while True:
219             while self.mon.pending():
220                 ev = self.mon.recv()
221                 logger.debug(self.dbg + ": " + ev)
222                 for event in events:
223                     if event in ev:
224                         return ev
225             now = os.times()[4]
226             remaining = start + timeout - now
227             if remaining <= 0:
228                 break
229             if not self.mon.pending(timeout=remaining):
230                 break
231         return None
232
233     def get_status(self):
234         res = self.request("STATUS")
235         lines = res.splitlines()
236         vals = dict()
237         for l in lines:
238             [name,value] = l.split('=', 1)
239             vals[name] = value
240         return vals
241
242     def get_status_field(self, field):
243         vals = self.get_status()
244         if field in vals:
245             return vals[field]
246         return None
247
248     def get_driver_status(self):
249         res = self.request("STATUS-DRIVER")
250         lines = res.splitlines()
251         vals = dict()
252         for l in lines:
253             [name,value] = l.split('=', 1)
254             vals[name] = value
255         return vals
256
257     def get_driver_status_field(self, field):
258         vals = self.get_driver_status()
259         if field in vals:
260             return vals[field]
261         return None
262
263     def get_config(self):
264         res = self.request("GET_CONFIG")
265         lines = res.splitlines()
266         vals = dict()
267         for l in lines:
268             [name,value] = l.split('=', 1)
269             vals[name] = value
270         return vals
271
272     def mgmt_rx(self, timeout=5):
273         ev = self.wait_event(["MGMT-RX"], timeout=timeout)
274         if ev is None:
275             return None
276         msg = {}
277         frame = binascii.unhexlify(ev.split(' ')[1])
278         msg['frame'] = frame
279
280         hdr = struct.unpack('<HH6B6B6BH', frame[0:24])
281         msg['fc'] = hdr[0]
282         msg['subtype'] = (hdr[0] >> 4) & 0xf
283         hdr = hdr[1:]
284         msg['duration'] = hdr[0]
285         hdr = hdr[1:]
286         msg['da'] = "%02x:%02x:%02x:%02x:%02x:%02x" % hdr[0:6]
287         hdr = hdr[6:]
288         msg['sa'] = "%02x:%02x:%02x:%02x:%02x:%02x" % hdr[0:6]
289         hdr = hdr[6:]
290         msg['bssid'] = "%02x:%02x:%02x:%02x:%02x:%02x" % hdr[0:6]
291         hdr = hdr[6:]
292         msg['seq_ctrl'] = hdr[0]
293         msg['payload'] = frame[24:]
294
295         return msg
296
297     def mgmt_tx(self, msg):
298         t = (msg['fc'], 0) + mac2tuple(msg['da']) + mac2tuple(msg['sa']) + mac2tuple(msg['bssid']) + (0,)
299         hdr = struct.pack('<HH6B6B6BH', *t)
300         self.request("MGMT_TX " + binascii.hexlify(hdr + msg['payload']))
301
302     def get_sta(self, addr, info=None, next=False):
303         cmd = "STA-NEXT " if next else "STA "
304         if addr is None:
305             res = self.request("STA-FIRST")
306         elif info:
307             res = self.request(cmd + addr + " " + info)
308         else:
309             res = self.request(cmd + addr)
310         lines = res.splitlines()
311         vals = dict()
312         first = True
313         for l in lines:
314             if first and '=' not in l:
315                 vals['addr'] = l
316                 first = False
317             else:
318                 [name,value] = l.split('=', 1)
319                 vals[name] = value
320         return vals
321
322     def get_mib(self, param=None):
323         if param:
324             res = self.request("MIB " + param)
325         else:
326             res = self.request("MIB")
327         lines = res.splitlines()
328         vals = dict()
329         for l in lines:
330             name_val = l.split('=', 1)
331             if len(name_val) > 1:
332                 vals[name_val[0]] = name_val[1]
333         return vals
334
335     def get_pmksa(self, addr):
336         res = self.request("PMKSA")
337         lines = res.splitlines()
338         for l in lines:
339             if addr not in l:
340                 continue
341             vals = dict()
342             [index,aa,pmkid,expiration,opportunistic] = l.split(' ')
343             vals['index'] = index
344             vals['pmkid'] = pmkid
345             vals['expiration'] = expiration
346             vals['opportunistic'] = opportunistic
347             return vals
348         return None
349
350 def add_ap(apdev, params, wait_enabled=True, no_enable=False, timeout=30):
351         if isinstance(apdev, dict):
352             ifname = apdev['ifname']
353             try:
354                 hostname = apdev['hostname']
355                 port = apdev['port']
356                 logger.info("Starting AP " + hostname + "/" + port + " " + ifname)
357             except:
358                 logger.info("Starting AP " + ifname)
359                 hostname = None
360                 port = 8878
361         else:
362             ifname = apdev
363             logger.info("Starting AP " + ifname + " (old add_ap argument type)")
364             hostname = None
365             port = 8878
366         hapd_global = HostapdGlobal(apdev)
367         hapd_global.remove(ifname)
368         hapd_global.add(ifname)
369         port = hapd_global.get_ctrl_iface_port(ifname)
370         hapd = Hostapd(ifname, hostname=hostname, port=port)
371         if not hapd.ping():
372             raise Exception("Could not ping hostapd")
373         hapd.set_defaults()
374         fields = [ "ssid", "wpa_passphrase", "nas_identifier", "wpa_key_mgmt",
375                    "wpa",
376                    "wpa_pairwise", "rsn_pairwise", "auth_server_addr",
377                    "acct_server_addr", "osu_server_uri" ]
378         for field in fields:
379             if field in params:
380                 hapd.set(field, params[field])
381         for f,v in params.items():
382             if f in fields:
383                 continue
384             if isinstance(v, list):
385                 for val in v:
386                     hapd.set(f, val)
387             else:
388                 hapd.set(f, v)
389         if no_enable:
390             return hapd
391         hapd.enable()
392         if wait_enabled:
393             ev = hapd.wait_event(["AP-ENABLED", "AP-DISABLED"], timeout=timeout)
394             if ev is None:
395                 raise Exception("AP startup timed out")
396             if "AP-ENABLED" not in ev:
397                 raise Exception("AP startup failed")
398         return hapd
399
400 def add_bss(apdev, ifname, confname, ignore_error=False):
401     phy = utils.get_phy(apdev)
402     try:
403         hostname = apdev['hostname']
404         port = apdev['port']
405         logger.info("Starting BSS " + hostname + "/" + port + " phy=" + phy + " ifname=" + ifname)
406     except:
407         logger.info("Starting BSS phy=" + phy + " ifname=" + ifname)
408         hostname = None
409         port = 8878
410     hapd_global = HostapdGlobal(apdev)
411     hapd_global.add_bss(phy, confname, ignore_error)
412     port = hapd_global.get_ctrl_iface_port(ifname)
413     hapd = Hostapd(ifname, hostname=hostname, port=port)
414     if not hapd.ping():
415         raise Exception("Could not ping hostapd")
416     return hapd
417
418 def add_iface(apdev, confname):
419     ifname = apdev['ifname']
420     try:
421         hostname = apdev['hostname']
422         port = apdev['port']
423         logger.info("Starting interface " + hostname + "/" + port + " " + ifname)
424     except:
425         logger.info("Starting interface " + ifname)
426         hostname = None
427         port = 8878
428     hapd_global = HostapdGlobal(apdev)
429     hapd_global.add_iface(ifname, confname)
430     port = hapd_global.get_ctrl_iface_port(ifname)
431     hapd = Hostapd(ifname, hostname=hostname, port=port)
432     if not hapd.ping():
433         raise Exception("Could not ping hostapd")
434     return hapd
435
436 def remove_bss(apdev, ifname=None):
437     if ifname == None:
438         ifname = apdev['ifname']
439     try:
440         hostname = apdev['hostname']
441         port = apdev['port']
442         logger.info("Removing BSS " + hostname + "/" + port + " " + ifname)
443     except:
444         logger.info("Removing BSS " + ifname)
445     hapd_global = HostapdGlobal(apdev)
446     hapd_global.remove(ifname)
447
448 def terminate(apdev):
449     try:
450         hostname = apdev['hostname']
451         port = apdev['port']
452         logger.info("Terminating hostapd " + hostname + "/" + port)
453     except:
454         logger.info("Terminating hostapd")
455     hapd_global = HostapdGlobal(apdev)
456     hapd_global.terminate()
457
458 def wpa2_params(ssid=None, passphrase=None):
459     params = { "wpa": "2",
460                "wpa_key_mgmt": "WPA-PSK",
461                "rsn_pairwise": "CCMP" }
462     if ssid:
463         params["ssid"] = ssid
464     if passphrase:
465         params["wpa_passphrase"] = passphrase
466     return params
467
468 def wpa_params(ssid=None, passphrase=None):
469     params = { "wpa": "1",
470                "wpa_key_mgmt": "WPA-PSK",
471                "wpa_pairwise": "TKIP" }
472     if ssid:
473         params["ssid"] = ssid
474     if passphrase:
475         params["wpa_passphrase"] = passphrase
476     return params
477
478 def wpa_mixed_params(ssid=None, passphrase=None):
479     params = { "wpa": "3",
480                "wpa_key_mgmt": "WPA-PSK",
481                "wpa_pairwise": "TKIP",
482                "rsn_pairwise": "CCMP" }
483     if ssid:
484         params["ssid"] = ssid
485     if passphrase:
486         params["wpa_passphrase"] = passphrase
487     return params
488
489 def radius_params():
490     params = { "auth_server_addr": "127.0.0.1",
491                "auth_server_port": "1812",
492                "auth_server_shared_secret": "radius",
493                "nas_identifier": "nas.w1.fi" }
494     return params
495
496 def wpa_eap_params(ssid=None):
497     params = radius_params()
498     params["wpa"] = "1"
499     params["wpa_key_mgmt"] = "WPA-EAP"
500     params["wpa_pairwise"] = "TKIP"
501     params["ieee8021x"] = "1"
502     if ssid:
503         params["ssid"] = ssid
504     return params
505
506 def wpa2_eap_params(ssid=None):
507     params = radius_params()
508     params["wpa"] = "2"
509     params["wpa_key_mgmt"] = "WPA-EAP"
510     params["rsn_pairwise"] = "CCMP"
511     params["ieee8021x"] = "1"
512     if ssid:
513         params["ssid"] = ssid
514     return params
515
516 def b_only_params(channel="1", ssid=None, country=None):
517     params = { "hw_mode" : "b",
518                "channel" : channel }
519     if ssid:
520         params["ssid"] = ssid
521     if country:
522         params["country_code"] = country
523     return params
524
525 def g_only_params(channel="1", ssid=None, country=None):
526     params = { "hw_mode" : "g",
527                "channel" : channel }
528     if ssid:
529         params["ssid"] = ssid
530     if country:
531         params["country_code"] = country
532     return params
533
534 def a_only_params(channel="36", ssid=None, country=None):
535     params = { "hw_mode" : "a",
536                "channel" : channel }
537     if ssid:
538         params["ssid"] = ssid
539     if country:
540         params["country_code"] = country
541     return params
542
543 def ht20_params(channel="1", ssid=None, country=None):
544     params = { "ieee80211n" : "1",
545                "channel" : channel,
546                "hw_mode" : "g" }
547     if int(channel) > 14:
548         params["hw_mode"] = "a"
549     if ssid:
550         params["ssid"] = ssid
551     if country:
552         params["country_code"] = country
553     return params
554
555 def ht40_plus_params(channel="1", ssid=None, country=None):
556     params = ht20_params(channel, ssid, country)
557     params['ht_capab'] = "[HT40+]"
558     return params
559
560 def ht40_minus_params(channel="1", ssid=None, country=None):
561     params = ht20_params(channel, ssid, country)
562     params['ht_capab'] = "[HT40-]"
563     return params