Move RC4 into crypto.h as a replaceable crypto function
[libeap.git] / src / crypto / crypto_libtomcrypt.c
1 /*
2  * WPA Supplicant / Crypto wrapper for LibTomCrypt (for internal TLSv1)
3  * Copyright (c) 2005-2006, Jouni Malinen <j@w1.fi>
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License version 2 as
7  * published by the Free Software Foundation.
8  *
9  * Alternatively, this software may be distributed under the terms of BSD
10  * license.
11  *
12  * See README and COPYING for more details.
13  */
14
15 #include "includes.h"
16 #include <tomcrypt.h>
17
18 #include "common.h"
19 #include "crypto.h"
20
21 #ifndef mp_init_multi
22 #define mp_init_multi                ltc_init_multi
23 #define mp_clear_multi               ltc_deinit_multi
24 #define mp_unsigned_bin_size(a)      ltc_mp.unsigned_size(a)
25 #define mp_to_unsigned_bin(a, b)     ltc_mp.unsigned_write(a, b)
26 #define mp_read_unsigned_bin(a, b, c) ltc_mp.unsigned_read(a, b, c)
27 #define mp_exptmod(a,b,c,d)          ltc_mp.exptmod(a,b,c,d)
28 #endif
29
30
31 int md4_vector(size_t num_elem, const u8 *addr[], const size_t *len, u8 *mac)
32 {
33         hash_state md;
34         size_t i;
35
36         md4_init(&md);
37         for (i = 0; i < num_elem; i++)
38                 md4_process(&md, addr[i], len[i]);
39         md4_done(&md, mac);
40         return 0;
41 }
42
43
44 void des_encrypt(const u8 *clear, const u8 *key, u8 *cypher)
45 {
46         u8 pkey[8], next, tmp;
47         int i;
48         symmetric_key skey;
49
50         /* Add parity bits to the key */
51         next = 0;
52         for (i = 0; i < 7; i++) {
53                 tmp = key[i];
54                 pkey[i] = (tmp >> i) | next | 1;
55                 next = tmp << (7 - i);
56         }
57         pkey[i] = next | 1;
58
59         des_setup(pkey, 8, 0, &skey);
60         des_ecb_encrypt(clear, cypher, &skey);
61         des_done(&skey);
62 }
63
64
65 #ifdef EAP_TLS_FUNCS
66 int md5_vector(size_t num_elem, const u8 *addr[], const size_t *len, u8 *mac)
67 {
68         hash_state md;
69         size_t i;
70
71         md5_init(&md);
72         for (i = 0; i < num_elem; i++)
73                 md5_process(&md, addr[i], len[i]);
74         md5_done(&md, mac);
75         return 0;
76 }
77
78
79 int sha1_vector(size_t num_elem, const u8 *addr[], const size_t *len, u8 *mac)
80 {
81         hash_state md;
82         size_t i;
83
84         sha1_init(&md);
85         for (i = 0; i < num_elem; i++)
86                 sha1_process(&md, addr[i], len[i]);
87         sha1_done(&md, mac);
88         return 0;
89 }
90
91
92 void * aes_encrypt_init(const u8 *key, size_t len)
93 {
94         symmetric_key *skey;
95         skey = os_malloc(sizeof(*skey));
96         if (skey == NULL)
97                 return NULL;
98         if (aes_setup(key, len, 0, skey) != CRYPT_OK) {
99                 os_free(skey);
100                 return NULL;
101         }
102         return skey;
103 }
104
105
106 void aes_encrypt(void *ctx, const u8 *plain, u8 *crypt)
107 {
108         symmetric_key *skey = ctx;
109         aes_ecb_encrypt(plain, crypt, skey);
110 }
111
112
113 void aes_encrypt_deinit(void *ctx)
114 {
115         symmetric_key *skey = ctx;
116         aes_done(skey);
117         os_free(skey);
118 }
119
120
121 void * aes_decrypt_init(const u8 *key, size_t len)
122 {
123         symmetric_key *skey;
124         skey = os_malloc(sizeof(*skey));
125         if (skey == NULL)
126                 return NULL;
127         if (aes_setup(key, len, 0, skey) != CRYPT_OK) {
128                 os_free(skey);
129                 return NULL;
130         }
131         return skey;
132 }
133
134
135 void aes_decrypt(void *ctx, const u8 *crypt, u8 *plain)
136 {
137         symmetric_key *skey = ctx;
138         aes_ecb_encrypt(plain, (u8 *) crypt, skey);
139 }
140
141
142 void aes_decrypt_deinit(void *ctx)
143 {
144         symmetric_key *skey = ctx;
145         aes_done(skey);
146         os_free(skey);
147 }
148
149
150 #ifdef CONFIG_TLS_INTERNAL
151
152 struct crypto_hash {
153         enum crypto_hash_alg alg;
154         int error;
155         union {
156                 hash_state md;
157                 hmac_state hmac;
158         } u;
159 };
160
161
162 struct crypto_hash * crypto_hash_init(enum crypto_hash_alg alg, const u8 *key,
163                                       size_t key_len)
164 {
165         struct crypto_hash *ctx;
166
167         ctx = os_zalloc(sizeof(*ctx));
168         if (ctx == NULL)
169                 return NULL;
170
171         ctx->alg = alg;
172
173         switch (alg) {
174         case CRYPTO_HASH_ALG_MD5:
175                 if (md5_init(&ctx->u.md) != CRYPT_OK)
176                         goto fail;
177                 break;
178         case CRYPTO_HASH_ALG_SHA1:
179                 if (sha1_init(&ctx->u.md) != CRYPT_OK)
180                         goto fail;
181                 break;
182         case CRYPTO_HASH_ALG_HMAC_MD5:
183                 if (hmac_init(&ctx->u.hmac, find_hash("md5"), key, key_len) !=
184                     CRYPT_OK)
185                         goto fail;
186                 break;
187         case CRYPTO_HASH_ALG_HMAC_SHA1:
188                 if (hmac_init(&ctx->u.hmac, find_hash("sha1"), key, key_len) !=
189                     CRYPT_OK)
190                         goto fail;
191                 break;
192         default:
193                 goto fail;
194         }
195
196         return ctx;
197
198 fail:
199         os_free(ctx);
200         return NULL;
201 }
202
203 void crypto_hash_update(struct crypto_hash *ctx, const u8 *data, size_t len)
204 {
205         if (ctx == NULL || ctx->error)
206                 return;
207
208         switch (ctx->alg) {
209         case CRYPTO_HASH_ALG_MD5:
210                 ctx->error = md5_process(&ctx->u.md, data, len) != CRYPT_OK;
211                 break;
212         case CRYPTO_HASH_ALG_SHA1:
213                 ctx->error = sha1_process(&ctx->u.md, data, len) != CRYPT_OK;
214                 break;
215         case CRYPTO_HASH_ALG_HMAC_MD5:
216         case CRYPTO_HASH_ALG_HMAC_SHA1:
217                 ctx->error = hmac_process(&ctx->u.hmac, data, len) != CRYPT_OK;
218                 break;
219         }
220 }
221
222
223 int crypto_hash_finish(struct crypto_hash *ctx, u8 *mac, size_t *len)
224 {
225         int ret = 0;
226         unsigned long clen;
227
228         if (ctx == NULL)
229                 return -2;
230
231         if (mac == NULL || len == NULL) {
232                 os_free(ctx);
233                 return 0;
234         }
235
236         if (ctx->error) {
237                 os_free(ctx);
238                 return -2;
239         }
240
241         switch (ctx->alg) {
242         case CRYPTO_HASH_ALG_MD5:
243                 if (*len < 16) {
244                         *len = 16;
245                         os_free(ctx);
246                         return -1;
247                 }
248                 *len = 16;
249                 if (md5_done(&ctx->u.md, mac) != CRYPT_OK)
250                         ret = -2;
251                 break;
252         case CRYPTO_HASH_ALG_SHA1:
253                 if (*len < 20) {
254                         *len = 20;
255                         os_free(ctx);
256                         return -1;
257                 }
258                 *len = 20;
259                 if (sha1_done(&ctx->u.md, mac) != CRYPT_OK)
260                         ret = -2;
261                 break;
262         case CRYPTO_HASH_ALG_HMAC_SHA1:
263                 if (*len < 20) {
264                         *len = 20;
265                         os_free(ctx);
266                         return -1;
267                 }
268                 /* continue */
269         case CRYPTO_HASH_ALG_HMAC_MD5:
270                 if (*len < 16) {
271                         *len = 16;
272                         os_free(ctx);
273                         return -1;
274                 }
275                 clen = *len;
276                 if (hmac_done(&ctx->u.hmac, mac, &clen) != CRYPT_OK) {
277                         os_free(ctx);
278                         return -1;
279                 }
280                 *len = clen;
281                 break;
282         default:
283                 ret = -2;
284                 break;
285         }
286
287         os_free(ctx);
288
289         return ret;
290 }
291
292
293 struct crypto_cipher {
294         int rc4;
295         union {
296                 symmetric_CBC cbc;
297                 struct {
298                         size_t used_bytes;
299                         u8 key[16];
300                         size_t keylen;
301                 } rc4;
302         } u;
303 };
304
305
306 struct crypto_cipher * crypto_cipher_init(enum crypto_cipher_alg alg,
307                                           const u8 *iv, const u8 *key,
308                                           size_t key_len)
309 {       
310         struct crypto_cipher *ctx;
311         int idx, res, rc4 = 0;
312
313         switch (alg) {
314         case CRYPTO_CIPHER_ALG_AES:
315                 idx = find_cipher("aes");
316                 break;
317         case CRYPTO_CIPHER_ALG_3DES:
318                 idx = find_cipher("3des");
319                 break;
320         case CRYPTO_CIPHER_ALG_DES:
321                 idx = find_cipher("des");
322                 break;
323         case CRYPTO_CIPHER_ALG_RC2:
324                 idx = find_cipher("rc2");
325                 break;
326         case CRYPTO_CIPHER_ALG_RC4:
327                 idx = -1;
328                 rc4 = 1;
329                 break;
330         default:
331                 return NULL;
332         }
333
334         ctx = os_zalloc(sizeof(*ctx));
335         if (ctx == NULL)
336                 return NULL;
337
338         if (rc4) {
339                 ctx->rc4 = 1;
340                 if (key_len > sizeof(ctx->u.rc4.key)) {
341                         os_free(ctx);
342                         return NULL;
343                 }
344                 ctx->u.rc4.keylen = key_len;
345                 os_memcpy(ctx->u.rc4.key, key, key_len);
346         } else {
347                 res = cbc_start(idx, iv, key, key_len, 0, &ctx->u.cbc);
348                 if (res != CRYPT_OK) {
349                         wpa_printf(MSG_DEBUG, "LibTomCrypt: Cipher start "
350                                    "failed: %s", error_to_string(res));
351                         os_free(ctx);
352                         return NULL;
353                 }
354         }
355
356         return ctx;
357 }
358
359 int crypto_cipher_encrypt(struct crypto_cipher *ctx, const u8 *plain,
360                           u8 *crypt, size_t len)
361 {
362         int res;
363
364         if (ctx->rc4) {
365                 if (plain != crypt)
366                         os_memcpy(crypt, plain, len);
367                 rc4_skip(ctx->u.rc4.key, ctx->u.rc4.keylen,
368                          ctx->u.rc4.used_bytes, crypt, len);
369                 ctx->u.rc4.used_bytes += len;
370                 return 0;
371         }
372
373         res = cbc_encrypt(plain, crypt, len, &ctx->u.cbc);
374         if (res != CRYPT_OK) {
375                 wpa_printf(MSG_DEBUG, "LibTomCrypt: CBC encryption "
376                            "failed: %s", error_to_string(res));
377                 return -1;
378         }
379         return 0;
380 }
381
382
383 int crypto_cipher_decrypt(struct crypto_cipher *ctx, const u8 *crypt,
384                           u8 *plain, size_t len)
385 {
386         int res;
387
388         if (ctx->rc4) {
389                 if (plain != crypt)
390                         os_memcpy(plain, crypt, len);
391                 rc4_skip(ctx->u.rc4.key, ctx->u.rc4.keylen,
392                          ctx->u.rc4.used_bytes, plain, len);
393                 ctx->u.rc4.used_bytes += len;
394                 return 0;
395         }
396
397         res = cbc_decrypt(crypt, plain, len, &ctx->u.cbc);
398         if (res != CRYPT_OK) {
399                 wpa_printf(MSG_DEBUG, "LibTomCrypt: CBC decryption "
400                            "failed: %s", error_to_string(res));
401                 return -1;
402         }
403
404         return 0;
405 }
406
407
408 void crypto_cipher_deinit(struct crypto_cipher *ctx)
409 {
410         if (!ctx->rc4)
411                 cbc_done(&ctx->u.cbc);
412         os_free(ctx);
413 }
414
415
416 struct crypto_public_key {
417         rsa_key rsa;
418 };
419
420 struct crypto_private_key {
421         rsa_key rsa;
422 };
423
424
425 struct crypto_public_key * crypto_public_key_import(const u8 *key, size_t len)
426 {
427         int res;
428         struct crypto_public_key *pk;
429
430         pk = os_zalloc(sizeof(*pk));
431         if (pk == NULL)
432                 return NULL;
433
434         res = rsa_import(key, len, &pk->rsa);
435         if (res != CRYPT_OK) {
436                 wpa_printf(MSG_ERROR, "LibTomCrypt: Failed to import "
437                            "public key (res=%d '%s')",
438                            res, error_to_string(res));
439                 os_free(pk);
440                 return NULL;
441         }
442
443         if (pk->rsa.type != PK_PUBLIC) {
444                 wpa_printf(MSG_ERROR, "LibTomCrypt: Public key was not of "
445                            "correct type");
446                 rsa_free(&pk->rsa);
447                 os_free(pk);
448                 return NULL;
449         }
450
451         return pk;
452 }
453
454
455 struct crypto_private_key * crypto_private_key_import(const u8 *key,
456                                                       size_t len)
457 {
458         int res;
459         struct crypto_private_key *pk;
460
461         pk = os_zalloc(sizeof(*pk));
462         if (pk == NULL)
463                 return NULL;
464
465         res = rsa_import(key, len, &pk->rsa);
466         if (res != CRYPT_OK) {
467                 wpa_printf(MSG_ERROR, "LibTomCrypt: Failed to import "
468                            "private key (res=%d '%s')",
469                            res, error_to_string(res));
470                 os_free(pk);
471                 return NULL;
472         }
473
474         if (pk->rsa.type != PK_PRIVATE) {
475                 wpa_printf(MSG_ERROR, "LibTomCrypt: Private key was not of "
476                            "correct type");
477                 rsa_free(&pk->rsa);
478                 os_free(pk);
479                 return NULL;
480         }
481
482         return pk;
483 }
484
485
486 struct crypto_public_key * crypto_public_key_from_cert(const u8 *buf,
487                                                        size_t len)
488 {
489         /* No X.509 support in LibTomCrypt */
490         return NULL;
491 }
492
493
494 static int pkcs1_generate_encryption_block(u8 block_type, size_t modlen,
495                                            const u8 *in, size_t inlen,
496                                            u8 *out, size_t *outlen)
497 {
498         size_t ps_len;
499         u8 *pos;
500
501         /*
502          * PKCS #1 v1.5, 8.1:
503          *
504          * EB = 00 || BT || PS || 00 || D
505          * BT = 00 or 01 for private-key operation; 02 for public-key operation
506          * PS = k-3-||D||; at least eight octets
507          * (BT=0: PS=0x00, BT=1: PS=0xff, BT=2: PS=pseudorandom non-zero)
508          * k = length of modulus in octets (modlen)
509          */
510
511         if (modlen < 12 || modlen > *outlen || inlen > modlen - 11) {
512                 wpa_printf(MSG_DEBUG, "PKCS #1: %s - Invalid buffer "
513                            "lengths (modlen=%lu outlen=%lu inlen=%lu)",
514                            __func__, (unsigned long) modlen,
515                            (unsigned long) *outlen,
516                            (unsigned long) inlen);
517                 return -1;
518         }
519
520         pos = out;
521         *pos++ = 0x00;
522         *pos++ = block_type; /* BT */
523         ps_len = modlen - inlen - 3;
524         switch (block_type) {
525         case 0:
526                 os_memset(pos, 0x00, ps_len);
527                 pos += ps_len;
528                 break;
529         case 1:
530                 os_memset(pos, 0xff, ps_len);
531                 pos += ps_len;
532                 break;
533         case 2:
534                 if (os_get_random(pos, ps_len) < 0) {
535                         wpa_printf(MSG_DEBUG, "PKCS #1: %s - Failed to get "
536                                    "random data for PS", __func__);
537                         return -1;
538                 }
539                 while (ps_len--) {
540                         if (*pos == 0x00)
541                                 *pos = 0x01;
542                         pos++;
543                 }
544                 break;
545         default:
546                 wpa_printf(MSG_DEBUG, "PKCS #1: %s - Unsupported block type "
547                            "%d", __func__, block_type);
548                 return -1;
549         }
550         *pos++ = 0x00;
551         os_memcpy(pos, in, inlen); /* D */
552
553         return 0;
554 }
555
556
557 static int crypto_rsa_encrypt_pkcs1(int block_type, rsa_key *key, int key_type,
558                                     const u8 *in, size_t inlen,
559                                     u8 *out, size_t *outlen)
560 {
561         unsigned long len, modlen;
562         int res;
563
564         modlen = mp_unsigned_bin_size(key->N);
565
566         if (pkcs1_generate_encryption_block(block_type, modlen, in, inlen,
567                                             out, outlen) < 0)
568                 return -1;
569
570         len = *outlen;
571         res = rsa_exptmod(out, modlen, out, &len, key_type, key);
572         if (res != CRYPT_OK) {
573                 wpa_printf(MSG_DEBUG, "LibTomCrypt: rsa_exptmod failed: %s",
574                            error_to_string(res));
575                 return -1;
576         }
577         *outlen = len;
578
579         return 0;
580 }
581
582
583 int crypto_public_key_encrypt_pkcs1_v15(struct crypto_public_key *key,
584                                         const u8 *in, size_t inlen,
585                                         u8 *out, size_t *outlen)
586 {
587         return crypto_rsa_encrypt_pkcs1(2, &key->rsa, PK_PUBLIC, in, inlen,
588                                         out, outlen);
589 }
590
591
592 int crypto_private_key_sign_pkcs1(struct crypto_private_key *key,
593                                   const u8 *in, size_t inlen,
594                                   u8 *out, size_t *outlen)
595 {
596         return crypto_rsa_encrypt_pkcs1(1, &key->rsa, PK_PRIVATE, in, inlen,
597                                         out, outlen);
598 }
599
600
601 void crypto_public_key_free(struct crypto_public_key *key)
602 {
603         if (key) {
604                 rsa_free(&key->rsa);
605                 os_free(key);
606         }
607 }
608
609
610 void crypto_private_key_free(struct crypto_private_key *key)
611 {
612         if (key) {
613                 rsa_free(&key->rsa);
614                 os_free(key);
615         }
616 }
617
618
619 int crypto_public_key_decrypt_pkcs1(struct crypto_public_key *key,
620                                     const u8 *crypt, size_t crypt_len,
621                                     u8 *plain, size_t *plain_len)
622 {
623         int res;
624         unsigned long len;
625         u8 *pos;
626
627         len = *plain_len;
628         res = rsa_exptmod(crypt, crypt_len, plain, &len, PK_PUBLIC,
629                           &key->rsa);
630         if (res != CRYPT_OK) {
631                 wpa_printf(MSG_DEBUG, "LibTomCrypt: rsa_exptmod failed: %s",
632                            error_to_string(res));
633                 return -1;
634         }
635
636         /*
637          * PKCS #1 v1.5, 8.1:
638          *
639          * EB = 00 || BT || PS || 00 || D
640          * BT = 01
641          * PS = k-3-||D|| times FF
642          * k = length of modulus in octets
643          */
644
645         if (len < 3 + 8 + 16 /* min hash len */ ||
646             plain[0] != 0x00 || plain[1] != 0x01 || plain[2] != 0xff) {
647                 wpa_printf(MSG_INFO, "LibTomCrypt: Invalid signature EB "
648                            "structure");
649                 return -1;
650         }
651
652         pos = plain + 3;
653         while (pos < plain + len && *pos == 0xff)
654                 pos++;
655         if (pos - plain - 2 < 8) {
656                 /* PKCS #1 v1.5, 8.1: At least eight octets long PS */
657                 wpa_printf(MSG_INFO, "LibTomCrypt: Too short signature "
658                            "padding");
659                 return -1;
660         }
661
662         if (pos + 16 /* min hash len */ >= plain + len || *pos != 0x00) {
663                 wpa_printf(MSG_INFO, "LibTomCrypt: Invalid signature EB "
664                            "structure (2)");
665                 return -1;
666         }
667         pos++;
668         len -= pos - plain;
669
670         /* Strip PKCS #1 header */
671         os_memmove(plain, pos, len);
672         *plain_len = len;
673
674         return 0;
675 }
676
677
678 int crypto_global_init(void)
679 {
680         ltc_mp = tfm_desc;
681         /* TODO: only register algorithms that are really needed */
682         if (register_hash(&md4_desc) < 0 ||
683             register_hash(&md5_desc) < 0 ||
684             register_hash(&sha1_desc) < 0 ||
685             register_cipher(&aes_desc) < 0 ||
686             register_cipher(&des_desc) < 0 ||
687             register_cipher(&des3_desc) < 0) {
688                 wpa_printf(MSG_ERROR, "TLSv1: Failed to register "
689                            "hash/cipher functions");
690                 return -1;
691         }
692
693         return 0;
694 }
695
696
697 void crypto_global_deinit(void)
698 {
699 }
700
701
702 #if defined(EAP_FAST) || defined(EAP_SERVER_FAST)
703
704 int crypto_mod_exp(const u8 *base, size_t base_len,
705                    const u8 *power, size_t power_len,
706                    const u8 *modulus, size_t modulus_len,
707                    u8 *result, size_t *result_len)
708 {
709         void *b, *p, *m, *r;
710
711         if (mp_init_multi(&b, &p, &m, &r, NULL) != CRYPT_OK)
712                 return -1;
713
714         if (mp_read_unsigned_bin(b, (u8 *) base, base_len) != CRYPT_OK ||
715             mp_read_unsigned_bin(p, (u8 *) power, power_len) != CRYPT_OK ||
716             mp_read_unsigned_bin(m, (u8 *) modulus, modulus_len) != CRYPT_OK)
717                 goto fail;
718
719         if (mp_exptmod(b, p, m, r) != CRYPT_OK)
720                 goto fail;
721
722         *result_len = mp_unsigned_bin_size(r);
723         if (mp_to_unsigned_bin(r, result) != CRYPT_OK)
724                 goto fail;
725
726         mp_clear_multi(b, p, m, r, NULL);
727         return 0;
728
729 fail:
730         mp_clear_multi(b, p, m, r, NULL);
731         return -1;
732 }
733
734 #endif /* EAP_FAST || EAP_SERVER_FAST */
735
736 #endif /* CONFIG_TLS_INTERNAL */
737
738 #endif /* EAP_TLS_FUNCS */