fix incorrect deflation code

This commit is contained in:
sheaf 2021-04-27 01:11:43 +02:00
parent d604d4120e
commit 4182b78c7e
3 changed files with 115 additions and 46 deletions

View file

@ -182,7 +182,7 @@ closestPoint
closestPoint pts c = pickClosest ( 0 :| 1 : roots ) -- todo: also include the self-intersection point if one exists
where
roots :: [ r ]
roots = filter ( \ r -> r > 0 && r < 1 ) ( realRoots 2000 $ ddist @v pts c )
roots = filter ( \ r -> r > 0 && r < 1 ) ( realRoots 500 $ ddist @v pts c )
pickClosest :: NonEmpty r -> ArgMin r ( r, p )
pickClosest ( s :| ss ) = go s q nm0 ss

View file

@ -262,7 +262,7 @@ fitPiece dist_tol t_tol maxIters p tp qs r tr =
let
laguerreStepResult :: Complex Double
laguerreStepResult = runST do
coeffs <- unsafeThawPrimArray . primArrayFromListN 6 . map (:+ 0)
coeffs <- unsafeThawPrimArray . primArrayFromListN 6
$ Cubic.ddist @( Vector2D Double ) bez q
laguerre epsilon 1 coeffs ( ti :+ 0 )
ti' <- case laguerreStepResult of

View file

