add chain rule R^n -> R -> R

This commit is contained in:
sheaf 2023-01-22 04:51:23 +01:00
parent 236055b4ca
commit bdcac18ab9
5 changed files with 350 additions and 70 deletions

View file

@ -411,9 +411,8 @@ instance Ring r => Ring ( D3𝔸4 r ) where
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- Field & transcendental instances -- Field & transcendental instances
-- TODO!!
deriving newtype instance Field r => Field ( D𝔸0 r ) deriving newtype instance Field r => Field ( D𝔸0 r )
--instance Field r => Field ( D1𝔸1 r ) where --instance Field r => Field ( D1𝔸1 r ) where
--instance Field r => Field ( D1𝔸2 r ) where --instance Field r => Field ( D1𝔸2 r ) where
--instance Field r => Field ( D1𝔸3 r ) where --instance Field r => Field ( D1𝔸3 r ) where
@ -426,7 +425,37 @@ instance Field r => Field ( D3𝔸1 r ) where
instance Field r => Field ( D3𝔸2 r ) where instance Field r => Field ( D3𝔸2 r ) where
instance Field r => Field ( D3𝔸3 r ) where instance Field r => Field ( D3𝔸3 r ) where
instance Field r => Field ( D3𝔸4 r ) where instance Field r => Field ( D3𝔸4 r ) where
-- TODO
d1sin, d1cos :: Transcendental r => r -> D1𝔸1 r
d1sin x =
let !s = sin x
!c = cos x
in D11 s ( T c )
d1cos x =
let !s = sin x
!c = cos x
in D11 c ( T -s )
d2sin, d2cos :: Transcendental r => r -> D2𝔸1 r
d2sin x =
let !s = sin x
!c = cos x
in D21 s ( T c ) ( T -s )
d2cos x =
let !s = sin x
!c = cos x
in D21 c ( T -s ) ( T -c )
d3sin, d3cos :: Transcendental r => r -> D3𝔸1 r
d3sin x =
let !s = sin x
!c = cos x
in D31 s ( T c ) ( T -s ) ( T -c )
d3cos x =
let !s = sin x
!c = cos x
in D31 c ( T -s ) ( T -c ) ( T s )
deriving newtype instance Transcendental r => Transcendental ( D𝔸0 r ) deriving newtype instance Transcendental r => Transcendental ( D𝔸0 r )
--instance Transcendental r => Transcendental ( D1𝔸1 r ) where --instance Transcendental r => Transcendental ( D1𝔸1 r ) where
@ -434,38 +463,170 @@ deriving newtype instance Transcendental r => Transcendental ( D𝔸0 r )
--instance Transcendental r => Transcendental ( D1𝔸3 r ) where --instance Transcendental r => Transcendental ( D1𝔸3 r ) where
--instance Transcendental r => Transcendental ( D1𝔸4 r ) where --instance Transcendental r => Transcendental ( D1𝔸4 r ) where
instance Transcendental r => Transcendental ( D2𝔸1 r ) where instance Transcendental r => Transcendental ( D2𝔸1 r ) where
pi = konst @Double @2 @( 1 ) pi
sin df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d2sin @r ( _D21_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
cos df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d2cos @r ( _D21_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
instance Transcendental r => Transcendental ( D2𝔸2 r ) where instance Transcendental r => Transcendental ( D2𝔸2 r ) where
pi = konst @Double @2 @( 2 ) pi
sin df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d2sin @r ( _D22_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
cos df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d2cos @r ( _D22_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
instance Transcendental r => Transcendental ( D2𝔸3 r ) where instance Transcendental r => Transcendental ( D2𝔸3 r ) where
pi = konst @Double @2 @( 3 ) pi pi = konst @Double @2 @( 3 ) pi
sin ( D23 v ( T dx ) ( T dy ) ( T dz ) ( T ddx ) ( T dxdy ) ( T ddy ) ( T dxdz ) ( T dydz ) ( T ddz ) ) sin df =
= let !s = sin v let
!c = cos v fromInt = fromInteger @r
in D23 s add = (+) @r
( T $ c * dx ) ( T $ c * dy ) ( T $ c * dz ) times = (*) @r
( T $ 2 * c * ddx - s * ( dx ^ 2 ) ) pow = (^) @r
( T $ 2 * c * dxdy - 2 * s * dx * dy ) dg = d2sin @r ( _D23_v df )
( T $ 2 * c * ddy - s * ( dy ^ 2 ) ) in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
( T $ 2 * c * dxdz - 2 * s * dx * dz ) [|| df ||] [|| dg ||] )
( T $ 2 * c * dydz - 2 * s * dy * dz ) cos df =
( T $ 2 * c * ddz - s * ( dz ^ 2 ) ) let
fromInt = fromInteger @r
cos ( D23 v ( T dx ) ( T dy ) ( T dz ) ( T ddx ) ( T dxdy ) ( T ddy ) ( T dxdz ) ( T dydz ) ( T ddz ) ) add = (+) @r
= let !s = sin v times = (*) @r
!c = cos v pow = (^) @r
in D23 c dg = d2cos @r ( _D23_v df )
( T $ -s * dx ) ( T $ -s * dy ) ( T $ -s * dz ) in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
( T $ -2 * s * ddx - c * ( dx ^ 2 ) ) [|| df ||] [|| dg ||] )
( T $ -2 * s * dxdy - 2 * c * dx * dy )
( T $ -2 * s * ddy - c * ( dy ^ 2 ) )
( T $ -2 * s * dxdz - 2 * c * dx * dz )
( T $ -2 * s * dydz - 2 * c * dy * dz )
( T $ -2 * s * ddz - c * ( dz ^ 2 ) )
instance Transcendental r => Transcendental ( D2𝔸4 r ) where instance Transcendental r => Transcendental ( D2𝔸4 r ) where
pi = konst @Double @2 @( 4 ) pi
sin df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d2sin @r ( _D24_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
cos df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d2cos @r ( _D24_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
instance Transcendental r => Transcendental ( D3𝔸1 r ) where instance Transcendental r => Transcendental ( D3𝔸1 r ) where
pi = konst @Double @3 @( 1 ) pi
sin df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d3sin @r ( _D31_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
cos df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d3cos @r ( _D31_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
instance Transcendental r => Transcendental ( D3𝔸2 r ) where instance Transcendental r => Transcendental ( D3𝔸2 r ) where
pi = konst @Double @3 @( 2 ) pi
sin df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d3sin @r ( _D32_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
cos df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d3cos @r ( _D32_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
instance Transcendental r => Transcendental ( D3𝔸3 r ) where instance Transcendental r => Transcendental ( D3𝔸3 r ) where
pi = konst @Double @3 @( 3 ) pi
sin df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d3sin @r ( _D33_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
cos df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d3cos @r ( _D33_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
instance Transcendental r => Transcendental ( D3𝔸4 r ) where instance Transcendental r => Transcendental ( D3𝔸4 r ) where
pi = konst @Double @3 @( 4 ) pi
sin df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d3sin @r ( _D34_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
cos df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d3cos @r ( _D34_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- HasChainRule instances. -- HasChainRule instances.
@ -513,7 +674,7 @@ instance HasChainRule Double 2 ( 1 ) where
let !o = origin @Double @( T w ) let !o = origin @Double @( T w )
!p = (^+^) @Double @( T w ) !p = (^+^) @Double @( T w )
!s = (^*) @Double @( T w ) !s = (^*) @Double @( T w )
in $$( chainRuleQ in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||] [|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] ) [|| df ||] [|| dg ||] )
@ -548,7 +709,7 @@ instance HasChainRule Double 3 ( 1 ) where
let !o = origin @Double @( T w ) let !o = origin @Double @( T w )
!p = (^+^) @Double @( T w ) !p = (^+^) @Double @( T w )
!s = (^*) @Double @( T w ) !s = (^*) @Double @( T w )
in $$( chainRuleQ in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||] [|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] ) [|| df ||] [|| dg ||] )
@ -583,7 +744,7 @@ instance HasChainRule Double 2 ( 2 ) where
let !o = origin @Double @( T w ) let !o = origin @Double @( T w )
!p = (^+^) @Double @( T w ) !p = (^+^) @Double @( T w )
!s = (^*) @Double @( T w ) !s = (^*) @Double @( T w )
in $$( chainRuleQ in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||] [|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] ) [|| df ||] [|| dg ||] )
@ -618,7 +779,7 @@ instance HasChainRule Double 3 ( 2 ) where
let !o = origin @Double @( T w ) let !o = origin @Double @( T w )
!p = (^+^) @Double @( T w ) !p = (^+^) @Double @( T w )
!s = (^*) @Double @( T w ) !s = (^*) @Double @( T w )
in $$( chainRuleQ in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||] [|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] ) [|| df ||] [|| dg ||] )
@ -653,7 +814,7 @@ instance HasChainRule Double 2 ( 3 ) where
let !o = origin @Double @( T w ) let !o = origin @Double @( T w )
!p = (^+^) @Double @( T w ) !p = (^+^) @Double @( T w )
!s = (^*) @Double @( T w ) !s = (^*) @Double @( T w )
in $$( chainRuleQ in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||] [|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] ) [|| df ||] [|| dg ||] )
@ -688,7 +849,7 @@ instance HasChainRule Double 3 ( 3 ) where
let !o = origin @Double @( T w ) let !o = origin @Double @( T w )
!p = (^+^) @Double @( T w ) !p = (^+^) @Double @( T w )
!s = (^*) @Double @( T w ) !s = (^*) @Double @( T w )
in $$( chainRuleQ in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||] [|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] ) [|| df ||] [|| dg ||] )
@ -723,7 +884,7 @@ instance HasChainRule Double 2 ( 4 ) where
let !o = origin @Double @( T w ) let !o = origin @Double @( T w )
!p = (^+^) @Double @( T w ) !p = (^+^) @Double @( T w )
!s = (^*) @Double @( T w ) !s = (^*) @Double @( T w )
in $$( chainRuleQ in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||] [|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] ) [|| df ||] [|| dg ||] )
@ -758,6 +919,6 @@ instance HasChainRule Double 3 ( 4 ) where
let !o = origin @Double @( T w ) let !o = origin @Double @( T w )
!p = (^+^) @Double @( T w ) !p = (^+^) @Double @( T w )
!s = (^*) @Double @( T w ) !s = (^*) @Double @( T w )
in $$( chainRuleQ in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||] [|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] ) [|| df ||] [|| dg ||] )

