diff --git a/bn.c b/bn.c index c0af1b7..6c4c646 100644 --- a/bn.c +++ b/bn.c @@ -113,11 +113,13 @@ int mp_set_int(mp_int *a, unsigned long b) if ((res = mp_grow(a, 32/DIGIT_BIT + 1)) != MP_OKAY) { return res; } + mp_zero(a); /* set four bits at a time, simplest solution to the what if DIGIT_BIT==7 case */ for (x = 0; x < 8; x++) { mp_mul_2d(a, 4, a); a->dp[0] |= (b>>28)&15; b <<= 4; + a->used += 32/DIGIT_BIT + 1; } mp_clamp(a); return MP_OKAY; @@ -140,8 +142,9 @@ int mp_copy(mp_int *a, mp_int *b) int res, n; /* if dst == src do nothing */ - if (a->dp == b->dp) + if (a == b || a->dp == b->dp) { return MP_OKAY; + } /* grow dest */ if ((res = mp_grow(b, a->used)) != MP_OKAY) { @@ -338,15 +341,22 @@ int mp_div_2d(mp_int *a, int b, mp_int *c, mp_int *d) { mp_digit D, r, rr; int x, res; + mp_int t; + + if ((res = mp_init(&t)) != MP_OKAY) { + return res; + } if (d != NULL) { - if ((res = mp_mod_2d(a, b, d)) != MP_OKAY) { + if ((res = mp_mod_2d(a, b, &t)) != MP_OKAY) { + mp_clear(&t); return res; } } /* copy */ if ((res = mp_copy(a, c)) != MP_OKAY) { + mp_clear(&t); return res; } @@ -364,6 +374,12 @@ int mp_div_2d(mp_int *a, int b, mp_int *c, mp_int *d) } } mp_clamp(c); + if (d != NULL) { + res = mp_copy(&t, d); + } else { + res = MP_OKAY; + } + mp_clear(&t); return MP_OKAY; } @@ -392,7 +408,7 @@ int mp_mul_2d(mp_int *a, int b, mp_int *c) d = (mp_digit)(b % DIGIT_BIT); if (d != 0) { r = 0; - for (x = 0; x < a->used; x++) { + for (x = 0; x < c->used; x++) { rr = (c->dp[x] >> (DIGIT_BIT - d)) & ((mp_digit)((1U<dp[x] = ((c->dp[x] << d) | r) & MP_MASK; r = rr; @@ -405,13 +421,49 @@ int mp_mul_2d(mp_int *a, int b, mp_int *c) /* b = a/2 */ int mp_div_2(mp_int *a, mp_int *b) { - return mp_div_2d(a, 1, b, NULL); + mp_digit r, rr; + int x, res; + + /* copy */ + if ((res = mp_copy(a, b)) != MP_OKAY) { + return res; + } + + r = 0; + for (x = b->used - 1; x >= 0; x--) { + rr = b->dp[x] & 1; + b->dp[x] = (b->dp[x] >> 1) | (r << (DIGIT_BIT-1)); + r = rr; + } + mp_clamp(b); + return MP_OKAY; } /* b = a*2 */ int mp_mul_2(mp_int *a, mp_int *b) { - return mp_mul_2d(a, 1, b); + mp_digit r, rr; + int x, res; + + /* copy */ + if ((res = mp_copy(a, b)) != MP_OKAY) { + return res; + } + + if ((res = mp_grow(b, b->used + 1)) != MP_OKAY) { + return res; + } + b->used = b->alloc; + + /* shift any bit count < DIGIT_BIT */ + r = 0; + for (x = 0; x < b->used; x++) { + rr = (b->dp[x] >> (DIGIT_BIT - 1)) & 1; + b->dp[x] = ((b->dp[x] << 1) | r) & MP_MASK; + r = rr; + } + mp_clamp(b); + return MP_OKAY; } /* low level addition */ @@ -526,8 +578,6 @@ static int fast_s_mp_mul_digs(mp_int *a, mp_int *b, mp_int *c, int digs) mp_word W[512], *_W; mp_digit tmpx, *tmpt, *tmpy; -// printf("\nHOLA\n"); - if ((res = mp_init_size(&t, digs)) != MP_OKAY) { return res; } @@ -624,7 +674,7 @@ static int fast_s_mp_mul_high_digs(mp_int *a, mp_int *b, mp_int *c, int digs) pa = a->used; pb = b->used; - memset(W, 0, (pa + pb + 1) * sizeof(mp_word)); + memset(&W[digs], 0, (pa + pb + 1 - digs) * sizeof(mp_word)); for (ix = 0; ix < pa; ix++) { tmpx = a->dp[ix]; tmpt = &(t.dp[digs]); @@ -636,7 +686,7 @@ static int fast_s_mp_mul_high_digs(mp_int *a, mp_int *b, mp_int *c, int digs) } /* now convert the array W downto what we need */ - for (ix = 1; ix < (pa+pb+1); ix++) { + for (ix = digs+1; ix < (pa+pb+1); ix++) { W[ix] = W[ix] + (W[ix-1] >> ((mp_word)DIGIT_BIT)); t.dp[ix-1] = W[ix-1] & ((mp_word)MP_MASK); } @@ -665,7 +715,7 @@ static int s_mp_mul_high_digs(mp_int *a, mp_int *b, mp_int *c, int digs) mp_digit tmpx, *tmpt, *tmpy; /* can we use the fast multiplier? */ - if ((digs < 512) && digs < (1<<( (CHAR_BIT*sizeof(mp_word)) - (2*DIGIT_BIT)))) { + if (((a->used + b->used + 1) < 512) && MAX(a->used, b->used) < (1<<( (CHAR_BIT*sizeof(mp_word)) - (2*DIGIT_BIT)))) { return fast_s_mp_mul_high_digs(a,b,c,digs); } @@ -959,13 +1009,14 @@ ERR : /* high level multiplication (handles sign) */ int mp_mul(mp_int *a, mp_int *b, mp_int *c) { - int res; + int res, neg; + neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG; if (MIN(a->used, b->used) > KARATSUBA_MUL_CUTOFF) { res = mp_karatsuba_mul(a, b, c); } else { res = s_mp_mul(a, b, c); } - c->sign = (a->sign == b->sign) ? MP_ZPOS : MP_NEG; + c->sign = neg; return res; } @@ -1047,13 +1098,17 @@ int mp_div(mp_int *a, mp_int *b, mp_int *c, mp_int *d) mp_int q, x, y, t1, t2; int res, n, t, i, norm, neg; + /* is divisor zero ? */ + if (mp_iszero(b) == 1) { + return MP_VAL; + } + /* if a < b then q=0, r = a */ if (mp_cmp_mag(a, b) == MP_LT) { if (d != NULL) { - res = mp_copy(a, d); - d->sign = a->sign; + res = mp_copy(a, d); } else { - res = MP_OKAY; + res = MP_OKAY; } if (c != NULL) { mp_zero(c); @@ -1182,6 +1237,8 @@ int mp_div(mp_int *a, mp_int *b, mp_int *c, mp_int *d) } /* now q is the quotient and x is the remainder [which we have to normalize] */ + /* get sign before writing to c */ + x.sign = a->sign; if (c != NULL) { mp_clamp(&q); mp_copy(&q, c); @@ -1189,7 +1246,6 @@ int mp_div(mp_int *a, mp_int *b, mp_int *c, mp_int *d) } if (d != NULL) { - x.sign = a->sign; mp_div_2d(&x, norm, &x, NULL); mp_clamp(&x); mp_copy(&x, d); @@ -1205,6 +1261,31 @@ __Q: mp_clear(&q); return res; } +/* c = a mod b, 0 <= c < b */ +int mp_mod(mp_int *a, mp_int *b, mp_int *c) +{ + mp_int t; + int res; + + if ((res = mp_init(&t)) != MP_OKAY) { + return res; + } + + if ((res = mp_div(a, b, NULL, &t)) != MP_OKAY) { + mp_clear(&t); + return res; + } + + if (t.sign == MP_NEG) { + res = mp_add(b, &t, c); + } else { + res = mp_copy(&t, c); + } + + mp_clear(&t); + return res; +} + /* single digit addition */ int mp_add_d(mp_int *a, mp_digit b, mp_int *c) { @@ -1259,6 +1340,7 @@ int mp_mul_d(mp_int *a, mp_digit b, mp_int *c) } t.dp[ix] = u; + t.sign = a->sign; mp_clamp(&t); if ((res = mp_copy(&t, c)) != MP_OKAY) { mp_clear(&t); @@ -1295,50 +1377,144 @@ int mp_div_d(mp_int *a, mp_digit b, mp_int *c, mp_digit *d) return res; } +int mp_mod_d(mp_int *a, mp_digit b, mp_digit *c) +{ + mp_int t, t2; + int res; + + if ((res = mp_init(&t)) != MP_OKAY) { + return res; + } + + if ((res = mp_init(&t2)) != MP_OKAY) { + mp_clear(&t); + return res; + } + + mp_set(&t, b); + mp_div(a, &t, NULL, &t2); + + if (t2.sign == MP_NEG) { + if ((res = mp_add_d(&t2, b, &t2)) != MP_OKAY) { + mp_clear(&t); + mp_clear(&t2); + return res; + } + } + *c = t2.dp[0]; + mp_clear(&t); + mp_clear(&t2); + return MP_OKAY; +} + +int mp_expt_d(mp_int *a, mp_digit b, mp_int *c) +{ + int res, x; + mp_int g; + + if ((res = mp_init_copy(&g, a)) != MP_OKAY) { + return res; + } + + /* set initial result */ + mp_set(c, 1); + + for (x = 0; x < (int)DIGIT_BIT; x++) { + if ((res = mp_sqr(c, c)) != MP_OKAY) { + mp_clear(&g); + return res; + } + + if (b & (mp_digit)(1<<(DIGIT_BIT-1))) { + if ((res = mp_mul(c, &g, c)) != MP_OKAY) { + mp_clear(&g); + return res; + } + } + + b <<= 1; + } + + mp_clear(&g); + return MP_OKAY; +} + /* simple modular functions */ /* d = a + b (mod c) */ int mp_addmod(mp_int *a, mp_int *b, mp_int *c, mp_int *d) { int res; + mp_int t; - if ((res = mp_add(a, b, d)) != MP_OKAY) { + if ((res = mp_init(&t)) != MP_OKAY) { return res; } - return mp_mod(d, c, d); + + if ((res = mp_add(a, b, &t)) != MP_OKAY) { + mp_clear(&t); + return res; + } + res = mp_mod(&t, c, d); + mp_clear(&t); + return res; } /* d = a - b (mod c) */ int mp_submod(mp_int *a, mp_int *b, mp_int *c, mp_int *d) { int res; + mp_int t; - if ((res = mp_sub(a, b, d)) != MP_OKAY) { + if ((res = mp_init(&t)) != MP_OKAY) { return res; } - return mp_mod(d, c, d); + + if ((res = mp_sub(a, b, &t)) != MP_OKAY) { + mp_clear(&t); + return res; + } + res = mp_mod(&t, c, d); + mp_clear(&t); + return res; } /* d = a * b (mod c) */ int mp_mulmod(mp_int *a, mp_int *b, mp_int *c, mp_int *d) { int res; + mp_int t; - if ((res = mp_mul(a, b, d)) != MP_OKAY) { + if ((res = mp_init(&t)) != MP_OKAY) { return res; } - return mp_mod(d, c, d); + + if ((res = mp_mul(a, b, &t)) != MP_OKAY) { + mp_clear(&t); + return res; + } + res = mp_mod(&t, c, d); + mp_clear(&t); + return res; } /* c = a * a (mod b) */ int mp_sqrmod(mp_int *a, mp_int *b, mp_int *c) { int res; + mp_int t; - if ((res = mp_sqr(a, c)) != MP_OKAY) { + if ((res = mp_init(&t)) != MP_OKAY) { return res; } - return mp_mod(c, b, c); + + if ((res = mp_sqr(a, &t)) != MP_OKAY) { + mp_clear(&t); + return res; + } + res = mp_mod(&t, b, c); + mp_clear(&t); + return res; } /* Greatest Common Divisor using the binary method [Algorithm B, page 338, vol2 of TAOCP] @@ -1462,107 +1638,184 @@ int mp_lcm(mp_int *a, mp_int *b, mp_int *c) return res; } -/* computes the modular inverse via extended euclidean algorithm, that is c = 1/a mod b */ +/* computes the modular inverse via binary extended euclidean algorithm, that is c = 1/a mod b */ int mp_invmod(mp_int *a, mp_int *b, mp_int *c) { - int res; - mp_int t1, t2, t3, u1, u2, u3, v1, v2, v3, q; + mp_int x, y, u, v, A, B, C, D; + int res, neg; - if ((res = mp_init(&t1)) != MP_OKAY) { - return res; + if ((res = mp_init(&x)) != MP_OKAY) { + goto __ERR; } - if ((res = mp_init(&t2)) != MP_OKAY) { - goto __T1; + if ((res = mp_init(&y)) != MP_OKAY) { + goto __X; } - if ((res = mp_init(&t3)) != MP_OKAY) { - goto __T2; - } - - if ((res = mp_init(&u1)) != MP_OKAY) { - goto __T3; + if ((res = mp_init(&u)) != MP_OKAY) { + goto __Y; } - if ((res = mp_init(&u2)) != MP_OKAY) { - goto __U1; + if ((res = mp_init(&v)) != MP_OKAY) { + goto __U; } - if ((res = mp_init(&u3)) != MP_OKAY) { - goto __U2; + if ((res = mp_init(&A)) != MP_OKAY) { + goto __V; } - if ((res = mp_init(&v1)) != MP_OKAY) { - goto __U3; - } - - if ((res = mp_init(&v2)) != MP_OKAY) { - goto __V1; - } - - if ((res = mp_init(&v3)) != MP_OKAY) { - goto __V2; + if ((res = mp_init(&B)) != MP_OKAY) { + goto __A; } - if ((res = mp_init(&q)) != MP_OKAY) { - goto __V3; - } - - /* (u1, u2, u3) = (1, 0, a) */ - mp_set(&u1, 1); - if ((res = mp_copy(a, &u3)) != MP_OKAY) { - goto __Q; + if ((res = mp_init(&C)) != MP_OKAY) { + goto __B; } - /* (v1, v2, v3) = (0, 1, b) */ - mp_set(&u2, 1); - if ((res = mp_copy(b, &v3)) != MP_OKAY) { - goto __Q; + if ((res = mp_init(&D)) != MP_OKAY) { + goto __C; } - while (mp_iszero(&v3) == 0) { - if ((res = mp_div(&u3, &v3, &q, NULL)) != MP_OKAY) { - goto __Q; - } - - /* (t1, t2, t3) = (u1, u2, u3) - q*(v1, v2, v3) */ - if ((res = mp_mul(&q, &v1, &t1)) != MP_OKAY) { goto __Q; } - if ((res = mp_sub(&u1, &t1, &t1)) != MP_OKAY) { goto __Q; } - if ((res = mp_mul(&q, &v2, &t2)) != MP_OKAY) { goto __Q; } - if ((res = mp_sub(&u2, &t2, &t2)) != MP_OKAY) { goto __Q; } - if ((res = mp_mul(&q, &v3, &t3)) != MP_OKAY) { goto __Q; } - if ((res = mp_sub(&u3, &t3, &t3)) != MP_OKAY) { goto __Q; } - - /* u = v */ - if ((res = mp_copy(&v1, &u1)) != MP_OKAY) { goto __Q; } - if ((res = mp_copy(&v2, &u2)) != MP_OKAY) { goto __Q; } - if ((res = mp_copy(&v3, &u3)) != MP_OKAY) { goto __Q; } - - /* v = t */ - if ((res = mp_copy(&t1, &v1)) != MP_OKAY) { goto __Q; } - if ((res = mp_copy(&t2, &v2)) != MP_OKAY) { goto __Q; } - if ((res = mp_copy(&t3, &v3)) != MP_OKAY) { goto __Q; } - } - - /* if u3 != 1, then there is no inverse */ - if (mp_cmp_d(&u3, 1) != MP_EQ) { + /* x = a, y = b */ + if ((res = mp_copy(a, &x)) != MP_OKAY) { + goto __D; + } + if ((res = mp_copy(b, &y)) != MP_OKAY) { + goto __D; + } + + if ((res = mp_abs(&x, &x)) != MP_OKAY) { + goto __D; + } + + /* 2. [modified] if x,y are both even then return an error! */ + if (mp_iseven(&x) == 1 && mp_iseven(&y) == 1) { res = MP_VAL; - goto __Q; + goto __D; } - - /* u1 is the inverse */ - res = mp_copy(&u1, c); -__Q : mp_clear(&q); -__V3: mp_clear(&v3); -__V2: mp_clear(&v1); -__V1: mp_clear(&v1); -__U3: mp_clear(&u3); -__U2: mp_clear(&u2); -__U1: mp_clear(&u1); -__T3: mp_clear(&t3); -__T2: mp_clear(&t2); -__T1: mp_clear(&t1); - return res; + + /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */ + if ((res = mp_copy(&x, &u)) != MP_OKAY) { + goto __D; + } + if ((res = mp_copy(&y, &v)) != MP_OKAY) { + goto __D; + } + mp_set(&A, 1); + mp_set(&D, 1); + + +top: + /* 4. while u is even do */ + while (mp_iseven(&u) == 1) { + /* 4.1 u = u/2 */ + if ((res = mp_div_2(&u, &u)) != MP_OKAY) { + goto __D; + } + /* 4.2 if A or B is odd then */ + if (mp_iseven(&A) == 0 || mp_iseven(&B) == 0) { + /* A = (A+y)/2, B = (B-x)/2 */ + if ((res = mp_add(&A, &y, &A)) != MP_OKAY) { + goto __D; + } + if ((res = mp_sub(&B, &x, &B)) != MP_OKAY) { + goto __D; + } + } + /* A = A/2, B = B/2 */ + if ((res = mp_div_2(&A, &A)) != MP_OKAY) { + goto __D; + } + if ((res = mp_div_2(&B, &B)) != MP_OKAY) { + goto __D; + } + } + + + /* 5. while v is even do */ + while (mp_iseven(&v) == 1) { + /* 5.1 v = v/2 */ + if ((res = mp_div_2(&v, &v)) != MP_OKAY) { + goto __D; + } + /* 5.2 if C,D are even then */ + if (mp_iseven(&C) == 0 || mp_iseven(&D) == 0) { + /* C = (C+y)/2, D = (D-x)/2 */ + if ((res = mp_add(&C, &y, &C)) != MP_OKAY) { + goto __D; + } + if ((res = mp_sub(&D, &x, &D)) != MP_OKAY) { + goto __D; + } + } + /* C = C/2, D = D/2 */ + if ((res = mp_div_2(&C, &C)) != MP_OKAY) { + goto __D; + } + if ((res = mp_div_2(&D, &D)) != MP_OKAY) { + goto __D; + } + } + + /* 6. if u >= v then */ + if (mp_cmp(&u, &v) != MP_LT) { + /* u = u - v, A = A - C, B = B - D */ + if ((res = mp_sub(&u, &v, &u)) != MP_OKAY) { + goto __D; + } + + if ((res = mp_sub(&A, &C, &A)) != MP_OKAY) { + goto __D; + } + + if ((res = mp_sub(&B, &D, &B)) != MP_OKAY) { + goto __D; + } + } else { + /* v - v - u, C = C - A, D = D - B */ + if ((res = mp_sub(&v, &u, &v)) != MP_OKAY) { + goto __D; + } + + if ((res = mp_sub(&C, &A, &C)) != MP_OKAY) { + goto __D; + } + + if ((res = mp_sub(&D, &B, &D)) != MP_OKAY) { + goto __D; + } + } + + /* if not zero goto step 4 */ + if (mp_iszero(&u) == 0) goto top; + + /* now a = C, b = D, gcd == g*v */ + + /* if v != 1 then there is no inverse */ + if (mp_cmp_d(&v, 1) != MP_EQ) { + res = MP_VAL; + goto __D; + } + + /* a is now the inverse */ + neg = a->sign; + if (C.sign == MP_NEG) { + res = mp_add(b, &C, c); + } else { + res = mp_copy(&C, c); + } + c->sign = neg; + +__D: mp_clear(&D); +__C: mp_clear(&C); +__B: mp_clear(&B); +__A: mp_clear(&A); +__V: mp_clear(&v); +__U: mp_clear(&u); +__Y: mp_clear(&y); +__X: mp_clear(&x); +__ERR: + return res; } /* pre-calculate the value required for Barrett reduction @@ -1838,7 +2091,7 @@ int mp_count_bits(mp_int *a) q = a->dp[a->used - 1]; while (q) { ++r; - q >>= 1UL; + q >>= ((mp_digit)1); } return r; } @@ -1846,13 +2099,14 @@ int mp_count_bits(mp_int *a) /* reads a unsigned char array, assumes the msb is stored first [big endian] */ int mp_read_unsigned_bin(mp_int *a, unsigned char *b, int c) { - int res; + int res, n; mp_zero(a); - a->used = (c/DIGIT_BIT) + ((c % DIGIT_BIT) != 0 ? 1: 0); + n = (c/DIGIT_BIT) + ((c % DIGIT_BIT) != 0 ? 1: 0); if ((res = mp_grow(a, a->used)) != MP_OKAY) { return res; } + a->used = n; while (c-- > 0) { if ((res = mp_mul_2d(a, 8, a)) != MP_OKAY) { return res; diff --git a/bn.h b/bn.h index 496f42d..54c8e7a 100644 --- a/bn.h +++ b/bn.h @@ -46,7 +46,9 @@ #define DIGIT_BIT ((CHAR_BIT * sizeof(mp_digit) - 1)) /* bits per digit */ #endif -#define MP_MASK ((((mp_digit)1)<<((mp_digit)DIGIT_BIT))-((mp_digit)1)) +#define MP_DIGIT_BIT DIGIT_BIT +#define MP_MASK ((((mp_digit)1)<<((mp_digit)DIGIT_BIT))-((mp_digit)1)) +#define MP_DIGIT_MAX MP_MASK /* equalities */ #define MP_LT -1 /* less than */ @@ -57,8 +59,9 @@ #define MP_NEG 1 /* negative */ #define MP_OKAY 0 /* ok result */ -#define MP_MEM 1 /* out of mem */ -#define MP_VAL 2 /* invalid input */ +#define MP_MEM -2 /* out of mem */ +#define MP_VAL -3 /* invalid input */ +#define MP_RANGE MP_VAL #define KARATSUBA_MUL_CUTOFF 80 /* Min. number of digits before Karatsuba multiplication is used. */ #define KARATSUBA_SQR_CUTOFF 80 /* Min. number of digits before Karatsuba squaring is used. */ @@ -68,6 +71,10 @@ typedef struct { mp_digit *dp; } mp_int; +#define USED(m) ((m)->used) +#define DIGIT(m,k) ((m)->dp[k]) +#define SIGN(m) ((m)->sign) + /* ---> init and deinit bignum functions <--- */ /* init a bignum */ @@ -155,8 +162,8 @@ int mp_sqr(mp_int *a, mp_int *b); /* a/b => cb + d == a */ int mp_div(mp_int *a, mp_int *b, mp_int *c, mp_int *d); -/* c == a mod b */ -#define mp_mod(a, b, c) mp_div(a, b, NULL, c) +/* c = a mod b, 0 <= c < b */ +int mp_mod(mp_int *a, mp_int *b, mp_int *c); /* ---> single digit functions <--- */ @@ -175,8 +182,11 @@ int mp_mul_d(mp_int *a, mp_digit b, mp_int *c); /* a/b => cb + d == a */ int mp_div_d(mp_int *a, mp_digit b, mp_int *c, mp_digit *d); -/* c = a mod b */ -#define mp_mod_d(a,b,c) mp_div_d(a, b, NULL, c) +/* c = a^b */ +int mp_expt_d(mp_int *a, mp_digit b, mp_int *c); + +/* c = a mod b, 0 <= c < b */ +int mp_mod_d(mp_int *a, mp_digit b, mp_digit *c); /* ---> number theory <--- */ diff --git a/bn.pdf b/bn.pdf index 43b9c7d..8f8f58d 100644 Binary files a/bn.pdf and b/bn.pdf differ diff --git a/bn.tex b/bn.tex index 7edc56b..0e651ce 100644 --- a/bn.tex +++ b/bn.tex @@ -1,7 +1,7 @@ \documentclass{article} \begin{document} -\title{LibTomMath v0.02 \\ A Free Multiple Precision Integer Library} +\title{LibTomMath v0.03 \\ A Free Multiple Precision Integer Library} \author{Tom St Denis \\ tomstdenis@iahu.ca} \maketitle \newpage @@ -82,6 +82,15 @@ used member refers to how many digits are actually used in the representation of to how many digits have been allocated off the heap. There is also the $\beta$ quantity which is equal to $2^W$ where $W$ is the number of bits in a digit (default is 28). +\subsection{Calling Functions} +Most functions expect pointers to mp\_int's as parameters. To save on memory usage it is possible to have source +variables as destinations. For example: +\begin{verbatim} + mp_add(&x, &y, &x); /* x = x + y */ + mp_mul(&x, &z, &x); /* x = x * z */ + mp_div_2(&x, &x); /* x = x / 2 */ +\end{verbatim} + \subsection{Basic Functionality} Essentially all LibTomMath functions return one of three values to indicate if the function worked as desired. A function will return \textbf{MP\_OKAY} if the function was successful. A function will return \textbf{MP\_MEM} if @@ -219,8 +228,8 @@ int mp_sqr(mp_int *a, mp_int *b); /* a/b => cb + d == a */ int mp_div(mp_int *a, mp_int *b, mp_int *c, mp_int *d); -/* c == a mod b */ -#define mp_mod(a, b, c) mp_div(a, b, NULL, c) +/* c = a mod b, 0 <= c < b */ +int mp_mod(mp_int *a, mp_int *b, mp_int *c); \end{verbatim} The \textbf{mp\_cmp} will compare two integers. It will return \textbf{MP\_LT} if the first parameter is less than @@ -251,8 +260,8 @@ int mp_mul_d(mp_int *a, mp_digit b, mp_int *c); /* a/b => cb + d == a */ int mp_div_d(mp_int *a, mp_digit b, mp_int *c, mp_digit *d); -/* c = a mod b */ -#define mp_mod_d(a,b,c) mp_div_d(a, b, NULL, c) +/* c = a mod b, 0 <= c < b */ +int mp_mod_d(mp_int *a, mp_digit b, mp_digit *c); \end{verbatim} Note that care should be taken for the value of the digit passed. By default, any 28-bit integer is a valid digit that can @@ -328,27 +337,27 @@ average. The following results were observed. \begin{tabular}{c|c|c|c} \hline \textbf{Operation} & \textbf{Size (bits)} & \textbf{Time with MPI (cycles)} & \textbf{Time with LibTomMath (cycles)} \\ \hline -Multiply & 128 & 1,394 & 893 \\ -Multiply & 256 & 2,559 & 1,744 \\ -Multiply & 512 & 7,919 & 4,484 \\ -Multiply & 1024 & 28,460 & 9,326, \\ -Multiply & 2048 & 109,637 & 30,140 \\ -Multiply & 4096 & 467,226 & 122,290 \\ +Multiply & 128 & 1,426 & 928 \\ +Multiply & 256 & 2,551 & 1,787 \\ +Multiply & 512 & 7,913 & 3,458 \\ +Multiply & 1024 & 28,496 & 9,271 \\ +Multiply & 2048 & 109,897 & 29,917 \\ +Multiply & 4096 & 469,970 & 123,934 \\ \hline -Square & 128 & 1,288 & 1,172 \\ -Square & 256 & 1,705 & 2,162 \\ -Square & 512 & 5,365 & 3,723 \\ -Square & 1024 & 18,836 & 9,063 \\ -Square & 2048 & 72,334 & 27,489 \\ -Square & 4096 & 306,252 & 110,372 \\ +Square & 128 & 1,319 & 1,230 \\ +Square & 256 & 1,776 & 2,131 \\ +Square & 512 & 5,399 & 3,694 \\ +Square & 1024 & 18,991 & 9,172 \\ +Square & 2048 & 72,126 & 27,352 \\ +Square & 4096 & 306,269 & 110,607 \\ \hline -Exptmod & 512 & 30,497,732 & 6,898,504 \\ -Exptmod & 768 & 98,943,020 & 15,510,779 \\ -Exptmod & 1024 & 221,123,749 & 27,962,904 \\ -Exptmod & 2048 & 1,694,796,907 & 146,631,975 \\ -Exptmod & 2560 & 3,262,360,107 & 305,530,060 \\ -Exptmod & 3072 & 5,647,243,373 & 472,572,762 \\ -Exptmod & 4096 & 13,345,194,048 & 984,415,240 +Exptmod & 512 & 32,021,586 & 6,880,075 \\ +Exptmod & 768 & 97,595,492 & 15,202,614 \\ +Exptmod & 1024 & 223,302,532 & 28,081,865 \\ +Exptmod & 2048 & 1,682,223,369 & 146,545,454 \\ +Exptmod & 2560 & 3,268,615,571 & 310,970,112 \\ +Exptmod & 3072 & 5,597,240,141 & 480,703,712 \\ +Exptmod & 4096 & 13,347,270,891 & 985,918,868 \end{tabular} \end{center} diff --git a/changes.txt b/changes.txt index fb6c798..ca7f537 100644 --- a/changes.txt +++ b/changes.txt @@ -1,3 +1,15 @@ +Dec 27th, 2002 +v0.03 -- Sped up s_mp_mul_high_digs by not computing the carries of the lower digits + -- Fixed a bug where mp_set_int wouldn't zero the value first and set the used member. + -- fixed a bug in s_mp_mul_high_digs where the limit placed on the result digits was not calculated properly + -- fixed bugs in add/sub/mul/sqr_mod functions where if the modulus and dest were the same it wouldn't work + -- fixed a bug in mp_mod and mp_mod_d concerning negative inputs + -- mp_mul_d didn't preserve sign + -- Many many many many fixes + -- Works in LibTomCrypt now :-) + -- Added iterations to the timing demos... more accurate. + -- Tom needs a job. + Dec 26th, 2002 v0.02 -- Fixed a few "slips" in the manual. This is "LibTomMath" afterall :-) -- Added mp_cmp_mag, mp_neg, mp_abs and mp_radix_size that were missing. diff --git a/demo.c b/demo.c index 7916091..cebec83 100644 --- a/demo.c +++ b/demo.c @@ -21,22 +21,37 @@ void reset(void) { _tt = clock(); } unsigned long long rdtsc(void) { return clock() - _tt; } #endif -static void draw(mp_int *a) +void draw(mp_int *a) { char buf[4096]; - int x; printf("a->used == %d\na->alloc == %d\na->sign == %d\n", a->used, a->alloc, a->sign); mp_toradix(a, buf, 10); printf("num == %s\n", buf); printf("\n"); } +unsigned long lfsr = 0xAAAAAAAAUL; + +int lbit(void) +{ + if (lfsr & 0x80000000UL) { + lfsr = ((lfsr << 1) ^ 0x8000001BUL) & 0xFFFFFFFFUL; + return 1; + } else { + lfsr <<= 1; + return 0; + } +} + + + int main(void) { mp_int a, b, c, d, e, f; - unsigned long expt_n, add_n, sub_n, mul_n, div_n, sqr_n, mul2d_n, div2d_n, gcd_n, lcm_n; + unsigned long expt_n, add_n, sub_n, mul_n, div_n, sqr_n, mul2d_n, div2d_n, gcd_n, lcm_n, inv_n; unsigned char cmd[4096], buf[4096]; int rr; + mp_digit tom; #ifdef TIMER int n; @@ -50,17 +65,21 @@ int main(void) mp_init(&e); mp_init(&f); + mp_read_radix(&a, "-2", 10); + mp_read_radix(&b, "2", 10); + mp_expt_d(&a, 3, &a); + draw(&a); #ifdef TIMER mp_read_radix(&a, "340282366920938463463374607431768211455", 10); while (a.used * DIGIT_BIT < 8192) { reset(); - for (rr = 0; rr < 10000; rr++) { + for (rr = 0; rr < 100000; rr++) { mp_mul(&a, &a, &b); } tt = rdtsc(); - printf("Multiplying %d-bit took %llu cycles\n", mp_count_bits(&a), tt / ((unsigned long long)10000)); + printf("Multiplying %d-bit took %llu cycles\n", mp_count_bits(&a), tt / ((unsigned long long)100000)); mp_copy(&b, &a); } @@ -68,11 +87,11 @@ int main(void) mp_read_radix(&a, "340282366920938463463374607431768211455", 10); while (a.used * DIGIT_BIT < 8192) { reset(); - for (rr = 0; rr < 10000; rr++) { + for (rr = 0; rr < 100000; rr++) { mp_sqr(&a, &b); } tt = rdtsc(); - printf("Squaring %d-bit took %llu cycles\n", mp_count_bits(&a), tt / ((unsigned long long)10000)); + printf("Squaring %d-bit took %llu cycles\n", mp_count_bits(&a), tt / ((unsigned long long)100000)); mp_copy(&b, &a); } @@ -87,19 +106,18 @@ int main(void) "1214855636816562637502584060163403830270705000634713483015101384881871978446801224798536155406895823305035467591632531067547890948695117172076954220727075688048751022421198712032848890056357845974246560748347918630050853933697792254955890439720297560693579400297062396904306270145886830719309296352765295712183040773146419022875165382778007040109957609739589875590885701126197906063620133954893216612678838507540777138437797705602453719559017633986486649523611975865005712371194067612263330335590526176087004421363598470302731349138773205901447704682181517904064735636518462452242791676541725292378925568296858010151852326316777511935037531017413910506921922450666933202278489024521263798482237150056835746454842662048692127173834433089016107854491097456725016327709663199738238442164843147132789153725513257167915555162094970853584447993125488607696008169807374736711297007473812256272245489405898470297178738029484459690836250560495461579533254473316340608217876781986188705928270735695752830825527963838355419762516246028680280988020401914551825487349990306976304093109384451438813251211051597392127491464898797406789175453067960072008590614886532333015881171367104445044718144312416815712216611576221546455968770801413440778423979", NULL }; - srand(time(NULL)); for (n = 0; primes[n]; n++) { mp_read_radix(&a, primes[n], 10); mp_zero(&b); for (rr = 0; rr < mp_count_bits(&a); rr++) { mp_mul_2d(&b, 1, &b); - b.dp[0] |= (rand()&1); + b.dp[0] |= lbit(); } mp_sub_d(&a, 1, &c); mp_mod(&b, &c, &b); mp_set(&c, 3); reset(); - for (rr = 0; rr < 20; rr++) { + for (rr = 0; rr < 35; rr++) { mp_exptmod(&c, &b, &a, &d); } tt = rdtsc(); @@ -112,15 +130,15 @@ int main(void) draw(&d); exit(0); } - printf("Exponentiating %d-bit took %llu cycles\n", mp_count_bits(&a), tt / ((unsigned long long)20)); + printf("Exponentiating %d-bit took %llu cycles\n", mp_count_bits(&a), tt / ((unsigned long long)35)); } } #endif - expt_n = lcm_n = gcd_n = add_n = sub_n = mul_n = div_n = sqr_n = mul2d_n = div2d_n = 0; + inv_n = expt_n = lcm_n = gcd_n = add_n = sub_n = mul_n = div_n = sqr_n = mul2d_n = div2d_n = 0; for (;;) { - printf("add=%7lu sub=%7lu mul=%7lu div=%7lu sqr=%7lu mul2d=%7lu div2d=%7lu gcd=%7lu lcm=%7lu expt=%7lu\r", add_n, sub_n, mul_n, div_n, sqr_n, mul2d_n, div2d_n, gcd_n, lcm_n, expt_n); + printf("%7lu/%7lu/%7lu/%7lu/%7lu/%7lu/%7lu/%7lu/%7lu/%7lu/%7lu\r", add_n, sub_n, mul_n, div_n, sqr_n, mul2d_n, div2d_n, gcd_n, lcm_n, expt_n, inv_n); fgets(cmd, 4095, stdin); cmd[strlen(cmd)-1] = 0; printf("%s ]\r",cmd); @@ -161,6 +179,33 @@ int main(void) draw(&a);draw(&b);draw(&c);draw(&d); return 0; } + + /* test the sign/unsigned storage functions */ + + rr = mp_signed_bin_size(&c); + mp_to_signed_bin(&c, cmd); + memset(cmd+rr, rand()&255, sizeof(cmd)-rr); + mp_read_signed_bin(&d, cmd, rr); + if (mp_cmp(&c, &d) != MP_EQ) { + printf("mp_signed_bin failure!\n"); + draw(&c); + draw(&d); + return 0; + } + + + rr = mp_unsigned_bin_size(&c); + mp_to_unsigned_bin(&c, cmd); + memset(cmd+rr, rand()&255, sizeof(cmd)-rr); + mp_read_unsigned_bin(&d, cmd, rr); + if (mp_cmp_mag(&c, &d) != MP_EQ) { + printf("mp_unsigned_bin failure!\n"); + draw(&c); + draw(&d); + return 0; + } + + } else if (!strcmp(cmd, "sub")) { ++sub_n; fgets(buf, 4095, stdin); mp_read_radix(&a, buf, 10); fgets(buf, 4095, stdin); mp_read_radix(&b, buf, 10); @@ -210,7 +255,7 @@ draw(&a);draw(&b);draw(&c); mp_gcd(&a, &b, &d); d.sign = c.sign; if (mp_cmp(&c, &d) != MP_EQ) { - printf("gcd %lu failure!\n", sqr_n); + printf("gcd %lu failure!\n", gcd_n); draw(&a);draw(&b);draw(&c);draw(&d); return 0; } @@ -221,7 +266,7 @@ draw(&a);draw(&b);draw(&c);draw(&d); mp_lcm(&a, &b, &d); d.sign = c.sign; if (mp_cmp(&c, &d) != MP_EQ) { - printf("lcm %lu failure!\n", sqr_n); + printf("lcm %lu failure!\n", lcm_n); draw(&a);draw(&b);draw(&c);draw(&d); return 0; } @@ -232,11 +277,26 @@ draw(&a);draw(&b);draw(&c);draw(&d); fgets(buf, 4095, stdin); mp_read_radix(&d, buf, 10); mp_exptmod(&a, &b, &c, &e); if (mp_cmp(&d, &e) != MP_EQ) { - printf("expt %lu failure!\n", sqr_n); + printf("expt %lu failure!\n", expt_n); draw(&a);draw(&b);draw(&c);draw(&d); draw(&e); return 0; } + } else if (!strcmp(cmd, "invmod")) { ++inv_n; + fgets(buf, 4095, stdin); mp_read_radix(&a, buf, 10); + fgets(buf, 4095, stdin); mp_read_radix(&b, buf, 10); + fgets(buf, 4095, stdin); mp_read_radix(&c, buf, 10); + mp_invmod(&a, &b, &d); + mp_mulmod(&d,&a,&b,&e); + if (mp_cmp_d(&e, 1) != MP_EQ) { + printf("inv [wrong value from MPI?!] failure\n"); + draw(&a);draw(&b);draw(&c);draw(&d); + mp_gcd(&a, &b, &e); + draw(&e); + return 0; + } + } + } return 0; } diff --git a/makefile b/makefile index 369ff75..8ca82da 100644 --- a/makefile +++ b/makefile @@ -1,7 +1,7 @@ CC = gcc CFLAGS += -Wall -W -O3 -funroll-loops -VERSION=0.02 +VERSION=0.03 default: test diff --git a/mtest/mtest.c b/mtest/mtest.c index d9f919a..393065f 100644 --- a/mtest/mtest.c +++ b/mtest/mtest.c @@ -82,7 +82,7 @@ int main(void) rng = fopen("/dev/urandom", "rb"); for (;;) { - n = fgetc(rng) % 10; + n = fgetc(rng) % 11; if (n == 0) { /* add tests */ @@ -211,6 +211,21 @@ int main(void) printf("%s\n", buf); mp_todecimal(&d, buf); printf("%s\n", buf); + } else if (n == 10) { + /* invmod test */ + rand_num2(&a); + rand_num2(&b); + b.sign = MP_ZPOS; + mp_gcd(&a, &b, &c); + if (mp_cmp_d(&c, 1) != 0) continue; + mp_invmod(&a, &b, &c); + printf("invmod\n"); + mp_todecimal(&a, buf); + printf("%s\n", buf); + mp_todecimal(&b, buf); + printf("%s\n", buf); + mp_todecimal(&c, buf); + printf("%s\n", buf); } } fclose(rng);