#include "rsa.h" #include #include #include #include static int random_prime(mpz_t prime, const size_t size) { u8 tmp[size]; FILE *urandom = fopen("/dev/urandom", "rb"); if((urandom == NULL) || (prime == NULL)) return 0; fread(tmp, 1, size, urandom); mpz_import(prime, size, 1, 1, 1, 0, tmp); mpz_nextprime(prime, prime); fclose(urandom); return 1; } static int rsa_keygen(rsa_key *key) { if(key == NULL) return 0; // init bignums mpz_init_set_ui(key->e, 65537); mpz_inits(key->p, key->q, key->n, key->d, NULL); // prime gen if ((!random_prime(key->p, MODULUS_SIZE/2)) || (!random_prime(key->q, MODULUS_SIZE/2))) return 0; //printf("%d\n", mpz_probab_prime_p(key->p, 50)); //printf("%d\n", mpz_probab_prime_p(key->q, 50)); // compute n mpz_mul(key->n, key->p, key->q); // compute phi(n) mpz_t phi_n; mpz_init(phi_n); mpz_sub_ui(key->p, key->p, 1); mpz_sub_ui(key->q, key->q, 1); mpz_mul(phi_n, key->p, key->q); mpz_add_ui(key->p, key->p, 1); mpz_add_ui(key->q, key->q, 1); // compute d if(mpz_invert(key->d, key->e, phi_n) == 0) { return 0; } // free temporary phi_n and return true mpz_clear(phi_n); return 1; } int rsa_init(rsa_key *key) { if(1) { return rsa_keygen(key); } else { // TODO: get from sealing } } void rsa_free(rsa_key *key) { // free bignums mpz_clears(key->p, key->q, key->n, key->e, key->d, NULL); } static int pkcs1(mpz_t message, const u8 *data, const size_t length) { // temporary buffer u8 padded_bytes[MODULUS_SIZE]; // calculate padding size (how many 0xff bytes) size_t padding_length = MODULUS_SIZE - length - 3; if ((padding_length < 8) || (message == NULL) || (data == NULL)) { // message to big // or null pointer return 0; } // set padding bytes padded_bytes[0] = 0x00; padded_bytes[1] = 0x01; padded_bytes[2 + padding_length] = 0x00; for (size_t i = 2; i < padding_length + 2; i++) { padded_bytes[i] = 0xff; } // copy message bytes memcpy(padded_bytes + padding_length + 3, data, length); // convert padded message to mpz_t mpz_import(message, MODULUS_SIZE, 1, 1, 0, 0, padded_bytes); return 1; } // TODO RSA Blinding int rsa_sign(u8 *sig, const u8 *sha256, const rsa_key *key) { // null pointer handling if((sig == NULL) || (sha256 == NULL) || (key == NULL)) return 0; // init bignum message mpz_t message; mpz_init(message); // add padding if(!pkcs1(message, sha256, 32)) { return 0; } // compute signature mpz_powm(message, message, key->d, key->n); // export signature size_t size = (mpz_sizeinbase(message, 2) + 7) / 8; mpz_export(sig, &size, 1, 1, 0, 0, message); // free bignum and return true mpz_clear(message); return 1; } int rsa_verify(const u8 *sig, const size_t sig_length, u8 *sha256, rsa_public_key *pk) { // null pointer handling if((sig == NULL) || (sha256 == NULL) || (pk == NULL)) return 0; // initialize bignums mpz_t signature, message; mpz_inits(signature, message, NULL); // import signature mpz_import(signature, (sig_length < MODULUS_SIZE) ? sig_length : MODULUS_SIZE, 1, 1, 0, 0, sig); // revert rsa signing process mpz_powm(signature, signature, pk->e, pk->n); // rebuild signed message if(!pkcs1(message, sha256, 32)) return 0; // compare signature with expected value if(mpz_cmp(signature, message) != 0) return 0; // free bignums and return valid signature mpz_clears(signature, message, NULL); return 1; } void rsa_print(rsa_key *key) { gmp_printf("%Zu\n", key->p); gmp_printf("%Zu\n", key->q); gmp_printf("%Zu\n", key->n); gmp_printf("%Zu\n", key->e); gmp_printf("%Zu\n", key->d); }