mirror of
https://gitlab.com/sheaf/metabrush.git
synced 2024-11-23 15:34:06 +00:00
use hmatrix for least squares solving
* this fixes the convergence problems of cubic Bézier curve fitting
This commit is contained in:
parent
f16ac3fa93
commit
b3941a2834
|
@ -68,7 +68,7 @@ library
|
|||
, Math.Bezier.Quadratic
|
||||
, Math.Bezier.Stroke
|
||||
, Math.Epsilon
|
||||
, Math.Linear.SVD
|
||||
, Math.Linear.Solve
|
||||
, Math.Module
|
||||
, Math.Roots
|
||||
, Math.Vector2D
|
||||
|
@ -76,8 +76,12 @@ library
|
|||
build-depends:
|
||||
groups-generic
|
||||
^>= 0.1.0.0
|
||||
, hmatrix
|
||||
^>= 0.20.0.0
|
||||
, vector
|
||||
^>= 0.12.1.2
|
||||
, QuickCheck
|
||||
^>= 2.14.1
|
||||
|
||||
executable MetaBrush
|
||||
|
||||
|
|
|
@ -9,3 +9,10 @@ source-repository-package
|
|||
location: https://github.com/thestr4ng3r/gi-cairo-render
|
||||
tag: 8727c43cdf91aeedffc9cb4c5575f56660a86399
|
||||
subdir: gi-cairo-render
|
||||
|
||||
-- latest version of hmatrix
|
||||
source-repository-package
|
||||
type: git
|
||||
location: https://github.com/haskell-numerics/hmatrix
|
||||
tag: 08138810946c7eae2254feeb33269cd962d5e0c8
|
||||
subdir: packages/base
|
||||
|
|
|
@ -16,6 +16,10 @@ import Data.Complex
|
|||
( Complex(..) )
|
||||
import Data.Foldable
|
||||
( for_ )
|
||||
import Data.Functor
|
||||
( ($>) )
|
||||
import Data.Semigroup
|
||||
( Arg(..), Max(..), ArgMax )
|
||||
|
||||
-- acts
|
||||
import Data.Act
|
||||
|
@ -42,8 +46,8 @@ import Math.Bezier.Cubic
|
|||
( Bezier(..), bezier, ddist )
|
||||
import Math.Epsilon
|
||||
( epsilon )
|
||||
import Math.Linear.SVD
|
||||
( lsolve )
|
||||
import Math.Linear.Solve
|
||||
( linearSolve )
|
||||
import Math.Module
|
||||
( Module((*^), (^-^))
|
||||
, Inner((^.^)), quadrance
|
||||
|
@ -63,34 +67,42 @@ import Math.Vector2D
|
|||
-- * ends at \( r \) with tangent \( \textrm{t}_r \),
|
||||
-- * best fits the intermediate sequence of points \( \left ( q_i \right )_{i=1}^n \).
|
||||
--
|
||||
-- This function also returns \( \textrm{ArgMax}\ t_\textrm{max}\ d^2_\textrm{max}: \)
|
||||
-- the parameter and squared distance of the worst-fitting point.
|
||||
-- It is guaranteed that all points to fit lie within the tubular neighbourhood
|
||||
-- of radius \( d_\textrm{max} \) of the fitted curve.
|
||||
--
|
||||
-- /Note/: the order of the intermediate points is important.
|
||||
--
|
||||
-- Proceeds by fitting a cubic Bézier curve \( B(t) \), \( 0 \leqslant t \leqslant 1 \),
|
||||
-- with given endpoints and tangents, which minimises the sum of squares functional
|
||||
--
|
||||
-- \[ \sum_{i=1}^n \left \| C(t_i) - q_i \right \|^2. \]
|
||||
-- \[ \sum_{i=1}^n \Big \| B(t_i) - q_i \Big \|^2. \]
|
||||
--
|
||||
-- The values of the parameters \( \left ( t_i \right )_{i=1}^n \) are recursively estimated,
|
||||
-- starting from uniform parametrisation.
|
||||
-- starting from uniform parametrisation (this will be the fit if `maxIters` is 0).
|
||||
--
|
||||
-- The iteration ends when any of the following conditions are satisfied:
|
||||
--
|
||||
-- * each new estimated parameter values \( t_i' \) differs from
|
||||
-- * each new estimated parameter value \( t_i' \) differs from
|
||||
-- its previous value \( t_i \) by less than \( \texttt{t_tol} \),
|
||||
-- * each point \( C(t_i) \) is within squared distance \( \texttt{sq_dist_tol} \)
|
||||
-- of the point \( q_i \) it is associated with,
|
||||
-- * the maximum iteration limit \( \texttt{maxCount} \) has been reached.
|
||||
-- * each on-curve point \( B(t_i) \) is within distance \( \texttt{dist_tol} \)
|
||||
-- of its corresponding point to fit \( q_i \),
|
||||
-- * the maximum iteration limit \( \texttt{maxIters} \) has been reached.
|
||||
fitPiece
|
||||
:: Double -- ^ \( \texttt{t_tol} \), the tolerance for the Bézier parameter
|
||||
-> Double -- ^ \( \texttt{sq_dist_tol} \), tolerance for the squared distance
|
||||
-> Int -- ^ \( \texttt{maxCount} \), maximum number of iterations
|
||||
-> Double -- ^ \( \texttt{dist_tol} \), tolerance for the distance
|
||||
-> Int -- ^ \( \texttt{maxIters} \), maximum number of iterations
|
||||
-> Point2D Double -- ^ \( p \), start point
|
||||
-> Vector2D Double -- ^ \( \textrm{t}_p \), start tangent vector (length is ignored)
|
||||
-> [ Point2D Double ] -- ^ \( \left ( q_i \right )_{i=1}^n \), points to fit
|
||||
-> Point2D Double -- ^ \( r \), end point
|
||||
-> Vector2D Double -- ^ \( \textrm{t}_r \), end tangent vector (length is ignored)
|
||||
-> Bezier ( Point2D Double )
|
||||
fitPiece t_tol sq_dist_tol maxCount p tp qs r tr = piece
|
||||
-> ( Bezier ( Point2D Double ), ArgMax Double Double )
|
||||
fitPiece t_tol dist_tol maxIters p tp qs r tr =
|
||||
runST do
|
||||
ts <- Unboxed.Vector.unsafeThaw ( Unboxed.Vector.generate n uniform )
|
||||
loop ts 0
|
||||
where
|
||||
n :: Int
|
||||
n = length qs
|
||||
|
@ -103,12 +115,7 @@ fitPiece t_tol sq_dist_tol maxCount p tp qs r tr = piece
|
|||
f2 t = h0 t *^ ( MkVector2D p )
|
||||
f3 t = h3 t *^ ( MkVector2D r )
|
||||
|
||||
piece :: Bezier ( Point2D Double )
|
||||
piece = runST do
|
||||
ts <- Unboxed.Vector.unsafeThaw ( Unboxed.Vector.generate n uniform )
|
||||
loop ts 0
|
||||
|
||||
loop :: forall s. Unboxed.MVector s Double -> Int -> ST s ( Bezier ( Point2D Double ) )
|
||||
loop :: forall s. Unboxed.MVector s Double -> Int -> ST s ( Bezier ( Point2D Double ), ArgMax Double Double )
|
||||
loop ts count = do
|
||||
let
|
||||
hermiteParameters :: Mat22 Double -> Vector2D Double -> Int -> [ Point2D Double ] -> ST s ( Vector2D Double )
|
||||
|
@ -129,7 +136,7 @@ fitPiece t_tol sq_dist_tol maxCount p tp qs r tr = piece
|
|||
b1' = b1 + ( q' ^.^ f0i )
|
||||
b2' = b2 + ( q' ^.^ f1i )
|
||||
hermiteParameters ( Mat22 a11' a12' a21' a22' ) ( Vector2D b1' b2' ) ( i + 1 ) rest
|
||||
hermiteParameters a b _ [] = pure ( lsolve a b )
|
||||
hermiteParameters a b _ [] = pure ( linearSolve a b )
|
||||
|
||||
Vector2D s1 s2 <- hermiteParameters ( Mat22 0 0 0 0 ) ( Vector2D 0 0 ) 0 qs
|
||||
|
||||
|
@ -141,37 +148,37 @@ fitPiece t_tol sq_dist_tol maxCount p tp qs r tr = piece
|
|||
bez :: Bezier ( Point2D Double )
|
||||
bez = Bezier p cp1 cp2 r
|
||||
|
||||
if count >= maxCount
|
||||
then pure bez
|
||||
else do
|
||||
-- Run one iteration of Laguerre's method to improve the parameter values t_i,
|
||||
-- so that t_i' is a better approximation of the parameter
|
||||
-- at which the curve is closest to q_i.
|
||||
( dts_changed, argmax_sq_dist ) <- ( `execStateT` ( False, Max ( Arg 0 0 ) ) ) $ for_ ( zip qs [ 0 .. ] ) \( q, i ) -> do
|
||||
ti <- lift ( Unboxed.MVector.unsafeRead ts i )
|
||||
let
|
||||
poly :: [ Complex Double ]
|
||||
poly = map (:+ 0) $ ddist @( Vector2D Double ) bez q
|
||||
ti' <- case laguerre epsilon 1 poly ( ti :+ 0 ) of
|
||||
x :+ y
|
||||
| abs y > epsilon
|
||||
|| isNaN x
|
||||
|| isNaN y
|
||||
-> modify' ( first ( const True ) ) $> ti
|
||||
| otherwise
|
||||
-> pure x
|
||||
let
|
||||
dt, sq_dist :: Double
|
||||
dt = abs ( ti' - ti )
|
||||
sq_dist = quadrance @( Vector2D Double ) q ( bezier @( Vector2D Double ) bez ti' )
|
||||
when ( dt > t_tol )
|
||||
( modify' ( first ( const True ) ) )
|
||||
modify' ( second ( <> Max ( Arg ti' sq_dist ) ) )
|
||||
lift ( Unboxed.MVector.unsafeWrite ts i ti' )
|
||||
|
||||
-- Run one iteration of Laguerre's method to improve the parameter values t_i,
|
||||
-- so that t_i' is a better approximation of the parameter
|
||||
-- at which the curve is closest to q_i.
|
||||
( ts_ok, pts_ok ) <- ( `execStateT` ( True, True ) ) $ for_ ( zip qs [ 0 .. ] ) \( q, i ) -> do
|
||||
ti <- lift ( Unboxed.MVector.unsafeRead ts i )
|
||||
let
|
||||
poly :: [ Complex Double ]
|
||||
poly = map (:+ 0) $ ddist @( Vector2D Double ) bez q
|
||||
ti' :: Double
|
||||
ti' = case laguerre epsilon 1 poly ( ti :+ 0 ) of
|
||||
x :+ y
|
||||
| abs y > epsilon
|
||||
|| isNaN x
|
||||
|| isNaN y
|
||||
-> ti
|
||||
| otherwise
|
||||
-> x
|
||||
|
||||
when ( abs ( ti' - ti ) > t_tol )
|
||||
( modify' ( first ( const False ) ) )
|
||||
when ( quadrance @( Vector2D Double ) q ( bezier @( Vector2D Double ) bez ti' ) > sq_dist_tol )
|
||||
( modify' ( second ( const False ) ) )
|
||||
lift ( Unboxed.MVector.unsafeWrite ts i ti' )
|
||||
|
||||
if ts_ok || pts_ok
|
||||
then pure bez
|
||||
else loop ts ( count + 1 )
|
||||
case argmax_sq_dist of
|
||||
Max ( Arg _ max_sq_dist )
|
||||
| count < maxIters
|
||||
&& ( dts_changed || max_sq_dist > dist_tol ^ ( 2 :: Int ) )
|
||||
-> loop ts ( count + 1 )
|
||||
_ -> pure ( bez, argmax_sq_dist )
|
||||
|
||||
-- | Cubic Hermite polynomial.
|
||||
h0, h1, h2, h3 :: Num t => t -> t
|
||||
|
|
|
@ -1,107 +0,0 @@
|
|||
{-# LANGUAGE DerivingStrategies #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module Math.Linear.SVD
|
||||
( SVD(..), svd
|
||||
, pinv
|
||||
, lsolve
|
||||
)
|
||||
where
|
||||
|
||||
-- base
|
||||
import Data.Complex
|
||||
( Complex(..), magnitude )
|
||||
|
||||
-- MetaBrush
|
||||
import Math.Epsilon
|
||||
( epsilon )
|
||||
import Math.Vector2D
|
||||
( Vector2D(..), Mat22(..) )
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
data SVD a
|
||||
= SVD
|
||||
{ u :: !( Complex a )
|
||||
, sv1 :: !a
|
||||
, sv2 :: !a
|
||||
, v :: !( Complex a )
|
||||
}
|
||||
deriving stock Show
|
||||
|
||||
-- | Singular value decomposition of a real 2x2 matrix.
|
||||
svd :: forall a. ( RealFloat a, Show a ) => Mat22 a -> SVD a
|
||||
svd ( Mat22 a b c d )
|
||||
| abs f < tol
|
||||
|| any isNaN [ c1, c2, s1, s2 ]
|
||||
|| magnitude u < 1 - tol
|
||||
|| magnitude v < 1 - tol
|
||||
= SVD { u = 1, v = 1, .. }
|
||||
| otherwise
|
||||
= SVD { .. }
|
||||
where
|
||||
tol = sqrt epsilon
|
||||
|
||||
det = a * d - b * c
|
||||
q = a * a + b * b + c * c + d * d
|
||||
|
||||
n1 = q + 2 * det
|
||||
n2 = q - 2 * det
|
||||
|
||||
r1 = sqrt n1
|
||||
r2 = sqrt n2
|
||||
|
||||
sv1 = 0.5 * ( r1 + r2 )
|
||||
sv2 = 0.5 * ( r1 - r2 )
|
||||
|
||||
k = a * a - d * d
|
||||
l = b * b - c * c
|
||||
f = n1 * n2
|
||||
|
||||
|
||||
i = 0.5 / sqrt f
|
||||
|
||||
ip = ( k + l ) * i
|
||||
im = ( k - l ) * i
|
||||
|
||||
c1 = sqrt ( 0.5 + ip )
|
||||
c2 = sqrt ( 0.5 + im )
|
||||
s1 = signum ( a * c + b * d ) * sqrt ( 0.5 - ip )
|
||||
s2 = signum ( a * b + c * d ) * sqrt ( 0.5 - im )
|
||||
|
||||
u = c1 :+ s1
|
||||
v = c2 :+ s2
|
||||
|
||||
-- | Pseudo-inverse of a real 2x2 matrix.
|
||||
pinv :: forall a. ( RealFloat a, Show a ) => Mat22 a -> Mat22 a
|
||||
pinv mat = case svd mat of
|
||||
SVD { u = c1 :+ s1, sv1, sv2, v = c2 :+ s2 } ->
|
||||
Mat22
|
||||
( c1 * c2 * rsv1 + s1 * s2 * rsv2 ) ( s1 * c2 * rsv1 - c1 * s2 * rsv2 )
|
||||
( c1 * s2 * rsv1 - s1 * c2 * rsv2 ) ( s1 * s2 * rsv1 + c1 * c2 * rsv2 )
|
||||
where
|
||||
rsv1, rsv2 :: a
|
||||
rsv1
|
||||
| sv1 < epsilon
|
||||
= sv1
|
||||
| otherwise
|
||||
= recip sv1
|
||||
rsv2
|
||||
| abs sv2 < epsilon
|
||||
= sv2
|
||||
| otherwise
|
||||
= recip sv2
|
||||
|
||||
-- | Solve a 2x2 system of linear equations.
|
||||
lsolve :: forall a. ( RealFloat a, Show a ) => Mat22 a -> Vector2D a -> Vector2D a
|
||||
lsolve mat ( Vector2D x y ) = Vector2D x' y'
|
||||
where
|
||||
x', y', a11, a12, a21, a22 :: a
|
||||
Mat22
|
||||
a11 a12
|
||||
a21 a22
|
||||
= pinv mat
|
||||
x' = a11 * x + a12 * y
|
||||
y' = a21 * x + a22 * y
|
21
src/lib/Math/Linear/Solve.hs
Normal file
21
src/lib/Math/Linear/Solve.hs
Normal file
|
@ -0,0 +1,21 @@
|
|||
module Math.Linear.Solve
|
||||
( linearSolve )
|
||||
where
|
||||
|
||||
-- hmatrix
|
||||
import qualified Numeric.LinearAlgebra as LAPACK
|
||||
( linearSolveLS )
|
||||
import qualified Numeric.LinearAlgebra.Data as HMatrix
|
||||
( Matrix, col, matrix, atIndex )
|
||||
|
||||
-- MetaBrush
|
||||
import Math.Vector2D
|
||||
( Vector2D(..), Mat22(..) )
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
linearSolve :: Mat22 Double -> Vector2D Double -> Vector2D Double
|
||||
linearSolve ( Mat22 a b c d ) ( Vector2D p q ) = Vector2D ( sol `HMatrix.atIndex` (0,0) ) ( sol `HMatrix.atIndex` (1,0) )
|
||||
where
|
||||
sol :: HMatrix.Matrix Double
|
||||
sol = LAPACK.linearSolveLS ( HMatrix.matrix 2 [a,b,c,d] ) ( HMatrix.col [p,q] )
|
|
@ -1,9 +1,10 @@
|
|||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DeriveTraversable #-}
|
||||
{-# LANGUAGE DerivingVia #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
{-# LANGUAGE DeriveTraversable #-}
|
||||
{-# LANGUAGE DerivingVia #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
|
||||
module Math.Vector2D
|
||||
( Point2D(..), Vector2D(.., Vector2D), Mat22(..)
|
||||
|
@ -15,7 +16,7 @@ module Math.Vector2D
|
|||
import Data.Monoid
|
||||
( Sum(..) )
|
||||
import GHC.Generics
|
||||
( Generic )
|
||||
( Generic, Generic1 )
|
||||
|
||||
-- acts
|
||||
import Data.Act
|
||||
|
@ -23,7 +24,7 @@ import Data.Act
|
|||
|
||||
-- generic-data
|
||||
import Generic.Data
|
||||
( GenericProduct(..) )
|
||||
( Generically1(..), GenericProduct(..) )
|
||||
|
||||
-- groups
|
||||
import Data.Group
|
||||
|
@ -40,12 +41,15 @@ import Math.Module
|
|||
--------------------------------------------------------------------------------
|
||||
|
||||
data Point2D a = Point2D !a !a
|
||||
deriving stock ( Show, Eq, Generic, Functor, Foldable, Traversable )
|
||||
deriving stock ( Show, Eq, Generic, Generic1, Functor, Foldable, Traversable )
|
||||
deriving ( Act ( Vector2D a ), Torsor ( Vector2D a ) )
|
||||
via Vector2D a
|
||||
deriving Applicative
|
||||
via Generically1 Point2D
|
||||
|
||||
newtype Vector2D a = MkVector2D { tip :: Point2D a }
|
||||
deriving stock ( Show, Eq, Functor, Foldable, Traversable )
|
||||
deriving stock ( Show, Generic, Generic1, Foldable, Traversable )
|
||||
deriving newtype ( Eq, Functor, Applicative )
|
||||
deriving ( Semigroup, Monoid, Group )
|
||||
via GenericProduct ( Point2D ( Sum a ) )
|
||||
|
||||
|
@ -70,4 +74,6 @@ cross ( Vector2D x1 y1 ) ( Vector2D x2 y2 )
|
|||
|
||||
data Mat22 a
|
||||
= Mat22 !a !a !a !a
|
||||
deriving stock ( Show, Eq, Generic, Functor, Foldable, Traversable )
|
||||
deriving stock ( Show, Eq, Generic, Generic1, Functor, Foldable, Traversable )
|
||||
deriving Applicative
|
||||
via Generically1 Mat22
|
||||
|
|
Loading…
Reference in a new issue