fix some errors in polynomial code

This commit is contained in:
sheaf 2021-02-26 01:52:52 +01:00
parent 5e62937f16
commit a9adcba8eb
4 changed files with 23 additions and 19 deletions

View file

@ -177,7 +177,7 @@ ddist ( Bezier {..} ) c = [ a5, a4, a3, a2, a1, a0 ]
-- | Finds the closest point to a given point on a cubic Bézier curve. -- | Finds the closest point to a given point on a cubic Bézier curve.
closestPoint closestPoint
:: forall v r p. ( Torsor v p, Inner r v, RealFloat r, Prim r ) :: forall v r p. ( Torsor v p, Inner r v, RealFloat r, Prim r, NFData r )
=> Bezier p -> p -> ArgMin r ( r, p ) => Bezier p -> p -> ArgMin r ( r, p )
closestPoint pts c = pickClosest ( 0 :| 1 : roots ) -- todo: also include the self-intersection point if one exists closestPoint pts c = pickClosest ( 0 :| 1 : roots ) -- todo: also include the self-intersection point if one exists
where where

View file

@ -145,7 +145,7 @@ ddist ( Bezier {..} ) c = [ a3, a2, a1, a0 ]
-- | Finds the closest point to a given point on a quadratic Bézier curve. -- | Finds the closest point to a given point on a quadratic Bézier curve.
closestPoint closestPoint
:: forall v r p. ( Torsor v p, Inner r v, RealFloat r, Prim r ) :: forall v r p. ( Torsor v p, Inner r v, RealFloat r, Prim r, NFData r )
=> Bezier p -> p -> ArgMin r ( r, p ) => Bezier p -> p -> ArgMin r ( r, p )
closestPoint pts c = pickClosest ( 0 :| 1 : roots ) closestPoint pts c = pickClosest ( 0 :| 1 : roots )
where where

View file

