mirror of
https://gitlab.com/sheaf/metabrush.git
synced 2024-11-05 23:03:38 +00:00
optimise root-finding functions
* use PrimArray to represent polynomials * add some strictness annotations * turn on some optimisation flags * use quadratic formula for quadratic polynomials
This commit is contained in:
parent
eb8e7012aa
commit
58ca70c1bd
|
@ -40,6 +40,8 @@ common common
|
||||||
>= 1.2.0.1 && < 2.0
|
>= 1.2.0.1 && < 2.0
|
||||||
, groups
|
, groups
|
||||||
^>= 0.4.1.0
|
^>= 0.4.1.0
|
||||||
|
, primitive
|
||||||
|
^>= 0.7.1.0
|
||||||
, transformers
|
, transformers
|
||||||
^>= 0.5.6.2
|
^>= 0.5.6.2
|
||||||
|
|
||||||
|
@ -49,7 +51,11 @@ common common
|
||||||
ghc-options:
|
ghc-options:
|
||||||
-O1
|
-O1
|
||||||
-fexpose-all-unfoldings
|
-fexpose-all-unfoldings
|
||||||
|
-funfolding-use-threshold=16
|
||||||
|
-fexcess-precision
|
||||||
-fspecialise-aggressively
|
-fspecialise-aggressively
|
||||||
|
-optc-O3
|
||||||
|
-optc-ffast-math
|
||||||
-Wall
|
-Wall
|
||||||
-Wcompat
|
-Wcompat
|
||||||
-fwarn-missing-local-signatures
|
-fwarn-missing-local-signatures
|
||||||
|
@ -85,6 +91,8 @@ library
|
||||||
^>= 0.20.0.0
|
^>= 0.20.0.0
|
||||||
, monad-par
|
, monad-par
|
||||||
^>= 0.3.5
|
^>= 0.3.5
|
||||||
|
, prim-instances
|
||||||
|
^>= 0.2
|
||||||
, vector
|
, vector
|
||||||
^>= 0.12.1.2
|
^>= 0.12.1.2
|
||||||
|
|
||||||
|
|
|
@ -214,7 +214,7 @@ main = do
|
||||||
maxHistorySizeTVar <- STM.newTVarIO @Int 1000
|
maxHistorySizeTVar <- STM.newTVarIO @Int 1000
|
||||||
fitParametersTVar <- STM.newTVarIO @FitParameters
|
fitParametersTVar <- STM.newTVarIO @FitParameters
|
||||||
( FitParameters
|
( FitParameters
|
||||||
{ maxSubdiv = 10
|
{ maxSubdiv = 6
|
||||||
, nbSegments = 12
|
, nbSegments = 12
|
||||||
, dist_tol = 5e-3
|
, dist_tol = 5e-3
|
||||||
, t_tol = 1e-4
|
, t_tol = 1e-4
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
{-# LANGUAGE DerivingStrategies #-}
|
{-# LANGUAGE DerivingStrategies #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||||
|
@ -72,7 +73,7 @@ newtype UniqueSupply = UniqueSupply { uniqueSupplyTVar :: STM.TVar Unique }
|
||||||
|
|
||||||
freshUnique :: UniqueSupply -> STM Unique
|
freshUnique :: UniqueSupply -> STM Unique
|
||||||
freshUnique ( UniqueSupply { uniqueSupplyTVar } ) = do
|
freshUnique ( UniqueSupply { uniqueSupplyTVar } ) = do
|
||||||
uniq@( Unique i ) <- STM.readTVar uniqueSupplyTVar
|
uniq@( Unique !i ) <- STM.readTVar uniqueSupplyTVar
|
||||||
STM.writeTVar uniqueSupplyTVar ( Unique ( succ i ) )
|
STM.writeTVar uniqueSupplyTVar ( Unique ( succ i ) )
|
||||||
pure uniq
|
pure uniq
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
{-# LANGUAGE AllowAmbiguousTypes #-}
|
{-# LANGUAGE AllowAmbiguousTypes #-}
|
||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
{-# LANGUAGE DeriveAnyClass #-}
|
{-# LANGUAGE DeriveAnyClass #-}
|
||||||
{-# LANGUAGE DeriveGeneric #-}
|
{-# LANGUAGE DeriveGeneric #-}
|
||||||
{-# LANGUAGE DeriveTraversable #-}
|
{-# LANGUAGE DeriveTraversable #-}
|
||||||
|
@ -57,6 +58,10 @@ import Data.Group
|
||||||
import Data.Group.Generics
|
import Data.Group.Generics
|
||||||
()
|
()
|
||||||
|
|
||||||
|
-- primitive
|
||||||
|
import Data.Primitive.Types
|
||||||
|
( Prim )
|
||||||
|
|
||||||
-- MetaBrush
|
-- MetaBrush
|
||||||
import qualified Math.Bezier.Quadratic as Quadratic
|
import qualified Math.Bezier.Quadratic as Quadratic
|
||||||
( Bezier(..), bezier )
|
( Bezier(..), bezier )
|
||||||
|
@ -154,28 +159,30 @@ subdivide ( Bezier {..} ) t = ( Bezier p0 q1 q2 pt, Bezier pt r1 r2 p3 )
|
||||||
|
|
||||||
-- | Polynomial coefficients of the derivative of the distance to a cubic Bézier curve.
|
-- | Polynomial coefficients of the derivative of the distance to a cubic Bézier curve.
|
||||||
ddist :: forall v r p. ( Torsor v p, Inner r v, RealFloat r ) => Bezier p -> p -> [ r ]
|
ddist :: forall v r p. ( Torsor v p, Inner r v, RealFloat r ) => Bezier p -> p -> [ r ]
|
||||||
ddist ( Bezier {..} ) c = [ a0, a1, a2, a3, a4, a5 ]
|
ddist ( Bezier {..} ) c = [ a5, a4, a3, a2, a1, a0 ]
|
||||||
where
|
where
|
||||||
v, v', v'', v''' :: v
|
v, v', v'', v''' :: v
|
||||||
v = c --> p0
|
!v = c --> p0
|
||||||
v' = p0 --> p1
|
!v' = p0 --> p1
|
||||||
v'' = p1 --> p0 ^+^ p1 --> p2
|
!v'' = p1 --> p0 ^+^ p1 --> p2
|
||||||
v''' = p0 --> p3 ^+^ 3 *^ ( p2 --> p1 )
|
!v''' = p0 --> p3 ^+^ 3 *^ ( p2 --> p1 )
|
||||||
|
|
||||||
a0, a1, a2, a3, a4, a5 :: r
|
a0, a1, a2, a3, a4, a5 :: r
|
||||||
a0 = v ^.^ v'
|
!a0 = v ^.^ v'
|
||||||
a1 = 3 * squaredNorm v' + 2 * v ^.^ v''
|
!a1 = 3 * squaredNorm v' + 2 * v ^.^ v''
|
||||||
a2 = 9 * v' ^.^ v'' + v ^.^ v'''
|
!a2 = 9 * v' ^.^ v'' + v ^.^ v'''
|
||||||
a3 = 6 * squaredNorm v'' + 4 * v' ^.^ v'''
|
!a3 = 6 * squaredNorm v'' + 4 * v' ^.^ v'''
|
||||||
a4 = 5 * v'' ^.^ v'''
|
!a4 = 5 * v'' ^.^ v'''
|
||||||
a5 = squaredNorm v'''
|
!a5 = squaredNorm v'''
|
||||||
|
|
||||||
-- | Finds the closest point to a given point on a cubic Bézier curve.
|
-- | Finds the closest point to a given point on a cubic Bézier curve.
|
||||||
closestPoint :: forall v r p. ( Torsor v p, Inner r v, RealFloat r ) => Bezier p -> p -> ArgMin r ( r, p )
|
closestPoint
|
||||||
|
:: forall v r p. ( Torsor v p, Inner r v, RealFloat r, Prim r )
|
||||||
|
=> Bezier p -> p -> ArgMin r ( r, p )
|
||||||
closestPoint pts@( Bezier {..} ) c = pickClosest ( 0 :| 1 : roots ) -- todo: also include the self-intersection point if one exists
|
closestPoint pts@( Bezier {..} ) c = pickClosest ( 0 :| 1 : roots ) -- todo: also include the self-intersection point if one exists
|
||||||
where
|
where
|
||||||
roots :: [ r ]
|
roots :: [ r ]
|
||||||
roots = filter ( \ r -> r > 0 && r < 1 ) ( realRoots $ ddist @v pts c )
|
roots = filter ( \ r -> r > 0 && r < 1 ) ( realRoots 2000 $ ddist @v pts c )
|
||||||
|
|
||||||
pickClosest :: NonEmpty r -> ArgMin r ( r, p )
|
pickClosest :: NonEmpty r -> ArgMin r ( r, p )
|
||||||
pickClosest ( s :| ss ) = go s q nm0 ss
|
pickClosest ( s :| ss ) = go s q nm0 ss
|
||||||
|
|
|
@ -46,6 +46,10 @@ import qualified Data.Sequence as Seq
|
||||||
import Control.DeepSeq
|
import Control.DeepSeq
|
||||||
( NFData )
|
( NFData )
|
||||||
|
|
||||||
|
-- primitive
|
||||||
|
import Data.Primitive.PrimArray
|
||||||
|
( primArrayFromListN, unsafeThawPrimArray )
|
||||||
|
|
||||||
-- transformers
|
-- transformers
|
||||||
import Control.Monad.Trans.State.Strict
|
import Control.Monad.Trans.State.Strict
|
||||||
( execStateT, modify' )
|
( execStateT, modify' )
|
||||||
|
@ -241,9 +245,12 @@ fitPiece dist_tol t_tol maxIters p tp qs r tr =
|
||||||
( dts_changed, argmax_sq_dist ) <- ( `execStateT` ( False, Max ( Arg 0 0 ) ) ) $ for_ ( zip qs [ 0 .. ] ) \( q, i ) -> do
|
( dts_changed, argmax_sq_dist ) <- ( `execStateT` ( False, Max ( Arg 0 0 ) ) ) $ for_ ( zip qs [ 0 .. ] ) \( q, i ) -> do
|
||||||
ti <- lift ( Unboxed.MVector.unsafeRead ts i )
|
ti <- lift ( Unboxed.MVector.unsafeRead ts i )
|
||||||
let
|
let
|
||||||
poly :: [ Complex Double ]
|
laguerreStepResult :: Complex Double
|
||||||
poly = map (:+ 0) $ Cubic.ddist @( Vector2D Double ) bez q
|
laguerreStepResult = runST do
|
||||||
ti' <- case laguerre epsilon 1 poly ( ti :+ 0 ) of
|
coeffs <- unsafeThawPrimArray . primArrayFromListN 6 . map (:+ 0)
|
||||||
|
$ Cubic.ddist @( Vector2D Double ) bez q
|
||||||
|
laguerre epsilon 1 coeffs ( ti :+ 0 )
|
||||||
|
ti' <- case laguerreStepResult of
|
||||||
x :+ y
|
x :+ y
|
||||||
| isNaN x
|
| isNaN x
|
||||||
|| isNaN y
|
|| isNaN y
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
{-# LANGUAGE AllowAmbiguousTypes #-}
|
{-# LANGUAGE AllowAmbiguousTypes #-}
|
||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
{-# LANGUAGE DeriveAnyClass #-}
|
{-# LANGUAGE DeriveAnyClass #-}
|
||||||
{-# LANGUAGE DeriveGeneric #-}
|
{-# LANGUAGE DeriveGeneric #-}
|
||||||
{-# LANGUAGE DeriveTraversable #-}
|
{-# LANGUAGE DeriveTraversable #-}
|
||||||
|
@ -54,6 +55,10 @@ import Data.Group
|
||||||
import Data.Group.Generics
|
import Data.Group.Generics
|
||||||
()
|
()
|
||||||
|
|
||||||
|
-- primitive
|
||||||
|
import Data.Primitive.Types
|
||||||
|
( Prim )
|
||||||
|
|
||||||
-- MetaBrush
|
-- MetaBrush
|
||||||
import Math.Epsilon
|
import Math.Epsilon
|
||||||
( epsilon )
|
( epsilon )
|
||||||
|
@ -125,25 +130,27 @@ subdivide ( Bezier {..} ) t = ( Bezier p0 q1 pt, Bezier pt r1 p2 )
|
||||||
|
|
||||||
-- | Polynomial coefficients of the derivative of the distance to a quadratic Bézier curve.
|
-- | Polynomial coefficients of the derivative of the distance to a quadratic Bézier curve.
|
||||||
ddist :: forall v r p. ( Torsor v p, Inner r v, RealFloat r ) => Bezier p -> p -> [ r ]
|
ddist :: forall v r p. ( Torsor v p, Inner r v, RealFloat r ) => Bezier p -> p -> [ r ]
|
||||||
ddist ( Bezier {..} ) c = [ a0, a1, a2, a3 ]
|
ddist ( Bezier {..} ) c = [ a3, a2, a1, a0 ]
|
||||||
where
|
where
|
||||||
v, v', v'' :: v
|
v, v', v'' :: v
|
||||||
v = c --> p0
|
!v = c --> p0
|
||||||
v' = p0 --> p1
|
!v' = p0 --> p1
|
||||||
v'' = p1 --> p0 ^+^ p1 --> p2
|
!v'' = p1 --> p0 ^+^ p1 --> p2
|
||||||
|
|
||||||
a0, a1, a2, a3 :: r
|
a0, a1, a2, a3 :: r
|
||||||
a0 = v ^.^ v'
|
!a0 = v ^.^ v'
|
||||||
a1 = v ^.^ v'' + 2 * squaredNorm v'
|
!a1 = v ^.^ v'' + 2 * squaredNorm v'
|
||||||
a2 = 3 * v' ^.^ v''
|
!a2 = 3 * v' ^.^ v''
|
||||||
a3 = squaredNorm v''
|
!a3 = squaredNorm v''
|
||||||
|
|
||||||
-- | Finds the closest point to a given point on a quadratic Bézier curve.
|
-- | Finds the closest point to a given point on a quadratic Bézier curve.
|
||||||
closestPoint :: forall v r p. ( Torsor v p, Inner r v, RealFloat r ) => Bezier p -> p -> ArgMin r ( r, p )
|
closestPoint
|
||||||
|
:: forall v r p. ( Torsor v p, Inner r v, RealFloat r, Prim r )
|
||||||
|
=> Bezier p -> p -> ArgMin r ( r, p )
|
||||||
closestPoint pts@( Bezier {..} ) c = pickClosest ( 0 :| 1 : roots )
|
closestPoint pts@( Bezier {..} ) c = pickClosest ( 0 :| 1 : roots )
|
||||||
where
|
where
|
||||||
roots :: [ r ]
|
roots :: [ r ]
|
||||||
roots = filter ( \ r -> r > 0 && r < 1 ) ( realRoots $ ddist @v pts c )
|
roots = filter ( \ r -> r > 0 && r < 1 ) ( realRoots 2000 $ ddist @v pts c )
|
||||||
|
|
||||||
pickClosest :: NonEmpty r -> ArgMin r ( r, p )
|
pickClosest :: NonEmpty r -> ArgMin r ( r, p )
|
||||||
pickClosest ( s :| ss ) = go s q nm0 ss
|
pickClosest ( s :| ss ) = go s q nm0 ss
|
||||||
|
|
|
@ -78,7 +78,7 @@ import Math.Module
|
||||||
, lerp, squaredNorm
|
, lerp, squaredNorm
|
||||||
)
|
)
|
||||||
import Math.Roots
|
import Math.Roots
|
||||||
( realRoots )
|
( solveQuadratic )
|
||||||
import Math.Vector2D
|
import Math.Vector2D
|
||||||
( Point2D(..), Vector2D(..), cross )
|
( Point2D(..), Vector2D(..), cross )
|
||||||
|
|
||||||
|
@ -580,12 +580,12 @@ withTangent tgt ( spt0 :<| spt1 :<| spts ) =
|
||||||
| otherwise
|
| otherwise
|
||||||
= Nothing
|
= Nothing
|
||||||
in
|
in
|
||||||
case mapMaybe correctTangentParam $ realRoots [ c01, 2 * ( c12 - c01 ), c01 + c23 - 2 * c12 ] of
|
case mapMaybe correctTangentParam $ solveQuadratic c01 ( 2 * ( c12 - c01 ) ) ( c01 + c23 - 2 * c12 ) of
|
||||||
( t : _ )
|
( t : _ )
|
||||||
-> Offset i ( Just t ) ( MkVector2D $ Cubic.bezier @( Vector2D Double ) bez t )
|
-> Offset i ( Just t ) ( MkVector2D $ Cubic.bezier @( Vector2D Double ) bez t )
|
||||||
-- Fallback in case we couldn't solve the quadratic for some reason.
|
|
||||||
_
|
_
|
||||||
| Just s <- between tgt tgt0 tgt2
|
| Just s <- between tgt tgt0 tgt2
|
||||||
|
-- Fallback in case we couldn't solve the quadratic for some reason.
|
||||||
-> Offset i ( Just s ) ( MkVector2D $ Cubic.bezier @( Vector2D Double ) bez s )
|
-> Offset i ( Just s ) ( MkVector2D $ Cubic.bezier @( Vector2D Double ) bez s )
|
||||||
-- Otherwise: go to next piece of the curve.
|
-- Otherwise: go to next piece of the curve.
|
||||||
| otherwise
|
| otherwise
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
|
||||||
module Math.Epsilon
|
module Math.Epsilon
|
||||||
( epsilon )
|
( epsilon, nearZero )
|
||||||
where
|
where
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
@ -10,3 +10,6 @@ module Math.Epsilon
|
||||||
{-# SPECIALISE epsilon :: Double #-}
|
{-# SPECIALISE epsilon :: Double #-}
|
||||||
epsilon :: forall r. RealFloat r => r
|
epsilon :: forall r. RealFloat r => r
|
||||||
epsilon = encodeFloat 1 ( 5 - floatDigits ( 0 :: r ) )
|
epsilon = encodeFloat 1 ( 5 - floatDigits ( 0 :: r ) )
|
||||||
|
|
||||||
|
nearZero :: RealFloat r => r -> Bool
|
||||||
|
nearZero x = abs x < epsilon
|
||||||
|
|
|
@ -1,108 +1,172 @@
|
||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE BlockArguments #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE NamedWildCards #-}
|
||||||
|
{-# LANGUAGE PartialTypeSignatures #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE TypeApplications #-}
|
||||||
|
|
||||||
module Math.Roots
|
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
|
||||||
|
|
||||||
where
|
module Math.Roots where
|
||||||
|
|
||||||
-- base
|
-- base
|
||||||
|
import Control.Monad
|
||||||
|
( unless )
|
||||||
|
import Control.Monad.ST
|
||||||
|
( ST, runST )
|
||||||
import Data.Complex
|
import Data.Complex
|
||||||
( Complex(..), magnitude )
|
( Complex(..), magnitude )
|
||||||
import Data.List.NonEmpty
|
|
||||||
( NonEmpty(..), toList )
|
|
||||||
import Data.Maybe
|
import Data.Maybe
|
||||||
( mapMaybe )
|
( mapMaybe )
|
||||||
|
|
||||||
|
-- primitive
|
||||||
|
import Control.Monad.Primitive
|
||||||
|
( PrimMonad(PrimState) )
|
||||||
|
import Data.Primitive.PrimArray
|
||||||
|
( PrimArray, MutablePrimArray
|
||||||
|
, primArrayFromList
|
||||||
|
, getSizeofMutablePrimArray, sizeofPrimArray
|
||||||
|
, unsafeThawPrimArray, cloneMutablePrimArray
|
||||||
|
, shrinkMutablePrimArray, readPrimArray, writePrimArray
|
||||||
|
)
|
||||||
|
import Data.Primitive.Types
|
||||||
|
( Prim )
|
||||||
|
|
||||||
|
-- prim-instances
|
||||||
|
import Data.Primitive.Instances
|
||||||
|
() -- instance Prim a => Prim ( Complex a )
|
||||||
|
|
||||||
-- MetaBrush
|
-- MetaBrush
|
||||||
import Math.Epsilon
|
import Math.Epsilon
|
||||||
( epsilon )
|
( epsilon, nearZero )
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
-- | Find real roots of a polynomial.
|
-- | Real solutions to a quadratic equation.
|
||||||
|
solveQuadratic :: forall a. RealFloat a => a -> a -> a -> [ a ]
|
||||||
-- Coefficients are given in order of increasing degree, e.g.:
|
solveQuadratic a0 a1 a2
|
||||||
-- x² + 7 is given by [ 7, 0, 1 ].
|
| nearZero a1 && nearZero a2
|
||||||
realRoots :: forall r. RealFloat r => [ r ] -> [ r ]
|
= if nearZero a0
|
||||||
realRoots p = mapMaybe isReal ( roots epsilon 10000 ( map (:+ 0) p ) )
|
then [ 0, 0.5, 1 ] -- convention
|
||||||
|
else []
|
||||||
|
| nearZero ( a0 * a0 * a2 / ( a1 * a1 ) )
|
||||||
|
= [ - a0 / a1 ]
|
||||||
|
| disc < 0
|
||||||
|
= [] -- non-real solutions
|
||||||
|
| otherwise
|
||||||
|
= let
|
||||||
|
r :: a
|
||||||
|
r =
|
||||||
|
if a1 >= 0
|
||||||
|
then 2 * a0 / ( - a1 - sqrt disc )
|
||||||
|
else 0.5 * ( - a1 + sqrt disc) / a2
|
||||||
|
in [ r, -r - a1 ]
|
||||||
where
|
where
|
||||||
isReal :: Complex r -> Maybe r
|
disc :: a
|
||||||
|
disc = a1 * a1 - 4 * a0 * a2
|
||||||
|
|
||||||
|
-- | Find real roots of a polynomial.
|
||||||
|
--
|
||||||
|
-- Coefficients are given in order of decreasing degree, e.g.:
|
||||||
|
-- x² + 7 is given by [ 1, 0, 7 ].
|
||||||
|
realRoots :: forall a. ( RealFloat a, Prim a ) => Int -> [ a ] -> [ a ]
|
||||||
|
realRoots maxIters coeffs = mapMaybe isReal ( roots epsilon maxIters ( map (:+ 0) coeffs ) )
|
||||||
|
where
|
||||||
|
isReal :: Complex a -> Maybe a
|
||||||
isReal ( a :+ b )
|
isReal ( a :+ b )
|
||||||
| abs b < epsilon = Just a
|
| abs b < epsilon = Just a
|
||||||
| otherwise = Nothing
|
| otherwise = Nothing
|
||||||
|
|
||||||
-- | Compute all roots of a polynomial using Laguerre's method and (forward) deflation.
|
-- | Compute all roots of a polynomial using Laguerre's method and (forward) deflation.
|
||||||
--
|
--
|
||||||
-- Polynomial coefficients are given in order of ascending degree (e.g. constant coefficient first).
|
-- Polynomial coefficients are given in order of descending degree (e.g. constant coefficient last).
|
||||||
--
|
--
|
||||||
-- N.B. The forward deflation process is only guaranteed to be numerically stable
|
-- N.B. The forward deflation process is only guaranteed to be numerically stable
|
||||||
-- if Laguerre's method finds roots in increasing order of magnitude.
|
-- if Laguerre's method finds roots in increasing order of magnitude.
|
||||||
roots :: forall a. RealFloat a => a -> Int -> [ Complex a ] -> [ Complex a ]
|
roots :: forall a. ( RealFloat a, Prim a ) => a -> Int -> [ Complex a ] -> [ Complex a ]
|
||||||
roots eps maxIters p = go p []
|
roots eps maxIters coeffs = runST do
|
||||||
where
|
let
|
||||||
go :: [ Complex a ] -> [ Complex a ] -> [ Complex a ]
|
coeffPrimArray :: PrimArray ( Complex a )
|
||||||
go q rs
|
coeffPrimArray = primArrayFromList coeffs
|
||||||
| length q <= 2 = r : rs
|
sz :: Int
|
||||||
| otherwise = go ( deflate r q ) ( r : rs )
|
sz = sizeofPrimArray coeffPrimArray
|
||||||
where
|
p <- unsafeThawPrimArray coeffPrimArray
|
||||||
r :: Complex a
|
let
|
||||||
r = laguerre eps maxIters q 0
|
go :: Int -> [ Complex a ] -> ST _s [ Complex a ]
|
||||||
-- Start the iteration at 0 for best chance of numerical stability.
|
go i rs = do
|
||||||
|
!r <- laguerre eps maxIters p 0 -- Start Laguerre's method at 0 for best chance of numerical stability.
|
||||||
|
if i <= 2
|
||||||
|
then pure ( r : rs )
|
||||||
|
else do
|
||||||
|
deflate r p
|
||||||
|
go ( i - 1 ) ( r : rs )
|
||||||
|
go sz []
|
||||||
|
|
||||||
-- | Deflate a polynomial: factor out a root of the polynomial.
|
-- | Deflate a polynomial: factor out a root of the polynomial.
|
||||||
--
|
--
|
||||||
-- The polynomial must have degree at least 2.
|
-- The polynomial must have degree at least 2.
|
||||||
deflate :: forall a. Num a => a -> [ a ] -> [ a ]
|
deflate :: forall a m s. ( Num a, Prim a, PrimMonad m, s ~ PrimState m ) => a -> MutablePrimArray s a -> m ()
|
||||||
deflate r ( _ : c : cs ) = toList $ go ( c :| cs )
|
deflate r p = do
|
||||||
where
|
deg <- subtract 1 <$> getSizeofMutablePrimArray p
|
||||||
go :: NonEmpty a -> NonEmpty a
|
case compare deg 2 of
|
||||||
go ( a :| [] ) = a :| []
|
LT -> pure ()
|
||||||
go ( a :| a' : as ) = case go ( a' :| as ) of
|
EQ -> shrinkMutablePrimArray p deg
|
||||||
( b' :| bs ) -> ( a + r * b' ) :| ( b' : bs )
|
GT -> do
|
||||||
deflate _ _ = error "deflate: polynomial of degree < 2"
|
shrinkMutablePrimArray p deg
|
||||||
|
let
|
||||||
|
go :: a -> Int -> m ()
|
||||||
|
go b i = unless ( i >= deg ) do
|
||||||
|
ai <- readPrimArray p i
|
||||||
|
writePrimArray p i ( ai + r * b )
|
||||||
|
go ai ( i + 1 )
|
||||||
|
a0 <- readPrimArray p 0
|
||||||
|
go a0 1
|
||||||
|
|
||||||
-- | Laguerre's method.
|
-- | Laguerre's method.
|
||||||
laguerre
|
laguerre
|
||||||
:: forall a. RealFloat a
|
:: forall a m s
|
||||||
|
. ( RealFloat a, Prim a, PrimMonad m, s ~ PrimState m )
|
||||||
=> a -- ^ error tolerance
|
=> a -- ^ error tolerance
|
||||||
-> Int -- ^ max number of iterations
|
-> Int -- ^ max number of iterations
|
||||||
-> [ Complex a ] -- ^ polynomial
|
-> MutablePrimArray s ( Complex a ) -- ^ polynomial
|
||||||
-> Complex a -- ^ initial point
|
-> Complex a -- ^ initial point
|
||||||
-> Complex a
|
-> m ( Complex a )
|
||||||
laguerre eps maxIters p = go maxIters
|
laguerre eps maxIters p x0 = do
|
||||||
where
|
p' <- derivative p
|
||||||
p', p'' :: [ Complex a ]
|
p'' <- derivative p'
|
||||||
p' = derivative p
|
let
|
||||||
p'' = derivative p'
|
go :: Int -> Complex a -> m ( Complex a )
|
||||||
go :: Int -> Complex a -> Complex a
|
go iterationsLeft x = do
|
||||||
go iterationsLeft x
|
x' <- laguerreStep eps p p' p'' x
|
||||||
| iterationsLeft <= 1
|
if iterationsLeft <= 1 || magnitude ( x' - x ) < eps
|
||||||
|| magnitude ( x' - x ) < eps = x'
|
then pure x'
|
||||||
| otherwise = go ( iterationsLeft - 1 ) x'
|
else go ( iterationsLeft - 1 ) x'
|
||||||
where
|
go maxIters x0
|
||||||
x' :: Complex a
|
|
||||||
x' = laguerreStep eps p p' p'' x
|
|
||||||
|
|
||||||
-- | Take a single step in Laguerre's method.
|
-- | Take a single step in Laguerre's method.
|
||||||
laguerreStep
|
laguerreStep
|
||||||
:: forall a. RealFloat a
|
:: forall a m s
|
||||||
|
. ( RealFloat a, Prim a, PrimMonad m, s ~ PrimState m )
|
||||||
=> a -- ^ error tolerance
|
=> a -- ^ error tolerance
|
||||||
-> [ Complex a ] -- ^ polynomial
|
-> MutablePrimArray s ( Complex a ) -- ^ polynomial
|
||||||
-> [ Complex a ] -- ^ first derivative of polynomial
|
-> MutablePrimArray s ( Complex a ) -- ^ first derivative of polynomial
|
||||||
-> [ Complex a ] -- ^ second derivative of polynomial
|
-> MutablePrimArray s ( Complex a ) -- ^ second derivative of polynomial
|
||||||
-> Complex a -- ^ initial point
|
-> Complex a -- ^ initial point
|
||||||
-> Complex a
|
-> m ( Complex a )
|
||||||
laguerreStep eps p p' p'' x
|
laguerreStep eps p p' p'' x = do
|
||||||
| magnitude px < eps = x
|
n <- fromIntegral @_ @a <$> getSizeofMutablePrimArray p
|
||||||
| otherwise = x - n / denom
|
px <- eval p x
|
||||||
where
|
if magnitude px < eps
|
||||||
n = fromIntegral ( length p )
|
then pure x
|
||||||
px = eval p x
|
else do
|
||||||
p'x = eval p' x
|
p'x <- eval p' x
|
||||||
p''x = eval p'' x
|
p''x <- eval p'' x
|
||||||
|
let
|
||||||
g = p'x / px
|
g = p'x / px
|
||||||
g² = g * g
|
g² = g * g
|
||||||
h = g² - p''x / px
|
h = g² - p''x / px
|
||||||
delta = sqrt $ ( n - 1 ) * ( n * h - g² )
|
delta = sqrt $ ( n - 1 ) *: ( n *: h - g² )
|
||||||
gp = g + delta
|
gp = g + delta
|
||||||
gm = g - delta
|
gm = g - delta
|
||||||
denom
|
denom
|
||||||
|
@ -110,11 +174,42 @@ laguerreStep eps p p' p'' x
|
||||||
= gm
|
= gm
|
||||||
| otherwise
|
| otherwise
|
||||||
= gp
|
= gp
|
||||||
|
pure $ x - n *: ( recip denom )
|
||||||
|
|
||||||
|
where
|
||||||
|
(*:) :: a -> Complex a -> Complex a
|
||||||
|
r *: (u :+ v) = ( r * u ) :+ ( r * v )
|
||||||
|
|
||||||
-- | Evaluate a polynomial.
|
-- | Evaluate a polynomial.
|
||||||
eval :: Num a => [ a ] -> a -> a
|
eval
|
||||||
eval as x = foldr ( \ a b -> a + x * b ) 0 as
|
:: forall a m s
|
||||||
|
. ( Num a, Prim a, PrimMonad m, s ~ PrimState m )
|
||||||
|
=> MutablePrimArray s a -> a -> m a
|
||||||
|
eval p x = do
|
||||||
|
n <- getSizeofMutablePrimArray p
|
||||||
|
let
|
||||||
|
go :: a -> Int -> m a
|
||||||
|
go !a i =
|
||||||
|
if i >= n
|
||||||
|
then pure a
|
||||||
|
else do
|
||||||
|
!b <- readPrimArray p i
|
||||||
|
go ( b + x * a ) ( i + 1 )
|
||||||
|
an <- readPrimArray p 0
|
||||||
|
go an 1
|
||||||
|
|
||||||
-- | Derivative of a polynomial.
|
-- | Derivative of a polynomial.
|
||||||
derivative :: Num a => [ a ] -> [ a ]
|
derivative
|
||||||
derivative as = zipWith ( \ i a -> fromIntegral i * a ) [ ( 1 :: Int ) .. ] ( tail as )
|
:: forall a m s
|
||||||
|
. ( Num a, Prim a, PrimMonad m, s ~ PrimState m )
|
||||||
|
=> MutablePrimArray s a -> m ( MutablePrimArray s a )
|
||||||
|
derivative p = do
|
||||||
|
deg <- subtract 1 <$> getSizeofMutablePrimArray p
|
||||||
|
p' <- cloneMutablePrimArray p 0 deg
|
||||||
|
let
|
||||||
|
go :: Int -> m ()
|
||||||
|
go i = unless ( i >= deg ) do
|
||||||
|
a <- readPrimArray p' i
|
||||||
|
writePrimArray p' i ( a * fromIntegral ( deg - i ) )
|
||||||
|
go 0
|
||||||
|
pure p'
|
||||||
|
|
Loading…
Reference in a new issue