use crate::{
arch::word::Word,
ibig::IBig,
math,
memory::{self, MemoryAllocation},
modular::{
modulo::{Modulo, ModuloLarge, ModuloRepr, ModuloSmall, ModuloSmallRaw},
modulo_ring::ModuloRingSmall,
},
primitive::{double_word, split_double_word, PrimitiveUnsigned, WORD_BITS, WORD_BITS_USIZE},
sign::Sign::*,
ubig::{Repr::*, UBig},
};
impl<'a> Modulo<'a> {
#[inline]
pub fn pow(&self, exp: &UBig) -> Modulo<'a> {
match self.repr() {
ModuloRepr::Small(self_small) => self_small.pow(exp).into(),
ModuloRepr::Large(self_large) => self_large.pow(exp).into(),
}
}
#[inline]
pub fn pow_signed(&self, exp: &IBig) -> Modulo<'a> {
match exp.sign() {
Positive => self.pow(exp.magnitude()),
Negative => match self.inverse() {
None => panic!("Non-invertible Modulo taken to a negative power"),
Some(inv) => inv.pow(exp.magnitude()),
},
}
}
}
impl ModuloSmallRaw {
#[inline]
pub(crate) const fn pow_word(self, exp: Word, ring: &ModuloRingSmall) -> ModuloSmallRaw {
if exp == 0 {
return ModuloSmallRaw::from_word(1, ring);
}
let bits = WORD_BITS - 1 - exp.leading_zeros();
self.pow_helper(bits, self, exp, ring)
}
#[inline]
const fn pow_helper(
self,
mut bits: u32,
base: ModuloSmallRaw,
exp: Word,
ring: &ModuloRingSmall,
) -> ModuloSmallRaw {
let mut res = self;
while bits > 0 {
res = res.mul(res, ring);
bits -= 1;
if exp & (1 << bits) != 0 {
res = res.mul(base, ring);
}
}
res
}
}
impl<'a> ModuloSmall<'a> {
#[inline]
fn pow(&self, exp: &UBig) -> ModuloSmall<'a> {
match exp.repr() {
Small(0) => ModuloSmall::from_ubig(&UBig::from_word(1), self.ring()),
Small(1) => self.clone(),
Small(2) => {
let res = self.raw().mul(self.raw(), self.ring());
ModuloSmall::new(res, self.ring())
}
_ => self.pow_nontrivial(exp),
}
}
fn pow_nontrivial(&self, exp: &UBig) -> ModuloSmall<'a> {
debug_assert!(*exp >= UBig::from_word(3));
let exp_words = exp.as_words();
let mut n = exp_words.len() - 1;
let mut val = self.raw().pow_word(exp_words[n], self.ring());
while n != 0 {
n -= 1;
val = val.pow_helper(WORD_BITS, self.raw(), exp_words[n], self.ring());
}
ModuloSmall::new(val, self.ring())
}
}
impl<'a> ModuloLarge<'a> {
fn pow(&self, exp: &UBig) -> ModuloLarge<'a> {
match exp.repr() {
Small(0) => ModuloLarge::from_ubig(UBig::from_word(1), self.ring()),
Small(1) => self.clone(),
_ => self.pow_nontrivial(exp),
}
}
fn pow_nontrivial(&self, exp: &UBig) -> ModuloLarge<'a> {
debug_assert!(*exp >= UBig::from_word(2));
let n = self.ring().normalized_modulus().len();
let window_len = ModuloLarge::choose_pow_window_len(exp.bit_len());
#[allow(clippy::redundant_closure)]
let table_words = ((1usize << (window_len - 1)) - 1)
.checked_mul(n)
.unwrap_or_else(|| memory::panic_out_of_memory());
let memory_requirement = memory::add_layout(
memory::array_layout::<Word>(table_words),
self.ring().mul_memory_requirement(),
);
let mut allocation = MemoryAllocation::new(memory_requirement);
let mut memory = allocation.memory();
let (table, mut memory) = memory.allocate_slice_fill::<Word>(table_words, 0);
let mut val = self.clone();
val.mul_in_place(self, &mut memory);
for i in 1..(1 << (window_len - 1)) {
let (prev, cur) = if i == 1 {
(self.normalized_value(), &mut table[0..n])
} else {
let (prev, cur) = (&mut table[(i - 2) * n..i * n]).split_at_mut(n);
(&*prev, cur)
};
cur.copy_from_slice(self.ring().mul_normalized(
prev,
val.normalized_value(),
&mut memory,
));
}
let exp_words = exp.as_words();
let mut bit = exp.bit_len() - 2;
loop {
let word_idx = bit / WORD_BITS_USIZE;
let bit_idx = (bit % WORD_BITS_USIZE) as u32;
let cur_word = exp_words[word_idx];
if cur_word & (1 << bit_idx) != 0 {
let next_word = if word_idx == 0 {
0
} else {
exp_words[word_idx - 1]
};
let (mut window, _) = split_double_word(
double_word(next_word, cur_word) >> (bit_idx + 1 + WORD_BITS - window_len),
);
window &= math::ones::<Word>(window_len);
let num_bits = window_len - window.trailing_zeros();
window >>= window_len - num_bits;
for _ in 0..num_bits - 1 {
val.square_in_place(&mut memory);
}
bit -= (num_bits as usize) - 1;
debug_assert!(window & 1 == 1);
let entry_idx = (window >> 1) as usize;
let entry = if entry_idx == 0 {
self.normalized_value()
} else {
&table[(entry_idx - 1) * n..entry_idx * n]
};
val.mul_normalized_in_place(entry, &mut memory);
}
if bit == 0 {
break;
}
bit -= 1;
val.square_in_place(&mut memory);
}
val
}
fn choose_pow_window_len(n: usize) -> u32 {
let cost = |window_size| (1usize << (window_size - 1)) - 1 + n / (window_size as usize + 1);
let mut window_size = 1;
let mut c = cost(window_size);
while window_size + 1 < WORD_BITS.min(usize::BIT_SIZE) {
let c2 = cost(window_size + 1);
if c <= c2 {
break;
}
window_size += 1;
c = c2;
}
window_size
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pow_word() {
let ring = ModuloRingSmall::new(100);
let a = ModuloSmallRaw::from_word(17, &ring);
assert_eq!(a.pow_word(0, &ring).residue(&ring), 1);
assert_eq!(a.pow_word(15, &ring).residue(&ring), 93);
}
}