mirror of
https://gitlab.com/sheaf/metabrush.git
synced 2024-11-05 14:53:37 +00:00
optimise root-finding functions
* use PrimArray to represent polynomials * add some strictness annotations * turn on some optimisation flags * use quadratic formula for quadratic polynomials
This commit is contained in:
parent
eb8e7012aa
commit
58ca70c1bd
|
@ -40,6 +40,8 @@ common common
|
|||
>= 1.2.0.1 && < 2.0
|
||||
, groups
|
||||
^>= 0.4.1.0
|
||||
, primitive
|
||||
^>= 0.7.1.0
|
||||
, transformers
|
||||
^>= 0.5.6.2
|
||||
|
||||
|
@ -49,7 +51,11 @@ common common
|
|||
ghc-options:
|
||||
-O1
|
||||
-fexpose-all-unfoldings
|
||||
-funfolding-use-threshold=16
|
||||
-fexcess-precision
|
||||
-fspecialise-aggressively
|
||||
-optc-O3
|
||||
-optc-ffast-math
|
||||
-Wall
|
||||
-Wcompat
|
||||
-fwarn-missing-local-signatures
|
||||
|
@ -79,12 +85,14 @@ library
|
|||
, Math.Vector2D
|
||||
|
||||
build-depends:
|
||||
groups-generic
|
||||
groups-generic
|
||||
^>= 0.1.0.0
|
||||
, hmatrix
|
||||
^>= 0.20.0.0
|
||||
, monad-par
|
||||
^>= 0.3.5
|
||||
, prim-instances
|
||||
^>= 0.2
|
||||
, vector
|
||||
^>= 0.12.1.2
|
||||
|
||||
|
|
|
@ -214,7 +214,7 @@ main = do
|
|||
maxHistorySizeTVar <- STM.newTVarIO @Int 1000
|
||||
fitParametersTVar <- STM.newTVarIO @FitParameters
|
||||
( FitParameters
|
||||
{ maxSubdiv = 10
|
||||
{ maxSubdiv = 6
|
||||
, nbSegments = 12
|
||||
, dist_tol = 5e-3
|
||||
, t_tol = 1e-4
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE DerivingStrategies #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
|
@ -72,7 +73,7 @@ newtype UniqueSupply = UniqueSupply { uniqueSupplyTVar :: STM.TVar Unique }
|
|||
|
||||
freshUnique :: UniqueSupply -> STM Unique
|
||||
freshUnique ( UniqueSupply { uniqueSupplyTVar } ) = do
|
||||
uniq@( Unique i ) <- STM.readTVar uniqueSupplyTVar
|
||||
uniq@( Unique !i ) <- STM.readTVar uniqueSupplyTVar
|
||||
STM.writeTVar uniqueSupplyTVar ( Unique ( succ i ) )
|
||||
pure uniq
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
{-# LANGUAGE AllowAmbiguousTypes #-}
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DeriveTraversable #-}
|
||||
|
@ -57,6 +58,10 @@ import Data.Group
|
|||
import Data.Group.Generics
|
||||
()
|
||||
|
||||
-- primitive
|
||||
import Data.Primitive.Types
|
||||
( Prim )
|
||||
|
||||
-- MetaBrush
|
||||
import qualified Math.Bezier.Quadratic as Quadratic
|
||||
( Bezier(..), bezier )
|
||||
|
@ -154,28 +159,30 @@ subdivide ( Bezier {..} ) t = ( Bezier p0 q1 q2 pt, Bezier pt r1 r2 p3 )
|
|||
|
||||
-- | Polynomial coefficients of the derivative of the distance to a cubic Bézier curve.
|
||||
ddist :: forall v r p. ( Torsor v p, Inner r v, RealFloat r ) => Bezier p -> p -> [ r ]
|
||||
ddist ( Bezier {..} ) c = [ a0, a1, a2, a3, a4, a5 ]
|
||||
ddist ( Bezier {..} ) c = [ a5, a4, a3, a2, a1, a0 ]
|
||||
where
|
||||
v, v', v'', v''' :: v
|
||||
v = c --> p0
|
||||
v' = p0 --> p1
|
||||
v'' = p1 --> p0 ^+^ p1 --> p2
|
||||
v''' = p0 --> p3 ^+^ 3 *^ ( p2 --> p1 )
|
||||
!v = c --> p0
|
||||
!v' = p0 --> p1
|
||||
!v'' = p1 --> p0 ^+^ p1 --> p2
|
||||
!v''' = p0 --> p3 ^+^ 3 *^ ( p2 --> p1 )
|
||||
|
||||
a0, a1, a2, a3, a4, a5 :: r
|
||||
a0 = v ^.^ v'
|
||||
a1 = 3 * squaredNorm v' + 2 * v ^.^ v''
|
||||
a2 = 9 * v' ^.^ v'' + v ^.^ v'''
|
||||
a3 = 6 * squaredNorm v'' + 4 * v' ^.^ v'''
|
||||
a4 = 5 * v'' ^.^ v'''
|
||||
a5 = squaredNorm v'''
|
||||
!a0 = v ^.^ v'
|
||||
!a1 = 3 * squaredNorm v' + 2 * v ^.^ v''
|
||||
!a2 = 9 * v' ^.^ v'' + v ^.^ v'''
|
||||
!a3 = 6 * squaredNorm v'' + 4 * v' ^.^ v'''
|
||||
!a4 = 5 * v'' ^.^ v'''
|
||||
!a5 = squaredNorm v'''
|
||||
|
||||
-- | Finds the closest point to a given point on a cubic Bézier curve.
|
||||
closestPoint :: forall v r p. ( Torsor v p, Inner r v, RealFloat r ) => Bezier p -> p -> ArgMin r ( r, p )
|
||||
closestPoint
|
||||
:: forall v r p. ( Torsor v p, Inner r v, RealFloat r, Prim r )
|
||||
=> Bezier p -> p -> ArgMin r ( r, p )
|
||||
closestPoint pts@( Bezier {..} ) c = pickClosest ( 0 :| 1 : roots ) -- todo: also include the self-intersection point if one exists
|
||||
where
|
||||
roots :: [ r ]
|
||||
roots = filter ( \ r -> r > 0 && r < 1 ) ( realRoots $ ddist @v pts c )
|
||||
roots = filter ( \ r -> r > 0 && r < 1 ) ( realRoots 2000 $ ddist @v pts c )
|
||||
|
||||
pickClosest :: NonEmpty r -> ArgMin r ( r, p )
|
||||
pickClosest ( s :| ss ) = go s q nm0 ss
|
||||
|
|
|
@ -46,6 +46,10 @@ import qualified Data.Sequence as Seq
|
|||
import Control.DeepSeq
|
||||
( NFData )
|
||||
|
||||
-- primitive
|
||||
import Data.Primitive.PrimArray
|
||||
( primArrayFromListN, unsafeThawPrimArray )
|
||||
|
||||
-- transformers
|
||||
import Control.Monad.Trans.State.Strict
|
||||
( execStateT, modify' )
|
||||
|
@ -241,9 +245,12 @@ fitPiece dist_tol t_tol maxIters p tp qs r tr =
|
|||
( dts_changed, argmax_sq_dist ) <- ( `execStateT` ( False, Max ( Arg 0 0 ) ) ) $ for_ ( zip qs [ 0 .. ] ) \( q, i ) -> do
|
||||
ti <- lift ( Unboxed.MVector.unsafeRead ts i )
|
||||
let
|
||||
poly :: [ Complex Double ]
|
||||
poly = map (:+ 0) $ Cubic.ddist @( Vector2D Double ) bez q
|
||||
ti' <- case laguerre epsilon 1 poly ( ti :+ 0 ) of
|
||||
laguerreStepResult :: Complex Double
|
||||
laguerreStepResult = runST do
|
||||
coeffs <- unsafeThawPrimArray . primArrayFromListN 6 . map (:+ 0)
|
||||
$ Cubic.ddist @( Vector2D Double ) bez q
|
||||
laguerre epsilon 1 coeffs ( ti :+ 0 )
|
||||
ti' <- case laguerreStepResult of
|
||||
x :+ y
|
||||
| isNaN x
|
||||
|| isNaN y
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
{-# LANGUAGE AllowAmbiguousTypes #-}
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE DeriveAnyClass #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DeriveTraversable #-}
|
||||
|
@ -54,6 +55,10 @@ import Data.Group
|
|||
import Data.Group.Generics
|
||||
()
|
||||
|
||||
-- primitive
|
||||
import Data.Primitive.Types
|
||||
( Prim )
|
||||
|
||||
-- MetaBrush
|
||||
import Math.Epsilon
|
||||
( epsilon )
|
||||
|
@ -125,25 +130,27 @@ subdivide ( Bezier {..} ) t = ( Bezier p0 q1 pt, Bezier pt r1 p2 )
|
|||
|
||||
-- | Polynomial coefficients of the derivative of the distance to a quadratic Bézier curve.
|
||||
ddist :: forall v r p. ( Torsor v p, Inner r v, RealFloat r ) => Bezier p -> p -> [ r ]
|
||||
ddist ( Bezier {..} ) c = [ a0, a1, a2, a3 ]
|
||||
ddist ( Bezier {..} ) c = [ a3, a2, a1, a0 ]
|
||||
where
|
||||
v, v', v'' :: v
|
||||
v = c --> p0
|
||||
v' = p0 --> p1
|
||||
v'' = p1 --> p0 ^+^ p1 --> p2
|
||||
!v = c --> p0
|
||||
!v' = p0 --> p1
|
||||
!v'' = p1 --> p0 ^+^ p1 --> p2
|
||||
|
||||
a0, a1, a2, a3 :: r
|
||||
a0 = v ^.^ v'
|
||||
a1 = v ^.^ v'' + 2 * squaredNorm v'
|
||||
a2 = 3 * v' ^.^ v''
|
||||
a3 = squaredNorm v''
|
||||
!a0 = v ^.^ v'
|
||||
!a1 = v ^.^ v'' + 2 * squaredNorm v'
|
||||
!a2 = 3 * v' ^.^ v''
|
||||
!a3 = squaredNorm v''
|
||||
|
||||
-- | Finds the closest point to a given point on a quadratic Bézier curve.
|
||||
closestPoint :: forall v r p. ( Torsor v p, Inner r v, RealFloat r ) => Bezier p -> p -> ArgMin r ( r, p )
|
||||
closestPoint
|
||||
:: forall v r p. ( Torsor v p, Inner r v, RealFloat r, Prim r )
|
||||
=> Bezier p -> p -> ArgMin r ( r, p )
|
||||
closestPoint pts@( Bezier {..} ) c = pickClosest ( 0 :| 1 : roots )
|
||||
where
|
||||
roots :: [ r ]
|
||||
roots = filter ( \ r -> r > 0 && r < 1 ) ( realRoots $ ddist @v pts c )
|
||||
roots = filter ( \ r -> r > 0 && r < 1 ) ( realRoots 2000 $ ddist @v pts c )
|
||||
|
||||
pickClosest :: NonEmpty r -> ArgMin r ( r, p )
|
||||
pickClosest ( s :| ss ) = go s q nm0 ss
|
||||
|
|
|
@ -78,7 +78,7 @@ import Math.Module
|
|||
, lerp, squaredNorm
|
||||
)
|
||||
import Math.Roots
|
||||
( realRoots )
|
||||
( solveQuadratic )
|
||||
import Math.Vector2D
|
||||
( Point2D(..), Vector2D(..), cross )
|
||||
|
||||
|
@ -580,14 +580,14 @@ withTangent tgt ( spt0 :<| spt1 :<| spts ) =
|
|||
| otherwise
|
||||
= Nothing
|
||||
in
|
||||
case mapMaybe correctTangentParam $ realRoots [ c01, 2 * ( c12 - c01 ), c01 + c23 - 2 * c12 ] of
|
||||
case mapMaybe correctTangentParam $ solveQuadratic c01 ( 2 * ( c12 - c01 ) ) ( c01 + c23 - 2 * c12 ) of
|
||||
( t : _ )
|
||||
-> Offset i ( Just t ) ( MkVector2D $ Cubic.bezier @( Vector2D Double ) bez t )
|
||||
-- Fallback in case we couldn't solve the quadratic for some reason.
|
||||
_
|
||||
| Just s <- between tgt tgt0 tgt2
|
||||
-- Fallback in case we couldn't solve the quadratic for some reason.
|
||||
-> Offset i ( Just s ) ( MkVector2D $ Cubic.bezier @( Vector2D Double ) bez s )
|
||||
-- Otherwise: go to next piece of the curve.
|
||||
-- Otherwise: go to next piece of the curve.
|
||||
| otherwise
|
||||
-> continue ( i + 3 ) tgt2 sp3 ps
|
||||
go _ _ _ _ _
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Math.Epsilon
|
||||
( epsilon )
|
||||
( epsilon, nearZero )
|
||||
where
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
@ -10,3 +10,6 @@ module Math.Epsilon
|
|||
{-# SPECIALISE epsilon :: Double #-}
|
||||
epsilon :: forall r. RealFloat r => r
|
||||
epsilon = encodeFloat 1 ( 5 - floatDigits ( 0 :: r ) )
|
||||
|
||||
nearZero :: RealFloat r => r -> Bool
|
||||
nearZero x = abs x < epsilon
|
||||
|
|
|
@ -1,120 +1,215 @@
|
|||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE BlockArguments #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE NamedWildCards #-}
|
||||
{-# LANGUAGE PartialTypeSignatures #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeApplications #-}
|
||||
|
||||
module Math.Roots
|
||||
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
|
||||
|
||||
where
|
||||
module Math.Roots where
|
||||
|
||||
-- base
|
||||
import Control.Monad
|
||||
( unless )
|
||||
import Control.Monad.ST
|
||||
( ST, runST )
|
||||
import Data.Complex
|
||||
( Complex(..), magnitude )
|
||||
import Data.List.NonEmpty
|
||||
( NonEmpty(..), toList )
|
||||
import Data.Maybe
|
||||
( mapMaybe )
|
||||
|
||||
-- primitive
|
||||
import Control.Monad.Primitive
|
||||
( PrimMonad(PrimState) )
|
||||
import Data.Primitive.PrimArray
|
||||
( PrimArray, MutablePrimArray
|
||||
, primArrayFromList
|
||||
, getSizeofMutablePrimArray, sizeofPrimArray
|
||||
, unsafeThawPrimArray, cloneMutablePrimArray
|
||||
, shrinkMutablePrimArray, readPrimArray, writePrimArray
|
||||
)
|
||||
import Data.Primitive.Types
|
||||
( Prim )
|
||||
|
||||
-- prim-instances
|
||||
import Data.Primitive.Instances
|
||||
() -- instance Prim a => Prim ( Complex a )
|
||||
|
||||
-- MetaBrush
|
||||
import Math.Epsilon
|
||||
( epsilon )
|
||||
( epsilon, nearZero )
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
-- | Find real roots of a polynomial.
|
||||
|
||||
-- Coefficients are given in order of increasing degree, e.g.:
|
||||
-- x² + 7 is given by [ 7, 0, 1 ].
|
||||
realRoots :: forall r. RealFloat r => [ r ] -> [ r ]
|
||||
realRoots p = mapMaybe isReal ( roots epsilon 10000 ( map (:+ 0) p ) )
|
||||
-- | Real solutions to a quadratic equation.
|
||||
solveQuadratic :: forall a. RealFloat a => a -> a -> a -> [ a ]
|
||||
solveQuadratic a0 a1 a2
|
||||
| nearZero a1 && nearZero a2
|
||||
= if nearZero a0
|
||||
then [ 0, 0.5, 1 ] -- convention
|
||||
else []
|
||||
| nearZero ( a0 * a0 * a2 / ( a1 * a1 ) )
|
||||
= [ - a0 / a1 ]
|
||||
| disc < 0
|
||||
= [] -- non-real solutions
|
||||
| otherwise
|
||||
= let
|
||||
r :: a
|
||||
r =
|
||||
if a1 >= 0
|
||||
then 2 * a0 / ( - a1 - sqrt disc )
|
||||
else 0.5 * ( - a1 + sqrt disc) / a2
|
||||
in [ r, -r - a1 ]
|
||||
where
|
||||
isReal :: Complex r -> Maybe r
|
||||
disc :: a
|
||||
disc = a1 * a1 - 4 * a0 * a2
|
||||
|
||||
-- | Find real roots of a polynomial.
|
||||
--
|
||||
-- Coefficients are given in order of decreasing degree, e.g.:
|
||||
-- x² + 7 is given by [ 1, 0, 7 ].
|
||||
realRoots :: forall a. ( RealFloat a, Prim a ) => Int -> [ a ] -> [ a ]
|
||||
realRoots maxIters coeffs = mapMaybe isReal ( roots epsilon maxIters ( map (:+ 0) coeffs ) )
|
||||
where
|
||||
isReal :: Complex a -> Maybe a
|
||||
isReal ( a :+ b )
|
||||
| abs b < epsilon = Just a
|
||||
| otherwise = Nothing
|
||||
|
||||
-- | Compute all roots of a polynomial using Laguerre's method and (forward) deflation.
|
||||
--
|
||||
-- Polynomial coefficients are given in order of ascending degree (e.g. constant coefficient first).
|
||||
-- Polynomial coefficients are given in order of descending degree (e.g. constant coefficient last).
|
||||
--
|
||||
-- N.B. The forward deflation process is only guaranteed to be numerically stable
|
||||
-- if Laguerre's method finds roots in increasing order of magnitude.
|
||||
roots :: forall a. RealFloat a => a -> Int -> [ Complex a ] -> [ Complex a ]
|
||||
roots eps maxIters p = go p []
|
||||
where
|
||||
go :: [ Complex a ] -> [ Complex a ] -> [ Complex a ]
|
||||
go q rs
|
||||
| length q <= 2 = r : rs
|
||||
| otherwise = go ( deflate r q ) ( r : rs )
|
||||
where
|
||||
r :: Complex a
|
||||
r = laguerre eps maxIters q 0
|
||||
-- Start the iteration at 0 for best chance of numerical stability.
|
||||
roots :: forall a. ( RealFloat a, Prim a ) => a -> Int -> [ Complex a ] -> [ Complex a ]
|
||||
roots eps maxIters coeffs = runST do
|
||||
let
|
||||
coeffPrimArray :: PrimArray ( Complex a )
|
||||
coeffPrimArray = primArrayFromList coeffs
|
||||
sz :: Int
|
||||
sz = sizeofPrimArray coeffPrimArray
|
||||
p <- unsafeThawPrimArray coeffPrimArray
|
||||
let
|
||||
go :: Int -> [ Complex a ] -> ST _s [ Complex a ]
|
||||
go i rs = do
|
||||
!r <- laguerre eps maxIters p 0 -- Start Laguerre's method at 0 for best chance of numerical stability.
|
||||
if i <= 2
|
||||
then pure ( r : rs )
|
||||
else do
|
||||
deflate r p
|
||||
go ( i - 1 ) ( r : rs )
|
||||
go sz []
|
||||
|
||||
-- | Deflate a polynomial: factor out a root of the polynomial.
|
||||
--
|
||||
-- The polynomial must have degree at least 2.
|
||||
deflate :: forall a. Num a => a -> [ a ] -> [ a ]
|
||||
deflate r ( _ : c : cs ) = toList $ go ( c :| cs )
|
||||
where
|
||||
go :: NonEmpty a -> NonEmpty a
|
||||
go ( a :| [] ) = a :| []
|
||||
go ( a :| a' : as ) = case go ( a' :| as ) of
|
||||
( b' :| bs ) -> ( a + r * b' ) :| ( b' : bs )
|
||||
deflate _ _ = error "deflate: polynomial of degree < 2"
|
||||
deflate :: forall a m s. ( Num a, Prim a, PrimMonad m, s ~ PrimState m ) => a -> MutablePrimArray s a -> m ()
|
||||
deflate r p = do
|
||||
deg <- subtract 1 <$> getSizeofMutablePrimArray p
|
||||
case compare deg 2 of
|
||||
LT -> pure ()
|
||||
EQ -> shrinkMutablePrimArray p deg
|
||||
GT -> do
|
||||
shrinkMutablePrimArray p deg
|
||||
let
|
||||
go :: a -> Int -> m ()
|
||||
go b i = unless ( i >= deg ) do
|
||||
ai <- readPrimArray p i
|
||||
writePrimArray p i ( ai + r * b )
|
||||
go ai ( i + 1 )
|
||||
a0 <- readPrimArray p 0
|
||||
go a0 1
|
||||
|
||||
-- | Laguerre's method.
|
||||
laguerre
|
||||
:: forall a. RealFloat a
|
||||
=> a -- ^ error tolerance
|
||||
-> Int -- ^ max number of iterations
|
||||
-> [ Complex a ] -- ^ polynomial
|
||||
-> Complex a -- ^ initial point
|
||||
-> Complex a
|
||||
laguerre eps maxIters p = go maxIters
|
||||
where
|
||||
p', p'' :: [ Complex a ]
|
||||
p' = derivative p
|
||||
p'' = derivative p'
|
||||
go :: Int -> Complex a -> Complex a
|
||||
go iterationsLeft x
|
||||
| iterationsLeft <= 1
|
||||
|| magnitude ( x' - x ) < eps = x'
|
||||
| otherwise = go ( iterationsLeft - 1 ) x'
|
||||
where
|
||||
x' :: Complex a
|
||||
x' = laguerreStep eps p p' p'' x
|
||||
:: forall a m s
|
||||
. ( RealFloat a, Prim a, PrimMonad m, s ~ PrimState m )
|
||||
=> a -- ^ error tolerance
|
||||
-> Int -- ^ max number of iterations
|
||||
-> MutablePrimArray s ( Complex a ) -- ^ polynomial
|
||||
-> Complex a -- ^ initial point
|
||||
-> m ( Complex a )
|
||||
laguerre eps maxIters p x0 = do
|
||||
p' <- derivative p
|
||||
p'' <- derivative p'
|
||||
let
|
||||
go :: Int -> Complex a -> m ( Complex a )
|
||||
go iterationsLeft x = do
|
||||
x' <- laguerreStep eps p p' p'' x
|
||||
if iterationsLeft <= 1 || magnitude ( x' - x ) < eps
|
||||
then pure x'
|
||||
else go ( iterationsLeft - 1 ) x'
|
||||
go maxIters x0
|
||||
|
||||
-- | Take a single step in Laguerre's method.
|
||||
laguerreStep
|
||||
:: forall a. RealFloat a
|
||||
=> a -- ^ error tolerance
|
||||
-> [ Complex a ] -- ^ polynomial
|
||||
-> [ Complex a ] -- ^ first derivative of polynomial
|
||||
-> [ Complex a ] -- ^ second derivative of polynomial
|
||||
-> Complex a -- ^ initial point
|
||||
-> Complex a
|
||||
laguerreStep eps p p' p'' x
|
||||
| magnitude px < eps = x
|
||||
| otherwise = x - n / denom
|
||||
:: forall a m s
|
||||
. ( RealFloat a, Prim a, PrimMonad m, s ~ PrimState m )
|
||||
=> a -- ^ error tolerance
|
||||
-> MutablePrimArray s ( Complex a ) -- ^ polynomial
|
||||
-> MutablePrimArray s ( Complex a ) -- ^ first derivative of polynomial
|
||||
-> MutablePrimArray s ( Complex a ) -- ^ second derivative of polynomial
|
||||
-> Complex a -- ^ initial point
|
||||
-> m ( Complex a )
|
||||
laguerreStep eps p p' p'' x = do
|
||||
n <- fromIntegral @_ @a <$> getSizeofMutablePrimArray p
|
||||
px <- eval p x
|
||||
if magnitude px < eps
|
||||
then pure x
|
||||
else do
|
||||
p'x <- eval p' x
|
||||
p''x <- eval p'' x
|
||||
let
|
||||
g = p'x / px
|
||||
g² = g * g
|
||||
h = g² - p''x / px
|
||||
delta = sqrt $ ( n - 1 ) *: ( n *: h - g² )
|
||||
gp = g + delta
|
||||
gm = g - delta
|
||||
denom
|
||||
| magnitude gm > magnitude gp
|
||||
= gm
|
||||
| otherwise
|
||||
= gp
|
||||
pure $ x - n *: ( recip denom )
|
||||
|
||||
where
|
||||
n = fromIntegral ( length p )
|
||||
px = eval p x
|
||||
p'x = eval p' x
|
||||
p''x = eval p'' x
|
||||
g = p'x / px
|
||||
g² = g * g
|
||||
h = g² - p''x / px
|
||||
delta = sqrt $ ( n - 1 ) * ( n * h - g² )
|
||||
gp = g + delta
|
||||
gm = g - delta
|
||||
denom
|
||||
| magnitude gm > magnitude gp
|
||||
= gm
|
||||
| otherwise
|
||||
= gp
|
||||
(*:) :: a -> Complex a -> Complex a
|
||||
r *: (u :+ v) = ( r * u ) :+ ( r * v )
|
||||
|
||||
-- | Evaluate a polynomial.
|
||||
eval :: Num a => [ a ] -> a -> a
|
||||
eval as x = foldr ( \ a b -> a + x * b ) 0 as
|
||||
eval
|
||||
:: forall a m s
|
||||
. ( Num a, Prim a, PrimMonad m, s ~ PrimState m )
|
||||
=> MutablePrimArray s a -> a -> m a
|
||||
eval p x = do
|
||||
n <- getSizeofMutablePrimArray p
|
||||
let
|
||||
go :: a -> Int -> m a
|
||||
go !a i =
|
||||
if i >= n
|
||||
then pure a
|
||||
else do
|
||||
!b <- readPrimArray p i
|
||||
go ( b + x * a ) ( i + 1 )
|
||||
an <- readPrimArray p 0
|
||||
go an 1
|
||||
|
||||
-- | Derivative of a polynomial.
|
||||
derivative :: Num a => [ a ] -> [ a ]
|
||||
derivative as = zipWith ( \ i a -> fromIntegral i * a ) [ ( 1 :: Int ) .. ] ( tail as )
|
||||
derivative
|
||||
:: forall a m s
|
||||
. ( Num a, Prim a, PrimMonad m, s ~ PrimState m )
|
||||
=> MutablePrimArray s a -> m ( MutablePrimArray s a )
|
||||
derivative p = do
|
||||
deg <- subtract 1 <$> getSizeofMutablePrimArray p
|
||||
p' <- cloneMutablePrimArray p 0 deg
|
||||
let
|
||||
go :: Int -> m ()
|
||||
go i = unless ( i >= deg ) do
|
||||
a <- readPrimArray p' i
|
||||
writePrimArray p' i ( a * fromIntegral ( deg - i ) )
|
||||
go 0
|
||||
pure p'
|
||||
|
|
Loading…
Reference in a new issue