@ -161,7 +161,7 @@ discardCache ( view ( typed @( CachedStroke s ) ) -> CachedStroke { cachedStroke
{-# INLINE invalidateCache #-} {-# INLINE invalidateCache #-}
invalidateCache :: forall crvData. HasType ( CachedStroke RealWorld ) crvData => crvData -> crvData invalidateCache :: forall crvData. HasType ( CachedStroke RealWorld ) crvData => crvData -> crvData
invalidateCache = runRW# \ s -> do invalidateCache = runRW# \ s ->
case newMutVar# Nothing s of case newMutVar# Nothing s of
(# _, mutVar #) -> (# _, mutVar #) ->
set ( typed @( CachedStroke RealWorld ) ) set ( typed @( CachedStroke RealWorld ) )

View file

@ -6,8 +6,6 @@
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
module Math.Roots where module Math.Roots where
-- base -- base
@ -20,6 +18,10 @@ import Data.Complex
import Data.Maybe import Data.Maybe
( mapMaybe ) ( mapMaybe )
-- deepseq
import Control.DeepSeq
( NFData, force )
-- primitive -- primitive
import Control.Monad.Primitive import Control.Monad.Primitive
( PrimMonad(PrimState) ) ( PrimMonad(PrimState) )
@ -70,7 +72,7 @@ solveQuadratic a0 a1 a2
-- --
-- Coefficients are given in order of decreasing degree, e.g.: -- Coefficients are given in order of decreasing degree, e.g.:
-- x² + 7 is given by [ 1, 0, 7 ]. -- x² + 7 is given by [ 1, 0, 7 ].
realRoots :: forall a. ( RealFloat a, Prim a ) => Int -> [ a ] -> [ a ] realRoots :: forall a. ( RealFloat a, Prim a, NFData a ) => Int -> [ a ] -> [ a ]
realRoots maxIters coeffs = mapMaybe isReal ( roots epsilon maxIters ( map (:+ 0) coeffs ) ) realRoots maxIters coeffs = mapMaybe isReal ( roots epsilon maxIters ( map (:+ 0) coeffs ) )
where where
isReal :: Complex a -> Maybe a isReal :: Complex a -> Maybe a
@ -84,18 +86,18 @@ realRoots maxIters coeffs = mapMaybe isReal ( roots epsilon maxIters ( map (:+ 0
-- --
-- N.B. The forward deflation process is only guaranteed to be numerically stable -- N.B. The forward deflation process is only guaranteed to be numerically stable
-- if Laguerre's method finds roots in increasing order of magnitude. -- if Laguerre's method finds roots in increasing order of magnitude.
roots :: forall a. ( RealFloat a, Prim a ) => a -> Int -> [ Complex a ] -> [ Complex a ] roots :: forall a. ( RealFloat a, Prim a, NFData a ) => a -> Int -> [ Complex a ] -> [ Complex a ]
roots eps maxIters coeffs = runST do roots eps maxIters coeffs = runST do
let let
coeffPrimArray :: PrimArray ( Complex a ) coeffPrimArray :: PrimArray ( Complex a )
coeffPrimArray = primArrayFromList coeffs coeffPrimArray = primArrayFromList coeffs
sz :: Int sz :: Int
sz = sizeofPrimArray coeffPrimArray sz = sizeofPrimArray coeffPrimArray
p <- unsafeThawPrimArray coeffPrimArray ( p :: MutablePrimArray s ( Complex a ) ) <- unsafeThawPrimArray coeffPrimArray
let let
go :: Int -> [ Complex a ] -> ST _s [ Complex a ] go :: Int -> [ Complex a ] -> ST s [ Complex a ]
go i rs = do go !i rs = do
!r <- laguerre eps maxIters p 0 -- Start Laguerre's method at 0 for best chance of numerical stability. !r <- force <$> laguerre eps maxIters p 0 -- Start Laguerre's method at 0 for best chance of numerical stability.
if i <= 2 if i <= 2
then pure ( r : rs ) then pure ( r : rs )
else do else do
@ -116,7 +118,7 @@ deflate r p = do
shrinkMutablePrimArray p deg shrinkMutablePrimArray p deg
let let
go :: a -> Int -> m () go :: a -> Int -> m ()
go b i = unless ( i >= deg ) do go !b !i = unless ( i >= deg ) do
ai <- readPrimArray p i ai <- readPrimArray p i
writePrimArray p i ( ai + r * b ) writePrimArray p i ( ai + r * b )
go ai ( i + 1 ) go ai ( i + 1 )
@ -137,7 +139,7 @@ laguerre eps maxIters p x0 = do
p'' <- derivative p' p'' <- derivative p'
let let
go :: Int -> Complex a -> m ( Complex a ) go :: Int -> Complex a -> m ( Complex a )
go iterationsLeft x = do go !iterationsLeft !x = do
x' <- laguerreStep eps p p' p'' x x' <- laguerreStep eps p p' p'' x
if iterationsLeft <= 1 || magnitude ( x' - x ) < eps if iterationsLeft <= 1 || magnitude ( x' - x ) < eps
then pure x' then pure x'
@ -189,11 +191,12 @@ eval p x = do
n <- getSizeofMutablePrimArray p n <- getSizeofMutablePrimArray p
let let
go :: a -> Int -> m a go :: a -> Int -> m a
go !a i = go !a !i
if i >= n | i >= n
then pure a = pure a
else do | otherwise
!b <- readPrimArray p i = do
b <- readPrimArray p i
go ( b + x * a ) ( i + 1 ) go ( b + x * a ) ( i + 1 )
an <- readPrimArray p 0 an <- readPrimArray p 0
go an 1 go an 1
@ -208,8 +211,9 @@ derivative p = do
p' <- cloneMutablePrimArray p 0 deg p' <- cloneMutablePrimArray p 0 deg
let let
go :: Int -> m () go :: Int -> m ()
go i = unless ( i >= deg ) do go !i = unless ( i >= deg - 1 ) do
a <- readPrimArray p' i a <- readPrimArray p' i
writePrimArray p' i ( a * fromIntegral ( deg - i ) ) writePrimArray p' i ( a * fromIntegral ( deg - i ) )
go ( i + 1 )
go 0 go 0
pure p' pure p'