View file

@ -12,7 +12,7 @@ module Math.Algebra.Dual.Internal
, D1𝔸3(..), D2𝔸3(..), D3𝔸3(..) , D1𝔸3(..), D2𝔸3(..), D3𝔸3(..)
, D1𝔸4(..), D2𝔸4(..), D3𝔸4(..) , D1𝔸4(..), D2𝔸4(..), D3𝔸4(..)
, chainRuleQ , chainRule1NQ, chainRuleN1Q
) where ) where
-- base -- base
@ -35,7 +35,9 @@ import Math.Linear
) )
import Math.Monomial import Math.Monomial
( Mon(..), MonomialBasis(..), Vars, Deg ( Mon(..), MonomialBasis(..), Vars, Deg
, mons, faà, multiSubsetsSum, zeroMonomial , mons, zeroMonomial, isZeroMonomial, totalDegree
, multiSubsetSumFaà, multiSubsetsSum
, partitionFaà, partitions
) )
import Math.Ring import Math.Ring
( Ring ) ( Ring )
@ -179,8 +181,10 @@ data D3𝔸4 v =
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- | The chain rule, to be spliced in using Template Haskell. -- | The chain rule for a composition \( \mathbb{R}^1 \to \mathbb{R}^n \to W \)
chainRuleQ :: forall dr1 dv v r w d --
-- (To be spliced in using Template Haskell.)
chainRule1NQ :: forall dr1 dv v r w d
. ( Ring r, RepresentableQ r v . ( Ring r, RepresentableQ r v
, MonomialBasis dr1, Vars dr1 ~ 1 , MonomialBasis dr1, Vars dr1 ~ 1
, MonomialBasis dv , Vars dv ~ RepDim v , MonomialBasis dv , Vars dv ~ RepDim v
@ -193,7 +197,7 @@ chainRuleQ :: forall dr1 dv v r w d
-> CodeQ ( dr1 v ) -> CodeQ ( dr1 v )
-> CodeQ ( dv w ) -> CodeQ ( dv w )
-> CodeQ ( dr1 w ) -> CodeQ ( dr1 w )
chainRuleQ zero_w sum_w scale_w df dg = chainRule1NQ zero_w sum_w scale_w df dg =
monTabulate @dr1 \ ( Mon ( k `VS` _ ) ) -> monTabulate @dr1 \ ( Mon ( k `VS` _ ) ) ->
case k of case k of
-- Set the value of the composition separately, -- Set the value of the composition separately,
@ -203,7 +207,7 @@ chainRuleQ zero_w sum_w scale_w df dg =
[|| unT $ $$( foldQ sum_w zero_w [|| unT $ $$( foldQ sum_w zero_w
[ [|| $$scale_w ( T $$( monIndex @dv dg m_g ) ) [ [|| $$scale_w ( T $$( monIndex @dv dg m_g ) )
$$( foldQ [|| (Ring.+) ||] [|| Ring.fromInteger ( 0 :: Integer ) ||] $$( foldQ [|| (Ring.+) ||] [|| Ring.fromInteger ( 0 :: Integer ) ||]
[ [|| Ring.fromInteger $$( liftTyped $ fromIntegral $ fk is ) Ring.* [ [|| Ring.fromInteger $$( liftTyped $ fromIntegral $ multiSubsetSumFk is ) Ring.*
$$( foldQ [|| (Ring.*) ||] [|| Ring.fromInteger ( 1 :: Integer ) ||] $$( foldQ [|| (Ring.*) ||] [|| Ring.fromInteger ( 1 :: Integer ) ||]
[ foldQ [|| (Ring.*) ||] [|| Ring.fromInteger ( 1 :: Integer ) ||] [ foldQ [|| (Ring.*) ||] [|| Ring.fromInteger ( 1 :: Integer ) ||]
[ ( indexQ @r @v ( monIndex @dr1 df ( Mon $ f_deg `VS` VZ ) ) v_index ) [ ( indexQ @r @v ( monIndex @dr1 df ( Mon $ f_deg `VS` VZ ) ) v_index )
@ -223,6 +227,45 @@ chainRuleQ zero_w sum_w scale_w df dg =
] ]
) ||] ) ||]
-- | The chain rule for a composition \( \mathbb{R}^n \to \mathbb{R} \to \mathbb{R} \)
--
-- (To be spliced in using Template Haskell.)
chainRuleN1Q :: forall du dr1 r
. ( MonomialBasis du
, MonomialBasis dr1, Vars dr1 ~ 1
)
=> CodeQ ( Integer -> r ) -- ^ fromInteger (circumvent TH constraint issue)
-> CodeQ ( r -> r -> r ) -- ^ (+) (circumvent TH constraint issue)
-> CodeQ ( r -> r -> r ) -- ^ (*) (circumvent TH constraint issue)
-> CodeQ ( r -> Word -> r ) -- ^ (^) (circumvent TH constraint issue)
-> CodeQ ( du r )
-> CodeQ ( dr1 r )
-> CodeQ ( du r )
chainRuleN1Q fromInt add times pow df dg =
monTabulate @du \ mon ->
if
| isZeroMonomial mon
-- Set the value of the composition separately,
-- as that isn't handled by the Faà di Bruno formula.
-> monIndex @dr1 dg zeroMonomial
| otherwise
-> foldQ add [|| $$fromInt ( 0 :: Integer ) ||]
[ [|| $$times $$( monIndex @dr1 dg ( Mon ( k `VS` VZ ) ) )
$$( foldQ add [|| $$fromInt ( 0 :: Integer ) ||]
[ [|| $$times ( $$fromInt $$( liftTyped $ fromIntegral $ partitionFaà mon is ) )
$$( foldQ times [|| $$fromInt ( 1 :: Integer ) ||]
[ [|| $$pow $$( monIndex @du df f_mon ) p ||]
| ( f_mon, p ) <- is
]
) ||]
| is <- mss
]
)
||]
| k <- [ 1 .. totalDegree mon ]
, let mss = partitions k mon
]
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- MonomialBasis instances follow (nothing else). -- MonomialBasis instances follow (nothing else).

View file

@ -35,7 +35,7 @@ import Data.Group.Generics
-- splines -- splines
import Math.Algebra.Dual import Math.Algebra.Dual
import Math.Algebra.Dual.Internal import Math.Algebra.Dual.Internal
( chainRuleQ ) ( chainRule1NQ )
import Math.Interval.Internal import Math.Interval.Internal
( 𝕀(..) ) ( 𝕀(..) )
import Math.Linear import Math.Linear
@ -184,7 +184,7 @@ instance HasChainRule ( 𝕀 Double ) 2 ( 𝕀 1 ) where
let !o = origin @( 𝕀 Double ) @( T w ) let !o = origin @( 𝕀 Double ) @( T w )
!p = (^+^) @( 𝕀 Double ) @( T w ) !p = (^+^) @( 𝕀 Double ) @( T w )
!s = (^*) @( 𝕀 Double ) @( T w ) !s = (^*) @( 𝕀 Double ) @( T w )
in $$( chainRuleQ in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||] [|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] ) [|| df ||] [|| dg ||] )
@ -219,7 +219,7 @@ instance HasChainRule ( 𝕀 Double ) 3 ( 𝕀 1 ) where
let !o = origin @( 𝕀 Double ) @( T w ) let !o = origin @( 𝕀 Double ) @( T w )
!p = (^+^) @( 𝕀 Double ) @( T w ) !p = (^+^) @( 𝕀 Double ) @( T w )
!s = (^*) @( 𝕀 Double ) @( T w ) !s = (^*) @( 𝕀 Double ) @( T w )
in $$( chainRuleQ in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||] [|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] ) [|| df ||] [|| dg ||] )
@ -254,7 +254,7 @@ instance HasChainRule ( 𝕀 Double ) 2 ( 𝕀 2 ) where
let !o = origin @( 𝕀 Double ) @( T w ) let !o = origin @( 𝕀 Double ) @( T w )
!p = (^+^) @( 𝕀 Double ) @( T w ) !p = (^+^) @( 𝕀 Double ) @( T w )
!s = (^*) @( 𝕀 Double ) @( T w ) !s = (^*) @( 𝕀 Double ) @( T w )
in $$( chainRuleQ in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||] [|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] ) [|| df ||] [|| dg ||] )
@ -289,7 +289,7 @@ instance HasChainRule ( 𝕀 Double ) 3 ( 𝕀 2 ) where
let !o = origin @( 𝕀 Double ) @( T w ) let !o = origin @( 𝕀 Double ) @( T w )
!p = (^+^) @( 𝕀 Double ) @( T w ) !p = (^+^) @( 𝕀 Double ) @( T w )
!s = (^*) @( 𝕀 Double ) @( T w ) !s = (^*) @( 𝕀 Double ) @( T w )
in $$( chainRuleQ in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||] [|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] ) [|| df ||] [|| dg ||] )
@ -324,7 +324,7 @@ instance HasChainRule ( 𝕀 Double ) 2 ( 𝕀 3 ) where
let !o = origin @( 𝕀 Double ) @( T w ) let !o = origin @( 𝕀 Double ) @( T w )
!p = (^+^) @( 𝕀 Double ) @( T w ) !p = (^+^) @( 𝕀 Double ) @( T w )
!s = (^*) @( 𝕀 Double ) @( T w ) !s = (^*) @( 𝕀 Double ) @( T w )
in $$( chainRuleQ in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||] [|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] ) [|| df ||] [|| dg ||] )
@ -359,7 +359,7 @@ instance HasChainRule ( 𝕀 Double ) 3 ( 𝕀 3 ) where
let !o = origin @( 𝕀 Double ) @( T w ) let !o = origin @( 𝕀 Double ) @( T w )
!p = (^+^) @( 𝕀 Double ) @( T w ) !p = (^+^) @( 𝕀 Double ) @( T w )
!s = (^*) @( 𝕀 Double ) @( T w ) !s = (^*) @( 𝕀 Double ) @( T w )
in $$( chainRuleQ in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||] [|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] ) [|| df ||] [|| dg ||] )
@ -394,7 +394,7 @@ instance HasChainRule ( 𝕀 Double ) 2 ( 𝕀 4 ) where
let !o = origin @( 𝕀 Double ) @( T w ) let !o = origin @( 𝕀 Double ) @( T w )
!p = (^+^) @( 𝕀 Double ) @( T w ) !p = (^+^) @( 𝕀 Double ) @( T w )
!s = (^*) @( 𝕀 Double ) @( T w ) !s = (^*) @( 𝕀 Double ) @( T w )
in $$( chainRuleQ in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||] [|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] ) [|| df ||] [|| dg ||] )
@ -429,6 +429,6 @@ instance HasChainRule ( 𝕀 Double ) 3 ( 𝕀 4 ) where
let !o = origin @( 𝕀 Double ) @( T w ) let !o = origin @( 𝕀 Double ) @( T w )
!p = (^+^) @( 𝕀 Double ) @( T w ) !p = (^+^) @( 𝕀 Double ) @( T w )
!s = (^*) @( 𝕀 Double ) @( T w ) !s = (^*) @( 𝕀 Double ) @( T w )
in $$( chainRuleQ in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||] [|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] ) [|| df ||] [|| dg ||] )

