Line data Source code
1 : #include "tommath_private.h"
2 : #ifdef BN_MP_DIV_C
3 : /* LibTomMath, multiple-precision integer library -- Tom St Denis */
4 : /* SPDX-License-Identifier: Unlicense */
5 :
6 : #ifdef BN_MP_DIV_SMALL
7 :
8 : /* slower bit-bang division... also smaller */
9 : mp_err mp_div(const mp_int *a, const mp_int *b, mp_int *c, mp_int *d)
10 : {
11 : mp_int ta, tb, tq, q;
12 : int n, n2;
13 : mp_err err;
14 :
15 : /* is divisor zero ? */
16 : if (MP_IS_ZERO(b)) {
17 : return MP_VAL;
18 : }
19 :
20 : /* if a < b then q=0, r = a */
21 : if (mp_cmp_mag(a, b) == MP_LT) {
22 : if (d != NULL) {
23 : err = mp_copy(a, d);
24 : } else {
25 : err = MP_OKAY;
26 : }
27 : if (c != NULL) {
28 : mp_zero(c);
29 : }
30 : return err;
31 : }
32 :
33 : /* init our temps */
34 : if ((err = mp_init_multi(&ta, &tb, &tq, &q, NULL)) != MP_OKAY) {
35 : return err;
36 : }
37 :
38 :
39 : mp_set(&tq, 1uL);
40 : n = mp_count_bits(a) - mp_count_bits(b);
41 : if ((err = mp_abs(a, &ta)) != MP_OKAY) goto LBL_ERR;
42 : if ((err = mp_abs(b, &tb)) != MP_OKAY) goto LBL_ERR;
43 : if ((err = mp_mul_2d(&tb, n, &tb)) != MP_OKAY) goto LBL_ERR;
44 : if ((err = mp_mul_2d(&tq, n, &tq)) != MP_OKAY) goto LBL_ERR;
45 :
46 : while (n-- >= 0) {
47 : if (mp_cmp(&tb, &ta) != MP_GT) {
48 : if ((err = mp_sub(&ta, &tb, &ta)) != MP_OKAY) goto LBL_ERR;
49 : if ((err = mp_add(&q, &tq, &q)) != MP_OKAY) goto LBL_ERR;
50 : }
51 : if ((err = mp_div_2d(&tb, 1, &tb, NULL)) != MP_OKAY) goto LBL_ERR;
52 : if ((err = mp_div_2d(&tq, 1, &tq, NULL)) != MP_OKAY) goto LBL_ERR;
53 : }
54 :
55 : /* now q == quotient and ta == remainder */
56 : n = a->sign;
57 : n2 = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
58 : if (c != NULL) {
59 : mp_exch(c, &q);
60 : c->sign = MP_IS_ZERO(c) ? MP_ZPOS : n2;
61 : }
62 : if (d != NULL) {
63 : mp_exch(d, &ta);
64 : d->sign = MP_IS_ZERO(d) ? MP_ZPOS : n;
65 : }
66 : LBL_ERR:
67 : mp_clear_multi(&ta, &tb, &tq, &q, NULL);
68 : return err;
69 : }
70 :
71 : #else
72 :
73 : /* integer signed division.
74 : * c*b + d == a [e.g. a/b, c=quotient, d=remainder]
75 : * HAC pp.598 Algorithm 14.20
76 : *
77 : * Note that the description in HAC is horribly
78 : * incomplete. For example, it doesn't consider
79 : * the case where digits are removed from 'x' in
80 : * the inner loop. It also doesn't consider the
81 : * case that y has fewer than three digits, etc..
82 : *
83 : * The overall algorithm is as described as
84 : * 14.20 from HAC but fixed to treat these cases.
85 : */
86 1895 : mp_err mp_div(const mp_int *a, const mp_int *b, mp_int *c, mp_int *d)
87 : {
88 88 : mp_int q, x, y, t1, t2;
89 88 : int n, t, i, norm;
90 88 : mp_sign neg;
91 88 : mp_err err;
92 :
93 : /* is divisor zero ? */
94 1895 : if (MP_IS_ZERO(b)) {
95 0 : return MP_VAL;
96 : }
97 :
98 : /* if a < b then q=0, r = a */
99 1895 : if (mp_cmp_mag(a, b) == MP_LT) {
100 376 : if (d != NULL) {
101 376 : err = mp_copy(a, d);
102 : } else {
103 0 : err = MP_OKAY;
104 : }
105 376 : if (c != NULL) {
106 0 : mp_zero(c);
107 : }
108 376 : return err;
109 : }
110 :
111 1519 : if ((err = mp_init_size(&q, a->used + 2)) != MP_OKAY) {
112 0 : return err;
113 : }
114 1519 : q.used = a->used + 2;
115 :
116 1519 : if ((err = mp_init(&t1)) != MP_OKAY) goto LBL_Q;
117 :
118 1519 : if ((err = mp_init(&t2)) != MP_OKAY) goto LBL_T1;
119 :
120 1519 : if ((err = mp_init_copy(&x, a)) != MP_OKAY) goto LBL_T2;
121 :
122 1519 : if ((err = mp_init_copy(&y, b)) != MP_OKAY) goto LBL_X;
123 :
124 : /* fix the sign */
125 1519 : neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
126 1519 : x.sign = y.sign = MP_ZPOS;
127 :
128 : /* normalize both x and y, ensure that y >= b/2, [b == 2**MP_DIGIT_BIT] */
129 1519 : norm = mp_count_bits(&y) % MP_DIGIT_BIT;
130 1519 : if (norm < (MP_DIGIT_BIT - 1)) {
131 1519 : norm = (MP_DIGIT_BIT - 1) - norm;
132 1519 : if ((err = mp_mul_2d(&x, norm, &x)) != MP_OKAY) goto LBL_Y;
133 1519 : if ((err = mp_mul_2d(&y, norm, &y)) != MP_OKAY) goto LBL_Y;
134 : } else {
135 0 : norm = 0;
136 : }
137 :
138 : /* note hac does 0 based, so if used==5 then its 0,1,2,3,4, e.g. use 4 */
139 1519 : n = x.used - 1;
140 1519 : t = y.used - 1;
141 :
142 : /* while (x >= y*b**n-t) do { q[n-t] += 1; x -= y*b**{n-t} } */
143 : /* y = y*b**{n-t} */
144 1519 : if ((err = mp_lshd(&y, n - t)) != MP_OKAY) goto LBL_Y;
145 :
146 1672 : while (mp_cmp(&x, &y) != MP_LT) {
147 153 : ++(q.dp[n - t]);
148 153 : if ((err = mp_sub(&x, &y, &x)) != MP_OKAY) goto LBL_Y;
149 : }
150 :
151 : /* reset y by shifting it back down */
152 1519 : mp_rshd(&y, n - t);
153 :
154 : /* step 3. for i from n down to (t + 1) */
155 76663 : for (i = n; i >= (t + 1); i--) {
156 75070 : if (i > x.used) {
157 0 : continue;
158 : }
159 :
160 : /* step 3.1 if xi == yt then set q{i-t-1} to b-1,
161 : * otherwise set q{i-t-1} to (xi*b + x{i-1})/yt */
162 75070 : if (x.dp[i] == y.dp[t]) {
163 0 : q.dp[(i - t) - 1] = ((mp_digit)1 << (mp_digit)MP_DIGIT_BIT) - (mp_digit)1;
164 : } else {
165 3597 : mp_word tmp;
166 75070 : tmp = (mp_word)x.dp[i] << (mp_word)MP_DIGIT_BIT;
167 75070 : tmp |= (mp_word)x.dp[i - 1];
168 75070 : tmp /= (mp_word)y.dp[t];
169 75070 : if (tmp > (mp_word)MP_MASK) {
170 0 : tmp = MP_MASK;
171 : }
172 75070 : q.dp[(i - t) - 1] = (mp_digit)(tmp & (mp_word)MP_MASK);
173 : }
174 :
175 : /* while (q{i-t-1} * (yt * b + y{t-1})) >
176 : xi * b**2 + xi-1 * b + xi-2
177 :
178 : do q{i-t-1} -= 1;
179 : */
180 75070 : q.dp[(i - t) - 1] = (q.dp[(i - t) - 1] + 1uL) & (mp_digit)MP_MASK;
181 6009 : do {
182 119766 : q.dp[(i - t) - 1] = (q.dp[(i - t) - 1] - 1uL) & (mp_digit)MP_MASK;
183 :
184 : /* find left hand */
185 119766 : mp_zero(&t1);
186 119766 : t1.dp[0] = ((t - 1) < 0) ? 0u : y.dp[t - 1];
187 119766 : t1.dp[1] = y.dp[t];
188 119766 : t1.used = 2;
189 119766 : if ((err = mp_mul_d(&t1, q.dp[(i - t) - 1], &t1)) != MP_OKAY) goto LBL_Y;
190 :
191 : /* find right hand */
192 119766 : t2.dp[0] = ((i - 2) < 0) ? 0u : x.dp[i - 2];
193 119766 : t2.dp[1] = x.dp[i - 1]; /* i >= 1 always holds */
194 119766 : t2.dp[2] = x.dp[i];
195 119766 : t2.used = 3;
196 119766 : } while (mp_cmp_mag(&t1, &t2) == MP_GT);
197 :
198 : /* step 3.3 x = x - q{i-t-1} * y * b**{i-t-1} */
199 75070 : if ((err = mp_mul_d(&y, q.dp[(i - t) - 1], &t1)) != MP_OKAY) goto LBL_Y;
200 :
201 75070 : if ((err = mp_lshd(&t1, (i - t) - 1)) != MP_OKAY) goto LBL_Y;
202 :
203 75070 : if ((err = mp_sub(&x, &t1, &x)) != MP_OKAY) goto LBL_Y;
204 :
205 : /* if x < 0 then { x = x + y*b**{i-t-1}; q{i-t-1} -= 1; } */
206 75070 : if (x.sign == MP_NEG) {
207 0 : if ((err = mp_copy(&y, &t1)) != MP_OKAY) goto LBL_Y;
208 0 : if ((err = mp_lshd(&t1, (i - t) - 1)) != MP_OKAY) goto LBL_Y;
209 0 : if ((err = mp_add(&x, &t1, &x)) != MP_OKAY) goto LBL_Y;
210 :
211 0 : q.dp[(i - t) - 1] = (q.dp[(i - t) - 1] - 1uL) & MP_MASK;
212 : }
213 : }
214 :
215 : /* now q is the quotient and x is the remainder
216 : * [which we have to normalize]
217 : */
218 :
219 : /* get sign before writing to c */
220 1519 : x.sign = (x.used == 0) ? MP_ZPOS : a->sign;
221 :
222 1519 : if (c != NULL) {
223 0 : mp_clamp(&q);
224 0 : mp_exch(&q, c);
225 0 : c->sign = neg;
226 : }
227 :
228 1519 : if (d != NULL) {
229 1519 : if ((err = mp_div_2d(&x, norm, &x, NULL)) != MP_OKAY) goto LBL_Y;
230 1519 : mp_exch(&x, d);
231 : }
232 :
233 1445 : err = MP_OKAY;
234 :
235 1519 : LBL_Y:
236 1519 : mp_clear(&y);
237 1519 : LBL_X:
238 1519 : mp_clear(&x);
239 1519 : LBL_T2:
240 1519 : mp_clear(&t2);
241 1519 : LBL_T1:
242 1519 : mp_clear(&t1);
243 1519 : LBL_Q:
244 1519 : mp_clear(&q);
245 1519 : return err;
246 : }
247 :
248 : #endif
249 :
250 : #endif
|