From b3941a2834bd56dee549fb546ec53dd5d7f5b575 Mon Sep 17 00:00:00 2001 From: sheaf Date: Wed, 26 Aug 2020 00:22:07 +0200 Subject: [PATCH] use hmatrix for least squares solving MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * this fixes the convergence problems of cubic Bézier curve fitting --- MetaBrush.cabal | 6 +- cabal.project | 7 ++ src/lib/Math/Bezier/Cubic/Fit.hs | 105 ++++++++++++++++-------------- src/lib/Math/Linear/SVD.hs | 107 ------------------------------- src/lib/Math/Linear/Solve.hs | 21 ++++++ src/lib/Math/Vector2D.hs | 28 ++++---- 6 files changed, 106 insertions(+), 168 deletions(-) delete mode 100644 src/lib/Math/Linear/SVD.hs create mode 100644 src/lib/Math/Linear/Solve.hs diff --git a/MetaBrush.cabal b/MetaBrush.cabal index b5276b3..d330c3a 100644 --- a/MetaBrush.cabal +++ b/MetaBrush.cabal @@ -68,7 +68,7 @@ library , Math.Bezier.Quadratic , Math.Bezier.Stroke , Math.Epsilon - , Math.Linear.SVD + , Math.Linear.Solve , Math.Module , Math.Roots , Math.Vector2D @@ -76,8 +76,12 @@ library build-depends: groups-generic ^>= 0.1.0.0 + , hmatrix + ^>= 0.20.0.0 , vector ^>= 0.12.1.2 + , QuickCheck + ^>= 2.14.1 executable MetaBrush diff --git a/cabal.project b/cabal.project index 6132e2d..1600002 100644 --- a/cabal.project +++ b/cabal.project @@ -9,3 +9,10 @@ source-repository-package location: https://github.com/thestr4ng3r/gi-cairo-render tag: 8727c43cdf91aeedffc9cb4c5575f56660a86399 subdir: gi-cairo-render + +-- latest version of hmatrix +source-repository-package + type: git + location: https://github.com/haskell-numerics/hmatrix + tag: 08138810946c7eae2254feeb33269cd962d5e0c8 + subdir: packages/base diff --git a/src/lib/Math/Bezier/Cubic/Fit.hs b/src/lib/Math/Bezier/Cubic/Fit.hs index 2086368..62a10a3 100644 --- a/src/lib/Math/Bezier/Cubic/Fit.hs +++ b/src/lib/Math/Bezier/Cubic/Fit.hs @@ -16,6 +16,10 @@ import Data.Complex ( Complex(..) ) import Data.Foldable ( for_ ) +import Data.Functor + ( ($>) ) +import Data.Semigroup + ( Arg(..), Max(..), ArgMax ) -- acts import Data.Act @@ -42,8 +46,8 @@ import Math.Bezier.Cubic ( Bezier(..), bezier, ddist ) import Math.Epsilon ( epsilon ) -import Math.Linear.SVD - ( lsolve ) +import Math.Linear.Solve + ( linearSolve ) import Math.Module ( Module((*^), (^-^)) , Inner((^.^)), quadrance @@ -63,34 +67,42 @@ import Math.Vector2D -- * ends at \( r \) with tangent \( \textrm{t}_r \), -- * best fits the intermediate sequence of points \( \left ( q_i \right )_{i=1}^n \). -- +-- This function also returns \( \textrm{ArgMax}\ t_\textrm{max}\ d^2_\textrm{max}: \) +-- the parameter and squared distance of the worst-fitting point. +-- It is guaranteed that all points to fit lie within the tubular neighbourhood +-- of radius \( d_\textrm{max} \) of the fitted curve. +-- -- /Note/: the order of the intermediate points is important. -- -- Proceeds by fitting a cubic Bézier curve \( B(t) \), \( 0 \leqslant t \leqslant 1 \), -- with given endpoints and tangents, which minimises the sum of squares functional -- --- \[ \sum_{i=1}^n \left \| C(t_i) - q_i \right \|^2. \] +-- \[ \sum_{i=1}^n \Big \| B(t_i) - q_i \Big \|^2. \] -- -- The values of the parameters \( \left ( t_i \right )_{i=1}^n \) are recursively estimated, --- starting from uniform parametrisation. +-- starting from uniform parametrisation (this will be the fit if `maxIters` is 0). -- -- The iteration ends when any of the following conditions are satisfied: -- --- * each new estimated parameter values \( t_i' \) differs from +-- * each new estimated parameter value \( t_i' \) differs from -- its previous value \( t_i \) by less than \( \texttt{t_tol} \), --- * each point \( C(t_i) \) is within squared distance \( \texttt{sq_dist_tol} \) --- of the point \( q_i \) it is associated with, --- * the maximum iteration limit \( \texttt{maxCount} \) has been reached. +-- * each on-curve point \( B(t_i) \) is within distance \( \texttt{dist_tol} \) +-- of its corresponding point to fit \( q_i \), +-- * the maximum iteration limit \( \texttt{maxIters} \) has been reached. fitPiece :: Double -- ^ \( \texttt{t_tol} \), the tolerance for the Bézier parameter - -> Double -- ^ \( \texttt{sq_dist_tol} \), tolerance for the squared distance - -> Int -- ^ \( \texttt{maxCount} \), maximum number of iterations + -> Double -- ^ \( \texttt{dist_tol} \), tolerance for the distance + -> Int -- ^ \( \texttt{maxIters} \), maximum number of iterations -> Point2D Double -- ^ \( p \), start point -> Vector2D Double -- ^ \( \textrm{t}_p \), start tangent vector (length is ignored) -> [ Point2D Double ] -- ^ \( \left ( q_i \right )_{i=1}^n \), points to fit -> Point2D Double -- ^ \( r \), end point -> Vector2D Double -- ^ \( \textrm{t}_r \), end tangent vector (length is ignored) - -> Bezier ( Point2D Double ) -fitPiece t_tol sq_dist_tol maxCount p tp qs r tr = piece + -> ( Bezier ( Point2D Double ), ArgMax Double Double ) +fitPiece t_tol dist_tol maxIters p tp qs r tr = + runST do + ts <- Unboxed.Vector.unsafeThaw ( Unboxed.Vector.generate n uniform ) + loop ts 0 where n :: Int n = length qs @@ -103,12 +115,7 @@ fitPiece t_tol sq_dist_tol maxCount p tp qs r tr = piece f2 t = h0 t *^ ( MkVector2D p ) f3 t = h3 t *^ ( MkVector2D r ) - piece :: Bezier ( Point2D Double ) - piece = runST do - ts <- Unboxed.Vector.unsafeThaw ( Unboxed.Vector.generate n uniform ) - loop ts 0 - - loop :: forall s. Unboxed.MVector s Double -> Int -> ST s ( Bezier ( Point2D Double ) ) + loop :: forall s. Unboxed.MVector s Double -> Int -> ST s ( Bezier ( Point2D Double ), ArgMax Double Double ) loop ts count = do let hermiteParameters :: Mat22 Double -> Vector2D Double -> Int -> [ Point2D Double ] -> ST s ( Vector2D Double ) @@ -129,7 +136,7 @@ fitPiece t_tol sq_dist_tol maxCount p tp qs r tr = piece b1' = b1 + ( q' ^.^ f0i ) b2' = b2 + ( q' ^.^ f1i ) hermiteParameters ( Mat22 a11' a12' a21' a22' ) ( Vector2D b1' b2' ) ( i + 1 ) rest - hermiteParameters a b _ [] = pure ( lsolve a b ) + hermiteParameters a b _ [] = pure ( linearSolve a b ) Vector2D s1 s2 <- hermiteParameters ( Mat22 0 0 0 0 ) ( Vector2D 0 0 ) 0 qs @@ -141,37 +148,37 @@ fitPiece t_tol sq_dist_tol maxCount p tp qs r tr = piece bez :: Bezier ( Point2D Double ) bez = Bezier p cp1 cp2 r - if count >= maxCount - then pure bez - else do + -- Run one iteration of Laguerre's method to improve the parameter values t_i, + -- so that t_i' is a better approximation of the parameter + -- at which the curve is closest to q_i. + ( dts_changed, argmax_sq_dist ) <- ( `execStateT` ( False, Max ( Arg 0 0 ) ) ) $ for_ ( zip qs [ 0 .. ] ) \( q, i ) -> do + ti <- lift ( Unboxed.MVector.unsafeRead ts i ) + let + poly :: [ Complex Double ] + poly = map (:+ 0) $ ddist @( Vector2D Double ) bez q + ti' <- case laguerre epsilon 1 poly ( ti :+ 0 ) of + x :+ y + | abs y > epsilon + || isNaN x + || isNaN y + -> modify' ( first ( const True ) ) $> ti + | otherwise + -> pure x + let + dt, sq_dist :: Double + dt = abs ( ti' - ti ) + sq_dist = quadrance @( Vector2D Double ) q ( bezier @( Vector2D Double ) bez ti' ) + when ( dt > t_tol ) + ( modify' ( first ( const True ) ) ) + modify' ( second ( <> Max ( Arg ti' sq_dist ) ) ) + lift ( Unboxed.MVector.unsafeWrite ts i ti' ) - -- Run one iteration of Laguerre's method to improve the parameter values t_i, - -- so that t_i' is a better approximation of the parameter - -- at which the curve is closest to q_i. - ( ts_ok, pts_ok ) <- ( `execStateT` ( True, True ) ) $ for_ ( zip qs [ 0 .. ] ) \( q, i ) -> do - ti <- lift ( Unboxed.MVector.unsafeRead ts i ) - let - poly :: [ Complex Double ] - poly = map (:+ 0) $ ddist @( Vector2D Double ) bez q - ti' :: Double - ti' = case laguerre epsilon 1 poly ( ti :+ 0 ) of - x :+ y - | abs y > epsilon - || isNaN x - || isNaN y - -> ti - | otherwise - -> x - - when ( abs ( ti' - ti ) > t_tol ) - ( modify' ( first ( const False ) ) ) - when ( quadrance @( Vector2D Double ) q ( bezier @( Vector2D Double ) bez ti' ) > sq_dist_tol ) - ( modify' ( second ( const False ) ) ) - lift ( Unboxed.MVector.unsafeWrite ts i ti' ) - - if ts_ok || pts_ok - then pure bez - else loop ts ( count + 1 ) + case argmax_sq_dist of + Max ( Arg _ max_sq_dist ) + | count < maxIters + && ( dts_changed || max_sq_dist > dist_tol ^ ( 2 :: Int ) ) + -> loop ts ( count + 1 ) + _ -> pure ( bez, argmax_sq_dist ) -- | Cubic Hermite polynomial. h0, h1, h2, h3 :: Num t => t -> t diff --git a/src/lib/Math/Linear/SVD.hs b/src/lib/Math/Linear/SVD.hs deleted file mode 100644 index 0ad48c0..0000000 --- a/src/lib/Math/Linear/SVD.hs +++ /dev/null @@ -1,107 +0,0 @@ -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE ScopedTypeVariables #-} - -module Math.Linear.SVD - ( SVD(..), svd - , pinv - , lsolve - ) - where - --- base -import Data.Complex - ( Complex(..), magnitude ) - --- MetaBrush -import Math.Epsilon - ( epsilon ) -import Math.Vector2D - ( Vector2D(..), Mat22(..) ) - --------------------------------------------------------------------------------- - -data SVD a - = SVD - { u :: !( Complex a ) - , sv1 :: !a - , sv2 :: !a - , v :: !( Complex a ) - } - deriving stock Show - --- | Singular value decomposition of a real 2x2 matrix. -svd :: forall a. ( RealFloat a, Show a ) => Mat22 a -> SVD a -svd ( Mat22 a b c d ) - | abs f < tol - || any isNaN [ c1, c2, s1, s2 ] - || magnitude u < 1 - tol - || magnitude v < 1 - tol - = SVD { u = 1, v = 1, .. } - | otherwise - = SVD { .. } - where - tol = sqrt epsilon - - det = a * d - b * c - q = a * a + b * b + c * c + d * d - - n1 = q + 2 * det - n2 = q - 2 * det - - r1 = sqrt n1 - r2 = sqrt n2 - - sv1 = 0.5 * ( r1 + r2 ) - sv2 = 0.5 * ( r1 - r2 ) - - k = a * a - d * d - l = b * b - c * c - f = n1 * n2 - - - i = 0.5 / sqrt f - - ip = ( k + l ) * i - im = ( k - l ) * i - - c1 = sqrt ( 0.5 + ip ) - c2 = sqrt ( 0.5 + im ) - s1 = signum ( a * c + b * d ) * sqrt ( 0.5 - ip ) - s2 = signum ( a * b + c * d ) * sqrt ( 0.5 - im ) - - u = c1 :+ s1 - v = c2 :+ s2 - --- | Pseudo-inverse of a real 2x2 matrix. -pinv :: forall a. ( RealFloat a, Show a ) => Mat22 a -> Mat22 a -pinv mat = case svd mat of - SVD { u = c1 :+ s1, sv1, sv2, v = c2 :+ s2 } -> - Mat22 - ( c1 * c2 * rsv1 + s1 * s2 * rsv2 ) ( s1 * c2 * rsv1 - c1 * s2 * rsv2 ) - ( c1 * s2 * rsv1 - s1 * c2 * rsv2 ) ( s1 * s2 * rsv1 + c1 * c2 * rsv2 ) - where - rsv1, rsv2 :: a - rsv1 - | sv1 < epsilon - = sv1 - | otherwise - = recip sv1 - rsv2 - | abs sv2 < epsilon - = sv2 - | otherwise - = recip sv2 - --- | Solve a 2x2 system of linear equations. -lsolve :: forall a. ( RealFloat a, Show a ) => Mat22 a -> Vector2D a -> Vector2D a -lsolve mat ( Vector2D x y ) = Vector2D x' y' - where - x', y', a11, a12, a21, a22 :: a - Mat22 - a11 a12 - a21 a22 - = pinv mat - x' = a11 * x + a12 * y - y' = a21 * x + a22 * y diff --git a/src/lib/Math/Linear/Solve.hs b/src/lib/Math/Linear/Solve.hs new file mode 100644 index 0000000..0eee033 --- /dev/null +++ b/src/lib/Math/Linear/Solve.hs @@ -0,0 +1,21 @@ +module Math.Linear.Solve + ( linearSolve ) + where + +-- hmatrix +import qualified Numeric.LinearAlgebra as LAPACK + ( linearSolveLS ) +import qualified Numeric.LinearAlgebra.Data as HMatrix + ( Matrix, col, matrix, atIndex ) + +-- MetaBrush +import Math.Vector2D + ( Vector2D(..), Mat22(..) ) + +-------------------------------------------------------------------------------- + +linearSolve :: Mat22 Double -> Vector2D Double -> Vector2D Double +linearSolve ( Mat22 a b c d ) ( Vector2D p q ) = Vector2D ( sol `HMatrix.atIndex` (0,0) ) ( sol `HMatrix.atIndex` (1,0) ) + where + sol :: HMatrix.Matrix Double + sol = LAPACK.linearSolveLS ( HMatrix.matrix 2 [a,b,c,d] ) ( HMatrix.col [p,q] ) diff --git a/src/lib/Math/Vector2D.hs b/src/lib/Math/Vector2D.hs index 72525c9..55ba685 100644 --- a/src/lib/Math/Vector2D.hs +++ b/src/lib/Math/Vector2D.hs @@ -1,9 +1,10 @@ -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DeriveTraversable #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} module Math.Vector2D ( Point2D(..), Vector2D(.., Vector2D), Mat22(..) @@ -15,7 +16,7 @@ module Math.Vector2D import Data.Monoid ( Sum(..) ) import GHC.Generics - ( Generic ) + ( Generic, Generic1 ) -- acts import Data.Act @@ -23,7 +24,7 @@ import Data.Act -- generic-data import Generic.Data - ( GenericProduct(..) ) + ( Generically1(..), GenericProduct(..) ) -- groups import Data.Group @@ -40,12 +41,15 @@ import Math.Module -------------------------------------------------------------------------------- data Point2D a = Point2D !a !a - deriving stock ( Show, Eq, Generic, Functor, Foldable, Traversable ) + deriving stock ( Show, Eq, Generic, Generic1, Functor, Foldable, Traversable ) deriving ( Act ( Vector2D a ), Torsor ( Vector2D a ) ) via Vector2D a + deriving Applicative + via Generically1 Point2D newtype Vector2D a = MkVector2D { tip :: Point2D a } - deriving stock ( Show, Eq, Functor, Foldable, Traversable ) + deriving stock ( Show, Generic, Generic1, Foldable, Traversable ) + deriving newtype ( Eq, Functor, Applicative ) deriving ( Semigroup, Monoid, Group ) via GenericProduct ( Point2D ( Sum a ) ) @@ -70,4 +74,6 @@ cross ( Vector2D x1 y1 ) ( Vector2D x2 y2 ) data Mat22 a = Mat22 !a !a !a !a - deriving stock ( Show, Eq, Generic, Functor, Foldable, Traversable ) + deriving stock ( Show, Eq, Generic, Generic1, Functor, Foldable, Traversable ) + deriving Applicative + via Generically1 Mat22