@ -10,11 +10,11 @@ module Math.Roots where
-- base
import Control.Monad
( unless )
( unless, when )
import Control.Monad.ST
( ST, runST )
import Data.Complex
( Complex(..), magnitude )
( Complex(..), conjugate, magnitude, realPart, imagPart )
import Data.Maybe
( mapMaybe )
@ -68,71 +68,139 @@ solveQuadratic a0 a1 a2
disc :: a
disc = a1 * a1 - 4 * a0 * a2
-- | Find real roots of a polynomial.
-- | Find real roots of a polynomial with real coefficients.
--
-- Coefficients are given in order of decreasing degree, e.g.:
-- x² + 7 is given by [ 1, 0, 7 ].
realRoots :: forall a. ( RealFloat a, Prim a, NFData a ) => Int -> [ a ] -> [ a ]
realRoots maxIters coeffs = mapMaybe isReal ( roots epsilon maxIters ( map (:+ 0) coeffs ) )
realRoots maxIters coeffs = mapMaybe isReal ( roots epsilon maxIters coeffs )
where
isReal :: Complex a -> Maybe a
isReal ( a :+ b )
| abs b < epsilon = Just a
| otherwise = Nothing
-- | Compute all roots of a polynomial using Laguerre's method and (forward) deflation.
-- | Compute all roots of a polynomial with real coefficients using Laguerre's method and (forward) deflation.
--
-- Polynomial coefficients are given in order of descending degree (e.g. constant coefficient last).
--
-- N.B. The forward deflation process is only guaranteed to be numerically stable
-- if Laguerre's method finds roots in increasing order of magnitude.
roots :: forall a. ( RealFloat a, Prim a, NFData a ) => a -> Int -> [ Complex a ] -> [ Complex a ]
roots :: forall a. ( RealFloat a, Prim a, NFData a ) => a -> Int -> [ a ] -> [ Complex a ]
roots eps maxIters coeffs = runST do
let
coeffPrimArray :: PrimArray ( Complex a )
coeffPrimArray :: PrimArray a
coeffPrimArray = primArrayFromList coeffs
sz :: Int
sz = sizeofPrimArray coeffPrimArray
( p :: MutablePrimArray s ( Complex a ) ) <- unsafeThawPrimArray coeffPrimArray
( p :: MutablePrimArray s a ) <- unsafeThawPrimArray coeffPrimArray
let
go :: Int -> [ Complex a ] -> ST s [ Complex a ]
go !i rs = do
!r <- force <$> laguerre eps maxIters p 0 -- Start Laguerre's method at 0 for best chance of numerical stability.
-- Estimate the root with minimum absolute value in order to
-- improve numerical stability of Laguerre's method with forward deflation.
!z0 <- minAbsStartPoint p
!r <- force <$> laguerre eps maxIters p z0
if i <= 2
then pure ( r : rs )
else do
deflate r p
go ( i - 1 ) ( r : rs )
else
-- real root
if abs ( imagPart r ) < epsilon
then do
deflate ( realPart r ) p
go ( i - 1 ) ( r : rs )
else do
deflateConjugatePair r p
go ( i - 2 ) ( r : conjugate r : rs )
go sz []
-- | Deflate a polynomial: factor out a root of the polynomial.
--
-- The polynomial must have degree at least 2.
-- | Estimate the root with smallest absolute value.
--
-- Polynomial coefficients are given in order of descending degree (e.g. constant coefficient last).
minAbsStartPoint :: forall a m s. ( RealFloat a, Prim a, PrimMonad m, s ~ PrimState m ) => MutablePrimArray s a -> m ( Complex a )
minAbsStartPoint p = do
n <- subtract 1 <$> getSizeofMutablePrimArray p
an <- readPrimArray p n
if abs an < epsilon
then pure 0
else do
a0 <- readPrimArray p 0
let
r, m0 :: Complex a
r = log ( abs an :+ 0 )
m0 = exp $ ( r - log ( abs a0 :+ 0 ) ) / fromIntegral n
go :: Int -> Complex a -> m ( Complex a )
go i m
| i >= n
= pure m
| otherwise
= do
ai <- readPrimArray p i
if abs ai < epsilon
then go ( i + 1 ) m
else do
let
mi :: Complex a
mi = exp $ ( r - log ( abs ai :+ 0 ) ) / fromIntegral ( n - i )
if magnitude mi < magnitude m
then go ( i + 1 ) mi
else go ( i + 1 ) m
m <- go 1 m0
pure ( 0.5 * m )
-- | Forward deflation of a polynomial by a root: factors out the root.
deflate :: forall a m s. ( Num a, Prim a, PrimMonad m, s ~ PrimState m ) => a -> MutablePrimArray s a -> m ()
deflate r p = do
deg <- subtract 1 <$> getSizeofMutablePrimArray p
case compare deg 2 of
LT -> pure ()
EQ -> shrinkMutablePrimArray p deg
GT -> do
shrinkMutablePrimArray p deg
let
go :: a -> Int -> m ()
go !b !i = unless ( i >= deg ) do
ai <- readPrimArray p i
writePrimArray p i ( ai + r * b )
go ai ( i + 1 )
a0 <- readPrimArray p 0
go a0 1
when ( deg >= 2 ) do
shrinkMutablePrimArray p deg
let
go :: a -> Int -> m ()
go !b !i = unless ( i >= deg ) do
ai <- readPrimArray p i
let
bi :: a
!bi = ai + r * b
writePrimArray p i bi
go bi ( i + 1 )
a0 <- readPrimArray p 0
go a0 1
-- | Forward deflation of a polynomial with real coefficients by a pair of complex-conjugate roots.
deflateConjugatePair :: forall a m s. ( Num a, Prim a, PrimMonad m, s ~ PrimState m ) => Complex a -> MutablePrimArray s a -> m ()
deflateConjugatePair ( x :+ y ) p = do
deg <- subtract 1 <$> getSizeofMutablePrimArray p
when ( deg >= 3 ) do
shrinkMutablePrimArray p ( deg - 1 )
let
c1, c2 :: a
!c1 = 2 * x
!c2 = x * x + y * y
a0 <- readPrimArray p 0
a1 <- readPrimArray p 1
let
b1 :: a
!b1 = a1 + c1 * a0
writePrimArray p 1 b1
let
go :: a -> a -> Int -> m ()
go !b !b' !i = unless ( i >= deg - 1 ) do
ai <- readPrimArray p i
let
bi :: a
!bi = ai + c1 * b - c2 * b'
writePrimArray p i bi
go bi b ( i + 1 )
go b1 a0 2
-- | Laguerre's method.
laguerre
:: forall a m s
. ( RealFloat a, Prim a, PrimMonad m, s ~ PrimState m )
=> a -- ^ error tolerance
-> Int -- ^ max number of iterations
-> MutablePrimArray s ( Complex a ) -- ^ polynomial
-> Complex a -- ^ initial point
=> a -- ^ error tolerance
-> Int -- ^ max number of iterations
-> MutablePrimArray s a -- ^ polynomial
-> Complex a -- ^ initial point
-> m ( Complex a )
laguerre eps maxIters p x0 = do
p' <- derivative p
@ -150,11 +218,11 @@ laguerre eps maxIters p x0 = do
laguerreStep
:: forall a m s
. ( RealFloat a, Prim a, PrimMonad m, s ~ PrimState m )
=> a -- ^ error tolerance
-> MutablePrimArray s ( Complex a ) -- ^ polynomial
-> MutablePrimArray s ( Complex a ) -- ^ first derivative of polynomial
-> MutablePrimArray s ( Complex a ) -- ^ second derivative of polynomial
-> Complex a -- ^ initial point
=> a -- ^ error tolerance
-> MutablePrimArray s a -- ^ polynomial
-> MutablePrimArray s a -- ^ first derivative of polynomial
-> MutablePrimArray s a -- ^ second derivative of polynomial
-> Complex a -- ^ initial point
-> m ( Complex a )
laguerreStep eps p p' p'' x = do
n <- fromIntegral @_ @a <$> getSizeofMutablePrimArray p
@ -168,7 +236,8 @@ laguerreStep eps p p' p'' x = do
g = p'x / px
g² = g * g
h = g² - p''x / px
delta = sqrt $ ( n - 1 ) *: ( n *: h - g² )
mult = 1 --max 1 . min n . fromIntegral . round @_ @Integer . realPart $ log p'x / log ( px / p'x )
delta = sqrt $ ( n - 1 ) *: ( ( n / mult ) *: h - g² )
gp = g + delta
gm = g - delta
denom
@ -182,24 +251,24 @@ laguerreStep eps p p' p'' x = do
(*:) :: a -> Complex a -> Complex a
r *: (u :+ v) = ( r * u ) :+ ( r * v )
-- | Evaluate a polynomial.
-- | Evaluate a polynomial with real coefficients at a complex number.
eval
:: forall a m s
. ( Num a, Prim a, PrimMonad m, s ~ PrimState m )
=> MutablePrimArray s a -> a -> m a
. ( RealFloat a, Prim a, PrimMonad m, s ~ PrimState m )
=> MutablePrimArray s a -> Complex a -> m ( Complex a )
eval p x = do
n <- getSizeofMutablePrimArray p
let
go :: a -> Int -> m a
go :: Complex a -> Int -> m ( Complex a )
go !a !i
| i >= n
= pure a
| otherwise
= do
b <- readPrimArray p i
go ( b + x * a ) ( i + 1 )
go ( ( b :+ 0 ) + x * a ) ( i + 1 )
an <- readPrimArray p 0
go an 1
go ( an :+ 0 ) 1
-- | Derivative of a polynomial.
derivative