fix some errors in polynomial code

This commit is contained in:
sheaf 2021-02-26 01:52:52 +01:00
parent 5e62937f16
commit a9adcba8eb
4 changed files with 23 additions and 19 deletions

View file

@ -177,7 +177,7 @@ ddist ( Bezier {..} ) c = [ a5, a4, a3, a2, a1, a0 ]
-- | Finds the closest point to a given point on a cubic Bézier curve.
closestPoint
:: forall v r p. ( Torsor v p, Inner r v, RealFloat r, Prim r )
:: forall v r p. ( Torsor v p, Inner r v, RealFloat r, Prim r, NFData r )
=> Bezier p -> p -> ArgMin r ( r, p )
closestPoint pts c = pickClosest ( 0 :| 1 : roots ) -- todo: also include the self-intersection point if one exists
where

View file

@ -145,7 +145,7 @@ ddist ( Bezier {..} ) c = [ a3, a2, a1, a0 ]
-- | Finds the closest point to a given point on a quadratic Bézier curve.
closestPoint
:: forall v r p. ( Torsor v p, Inner r v, RealFloat r, Prim r )
:: forall v r p. ( Torsor v p, Inner r v, RealFloat r, Prim r, NFData r )
=> Bezier p -> p -> ArgMin r ( r, p )
closestPoint pts c = pickClosest ( 0 :| 1 : roots )
where

View file

@ -161,7 +161,7 @@ discardCache ( view ( typed @( CachedStroke s ) ) -> CachedStroke { cachedStroke
{-# INLINE invalidateCache #-}
invalidateCache :: forall crvData. HasType ( CachedStroke RealWorld ) crvData => crvData -> crvData
invalidateCache = runRW# \ s -> do
invalidateCache = runRW# \ s ->
case newMutVar# Nothing s of
(# _, mutVar #) ->
set ( typed @( CachedStroke RealWorld ) )

View file

@ -6,8 +6,6 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
module Math.Roots where
-- base
@ -20,6 +18,10 @@ import Data.Complex
import Data.Maybe
( mapMaybe )
-- deepseq
import Control.DeepSeq
( NFData, force )
-- primitive
import Control.Monad.Primitive
( PrimMonad(PrimState) )
@ -70,7 +72,7 @@ solveQuadratic a0 a1 a2
--
-- Coefficients are given in order of decreasing degree, e.g.:
-- x² + 7 is given by [ 1, 0, 7 ].
realRoots :: forall a. ( RealFloat a, Prim a ) => Int -> [ a ] -> [ a ]
realRoots :: forall a. ( RealFloat a, Prim a, NFData a ) => Int -> [ a ] -> [ a ]
realRoots maxIters coeffs = mapMaybe isReal ( roots epsilon maxIters ( map (:+ 0) coeffs ) )
where
isReal :: Complex a -> Maybe a
@ -84,18 +86,18 @@ realRoots maxIters coeffs = mapMaybe isReal ( roots epsilon maxIters ( map (:+ 0
--
-- N.B. The forward deflation process is only guaranteed to be numerically stable
-- if Laguerre's method finds roots in increasing order of magnitude.
roots :: forall a. ( RealFloat a, Prim a ) => a -> Int -> [ Complex a ] -> [ Complex a ]
roots :: forall a. ( RealFloat a, Prim a, NFData a ) => a -> Int -> [ Complex a ] -> [ Complex a ]
roots eps maxIters coeffs = runST do
let
coeffPrimArray :: PrimArray ( Complex a )
coeffPrimArray = primArrayFromList coeffs
sz :: Int
sz = sizeofPrimArray coeffPrimArray
p <- unsafeThawPrimArray coeffPrimArray
( p :: MutablePrimArray s ( Complex a ) ) <- unsafeThawPrimArray coeffPrimArray
let
go :: Int -> [ Complex a ] -> ST _s [ Complex a ]
go i rs = do
!r <- laguerre eps maxIters p 0 -- Start Laguerre's method at 0 for best chance of numerical stability.
go :: Int -> [ Complex a ] -> ST s [ Complex a ]
go !i rs = do
!r <- force <$> laguerre eps maxIters p 0 -- Start Laguerre's method at 0 for best chance of numerical stability.
if i <= 2
then pure ( r : rs )
else do
@ -116,7 +118,7 @@ deflate r p = do
shrinkMutablePrimArray p deg
let
go :: a -> Int -> m ()
go b i = unless ( i >= deg ) do
go !b !i = unless ( i >= deg ) do
ai <- readPrimArray p i
writePrimArray p i ( ai + r * b )
go ai ( i + 1 )
@ -137,7 +139,7 @@ laguerre eps maxIters p x0 = do
p'' <- derivative p'
let
go :: Int -> Complex a -> m ( Complex a )
go iterationsLeft x = do
go !iterationsLeft !x = do
x' <- laguerreStep eps p p' p'' x
if iterationsLeft <= 1 || magnitude ( x' - x ) < eps
then pure x'
@ -189,11 +191,12 @@ eval p x = do
n <- getSizeofMutablePrimArray p
let
go :: a -> Int -> m a
go !a i =
if i >= n
then pure a
else do
!b <- readPrimArray p i
go !a !i
| i >= n
= pure a
| otherwise
= do
b <- readPrimArray p i
go ( b + x * a ) ( i + 1 )
an <- readPrimArray p 0
go an 1
@ -208,8 +211,9 @@ derivative p = do
p' <- cloneMutablePrimArray p 0 deg
let
go :: Int -> m ()
go i = unless ( i >= deg ) do
go !i = unless ( i >= deg - 1 ) do
a <- readPrimArray p' i
writePrimArray p' i ( a * fromIntegral ( deg - i ) )
go ( i + 1 )
go 0
pure p'