From 58ca70c1bdeecbb5bc7b9d2a36876ba5bcd2356a Mon Sep 17 00:00:00 2001 From: sheaf Date: Sat, 19 Sep 2020 00:01:02 +0200 Subject: [PATCH] optimise root-finding functions * use PrimArray to represent polynomials * add some strictness annotations * turn on some optimisation flags * use quadratic formula for quadratic polynomials --- MetaBrush.cabal | 10 +- app/Main.hs | 2 +- src/app/MetaBrush/Unique.hs | 3 +- src/lib/Math/Bezier/Cubic.hs | 33 ++-- src/lib/Math/Bezier/Cubic/Fit.hs | 13 +- src/lib/Math/Bezier/Quadratic.hs | 27 ++-- src/lib/Math/Bezier/Stroke.hs | 8 +- src/lib/Math/Epsilon.hs | 5 +- src/lib/Math/Roots.hs | 257 +++++++++++++++++++++---------- 9 files changed, 243 insertions(+), 115 deletions(-) diff --git a/MetaBrush.cabal b/MetaBrush.cabal index 36757ae..c00e335 100644 --- a/MetaBrush.cabal +++ b/MetaBrush.cabal @@ -40,6 +40,8 @@ common common >= 1.2.0.1 && < 2.0 , groups ^>= 0.4.1.0 + , primitive + ^>= 0.7.1.0 , transformers ^>= 0.5.6.2 @@ -49,7 +51,11 @@ common common ghc-options: -O1 -fexpose-all-unfoldings + -funfolding-use-threshold=16 + -fexcess-precision -fspecialise-aggressively + -optc-O3 + -optc-ffast-math -Wall -Wcompat -fwarn-missing-local-signatures @@ -79,12 +85,14 @@ library , Math.Vector2D build-depends: - groups-generic + groups-generic ^>= 0.1.0.0 , hmatrix ^>= 0.20.0.0 , monad-par ^>= 0.3.5 + , prim-instances + ^>= 0.2 , vector ^>= 0.12.1.2 diff --git a/app/Main.hs b/app/Main.hs index 23e806e..24526f3 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -214,7 +214,7 @@ main = do maxHistorySizeTVar <- STM.newTVarIO @Int 1000 fitParametersTVar <- STM.newTVarIO @FitParameters ( FitParameters - { maxSubdiv = 10 + { maxSubdiv = 6 , nbSegments = 12 , dist_tol = 5e-3 , t_tol = 1e-4 diff --git a/src/app/MetaBrush/Unique.hs b/src/app/MetaBrush/Unique.hs index f222382..5c64a86 100644 --- a/src/app/MetaBrush/Unique.hs +++ b/src/app/MetaBrush/Unique.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} @@ -72,7 +73,7 @@ newtype UniqueSupply = UniqueSupply { uniqueSupplyTVar :: STM.TVar Unique } freshUnique :: UniqueSupply -> STM Unique freshUnique ( UniqueSupply { uniqueSupplyTVar } ) = do - uniq@( Unique i ) <- STM.readTVar uniqueSupplyTVar + uniq@( Unique !i ) <- STM.readTVar uniqueSupplyTVar STM.writeTVar uniqueSupplyTVar ( Unique ( succ i ) ) pure uniq diff --git a/src/lib/Math/Bezier/Cubic.hs b/src/lib/Math/Bezier/Cubic.hs index 0256b96..bebec72 100644 --- a/src/lib/Math/Bezier/Cubic.hs +++ b/src/lib/Math/Bezier/Cubic.hs @@ -1,4 +1,5 @@ {-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveTraversable #-} @@ -57,6 +58,10 @@ import Data.Group import Data.Group.Generics () +-- primitive +import Data.Primitive.Types + ( Prim ) + -- MetaBrush import qualified Math.Bezier.Quadratic as Quadratic ( Bezier(..), bezier ) @@ -154,28 +159,30 @@ subdivide ( Bezier {..} ) t = ( Bezier p0 q1 q2 pt, Bezier pt r1 r2 p3 ) -- | Polynomial coefficients of the derivative of the distance to a cubic Bézier curve. ddist :: forall v r p. ( Torsor v p, Inner r v, RealFloat r ) => Bezier p -> p -> [ r ] -ddist ( Bezier {..} ) c = [ a0, a1, a2, a3, a4, a5 ] +ddist ( Bezier {..} ) c = [ a5, a4, a3, a2, a1, a0 ] where v, v', v'', v''' :: v - v = c --> p0 - v' = p0 --> p1 - v'' = p1 --> p0 ^+^ p1 --> p2 - v''' = p0 --> p3 ^+^ 3 *^ ( p2 --> p1 ) + !v = c --> p0 + !v' = p0 --> p1 + !v'' = p1 --> p0 ^+^ p1 --> p2 + !v''' = p0 --> p3 ^+^ 3 *^ ( p2 --> p1 ) a0, a1, a2, a3, a4, a5 :: r - a0 = v ^.^ v' - a1 = 3 * squaredNorm v' + 2 * v ^.^ v'' - a2 = 9 * v' ^.^ v'' + v ^.^ v''' - a3 = 6 * squaredNorm v'' + 4 * v' ^.^ v''' - a4 = 5 * v'' ^.^ v''' - a5 = squaredNorm v''' + !a0 = v ^.^ v' + !a1 = 3 * squaredNorm v' + 2 * v ^.^ v'' + !a2 = 9 * v' ^.^ v'' + v ^.^ v''' + !a3 = 6 * squaredNorm v'' + 4 * v' ^.^ v''' + !a4 = 5 * v'' ^.^ v''' + !a5 = squaredNorm v''' -- | Finds the closest point to a given point on a cubic Bézier curve. -closestPoint :: forall v r p. ( Torsor v p, Inner r v, RealFloat r ) => Bezier p -> p -> ArgMin r ( r, p ) +closestPoint + :: forall v r p. ( Torsor v p, Inner r v, RealFloat r, Prim r ) + => Bezier p -> p -> ArgMin r ( r, p ) closestPoint pts@( Bezier {..} ) c = pickClosest ( 0 :| 1 : roots ) -- todo: also include the self-intersection point if one exists where roots :: [ r ] - roots = filter ( \ r -> r > 0 && r < 1 ) ( realRoots $ ddist @v pts c ) + roots = filter ( \ r -> r > 0 && r < 1 ) ( realRoots 2000 $ ddist @v pts c ) pickClosest :: NonEmpty r -> ArgMin r ( r, p ) pickClosest ( s :| ss ) = go s q nm0 ss diff --git a/src/lib/Math/Bezier/Cubic/Fit.hs b/src/lib/Math/Bezier/Cubic/Fit.hs index 1f84864..c47a651 100644 --- a/src/lib/Math/Bezier/Cubic/Fit.hs +++ b/src/lib/Math/Bezier/Cubic/Fit.hs @@ -46,6 +46,10 @@ import qualified Data.Sequence as Seq import Control.DeepSeq ( NFData ) +-- primitive +import Data.Primitive.PrimArray + ( primArrayFromListN, unsafeThawPrimArray ) + -- transformers import Control.Monad.Trans.State.Strict ( execStateT, modify' ) @@ -241,9 +245,12 @@ fitPiece dist_tol t_tol maxIters p tp qs r tr = ( dts_changed, argmax_sq_dist ) <- ( `execStateT` ( False, Max ( Arg 0 0 ) ) ) $ for_ ( zip qs [ 0 .. ] ) \( q, i ) -> do ti <- lift ( Unboxed.MVector.unsafeRead ts i ) let - poly :: [ Complex Double ] - poly = map (:+ 0) $ Cubic.ddist @( Vector2D Double ) bez q - ti' <- case laguerre epsilon 1 poly ( ti :+ 0 ) of + laguerreStepResult :: Complex Double + laguerreStepResult = runST do + coeffs <- unsafeThawPrimArray . primArrayFromListN 6 . map (:+ 0) + $ Cubic.ddist @( Vector2D Double ) bez q + laguerre epsilon 1 coeffs ( ti :+ 0 ) + ti' <- case laguerreStepResult of x :+ y | isNaN x || isNaN y diff --git a/src/lib/Math/Bezier/Quadratic.hs b/src/lib/Math/Bezier/Quadratic.hs index aca0caa..3a0621a 100644 --- a/src/lib/Math/Bezier/Quadratic.hs +++ b/src/lib/Math/Bezier/Quadratic.hs @@ -1,4 +1,5 @@ {-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveTraversable #-} @@ -54,6 +55,10 @@ import Data.Group import Data.Group.Generics () +-- primitive +import Data.Primitive.Types + ( Prim ) + -- MetaBrush import Math.Epsilon ( epsilon ) @@ -125,25 +130,27 @@ subdivide ( Bezier {..} ) t = ( Bezier p0 q1 pt, Bezier pt r1 p2 ) -- | Polynomial coefficients of the derivative of the distance to a quadratic Bézier curve. ddist :: forall v r p. ( Torsor v p, Inner r v, RealFloat r ) => Bezier p -> p -> [ r ] -ddist ( Bezier {..} ) c = [ a0, a1, a2, a3 ] +ddist ( Bezier {..} ) c = [ a3, a2, a1, a0 ] where v, v', v'' :: v - v = c --> p0 - v' = p0 --> p1 - v'' = p1 --> p0 ^+^ p1 --> p2 + !v = c --> p0 + !v' = p0 --> p1 + !v'' = p1 --> p0 ^+^ p1 --> p2 a0, a1, a2, a3 :: r - a0 = v ^.^ v' - a1 = v ^.^ v'' + 2 * squaredNorm v' - a2 = 3 * v' ^.^ v'' - a3 = squaredNorm v'' + !a0 = v ^.^ v' + !a1 = v ^.^ v'' + 2 * squaredNorm v' + !a2 = 3 * v' ^.^ v'' + !a3 = squaredNorm v'' -- | Finds the closest point to a given point on a quadratic Bézier curve. -closestPoint :: forall v r p. ( Torsor v p, Inner r v, RealFloat r ) => Bezier p -> p -> ArgMin r ( r, p ) +closestPoint + :: forall v r p. ( Torsor v p, Inner r v, RealFloat r, Prim r ) + => Bezier p -> p -> ArgMin r ( r, p ) closestPoint pts@( Bezier {..} ) c = pickClosest ( 0 :| 1 : roots ) where roots :: [ r ] - roots = filter ( \ r -> r > 0 && r < 1 ) ( realRoots $ ddist @v pts c ) + roots = filter ( \ r -> r > 0 && r < 1 ) ( realRoots 2000 $ ddist @v pts c ) pickClosest :: NonEmpty r -> ArgMin r ( r, p ) pickClosest ( s :| ss ) = go s q nm0 ss diff --git a/src/lib/Math/Bezier/Stroke.hs b/src/lib/Math/Bezier/Stroke.hs index 9a1e6e9..ddbd986 100644 --- a/src/lib/Math/Bezier/Stroke.hs +++ b/src/lib/Math/Bezier/Stroke.hs @@ -78,7 +78,7 @@ import Math.Module , lerp, squaredNorm ) import Math.Roots - ( realRoots ) + ( solveQuadratic ) import Math.Vector2D ( Point2D(..), Vector2D(..), cross ) @@ -580,14 +580,14 @@ withTangent tgt ( spt0 :<| spt1 :<| spts ) = | otherwise = Nothing in - case mapMaybe correctTangentParam $ realRoots [ c01, 2 * ( c12 - c01 ), c01 + c23 - 2 * c12 ] of + case mapMaybe correctTangentParam $ solveQuadratic c01 ( 2 * ( c12 - c01 ) ) ( c01 + c23 - 2 * c12 ) of ( t : _ ) -> Offset i ( Just t ) ( MkVector2D $ Cubic.bezier @( Vector2D Double ) bez t ) - -- Fallback in case we couldn't solve the quadratic for some reason. _ | Just s <- between tgt tgt0 tgt2 + -- Fallback in case we couldn't solve the quadratic for some reason. -> Offset i ( Just s ) ( MkVector2D $ Cubic.bezier @( Vector2D Double ) bez s ) - -- Otherwise: go to next piece of the curve. + -- Otherwise: go to next piece of the curve. | otherwise -> continue ( i + 3 ) tgt2 sp3 ps go _ _ _ _ _ diff --git a/src/lib/Math/Epsilon.hs b/src/lib/Math/Epsilon.hs index 1194a9d..918725c 100644 --- a/src/lib/Math/Epsilon.hs +++ b/src/lib/Math/Epsilon.hs @@ -1,7 +1,7 @@ {-# LANGUAGE ScopedTypeVariables #-} module Math.Epsilon - ( epsilon ) + ( epsilon, nearZero ) where -------------------------------------------------------------------------------- @@ -10,3 +10,6 @@ module Math.Epsilon {-# SPECIALISE epsilon :: Double #-} epsilon :: forall r. RealFloat r => r epsilon = encodeFloat 1 ( 5 - floatDigits ( 0 :: r ) ) + +nearZero :: RealFloat r => r -> Bool +nearZero x = abs x < epsilon diff --git a/src/lib/Math/Roots.hs b/src/lib/Math/Roots.hs index acc580c..0cebab4 100644 --- a/src/lib/Math/Roots.hs +++ b/src/lib/Math/Roots.hs @@ -1,120 +1,215 @@ -{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE NamedWildCards #-} +{-# LANGUAGE PartialTypeSignatures #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} -module Math.Roots +{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-} - where +module Math.Roots where -- base +import Control.Monad + ( unless ) +import Control.Monad.ST + ( ST, runST ) import Data.Complex ( Complex(..), magnitude ) -import Data.List.NonEmpty - ( NonEmpty(..), toList ) import Data.Maybe ( mapMaybe ) +-- primitive +import Control.Monad.Primitive + ( PrimMonad(PrimState) ) +import Data.Primitive.PrimArray + ( PrimArray, MutablePrimArray + , primArrayFromList + , getSizeofMutablePrimArray, sizeofPrimArray + , unsafeThawPrimArray, cloneMutablePrimArray + , shrinkMutablePrimArray, readPrimArray, writePrimArray + ) +import Data.Primitive.Types + ( Prim ) + +-- prim-instances +import Data.Primitive.Instances + () -- instance Prim a => Prim ( Complex a ) + -- MetaBrush import Math.Epsilon - ( epsilon ) + ( epsilon, nearZero ) -------------------------------------------------------------------------------- --- | Find real roots of a polynomial. - --- Coefficients are given in order of increasing degree, e.g.: --- x² + 7 is given by [ 7, 0, 1 ]. -realRoots :: forall r. RealFloat r => [ r ] -> [ r ] -realRoots p = mapMaybe isReal ( roots epsilon 10000 ( map (:+ 0) p ) ) +-- | Real solutions to a quadratic equation. +solveQuadratic :: forall a. RealFloat a => a -> a -> a -> [ a ] +solveQuadratic a0 a1 a2 + | nearZero a1 && nearZero a2 + = if nearZero a0 + then [ 0, 0.5, 1 ] -- convention + else [] + | nearZero ( a0 * a0 * a2 / ( a1 * a1 ) ) + = [ - a0 / a1 ] + | disc < 0 + = [] -- non-real solutions + | otherwise + = let + r :: a + r = + if a1 >= 0 + then 2 * a0 / ( - a1 - sqrt disc ) + else 0.5 * ( - a1 + sqrt disc) / a2 + in [ r, -r - a1 ] where - isReal :: Complex r -> Maybe r + disc :: a + disc = a1 * a1 - 4 * a0 * a2 + +-- | Find real roots of a polynomial. +-- +-- Coefficients are given in order of decreasing degree, e.g.: +-- x² + 7 is given by [ 1, 0, 7 ]. +realRoots :: forall a. ( RealFloat a, Prim a ) => Int -> [ a ] -> [ a ] +realRoots maxIters coeffs = mapMaybe isReal ( roots epsilon maxIters ( map (:+ 0) coeffs ) ) + where + isReal :: Complex a -> Maybe a isReal ( a :+ b ) | abs b < epsilon = Just a | otherwise = Nothing -- | Compute all roots of a polynomial using Laguerre's method and (forward) deflation. -- --- Polynomial coefficients are given in order of ascending degree (e.g. constant coefficient first). +-- Polynomial coefficients are given in order of descending degree (e.g. constant coefficient last). -- -- N.B. The forward deflation process is only guaranteed to be numerically stable -- if Laguerre's method finds roots in increasing order of magnitude. -roots :: forall a. RealFloat a => a -> Int -> [ Complex a ] -> [ Complex a ] -roots eps maxIters p = go p [] - where - go :: [ Complex a ] -> [ Complex a ] -> [ Complex a ] - go q rs - | length q <= 2 = r : rs - | otherwise = go ( deflate r q ) ( r : rs ) - where - r :: Complex a - r = laguerre eps maxIters q 0 - -- Start the iteration at 0 for best chance of numerical stability. +roots :: forall a. ( RealFloat a, Prim a ) => a -> Int -> [ Complex a ] -> [ Complex a ] +roots eps maxIters coeffs = runST do + let + coeffPrimArray :: PrimArray ( Complex a ) + coeffPrimArray = primArrayFromList coeffs + sz :: Int + sz = sizeofPrimArray coeffPrimArray + p <- unsafeThawPrimArray coeffPrimArray + let + go :: Int -> [ Complex a ] -> ST _s [ Complex a ] + go i rs = do + !r <- laguerre eps maxIters p 0 -- Start Laguerre's method at 0 for best chance of numerical stability. + if i <= 2 + then pure ( r : rs ) + else do + deflate r p + go ( i - 1 ) ( r : rs ) + go sz [] -- | Deflate a polynomial: factor out a root of the polynomial. -- -- The polynomial must have degree at least 2. -deflate :: forall a. Num a => a -> [ a ] -> [ a ] -deflate r ( _ : c : cs ) = toList $ go ( c :| cs ) - where - go :: NonEmpty a -> NonEmpty a - go ( a :| [] ) = a :| [] - go ( a :| a' : as ) = case go ( a' :| as ) of - ( b' :| bs ) -> ( a + r * b' ) :| ( b' : bs ) -deflate _ _ = error "deflate: polynomial of degree < 2" +deflate :: forall a m s. ( Num a, Prim a, PrimMonad m, s ~ PrimState m ) => a -> MutablePrimArray s a -> m () +deflate r p = do + deg <- subtract 1 <$> getSizeofMutablePrimArray p + case compare deg 2 of + LT -> pure () + EQ -> shrinkMutablePrimArray p deg + GT -> do + shrinkMutablePrimArray p deg + let + go :: a -> Int -> m () + go b i = unless ( i >= deg ) do + ai <- readPrimArray p i + writePrimArray p i ( ai + r * b ) + go ai ( i + 1 ) + a0 <- readPrimArray p 0 + go a0 1 -- | Laguerre's method. laguerre - :: forall a. RealFloat a - => a -- ^ error tolerance - -> Int -- ^ max number of iterations - -> [ Complex a ] -- ^ polynomial - -> Complex a -- ^ initial point - -> Complex a -laguerre eps maxIters p = go maxIters - where - p', p'' :: [ Complex a ] - p' = derivative p - p'' = derivative p' - go :: Int -> Complex a -> Complex a - go iterationsLeft x - | iterationsLeft <= 1 - || magnitude ( x' - x ) < eps = x' - | otherwise = go ( iterationsLeft - 1 ) x' - where - x' :: Complex a - x' = laguerreStep eps p p' p'' x + :: forall a m s + . ( RealFloat a, Prim a, PrimMonad m, s ~ PrimState m ) + => a -- ^ error tolerance + -> Int -- ^ max number of iterations + -> MutablePrimArray s ( Complex a ) -- ^ polynomial + -> Complex a -- ^ initial point + -> m ( Complex a ) +laguerre eps maxIters p x0 = do + p' <- derivative p + p'' <- derivative p' + let + go :: Int -> Complex a -> m ( Complex a ) + go iterationsLeft x = do + x' <- laguerreStep eps p p' p'' x + if iterationsLeft <= 1 || magnitude ( x' - x ) < eps + then pure x' + else go ( iterationsLeft - 1 ) x' + go maxIters x0 -- | Take a single step in Laguerre's method. laguerreStep - :: forall a. RealFloat a - => a -- ^ error tolerance - -> [ Complex a ] -- ^ polynomial - -> [ Complex a ] -- ^ first derivative of polynomial - -> [ Complex a ] -- ^ second derivative of polynomial - -> Complex a -- ^ initial point - -> Complex a -laguerreStep eps p p' p'' x - | magnitude px < eps = x - | otherwise = x - n / denom + :: forall a m s + . ( RealFloat a, Prim a, PrimMonad m, s ~ PrimState m ) + => a -- ^ error tolerance + -> MutablePrimArray s ( Complex a ) -- ^ polynomial + -> MutablePrimArray s ( Complex a ) -- ^ first derivative of polynomial + -> MutablePrimArray s ( Complex a ) -- ^ second derivative of polynomial + -> Complex a -- ^ initial point + -> m ( Complex a ) +laguerreStep eps p p' p'' x = do + n <- fromIntegral @_ @a <$> getSizeofMutablePrimArray p + px <- eval p x + if magnitude px < eps + then pure x + else do + p'x <- eval p' x + p''x <- eval p'' x + let + g = p'x / px + g² = g * g + h = g² - p''x / px + delta = sqrt $ ( n - 1 ) *: ( n *: h - g² ) + gp = g + delta + gm = g - delta + denom + | magnitude gm > magnitude gp + = gm + | otherwise + = gp + pure $ x - n *: ( recip denom ) + where - n = fromIntegral ( length p ) - px = eval p x - p'x = eval p' x - p''x = eval p'' x - g = p'x / px - g² = g * g - h = g² - p''x / px - delta = sqrt $ ( n - 1 ) * ( n * h - g² ) - gp = g + delta - gm = g - delta - denom - | magnitude gm > magnitude gp - = gm - | otherwise - = gp + (*:) :: a -> Complex a -> Complex a + r *: (u :+ v) = ( r * u ) :+ ( r * v ) -- | Evaluate a polynomial. -eval :: Num a => [ a ] -> a -> a -eval as x = foldr ( \ a b -> a + x * b ) 0 as +eval + :: forall a m s + . ( Num a, Prim a, PrimMonad m, s ~ PrimState m ) + => MutablePrimArray s a -> a -> m a +eval p x = do + n <- getSizeofMutablePrimArray p + let + go :: a -> Int -> m a + go !a i = + if i >= n + then pure a + else do + !b <- readPrimArray p i + go ( b + x * a ) ( i + 1 ) + an <- readPrimArray p 0 + go an 1 -- | Derivative of a polynomial. -derivative :: Num a => [ a ] -> [ a ] -derivative as = zipWith ( \ i a -> fromIntegral i * a ) [ ( 1 :: Int ) .. ] ( tail as ) +derivative + :: forall a m s + . ( Num a, Prim a, PrimMonad m, s ~ PrimState m ) + => MutablePrimArray s a -> m ( MutablePrimArray s a ) +derivative p = do + deg <- subtract 1 <$> getSizeofMutablePrimArray p + p' <- cloneMutablePrimArray p 0 deg + let + go :: Int -> m () + go i = unless ( i >= deg ) do + a <- readPrimArray p' i + writePrimArray p' i ( a * fromIntegral ( deg - i ) ) + go 0 + pure p'