From 01fdd9a1262e7063cf7d320cec2751e585671871 Mon Sep 17 00:00:00 2001 From: sheaf Date: Sun, 21 Apr 2024 14:19:37 +0200 Subject: [PATCH] Improve robustness of quadratic equation solver MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Based on ideas from the paper "The Ins and Outs of Solving Quadratic Equations with Floating-Point Arithmetic" (Frédéric Goualard, 2023) Could still be improved further, but I think this is acceptable for now. --- brush-strokes/brush-strokes.cabal | 2 + brush-strokes/cabal.project | 1 + brush-strokes/src/lib/Math/Algebra/Dual.hs | 14 +-- .../src/lib/Math/Algebra/Dual/Internal.hs | 10 +- brush-strokes/src/lib/Math/Bezier/Cubic.hs | 8 +- brush-strokes/src/lib/Math/Linear/Internal.hs | 24 ++--- brush-strokes/src/lib/Math/Root/Isolation.hs | 28 ++++-- brush-strokes/src/lib/Math/Roots.hs | 97 +++++++++++++++---- cabal.project | 1 + 9 files changed, 133 insertions(+), 52 deletions(-) diff --git a/brush-strokes/brush-strokes.cabal b/brush-strokes/brush-strokes.cabal index 0bf66a7..cf5afde 100644 --- a/brush-strokes/brush-strokes.cabal +++ b/brush-strokes/brush-strokes.cabal @@ -153,6 +153,8 @@ library ^>= 3.3.7.0 , filepath >= 1.4 && < 1.6 + , fp-ieee + ^>= 0.1.0.4 , groups-generic ^>= 0.3.1.0 , parallel diff --git a/brush-strokes/cabal.project b/brush-strokes/cabal.project index 22cd7a7..adb02a2 100644 --- a/brush-strokes/cabal.project +++ b/brush-strokes/cabal.project @@ -2,6 +2,7 @@ packages: . constraints: acts -finitary, + fp-ieee +fma3, rounded-hw -pure-hs -c99 -avx512 +ghc-prim -x87-long-double tests: True diff --git a/brush-strokes/src/lib/Math/Algebra/Dual.hs b/brush-strokes/src/lib/Math/Algebra/Dual.hs index d32f2cd..fc1f342 100644 --- a/brush-strokes/src/lib/Math/Algebra/Dual.hs +++ b/brush-strokes/src/lib/Math/Algebra/Dual.hs @@ -123,9 +123,9 @@ chainRule ( D df ) ( D dg ) = uncurryD2 :: D 2 a ~ D 2 ( ℝ 1 ) => D 2 ( ℝ 1 ) ( C 2 a b ) -> a -> D 2 ( ℝ 2 ) b uncurryD2 ( D21 ( D b_t0 ) ( T ( D dbdt_t0 ) ) ( T ( D d2bdt2_t0 ) ) ) s0 = - let !( D21 b_t0s0 dbds_t0s0 d2bds2_t0s0 ) = b_t0 s0 - !( D21 dbdt_t0s0 d2bdtds_t0s0 _ ) = dbdt_t0 s0 - !( D21 d2bdt2_t0s0 _ _ ) = d2bdt2_t0 s0 + let D21 b_t0s0 dbds_t0s0 d2bds2_t0s0 = b_t0 s0 + D21 dbdt_t0s0 d2bdtds_t0s0 _ = dbdt_t0 s0 + D21 d2bdt2_t0s0 _ _ = d2bdt2_t0 s0 in D22 b_t0s0 ( T dbdt_t0s0 ) dbds_t0s0 @@ -134,10 +134,10 @@ uncurryD2 ( D21 ( D b_t0 ) ( T ( D dbdt_t0 ) ) ( T ( D d2bdt2_t0 ) ) ) s0 = uncurryD3 :: D 3 a ~ D 3 ( ℝ 1 ) => D 3 ( ℝ 1 ) ( C 3 a b ) -> a -> D 3 ( ℝ 2 ) b uncurryD3 ( D31 ( D b_t0 ) ( T ( D dbdt_t0 ) ) ( T ( D d2bdt2_t0 ) ) ( T ( D d3bdt3_t0 ) ) ) s0 = - let !( D31 b_t0s0 dbds_t0s0 d2bds2_t0s0 d3bds3_t0s0 ) = b_t0 s0 - !( D31 dbdt_t0s0 d2bdtds_t0s0 d3bdtds2_t0s0 _ ) = dbdt_t0 s0 - !( D31 d2bdt2_t0s0 d3bdt2ds_t0s0 _ _ ) = d2bdt2_t0 s0 - !( D31 d3bdt3_t0s0 _ _ _ ) = d3bdt3_t0 s0 + let D31 b_t0s0 dbds_t0s0 d2bds2_t0s0 d3bds3_t0s0 = b_t0 s0 + D31 dbdt_t0s0 d2bdtds_t0s0 d3bdtds2_t0s0 _ = dbdt_t0 s0 + D31 d2bdt2_t0s0 d3bdt2ds_t0s0 _ _ = d2bdt2_t0 s0 + D31 d3bdt3_t0s0 _ _ _ = d3bdt3_t0 s0 in D32 b_t0s0 ( T dbdt_t0s0 ) dbds_t0s0 diff --git a/brush-strokes/src/lib/Math/Algebra/Dual/Internal.hs b/brush-strokes/src/lib/Math/Algebra/Dual/Internal.hs index 787c03f..5652779 100644 --- a/brush-strokes/src/lib/Math/Algebra/Dual/Internal.hs +++ b/brush-strokes/src/lib/Math/Algebra/Dual/Internal.hs @@ -139,7 +139,7 @@ data D1𝔸3 v = data D2𝔸3 v = D23 { _D23_v :: v , _D23_dx, _D23_dy, _D23_dz :: ( T v ) - , _D23_dxdx, _D23_dxdy, _D23_dydy, _D23_dxdz, _D23_dydz, _D23_dzdz :: !( T v ) + , _D23_dxdx, _D23_dxdy, _D23_dydy, _D23_dxdz, _D23_dydz, _D23_dzdz :: ( T v ) } deriving stock ( Show, Eq, Functor, Foldable, Traversable, Generic, Generic1 ) deriving anyclass NFData @@ -150,9 +150,9 @@ data D2𝔸3 v = data D3𝔸3 v = D33 { _D33_v :: v , _D33_dx, _D33_dy, _D33_dz :: ( T v ) - , _D33_dxdx, _D33_dxdy, _D33_dydy, _D33_dxdz, _D33_dydz, _D33_dzdz :: !( T v ) + , _D33_dxdx, _D33_dxdy, _D33_dydy, _D33_dxdz, _D33_dydz, _D33_dzdz :: ( T v ) , _D33_dxdxdx, _D33_dxdxdy, _D33_dxdydy, _D33_dydydy - , _D33_dxdxdz, _D33_dxdydz, _D33_dxdzdz, _D33_dydydz, _D33_dydzdz, _D33_dzdzdz :: !( T v ) + , _D33_dxdxdz, _D33_dxdydz, _D33_dxdzdz, _D33_dydydz, _D33_dydzdz, _D33_dzdzdz :: ( T v ) } deriving stock ( Show, Eq, Functor, Foldable, Traversable, Generic, Generic1 ) deriving anyclass NFData @@ -370,11 +370,13 @@ instance MonomialBasis D1𝔸2 where _D12_dx = T $ f $ Mon ( Vec [ 1, 0 ] ) _D12_dy = T $ f $ Mon ( Vec [ 0, 1 ] ) in D12 { .. } + {-# INLINE monTabulate #-} monIndex d = \ case Mon ( Vec [ 1, 0 ] ) -> unT $ _D12_dx d Mon ( Vec [ 0, 1 ] ) -> unT $ _D12_dy d _ -> _D12_v d + {-# INLINE monIndex #-} type instance Deg D2𝔸2 = 2 type instance Vars D2𝔸2 = 2 @@ -453,12 +455,14 @@ instance MonomialBasis D1𝔸3 where !_D13_dy = T $ f ( Mon ( Vec [ 0, 1, 0 ] ) ) !_D13_dz = T $ f ( Mon ( Vec [ 0, 0, 1 ] ) ) in D13 { .. } + {-# INLINE monTabulate #-} monIndex d = \ case Mon ( Vec [ 1, 0, 0 ] ) -> unT $ _D13_dx d Mon ( Vec [ 0, 1, 0 ] ) -> unT $ _D13_dy d Mon ( Vec [ 0, 0, 1 ] ) -> unT $ _D13_dz d _ -> _D13_v d + {-# INLINE monIndex #-} type instance Deg D2𝔸3 = 2 type instance Vars D2𝔸3 = 3 diff --git a/brush-strokes/src/lib/Math/Bezier/Cubic.hs b/brush-strokes/src/lib/Math/Bezier/Cubic.hs index ba4fc6d..16b72e2 100644 --- a/brush-strokes/src/lib/Math/Bezier/Cubic.hs +++ b/brush-strokes/src/lib/Math/Bezier/Cubic.hs @@ -38,6 +38,10 @@ import Data.Act import Control.DeepSeq ( NFData, NFData1 ) +-- fp-ieee +import Numeric.Floating.IEEE.NaN + ( RealFloatNaN ) + -- groups import Data.Group ( Group ) @@ -287,10 +291,12 @@ selfIntersectionParameters ( Bezier {..} ) = solveQuadratic c0 c1 c2 c2 = f1 + f2 - f3 -- | Extremal values of the Bézier parameter for a cubic Bézier curve. -extrema :: RealFloat r => Bezier r -> [ r ] +extrema :: RealFloatNaN r => Bezier r -> [ r ] extrema ( Bezier {..} ) = solveQuadratic c b a where a = p3 - 3 * p2 + 3 * p1 - p0 b = 2 * ( p0 - 2 * p1 + p2 ) c = p1 - p0 {-# INLINEABLE extrema #-} +{-# SPECIALISE extrema :: Bezier Double -> [ Double ] #-} + diff --git a/brush-strokes/src/lib/Math/Linear/Internal.hs b/brush-strokes/src/lib/Math/Linear/Internal.hs index 0459a5e..e7dad04 100644 --- a/brush-strokes/src/lib/Math/Linear/Internal.hs +++ b/brush-strokes/src/lib/Math/Linear/Internal.hs @@ -148,23 +148,20 @@ instance RepresentableQ Double ( ℝ 0 ) where instance RepresentableQ Double ( ℝ 1 ) where tabulateQ f = [|| ℝ1 $$( f ( Fin 1 ) ) ||] indexQ p = \ case - Fin 1 -> [|| unℝ1 $$p ||] - Fin i -> error $ "invalid index for ℝ 1: " ++ show i + _ -> [|| unℝ1 $$p ||] instance RepresentableQ Double ( ℝ 2 ) where tabulateQ f = [|| ℝ2 $$( f ( Fin 1 ) ) $$( f ( Fin 2 ) ) ||] indexQ p = \ case Fin 1 -> [|| _ℝ2_x $$p ||] - Fin 2 -> [|| _ℝ2_y $$p ||] - Fin i -> error $ "invalid index for ℝ 2: " ++ show i + _ -> [|| _ℝ2_y $$p ||] instance RepresentableQ Double ( ℝ 3 ) where tabulateQ f = [|| ℝ3 $$( f ( Fin 1 ) ) $$( f ( Fin 2 ) ) $$( f ( Fin 3 ) ) ||] indexQ p = \ case Fin 1 -> [|| _ℝ3_x $$p ||] Fin 2 -> [|| _ℝ3_y $$p ||] - Fin 3 -> [|| _ℝ3_z $$p ||] - Fin i -> error $ "invalid index for ℝ 3: " ++ show i + _ -> [|| _ℝ3_z $$p ||] instance RepresentableQ Double ( ℝ 4 ) where tabulateQ f = [|| ℝ4 $$( f ( Fin 1 ) ) $$( f ( Fin 2 ) ) $$( f ( Fin 3 ) ) $$( f ( Fin 4 ) ) ||] @@ -172,8 +169,7 @@ instance RepresentableQ Double ( ℝ 4 ) where Fin 1 -> [|| _ℝ4_x $$p ||] Fin 2 -> [|| _ℝ4_y $$p ||] Fin 3 -> [|| _ℝ4_z $$p ||] - Fin 4 -> [|| _ℝ4_w $$p ||] - Fin i -> error $ "invalid index for ℝ 4: " ++ show i + _ -> [|| _ℝ4_w $$p ||] instance Representable Double ( ℝ 0 ) where tabulate _ = ℝ0 @@ -185,8 +181,7 @@ instance Representable Double ( ℝ 1 ) where tabulate f = ℝ1 ( f ( Fin 1 ) ) {-# INLINE tabulate #-} index p = \ case - Fin 1 -> unℝ1 p - Fin i -> error $ "invalid index for ℝ 1: " ++ show i + _ -> unℝ1 p {-# INLINE index #-} instance Representable Double ( ℝ 2 ) where @@ -194,8 +189,7 @@ instance Representable Double ( ℝ 2 ) where {-# INLINE tabulate #-} index p = \ case Fin 1 -> _ℝ2_x p - Fin 2 -> _ℝ2_y p - Fin i -> error $ "invalid index for ℝ 2: " ++ show i + _ -> _ℝ2_y p {-# INLINE index #-} instance Representable Double ( ℝ 3 ) where @@ -204,8 +198,7 @@ instance Representable Double ( ℝ 3 ) where index p = \ case Fin 1 -> _ℝ3_x p Fin 2 -> _ℝ3_y p - Fin 3 -> _ℝ3_z p - Fin i -> error $ "invalid index for ℝ 3: " ++ show i + _ -> _ℝ3_z p {-# INLINE index #-} instance Representable Double ( ℝ 4 ) where @@ -215,6 +208,5 @@ instance Representable Double ( ℝ 4 ) where Fin 1 -> _ℝ4_x p Fin 2 -> _ℝ4_y p Fin 3 -> _ℝ4_z p - Fin 4 -> _ℝ4_w p - Fin i -> error $ "invalid index for ℝ 4: " ++ show i + _ -> _ℝ4_w p {-# INLINE index #-} diff --git a/brush-strokes/src/lib/Math/Root/Isolation.hs b/brush-strokes/src/lib/Math/Root/Isolation.hs index 8c364b2..45dca47 100644 --- a/brush-strokes/src/lib/Math/Root/Isolation.hs +++ b/brush-strokes/src/lib/Math/Root/Isolation.hs @@ -246,16 +246,26 @@ defaultRootIsolationAlgorithms minWidth narrowAbs box history -- Otherwise, do a normal round. -- Currently: we try an interval Gauss–Seidel step followed by box(1)-consistency. _ -> GaussSeidel defaultGaussSeidelOptions - NE.:| [ Box1 box1Options ] + NE.:| [ Box1 _box1Options ] + where - box1Options :: Box1Options n d - box1Options = + _box1Options :: Box1Options n d + _box1Options = Box1Options { box1EpsEq = narrowAbs , box1CoordsToNarrow = toList $ universe @n -- [ Fin 1, Fin 2 ] , box1EqsToUse = toList $ universe @d } + _box2Options :: Box2Options n d + _box2Options = + Box2Options + { box2EpsEq = narrowAbs + , box2LambdaMin = 0.001 + , box2CoordsToNarrow = toList $ universe @n + , box2EqsToUse = toList $ universe @d + } + -- Did we reduce the box width by at least "narrowAbs" in at least one of the dimensions? sufficientlySmallerThan :: Box n -> Box n -> Bool b1 `sufficientlySmallerThan` b2 = @@ -288,13 +298,15 @@ defaultBisectionOptions _minWidth _narrowAbs box = in unT ( origin @Double ) `inside` iRange' -- box(1)-consistency - --not $ null $ makeBox1Consistent eqs _minWidth _narrowAbs box' + --let box1Options = Box1Options _narrowAbs ( toList $ universe @n ) ( toList $ universe @d ) + --in not $ null $ makeBox1Consistent _minWidth box1Options eqs box' -- box(2)-consistency - --let box'' = makeBox2Consistent eqs _minWidth _narrowAbs 0.2 box' + --let box2Options = Box2Options _narrowAbs 0.001 ( toList $ universe @n ) ( toList $ universe @d ) + -- box'' = makeBox2Consistent _minWidth box2Options eqs box' -- iRange'' :: Box d - -- iRange'' = value @Double @1 @( Box n ) $ eqs box'' - --in origin @Double `inside` iRange'' + -- iRange'' = eqs box'' `monIndex` zeroMonomial + --in unT ( origin @Double ) `inside` iRange'' , fallbackBisectionDim = \ _roundHist _prevRoundsHist eqs -> let df = eqs box @@ -308,6 +320,8 @@ defaultBisectionOptions _minWidth _narrowAbs box = -- First, check if the largest dimension is over 10 times larger -- than the smallest dimension; if so bisect along that coordinate. + -- + -- TODO: filter out dimensions smaller than minimum width. in case sortOnArg ( width . coordInterval ) datPerCoord of [] -> error "dimension 0" [Arg _ d] -> (coordIndex d, "") diff --git a/brush-strokes/src/lib/Math/Roots.hs b/brush-strokes/src/lib/Math/Roots.hs index 3dac503..5ef255c 100644 --- a/brush-strokes/src/lib/Math/Roots.hs +++ b/brush-strokes/src/lib/Math/Roots.hs @@ -33,6 +33,10 @@ import Data.Maybe import Control.DeepSeq ( NFData, force ) +-- fp-ieee +import Numeric.Floating.IEEE.NaN + ( RealFloatNaN(copySign) ) + -- primitive import Control.Monad.Primitive ( PrimMonad(PrimState) ) @@ -46,6 +50,10 @@ import Data.Primitive.PrimArray import Data.Primitive.Types ( Prim ) +-- rounded-hw +import Numeric.Rounded.Hardware.Internal + ( fusedMultiplyAdd ) + -- brush-strokes import Math.Epsilon ( epsilon, nearZero ) @@ -55,32 +63,85 @@ import Math.Epsilon -- | Real solutions to a quadratic equation. -- -- Coefficients are given in order of increasing degree. +-- +-- Implementation taken from https://pbr-book.org/4ed/Utilities/Mathematical_Infrastructure#Quadratic. solveQuadratic - :: forall a. RealFloat a + :: forall a. RealFloatNaN a => a -- ^ constant coefficient -> a -- ^ linear coefficient -> a -- ^ quadratic coefficient -> [ a ] solveQuadratic c b a - | nearZero b && nearZero a - = if nearZero c - then [ 0, 0.5, 1 ] -- convention - else [] - | nearZero ( c * c * a / ( b * b ) ) - = [ -c / b ] - | disc < 0 - = [] -- non-real solutions + | isNaN a || isNaN b || isNaN c + || isInfinite a || isInfinite b || isInfinite c + = [] | otherwise - = let - r :: a - r = - if b >= 0 - then 2 * c / ( -b - sqrt disc ) - else 0.5 * ( -b + sqrt disc ) / a - in [ r, c / ( a * r ) ] + = case (a == 0, c == 0) of + -- First, handle all cases in which a or c is zero. + + -- bx = 0 + (True , True ) + | b == 0 + -> [ 0, 0.5, 1 ] -- convention + | otherwise + -> [ 0 ] + -- bx + c = 0 + (True , False) + | b == 0 + -> [ ] + | otherwise + -> [ -c / b ] + -- ax² + bx = 0 + (False, True ) + | b == 0 + -> [ 0 ] + | signum a == signum b + -> [ -b / a, 0 ] + | otherwise + -> [ 0, -b / a ] + -- General case: ax² + bx + c = 0 + (False, False) + | discr < 0 + -> [] + | otherwise + -> let rootDiscr = sqrt discr + q = -0.5 * ( b + copySign rootDiscr b ) + x1 = q / a + x2 = c / q + in if x1 > x2 + then [ x2, x1 ] + else [ x1, x2 ] + where discr = discriminant a b c +{-# INLINEABLE solveQuadratic #-} +{-# SPECIALISE solveQuadratic :: Float -> Float -> Float -> [ Float ] #-} +{-# SPECIALISE solveQuadratic :: Double -> Double -> Double -> [ Double ] #-} +-- TODO: implement the version from the paper +-- "The Ins and Outs of Solving Quadratic Equations with Floating-Point Arithmetic" +-- which is even more robust. + +-- | Kahan's method for computing the discriminant \( b^2 - 4ac \), +-- using a fused multiply-add operation to avoid cancellation in the naive +-- formula (if \( b^2 \) and \( 4ac \) are close). +-- +-- From "The Ins and Outs of Solving Quadratic Equations with Floating-Point Arithmetic", +-- (Frédéric Goualard, 2023). +discriminant :: RealFloat a => a -> a -> a -> a +discriminant a b c + -- b² and 4ac are different enough that b² - 4ac gives a good answer + | 3 * abs d_naive >= b² - m4ac + = d_naive + | otherwise + = let dp = fma b b -b² + dq = fma ( 4 * a ) c m4ac + in d_naive + ( dp - dq ) where - disc :: a - disc = b * b - 4 * a * c + b² = b * b + m4ac = -4 * a * c + d_naive = b² + m4ac + fma = fusedMultiplyAdd +{-# INLINEABLE discriminant #-} +{-# SPECIALISE discriminant :: Float -> Float -> Float -> Float #-} +{-# SPECIALISE discriminant :: Double -> Double -> Double -> Double #-} -------------------------------------------------------------------------------- -- Root finding using Laguerre's method diff --git a/cabal.project b/cabal.project index d564475..38256ca 100644 --- a/cabal.project +++ b/cabal.project @@ -3,6 +3,7 @@ packages: ., brush-strokes constraints: acts -finitary, -- brush-strokes +use-fma, + fp-ieee +fma3, rounded-hw -pure-hs -c99 -avx512 +ghc-prim -x87-long-double, text -simdutf -- text +simdutf causes the "digit" package to fail to build with undefined symbol linker errors