use bnum::BInt; use bnum::cast::As; use typenum::{Sum,Unsigned}; use crate::fixed::Fixed; use fixed_wide_traits::wide::WideMul; use std::marker::PhantomData; macro_rules! impl_wide_mul { ($lhs: expr,$rhs: expr) => { impl<A,B> WideMul<Fixed<$rhs,B>> for Fixed<$lhs,A> where A:std::ops::Add<B>, B:Unsigned, { type Output=Fixed<{$lhs+$rhs},Sum<A,B>>; fn wide_mul(self,rhs:Fixed<$rhs,B>)->Self::Output{ Fixed{ bits:self.bits.as_::<BInt<{$lhs+$rhs}>>()*rhs.bits.as_::<BInt<{$lhs+$rhs}>>(), frac:PhantomData, } } } }; } macro_rules! impl_wide_mul_all { ($(($x:expr, $y:expr)),*) => { $( impl_wide_mul!($x, $y); )* }; } //const generics sidestepped wahoo impl_wide_mul_all!( (1,1),(2,1),(3,1),(4,1),(5,1),(6,1),(7,1),(8,1), (1,2),(2,2),(3,2),(4,2),(5,2),(6,2),(7,2),(8,2), (1,3),(2,3),(3,3),(4,3),(5,3),(6,3),(7,3),(8,3), (1,4),(2,4),(3,4),(4,4),(5,4),(6,4),(7,4),(8,4), (1,5),(2,5),(3,5),(4,5),(5,5),(6,5),(7,5),(8,5), (1,6),(2,6),(3,6),(4,6),(5,6),(6,6),(7,6),(8,6), (1,7),(2,7),(3,7),(4,7),(5,7),(6,7),(7,7),(8,7), (1,8),(2,8),(3,8),(4,8),(5,8),(6,8),(7,8),(8,8) ); impl<const SRC:usize,Frac> Fixed<SRC,Frac>{ pub fn widen<const DST:usize>(self)->Fixed<DST,Frac>{ Fixed{ bits:self.bits.as_::<BInt<DST>>(), frac:PhantomData, } } } impl<const CHUNKS:usize,Frac:Unsigned> Fixed<CHUNKS,Frac> where Fixed::<CHUNKS,Frac>:WideMul, <Fixed::<CHUNKS,Frac> as WideMul>::Output:Ord, { pub fn sqrt_unchecked(self)->Self{ //pow2 must be the minimum power of two which when squared is greater than self //the algorithm: //1. count "used" bits to the left of the decimal //2. add one //This is the power of two which is greater than self. //3. divide by 2 via >>1 //4. add on fractional offset //Voila //0001.0000 Fixed<u8,4> //sqrt //0110.0000 //pow2 = 0100.0000 let mut pow2=Self{ bits:BInt::<CHUNKS>::ONE.shl(((((CHUNKS as i32*64-Frac::I32-(self.bits.leading_zeros() as i32)+1)>>1)+Frac::I32) as u32).saturating_sub(1)), frac:PhantomData, }; let mut result=pow2; //cheat to make the types match let wide_self=self.wide_mul(Fixed::<CHUNKS,Frac>::ONE); loop{ if pow2==Self::ZERO{ break result; } //TODO: flip a single bit instead of adding a power of 2 let new_result=result+pow2; //note that the implicit truncation in the multiply //means that the algorithm can return a result which squares to a number greater than the input. match wide_self.cmp(&new_result.wide_mul(new_result)){ core::cmp::Ordering::Less=>(), core::cmp::Ordering::Equal=>break new_result, core::cmp::Ordering::Greater=>result=new_result, } pow2>>=1; } } pub fn sqrt(self)->Self{ if self<Self::ZERO{ panic!("Square root less than zero") }else{ self.sqrt_unchecked() } } pub fn sqrt_checked(self)->Option<Self>{ if self<Self::ZERO{ None }else{ Some(self.sqrt_unchecked()) } } }