d3f5c9f19a59368859d528452d718242a2c1199e
[mech_eap.git] / tests / hwsim / wpasupplicant.py
1 # Python class for controlling wpa_supplicant
2 # Copyright (c) 2013-2014, Jouni Malinen <j@w1.fi>
3 #
4 # This software may be distributed under the terms of the BSD license.
5 # See README for more details.
6
7 import os
8 import time
9 import logging
10 import binascii
11 import re
12 import struct
13 import wpaspy
14 import remotehost
15 import subprocess
16
17 logger = logging.getLogger()
18 wpas_ctrl = '/var/run/wpa_supplicant'
19
20 class WpaSupplicant:
21     def __init__(self, ifname=None, global_iface=None, hostname=None,
22                  port=9877, global_port=9878):
23         self.hostname = hostname
24         self.group_ifname = None
25         self.gctrl_mon = None
26         self.host = remotehost.Host(hostname, ifname)
27         self._group_dbg = None
28         if ifname:
29             self.set_ifname(ifname, hostname, port)
30             res = self.get_driver_status()
31             if 'capa.flags' in res and int(res['capa.flags'], 0) & 0x20000000:
32                 self.p2p_dev_ifname = 'p2p-dev-' + self.ifname
33             else:
34                 self.p2p_dev_ifname = ifname
35         else:
36             self.ifname = None
37
38         self.global_iface = global_iface
39         if global_iface:
40             if hostname != None:
41                 self.global_ctrl = wpaspy.Ctrl(hostname, global_port)
42                 self.global_mon = wpaspy.Ctrl(hostname, global_port)
43                 self.global_dbg = hostname + "/" + str(global_port) + "/"
44             else:
45                 self.global_ctrl = wpaspy.Ctrl(global_iface)
46                 self.global_mon = wpaspy.Ctrl(global_iface)
47                 self.global_dbg = ""
48             self.global_mon.attach()
49         else:
50             self.global_mon = None
51
52     def cmd_execute(self, cmd_array, shell=False):
53         if self.hostname is None:
54             if shell:
55                 cmd = ' '.join(cmd_array)
56             else:
57                 cmd = cmd_array
58             proc = subprocess.Popen(cmd, stderr=subprocess.STDOUT,
59                                     stdout=subprocess.PIPE, shell=shell)
60             out = proc.communicate()[0]
61             ret = proc.returncode
62             return ret, out
63         else:
64             return self.host.execute(cmd_array)
65
66     def terminate(self):
67         if self.global_mon:
68             self.global_mon.detach()
69             self.global_mon = None
70             self.global_ctrl.terminate()
71             self.global_ctrl = None
72
73     def close_ctrl(self):
74         if self.global_mon:
75             self.global_mon.detach()
76             self.global_mon = None
77             self.global_ctrl = None
78         self.remove_ifname()
79
80     def set_ifname(self, ifname, hostname=None, port=9877):
81         self.ifname = ifname
82         if hostname != None:
83             self.ctrl = wpaspy.Ctrl(hostname, port)
84             self.mon = wpaspy.Ctrl(hostname, port)
85             self.host = remotehost.Host(hostname, ifname)
86             self.dbg = hostname + "/" + ifname
87         else:
88             self.ctrl = wpaspy.Ctrl(os.path.join(wpas_ctrl, ifname))
89             self.mon = wpaspy.Ctrl(os.path.join(wpas_ctrl, ifname))
90             self.dbg = ifname
91         self.mon.attach()
92
93     def remove_ifname(self):
94         if self.ifname:
95             self.mon.detach()
96             self.mon = None
97             self.ctrl = None
98             self.ifname = None
99
100     def get_ctrl_iface_port(self, ifname):
101         if self.hostname is None:
102             return None
103
104         res = self.global_request("INTERFACES ctrl")
105         lines = res.splitlines()
106         found = False
107         for line in lines:
108             words = line.split()
109             if words[0] == ifname:
110                 found = True
111                 break
112         if not found:
113             raise Exception("Could not find UDP port for " + ifname)
114         res = line.find("ctrl_iface=udp:")
115         if res == -1:
116             raise Exception("Wrong ctrl_interface format")
117         words = line.split(":")
118         return int(words[1])
119
120     def interface_add(self, ifname, config="", driver="nl80211",
121                       drv_params=None, br_ifname=None, create=False,
122                       set_ifname=True, all_params=False, if_type=None):
123         status, groups = self.host.execute(["id"])
124         if status != 0:
125             group = "admin"
126         group = "admin" if "(admin)" in groups else "adm"
127         cmd = "INTERFACE_ADD " + ifname + "\t" + config + "\t" + driver + "\tDIR=/var/run/wpa_supplicant GROUP=" + group
128         if drv_params:
129             cmd = cmd + '\t' + drv_params
130         if br_ifname:
131             if not drv_params:
132                 cmd += '\t'
133             cmd += '\t' + br_ifname
134         if create:
135             if not br_ifname:
136                 cmd += '\t'
137                 if not drv_params:
138                     cmd += '\t'
139             cmd += '\tcreate'
140             if if_type:
141                 cmd += '\t' + if_type
142         if all_params and not create:
143             if not br_ifname:
144                 cmd += '\t'
145                 if not drv_params:
146                     cmd += '\t'
147             cmd += '\t'
148         if "FAIL" in self.global_request(cmd):
149             raise Exception("Failed to add a dynamic wpa_supplicant interface")
150         if not create and set_ifname:
151             port = self.get_ctrl_iface_port(ifname)
152             self.set_ifname(ifname, self.hostname, port)
153             res = self.get_driver_status()
154             if 'capa.flags' in res and int(res['capa.flags'], 0) & 0x20000000:
155                 self.p2p_dev_ifname = 'p2p-dev-' + self.ifname
156             else:
157                 self.p2p_dev_ifname = ifname
158
159     def interface_remove(self, ifname):
160         self.remove_ifname()
161         self.global_request("INTERFACE_REMOVE " + ifname)
162
163     def request(self, cmd, timeout=10):
164         logger.debug(self.dbg + ": CTRL: " + cmd)
165         return self.ctrl.request(cmd, timeout=timeout)
166
167     def global_request(self, cmd):
168         if self.global_iface is None:
169             return self.request(cmd)
170         else:
171             ifname = self.ifname or self.global_iface
172             logger.debug(self.global_dbg + ifname + ": CTRL(global): " + cmd)
173             return self.global_ctrl.request(cmd)
174
175     @property
176     def group_dbg(self):
177         if self._group_dbg is not None:
178             return self._group_dbg
179         if self.group_ifname is None:
180             raise Exception("Cannot have group_dbg without group_ifname")
181         if self.hostname is None:
182             self._group_dbg = self.group_ifname
183         else:
184             self._group_dbg = self.hostname + "/" + self.group_ifname
185         return self._group_dbg
186
187     def group_request(self, cmd):
188         if self.group_ifname and self.group_ifname != self.ifname:
189             if self.hostname is None:
190                 gctrl = wpaspy.Ctrl(os.path.join(wpas_ctrl, self.group_ifname))
191             else:
192                 port = self.get_ctrl_iface_port(self.group_ifname)
193                 gctrl = wpaspy.Ctrl(self.hostname, port)
194             logger.debug(self.group_dbg + ": CTRL(group): " + cmd)
195             return gctrl.request(cmd)
196         return self.request(cmd)
197
198     def ping(self):
199         return "PONG" in self.request("PING")
200
201     def global_ping(self):
202         return "PONG" in self.global_request("PING")
203
204     def reset(self):
205         self.dump_monitor()
206         res = self.request("FLUSH")
207         if not "OK" in res:
208             logger.info("FLUSH to " + self.ifname + " failed: " + res)
209         self.global_request("REMOVE_NETWORK all")
210         self.global_request("SET p2p_no_group_iface 1")
211         self.global_request("P2P_FLUSH")
212         if self.gctrl_mon:
213             try:
214                 self.gctrl_mon.detach()
215             except:
216                 pass
217             self.gctrl_mon = None
218         self.group_ifname = None
219         self.dump_monitor()
220
221         iter = 0
222         while iter < 60:
223             state1 = self.get_driver_status_field("scan_state")
224             p2pdev = "p2p-dev-" + self.ifname
225             state2 = self.get_driver_status_field("scan_state", ifname=p2pdev)
226             states = str(state1) + " " + str(state2)
227             if "SCAN_STARTED" in states or "SCAN_REQUESTED" in states:
228                 logger.info(self.ifname + ": Waiting for scan operation to complete before continuing")
229                 time.sleep(1)
230             else:
231                 break
232             iter = iter + 1
233         if iter == 60:
234             logger.error(self.ifname + ": Driver scan state did not clear")
235             print "Trying to clear cfg80211/mac80211 scan state"
236             status, buf = self.host.execute(["ifconfig", self.ifname, "down"])
237             if status != 0:
238                 logger.info("ifconfig failed: " + buf)
239                 logger.info(status)
240             status, buf = self.host.execute(["ifconfig", self.ifname, "up"])
241             if status != 0:
242                 logger.info("ifconfig failed: " + buf)
243                 logger.info(status)
244         if iter > 0:
245             # The ongoing scan could have discovered BSSes or P2P peers
246             logger.info("Run FLUSH again since scan was in progress")
247             self.request("FLUSH")
248             self.dump_monitor()
249
250         if not self.ping():
251             logger.info("No PING response from " + self.ifname + " after reset")
252
253     def add_network(self):
254         id = self.request("ADD_NETWORK")
255         if "FAIL" in id:
256             raise Exception("ADD_NETWORK failed")
257         return int(id)
258
259     def remove_network(self, id):
260         id = self.request("REMOVE_NETWORK " + str(id))
261         if "FAIL" in id:
262             raise Exception("REMOVE_NETWORK failed")
263         return None
264
265     def get_network(self, id, field):
266         res = self.request("GET_NETWORK " + str(id) + " " + field)
267         if res == "FAIL\n":
268             return None
269         return res
270
271     def set_network(self, id, field, value):
272         res = self.request("SET_NETWORK " + str(id) + " " + field + " " + value)
273         if "FAIL" in res:
274             raise Exception("SET_NETWORK failed")
275         return None
276
277     def set_network_quoted(self, id, field, value):
278         res = self.request("SET_NETWORK " + str(id) + " " + field + ' "' + value + '"')
279         if "FAIL" in res:
280             raise Exception("SET_NETWORK failed")
281         return None
282
283     def p2pdev_request(self, cmd):
284         return self.global_request("IFNAME=" + self.p2p_dev_ifname + " " + cmd)
285
286     def p2pdev_add_network(self):
287         id = self.p2pdev_request("ADD_NETWORK")
288         if "FAIL" in id:
289             raise Exception("p2pdev ADD_NETWORK failed")
290         return int(id)
291
292     def p2pdev_set_network(self, id, field, value):
293         res = self.p2pdev_request("SET_NETWORK " + str(id) + " " + field + " " + value)
294         if "FAIL" in res:
295             raise Exception("p2pdev SET_NETWORK failed")
296         return None
297
298     def p2pdev_set_network_quoted(self, id, field, value):
299         res = self.p2pdev_request("SET_NETWORK " + str(id) + " " + field + ' "' + value + '"')
300         if "FAIL" in res:
301             raise Exception("p2pdev SET_NETWORK failed")
302         return None
303
304     def list_networks(self, p2p=False):
305         if p2p:
306             res = self.global_request("LIST_NETWORKS")
307         else:
308             res = self.request("LIST_NETWORKS")
309         lines = res.splitlines()
310         networks = []
311         for l in lines:
312             if "network id" in l:
313                 continue
314             [id,ssid,bssid,flags] = l.split('\t')
315             network = {}
316             network['id'] = id
317             network['ssid'] = ssid
318             network['bssid'] = bssid
319             network['flags'] = flags
320             networks.append(network)
321         return networks
322
323     def hs20_enable(self, auto_interworking=False):
324         self.request("SET interworking 1")
325         self.request("SET hs20 1")
326         if auto_interworking:
327             self.request("SET auto_interworking 1")
328         else:
329             self.request("SET auto_interworking 0")
330
331     def interworking_add_network(self, bssid):
332         id = self.request("INTERWORKING_ADD_NETWORK " + bssid)
333         if "FAIL" in id or "OK" in id:
334             raise Exception("INTERWORKING_ADD_NETWORK failed")
335         return int(id)
336
337     def add_cred(self):
338         id = self.request("ADD_CRED")
339         if "FAIL" in id:
340             raise Exception("ADD_CRED failed")
341         return int(id)
342
343     def remove_cred(self, id):
344         id = self.request("REMOVE_CRED " + str(id))
345         if "FAIL" in id:
346             raise Exception("REMOVE_CRED failed")
347         return None
348
349     def set_cred(self, id, field, value):
350         res = self.request("SET_CRED " + str(id) + " " + field + " " + value)
351         if "FAIL" in res:
352             raise Exception("SET_CRED failed")
353         return None
354
355     def set_cred_quoted(self, id, field, value):
356         res = self.request("SET_CRED " + str(id) + " " + field + ' "' + value + '"')
357         if "FAIL" in res:
358             raise Exception("SET_CRED failed")
359         return None
360
361     def get_cred(self, id, field):
362         return self.request("GET_CRED " + str(id) + " " + field)
363
364     def add_cred_values(self, params):
365         id = self.add_cred()
366
367         quoted = [ "realm", "username", "password", "domain", "imsi",
368                    "excluded_ssid", "milenage", "ca_cert", "client_cert",
369                    "private_key", "domain_suffix_match", "provisioning_sp",
370                    "roaming_partner", "phase1", "phase2", "private_key_passwd" ]
371         for field in quoted:
372             if field in params:
373                 self.set_cred_quoted(id, field, params[field])
374
375         not_quoted = [ "eap", "roaming_consortium", "priority",
376                        "required_roaming_consortium", "sp_priority",
377                        "max_bss_load", "update_identifier", "req_conn_capab",
378                        "min_dl_bandwidth_home", "min_ul_bandwidth_home",
379                        "min_dl_bandwidth_roaming", "min_ul_bandwidth_roaming" ]
380         for field in not_quoted:
381             if field in params:
382                 self.set_cred(id, field, params[field])
383
384         return id
385
386     def select_network(self, id, freq=None):
387         if freq:
388             extra = " freq=" + str(freq)
389         else:
390             extra = ""
391         id = self.request("SELECT_NETWORK " + str(id) + extra)
392         if "FAIL" in id:
393             raise Exception("SELECT_NETWORK failed")
394         return None
395
396     def mesh_group_add(self, id):
397         id = self.request("MESH_GROUP_ADD " + str(id))
398         if "FAIL" in id:
399             raise Exception("MESH_GROUP_ADD failed")
400         return None
401
402     def mesh_group_remove(self):
403         id = self.request("MESH_GROUP_REMOVE " + str(self.ifname))
404         if "FAIL" in id:
405             raise Exception("MESH_GROUP_REMOVE failed")
406         return None
407
408     def connect_network(self, id, timeout=None):
409         if timeout is None:
410             timeout = 10 if self.hostname is None else 60
411         self.dump_monitor()
412         self.select_network(id)
413         self.wait_connected(timeout=timeout)
414         self.dump_monitor()
415
416     def get_status(self, extra=None):
417         if extra:
418             extra = "-" + extra
419         else:
420             extra = ""
421         res = self.request("STATUS" + extra)
422         lines = res.splitlines()
423         vals = dict()
424         for l in lines:
425             try:
426                 [name,value] = l.split('=', 1)
427                 vals[name] = value
428             except ValueError, e:
429                 logger.info(self.ifname + ": Ignore unexpected STATUS line: " + l)
430         return vals
431
432     def get_status_field(self, field, extra=None):
433         vals = self.get_status(extra)
434         if field in vals:
435             return vals[field]
436         return None
437
438     def get_group_status(self, extra=None):
439         if extra:
440             extra = "-" + extra
441         else:
442             extra = ""
443         res = self.group_request("STATUS" + extra)
444         lines = res.splitlines()
445         vals = dict()
446         for l in lines:
447             try:
448                 [name,value] = l.split('=', 1)
449             except ValueError:
450                 logger.info(self.ifname + ": Ignore unexpected status line: " + l)
451                 continue
452             vals[name] = value
453         return vals
454
455     def get_group_status_field(self, field, extra=None):
456         vals = self.get_group_status(extra)
457         if field in vals:
458             return vals[field]
459         return None
460
461     def get_driver_status(self, ifname=None):
462         if ifname is None:
463             res = self.request("STATUS-DRIVER")
464         else:
465             res = self.global_request("IFNAME=%s STATUS-DRIVER" % ifname)
466             if res.startswith("FAIL"):
467                 return dict()
468         lines = res.splitlines()
469         vals = dict()
470         for l in lines:
471             try:
472                 [name,value] = l.split('=', 1)
473             except ValueError:
474                 logger.info(self.ifname + ": Ignore unexpected status-driver line: " + l)
475                 continue
476             vals[name] = value
477         return vals
478
479     def get_driver_status_field(self, field, ifname=None):
480         vals = self.get_driver_status(ifname)
481         if field in vals:
482             return vals[field]
483         return None
484
485     def get_mcc(self):
486         mcc = int(self.get_driver_status_field('capa.num_multichan_concurrent'))
487         return 1 if mcc < 2 else mcc
488
489     def get_mib(self):
490         res = self.request("MIB")
491         lines = res.splitlines()
492         vals = dict()
493         for l in lines:
494             try:
495                 [name,value] = l.split('=', 1)
496                 vals[name] = value
497             except ValueError, e:
498                 logger.info(self.ifname + ": Ignore unexpected MIB line: " + l)
499         return vals
500
501     def p2p_dev_addr(self):
502         return self.get_status_field("p2p_device_address")
503
504     def p2p_interface_addr(self):
505         return self.get_group_status_field("address")
506
507     def own_addr(self):
508         try:
509             res = self.p2p_interface_addr()
510         except:
511             res = self.p2p_dev_addr()
512         return res
513
514     def p2p_listen(self):
515         return self.global_request("P2P_LISTEN")
516
517     def p2p_ext_listen(self, period, interval):
518         return self.global_request("P2P_EXT_LISTEN %d %d" % (period, interval))
519
520     def p2p_cancel_ext_listen(self):
521         return self.global_request("P2P_EXT_LISTEN")
522
523     def p2p_find(self, social=False, progressive=False, dev_id=None,
524                  dev_type=None, delay=None, freq=None):
525         cmd = "P2P_FIND"
526         if social:
527             cmd = cmd + " type=social"
528         elif progressive:
529             cmd = cmd + " type=progressive"
530         if dev_id:
531             cmd = cmd + " dev_id=" + dev_id
532         if dev_type:
533             cmd = cmd + " dev_type=" + dev_type
534         if delay:
535             cmd = cmd + " delay=" + str(delay)
536         if freq:
537             cmd = cmd + " freq=" + str(freq)
538         return self.global_request(cmd)
539
540     def p2p_stop_find(self):
541         return self.global_request("P2P_STOP_FIND")
542
543     def wps_read_pin(self):
544         self.pin = self.request("WPS_PIN get").rstrip("\n")
545         if "FAIL" in self.pin:
546             raise Exception("Could not generate PIN")
547         return self.pin
548
549     def peer_known(self, peer, full=True):
550         res = self.global_request("P2P_PEER " + peer)
551         if peer.lower() not in res.lower():
552             return False
553         if not full:
554             return True
555         return "[PROBE_REQ_ONLY]" not in res
556
557     def discover_peer(self, peer, full=True, timeout=15, social=True,
558                       force_find=False, freq=None):
559         logger.info(self.ifname + ": Trying to discover peer " + peer)
560         if not force_find and self.peer_known(peer, full):
561             return True
562         self.p2p_find(social, freq=freq)
563         count = 0
564         while count < timeout * 4:
565             time.sleep(0.25)
566             count = count + 1
567             if self.peer_known(peer, full):
568                 return True
569         return False
570
571     def get_peer(self, peer):
572         res = self.global_request("P2P_PEER " + peer)
573         if peer.lower() not in res.lower():
574             raise Exception("Peer information not available")
575         lines = res.splitlines()
576         vals = dict()
577         for l in lines:
578             if '=' in l:
579                 [name,value] = l.split('=', 1)
580                 vals[name] = value
581         return vals
582
583     def group_form_result(self, ev, expect_failure=False, go_neg_res=None):
584         if expect_failure:
585             if "P2P-GROUP-STARTED" in ev:
586                 raise Exception("Group formation succeeded when expecting failure")
587             exp = r'<.>(P2P-GO-NEG-FAILURE) status=([0-9]*)'
588             s = re.split(exp, ev)
589             if len(s) < 3:
590                 return None
591             res = {}
592             res['result'] = 'go-neg-failed'
593             res['status'] = int(s[2])
594             return res
595
596         if "P2P-GROUP-STARTED" not in ev:
597             raise Exception("No P2P-GROUP-STARTED event seen")
598
599         exp = r'<.>(P2P-GROUP-STARTED) ([^ ]*) ([^ ]*) ssid="(.*)" freq=([0-9]*) ((?:psk=.*)|(?:passphrase=".*")) go_dev_addr=([0-9a-f:]*) ip_addr=([0-9.]*) ip_mask=([0-9.]*) go_ip_addr=([0-9.]*)'
600         s = re.split(exp, ev)
601         if len(s) < 11:
602             exp = r'<.>(P2P-GROUP-STARTED) ([^ ]*) ([^ ]*) ssid="(.*)" freq=([0-9]*) ((?:psk=.*)|(?:passphrase=".*")) go_dev_addr=([0-9a-f:]*)'
603             s = re.split(exp, ev)
604             if len(s) < 8:
605                 raise Exception("Could not parse P2P-GROUP-STARTED")
606         res = {}
607         res['result'] = 'success'
608         res['ifname'] = s[2]
609         self.group_ifname = s[2]
610         try:
611             if self.hostname is None:
612                 self.gctrl_mon = wpaspy.Ctrl(os.path.join(wpas_ctrl,
613                                                           self.group_ifname))
614             else:
615                 port = self.get_ctrl_iface_port(self.group_ifname)
616                 self.gctrl_mon = wpaspy.Ctrl(self.hostname, port)
617             self.gctrl_mon.attach()
618         except:
619             logger.debug("Could not open monitor socket for group interface")
620             self.gctrl_mon = None
621         res['role'] = s[3]
622         res['ssid'] = s[4]
623         res['freq'] = s[5]
624         if "[PERSISTENT]" in ev:
625             res['persistent'] = True
626         else:
627             res['persistent'] = False
628         p = re.match(r'psk=([0-9a-f]*)', s[6])
629         if p:
630             res['psk'] = p.group(1)
631         p = re.match(r'passphrase="(.*)"', s[6])
632         if p:
633             res['passphrase'] = p.group(1)
634         res['go_dev_addr'] = s[7]
635
636         if len(s) > 8 and len(s[8]) > 0:
637             res['ip_addr'] = s[8]
638         if len(s) > 9:
639             res['ip_mask'] = s[9]
640         if len(s) > 10:
641             res['go_ip_addr'] = s[10]
642
643         if go_neg_res:
644             exp = r'<.>(P2P-GO-NEG-SUCCESS) role=(GO|client) freq=([0-9]*)'
645             s = re.split(exp, go_neg_res)
646             if len(s) < 4:
647                 raise Exception("Could not parse P2P-GO-NEG-SUCCESS")
648             res['go_neg_role'] = s[2]
649             res['go_neg_freq'] = s[3]
650
651         return res
652
653     def p2p_go_neg_auth(self, peer, pin, method, go_intent=None,
654                         persistent=False, freq=None, freq2=None,
655                         max_oper_chwidth=None, ht40=False, vht=False):
656         if not self.discover_peer(peer):
657             raise Exception("Peer " + peer + " not found")
658         self.dump_monitor()
659         if pin:
660             cmd = "P2P_CONNECT " + peer + " " + pin + " " + method + " auth"
661         else:
662             cmd = "P2P_CONNECT " + peer + " " + method + " auth"
663         if go_intent:
664             cmd = cmd + ' go_intent=' + str(go_intent)
665         if freq:
666             cmd = cmd + ' freq=' + str(freq)
667         if freq2:
668             cmd = cmd + ' freq2=' + str(freq2)
669         if max_oper_chwidth:
670             cmd = cmd + ' max_oper_chwidth=' + str(max_oper_chwidth)
671         if ht40:
672             cmd = cmd + ' ht40'
673         if vht:
674             cmd = cmd + ' vht'
675         if persistent:
676             cmd = cmd + " persistent"
677         if "OK" in self.global_request(cmd):
678             return None
679         raise Exception("P2P_CONNECT (auth) failed")
680
681     def p2p_go_neg_auth_result(self, timeout=1, expect_failure=False):
682         go_neg_res = None
683         ev = self.wait_global_event(["P2P-GO-NEG-SUCCESS",
684                                      "P2P-GO-NEG-FAILURE"], timeout)
685         if ev is None:
686             if expect_failure:
687                 return None
688             raise Exception("Group formation timed out")
689         if "P2P-GO-NEG-SUCCESS" in ev:
690             go_neg_res = ev
691             ev = self.wait_global_event(["P2P-GROUP-STARTED"], timeout)
692             if ev is None:
693                 if expect_failure:
694                     return None
695                 raise Exception("Group formation timed out")
696         self.dump_monitor()
697         return self.group_form_result(ev, expect_failure, go_neg_res)
698
699     def p2p_go_neg_init(self, peer, pin, method, timeout=0, go_intent=None,
700                         expect_failure=False, persistent=False,
701                         persistent_id=None, freq=None, provdisc=False,
702                         wait_group=True, freq2=None, max_oper_chwidth=None,
703                         ht40=False, vht=False):
704         if not self.discover_peer(peer):
705             raise Exception("Peer " + peer + " not found")
706         self.dump_monitor()
707         if pin:
708             cmd = "P2P_CONNECT " + peer + " " + pin + " " + method
709         else:
710             cmd = "P2P_CONNECT " + peer + " " + method
711         if go_intent is not None:
712             cmd = cmd + ' go_intent=' + str(go_intent)
713         if freq:
714             cmd = cmd + ' freq=' + str(freq)
715         if freq2:
716             cmd = cmd + ' freq2=' + str(freq2)
717         if max_oper_chwidth:
718             cmd = cmd + ' max_oper_chwidth=' + str(max_oper_chwidth)
719         if ht40:
720             cmd = cmd + ' ht40'
721         if vht:
722             cmd = cmd + ' vht'
723         if persistent:
724             cmd = cmd + " persistent"
725         elif persistent_id:
726             cmd = cmd + " persistent=" + persistent_id
727         if provdisc:
728             cmd = cmd + " provdisc"
729         if "OK" in self.global_request(cmd):
730             if timeout == 0:
731                 return None
732             go_neg_res = None
733             ev = self.wait_global_event(["P2P-GO-NEG-SUCCESS",
734                                          "P2P-GO-NEG-FAILURE"], timeout)
735             if ev is None:
736                 if expect_failure:
737                     return None
738                 raise Exception("Group formation timed out")
739             if "P2P-GO-NEG-SUCCESS" in ev:
740                 if not wait_group:
741                     return ev
742                 go_neg_res = ev
743                 ev = self.wait_global_event(["P2P-GROUP-STARTED"], timeout)
744                 if ev is None:
745                     if expect_failure:
746                         return None
747                     raise Exception("Group formation timed out")
748             self.dump_monitor()
749             return self.group_form_result(ev, expect_failure, go_neg_res)
750         raise Exception("P2P_CONNECT failed")
751
752     def wait_event(self, events, timeout=10):
753         start = os.times()[4]
754         while True:
755             while self.mon.pending():
756                 ev = self.mon.recv()
757                 logger.debug(self.dbg + ": " + ev)
758                 for event in events:
759                     if event in ev:
760                         return ev
761             now = os.times()[4]
762             remaining = start + timeout - now
763             if remaining <= 0:
764                 break
765             if not self.mon.pending(timeout=remaining):
766                 break
767         return None
768
769     def wait_global_event(self, events, timeout):
770         if self.global_iface is None:
771             self.wait_event(events, timeout)
772         else:
773             start = os.times()[4]
774             while True:
775                 while self.global_mon.pending():
776                     ev = self.global_mon.recv()
777                     logger.debug(self.global_dbg + self.ifname + "(global): " + ev)
778                     for event in events:
779                         if event in ev:
780                             return ev
781                 now = os.times()[4]
782                 remaining = start + timeout - now
783                 if remaining <= 0:
784                     break
785                 if not self.global_mon.pending(timeout=remaining):
786                     break
787         return None
788
789     def wait_group_event(self, events, timeout=10):
790         if self.group_ifname and self.group_ifname != self.ifname:
791             if self.gctrl_mon is None:
792                 return None
793             start = os.times()[4]
794             while True:
795                 while self.gctrl_mon.pending():
796                     ev = self.gctrl_mon.recv()
797                     logger.debug(self.group_dbg + "(group): " + ev)
798                     for event in events:
799                         if event in ev:
800                             return ev
801                 now = os.times()[4]
802                 remaining = start + timeout - now
803                 if remaining <= 0:
804                     break
805                 if not self.gctrl_mon.pending(timeout=remaining):
806                     break
807             return None
808
809         return self.wait_event(events, timeout)
810
811     def wait_go_ending_session(self):
812         if self.gctrl_mon:
813             try:
814                 self.gctrl_mon.detach()
815             except:
816                 pass
817             self.gctrl_mon = None
818         ev = self.wait_global_event(["P2P-GROUP-REMOVED"], timeout=3)
819         if ev is None:
820             raise Exception("Group removal event timed out")
821         if "reason=GO_ENDING_SESSION" not in ev:
822             raise Exception("Unexpected group removal reason")
823
824     def dump_monitor(self):
825         count_iface = 0
826         count_global = 0
827         while self.mon.pending():
828             ev = self.mon.recv()
829             logger.debug(self.dbg + ": " + ev)
830             count_iface += 1
831         while self.global_mon and self.global_mon.pending():
832             ev = self.global_mon.recv()
833             logger.debug(self.global_dbg + self.ifname + "(global): " + ev)
834             count_global += 1
835         return (count_iface, count_global)
836
837     def remove_group(self, ifname=None):
838         if self.gctrl_mon:
839             try:
840                 self.gctrl_mon.detach()
841             except:
842                 pass
843             self.gctrl_mon = None
844         if ifname is None:
845             ifname = self.group_ifname if self.group_ifname else self.ifname
846         if "OK" not in self.global_request("P2P_GROUP_REMOVE " + ifname):
847             raise Exception("Group could not be removed")
848         self.group_ifname = None
849
850     def p2p_start_go(self, persistent=None, freq=None, no_event_clear=False):
851         self.dump_monitor()
852         cmd = "P2P_GROUP_ADD"
853         if persistent is None:
854             pass
855         elif persistent is True:
856             cmd = cmd + " persistent"
857         else:
858             cmd = cmd + " persistent=" + str(persistent)
859         if freq:
860             cmd = cmd + " freq=" + str(freq)
861         if "OK" in self.global_request(cmd):
862             ev = self.wait_global_event(["P2P-GROUP-STARTED"], timeout=5)
863             if ev is None:
864                 raise Exception("GO start up timed out")
865             if not no_event_clear:
866                 self.dump_monitor()
867             return self.group_form_result(ev)
868         raise Exception("P2P_GROUP_ADD failed")
869
870     def p2p_go_authorize_client(self, pin):
871         cmd = "WPS_PIN any " + pin
872         if "FAIL" in self.group_request(cmd):
873             raise Exception("Failed to authorize client connection on GO")
874         return None
875
876     def p2p_go_authorize_client_pbc(self):
877         cmd = "WPS_PBC"
878         if "FAIL" in self.group_request(cmd):
879             raise Exception("Failed to authorize client connection on GO")
880         return None
881
882     def p2p_connect_group(self, go_addr, pin, timeout=0, social=False,
883                           freq=None):
884         self.dump_monitor()
885         if not self.discover_peer(go_addr, social=social, freq=freq):
886             if social or not self.discover_peer(go_addr, social=social):
887                 raise Exception("GO " + go_addr + " not found")
888         self.p2p_stop_find()
889         self.dump_monitor()
890         cmd = "P2P_CONNECT " + go_addr + " " + pin + " join"
891         if freq:
892             cmd += " freq=" + str(freq)
893         if "OK" in self.global_request(cmd):
894             if timeout == 0:
895                 self.dump_monitor()
896                 return None
897             ev = self.wait_global_event(["P2P-GROUP-STARTED",
898                                          "P2P-GROUP-FORMATION-FAILURE"],
899                                         timeout)
900             if ev is None:
901                 raise Exception("Joining the group timed out")
902             if "P2P-GROUP-STARTED" not in ev:
903                 raise Exception("Failed to join the group")
904             self.dump_monitor()
905             return self.group_form_result(ev)
906         raise Exception("P2P_CONNECT(join) failed")
907
908     def tdls_setup(self, peer):
909         cmd = "TDLS_SETUP " + peer
910         if "FAIL" in self.group_request(cmd):
911             raise Exception("Failed to request TDLS setup")
912         return None
913
914     def tdls_teardown(self, peer):
915         cmd = "TDLS_TEARDOWN " + peer
916         if "FAIL" in self.group_request(cmd):
917             raise Exception("Failed to request TDLS teardown")
918         return None
919
920     def tdls_link_status(self, peer):
921         cmd = "TDLS_LINK_STATUS " + peer
922         ret = self.group_request(cmd)
923         if "FAIL" in ret:
924             raise Exception("Failed to request TDLS link status")
925         return ret
926
927     def tspecs(self):
928         """Return (tsid, up) tuples representing current tspecs"""
929         res = self.request("WMM_AC_STATUS")
930         tspecs = re.findall(r"TSID=(\d+) UP=(\d+)", res)
931         tspecs = [tuple(map(int, tspec)) for tspec in tspecs]
932
933         logger.debug("tspecs: " + str(tspecs))
934         return tspecs
935
936     def add_ts(self, tsid, up, direction="downlink", expect_failure=False,
937                extra=None):
938         params = {
939             "sba": 9000,
940             "nominal_msdu_size": 1500,
941             "min_phy_rate": 6000000,
942             "mean_data_rate": 1500,
943         }
944         cmd = "WMM_AC_ADDTS %s tsid=%d up=%d" % (direction, tsid, up)
945         for (key, value) in params.iteritems():
946             cmd += " %s=%d" % (key, value)
947         if extra:
948             cmd += " " + extra
949
950         if self.request(cmd).strip() != "OK":
951             raise Exception("ADDTS failed (tsid=%d up=%d)" % (tsid, up))
952
953         if expect_failure:
954             ev = self.wait_event(["TSPEC-REQ-FAILED"], timeout=2)
955             if ev is None:
956                 raise Exception("ADDTS failed (time out while waiting failure)")
957             if "tsid=%d" % (tsid) not in ev:
958                 raise Exception("ADDTS failed (invalid tsid in TSPEC-REQ-FAILED")
959             return
960
961         ev = self.wait_event(["TSPEC-ADDED"], timeout=1)
962         if ev is None:
963             raise Exception("ADDTS failed (time out)")
964         if "tsid=%d" % (tsid) not in ev:
965             raise Exception("ADDTS failed (invalid tsid in TSPEC-ADDED)")
966
967         if not (tsid, up) in self.tspecs():
968             raise Exception("ADDTS failed (tsid not in tspec list)")
969
970     def del_ts(self, tsid):
971         if self.request("WMM_AC_DELTS %d" % (tsid)).strip() != "OK":
972             raise Exception("DELTS failed")
973
974         ev = self.wait_event(["TSPEC-REMOVED"], timeout=1)
975         if ev is None:
976             raise Exception("DELTS failed (time out)")
977         if "tsid=%d" % (tsid) not in ev:
978             raise Exception("DELTS failed (invalid tsid in TSPEC-REMOVED)")
979
980         tspecs = [(t, u) for (t, u) in self.tspecs() if t == tsid]
981         if tspecs:
982             raise Exception("DELTS failed (still in tspec list)")
983
984     def connect(self, ssid=None, ssid2=None, **kwargs):
985         logger.info("Connect STA " + self.ifname + " to AP")
986         id = self.add_network()
987         if ssid:
988             self.set_network_quoted(id, "ssid", ssid)
989         elif ssid2:
990             self.set_network(id, "ssid", ssid2)
991
992         quoted = [ "psk", "identity", "anonymous_identity", "password",
993                    "ca_cert", "client_cert", "private_key",
994                    "private_key_passwd", "ca_cert2", "client_cert2",
995                    "private_key2", "phase1", "phase2", "domain_suffix_match",
996                    "altsubject_match", "subject_match", "pac_file", "dh_file",
997                    "bgscan", "ht_mcs", "id_str", "openssl_ciphers",
998                    "domain_match" ]
999         for field in quoted:
1000             if field in kwargs and kwargs[field]:
1001                 self.set_network_quoted(id, field, kwargs[field])
1002
1003         not_quoted = [ "proto", "key_mgmt", "ieee80211w", "pairwise",
1004                        "group", "wep_key0", "wep_key1", "wep_key2", "wep_key3",
1005                        "wep_tx_keyidx", "scan_freq", "freq_list", "eap",
1006                        "eapol_flags", "fragment_size", "scan_ssid", "auth_alg",
1007                        "wpa_ptk_rekey", "disable_ht", "disable_vht", "bssid",
1008                        "disable_max_amsdu", "ampdu_factor", "ampdu_density",
1009                        "disable_ht40", "disable_sgi", "disable_ldpc",
1010                        "ht40_intolerant", "update_identifier", "mac_addr",
1011                        "erp", "bg_scan_period", "bssid_blacklist",
1012                        "bssid_whitelist", "mem_only_psk", "eap_workaround",
1013                        "engine" ]
1014         for field in not_quoted:
1015             if field in kwargs and kwargs[field]:
1016                 self.set_network(id, field, kwargs[field])
1017
1018         if "raw_psk" in kwargs and kwargs['raw_psk']:
1019             self.set_network(id, "psk", kwargs['raw_psk'])
1020         if "password_hex" in kwargs and kwargs['password_hex']:
1021             self.set_network(id, "password", kwargs['password_hex'])
1022         if "peerkey" in kwargs and kwargs['peerkey']:
1023             self.set_network(id, "peerkey", "1")
1024         if "okc" in kwargs and kwargs['okc']:
1025             self.set_network(id, "proactive_key_caching", "1")
1026         if "ocsp" in kwargs and kwargs['ocsp']:
1027             self.set_network(id, "ocsp", str(kwargs['ocsp']))
1028         if "only_add_network" in kwargs and kwargs['only_add_network']:
1029             return id
1030         if "wait_connect" not in kwargs or kwargs['wait_connect']:
1031             if "eap" in kwargs:
1032                 self.connect_network(id, timeout=20)
1033             else:
1034                 self.connect_network(id)
1035         else:
1036             self.dump_monitor()
1037             self.select_network(id)
1038         return id
1039
1040     def scan(self, type=None, freq=None, no_wait=False, only_new=False):
1041         if type:
1042             cmd = "SCAN TYPE=" + type
1043         else:
1044             cmd = "SCAN"
1045         if freq:
1046             cmd = cmd + " freq=" + str(freq)
1047         if only_new:
1048             cmd += " only_new=1"
1049         if not no_wait:
1050             self.dump_monitor()
1051         if not "OK" in self.request(cmd):
1052             raise Exception("Failed to trigger scan")
1053         if no_wait:
1054             return
1055         ev = self.wait_event(["CTRL-EVENT-SCAN-RESULTS"], 15)
1056         if ev is None:
1057             raise Exception("Scan timed out")
1058
1059     def scan_for_bss(self, bssid, freq=None, force_scan=False, only_new=False):
1060         if not force_scan and self.get_bss(bssid) is not None:
1061             return
1062         for i in range(0, 10):
1063             self.scan(freq=freq, type="ONLY", only_new=only_new)
1064             if self.get_bss(bssid) is not None:
1065                 return
1066         raise Exception("Could not find BSS " + bssid + " in scan")
1067
1068     def flush_scan_cache(self, freq=2417):
1069         self.request("BSS_FLUSH 0")
1070         self.scan(freq=freq, only_new=True)
1071         res = self.request("SCAN_RESULTS")
1072         if len(res.splitlines()) > 1:
1073             self.request("BSS_FLUSH 0")
1074             self.scan(freq=2422, only_new=True)
1075             res = self.request("SCAN_RESULTS")
1076             if len(res.splitlines()) > 1:
1077                 logger.info("flush_scan_cache: Could not clear all BSS entries. These remain:\n" + res)
1078
1079     def roam(self, bssid, fail_test=False):
1080         self.dump_monitor()
1081         if "OK" not in self.request("ROAM " + bssid):
1082             raise Exception("ROAM failed")
1083         if fail_test:
1084             ev = self.wait_event(["CTRL-EVENT-CONNECTED"], timeout=1)
1085             if ev is not None:
1086                 raise Exception("Unexpected connection")
1087             self.dump_monitor()
1088             return
1089         self.wait_connected(timeout=10, error="Roaming with the AP timed out")
1090         self.dump_monitor()
1091
1092     def roam_over_ds(self, bssid, fail_test=False):
1093         self.dump_monitor()
1094         if "OK" not in self.request("FT_DS " + bssid):
1095             raise Exception("FT_DS failed")
1096         if fail_test:
1097             ev = self.wait_event(["CTRL-EVENT-CONNECTED"], timeout=1)
1098             if ev is not None:
1099                 raise Exception("Unexpected connection")
1100             self.dump_monitor()
1101             return
1102         self.wait_connected(timeout=10, error="Roaming with the AP timed out")
1103         self.dump_monitor()
1104
1105     def wps_reg(self, bssid, pin, new_ssid=None, key_mgmt=None, cipher=None,
1106                 new_passphrase=None, no_wait=False):
1107         self.dump_monitor()
1108         if new_ssid:
1109             self.request("WPS_REG " + bssid + " " + pin + " " +
1110                          new_ssid.encode("hex") + " " + key_mgmt + " " +
1111                          cipher + " " + new_passphrase.encode("hex"))
1112             if no_wait:
1113                 return
1114             ev = self.wait_event(["WPS-SUCCESS"], timeout=15)
1115         else:
1116             self.request("WPS_REG " + bssid + " " + pin)
1117             if no_wait:
1118                 return
1119             ev = self.wait_event(["WPS-CRED-RECEIVED"], timeout=15)
1120             if ev is None:
1121                 raise Exception("WPS cred timed out")
1122             ev = self.wait_event(["WPS-FAIL"], timeout=15)
1123         if ev is None:
1124             raise Exception("WPS timed out")
1125         self.wait_connected(timeout=15)
1126
1127     def relog(self):
1128         self.global_request("RELOG")
1129
1130     def wait_completed(self, timeout=10):
1131         for i in range(0, timeout * 2):
1132             if self.get_status_field("wpa_state") == "COMPLETED":
1133                 return
1134             time.sleep(0.5)
1135         raise Exception("Timeout while waiting for COMPLETED state")
1136
1137     def get_capability(self, field):
1138         res = self.request("GET_CAPABILITY " + field)
1139         if "FAIL" in res:
1140             return None
1141         return res.split(' ')
1142
1143     def get_bss(self, bssid, ifname=None):
1144         if not ifname or ifname == self.ifname:
1145             res = self.request("BSS " + bssid)
1146         elif ifname == self.group_ifname:
1147             res = self.group_request("BSS " + bssid)
1148         else:
1149             return None
1150
1151         if "FAIL" in res:
1152             return None
1153         lines = res.splitlines()
1154         vals = dict()
1155         for l in lines:
1156             [name,value] = l.split('=', 1)
1157             vals[name] = value
1158         if len(vals) == 0:
1159             return None
1160         return vals
1161
1162     def get_pmksa(self, bssid):
1163         res = self.request("PMKSA")
1164         lines = res.splitlines()
1165         for l in lines:
1166             if bssid not in l:
1167                 continue
1168             vals = dict()
1169             [index,aa,pmkid,expiration,opportunistic] = l.split(' ')
1170             vals['index'] = index
1171             vals['pmkid'] = pmkid
1172             vals['expiration'] = expiration
1173             vals['opportunistic'] = opportunistic
1174             return vals
1175         return None
1176
1177     def get_sta(self, addr, info=None, next=False):
1178         cmd = "STA-NEXT " if next else "STA "
1179         if addr is None:
1180             res = self.request("STA-FIRST")
1181         elif info:
1182             res = self.request(cmd + addr + " " + info)
1183         else:
1184             res = self.request(cmd + addr)
1185         lines = res.splitlines()
1186         vals = dict()
1187         first = True
1188         for l in lines:
1189             if first:
1190                 vals['addr'] = l
1191                 first = False
1192             else:
1193                 [name,value] = l.split('=', 1)
1194                 vals[name] = value
1195         return vals
1196
1197     def mgmt_rx(self, timeout=5):
1198         ev = self.wait_event(["MGMT-RX"], timeout=timeout)
1199         if ev is None:
1200             return None
1201         msg = {}
1202         items = ev.split(' ')
1203         field,val = items[1].split('=')
1204         if field != "freq":
1205             raise Exception("Unexpected MGMT-RX event format: " + ev)
1206         msg['freq'] = val
1207
1208         field,val = items[2].split('=')
1209         if field != "datarate":
1210             raise Exception("Unexpected MGMT-RX event format: " + ev)
1211         msg['datarate'] = val
1212
1213         field,val = items[3].split('=')
1214         if field != "ssi_signal":
1215             raise Exception("Unexpected MGMT-RX event format: " + ev)
1216         msg['ssi_signal'] = val
1217
1218         frame = binascii.unhexlify(items[4])
1219         msg['frame'] = frame
1220
1221         hdr = struct.unpack('<HH6B6B6BH', frame[0:24])
1222         msg['fc'] = hdr[0]
1223         msg['subtype'] = (hdr[0] >> 4) & 0xf
1224         hdr = hdr[1:]
1225         msg['duration'] = hdr[0]
1226         hdr = hdr[1:]
1227         msg['da'] = "%02x:%02x:%02x:%02x:%02x:%02x" % hdr[0:6]
1228         hdr = hdr[6:]
1229         msg['sa'] = "%02x:%02x:%02x:%02x:%02x:%02x" % hdr[0:6]
1230         hdr = hdr[6:]
1231         msg['bssid'] = "%02x:%02x:%02x:%02x:%02x:%02x" % hdr[0:6]
1232         hdr = hdr[6:]
1233         msg['seq_ctrl'] = hdr[0]
1234         msg['payload'] = frame[24:]
1235
1236         return msg
1237
1238     def wait_connected(self, timeout=10, error="Connection timed out"):
1239         ev = self.wait_event(["CTRL-EVENT-CONNECTED"], timeout=timeout)
1240         if ev is None:
1241             raise Exception(error)
1242         return ev
1243
1244     def wait_disconnected(self, timeout=None, error="Disconnection timed out"):
1245         if timeout is None:
1246             timeout = 10 if self.hostname is None else 30
1247         ev = self.wait_event(["CTRL-EVENT-DISCONNECTED"], timeout=timeout)
1248         if ev is None:
1249             raise Exception(error)
1250         return ev
1251
1252     def get_group_ifname(self):
1253         return self.group_ifname if self.group_ifname else self.ifname
1254
1255     def get_config(self):
1256         res = self.request("DUMP")
1257         if res.startswith("FAIL"):
1258             raise Exception("DUMP failed")
1259         lines = res.splitlines()
1260         vals = dict()
1261         for l in lines:
1262             [name,value] = l.split('=', 1)
1263             vals[name] = value
1264         return vals
1265
1266     def asp_provision(self, peer, adv_id, adv_mac, session_id, session_mac,
1267                       method="1000", info="", status=None, cpt=None, role=None):
1268         if status is None:
1269             cmd = "P2P_ASP_PROVISION"
1270             params = "info='%s' method=%s" % (info, method)
1271         else:
1272             cmd = "P2P_ASP_PROVISION_RESP"
1273             params = "status=%d" % status
1274
1275         if role is not None:
1276             params += " role=" + role
1277         if cpt is not None:
1278             params += " cpt=" + cpt
1279
1280         if "OK" not in self.global_request("%s %s adv_id=%s adv_mac=%s session=%d session_mac=%s %s" %
1281                                            (cmd, peer, adv_id, adv_mac, session_id, session_mac, params)):
1282             raise Exception("%s request failed" % cmd)