View file

@ -28,6 +28,8 @@ import GHC.Show
( showSpace ) ( showSpace )
import GHC.TypeNats import GHC.TypeNats
( Nat, type (+) ) ( Nat, type (+) )
import Unsafe.Coerce
( unsafeCoerce )
-- acts -- acts
import Data.Act import Data.Act
@ -118,7 +120,7 @@ infixr 5 `VS`
type Vec :: Nat -> Type -> Type type Vec :: Nat -> Type -> Type
data Vec n a where data Vec n a where
VZ :: Vec 0 a VZ :: Vec 0 a
VS :: a -> Vec n a -> Vec ( 1 + n ) a VS :: forall n a. a -> Vec n a -> Vec ( 1 + n ) a
-- can't be strict, otherwise we can't conveniently -- can't be strict, otherwise we can't conveniently
-- unsafeCoerce from lists -- unsafeCoerce from lists
@ -128,6 +130,13 @@ deriving stock instance Functor ( Vec n )
deriving stock instance Foldable ( Vec n ) deriving stock instance Foldable ( Vec n )
deriving stock instance Traversable ( Vec n ) deriving stock instance Traversable ( Vec n )
instance Eq a => Eq ( Vec n a ) where
(==) = unsafeCoerce $ (==) @[a]
instance Ord a => Ord ( Vec n a ) where
compare = unsafeCoerce $ compare @[a]
(<=) = unsafeCoerce $ (<=) @[a]
infixl 9 ! infixl 9 !
(!) :: forall l a. Vec l a -> Fin l -> a (!) :: forall l a. Vec l a -> Fin l -> a
VS a _ ! Fin 1 = a VS a _ ! Fin 1 = a

