From bdcac18ab988b4385ba4549b26e1f49b77716478 Mon Sep 17 00:00:00 2001 From: sheaf Date: Sun, 22 Jan 2023 04:51:23 +0100 Subject: [PATCH] add chain rule R^n -> R -> R --- src/splines/Math/Algebra/Dual.hs | 227 ++++++++++++++++++---- src/splines/Math/Algebra/Dual/Internal.hs | 79 ++++++-- src/splines/Math/Interval.hs | 18 +- src/splines/Math/Linear.hs | 11 +- src/splines/Math/Monomial.hs | 85 +++++++- 5 files changed, 350 insertions(+), 70 deletions(-) diff --git a/src/splines/Math/Algebra/Dual.hs b/src/splines/Math/Algebra/Dual.hs index f6f341d..ac8ba3c 100644 --- a/src/splines/Math/Algebra/Dual.hs +++ b/src/splines/Math/Algebra/Dual.hs @@ -411,9 +411,8 @@ instance Ring r => Ring ( D3๐”ธ4 r ) where -------------------------------------------------------------------------------- -- Field & transcendental instances --- TODO!! - deriving newtype instance Field r => Field ( D๐”ธ0 r ) + --instance Field r => Field ( D1๐”ธ1 r ) where --instance Field r => Field ( D1๐”ธ2 r ) where --instance Field r => Field ( D1๐”ธ3 r ) where @@ -426,7 +425,37 @@ instance Field r => Field ( D3๐”ธ1 r ) where instance Field r => Field ( D3๐”ธ2 r ) where instance Field r => Field ( D3๐”ธ3 r ) where instance Field r => Field ( D3๐”ธ4 r ) where +-- TODO +d1sin, d1cos :: Transcendental r => r -> D1๐”ธ1 r +d1sin x = + let !s = sin x + !c = cos x + in D11 s ( T c ) +d1cos x = + let !s = sin x + !c = cos x + in D11 c ( T -s ) + +d2sin, d2cos :: Transcendental r => r -> D2๐”ธ1 r +d2sin x = + let !s = sin x + !c = cos x + in D21 s ( T c ) ( T -s ) +d2cos x = + let !s = sin x + !c = cos x + in D21 c ( T -s ) ( T -c ) + +d3sin, d3cos :: Transcendental r => r -> D3๐”ธ1 r +d3sin x = + let !s = sin x + !c = cos x + in D31 s ( T c ) ( T -s ) ( T -c ) +d3cos x = + let !s = sin x + !c = cos x + in D31 c ( T -s ) ( T -c ) ( T s ) deriving newtype instance Transcendental r => Transcendental ( D๐”ธ0 r ) --instance Transcendental r => Transcendental ( D1๐”ธ1 r ) where @@ -434,38 +463,170 @@ deriving newtype instance Transcendental r => Transcendental ( D๐”ธ0 r ) --instance Transcendental r => Transcendental ( D1๐”ธ3 r ) where --instance Transcendental r => Transcendental ( D1๐”ธ4 r ) where instance Transcendental r => Transcendental ( D2๐”ธ1 r ) where + pi = konst @Double @2 @( โ„ 1 ) pi + sin df = + let + fromInt = fromInteger @r + add = (+) @r + times = (*) @r + pow = (^) @r + dg = d2sin @r ( _D21_v df ) + in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||] + [|| df ||] [|| dg ||] ) + cos df = + let + fromInt = fromInteger @r + add = (+) @r + times = (*) @r + pow = (^) @r + dg = d2cos @r ( _D21_v df ) + in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||] + [|| df ||] [|| dg ||] ) instance Transcendental r => Transcendental ( D2๐”ธ2 r ) where + pi = konst @Double @2 @( โ„ 2 ) pi + sin df = + let + fromInt = fromInteger @r + add = (+) @r + times = (*) @r + pow = (^) @r + dg = d2sin @r ( _D22_v df ) + in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||] + [|| df ||] [|| dg ||] ) + cos df = + let + fromInt = fromInteger @r + add = (+) @r + times = (*) @r + pow = (^) @r + dg = d2cos @r ( _D22_v df ) + in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||] + [|| df ||] [|| dg ||] ) instance Transcendental r => Transcendental ( D2๐”ธ3 r ) where pi = konst @Double @2 @( โ„ 3 ) pi - sin ( D23 v ( T dx ) ( T dy ) ( T dz ) ( T ddx ) ( T dxdy ) ( T ddy ) ( T dxdz ) ( T dydz ) ( T ddz ) ) - = let !s = sin v - !c = cos v - in D23 s - ( T $ c * dx ) ( T $ c * dy ) ( T $ c * dz ) - ( T $ 2 * c * ddx - s * ( dx ^ 2 ) ) - ( T $ 2 * c * dxdy - 2 * s * dx * dy ) - ( T $ 2 * c * ddy - s * ( dy ^ 2 ) ) - ( T $ 2 * c * dxdz - 2 * s * dx * dz ) - ( T $ 2 * c * dydz - 2 * s * dy * dz ) - ( T $ 2 * c * ddz - s * ( dz ^ 2 ) ) - - cos ( D23 v ( T dx ) ( T dy ) ( T dz ) ( T ddx ) ( T dxdy ) ( T ddy ) ( T dxdz ) ( T dydz ) ( T ddz ) ) - = let !s = sin v - !c = cos v - in D23 c - ( T $ -s * dx ) ( T $ -s * dy ) ( T $ -s * dz ) - ( T $ -2 * s * ddx - c * ( dx ^ 2 ) ) - ( T $ -2 * s * dxdy - 2 * c * dx * dy ) - ( T $ -2 * s * ddy - c * ( dy ^ 2 ) ) - ( T $ -2 * s * dxdz - 2 * c * dx * dz ) - ( T $ -2 * s * dydz - 2 * c * dy * dz ) - ( T $ -2 * s * ddz - c * ( dz ^ 2 ) ) + sin df = + let + fromInt = fromInteger @r + add = (+) @r + times = (*) @r + pow = (^) @r + dg = d2sin @r ( _D23_v df ) + in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||] + [|| df ||] [|| dg ||] ) + cos df = + let + fromInt = fromInteger @r + add = (+) @r + times = (*) @r + pow = (^) @r + dg = d2cos @r ( _D23_v df ) + in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||] + [|| df ||] [|| dg ||] ) instance Transcendental r => Transcendental ( D2๐”ธ4 r ) where + pi = konst @Double @2 @( โ„ 4 ) pi + sin df = + let + fromInt = fromInteger @r + add = (+) @r + times = (*) @r + pow = (^) @r + dg = d2sin @r ( _D24_v df ) + in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||] + [|| df ||] [|| dg ||] ) + cos df = + let + fromInt = fromInteger @r + add = (+) @r + times = (*) @r + pow = (^) @r + dg = d2cos @r ( _D24_v df ) + in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||] + [|| df ||] [|| dg ||] ) + instance Transcendental r => Transcendental ( D3๐”ธ1 r ) where + pi = konst @Double @3 @( โ„ 1 ) pi + sin df = + let + fromInt = fromInteger @r + add = (+) @r + times = (*) @r + pow = (^) @r + dg = d3sin @r ( _D31_v df ) + in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||] + [|| df ||] [|| dg ||] ) + cos df = + let + fromInt = fromInteger @r + add = (+) @r + times = (*) @r + pow = (^) @r + dg = d3cos @r ( _D31_v df ) + in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||] + [|| df ||] [|| dg ||] ) + instance Transcendental r => Transcendental ( D3๐”ธ2 r ) where + pi = konst @Double @3 @( โ„ 2 ) pi + sin df = + let + fromInt = fromInteger @r + add = (+) @r + times = (*) @r + pow = (^) @r + dg = d3sin @r ( _D32_v df ) + in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||] + [|| df ||] [|| dg ||] ) + cos df = + let + fromInt = fromInteger @r + add = (+) @r + times = (*) @r + pow = (^) @r + dg = d3cos @r ( _D32_v df ) + in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||] + [|| df ||] [|| dg ||] ) + instance Transcendental r => Transcendental ( D3๐”ธ3 r ) where + pi = konst @Double @3 @( โ„ 3 ) pi + sin df = + let + fromInt = fromInteger @r + add = (+) @r + times = (*) @r + pow = (^) @r + dg = d3sin @r ( _D33_v df ) + in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||] + [|| df ||] [|| dg ||] ) + cos df = + let + fromInt = fromInteger @r + add = (+) @r + times = (*) @r + pow = (^) @r + dg = d3cos @r ( _D33_v df ) + in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||] + [|| df ||] [|| dg ||] ) + instance Transcendental r => Transcendental ( D3๐”ธ4 r ) where + pi = konst @Double @3 @( โ„ 4 ) pi + sin df = + let + fromInt = fromInteger @r + add = (+) @r + times = (*) @r + pow = (^) @r + dg = d3sin @r ( _D34_v df ) + in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||] + [|| df ||] [|| dg ||] ) + cos df = + let + fromInt = fromInteger @r + add = (+) @r + times = (*) @r + pow = (^) @r + dg = d3cos @r ( _D34_v df ) + in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||] + [|| df ||] [|| dg ||] ) -------------------------------------------------------------------------------- -- HasChainRule instances. @@ -513,7 +674,7 @@ instance HasChainRule Double 2 ( โ„ 1 ) where let !o = origin @Double @( T w ) !p = (^+^) @Double @( T w ) !s = (^*) @Double @( T w ) - in $$( chainRuleQ + in $$( chainRule1NQ [|| o ||] [|| p ||] [|| s ||] [|| df ||] [|| dg ||] ) @@ -548,7 +709,7 @@ instance HasChainRule Double 3 ( โ„ 1 ) where let !o = origin @Double @( T w ) !p = (^+^) @Double @( T w ) !s = (^*) @Double @( T w ) - in $$( chainRuleQ + in $$( chainRule1NQ [|| o ||] [|| p ||] [|| s ||] [|| df ||] [|| dg ||] ) @@ -583,7 +744,7 @@ instance HasChainRule Double 2 ( โ„ 2 ) where let !o = origin @Double @( T w ) !p = (^+^) @Double @( T w ) !s = (^*) @Double @( T w ) - in $$( chainRuleQ + in $$( chainRule1NQ [|| o ||] [|| p ||] [|| s ||] [|| df ||] [|| dg ||] ) @@ -618,7 +779,7 @@ instance HasChainRule Double 3 ( โ„ 2 ) where let !o = origin @Double @( T w ) !p = (^+^) @Double @( T w ) !s = (^*) @Double @( T w ) - in $$( chainRuleQ + in $$( chainRule1NQ [|| o ||] [|| p ||] [|| s ||] [|| df ||] [|| dg ||] ) @@ -653,7 +814,7 @@ instance HasChainRule Double 2 ( โ„ 3 ) where let !o = origin @Double @( T w ) !p = (^+^) @Double @( T w ) !s = (^*) @Double @( T w ) - in $$( chainRuleQ + in $$( chainRule1NQ [|| o ||] [|| p ||] [|| s ||] [|| df ||] [|| dg ||] ) @@ -688,7 +849,7 @@ instance HasChainRule Double 3 ( โ„ 3 ) where let !o = origin @Double @( T w ) !p = (^+^) @Double @( T w ) !s = (^*) @Double @( T w ) - in $$( chainRuleQ + in $$( chainRule1NQ [|| o ||] [|| p ||] [|| s ||] [|| df ||] [|| dg ||] ) @@ -723,7 +884,7 @@ instance HasChainRule Double 2 ( โ„ 4 ) where let !o = origin @Double @( T w ) !p = (^+^) @Double @( T w ) !s = (^*) @Double @( T w ) - in $$( chainRuleQ + in $$( chainRule1NQ [|| o ||] [|| p ||] [|| s ||] [|| df ||] [|| dg ||] ) @@ -758,6 +919,6 @@ instance HasChainRule Double 3 ( โ„ 4 ) where let !o = origin @Double @( T w ) !p = (^+^) @Double @( T w ) !s = (^*) @Double @( T w ) - in $$( chainRuleQ + in $$( chainRule1NQ [|| o ||] [|| p ||] [|| s ||] [|| df ||] [|| dg ||] ) diff --git a/src/splines/Math/Algebra/Dual/Internal.hs b/src/splines/Math/Algebra/Dual/Internal.hs index f70b978..59b2e1e 100644 --- a/src/splines/Math/Algebra/Dual/Internal.hs +++ b/src/splines/Math/Algebra/Dual/Internal.hs @@ -12,7 +12,7 @@ module Math.Algebra.Dual.Internal , D1๐”ธ3(..), D2๐”ธ3(..), D3๐”ธ3(..) , D1๐”ธ4(..), D2๐”ธ4(..), D3๐”ธ4(..) - , chainRuleQ + , chainRule1NQ, chainRuleN1Q ) where -- base @@ -35,7 +35,9 @@ import Math.Linear ) import Math.Monomial ( Mon(..), MonomialBasis(..), Vars, Deg - , mons, faร , multiSubsetsSum, zeroMonomial + , mons, zeroMonomial, isZeroMonomial, totalDegree + , multiSubsetSumFaร , multiSubsetsSum + , partitionFaร , partitions ) import Math.Ring ( Ring ) @@ -179,21 +181,23 @@ data D3๐”ธ4 v = -------------------------------------------------------------------------------- --- | The chain rule, to be spliced in using Template Haskell. -chainRuleQ :: forall dr1 dv v r w d - . ( Ring r, RepresentableQ r v - , MonomialBasis dr1, Vars dr1 ~ 1 - , MonomialBasis dv , Vars dv ~ RepDim v - , Deg dr1 ~ Deg dv - , d ~ Vars dv, KnownNat d - ) - => CodeQ ( T w ) -- Module r ( T w ) - -> CodeQ ( T w -> T w -> T w ) -- - -> CodeQ ( T w -> r -> T w ) -- (circumvent TH constraint issue) - -> CodeQ ( dr1 v ) - -> CodeQ ( dv w ) - -> CodeQ ( dr1 w ) -chainRuleQ zero_w sum_w scale_w df dg = +-- | The chain rule for a composition \( \mathbb{R}^1 \to \mathbb{R}^n \to W \) +-- +-- (To be spliced in using Template Haskell.) +chainRule1NQ :: forall dr1 dv v r w d + . ( Ring r, RepresentableQ r v + , MonomialBasis dr1, Vars dr1 ~ 1 + , MonomialBasis dv , Vars dv ~ RepDim v + , Deg dr1 ~ Deg dv + , d ~ Vars dv, KnownNat d + ) + => CodeQ ( T w ) -- Module r ( T w ) + -> CodeQ ( T w -> T w -> T w ) -- + -> CodeQ ( T w -> r -> T w ) -- (circumvent TH constraint issue) + -> CodeQ ( dr1 v ) + -> CodeQ ( dv w ) + -> CodeQ ( dr1 w ) +chainRule1NQ zero_w sum_w scale_w df dg = monTabulate @dr1 \ ( Mon ( k `VS` _ ) ) -> case k of -- Set the value of the composition separately, @@ -203,7 +207,7 @@ chainRuleQ zero_w sum_w scale_w df dg = [|| unT $ $$( foldQ sum_w zero_w [ [|| $$scale_w ( T $$( monIndex @dv dg m_g ) ) $$( foldQ [|| (Ring.+) ||] [|| Ring.fromInteger ( 0 :: Integer ) ||] - [ [|| Ring.fromInteger $$( liftTyped $ fromIntegral $ faร  k is ) Ring.* + [ [|| Ring.fromInteger $$( liftTyped $ fromIntegral $ multiSubsetSumFaร  k is ) Ring.* $$( foldQ [|| (Ring.*) ||] [|| Ring.fromInteger ( 1 :: Integer ) ||] [ foldQ [|| (Ring.*) ||] [|| Ring.fromInteger ( 1 :: Integer ) ||] [ ( indexQ @r @v ( monIndex @dr1 df ( Mon $ f_deg `VS` VZ ) ) v_index ) @@ -223,6 +227,45 @@ chainRuleQ zero_w sum_w scale_w df dg = ] ) ||] +-- | The chain rule for a composition \( \mathbb{R}^n \to \mathbb{R} \to \mathbb{R} \) +-- +-- (To be spliced in using Template Haskell.) +chainRuleN1Q :: forall du dr1 r + . ( MonomialBasis du + , MonomialBasis dr1, Vars dr1 ~ 1 + ) + => CodeQ ( Integer -> r ) -- ^ fromInteger (circumvent TH constraint issue) + -> CodeQ ( r -> r -> r ) -- ^ (+) (circumvent TH constraint issue) + -> CodeQ ( r -> r -> r ) -- ^ (*) (circumvent TH constraint issue) + -> CodeQ ( r -> Word -> r ) -- ^ (^) (circumvent TH constraint issue) + -> CodeQ ( du r ) + -> CodeQ ( dr1 r ) + -> CodeQ ( du r ) +chainRuleN1Q fromInt add times pow df dg = + monTabulate @du \ mon -> + if + | isZeroMonomial mon + -- Set the value of the composition separately, + -- as that isn't handled by the Faร  di Bruno formula. + -> monIndex @dr1 dg zeroMonomial + | otherwise + -> foldQ add [|| $$fromInt ( 0 :: Integer ) ||] + [ [|| $$times $$( monIndex @dr1 dg ( Mon ( k `VS` VZ ) ) ) + $$( foldQ add [|| $$fromInt ( 0 :: Integer ) ||] + [ [|| $$times ( $$fromInt $$( liftTyped $ fromIntegral $ partitionFaร  mon is ) ) + $$( foldQ times [|| $$fromInt ( 1 :: Integer ) ||] + [ [|| $$pow $$( monIndex @du df f_mon ) p ||] + | ( f_mon, p ) <- is + ] + ) ||] + | is <- mss + ] + ) + ||] + | k <- [ 1 .. totalDegree mon ] + , let mss = partitions k mon + ] + -------------------------------------------------------------------------------- -- MonomialBasis instances follow (nothing else). diff --git a/src/splines/Math/Interval.hs b/src/splines/Math/Interval.hs index 84c0fc2..47f0900 100644 --- a/src/splines/Math/Interval.hs +++ b/src/splines/Math/Interval.hs @@ -35,7 +35,7 @@ import Data.Group.Generics -- splines import Math.Algebra.Dual import Math.Algebra.Dual.Internal - ( chainRuleQ ) + ( chainRule1NQ ) import Math.Interval.Internal ( ๐•€(..) ) import Math.Linear @@ -184,7 +184,7 @@ instance HasChainRule ( ๐•€ Double ) 2 ( ๐•€โ„ 1 ) where let !o = origin @( ๐•€ Double ) @( T w ) !p = (^+^) @( ๐•€ Double ) @( T w ) !s = (^*) @( ๐•€ Double ) @( T w ) - in $$( chainRuleQ + in $$( chainRule1NQ [|| o ||] [|| p ||] [|| s ||] [|| df ||] [|| dg ||] ) @@ -219,7 +219,7 @@ instance HasChainRule ( ๐•€ Double ) 3 ( ๐•€โ„ 1 ) where let !o = origin @( ๐•€ Double ) @( T w ) !p = (^+^) @( ๐•€ Double ) @( T w ) !s = (^*) @( ๐•€ Double ) @( T w ) - in $$( chainRuleQ + in $$( chainRule1NQ [|| o ||] [|| p ||] [|| s ||] [|| df ||] [|| dg ||] ) @@ -254,7 +254,7 @@ instance HasChainRule ( ๐•€ Double ) 2 ( ๐•€โ„ 2 ) where let !o = origin @( ๐•€ Double ) @( T w ) !p = (^+^) @( ๐•€ Double ) @( T w ) !s = (^*) @( ๐•€ Double ) @( T w ) - in $$( chainRuleQ + in $$( chainRule1NQ [|| o ||] [|| p ||] [|| s ||] [|| df ||] [|| dg ||] ) @@ -289,7 +289,7 @@ instance HasChainRule ( ๐•€ Double ) 3 ( ๐•€โ„ 2 ) where let !o = origin @( ๐•€ Double ) @( T w ) !p = (^+^) @( ๐•€ Double ) @( T w ) !s = (^*) @( ๐•€ Double ) @( T w ) - in $$( chainRuleQ + in $$( chainRule1NQ [|| o ||] [|| p ||] [|| s ||] [|| df ||] [|| dg ||] ) @@ -324,7 +324,7 @@ instance HasChainRule ( ๐•€ Double ) 2 ( ๐•€โ„ 3 ) where let !o = origin @( ๐•€ Double ) @( T w ) !p = (^+^) @( ๐•€ Double ) @( T w ) !s = (^*) @( ๐•€ Double ) @( T w ) - in $$( chainRuleQ + in $$( chainRule1NQ [|| o ||] [|| p ||] [|| s ||] [|| df ||] [|| dg ||] ) @@ -359,7 +359,7 @@ instance HasChainRule ( ๐•€ Double ) 3 ( ๐•€โ„ 3 ) where let !o = origin @( ๐•€ Double ) @( T w ) !p = (^+^) @( ๐•€ Double ) @( T w ) !s = (^*) @( ๐•€ Double ) @( T w ) - in $$( chainRuleQ + in $$( chainRule1NQ [|| o ||] [|| p ||] [|| s ||] [|| df ||] [|| dg ||] ) @@ -394,7 +394,7 @@ instance HasChainRule ( ๐•€ Double ) 2 ( ๐•€โ„ 4 ) where let !o = origin @( ๐•€ Double ) @( T w ) !p = (^+^) @( ๐•€ Double ) @( T w ) !s = (^*) @( ๐•€ Double ) @( T w ) - in $$( chainRuleQ + in $$( chainRule1NQ [|| o ||] [|| p ||] [|| s ||] [|| df ||] [|| dg ||] ) @@ -429,6 +429,6 @@ instance HasChainRule ( ๐•€ Double ) 3 ( ๐•€โ„ 4 ) where let !o = origin @( ๐•€ Double ) @( T w ) !p = (^+^) @( ๐•€ Double ) @( T w ) !s = (^*) @( ๐•€ Double ) @( T w ) - in $$( chainRuleQ + in $$( chainRule1NQ [|| o ||] [|| p ||] [|| s ||] [|| df ||] [|| dg ||] ) diff --git a/src/splines/Math/Linear.hs b/src/splines/Math/Linear.hs index 07e28fa..3f89801 100644 --- a/src/splines/Math/Linear.hs +++ b/src/splines/Math/Linear.hs @@ -28,6 +28,8 @@ import GHC.Show ( showSpace ) import GHC.TypeNats ( Nat, type (+) ) +import Unsafe.Coerce + ( unsafeCoerce ) -- acts import Data.Act @@ -118,7 +120,7 @@ infixr 5 `VS` type Vec :: Nat -> Type -> Type data Vec n a where VZ :: Vec 0 a - VS :: a -> Vec n a -> Vec ( 1 + n ) a + VS :: forall n a. a -> Vec n a -> Vec ( 1 + n ) a -- can't be strict, otherwise we can't conveniently -- unsafeCoerce from lists @@ -128,6 +130,13 @@ deriving stock instance Functor ( Vec n ) deriving stock instance Foldable ( Vec n ) deriving stock instance Traversable ( Vec n ) +instance Eq a => Eq ( Vec n a ) where + (==) = unsafeCoerce $ (==) @[a] + +instance Ord a => Ord ( Vec n a ) where + compare = unsafeCoerce $ compare @[a] + (<=) = unsafeCoerce $ (<=) @[a] + infixl 9 ! (!) :: forall l a. Vec l a -> Fin l -> a VS a _ ! Fin 1 = a diff --git a/src/splines/Math/Monomial.hs b/src/splines/Math/Monomial.hs index d0eff63..ec5162c 100644 --- a/src/splines/Math/Monomial.hs +++ b/src/splines/Math/Monomial.hs @@ -1,6 +1,7 @@ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE ParallelListComp #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE UndecidableInstances #-} @@ -9,11 +10,13 @@ module Math.Monomial ( Mon(..) , MonomialBasis(..), Deg, Vars - , zeroMonomial, isZeroMonomial, isLinear + , zeroMonomial, isZeroMonomial + , totalDegree, isLinear , split, mons - , faร , multiSubsetsSum, multiSubsetSum + , multiSubsetSumFaร , multiSubsetsSum, multiSubsetSum + , partitionFaร , partitions , prodRuleQ @@ -27,7 +30,7 @@ import Data.Kind import GHC.Exts ( proxy# ) import GHC.TypeNats - ( KnownNat, Nat, natVal' ) +-- ( KnownNat, Nat, natVal' ) import Unsafe.Coerce ( unsafeCoerce ) @@ -46,7 +49,7 @@ import TH.Utils -- | @Mon k n@ is the set of monomials in @n@ variables of degree less than or equal to @k@. type Mon :: Nat -> Nat -> Type newtype Mon k n = Mon { monDegs :: Vec n Word } -- sum <= k - deriving stock Show + deriving stock ( Show, Eq, Ord ) type Deg :: ( Type -> Type ) -> Nat type Vars :: ( Type -> Type ) -> Nat @@ -65,6 +68,10 @@ isZeroMonomial ( Mon ( i `VS` is ) ) | otherwise = False +totalDegree :: Mon k n -> Word +totalDegree ( Mon VZ ) = 0 +totalDegree ( Mon ( i `VS` is ) ) = i + totalDegree ( Mon is ) + isLinear :: Mon k n -> Maybe ( Fin n ) isLinear = fmap Fin . go 1 where @@ -99,12 +106,32 @@ mons' _ 0 = [ [] ] mons' 0 n = [ replicate ( fromIntegral n ) 0 ] mons' k n = [ i : is | i <- reverse [ 0 .. k ], is <- mons' ( k - i ) ( n - 1 ) ] +subs :: Mon k n -> [ ( Mon k n, Word ) ] +subs ( Mon VZ ) = [ ( Mon VZ, maxBound ) ] +subs ( Mon ( i `VS` is ) ) + = [ ( Mon ( j `VS` js ) + , if j == 0 then mult else min ( i `quot` j ) mult ) + | j <- [ 0 .. i ] + , ( Mon js, mult ) <- subs ( Mon is ) + ] + word :: forall n. KnownNat n => Word word = fromIntegral $ natVal' @n proxy# +-- | The factorial function \( n! = n \cdot (n-1) \cdot `ldots` \cdot 2 `cdot` 1 \). +factorial :: Word -> Word +factorial i = product [ 1 .. i ] + +vecFactorial :: Vec n Word -> Word +vecFactorial VZ = 1 +vecFactorial ( i `VS` is ) = factorial i * vecFactorial is + +-------------------------------------------------------------------------------- +-- Computations for the chain rule R^1 -> R^n -> R^m + -- | Faร  di Bruno coefficient (naive implementation). -faร  :: Word -> Vec n [ ( Word, Word ) ] -> Word -faร  k multisubsets = +multiSubsetSumFaร  :: Word -> Vec n [ ( Word, Word ) ] -> Word +multiSubsetSumFaร  k multisubsets = factorial k `div` product [ factorial p * ( factorial i ) ^ p | multisubset <- toList multisubsets @@ -146,9 +173,49 @@ multiSubsetsSum is = goMSS [] -> 0 _ -> max 0 $ minimum is --- | The factorial function \( n! = n \cdot (n-1) \cdot `ldots` \cdot 2 `cdot` 1 \). -factorial :: Word -> Word -factorial i = product [ 1 .. i ] +-------------------------------------------------------------------------------- +-- Computations for the chain rule R^n -> R^1 -> R^1 + +partitionFaร  :: Mon k n -> [ ( Mon k n, Word ) ] -> Word +partitionFaร  ( Mon mon ) multiIndexes = + vecFactorial mon `div` + product [ factorial p * ( vecFactorial i ) ^ p + | ( Mon i, p ) <- toList multiIndexes ] + +-- | @partitions p mon@ computes all partitions of the monomial @mon@ into +-- @p@ (non-zero) parts, allowing repetition. +partitions :: forall k n + . Word -- ^ number of parts + -> Mon k n -- ^ monomial to sum to + -> [ [ ( Mon k n, Word ) ] ] +partitions n_parts0 mon0 = go n_parts0 mon0 mon0 + where + go :: Word -> Mon k n -> Mon k n -> [ [ ( Mon k n , Word ) ] ] + go 0 _ mon + | isZeroMonomial mon + = [ [] ] + | otherwise + = [] + go 1 maxSub mon + | isZeroMonomial mon + || mon >= maxSub + = [] + | otherwise + = [ [ ( mon, 1 ) ] ] + go n_parts maxSub mon + = [ ( part, l ) : parts + | ( part, l_max ) <- subs mon + , not ( isZeroMonomial part ) -- parts must be non-empty + , part < maxSub -- use a total ordering on monomials to ensure uniqueness + , l <- [ 1 .. min n_parts l_max ] + , parts <- go ( n_parts - l ) part ( subPart l mon part ) + ] + +subPart :: Word -> Mon k n -> Mon k n -> Mon k n +subPart = unsafeCoerce subPart' + +subPart' :: Word -> [ Word ] -> [ Word ] -> [ Word ] +subPart' m = zipWith ( \ i j -> i - m * j ) --------------------------------------------------------------------------------