diff --git a/bn_fast_mp_invmod.c b/bn_fast_mp_invmod.c index 08389dd..cabed0c 100644 --- a/bn_fast_mp_invmod.c +++ b/bn_fast_mp_invmod.c @@ -46,6 +46,12 @@ int fast_mp_invmod(const mp_int *a, const mp_int *b, mp_int *c) goto LBL_ERR; } + /* if one of x,y is zero return an error! */ + if ((mp_iszero(&x) == MP_YES) || (mp_iszero(&y) == MP_YES)) { + res = MP_VAL; + goto LBL_ERR; + } + /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */ if ((res = mp_copy(&x, &u)) != MP_OKAY) { goto LBL_ERR; diff --git a/bn_mp_invmod.c b/bn_mp_invmod.c index 525493a..528b0c7 100644 --- a/bn_mp_invmod.c +++ b/bn_mp_invmod.c @@ -18,14 +18,14 @@ /* hac 14.61, pp608 */ int mp_invmod(const mp_int *a, const mp_int *b, mp_int *c) { - /* b cannot be negative */ - if ((b->sign == MP_NEG) || (mp_iszero(b) == MP_YES)) { + /* b cannot be negative and has to be >1 */ + if ((b->sign == MP_NEG) || (mp_cmp_d(b, 1) != MP_GT)) { return MP_VAL; } #ifdef BN_FAST_MP_INVMOD_C /* if the modulus is odd we can use a faster routine instead */ - if ((mp_isodd(b) == MP_YES) && (mp_cmp_d(b, 1) != MP_EQ)) { + if ((mp_isodd(b) == MP_YES)) { return fast_mp_invmod(a, b, c); } #endif