diff --git a/bn.c b/bn.c index 6c4c646..429e1c8 100644 --- a/bn.c +++ b/bn.c @@ -23,6 +23,55 @@ static const char *s_rmap = #undef MAX #define MAX(x,y) ((x)>(y)?(x):(y)) +#ifdef DEBUG + +static char *_funcs[1000]; +static int _ifuncs; + +#define REGFUNC(name) { if (_ifuncs == 999) { printf("TROUBLE\n"); exit(0); } _funcs[_ifuncs++] = name; } +#define DECFUNC() --_ifuncs; +#define VERIFY(val) _verify(val, #val, __LINE__); + +static void _verify(mp_int *a, char *name, int line) +{ + int n, y; + static const char *err[] = { "Null DP", "alloc < used", "digits above used" }; + + /* dp null ? */ + y = 0; + if (a->dp == NULL) goto error; + + /* used should be <= alloc */ + ++y; + if (a->alloc < a->used) goto error; + + /* digits above used should be zero */ + ++y; + for (n = a->used; n < a->alloc; n++) { + if (a->dp[n]) goto error; + } + + /* ok */ + return; +error: + printf("Error (%s) with variable {%s} on line %d\n", err[y], name, line); + for (n = _ifuncs - 1; n >= 0; n--) { + if (_funcs[n] != NULL) { + printf("> %s\n", _funcs[n]); + } + } + printf("\n"); + exit(0); +} + +#else + +#define REGFUNC(name) +#define DECFUNC() +#define VERIFY(val) + +#endif + /* init a new bigint */ int mp_init(mp_int *a) { @@ -39,81 +88,115 @@ int mp_init(mp_int *a) /* clear one (frees) */ void mp_clear(mp_int *a) { + REGFUNC("mp_clear"); if (a->dp != NULL) { + VERIFY(a); memset(a->dp, 0, sizeof(mp_digit) * a->alloc); free(a->dp); a->dp = NULL; + a->alloc = a->used = 0; } + DECFUNC(); } +void mp_exch(mp_int *a, mp_int *b) +{ + mp_int t; + + REGFUNC("mp_exch"); + VERIFY(a); + VERIFY(b); + t = *a; *a = *b; *b = t; + DECFUNC(); +} + /* grow as required */ static int mp_grow(mp_int *a, int size) { int i; + mp_digit *tmp; + + REGFUNC("mp_grow"); + VERIFY(a); /* if the alloc size is smaller alloc more ram */ if (a->alloc < size) { - size += 16 - (size & 15); /* ensure its to the next multiple of 16 words */ - a->dp = realloc(a->dp, sizeof(mp_digit) * size); - if (a->dp == NULL) { + size += 32 - (size & 15); /* ensure there are always at least 16 digits extra on top */ + + tmp = calloc(sizeof(mp_digit), size); + if (tmp == NULL) { return MP_MEM; } - i = a->alloc; - a->alloc = size; - - /* zero top words */ - for (; i < size; i++) { - a->dp[i] = 0; + for (i = 0; i < a->used; i++) { + tmp[i] = a->dp[i]; } + free(a->dp); + a->dp = tmp; + a->alloc = size; } + DECFUNC(); return MP_OKAY; } /* shrink a bignum */ int mp_shrink(mp_int *a) { + REGFUNC("mp_shrink"); + VERIFY(a); if (a->alloc != a->used) { if ((a->dp = realloc(a->dp, sizeof(mp_digit) * a->used)) == NULL) { + DECFUNC(); return MP_MEM; } a->alloc = a->used; } + DECFUNC(); return MP_OKAY; } /* trim unused digits */ static void mp_clamp(mp_int *a) { + REGFUNC("mp_clamp"); + VERIFY(a); while (a->used > 0 && a->dp[a->used-1] == 0) --(a->used); if (a->used == 0) { a->sign = MP_ZPOS; } + DECFUNC(); } /* set to zero */ void mp_zero(mp_int *a) { + REGFUNC("mp_zero"); + VERIFY(a); a->sign = MP_ZPOS; a->used = 0; memset(a->dp, 0, sizeof(mp_digit) * a->alloc); + DECFUNC(); } /* set to a digit */ void mp_set(mp_int *a, mp_digit b) { + REGFUNC("mp_set"); + VERIFY(a); mp_zero(a); a->dp[0] = b & MP_MASK; - a->used = 1; + a->used = (a->dp[0] != 0) ? 1: 0; + DECFUNC(); } /* set a 32-bit const */ int mp_set_int(mp_int *a, unsigned long b) { - int res, x; - if ((res = mp_grow(a, 32/DIGIT_BIT + 1)) != MP_OKAY) { - return res; - } + int x; + + REGFUNC("mp_set_int"); + VERIFY(a); mp_zero(a); + /* set four bits at a time, simplest solution to the what if DIGIT_BIT==7 case */ for (x = 0; x < 8; x++) { mp_mul_2d(a, 4, a); @@ -121,7 +204,9 @@ int mp_set_int(mp_int *a, unsigned long b) b <<= 4; a->used += 32/DIGIT_BIT + 1; } + mp_clamp(a); + DECFUNC(); return MP_OKAY; } @@ -129,11 +214,15 @@ int mp_set_int(mp_int *a, unsigned long b) int mp_init_size(mp_int *a, int size) { int res; + REGFUNC("mp_init_size"); if ((res = mp_init(a)) != MP_OKAY) { + DECFUNC(); return res; } - return mp_grow(a, size); + res = mp_grow(a, size); + DECFUNC(); + return res; } /* copy, b = a */ @@ -141,13 +230,19 @@ int mp_copy(mp_int *a, mp_int *b) { int res, n; + REGFUNC("mp_copy"); + VERIFY(a); + VERIFY(b); + /* if dst == src do nothing */ if (a == b || a->dp == b->dp) { + DECFUNC(); return MP_OKAY; } /* grow dest */ if ((res = mp_grow(b, a->used)) != MP_OKAY) { + DECFUNC(); return res; } @@ -157,6 +252,7 @@ int mp_copy(mp_int *a, mp_int *b) for (n = 0; n < a->used; n++) { b->dp[n] = a->dp[n]; } + DECFUNC(); return MP_OKAY; } @@ -165,20 +261,30 @@ int mp_init_copy(mp_int *a, mp_int *b) { int res; + REGFUNC("mp_init_copy"); + VERIFY(b); if ((res = mp_init(a)) != MP_OKAY) { + DECFUNC(); return res; } - return mp_copy(b, a); + res = mp_copy(b, a); + DECFUNC(); + return res; } /* b = |a| */ int mp_abs(mp_int *a, mp_int *b) { int res; + REGFUNC("mp_abs"); + VERIFY(a); + VERIFY(b); if ((res = mp_copy(a, b)) != MP_OKAY) { + DECFUNC(); return res; } b->sign = MP_ZPOS; + DECFUNC(); return MP_OKAY; } @@ -186,65 +292,94 @@ int mp_abs(mp_int *a, mp_int *b) int mp_neg(mp_int *a, mp_int *b) { int res; + REGFUNC("mp_neg"); + VERIFY(a); + VERIFY(b); if ((res = mp_copy(a, b)) != MP_OKAY) { + DECFUNC(); return res; } b->sign = (a->sign == MP_ZPOS) ? MP_NEG : MP_ZPOS; + DECFUNC(); return MP_OKAY; } - /* compare maginitude of two ints (unsigned) */ int mp_cmp_mag(mp_int *a, mp_int *b) { int n; + REGFUNC("mp_cmp_mag"); + VERIFY(a); + VERIFY(b); + /* compare based on # of non-zero digits */ if (a->used > b->used) { + DECFUNC(); return MP_GT; } else if (a->used < b->used) { + DECFUNC(); return MP_LT; } /* compare based on digits */ for (n = a->used - 1; n >= 0; n--) { if (a->dp[n] > b->dp[n]) { + DECFUNC(); return MP_GT; } else if (a->dp[n] < b->dp[n]) { + DECFUNC(); return MP_LT; } } + DECFUNC(); return MP_EQ; } /* compare two ints (signed)*/ int mp_cmp(mp_int *a, mp_int *b) { + int res; + REGFUNC("mp_cmp"); + VERIFY(a); + VERIFY(b); /* compare based on sign */ if (a->sign == MP_NEG && b->sign == MP_ZPOS) { + DECFUNC(); return MP_LT; } else if (a->sign == MP_ZPOS && b->sign == MP_NEG) { + DECFUNC(); return MP_GT; } - return mp_cmp_mag(a, b); + res = mp_cmp_mag(a, b); + DECFUNC(); + return res; } /* compare a digit */ int mp_cmp_d(mp_int *a, mp_digit b) { + REGFUNC("mp_cmp_d"); + VERIFY(a); + if (a->sign == MP_NEG) { + DECFUNC(); return MP_LT; } if (a->used > 1) { + DECFUNC(); return MP_GT; } if (a->dp[0] > b) { + DECFUNC(); return MP_GT; } else if (a->dp[0] < b) { + DECFUNC(); return MP_LT; } else { + DECFUNC(); return MP_EQ; } } @@ -254,14 +389,19 @@ void mp_rshd(mp_int *a, int b) { int x; + REGFUNC("mp_rshd"); + VERIFY(a); + /* if b <= 0 then ignore it */ if (b <= 0) { + DECFUNC(); return; } /* if b > used then simply zero it and return */ if (a->used < b) { mp_zero(a); + DECFUNC(); return; } @@ -275,6 +415,7 @@ void mp_rshd(mp_int *a, int b) a->dp[x] = 0; } mp_clamp(a); + DECFUNC(); } /* shift left a certain amount of digits */ @@ -282,12 +423,18 @@ int mp_lshd(mp_int *a, int b) { int x, res; + REGFUNC("mp_lshd"); + VERIFY(a); + /* if its less than zero return */ - if (b <= 0) + if (b <= 0) { + DECFUNC(); return MP_OKAY; + } /* grow to fit the new digits */ if ((res = mp_grow(a, a->used + b)) != MP_OKAY) { + DECFUNC(); return res; } @@ -302,6 +449,7 @@ int mp_lshd(mp_int *a, int b) a->dp[x] = 0; } mp_clamp(a); + DECFUNC(); return MP_OKAY; } @@ -310,19 +458,27 @@ int mp_mod_2d(mp_int *a, int b, mp_int *c) { int x, res; + REGFUNC("mp_mod_2d"); + VERIFY(a); + VERIFY(c); + /* if b is <= 0 then zero the int */ if (b <= 0) { mp_zero(c); + DECFUNC(); return MP_OKAY; } /* if the modulus is larger than the value than return */ if (b > (int)(a->used * DIGIT_BIT)) { - return mp_copy(a, c); + res = mp_copy(a, c); + DECFUNC(); + return res; } /* copy */ if ((res = mp_copy(a, c)) != MP_OKAY) { + DECFUNC(); return res; } @@ -333,6 +489,7 @@ int mp_mod_2d(mp_int *a, int b, mp_int *c) /* clear the digit that is not completely outside/inside the modulus */ c->dp[b/DIGIT_BIT] &= (mp_digit)((((mp_digit)1)<<(b % DIGIT_BIT)) - ((mp_digit)1)); mp_clamp(c); + DECFUNC(); return MP_OKAY; } @@ -343,13 +500,26 @@ int mp_div_2d(mp_int *a, int b, mp_int *c, mp_int *d) int x, res; mp_int t; + REGFUNC("mp_div_2d"); + VERIFY(a); + VERIFY(c); + if (d != NULL) { VERIFY(d); } + + if (b <= 0) { + res = mp_copy(a, c); + if (d != NULL) { mp_zero(d); } + return res; + } + if ((res = mp_init(&t)) != MP_OKAY) { + DECFUNC(); return res; } if (d != NULL) { if ((res = mp_mod_2d(a, b, &t)) != MP_OKAY) { mp_clear(&t); + DECFUNC(); return res; } } @@ -357,6 +527,7 @@ int mp_div_2d(mp_int *a, int b, mp_int *c, mp_int *d) /* copy */ if ((res = mp_copy(a, c)) != MP_OKAY) { mp_clear(&t); + DECFUNC(); return res; } @@ -374,12 +545,12 @@ int mp_div_2d(mp_int *a, int b, mp_int *c, mp_int *d) } } mp_clamp(c); + res = MP_OKAY; if (d != NULL) { - res = mp_copy(&t, d); - } else { - res = MP_OKAY; + mp_exch(&t, d); } mp_clear(&t); + DECFUNC(); return MP_OKAY; } @@ -389,17 +560,24 @@ int mp_mul_2d(mp_int *a, int b, mp_int *c) mp_digit d, r, rr; int x, res; + REGFUNC("mp_mul_2d"); + VERIFY(a); + VERIFY(c); + /* copy */ if ((res = mp_copy(a, c)) != MP_OKAY) { + DECFUNC(); return res; } if ((res = mp_grow(c, c->used + b/DIGIT_BIT + 1)) != MP_OKAY) { + DECFUNC(); return res; } /* shift by as many digits in the bit count */ if ((res = mp_lshd(c, b/DIGIT_BIT)) != MP_OKAY) { + DECFUNC(); return res; } c->used = c->alloc; @@ -415,6 +593,7 @@ int mp_mul_2d(mp_int *a, int b, mp_int *c) } } mp_clamp(c); + DECFUNC(); return MP_OKAY; } @@ -423,9 +602,14 @@ int mp_div_2(mp_int *a, mp_int *b) { mp_digit r, rr; int x, res; + + REGFUNC("mp_div_2"); + VERIFY(a); + VERIFY(b); /* copy */ if ((res = mp_copy(a, b)) != MP_OKAY) { + DECFUNC(); return res; } @@ -436,6 +620,7 @@ int mp_div_2(mp_int *a, mp_int *b) r = rr; } mp_clamp(b); + DECFUNC(); return MP_OKAY; } @@ -445,12 +630,18 @@ int mp_mul_2(mp_int *a, mp_int *b) mp_digit r, rr; int x, res; + REGFUNC("mp_mul_2"); + VERIFY(a); + VERIFY(b); + /* copy */ if ((res = mp_copy(a, b)) != MP_OKAY) { + DECFUNC(); return res; } if ((res = mp_grow(b, b->used + 1)) != MP_OKAY) { + DECFUNC(); return res; } b->used = b->alloc; @@ -463,6 +654,7 @@ int mp_mul_2(mp_int *a, mp_int *b) r = rr; } mp_clamp(b); + DECFUNC(); return MP_OKAY; } @@ -473,6 +665,11 @@ static int s_mp_add(mp_int *a, mp_int *b, mp_int *c) int res, min, max, i; mp_digit u; + REGFUNC("s_mp_add"); + VERIFY(a); + VERIFY(b); + VERIFY(c); + /* find sizes */ if (a->used > b->used) { min = b->used; @@ -489,6 +686,7 @@ static int s_mp_add(mp_int *a, mp_int *b, mp_int *c) /* init result */ if ((res = mp_init_size(&t, max+1)) != MP_OKAY) { + DECFUNC(); return res; } t.used = max+1; @@ -513,12 +711,10 @@ static int s_mp_add(mp_int *a, mp_int *b, mp_int *c) /* add carry */ t.dp[i] = u; - mp_clamp(&t); - if ((res = mp_copy(&t, c)) != MP_OKAY) { - mp_clear(&t); - return res; - } + mp_clamp(&t); + mp_exch(&t, c); mp_clear(&t); + DECFUNC(); return MP_OKAY; } @@ -529,15 +725,21 @@ static int s_mp_sub(mp_int *a, mp_int *b, mp_int *c) int res, min, max, i; mp_digit u; + REGFUNC("s_mp_sub"); + VERIFY(a); + VERIFY(b); + VERIFY(c); + /* find sizes */ min = b->used; max = a->used; /* init result */ - if ((res = mp_init_size(&t, max+1)) != MP_OKAY) { + if ((res = mp_init_size(&t, max)) != MP_OKAY) { + DECFUNC(); return res; } - t.used = max+1; + t.used = max; /* sub digits from lower part */ u = 0; @@ -556,13 +758,10 @@ static int s_mp_sub(mp_int *a, mp_int *b, mp_int *c) } } - mp_clamp(&t); - if ((res = mp_copy(&t, c)) != MP_OKAY) { - mp_clear(&t); - return res; - } - + mp_clamp(&t); + mp_exch(&t, c); mp_clear(&t); + DECFUNC(); return MP_OKAY; } @@ -578,7 +777,13 @@ static int fast_s_mp_mul_digs(mp_int *a, mp_int *b, mp_int *c, int digs) mp_word W[512], *_W; mp_digit tmpx, *tmpt, *tmpy; + REGFUNC("fast_s_mp_mul_digs"); + VERIFY(a); + VERIFY(b); + VERIFY(c); + if ((res = mp_init_size(&t, digs)) != MP_OKAY) { + DECFUNC(); return res; } t.used = digs; @@ -606,12 +811,9 @@ static int fast_s_mp_mul_digs(mp_int *a, mp_int *b, mp_int *c, int digs) t.dp[digs-1] = W[digs-1] & ((mp_word)MP_MASK); mp_clamp(&t); - if ((res = mp_copy(&t, c)) != MP_OKAY) { - mp_clear(&t); - return res; - } + mp_exch(&t, c); mp_clear(&t); - + DECFUNC(); return MP_OKAY; } @@ -624,12 +826,19 @@ static int s_mp_mul_digs(mp_int *a, mp_int *b, mp_int *c, int digs) mp_word r; mp_digit tmpx, *tmpt, *tmpy; + REGFUNC("s_mp_mul_digs"); + VERIFY(a); + VERIFY(b); + VERIFY(c); + /* can we use the fast multiplier? */ if ((digs < 512) && digs < (1<<( (CHAR_BIT*sizeof(mp_word)) - (2*DIGIT_BIT)))) { + DECFUNC(); return fast_s_mp_mul_digs(a,b,c,digs); } if ((res = mp_init_size(&t, digs)) != MP_OKAY) { + DECFUNC(); return res; } t.used = digs; @@ -651,12 +860,10 @@ static int s_mp_mul_digs(mp_int *a, mp_int *b, mp_int *c, int digs) } mp_clamp(&t); - if ((res = mp_copy(&t, c)) != MP_OKAY) { - mp_clear(&t); - return res; - } + mp_exch(&t, c); + mp_clear(&t); - + DECFUNC(); return MP_OKAY; } @@ -667,7 +874,13 @@ static int fast_s_mp_mul_high_digs(mp_int *a, mp_int *b, mp_int *c, int digs) mp_word W[512], *_W; mp_digit tmpx, *tmpt, *tmpy; + REGFUNC("fast_s_mp_mul_high_digs"); + VERIFY(a); + VERIFY(b); + VERIFY(c); + if ((res = mp_init_size(&t, a->used + b->used + 1)) != MP_OKAY) { + DECFUNC(); return res; } t.used = a->used + b->used + 1; @@ -694,12 +907,9 @@ static int fast_s_mp_mul_high_digs(mp_int *a, mp_int *b, mp_int *c, int digs) mp_clamp(&t); - if ((res = mp_copy(&t, c)) != MP_OKAY) { - mp_clear(&t); - return res; - } + mp_exch(&t, c); mp_clear(&t); - + DECFUNC(); return MP_OKAY; } @@ -714,12 +924,19 @@ static int s_mp_mul_high_digs(mp_int *a, mp_int *b, mp_int *c, int digs) mp_word r; mp_digit tmpx, *tmpt, *tmpy; + REGFUNC("s_mp_mul_high_digs"); + VERIFY(a); + VERIFY(b); + VERIFY(c); + /* can we use the fast multiplier? */ if (((a->used + b->used + 1) < 512) && MAX(a->used, b->used) < (1<<( (CHAR_BIT*sizeof(mp_word)) - (2*DIGIT_BIT)))) { + DECFUNC(); return fast_s_mp_mul_high_digs(a,b,c,digs); } if ((res = mp_init_size(&t, a->used + b->used + 1)) != MP_OKAY) { + DECFUNC(); return res; } t.used = a->used + b->used + 1; @@ -739,12 +956,9 @@ static int s_mp_mul_high_digs(mp_int *a, mp_int *b, mp_int *c, int digs) *tmpt = u; } mp_clamp(&t); - if ((res = mp_copy(&t, c)) != MP_OKAY) { - mp_clear(&t); - return res; - } + mp_exch(&t, c); mp_clear(&t); - + DECFUNC(); return MP_OKAY; } @@ -755,9 +969,14 @@ static int fast_s_mp_sqr(mp_int *a, mp_int *b) int res, ix, iy, pa; mp_word W[512], *_W; mp_digit tmpx, *tmpy; + + REGFUNC("fast_s_mp_sqr"); + VERIFY(a); + VERIFY(b); pa = a->used; if ((res = mp_init_size(&t, pa + pa + 1)) != MP_OKAY) { + DECFUNC(); return res; } t.used = pa + pa + 1; @@ -781,11 +1000,9 @@ static int fast_s_mp_sqr(mp_int *a, mp_int *b) t.dp[(pa+pa+1)-1] = W[(pa+pa+1)-1] & ((mp_word)MP_MASK); mp_clamp(&t); - if ((res = mp_copy(&t, b)) != MP_OKAY) { - mp_clear(&t); - return res; - } + mp_exch(&t, b); mp_clear(&t); + DECFUNC(); return MP_OKAY; } @@ -797,13 +1014,19 @@ static int s_mp_sqr(mp_int *a, mp_int *b) mp_word r, u; mp_digit tmpx, *tmpt; + REGFUNC("s_mp_sqr"); + VERIFY(a); + VERIFY(b); + /* can we use the fast multiplier? */ if (((a->used * 2 + 1) < 512) && a->used < (1<<( (CHAR_BIT*sizeof(mp_word)) - (2*DIGIT_BIT) - 1))) { + DECFUNC(); return fast_s_mp_sqr(a,b); } pa = a->used; if ((res = mp_init_size(&t, pa + pa + 1)) != MP_OKAY) { + DECFUNC(); return res; } t.used = pa + pa + 1; @@ -833,11 +1056,9 @@ static int s_mp_sqr(mp_int *a, mp_int *b) } mp_clamp(&t); - if ((res = mp_copy(&t, b)) != MP_OKAY) { - mp_clear(&t); - return res; - } + mp_exch(&t, b); mp_clear(&t); + DECFUNC(); return MP_OKAY; } @@ -845,7 +1066,12 @@ static int s_mp_sqr(mp_int *a, mp_int *b) int mp_add(mp_int *a, mp_int *b, mp_int *c) { int sa, sb, res; - + + REGFUNC("mp_add"); + VERIFY(a); + VERIFY(b); + VERIFY(c); + sa = a->sign; sb = b->sign; @@ -877,6 +1103,7 @@ int mp_add(mp_int *a, mp_int *b, mp_int *c) res = s_mp_add(a, b, c); c->sign = MP_NEG; } + DECFUNC(); return res; } @@ -885,6 +1112,11 @@ int mp_sub(mp_int *a, mp_int *b, mp_int *c) { int sa, sb, res; + REGFUNC("mp_sub"); + VERIFY(a); + VERIFY(b); + VERIFY(c); + sa = a->sign; sb = b->sign; @@ -893,7 +1125,7 @@ int mp_sub(mp_int *a, mp_int *b, mp_int *c) /* both positive, a - b, but if b>a then we do -(b - a) */ if (mp_cmp_mag(a, b) == MP_LT) { /* b>a */ - res = s_mp_sub(b, a, c); + res = s_mp_sub(b, a, c); c->sign = MP_NEG; } else { res = s_mp_sub(a, b, c); @@ -917,6 +1149,8 @@ int mp_sub(mp_int *a, mp_int *b, mp_int *c) c->sign = MP_ZPOS; } } + + DECFUNC(); return res; } @@ -926,6 +1160,11 @@ static int mp_karatsuba_mul(mp_int *a, mp_int *b, mp_int *c) mp_int x0, x1, y0, y1, t1, t2, x0y0, x1y1; int B, err, neg, x; + REGFUNC("mp_karatsuba_mul"); + VERIFY(a); + VERIFY(b); + VERIFY(c); + err = MP_MEM; /* min # of digits */ @@ -992,7 +1231,8 @@ static int mp_karatsuba_mul(mp_int *a, mp_int *b, mp_int *c) if (mp_add(&x0y0, &t1, &t1) != MP_OKAY) goto X1Y1; /* t1 = x0y0 + t1 */ if (mp_add(&t1, &x1y1, &t1) != MP_OKAY) goto X1Y1; /* t1 = x0y0 + t1 + x1y1 */ - err = mp_copy(&t1, c); + err = MP_OKAY; + mp_exch(&t1, c); X1Y1: mp_clear(&x1y1); X0Y0: mp_clear(&x0y0); @@ -1003,6 +1243,7 @@ Y0 : mp_clear(&y0); X1 : mp_clear(&x1); X0 : mp_clear(&x0); ERR : + DECFUNC(); return err; } @@ -1010,6 +1251,10 @@ ERR : int mp_mul(mp_int *a, mp_int *b, mp_int *c) { int res, neg; + REGFUNC("mp_mul"); + VERIFY(a); + VERIFY(b); + VERIFY(c); neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG; if (MIN(a->used, b->used) > KARATSUBA_MUL_CUTOFF) { res = mp_karatsuba_mul(a, b, c); @@ -1017,6 +1262,7 @@ int mp_mul(mp_int *a, mp_int *b, mp_int *c) res = s_mp_mul(a, b, c); } c->sign = neg; + DECFUNC(); return res; } @@ -1025,6 +1271,10 @@ static int mp_karatsuba_sqr(mp_int *a, mp_int *b) { mp_int x0, x1, t1, t2, x0x0, x1x1; int B, err; + + REGFUNC("mp_karatsuba_sqr"); + VERIFY(a); + VERIFY(b); err = MP_MEM; @@ -1067,7 +1317,9 @@ static int mp_karatsuba_sqr(mp_int *a, mp_int *b) if (mp_add(&x0x0, &t1, &t1) != MP_OKAY) goto X1X1; /* t1 = x0y0 + t1 */ if (mp_add(&t1, &x1x1, &t1) != MP_OKAY) goto X1X1; /* t1 = x0y0 + t1 + x1y1 */ - err = mp_copy(&t1, b); + err = MP_OKAY; + mp_exch(&t1, b); + X1X1: mp_clear(&x1x1); X0X0: mp_clear(&x0x0); T2 : mp_clear(&t2); @@ -1075,6 +1327,7 @@ T1 : mp_clear(&t1); X1 : mp_clear(&x1); X0 : mp_clear(&x0); ERR : + DECFUNC(); return err; } @@ -1082,12 +1335,16 @@ ERR : int mp_sqr(mp_int *a, mp_int *b) { int res; + REGFUNC("mp_sqr"); + VERIFY(a); + VERIFY(b); if (a->used > KARATSUBA_SQR_CUTOFF) { res = mp_karatsuba_sqr(a, b); } else { res = s_mp_sqr(a, b); } b->sign = MP_ZPOS; + DECFUNC(); return res; } @@ -1098,8 +1355,15 @@ int mp_div(mp_int *a, mp_int *b, mp_int *c, mp_int *d) mp_int q, x, y, t1, t2; int res, n, t, i, norm, neg; + REGFUNC("mp_div"); + VERIFY(a); + VERIFY(b); + if (c != NULL) { VERIFY(c); } + if (d != NULL) { VERIFY(d); } + /* is divisor zero ? */ if (mp_iszero(b) == 1) { + DECFUNC(); return MP_VAL; } @@ -1113,11 +1377,12 @@ int mp_div(mp_int *a, mp_int *b, mp_int *c, mp_int *d) if (c != NULL) { mp_zero(c); } + DECFUNC(); return res; } - if ((res = mp_init_size(&q, a->used + 2)) != MP_OKAY) { + DECFUNC(); return res; } q.used = a->used + 2; @@ -1175,6 +1440,8 @@ int mp_div(mp_int *a, mp_int *b, mp_int *c, mp_int *d) /* step 3. for i from n down to (t + 1) */ for (i = n; i >= (t + 1); i--) { + if (i > x.alloc) continue; + /* step 3.1 if xi == yt then set q{i-t-1} to b-1, otherwise set q{i-t-1} to (xi*b + x{i-1})/yt */ if (x.dp[i] == y.dp[t]) { q.dp[i - t - 1] = ((1UL<sign; if (c != NULL) { mp_clamp(&q); - mp_copy(&q, c); + mp_exch(&q, c); c->sign = neg; } if (d != NULL) { mp_div_2d(&x, norm, &x, NULL); mp_clamp(&x); - mp_copy(&x, d); + mp_exch(&x, d); } res = MP_OKAY; @@ -1258,6 +1526,7 @@ __X: mp_clear(&x); __T2: mp_clear(&t2); __T1: mp_clear(&t1); __Q: mp_clear(&q); + DECFUNC(); return res; } @@ -1267,22 +1536,31 @@ int mp_mod(mp_int *a, mp_int *b, mp_int *c) mp_int t; int res; + REGFUNC("mp_mod"); + VERIFY(a); + VERIFY(b); + VERIFY(c); + if ((res = mp_init(&t)) != MP_OKAY) { + DECFUNC(); return res; } if ((res = mp_div(a, b, NULL, &t)) != MP_OKAY) { mp_clear(&t); + DECFUNC(); return res; } if (t.sign == MP_NEG) { res = mp_add(b, &t, c); } else { - res = mp_copy(&t, c); + res = MP_OKAY; + mp_exch(&t, c); } mp_clear(&t); + DECFUNC(); return res; } @@ -1292,13 +1570,19 @@ int mp_add_d(mp_int *a, mp_digit b, mp_int *c) mp_int t; int res; + REGFUNC("mp_add_d"); + VERIFY(a); + VERIFY(c); + if ((res = mp_init(&t)) != MP_OKAY) { + DECFUNC(); return res; } mp_set(&t, b); res = mp_add(a, &t, c); mp_clear(&t); + DECFUNC(); return res; } @@ -1308,13 +1592,19 @@ int mp_sub_d(mp_int *a, mp_digit b, mp_int *c) mp_int t; int res; + REGFUNC("mp_sub_d"); + VERIFY(a); + VERIFY(c); + if ((res = mp_init(&t)) != MP_OKAY) { + DECFUNC(); return res; } mp_set(&t, b); res = mp_sub(a, &t, c); mp_clear(&t); + DECFUNC(); return res; } @@ -1326,8 +1616,13 @@ int mp_mul_d(mp_int *a, mp_digit b, mp_int *c) mp_digit u; mp_int t; + REGFUNC("mp_mul_d"); + VERIFY(a); + VERIFY(c); + pa = a->used; if ((res = mp_init_size(&t, pa + 2)) != MP_OKAY) { + DECFUNC(); return res; } t.used = pa + 2; @@ -1342,11 +1637,9 @@ int mp_mul_d(mp_int *a, mp_digit b, mp_int *c) t.sign = a->sign; mp_clamp(&t); - if ((res = mp_copy(&t, c)) != MP_OKAY) { - mp_clear(&t); - return res; - } + mp_exch(&t, c); mp_clear(&t); + DECFUNC(); return MP_OKAY; } @@ -1356,12 +1649,18 @@ int mp_div_d(mp_int *a, mp_digit b, mp_int *c, mp_digit *d) mp_int t, t2; int res; + REGFUNC("mp_div_d"); + VERIFY(a); + if (c != NULL) { VERIFY(c); } + if ((res = mp_init(&t)) != MP_OKAY) { + DECFUNC(); return res; } if ((res = mp_init(&t2)) != MP_OKAY) { mp_clear(&t); + DECFUNC(); return res; } @@ -1374,6 +1673,7 @@ int mp_div_d(mp_int *a, mp_digit b, mp_int *c, mp_digit *d) mp_clear(&t); mp_clear(&t2); + DECFUNC(); return res; } @@ -1381,13 +1681,18 @@ int mp_mod_d(mp_int *a, mp_digit b, mp_digit *c) { mp_int t, t2; int res; + + REGFUNC("mp_mod_d"); + VERIFY(a); if ((res = mp_init(&t)) != MP_OKAY) { + DECFUNC(); return res; } if ((res = mp_init(&t2)) != MP_OKAY) { mp_clear(&t); + DECFUNC(); return res; } @@ -1398,12 +1703,14 @@ int mp_mod_d(mp_int *a, mp_digit b, mp_digit *c) if ((res = mp_add_d(&t2, b, &t2)) != MP_OKAY) { mp_clear(&t); mp_clear(&t2); + DECFUNC(); return res; } } *c = t2.dp[0]; mp_clear(&t); mp_clear(&t2); + DECFUNC(); return MP_OKAY; } @@ -1412,7 +1719,12 @@ int mp_expt_d(mp_int *a, mp_digit b, mp_int *c) int res, x; mp_int g; + REGFUNC("mp_expt_d"); + VERIFY(a); + VERIFY(c); + if ((res = mp_init_copy(&g, a)) != MP_OKAY) { + DECFUNC(); return res; } @@ -1422,12 +1734,14 @@ int mp_expt_d(mp_int *a, mp_digit b, mp_int *c) for (x = 0; x < (int)DIGIT_BIT; x++) { if ((res = mp_sqr(c, c)) != MP_OKAY) { mp_clear(&g); + DECFUNC(); return res; } if (b & (mp_digit)(1<<(DIGIT_BIT-1))) { if ((res = mp_mul(c, &g, c)) != MP_OKAY) { mp_clear(&g); + DECFUNC(); return res; } } @@ -1436,6 +1750,7 @@ int mp_expt_d(mp_int *a, mp_digit b, mp_int *c) } mp_clear(&g); + DECFUNC(); return MP_OKAY; } @@ -1447,16 +1762,25 @@ int mp_addmod(mp_int *a, mp_int *b, mp_int *c, mp_int *d) int res; mp_int t; + REGFUNC("mp_addmod"); + VERIFY(a); + VERIFY(b); + VERIFY(c); + VERIFY(d); + if ((res = mp_init(&t)) != MP_OKAY) { + DECFUNC(); return res; } if ((res = mp_add(a, b, &t)) != MP_OKAY) { mp_clear(&t); + DECFUNC(); return res; } res = mp_mod(&t, c, d); mp_clear(&t); + DECFUNC(); return res; } @@ -1466,16 +1790,25 @@ int mp_submod(mp_int *a, mp_int *b, mp_int *c, mp_int *d) int res; mp_int t; + REGFUNC("mp_submod"); + VERIFY(a); + VERIFY(b); + VERIFY(c); + VERIFY(d); + if ((res = mp_init(&t)) != MP_OKAY) { + DECFUNC(); return res; } if ((res = mp_sub(a, b, &t)) != MP_OKAY) { mp_clear(&t); + DECFUNC(); return res; } res = mp_mod(&t, c, d); mp_clear(&t); + DECFUNC(); return res; } @@ -1485,16 +1818,25 @@ int mp_mulmod(mp_int *a, mp_int *b, mp_int *c, mp_int *d) int res; mp_int t; + REGFUNC("mp_mulmod"); + VERIFY(a); + VERIFY(b); + VERIFY(c); + VERIFY(d); + if ((res = mp_init(&t)) != MP_OKAY) { + DECFUNC(); return res; } if ((res = mp_mul(a, b, &t)) != MP_OKAY) { mp_clear(&t); + DECFUNC(); return res; } res = mp_mod(&t, c, d); mp_clear(&t); + DECFUNC(); return res; } @@ -1504,16 +1846,24 @@ int mp_sqrmod(mp_int *a, mp_int *b, mp_int *c) int res; mp_int t; + REGFUNC("mp_sqrmod"); + VERIFY(a); + VERIFY(b); + VERIFY(c); + if ((res = mp_init(&t)) != MP_OKAY) { + DECFUNC(); return res; } if ((res = mp_sqr(a, &t)) != MP_OKAY) { mp_clear(&t); + DECFUNC(); return res; } res = mp_mod(&t, b, c); mp_clear(&t); + DECFUNC(); return res; } @@ -1524,15 +1874,23 @@ int mp_gcd(mp_int *a, mp_int *b, mp_int *c) mp_int u, v, t; int k, res, neg; + REGFUNC("mp_gcd"); + VERIFY(a); + VERIFY(b); + VERIFY(c); + /* either zero than gcd is the largest */ if (mp_iszero(a) == 1 && mp_iszero(b) == 0) { + DECFUNC(); return mp_copy(b, c); } if (mp_iszero(a) == 0 && mp_iszero(b) == 1) { + DECFUNC(); return mp_copy(a, c); } if (mp_iszero(a) == 1 && mp_iszero(b) == 1) { mp_set(c, 1); + DECFUNC(); return MP_OKAY; } @@ -1540,6 +1898,7 @@ int mp_gcd(mp_int *a, mp_int *b, mp_int *c) neg = (a->sign == b->sign) ? a->sign : MP_ZPOS; if ((res = mp_init_copy(&u, a)) != MP_OKAY) { + DECFUNC(); return res; } @@ -1602,14 +1961,13 @@ int mp_gcd(mp_int *a, mp_int *b, mp_int *c) goto __T; } - if ((res = mp_copy(&u, c)) != MP_OKAY) { - goto __T; - } + mp_exch(&u, c); c->sign = neg; + res = MP_OKAY; __T: mp_clear(&t); __V: mp_clear(&u); __U: mp_clear(&v); - + DECFUNC(); return res; } @@ -1619,31 +1977,206 @@ int mp_lcm(mp_int *a, mp_int *b, mp_int *c) int res; mp_int t; + REGFUNC("mp_lcm"); + VERIFY(a); + VERIFY(b); + VERIFY(c); + if ((res = mp_init(&t)) != MP_OKAY) { + DECFUNC(); return res; } if ((res = mp_mul(a, b, &t)) != MP_OKAY) { mp_clear(&t); + DECFUNC(); return res; } if ((res = mp_gcd(a, b, c)) != MP_OKAY) { mp_clear(&t); + DECFUNC(); return res; } res = mp_div(&t, c, c, NULL); mp_clear(&t); + DECFUNC(); return res; } /* computes the modular inverse via binary extended euclidean algorithm, that is c = 1/a mod b */ +static int fast_mp_invmod(mp_int *a, mp_int *b, mp_int *c) +{ + mp_int x, y, u, v, B, D; + int res, neg; + + REGFUNC("fast_mp_invmod"); + VERIFY(a); + VERIFY(b); + VERIFY(c); + + if ((res = mp_init(&x)) != MP_OKAY) { + goto __ERR; + } + + if ((res = mp_init(&y)) != MP_OKAY) { + goto __X; + } + + if ((res = mp_init(&u)) != MP_OKAY) { + goto __Y; + } + + if ((res = mp_init(&v)) != MP_OKAY) { + goto __U; + } + + if ((res = mp_init(&B)) != MP_OKAY) { + goto __V; + } + + if ((res = mp_init(&D)) != MP_OKAY) { + goto __B; + } + + /* x == modulus, y == value to invert */ + if ((res = mp_copy(b, &x)) != MP_OKAY) { + goto __D; + } + if ((res = mp_copy(a, &y)) != MP_OKAY) { + goto __D; + } + + if ((res = mp_abs(&y, &y)) != MP_OKAY) { + goto __D; + } + + /* 2. [modified] if x,y are both even then return an error! */ + if (mp_iseven(&x) == 1 && mp_iseven(&y) == 1) { + res = MP_VAL; + goto __D; + } + + /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */ + if ((res = mp_copy(&x, &u)) != MP_OKAY) { + goto __D; + } + if ((res = mp_copy(&y, &v)) != MP_OKAY) { + goto __D; + } + mp_set(&D, 1); + + +top: + /* 4. while u is even do */ + while (mp_iseven(&u) == 1) { + /* 4.1 u = u/2 */ + if ((res = mp_div_2(&u, &u)) != MP_OKAY) { + goto __D; + } + /* 4.2 if A or B is odd then */ + if (mp_iseven(&B) == 0) { + if ((res = mp_sub(&B, &x, &B)) != MP_OKAY) { + goto __D; + } + } + /* A = A/2, B = B/2 */ + if ((res = mp_div_2(&B, &B)) != MP_OKAY) { + goto __D; + } + } + + + /* 5. while v is even do */ + while (mp_iseven(&v) == 1) { + /* 5.1 v = v/2 */ + if ((res = mp_div_2(&v, &v)) != MP_OKAY) { + goto __D; + } + /* 5.2 if C,D are even then */ + if (mp_iseven(&D) == 0) { + /* C = (C+y)/2, D = (D-x)/2 */ + if ((res = mp_sub(&D, &x, &D)) != MP_OKAY) { + goto __D; + } + } + /* C = C/2, D = D/2 */ + if ((res = mp_div_2(&D, &D)) != MP_OKAY) { + goto __D; + } + } + + /* 6. if u >= v then */ + if (mp_cmp(&u, &v) != MP_LT) { + /* u = u - v, A = A - C, B = B - D */ + if ((res = mp_sub(&u, &v, &u)) != MP_OKAY) { + goto __D; + } + + if ((res = mp_sub(&B, &D, &B)) != MP_OKAY) { + goto __D; + } + } else { + /* v - v - u, C = C - A, D = D - B */ + if ((res = mp_sub(&v, &u, &v)) != MP_OKAY) { + goto __D; + } + + if ((res = mp_sub(&D, &B, &D)) != MP_OKAY) { + goto __D; + } + } + + /* if not zero goto step 4 */ + if (mp_iszero(&u) == 0) goto top; + + /* now a = C, b = D, gcd == g*v */ + + /* if v != 1 then there is no inverse */ + if (mp_cmp_d(&v, 1) != MP_EQ) { + res = MP_VAL; + goto __D; + } + + /* b is now the inverse */ + neg = a->sign; + while (D.sign == MP_NEG) { + if ((res = mp_add(&D, b, &D)) != MP_OKAY) { + goto __D; + } + } + mp_exch(&D, c); + c->sign = neg; + res = MP_OKAY; + +__D: mp_clear(&D); +__B: mp_clear(&B); +__V: mp_clear(&v); +__U: mp_clear(&u); +__Y: mp_clear(&y); +__X: mp_clear(&x); +__ERR: + DECFUNC(); + return res; +} + int mp_invmod(mp_int *a, mp_int *b, mp_int *c) { mp_int x, y, u, v, A, B, C, D; int res, neg; + REGFUNC("mp_invmod"); + VERIFY(a); + VERIFY(b); + VERIFY(c); + + if (mp_iseven(b) == 0) { + res = fast_mp_invmod(a,b,c); + DECFUNC(); + return res; + } + if ((res = mp_init(&x)) != MP_OKAY) { goto __ERR; } @@ -1802,7 +2335,8 @@ top: if (C.sign == MP_NEG) { res = mp_add(b, &C, c); } else { - res = mp_copy(&C, c); + mp_exch(&C, c); + res = MP_OKAY; } c->sign = neg; @@ -1815,6 +2349,7 @@ __U: mp_clear(&u); __Y: mp_clear(&y); __X: mp_clear(&x); __ERR: + DECFUNC(); return res; } @@ -1825,11 +2360,18 @@ int mp_reduce_setup(mp_int *a, mp_int *b) { int res; + REGFUNC("mp_reduce_setup"); + VERIFY(a); + VERIFY(b); + mp_set(a, 1); if ((res = mp_lshd(a, b->used * 2)) != MP_OKAY) { + DECFUNC(); return res; } - return mp_div(a, b, a, NULL); + res = mp_div(a, b, a, NULL); + DECFUNC(); + return res; } /* reduces x mod m, assumes 0 < x < m^2, mu is precomputed via mp_reduce_setup */ @@ -1838,8 +2380,15 @@ int mp_reduce(mp_int *x, mp_int *m, mp_int *mu) mp_int q; int res, um = m->used; - if((res = mp_init_copy(&q, x)) != MP_OKAY) + REGFUNC("mp_reduce"); + VERIFY(x); + VERIFY(m); + VERIFY(mu); + + if((res = mp_init_copy(&q, x)) != MP_OKAY) { + DECFUNC(); return res; + } mp_rshd(&q, um - 1); /* q1 = x / b^(k-1) */ @@ -1883,6 +2432,7 @@ int mp_reduce(mp_int *x, mp_int *m, mp_int *mu) CLEANUP: mp_clear(&q); + DECFUNC(); return res; } @@ -1893,6 +2443,12 @@ int mp_exptmod(mp_int *G, mp_int *X, mp_int *P, mp_int *Y) mp_digit buf; int err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, z, winsize, tab[64]; + REGFUNC("mp_exptmod"); + VERIFY(G); + VERIFY(X); + VERIFY(P); + VERIFY(Y); + /* find window size */ x = mp_count_bits(X); if (x <= 18) { winsize = 2; } @@ -1907,6 +2463,7 @@ int mp_exptmod(mp_int *G, mp_int *X, mp_int *P, mp_int *Y) for (y = 0; y < x; y++) { mp_clear(&M[y]); } + DECFUNC(); return err; } } @@ -2051,13 +2608,15 @@ int mp_exptmod(mp_int *G, mp_int *X, mp_int *P, mp_int *Y) } } - err = mp_copy(&res, Y); + mp_exch(&res, Y); + err = MP_OKAY; __RES: mp_clear(&res); __MU : mp_clear(&mu); __M : for (x = 0; x < (1<used)) != MP_OKAY) { - return res; - } - a->used = n; while (c-- > 0) { if ((res = mp_mul_2d(a, 8, a)) != MP_OKAY) { return res; @@ -2154,9 +2707,13 @@ int mp_to_unsigned_bin(mp_int *a, unsigned char *b) } else { b[x++] = (unsigned char)(t.dp[0] | ((t.dp[1] & 0x01) << 7)); } - mp_div_2d(&t, 8, &t, NULL); + if ((res = mp_div_2d(&t, 8, &t, NULL)) != MP_OKAY) { + mp_clear(&t); + return res; + } } reverse(b, x); + mp_clear(&t); return MP_OKAY; } @@ -2249,6 +2806,7 @@ int mp_toradix(mp_int *a, unsigned char *str, int radix) digs = 0; while (mp_iszero(&t) == 0) { if ((res = mp_div_d(&t, (mp_digit)radix, &t, &d)) != MP_OKAY) { + mp_clear(&t); return res; } *str++ = (unsigned char)s_rmap[d]; @@ -2284,6 +2842,7 @@ int mp_radix_size(mp_int *a, int radix) while (mp_iszero(&t) == 0) { if ((res = mp_div_d(&t, (mp_digit)radix, &t, &d)) != MP_OKAY) { + mp_clear(&t); return 0; } ++digs; diff --git a/bn.h b/bn.h index 54c8e7a..b0131ce 100644 --- a/bn.h +++ b/bn.h @@ -83,6 +83,9 @@ int mp_init(mp_int *a); /* free a bignum */ void mp_clear(mp_int *a); +/* exchange two ints */ +void mp_exch(mp_int *a, mp_int *b); + /* shrink ram required for a bignum */ int mp_shrink(mp_int *a); @@ -214,7 +217,11 @@ int mp_lcm(mp_int *a, mp_int *b, mp_int *c); /* used to setup the Barrett reduction for a given modulus b */ int mp_reduce_setup(mp_int *a, mp_int *b); -/* Barrett Reduction, computes a (mod b) with a precomputed value c */ +/* Barrett Reduction, computes a (mod b) with a precomputed value c + * + * Assumes that 0 < a <= b^2, note if 0 > a > -(b^2) then you can merely + * compute the reduction as -1 * mp_reduce(mp_abs(a)) [pseudo code]. + */ int mp_reduce(mp_int *a, mp_int *b, mp_int *c); /* d = a^b (mod c) */ diff --git a/bn.pdf b/bn.pdf index 8f8f58d..4834ba1 100644 Binary files a/bn.pdf and b/bn.pdf differ diff --git a/bn.tex b/bn.tex index 0e651ce..f10cee4 100644 --- a/bn.tex +++ b/bn.tex @@ -1,7 +1,7 @@ \documentclass{article} \begin{document} -\title{LibTomMath v0.03 \\ A Free Multiple Precision Integer Library} +\title{LibTomMath v0.04 \\ A Free Multiple Precision Integer Library} \author{Tom St Denis \\ tomstdenis@iahu.ca} \maketitle \newpage @@ -323,46 +323,161 @@ The integers are stored in big endian format as most libraries (and MPI) expect. \textbf{mp\_toradix} functions read and write (respectively) null terminated ASCII strings in a given radix. Valid values for the radix are between 2 and 64 (inclusively). +\section{Function Analysis} + +Throughout the function analysis the variable $N$ will denote the average size of an input to a function as measured +by the number of digits it has. The variable $W$ will denote the number of bits per word and $c$ will denote a small +constant amount of work. The big-oh notation will be abused slightly to consider numbers that do not grow to infinity. +That is we shall consider $O(N/2) \ne O(N)$ which is an abuse of the notation. + +\subsection{Digit Manipulation Functions} +The class of digit manipulation functions such as \textbf{mp\_rshd}, \textbf{mp\_lshd} and \textbf{mp\_mul\_2} are all +very simple functions to analyze. + +\subsubsection{mp\_rshd(mp\_int *a, int b)} +If the shift count ``b'' is less than or equal to zero the function returns without doing any work. If the +the shift count is larger than the number of digits in ``a'' then ``a'' is simply zeroed without shifting digits. + +This function requires no additional memory and $O(N)$ time. + +\subsubsection{mp\_lshd(mp\_int *a, int b)} +If the shift count ``b'' is less than or equal to zero the function returns success without doing any work. + +This function requires $O(b)$ additional digits of memory and $O(N)$ time. + +\subsubsection{mp\_div\_2d(mp\_int *a, int b, mp\_int *c, mp\_int *d)} +If the shift count ``b'' is less than or equal to zero the function places ``a'' in ``c'' and returns success. + +This function requires $O(2 \cdot N)$ additional digits of memory and $O(2 \cdot N)$ time. + +\subsubsection{mp\_mul\_2d(mp\_int *a, int b, mp\_int *c)} +If the shift count ``b'' is less than or equal to zero the function places ``a'' in ``c'' and returns success. + +This function requires $O(N)$ additional digits of memory and $O(2 \cdot N)$ time. + +\subsubsection{mp\_mod\_2d(mp\_int *a, int b, mp\_int *c)} +If the shift count ``b'' is less than or equal to zero the function places ``a'' in ``c'' and returns success. + +This function requires $O(N)$ additional digits of memory and $O(2 \cdot N)$ time. + +\subsection{Basic Arithmetic} + +\subsubsection{mp\_cmp(mp\_int *a, mp\_int *b)} +Performs a \textbf{signed} comparison between ``a'' and ``b'' returning +\textbf{MP\_GT} is ``a'' is larger than ``b''. + +This function requires no additional memory and $O(N)$ time. + +\subsubsection{mp\_cmp\_mag(mp\_int *a, mp\_int *b)} +Performs a \textbf{unsigned} comparison between ``a'' and ``b'' returning +\textbf{MP\_GT} is ``a'' is larger than ``b''. Note that this comparison is unsigned which means it will report, for +example, $-5 > 3$. By comparison mp\_cmp will report $-5 < 3$. + +This function requires no additional memory and $O(N)$ time. + +\subsubsection{mp\_add(mp\_int *a, mp\_int *b, mp\_int *c)} +Handles the sign of the numbers correctly which means it will subtract as required, e.g. $a + -b$ turns into $a - b$. + +This function requires no additional memory and $O(N)$ time. + +\subsubsection{mp\_sub(mp\_int *a, mp\_int *b, mp\_int *c)} +Handles the sign of the numbers correctly which means it will add as required, e.g. $a - -b$ turns into $a + b$. + +This function requires no additional memory and $O(N)$ time. + +\subsubsection{mp\_mul(mp\_int *a, mp\_int *b, mp\_int *c)} +Handles the sign of the numbers correctly which means it will correct the sign of the product as required, +e.g. $a \cdot -b$ turns into $-ab$. + +For relatively small inputs, that is less than 80 digits a standard baseline or comba-baseline multiplier is used. It +requires no additional memory and $O(N^2)$ time. The comba-baseline multiplier is only used if it can safely be used +without losing carry digits. The comba method is faster than the baseline method but cannot always be used which is why +both are provided. The code will automatically determine when it can be used. If the digit count is higher +than 80 for the inputs than a Karatsuba multiplier is used which requires approximately $O(6 \cdot N)$ memory and +$O(N^{lg(3)})$ time. + +\subsubsection{mp\_sqr(mp\_int *a, mp\_int *b)} +For relatively small inputs, that is less than 80 digits a modified squaring or comba-squaring algorithm is used. It +requires no additional memory and $O((N^2 + N)/2)$ time. The comba-squaring method is used only if it can be safely used +without losing carry digits. After 80 digits a Karatsuba squaring algorithm is used whcih requires approximately +$O(4 \cdot N)$ memory and $O(N^{lg(3)})$ time. + +\subsubsection{mp\_div(mp\_int *a, mp\_int *b, mp\_int *c, mp\_int *d)} +The quotient is placed in ``c'' and the remainder in ``d''. Either (or both) of ``c'' and ``d'' can be set to NULL +if the value is not desired. + +This function requires $O(4 \cdot N)$ memory and $O(N^2 + N)$ time. + +\subsection{Modular Arithmetic} + +\subsubsection{mp\_addmod, mp\_submod, mp\_mulmod, mp\_sqrmod} +These functions take the time of their host function plus the time it takes to perform a division. For example, +mp\_addmod takes $O(N + (N^2 + N))$ time. Note that if you are performing many modular operations in a row with +the same modulus you should consider Barrett reductions. + +NOTE: This section will be expanded upon in future releases of the library. + +\subsubsection{mp\_invmod(mp\_int *a, mp\_int *b, mp\_int *c)} +This function is technically only defined for moduli who are positive and inputs that are positive. That is it will find +$c = 1/a \mbox{ (mod }b\mbox{)}$ for any $a > 0$ and $b > 0$. The function will work for negative values of $a$ since +it merely computes $c = -1 \cdot (1/{\vert a \vert}) \mbox{ (mod }b\mbox{)}$. In general the input is only +\textbf{guaranteed} to lead to a correct output if $-b < a < b$ and $(a, b) = 1$. + +NOTE: This function will be revised to accept a wider range of inputs in future releases. + \section{Timing Analysis} \subsection{Observed Timings} A simple test program ``demo.c'' was developed which builds with either MPI or LibTomMath (without modification). The test was conducted on an AMD Athlon XP processor with 266Mhz DDR memory and the GCC 3.2 compiler\footnote{With build -options ``-O3 -fomit-frame-pointer -funroll-loops''}. The multiplications and squarings were repeated 10,000 times -each while the modular exponentiation (exptmod) were performed 10 times each. The RDTSC (Read Time Stamp Counter) instruction -was used to measure the time the entire iterations took and was divided by the number of iterations to get an -average. The following results were observed. +options ``-O3 -fomit-frame-pointer -funroll-loops''}. The multiplications and squarings were repeated 100,000 times +each while the modular exponentiation (exptmod) were performed 50 times each. The ``inversions'' refers to multiplicative +inversions modulo an odd number of a given size. The RDTSC (Read Time Stamp Counter) instruction was used to measure the +time the entire iterations took and was divided by the number of iterations to get an average. The following results +were observed. \begin{small} \begin{center} \begin{tabular}{c|c|c|c} \hline \textbf{Operation} & \textbf{Size (bits)} & \textbf{Time with MPI (cycles)} & \textbf{Time with LibTomMath (cycles)} \\ \hline -Multiply & 128 & 1,426 & 928 \\ -Multiply & 256 & 2,551 & 1,787 \\ -Multiply & 512 & 7,913 & 3,458 \\ -Multiply & 1024 & 28,496 & 9,271 \\ -Multiply & 2048 & 109,897 & 29,917 \\ -Multiply & 4096 & 469,970 & 123,934 \\ +Inversion & 128 & 264,083 & 172,381 \\ +Inversion & 256 & 549,370 & 381,237 \\ +Inversion & 512 & 1,675,975 & 1,212,341 \\ +Inversion & 1024 & 5,237,957 & 3,114,144 \\ +Inversion & 2048 & 17,871,944 & 8,137,896 \\ +Inversion & 4096 & 66,610,468 & 22,469,360 \\ +\hline +Multiply & 128 & 1,426 & 847 \\ +Multiply & 256 & 2,551 & 1,848 \\ +Multiply & 512 & 7,913 & 3,505 \\ +Multiply & 1024 & 28,496 & 9,097 \\ +Multiply & 2048 & 109,897 & 29,497 \\ +Multiply & 4096 & 469,970 & 112,651 \\ \hline -Square & 128 & 1,319 & 1,230 \\ -Square & 256 & 1,776 & 2,131 \\ -Square & 512 & 5,399 & 3,694 \\ -Square & 1024 & 18,991 & 9,172 \\ -Square & 2048 & 72,126 & 27,352 \\ -Square & 4096 & 306,269 & 110,607 \\ +Square & 128 & 1,319 & 883 \\ +Square & 256 & 1,776 & 1,895 \\ +Square & 512 & 5,399 & 3,543 \\ +Square & 1024 & 18,991 & 8,692 \\ +Square & 2048 & 72,126 & 26,792 \\ +Square & 4096 & 306,269 & 103,263 \\ \hline -Exptmod & 512 & 32,021,586 & 6,880,075 \\ -Exptmod & 768 & 97,595,492 & 15,202,614 \\ -Exptmod & 1024 & 223,302,532 & 28,081,865 \\ -Exptmod & 2048 & 1,682,223,369 & 146,545,454 \\ -Exptmod & 2560 & 3,268,615,571 & 310,970,112 \\ -Exptmod & 3072 & 5,597,240,141 & 480,703,712 \\ -Exptmod & 4096 & 13,347,270,891 & 985,918,868 +Exptmod & 512 & 32,021,586 & 7,096,687 \\ +Exptmod & 768 & 97,595,492 & 14,849,813 \\ +Exptmod & 1024 & 223,302,532 & 27,826,489 \\ +Exptmod & 2048 & 1,682,223,369 & 142,026,274 \\ +Exptmod & 2560 & 3,268,615,571 & 292,597,205 \\ +Exptmod & 3072 & 5,597,240,141 & 452,731,243 \\ +Exptmod & 4096 & 13,347,270,891 & 941,433,401 \end{tabular} \end{center} \end{small} +Note that the figures do fluctuate but their magnitudes are relatively intact. The purpose of the chart is not to +get an exact timing but to compare the two libraries. For example, in all of the tests the exact time for a 512-bit +squaring operation was not the same. The observed times were all approximately 3,500 cycles, more importantly they +were always faster than the timings observed with MPI by about the same magnitude. + \subsection{Digit Size} The first major constribution to the time savings is the fact that 28 bits are stored per digit instead of the MPI defualt of 16. This means in many of the algorithms the savings can be considerable. Consider a baseline multiplier diff --git a/changes.txt b/changes.txt index ca7f537..bf89d99 100644 --- a/changes.txt +++ b/changes.txt @@ -1,3 +1,10 @@ +Dec 29th, 2002 +v0.04 -- Fixed a memory leak in mp_to_unsigned_bin + -- optimized invmod code + -- Fixed bug in mp_div + -- use exchange instead of copy for results + -- added a bit more to the manual + Dec 27th, 2002 v0.03 -- Sped up s_mp_mul_high_digs by not computing the carries of the lower digits -- Fixed a bug where mp_set_int wouldn't zero the value first and set the used member. diff --git a/demo.c b/demo.c index cebec83..ed697c9 100644 --- a/demo.c +++ b/demo.c @@ -21,15 +21,20 @@ void reset(void) { _tt = clock(); } unsigned long long rdtsc(void) { return clock() - _tt; } #endif -void draw(mp_int *a) +void ndraw(mp_int *a, char *name) { char buf[4096]; - printf("a->used == %d\na->alloc == %d\na->sign == %d\n", a->used, a->alloc, a->sign); + printf("%s: ", name); mp_toradix(a, buf, 10); - printf("num == %s\n", buf); - printf("\n"); + printf("%s\n", buf); } +static void draw(mp_int *a) +{ + ndraw(a, ""); +} + + unsigned long lfsr = 0xAAAAAAAAUL; int lbit(void) @@ -43,7 +48,18 @@ int lbit(void) } } - +#ifdef U_MPI +int mp_reduce_setup(mp_int *a, mp_int *b) +{ + int res; + + mp_set(a, 1); + if ((res = s_mp_lshd(a, b->used * 2)) != MP_OKAY) { + return res; + } + return mp_div(a, b, a, NULL); +} +#endif int main(void) { @@ -51,7 +67,6 @@ int main(void) unsigned long expt_n, add_n, sub_n, mul_n, div_n, sqr_n, mul2d_n, div2d_n, gcd_n, lcm_n, inv_n; unsigned char cmd[4096], buf[4096]; int rr; - mp_digit tom; #ifdef TIMER int n; @@ -63,35 +78,63 @@ int main(void) mp_init(&c); mp_init(&d); mp_init(&e); - mp_init(&f); + mp_init(&f); - mp_read_radix(&a, "-2", 10); - mp_read_radix(&b, "2", 10); - mp_expt_d(&a, 3, &a); - draw(&a); + mp_read_radix(&a, "V//////////////////////////////////////////////////////////////////////////////////////", 64); + mp_reduce_setup(&b, &a); + printf("\n\n----\n\n"); + mp_toradix(&b, buf, 10); + printf("b == %s\n\n\n", buf); + mp_read_radix(&b, "4982748972349724892742", 10); + mp_sub_d(&a, 1, &c); + mp_exptmod(&b, &c, &a, &d); + mp_toradix(&d, buf, 10); + printf("b^p-1 == %s\n", buf); + + #ifdef TIMER mp_read_radix(&a, "340282366920938463463374607431768211455", 10); + mp_read_radix(&b, "234892374891378913789237289378973232333", 10); while (a.used * DIGIT_BIT < 8192) { reset(); - for (rr = 0; rr < 100000; rr++) { + for (rr = 0; rr < 1000; rr++) { + mp_invmod(&b, &a, &c); + } + tt = rdtsc(); + mp_mulmod(&b, &c, &a, &d); + if (mp_cmp_d(&d, 1) != MP_EQ) { + printf("Failed to invert\n"); + return 0; + } + printf("Inverting mod %d-bit took %llu cycles\n", mp_count_bits(&a), tt / ((unsigned long long)1000)); + mp_sqr(&a, &a); + mp_sqr(&b, &b); + } + + mp_read_radix(&a, "340282366920938463463374607431768211455", 10); + while (a.used * DIGIT_BIT < 8192) { + reset(); + for (rr = 0; rr < 1000000; rr++) { mp_mul(&a, &a, &b); } tt = rdtsc(); - printf("Multiplying %d-bit took %llu cycles\n", mp_count_bits(&a), tt / ((unsigned long long)100000)); + printf("Multiplying %d-bit took %llu cycles\n", mp_count_bits(&a), tt / ((unsigned long long)1000000)); mp_copy(&b, &a); } + + mp_read_radix(&a, "340282366920938463463374607431768211455", 10); while (a.used * DIGIT_BIT < 8192) { reset(); - for (rr = 0; rr < 100000; rr++) { + for (rr = 0; rr < 1000000; rr++) { mp_sqr(&a, &b); } tt = rdtsc(); - printf("Squaring %d-bit took %llu cycles\n", mp_count_bits(&a), tt / ((unsigned long long)100000)); + printf("Squaring %d-bit took %llu cycles\n", mp_count_bits(&a), tt / ((unsigned long long)1000000)); mp_copy(&b, &a); } @@ -117,7 +160,7 @@ int main(void) mp_mod(&b, &c, &b); mp_set(&c, 3); reset(); - for (rr = 0; rr < 35; rr++) { + for (rr = 0; rr < 50; rr++) { mp_exptmod(&c, &b, &a, &d); } tt = rdtsc(); @@ -130,7 +173,7 @@ int main(void) draw(&d); exit(0); } - printf("Exponentiating %d-bit took %llu cycles\n", mp_count_bits(&a), tt / ((unsigned long long)35)); + printf("Exponentiating %d-bit took %llu cycles\n", mp_count_bits(&a), tt / ((unsigned long long)50)); } } @@ -141,7 +184,7 @@ int main(void) printf("%7lu/%7lu/%7lu/%7lu/%7lu/%7lu/%7lu/%7lu/%7lu/%7lu/%7lu\r", add_n, sub_n, mul_n, div_n, sqr_n, mul2d_n, div2d_n, gcd_n, lcm_n, expt_n, inv_n); fgets(cmd, 4095, stdin); cmd[strlen(cmd)-1] = 0; - printf("%s ]\r",cmd); + printf("%s ]\r",cmd); fflush(stdout); if (!strcmp(cmd, "mul2d")) { ++mul2d_n; fgets(buf, 4095, stdin); mp_read_radix(&a, buf, 10); fgets(buf, 4095, stdin); sscanf(buf, "%d", &rr); @@ -173,7 +216,8 @@ int main(void) fgets(buf, 4095, stdin); mp_read_radix(&a, buf, 10); fgets(buf, 4095, stdin); mp_read_radix(&b, buf, 10); fgets(buf, 4095, stdin); mp_read_radix(&c, buf, 10); - mp_add(&a, &b, &d); + mp_copy(&a, &d); + mp_add(&d, &b, &d); if (mp_cmp(&c, &d) != MP_EQ) { printf("add %lu failure!\n", add_n); draw(&a);draw(&b);draw(&c);draw(&d); @@ -204,13 +248,13 @@ draw(&a);draw(&b);draw(&c);draw(&d); draw(&d); return 0; } - - + } else if (!strcmp(cmd, "sub")) { ++sub_n; fgets(buf, 4095, stdin); mp_read_radix(&a, buf, 10); fgets(buf, 4095, stdin); mp_read_radix(&b, buf, 10); fgets(buf, 4095, stdin); mp_read_radix(&c, buf, 10); - mp_sub(&a, &b, &d); + mp_copy(&a, &d); + mp_sub(&d, &b, &d); if (mp_cmp(&c, &d) != MP_EQ) { printf("sub %lu failure!\n", sub_n); draw(&a);draw(&b);draw(&c);draw(&d); @@ -220,7 +264,8 @@ draw(&a);draw(&b);draw(&c);draw(&d); fgets(buf, 4095, stdin); mp_read_radix(&a, buf, 10); fgets(buf, 4095, stdin); mp_read_radix(&b, buf, 10); fgets(buf, 4095, stdin); mp_read_radix(&c, buf, 10); - mp_mul(&a, &b, &d); + mp_copy(&a, &d); + mp_mul(&d, &b, &d); if (mp_cmp(&c, &d) != MP_EQ) { printf("mul %lu failure!\n", mul_n); draw(&a);draw(&b);draw(&c);draw(&d); @@ -242,7 +287,8 @@ draw(&a);draw(&b);draw(&c);draw(&d); draw(&e); draw(&f); } else if (!strcmp(cmd, "sqr")) { ++sqr_n; fgets(buf, 4095, stdin); mp_read_radix(&a, buf, 10); fgets(buf, 4095, stdin); mp_read_radix(&b, buf, 10); - mp_sqr(&a, &c); + mp_copy(&a, &c); + mp_sqr(&c, &c); if (mp_cmp(&b, &c) != MP_EQ) { printf("sqr %lu failure!\n", sqr_n); draw(&a);draw(&b);draw(&c); @@ -252,7 +298,8 @@ draw(&a);draw(&b);draw(&c); fgets(buf, 4095, stdin); mp_read_radix(&a, buf, 10); fgets(buf, 4095, stdin); mp_read_radix(&b, buf, 10); fgets(buf, 4095, stdin); mp_read_radix(&c, buf, 10); - mp_gcd(&a, &b, &d); + mp_copy(&a, &d); + mp_gcd(&d, &b, &d); d.sign = c.sign; if (mp_cmp(&c, &d) != MP_EQ) { printf("gcd %lu failure!\n", gcd_n); @@ -263,7 +310,8 @@ draw(&a);draw(&b);draw(&c);draw(&d); fgets(buf, 4095, stdin); mp_read_radix(&a, buf, 10); fgets(buf, 4095, stdin); mp_read_radix(&b, buf, 10); fgets(buf, 4095, stdin); mp_read_radix(&c, buf, 10); - mp_lcm(&a, &b, &d); + mp_copy(&a, &d); + mp_lcm(&d, &b, &d); d.sign = c.sign; if (mp_cmp(&c, &d) != MP_EQ) { printf("lcm %lu failure!\n", lcm_n); @@ -275,7 +323,8 @@ draw(&a);draw(&b);draw(&c);draw(&d); fgets(buf, 4095, stdin); mp_read_radix(&b, buf, 10); fgets(buf, 4095, stdin); mp_read_radix(&c, buf, 10); fgets(buf, 4095, stdin); mp_read_radix(&d, buf, 10); - mp_exptmod(&a, &b, &c, &e); + mp_copy(&a, &e); + mp_exptmod(&e, &b, &c, &e); if (mp_cmp(&d, &e) != MP_EQ) { printf("expt %lu failure!\n", expt_n); draw(&a);draw(&b);draw(&c);draw(&d); draw(&e); diff --git a/makefile b/makefile index 8ca82da..52e0735 100644 --- a/makefile +++ b/makefile @@ -1,7 +1,7 @@ CC = gcc -CFLAGS += -Wall -W -O3 -funroll-loops +CFLAGS += -DDEBUG -Wall -W -Os -VERSION=0.03 +VERSION=0.04 default: test diff --git a/mtest/mtest.c b/mtest/mtest.c index 393065f..a32e0e5 100644 --- a/mtest/mtest.c +++ b/mtest/mtest.c @@ -41,7 +41,7 @@ void rand_num(mp_int *a) unsigned char buf[512]; top: - size = 1 + (fgetc(rng) % 96); + size = 1 + ((fgetc(rng)*fgetc(rng)) % 32); buf[0] = (fgetc(rng)&1)?1:0; fread(buf+1, 1, size, rng); for (n = 0; n < size; n++) { @@ -57,7 +57,7 @@ void rand_num2(mp_int *a) unsigned char buf[512]; top: - size = 1 + (fgetc(rng) % 128); + size = 1 + ((fgetc(rng)*fgetc(rng)) % 32); buf[0] = (fgetc(rng)&1)?1:0; fread(buf+1, 1, size, rng); for (n = 0; n < size; n++) { @@ -196,7 +196,7 @@ int main(void) mp_todecimal(&c, buf); printf("%s\n", buf); } else if (n == 9) { - /* lcm test */ + /* exptmod test */ rand_num2(&a); rand_num2(&b); rand_num2(&c); @@ -216,8 +216,10 @@ int main(void) rand_num2(&a); rand_num2(&b); b.sign = MP_ZPOS; + a.sign = MP_ZPOS; mp_gcd(&a, &b, &c); if (mp_cmp_d(&c, 1) != 0) continue; + if (mp_cmp_d(&b, 1) == 0) continue; mp_invmod(&a, &b, &c); printf("invmod\n"); mp_todecimal(&a, buf);