ref: 79d9a3789b58d821eda59843f7b90211001765f1
dir: /3rd/mp/mplogic.c/
#include "platform.h" /* mplogic calculates b1|b2 subject to the following flag bits (fl) bit 0: subtract 1 from b1 bit 1: invert b1 bit 2: subtract 1 from b2 bit 3: invert b2 bit 4: add 1 to output bit 5: invert output it inverts appropriate bits automatically depending on the signs of the inputs */ static void mplogic(mpint *b1, mpint *b2, mpint *sum, uint32_t fl) { mpint *t; mpdigit *dp1, *dp2, *dpo, d1, d2, d; uint32_t c1, c2, co, i; assert(((b1->flags | b2->flags | sum->flags) & MPtimesafe) == 0); if(b1->sign < 0) fl ^= 0x03; if(b2->sign < 0) fl ^= 0x0c; sum->sign = (int)(((fl|fl>>2)^fl>>4)<<30)>>31|1; if(sum->sign < 0) fl ^= 0x30; if(b2->top > b1->top){ t = b1; b1 = b2; b2 = t; fl = fl >> 2 & 0x03 | fl << 2 & 0x0c | fl & 0x30; } mpbits(sum, b1->top*Dbits+1); dp1 = b1->p; dp2 = b2->p; dpo = sum->p; c1 = fl & 1; c2 = fl >> 2 & 1; co = fl >> 4 & 1; for(i = 0; i < b1->top; i++){ d1 = dp1[i] - c1; if(i < b2->top) d2 = dp2[i] - c2; else d2 = 0; if(d1 != (mpdigit)-1) c1 = 0; if(d2 != (mpdigit)-1) c2 = 0; if((fl & 2) != 0) d1 ^= -(mpdigit)1; if((fl & 8) != 0) d2 ^= -(mpdigit)1; d = d1 | d2; if((fl & 32) != 0) d ^= -(mpdigit)1; d += co; if(d != 0) co = 0; dpo[i] = d; } sum->top = i; if(co) dpo[sum->top++] = co; mpnorm(sum); } void mpor(mpint *b1, mpint *b2, mpint *sum) { mplogic(b1, b2, sum, 0); } void mpand(mpint *b1, mpint *b2, mpint *sum) { mplogic(b1, b2, sum, 0x2a); } void mpbic(mpint *b1, mpint *b2, mpint *sum) { mplogic(b1, b2, sum, 0x22); } void mpnot(mpint *b, mpint *r) { mpadd(b, mpone, r); if(r->top != 0) r->sign ^= -2; } void mpxor(mpint *b1, mpint *b2, mpint *sum) { mpint *t; mpdigit *dp1, *dp2, *dpo, d1, d2, d; uint32_t c1, c2, co, i; int fl; assert(((b1->flags | b2->flags | sum->flags) & MPtimesafe) == 0); if(b2->top > b1->top){ t = b1; b1 = b2; b2 = t; } fl = (b1->sign & 10) ^ (b2->sign & 12); sum->sign = (int)(fl << 28) >> 31 | 1; mpbits(sum, b1->top*Dbits+1); dp1 = b1->p; dp2 = b2->p; dpo = sum->p; c1 = fl >> 1 & 1; c2 = fl >> 2 & 1; co = fl >> 3 & 1; for(i = 0; i < b1->top; i++){ d1 = dp1[i] - c1; if(i < b2->top) d2 = dp2[i] - c2; else d2 = 0; if(d1 != (mpdigit)-1) c1 = 0; if(d2 != (mpdigit)-1) c2 = 0; d = d1 ^ d2; d += co; if(d != 0) co = 0; dpo[i] = d; } sum->top = i; if(co) dpo[sum->top++] = co; mpnorm(sum); } void mpasr(mpint *b, int n, mpint *r) { if(b->sign > 0 || n <= 0){ mpright(b, n, r); return; } mpadd(b, mpone, r); mpright(r, n, r); mpsub(r, mpone, r); }