View file

@ -1,6 +1,7 @@
{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ParallelListComp #-}
{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE UndecidableInstances #-}
@ -9,11 +10,13 @@
module Math.Monomial module Math.Monomial
( Mon(..) ( Mon(..)
, MonomialBasis(..), Deg, Vars , MonomialBasis(..), Deg, Vars
, zeroMonomial, isZeroMonomial, isLinear , zeroMonomial, isZeroMonomial
, totalDegree, isLinear
, split, mons , split, mons
, faà, multiSubsetsSum, multiSubsetSum , multiSubsetSumFaà, multiSubsetsSum, multiSubsetSum
, partitionFaà, partitions
, prodRuleQ , prodRuleQ
@ -27,7 +30,7 @@ import Data.Kind
import GHC.Exts import GHC.Exts
( proxy# ) ( proxy# )
import GHC.TypeNats import GHC.TypeNats
( KnownNat, Nat, natVal' ) -- ( KnownNat, Nat, natVal' )
import Unsafe.Coerce import Unsafe.Coerce
( unsafeCoerce ) ( unsafeCoerce )
@ -46,7 +49,7 @@ import TH.Utils
-- | @Mon k n@ is the set of monomials in @n@ variables of degree less than or equal to @k@. -- | @Mon k n@ is the set of monomials in @n@ variables of degree less than or equal to @k@.
type Mon :: Nat -> Nat -> Type type Mon :: Nat -> Nat -> Type
newtype Mon k n = Mon { monDegs :: Vec n Word } -- sum <= k newtype Mon k n = Mon { monDegs :: Vec n Word } -- sum <= k
deriving stock Show deriving stock ( Show, Eq, Ord )
type Deg :: ( Type -> Type ) -> Nat type Deg :: ( Type -> Type ) -> Nat
type Vars :: ( Type -> Type ) -> Nat type Vars :: ( Type -> Type ) -> Nat
@ -65,6 +68,10 @@ isZeroMonomial ( Mon ( i `VS` is ) )
| otherwise | otherwise
= False = False
totalDegree :: Mon k n -> Word
totalDegree ( Mon VZ ) = 0
totalDegree ( Mon ( i `VS` is ) ) = i + totalDegree ( Mon is )
isLinear :: Mon k n -> Maybe ( Fin n ) isLinear :: Mon k n -> Maybe ( Fin n )
isLinear = fmap Fin . go 1 isLinear = fmap Fin . go 1
where where
@ -99,12 +106,32 @@ mons' _ 0 = [ [] ]
mons' 0 n = [ replicate ( fromIntegral n ) 0 ] mons' 0 n = [ replicate ( fromIntegral n ) 0 ]
mons' k n = [ i : is | i <- reverse [ 0 .. k ], is <- mons' ( k - i ) ( n - 1 ) ] mons' k n = [ i : is | i <- reverse [ 0 .. k ], is <- mons' ( k - i ) ( n - 1 ) ]
subs :: Mon k n -> [ ( Mon k n, Word ) ]
subs ( Mon VZ ) = [ ( Mon VZ, maxBound ) ]
subs ( Mon ( i `VS` is ) )
= [ ( Mon ( j `VS` js )
, if j == 0 then mult else min ( i `quot` j ) mult )
| j <- [ 0 .. i ]
, ( Mon js, mult ) <- subs ( Mon is )
]
word :: forall n. KnownNat n => Word word :: forall n. KnownNat n => Word
word = fromIntegral $ natVal' @n proxy# word = fromIntegral $ natVal' @n proxy#
-- | The factorial function \( n! = n \cdot (n-1) \cdot `ldots` \cdot 2 `cdot` 1 \).
factorial :: Word -> Word
factorial i = product [ 1 .. i ]
vecFactorial :: Vec n Word -> Word
vecFactorial VZ = 1
vecFactorial ( i `VS` is ) = factorial i * vecFactorial is
--------------------------------------------------------------------------------
-- Computations for the chain rule R^1 -> R^n -> R^m
-- | Faà di Bruno coefficient (naive implementation). -- | Faà di Bruno coefficient (naive implementation).
faà :: Word -> Vec n [ ( Word, Word ) ] -> Word multiSubsetSumF:: Word -> Vec n [ ( Word, Word ) ] -> Word
faà k multisubsets = multiSubsetSumFk multisubsets =
factorial k `div` factorial k `div`
product [ factorial p * ( factorial i ) ^ p product [ factorial p * ( factorial i ) ^ p
| multisubset <- toList multisubsets | multisubset <- toList multisubsets
@ -146,9 +173,49 @@ multiSubsetsSum is = goMSS
[] -> 0 [] -> 0
_ -> max 0 $ minimum is _ -> max 0 $ minimum is
-- | The factorial function \( n! = n \cdot (n-1) \cdot `ldots` \cdot 2 `cdot` 1 \). --------------------------------------------------------------------------------
factorial :: Word -> Word -- Computations for the chain rule R^n -> R^1 -> R^1
factorial i = product [ 1 .. i ]
partitionFaà :: Mon k n -> [ ( Mon k n, Word ) ] -> Word
partitionFaà ( Mon mon ) multiIndexes =
vecFactorial mon `div`
product [ factorial p * ( vecFactorial i ) ^ p
| ( Mon i, p ) <- toList multiIndexes ]
-- | @partitions p mon@ computes all partitions of the monomial @mon@ into
-- @p@ (non-zero) parts, allowing repetition.
partitions :: forall k n
. Word -- ^ number of parts
-> Mon k n -- ^ monomial to sum to
-> [ [ ( Mon k n, Word ) ] ]
partitions n_parts0 mon0 = go n_parts0 mon0 mon0
where
go :: Word -> Mon k n -> Mon k n -> [ [ ( Mon k n , Word ) ] ]
go 0 _ mon
| isZeroMonomial mon
= [ [] ]
| otherwise
= []
go 1 maxSub mon
| isZeroMonomial mon
|| mon >= maxSub
= []
| otherwise
= [ [ ( mon, 1 ) ] ]
go n_parts maxSub mon
= [ ( part, l ) : parts
| ( part, l_max ) <- subs mon
, not ( isZeroMonomial part ) -- parts must be non-empty
, part < maxSub -- use a total ordering on monomials to ensure uniqueness
, l <- [ 1 .. min n_parts l_max ]
, parts <- go ( n_parts - l ) part ( subPart l mon part )
]
subPart :: Word -> Mon k n -> Mon k n -> Mon k n
subPart = unsafeCoerce subPart'
subPart' :: Word -> [ Word ] -> [ Word ] -> [ Word ]
subPart' m = zipWith ( \ i j -> i - m * j )
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------