optimise root-finding functions

* use PrimArray to represent polynomials
  * add some strictness annotations
  * turn on some optimisation flags
  * use quadratic formula for quadratic polynomials
This commit is contained in:
sheaf 2020-09-19 00:01:02 +02:00
parent eb8e7012aa
commit 58ca70c1bd
9 changed files with 243 additions and 115 deletions

View file

@ -40,6 +40,8 @@ common common
>= 1.2.0.1 && < 2.0
, groups
^>= 0.4.1.0
, primitive
^>= 0.7.1.0
, transformers
^>= 0.5.6.2
@ -49,7 +51,11 @@ common common
ghc-options:
-O1
-fexpose-all-unfoldings
-funfolding-use-threshold=16
-fexcess-precision
-fspecialise-aggressively
-optc-O3
-optc-ffast-math
-Wall
-Wcompat
-fwarn-missing-local-signatures
@ -79,12 +85,14 @@ library
, Math.Vector2D
build-depends:
groups-generic
groups-generic
^>= 0.1.0.0
, hmatrix
^>= 0.20.0.0
, monad-par
^>= 0.3.5
, prim-instances
^>= 0.2
, vector
^>= 0.12.1.2

View file

@ -214,7 +214,7 @@ main = do
maxHistorySizeTVar <- STM.newTVarIO @Int 1000
fitParametersTVar <- STM.newTVarIO @FitParameters
( FitParameters
{ maxSubdiv = 10
{ maxSubdiv = 6
, nbSegments = 12
, dist_tol = 5e-3
, t_tol = 1e-4

View file

@ -1,3 +1,4 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
@ -72,7 +73,7 @@ newtype UniqueSupply = UniqueSupply { uniqueSupplyTVar :: STM.TVar Unique }
freshUnique :: UniqueSupply -> STM Unique
freshUnique ( UniqueSupply { uniqueSupplyTVar } ) = do
uniq@( Unique i ) <- STM.readTVar uniqueSupplyTVar
uniq@( Unique !i ) <- STM.readTVar uniqueSupplyTVar
STM.writeTVar uniqueSupplyTVar ( Unique ( succ i ) )
pure uniq

View file

@ -1,4 +1,5 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
@ -57,6 +58,10 @@ import Data.Group
import Data.Group.Generics
()
-- primitive
import Data.Primitive.Types
( Prim )
-- MetaBrush
import qualified Math.Bezier.Quadratic as Quadratic
( Bezier(..), bezier )
@ -154,28 +159,30 @@ subdivide ( Bezier {..} ) t = ( Bezier p0 q1 q2 pt, Bezier pt r1 r2 p3 )
-- | Polynomial coefficients of the derivative of the distance to a cubic Bézier curve.
ddist :: forall v r p. ( Torsor v p, Inner r v, RealFloat r ) => Bezier p -> p -> [ r ]
ddist ( Bezier {..} ) c = [ a0, a1, a2, a3, a4, a5 ]
ddist ( Bezier {..} ) c = [ a5, a4, a3, a2, a1, a0 ]
where
v, v', v'', v''' :: v
v = c --> p0
v' = p0 --> p1
v'' = p1 --> p0 ^+^ p1 --> p2
v''' = p0 --> p3 ^+^ 3 *^ ( p2 --> p1 )
!v = c --> p0
!v' = p0 --> p1
!v'' = p1 --> p0 ^+^ p1 --> p2
!v''' = p0 --> p3 ^+^ 3 *^ ( p2 --> p1 )
a0, a1, a2, a3, a4, a5 :: r
a0 = v ^.^ v'
a1 = 3 * squaredNorm v' + 2 * v ^.^ v''
a2 = 9 * v' ^.^ v'' + v ^.^ v'''
a3 = 6 * squaredNorm v'' + 4 * v' ^.^ v'''
a4 = 5 * v'' ^.^ v'''
a5 = squaredNorm v'''
!a0 = v ^.^ v'
!a1 = 3 * squaredNorm v' + 2 * v ^.^ v''
!a2 = 9 * v' ^.^ v'' + v ^.^ v'''
!a3 = 6 * squaredNorm v'' + 4 * v' ^.^ v'''
!a4 = 5 * v'' ^.^ v'''
!a5 = squaredNorm v'''
-- | Finds the closest point to a given point on a cubic Bézier curve.
closestPoint :: forall v r p. ( Torsor v p, Inner r v, RealFloat r ) => Bezier p -> p -> ArgMin r ( r, p )
closestPoint
:: forall v r p. ( Torsor v p, Inner r v, RealFloat r, Prim r )
=> Bezier p -> p -> ArgMin r ( r, p )
closestPoint pts@( Bezier {..} ) c = pickClosest ( 0 :| 1 : roots ) -- todo: also include the self-intersection point if one exists
where
roots :: [ r ]
roots = filter ( \ r -> r > 0 && r < 1 ) ( realRoots $ ddist @v pts c )
roots = filter ( \ r -> r > 0 && r < 1 ) ( realRoots 2000 $ ddist @v pts c )
pickClosest :: NonEmpty r -> ArgMin r ( r, p )
pickClosest ( s :| ss ) = go s q nm0 ss

View file

@ -46,6 +46,10 @@ import qualified Data.Sequence as Seq
import Control.DeepSeq
( NFData )
-- primitive
import Data.Primitive.PrimArray
( primArrayFromListN, unsafeThawPrimArray )
-- transformers
import Control.Monad.Trans.State.Strict
( execStateT, modify' )
@ -241,9 +245,12 @@ fitPiece dist_tol t_tol maxIters p tp qs r tr =
( dts_changed, argmax_sq_dist ) <- ( `execStateT` ( False, Max ( Arg 0 0 ) ) ) $ for_ ( zip qs [ 0 .. ] ) \( q, i ) -> do
ti <- lift ( Unboxed.MVector.unsafeRead ts i )
let
poly :: [ Complex Double ]
poly = map (:+ 0) $ Cubic.ddist @( Vector2D Double ) bez q
ti' <- case laguerre epsilon 1 poly ( ti :+ 0 ) of
laguerreStepResult :: Complex Double
laguerreStepResult = runST do
coeffs <- unsafeThawPrimArray . primArrayFromListN 6 . map (:+ 0)
$ Cubic.ddist @( Vector2D Double ) bez q
laguerre epsilon 1 coeffs ( ti :+ 0 )
ti' <- case laguerreStepResult of
x :+ y
| isNaN x
|| isNaN y

View file

@ -1,4 +1,5 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
@ -54,6 +55,10 @@ import Data.Group
import Data.Group.Generics
()
-- primitive
import Data.Primitive.Types
( Prim )
-- MetaBrush
import Math.Epsilon
( epsilon )
@ -125,25 +130,27 @@ subdivide ( Bezier {..} ) t = ( Bezier p0 q1 pt, Bezier pt r1 p2 )
-- | Polynomial coefficients of the derivative of the distance to a quadratic Bézier curve.
ddist :: forall v r p. ( Torsor v p, Inner r v, RealFloat r ) => Bezier p -> p -> [ r ]
ddist ( Bezier {..} ) c = [ a0, a1, a2, a3 ]
ddist ( Bezier {..} ) c = [ a3, a2, a1, a0 ]
where
v, v', v'' :: v
v = c --> p0
v' = p0 --> p1
v'' = p1 --> p0 ^+^ p1 --> p2
!v = c --> p0
!v' = p0 --> p1
!v'' = p1 --> p0 ^+^ p1 --> p2
a0, a1, a2, a3 :: r
a0 = v ^.^ v'
a1 = v ^.^ v'' + 2 * squaredNorm v'
a2 = 3 * v' ^.^ v''
a3 = squaredNorm v''
!a0 = v ^.^ v'
!a1 = v ^.^ v'' + 2 * squaredNorm v'
!a2 = 3 * v' ^.^ v''
!a3 = squaredNorm v''
-- | Finds the closest point to a given point on a quadratic Bézier curve.
closestPoint :: forall v r p. ( Torsor v p, Inner r v, RealFloat r ) => Bezier p -> p -> ArgMin r ( r, p )
closestPoint
:: forall v r p. ( Torsor v p, Inner r v, RealFloat r, Prim r )
=> Bezier p -> p -> ArgMin r ( r, p )
closestPoint pts@( Bezier {..} ) c = pickClosest ( 0 :| 1 : roots )
where
roots :: [ r ]
roots = filter ( \ r -> r > 0 && r < 1 ) ( realRoots $ ddist @v pts c )
roots = filter ( \ r -> r > 0 && r < 1 ) ( realRoots 2000 $ ddist @v pts c )
pickClosest :: NonEmpty r -> ArgMin r ( r, p )
pickClosest ( s :| ss ) = go s q nm0 ss

View file

@ -78,7 +78,7 @@ import Math.Module
, lerp, squaredNorm
)
import Math.Roots
( realRoots )
( solveQuadratic )
import Math.Vector2D
( Point2D(..), Vector2D(..), cross )
@ -580,14 +580,14 @@ withTangent tgt ( spt0 :<| spt1 :<| spts ) =
| otherwise
= Nothing
in
case mapMaybe correctTangentParam $ realRoots [ c01, 2 * ( c12 - c01 ), c01 + c23 - 2 * c12 ] of
case mapMaybe correctTangentParam $ solveQuadratic c01 ( 2 * ( c12 - c01 ) ) ( c01 + c23 - 2 * c12 ) of
( t : _ )
-> Offset i ( Just t ) ( MkVector2D $ Cubic.bezier @( Vector2D Double ) bez t )
-- Fallback in case we couldn't solve the quadratic for some reason.
_
| Just s <- between tgt tgt0 tgt2
-- Fallback in case we couldn't solve the quadratic for some reason.
-> Offset i ( Just s ) ( MkVector2D $ Cubic.bezier @( Vector2D Double ) bez s )
-- Otherwise: go to next piece of the curve.
-- Otherwise: go to next piece of the curve.
| otherwise
-> continue ( i + 3 ) tgt2 sp3 ps
go _ _ _ _ _

View file

@ -1,7 +1,7 @@
{-# LANGUAGE ScopedTypeVariables #-}
module Math.Epsilon
( epsilon )
( epsilon, nearZero )
where
--------------------------------------------------------------------------------
@ -10,3 +10,6 @@ module Math.Epsilon
{-# SPECIALISE epsilon :: Double #-}
epsilon :: forall r. RealFloat r => r
epsilon = encodeFloat 1 ( 5 - floatDigits ( 0 :: r ) )
nearZero :: RealFloat r => r -> Bool
nearZero x = abs x < epsilon

View file

@ -1,120 +1,215 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NamedWildCards #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Math.Roots
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
where
module Math.Roots where
-- base
import Control.Monad
( unless )
import Control.Monad.ST
( ST, runST )
import Data.Complex
( Complex(..), magnitude )
import Data.List.NonEmpty
( NonEmpty(..), toList )
import Data.Maybe
( mapMaybe )
-- primitive
import Control.Monad.Primitive
( PrimMonad(PrimState) )
import Data.Primitive.PrimArray
( PrimArray, MutablePrimArray
, primArrayFromList
, getSizeofMutablePrimArray, sizeofPrimArray
, unsafeThawPrimArray, cloneMutablePrimArray
, shrinkMutablePrimArray, readPrimArray, writePrimArray
)
import Data.Primitive.Types
( Prim )
-- prim-instances
import Data.Primitive.Instances
() -- instance Prim a => Prim ( Complex a )
-- MetaBrush
import Math.Epsilon
( epsilon )
( epsilon, nearZero )
--------------------------------------------------------------------------------
-- | Find real roots of a polynomial.
-- Coefficients are given in order of increasing degree, e.g.:
-- x² + 7 is given by [ 7, 0, 1 ].
realRoots :: forall r. RealFloat r => [ r ] -> [ r ]
realRoots p = mapMaybe isReal ( roots epsilon 10000 ( map (:+ 0) p ) )
-- | Real solutions to a quadratic equation.
solveQuadratic :: forall a. RealFloat a => a -> a -> a -> [ a ]
solveQuadratic a0 a1 a2
| nearZero a1 && nearZero a2
= if nearZero a0
then [ 0, 0.5, 1 ] -- convention
else []
| nearZero ( a0 * a0 * a2 / ( a1 * a1 ) )
= [ - a0 / a1 ]
| disc < 0
= [] -- non-real solutions
| otherwise
= let
r :: a
r =
if a1 >= 0
then 2 * a0 / ( - a1 - sqrt disc )
else 0.5 * ( - a1 + sqrt disc) / a2
in [ r, -r - a1 ]
where
isReal :: Complex r -> Maybe r
disc :: a
disc = a1 * a1 - 4 * a0 * a2
-- | Find real roots of a polynomial.
--
-- Coefficients are given in order of decreasing degree, e.g.:
-- x² + 7 is given by [ 1, 0, 7 ].
realRoots :: forall a. ( RealFloat a, Prim a ) => Int -> [ a ] -> [ a ]
realRoots maxIters coeffs = mapMaybe isReal ( roots epsilon maxIters ( map (:+ 0) coeffs ) )
where
isReal :: Complex a -> Maybe a
isReal ( a :+ b )
| abs b < epsilon = Just a
| otherwise = Nothing
-- | Compute all roots of a polynomial using Laguerre's method and (forward) deflation.
--
-- Polynomial coefficients are given in order of ascending degree (e.g. constant coefficient first).
-- Polynomial coefficients are given in order of descending degree (e.g. constant coefficient last).
--
-- N.B. The forward deflation process is only guaranteed to be numerically stable
-- if Laguerre's method finds roots in increasing order of magnitude.
roots :: forall a. RealFloat a => a -> Int -> [ Complex a ] -> [ Complex a ]
roots eps maxIters p = go p []
where
go :: [ Complex a ] -> [ Complex a ] -> [ Complex a ]
go q rs
| length q <= 2 = r : rs
| otherwise = go ( deflate r q ) ( r : rs )
where
r :: Complex a
r = laguerre eps maxIters q 0
-- Start the iteration at 0 for best chance of numerical stability.
roots :: forall a. ( RealFloat a, Prim a ) => a -> Int -> [ Complex a ] -> [ Complex a ]
roots eps maxIters coeffs = runST do
let
coeffPrimArray :: PrimArray ( Complex a )
coeffPrimArray = primArrayFromList coeffs
sz :: Int
sz = sizeofPrimArray coeffPrimArray
p <- unsafeThawPrimArray coeffPrimArray
let
go :: Int -> [ Complex a ] -> ST _s [ Complex a ]
go i rs = do
!r <- laguerre eps maxIters p 0 -- Start Laguerre's method at 0 for best chance of numerical stability.
if i <= 2
then pure ( r : rs )
else do
deflate r p
go ( i - 1 ) ( r : rs )
go sz []
-- | Deflate a polynomial: factor out a root of the polynomial.
--
-- The polynomial must have degree at least 2.
deflate :: forall a. Num a => a -> [ a ] -> [ a ]
deflate r ( _ : c : cs ) = toList $ go ( c :| cs )
where
go :: NonEmpty a -> NonEmpty a
go ( a :| [] ) = a :| []
go ( a :| a' : as ) = case go ( a' :| as ) of
( b' :| bs ) -> ( a + r * b' ) :| ( b' : bs )
deflate _ _ = error "deflate: polynomial of degree < 2"
deflate :: forall a m s. ( Num a, Prim a, PrimMonad m, s ~ PrimState m ) => a -> MutablePrimArray s a -> m ()
deflate r p = do
deg <- subtract 1 <$> getSizeofMutablePrimArray p
case compare deg 2 of
LT -> pure ()
EQ -> shrinkMutablePrimArray p deg
GT -> do
shrinkMutablePrimArray p deg
let
go :: a -> Int -> m ()
go b i = unless ( i >= deg ) do
ai <- readPrimArray p i
writePrimArray p i ( ai + r * b )
go ai ( i + 1 )
a0 <- readPrimArray p 0
go a0 1
-- | Laguerre's method.
laguerre
:: forall a. RealFloat a
=> a -- ^ error tolerance
-> Int -- ^ max number of iterations
-> [ Complex a ] -- ^ polynomial
-> Complex a -- ^ initial point
-> Complex a
laguerre eps maxIters p = go maxIters
where
p', p'' :: [ Complex a ]
p' = derivative p
p'' = derivative p'
go :: Int -> Complex a -> Complex a
go iterationsLeft x
| iterationsLeft <= 1
|| magnitude ( x' - x ) < eps = x'
| otherwise = go ( iterationsLeft - 1 ) x'
where
x' :: Complex a
x' = laguerreStep eps p p' p'' x
:: forall a m s
. ( RealFloat a, Prim a, PrimMonad m, s ~ PrimState m )
=> a -- ^ error tolerance
-> Int -- ^ max number of iterations
-> MutablePrimArray s ( Complex a ) -- ^ polynomial
-> Complex a -- ^ initial point
-> m ( Complex a )
laguerre eps maxIters p x0 = do
p' <- derivative p
p'' <- derivative p'
let
go :: Int -> Complex a -> m ( Complex a )
go iterationsLeft x = do
x' <- laguerreStep eps p p' p'' x
if iterationsLeft <= 1 || magnitude ( x' - x ) < eps
then pure x'
else go ( iterationsLeft - 1 ) x'
go maxIters x0
-- | Take a single step in Laguerre's method.
laguerreStep
:: forall a. RealFloat a
=> a -- ^ error tolerance
-> [ Complex a ] -- ^ polynomial
-> [ Complex a ] -- ^ first derivative of polynomial
-> [ Complex a ] -- ^ second derivative of polynomial
-> Complex a -- ^ initial point
-> Complex a
laguerreStep eps p p' p'' x
| magnitude px < eps = x
| otherwise = x - n / denom
:: forall a m s
. ( RealFloat a, Prim a, PrimMonad m, s ~ PrimState m )
=> a -- ^ error tolerance
-> MutablePrimArray s ( Complex a ) -- ^ polynomial
-> MutablePrimArray s ( Complex a ) -- ^ first derivative of polynomial
-> MutablePrimArray s ( Complex a ) -- ^ second derivative of polynomial
-> Complex a -- ^ initial point
-> m ( Complex a )
laguerreStep eps p p' p'' x = do
n <- fromIntegral @_ @a <$> getSizeofMutablePrimArray p
px <- eval p x
if magnitude px < eps
then pure x
else do
p'x <- eval p' x
p''x <- eval p'' x
let
g = p'x / px
g² = g * g
h = g² - p''x / px
delta = sqrt $ ( n - 1 ) *: ( n *: h - g² )
gp = g + delta
gm = g - delta
denom
| magnitude gm > magnitude gp
= gm
| otherwise
= gp
pure $ x - n *: ( recip denom )
where
n = fromIntegral ( length p )
px = eval p x
p'x = eval p' x
p''x = eval p'' x
g = p'x / px
g² = g * g
h = g² - p''x / px
delta = sqrt $ ( n - 1 ) * ( n * h - g² )
gp = g + delta
gm = g - delta
denom
| magnitude gm > magnitude gp
= gm
| otherwise
= gp
(*:) :: a -> Complex a -> Complex a
r *: (u :+ v) = ( r * u ) :+ ( r * v )
-- | Evaluate a polynomial.
eval :: Num a => [ a ] -> a -> a
eval as x = foldr ( \ a b -> a + x * b ) 0 as
eval
:: forall a m s
. ( Num a, Prim a, PrimMonad m, s ~ PrimState m )
=> MutablePrimArray s a -> a -> m a
eval p x = do
n <- getSizeofMutablePrimArray p
let
go :: a -> Int -> m a
go !a i =
if i >= n
then pure a
else do
!b <- readPrimArray p i
go ( b + x * a ) ( i + 1 )
an <- readPrimArray p 0
go an 1
-- | Derivative of a polynomial.
derivative :: Num a => [ a ] -> [ a ]
derivative as = zipWith ( \ i a -> fromIntegral i * a ) [ ( 1 :: Int ) .. ] ( tail as )
derivative
:: forall a m s
. ( Num a, Prim a, PrimMonad m, s ~ PrimState m )
=> MutablePrimArray s a -> m ( MutablePrimArray s a )
derivative p = do
deg <- subtract 1 <$> getSizeofMutablePrimArray p
p' <- cloneMutablePrimArray p 0 deg
let
go :: Int -> m ()
go i = unless ( i >= deg ) do
a <- readPrimArray p' i
writePrimArray p' i ( a * fromIntegral ( deg - i ) )
go 0
pure p'