From 4182b78c7e1988349616a70d908e0185886f853f Mon Sep 17 00:00:00 2001 From: sheaf Date: Tue, 27 Apr 2021 01:11:43 +0200 Subject: [PATCH] fix incorrect deflation code --- src/lib/Math/Bezier/Cubic.hs | 2 +- src/lib/Math/Bezier/Cubic/Fit.hs | 2 +- src/lib/Math/Roots.hs | 157 ++++++++++++++++++++++--------- 3 files changed, 115 insertions(+), 46 deletions(-) diff --git a/src/lib/Math/Bezier/Cubic.hs b/src/lib/Math/Bezier/Cubic.hs index d7185e2..2942c6d 100644 --- a/src/lib/Math/Bezier/Cubic.hs +++ b/src/lib/Math/Bezier/Cubic.hs @@ -182,7 +182,7 @@ closestPoint closestPoint pts c = pickClosest ( 0 :| 1 : roots ) -- todo: also include the self-intersection point if one exists where roots :: [ r ] - roots = filter ( \ r -> r > 0 && r < 1 ) ( realRoots 2000 $ ddist @v pts c ) + roots = filter ( \ r -> r > 0 && r < 1 ) ( realRoots 500 $ ddist @v pts c ) pickClosest :: NonEmpty r -> ArgMin r ( r, p ) pickClosest ( s :| ss ) = go s q nm0 ss diff --git a/src/lib/Math/Bezier/Cubic/Fit.hs b/src/lib/Math/Bezier/Cubic/Fit.hs index 8a21d88..133997f 100644 --- a/src/lib/Math/Bezier/Cubic/Fit.hs +++ b/src/lib/Math/Bezier/Cubic/Fit.hs @@ -262,7 +262,7 @@ fitPiece dist_tol t_tol maxIters p tp qs r tr = let laguerreStepResult :: Complex Double laguerreStepResult = runST do - coeffs <- unsafeThawPrimArray . primArrayFromListN 6 . map (:+ 0) + coeffs <- unsafeThawPrimArray . primArrayFromListN 6 $ Cubic.ddist @( Vector2D Double ) bez q laguerre epsilon 1 coeffs ( ti :+ 0 ) ti' <- case laguerreStepResult of diff --git a/src/lib/Math/Roots.hs b/src/lib/Math/Roots.hs index 08a0d7b..ad7e82c 100644 --- a/src/lib/Math/Roots.hs +++ b/src/lib/Math/Roots.hs @@ -10,11 +10,11 @@ module Math.Roots where -- base import Control.Monad - ( unless ) + ( unless, when ) import Control.Monad.ST ( ST, runST ) import Data.Complex - ( Complex(..), magnitude ) + ( Complex(..), conjugate, magnitude, realPart, imagPart ) import Data.Maybe ( mapMaybe ) @@ -68,71 +68,139 @@ solveQuadratic a0 a1 a2 disc :: a disc = a1 * a1 - 4 * a0 * a2 --- | Find real roots of a polynomial. +-- | Find real roots of a polynomial with real coefficients. -- -- Coefficients are given in order of decreasing degree, e.g.: -- x² + 7 is given by [ 1, 0, 7 ]. realRoots :: forall a. ( RealFloat a, Prim a, NFData a ) => Int -> [ a ] -> [ a ] -realRoots maxIters coeffs = mapMaybe isReal ( roots epsilon maxIters ( map (:+ 0) coeffs ) ) +realRoots maxIters coeffs = mapMaybe isReal ( roots epsilon maxIters coeffs ) where isReal :: Complex a -> Maybe a isReal ( a :+ b ) | abs b < epsilon = Just a | otherwise = Nothing --- | Compute all roots of a polynomial using Laguerre's method and (forward) deflation. +-- | Compute all roots of a polynomial with real coefficients using Laguerre's method and (forward) deflation. -- -- Polynomial coefficients are given in order of descending degree (e.g. constant coefficient last). -- -- N.B. The forward deflation process is only guaranteed to be numerically stable -- if Laguerre's method finds roots in increasing order of magnitude. -roots :: forall a. ( RealFloat a, Prim a, NFData a ) => a -> Int -> [ Complex a ] -> [ Complex a ] +roots :: forall a. ( RealFloat a, Prim a, NFData a ) => a -> Int -> [ a ] -> [ Complex a ] roots eps maxIters coeffs = runST do let - coeffPrimArray :: PrimArray ( Complex a ) + coeffPrimArray :: PrimArray a coeffPrimArray = primArrayFromList coeffs sz :: Int sz = sizeofPrimArray coeffPrimArray - ( p :: MutablePrimArray s ( Complex a ) ) <- unsafeThawPrimArray coeffPrimArray + ( p :: MutablePrimArray s a ) <- unsafeThawPrimArray coeffPrimArray let go :: Int -> [ Complex a ] -> ST s [ Complex a ] go !i rs = do - !r <- force <$> laguerre eps maxIters p 0 -- Start Laguerre's method at 0 for best chance of numerical stability. + -- Estimate the root with minimum absolute value in order to + -- improve numerical stability of Laguerre's method with forward deflation. + !z0 <- minAbsStartPoint p + !r <- force <$> laguerre eps maxIters p z0 if i <= 2 then pure ( r : rs ) - else do - deflate r p - go ( i - 1 ) ( r : rs ) + else + -- real root + if abs ( imagPart r ) < epsilon + then do + deflate ( realPart r ) p + go ( i - 1 ) ( r : rs ) + else do + deflateConjugatePair r p + go ( i - 2 ) ( r : conjugate r : rs ) go sz [] --- | Deflate a polynomial: factor out a root of the polynomial. --- --- The polynomial must have degree at least 2. +-- | Estimate the root with smallest absolute value. +-- +-- Polynomial coefficients are given in order of descending degree (e.g. constant coefficient last). +minAbsStartPoint :: forall a m s. ( RealFloat a, Prim a, PrimMonad m, s ~ PrimState m ) => MutablePrimArray s a -> m ( Complex a ) +minAbsStartPoint p = do + n <- subtract 1 <$> getSizeofMutablePrimArray p + an <- readPrimArray p n + if abs an < epsilon + then pure 0 + else do + a0 <- readPrimArray p 0 + let + r, m0 :: Complex a + r = log ( abs an :+ 0 ) + m0 = exp $ ( r - log ( abs a0 :+ 0 ) ) / fromIntegral n + go :: Int -> Complex a -> m ( Complex a ) + go i m + | i >= n + = pure m + | otherwise + = do + ai <- readPrimArray p i + if abs ai < epsilon + then go ( i + 1 ) m + else do + let + mi :: Complex a + mi = exp $ ( r - log ( abs ai :+ 0 ) ) / fromIntegral ( n - i ) + if magnitude mi < magnitude m + then go ( i + 1 ) mi + else go ( i + 1 ) m + m <- go 1 m0 + pure ( 0.5 * m ) + +-- | Forward deflation of a polynomial by a root: factors out the root. deflate :: forall a m s. ( Num a, Prim a, PrimMonad m, s ~ PrimState m ) => a -> MutablePrimArray s a -> m () deflate r p = do deg <- subtract 1 <$> getSizeofMutablePrimArray p - case compare deg 2 of - LT -> pure () - EQ -> shrinkMutablePrimArray p deg - GT -> do - shrinkMutablePrimArray p deg - let - go :: a -> Int -> m () - go !b !i = unless ( i >= deg ) do - ai <- readPrimArray p i - writePrimArray p i ( ai + r * b ) - go ai ( i + 1 ) - a0 <- readPrimArray p 0 - go a0 1 + when ( deg >= 2 ) do + shrinkMutablePrimArray p deg + let + go :: a -> Int -> m () + go !b !i = unless ( i >= deg ) do + ai <- readPrimArray p i + let + bi :: a + !bi = ai + r * b + writePrimArray p i bi + go bi ( i + 1 ) + a0 <- readPrimArray p 0 + go a0 1 + +-- | Forward deflation of a polynomial with real coefficients by a pair of complex-conjugate roots. +deflateConjugatePair :: forall a m s. ( Num a, Prim a, PrimMonad m, s ~ PrimState m ) => Complex a -> MutablePrimArray s a -> m () +deflateConjugatePair ( x :+ y ) p = do + deg <- subtract 1 <$> getSizeofMutablePrimArray p + when ( deg >= 3 ) do + shrinkMutablePrimArray p ( deg - 1 ) + let + c1, c2 :: a + !c1 = 2 * x + !c2 = x * x + y * y + a0 <- readPrimArray p 0 + a1 <- readPrimArray p 1 + let + b1 :: a + !b1 = a1 + c1 * a0 + writePrimArray p 1 b1 + let + go :: a -> a -> Int -> m () + go !b !b' !i = unless ( i >= deg - 1 ) do + ai <- readPrimArray p i + let + bi :: a + !bi = ai + c1 * b - c2 * b' + writePrimArray p i bi + go bi b ( i + 1 ) + go b1 a0 2 -- | Laguerre's method. laguerre :: forall a m s . ( RealFloat a, Prim a, PrimMonad m, s ~ PrimState m ) - => a -- ^ error tolerance - -> Int -- ^ max number of iterations - -> MutablePrimArray s ( Complex a ) -- ^ polynomial - -> Complex a -- ^ initial point + => a -- ^ error tolerance + -> Int -- ^ max number of iterations + -> MutablePrimArray s a -- ^ polynomial + -> Complex a -- ^ initial point -> m ( Complex a ) laguerre eps maxIters p x0 = do p' <- derivative p @@ -150,11 +218,11 @@ laguerre eps maxIters p x0 = do laguerreStep :: forall a m s . ( RealFloat a, Prim a, PrimMonad m, s ~ PrimState m ) - => a -- ^ error tolerance - -> MutablePrimArray s ( Complex a ) -- ^ polynomial - -> MutablePrimArray s ( Complex a ) -- ^ first derivative of polynomial - -> MutablePrimArray s ( Complex a ) -- ^ second derivative of polynomial - -> Complex a -- ^ initial point + => a -- ^ error tolerance + -> MutablePrimArray s a -- ^ polynomial + -> MutablePrimArray s a -- ^ first derivative of polynomial + -> MutablePrimArray s a -- ^ second derivative of polynomial + -> Complex a -- ^ initial point -> m ( Complex a ) laguerreStep eps p p' p'' x = do n <- fromIntegral @_ @a <$> getSizeofMutablePrimArray p @@ -168,7 +236,8 @@ laguerreStep eps p p' p'' x = do g = p'x / px g² = g * g h = g² - p''x / px - delta = sqrt $ ( n - 1 ) *: ( n *: h - g² ) + mult = 1 --max 1 . min n . fromIntegral . round @_ @Integer . realPart $ log p'x / log ( px / p'x ) + delta = sqrt $ ( n - 1 ) *: ( ( n / mult ) *: h - g² ) gp = g + delta gm = g - delta denom @@ -182,24 +251,24 @@ laguerreStep eps p p' p'' x = do (*:) :: a -> Complex a -> Complex a r *: (u :+ v) = ( r * u ) :+ ( r * v ) --- | Evaluate a polynomial. +-- | Evaluate a polynomial with real coefficients at a complex number. eval :: forall a m s - . ( Num a, Prim a, PrimMonad m, s ~ PrimState m ) - => MutablePrimArray s a -> a -> m a + . ( RealFloat a, Prim a, PrimMonad m, s ~ PrimState m ) + => MutablePrimArray s a -> Complex a -> m ( Complex a ) eval p x = do n <- getSizeofMutablePrimArray p let - go :: a -> Int -> m a + go :: Complex a -> Int -> m ( Complex a ) go !a !i | i >= n = pure a | otherwise = do b <- readPrimArray p i - go ( b + x * a ) ( i + 1 ) + go ( ( b :+ 0 ) + x * a ) ( i + 1 ) an <- readPrimArray p 0 - go an 1 + go ( an :+ 0 ) 1 -- | Derivative of a polynomial. derivative