Updated through tag hostap_2_5 from git://w1.fi/hostap.git
[mech_eap.git] / libeap / tests / test-rsa-sig-ver.c
1 /*
2  * Testing tool for RSA PKCS #1 v1.5 signature verification
3  * Copyright (c) 2014, Jouni Malinen <j@w1.fi>
4  *
5  * This software may be distributed under the terms of the BSD license.
6  * See README for more details.
7  */
8
9 #include "utils/includes.h"
10
11 #include "utils/common.h"
12 #include "crypto/crypto.h"
13 #include "tls/rsa.h"
14 #include "tls/asn1.h"
15 #include "tls/pkcs1.h"
16
17
18 static int cavp_rsa_sig_ver(const char *fname)
19 {
20         FILE *f;
21         int ret = 0;
22         char buf[15000], *pos, *pos2;
23         u8 msg[200], n[512], s[512], em[512], e[512];
24         size_t msg_len = 0, n_len = 0, s_len = 0, em_len, e_len = 0;
25         size_t tmp_len;
26         char sha_alg[20];
27         int ok = 0;
28
29         printf("CAVP RSA SigVer test vectors from %s\n", fname);
30
31         f = fopen(fname, "r");
32         if (f == NULL) {
33                 printf("%s does not exist - cannot validate CAVP RSA SigVer test vectors\n",
34                         fname);
35                 return 0;
36         }
37
38         while (fgets(buf, sizeof(buf), f)) {
39                 pos = os_strchr(buf, '=');
40                 if (pos == NULL)
41                         continue;
42                 pos2 = pos - 1;
43                 while (pos2 >= buf && *pos2 == ' ')
44                         *pos2-- = '\0';
45                 *pos++ = '\0';
46                 while (*pos == ' ')
47                         *pos++ = '\0';
48                 pos2 = os_strchr(pos, '\r');
49                 if (!pos2)
50                         pos2 = os_strchr(pos, '\n');
51                 if (pos2)
52                         *pos2 = '\0';
53                 else
54                         pos2 = pos + os_strlen(pos);
55
56                 if (os_strcmp(buf, "SHAAlg") == 0) {
57                         os_strlcpy(sha_alg, pos, sizeof(sha_alg));
58                 } else if (os_strcmp(buf, "Msg") == 0) {
59                         tmp_len = os_strlen(pos);
60                         if (tmp_len > sizeof(msg) * 2) {
61                                 printf("Too long Msg\n");
62                                 fclose(f);
63                                 return -1;
64                         }
65                         msg_len = tmp_len / 2;
66                         if (hexstr2bin(pos, msg, msg_len) < 0) {
67                                 printf("Invalid hex string '%s'\n", pos);
68                                 ret++;
69                                 break;
70                         }
71                 } else if (os_strcmp(buf, "n") == 0) {
72                         tmp_len = os_strlen(pos);
73                         if (tmp_len > sizeof(n) * 2) {
74                                 printf("Too long n\n");
75                                 fclose(f);
76                                 return -1;
77                         }
78                         n_len = tmp_len / 2;
79                         if (hexstr2bin(pos, n, n_len) < 0) {
80                                 printf("Invalid hex string '%s'\n", pos);
81                                 ret++;
82                                 break;
83                         }
84                 } else if (os_strcmp(buf, "e") == 0) {
85                         tmp_len = os_strlen(pos);
86                         if (tmp_len > sizeof(e) * 2) {
87                                 printf("Too long e\n");
88                                 fclose(f);
89                                 return -1;
90                         }
91                         e_len = tmp_len / 2;
92                         if (hexstr2bin(pos, e, e_len) < 0) {
93                                 printf("Invalid hex string '%s'\n", pos);
94                                 ret++;
95                                 break;
96                         }
97                 } else if (os_strcmp(buf, "S") == 0) {
98                         tmp_len = os_strlen(pos);
99                         if (tmp_len > sizeof(s) * 2) {
100                                 printf("Too long S\n");
101                                 fclose(f);
102                                 return -1;
103                         }
104                         s_len = tmp_len / 2;
105                         if (hexstr2bin(pos, s, s_len) < 0) {
106                                 printf("Invalid hex string '%s'\n", pos);
107                                 ret++;
108                                 break;
109                         }
110                 } else if (os_strncmp(buf, "EM", 2) == 0) {
111                         tmp_len = os_strlen(pos);
112                         if (tmp_len > sizeof(em) * 2) {
113                                 fclose(f);
114                                 return -1;
115                         }
116                         em_len = tmp_len / 2;
117                         if (hexstr2bin(pos, em, em_len) < 0) {
118                                 printf("Invalid hex string '%s'\n", pos);
119                                 ret++;
120                                 break;
121                         }
122                 } else if (os_strcmp(buf, "Result") == 0) {
123                         const u8 *addr[1];
124                         size_t len[1];
125                         struct crypto_public_key *pk;
126                         int res;
127                         u8 hash[32];
128                         size_t hash_len;
129                         const struct asn1_oid *alg;
130
131                         addr[0] = msg;
132                         len[0] = msg_len;
133                         if (os_strcmp(sha_alg, "SHA1") == 0) {
134                                 if (sha1_vector(1, addr, len, hash) < 0) {
135                                         fclose(f);
136                                         return -1;
137                                 }
138                                 hash_len = 20;
139                                 alg = &asn1_sha1_oid;
140                         } else if (os_strcmp(sha_alg, "SHA256") == 0) {
141                                 if (sha256_vector(1, addr, len, hash) < 0) {
142                                         fclose(f);
143                                         return -1;
144                                 }
145                                 hash_len = 32;
146                                 alg = &asn1_sha256_oid;
147                         } else {
148                                 continue;
149                         }
150
151                         printf("\nExpected result: %s\n", pos);
152                         wpa_hexdump(MSG_INFO, "Hash(Msg)", hash, hash_len);
153
154                         pk = crypto_public_key_import_parts(n, n_len,
155                                                             e, e_len);
156                         if (pk == NULL) {
157                                 printf("Failed to import public key\n");
158                                 ret++;
159                                 continue;
160                         }
161
162                         res = pkcs1_v15_sig_ver(pk, s, s_len, alg,
163                                                 hash, hash_len);
164                         crypto_public_key_free(pk);
165                         if ((*pos == 'F' && !res) || (*pos != 'F' && res)) {
166                                 printf("FAIL\n");
167                                 ret++;
168                                 continue;
169                         }
170
171                         printf("PASS\n");
172                         ok++;
173                 }
174         }
175
176         fclose(f);
177
178         if (ret)
179                 printf("Test case failed\n");
180         else
181                 printf("%d test vectors OK\n", ok);
182
183         return ret;
184 }
185
186
187 int main(int argc, char *argv[])
188 {
189         int ret = 0;
190
191         wpa_debug_level = 0;
192
193         if (cavp_rsa_sig_ver("CAVP/SigVer15_186-3.rsp"))
194                 ret++;
195         if (cavp_rsa_sig_ver("CAVP/SigVer15EMTest.txt"))
196                 ret++;
197
198         return ret;
199 }