mirror of
https://gitlab.com/sheaf/metabrush.git
synced 2024-11-27 09:24:08 +00:00
add chain rule R^n -> R -> R
This commit is contained in:
parent
236055b4ca
commit
bdcac18ab9
|
@ -411,9 +411,8 @@ instance Ring r => Ring ( D3𝔸4 r ) where
|
|||
--------------------------------------------------------------------------------
|
||||
-- Field & transcendental instances
|
||||
|
||||
-- TODO!!
|
||||
|
||||
deriving newtype instance Field r => Field ( D𝔸0 r )
|
||||
|
||||
--instance Field r => Field ( D1𝔸1 r ) where
|
||||
--instance Field r => Field ( D1𝔸2 r ) where
|
||||
--instance Field r => Field ( D1𝔸3 r ) where
|
||||
|
@ -426,7 +425,37 @@ instance Field r => Field ( D3𝔸1 r ) where
|
|||
instance Field r => Field ( D3𝔸2 r ) where
|
||||
instance Field r => Field ( D3𝔸3 r ) where
|
||||
instance Field r => Field ( D3𝔸4 r ) where
|
||||
-- TODO
|
||||
|
||||
d1sin, d1cos :: Transcendental r => r -> D1𝔸1 r
|
||||
d1sin x =
|
||||
let !s = sin x
|
||||
!c = cos x
|
||||
in D11 s ( T c )
|
||||
d1cos x =
|
||||
let !s = sin x
|
||||
!c = cos x
|
||||
in D11 c ( T -s )
|
||||
|
||||
d2sin, d2cos :: Transcendental r => r -> D2𝔸1 r
|
||||
d2sin x =
|
||||
let !s = sin x
|
||||
!c = cos x
|
||||
in D21 s ( T c ) ( T -s )
|
||||
d2cos x =
|
||||
let !s = sin x
|
||||
!c = cos x
|
||||
in D21 c ( T -s ) ( T -c )
|
||||
|
||||
d3sin, d3cos :: Transcendental r => r -> D3𝔸1 r
|
||||
d3sin x =
|
||||
let !s = sin x
|
||||
!c = cos x
|
||||
in D31 s ( T c ) ( T -s ) ( T -c )
|
||||
d3cos x =
|
||||
let !s = sin x
|
||||
!c = cos x
|
||||
in D31 c ( T -s ) ( T -c ) ( T s )
|
||||
|
||||
deriving newtype instance Transcendental r => Transcendental ( D𝔸0 r )
|
||||
--instance Transcendental r => Transcendental ( D1𝔸1 r ) where
|
||||
|
@ -434,38 +463,170 @@ deriving newtype instance Transcendental r => Transcendental ( D𝔸0 r )
|
|||
--instance Transcendental r => Transcendental ( D1𝔸3 r ) where
|
||||
--instance Transcendental r => Transcendental ( D1𝔸4 r ) where
|
||||
instance Transcendental r => Transcendental ( D2𝔸1 r ) where
|
||||
pi = konst @Double @2 @( ℝ 1 ) pi
|
||||
sin df =
|
||||
let
|
||||
fromInt = fromInteger @r
|
||||
add = (+) @r
|
||||
times = (*) @r
|
||||
pow = (^) @r
|
||||
dg = d2sin @r ( _D21_v df )
|
||||
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
cos df =
|
||||
let
|
||||
fromInt = fromInteger @r
|
||||
add = (+) @r
|
||||
times = (*) @r
|
||||
pow = (^) @r
|
||||
dg = d2cos @r ( _D21_v df )
|
||||
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
instance Transcendental r => Transcendental ( D2𝔸2 r ) where
|
||||
pi = konst @Double @2 @( ℝ 2 ) pi
|
||||
sin df =
|
||||
let
|
||||
fromInt = fromInteger @r
|
||||
add = (+) @r
|
||||
times = (*) @r
|
||||
pow = (^) @r
|
||||
dg = d2sin @r ( _D22_v df )
|
||||
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
cos df =
|
||||
let
|
||||
fromInt = fromInteger @r
|
||||
add = (+) @r
|
||||
times = (*) @r
|
||||
pow = (^) @r
|
||||
dg = d2cos @r ( _D22_v df )
|
||||
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
instance Transcendental r => Transcendental ( D2𝔸3 r ) where
|
||||
pi = konst @Double @2 @( ℝ 3 ) pi
|
||||
sin ( D23 v ( T dx ) ( T dy ) ( T dz ) ( T ddx ) ( T dxdy ) ( T ddy ) ( T dxdz ) ( T dydz ) ( T ddz ) )
|
||||
= let !s = sin v
|
||||
!c = cos v
|
||||
in D23 s
|
||||
( T $ c * dx ) ( T $ c * dy ) ( T $ c * dz )
|
||||
( T $ 2 * c * ddx - s * ( dx ^ 2 ) )
|
||||
( T $ 2 * c * dxdy - 2 * s * dx * dy )
|
||||
( T $ 2 * c * ddy - s * ( dy ^ 2 ) )
|
||||
( T $ 2 * c * dxdz - 2 * s * dx * dz )
|
||||
( T $ 2 * c * dydz - 2 * s * dy * dz )
|
||||
( T $ 2 * c * ddz - s * ( dz ^ 2 ) )
|
||||
|
||||
cos ( D23 v ( T dx ) ( T dy ) ( T dz ) ( T ddx ) ( T dxdy ) ( T ddy ) ( T dxdz ) ( T dydz ) ( T ddz ) )
|
||||
= let !s = sin v
|
||||
!c = cos v
|
||||
in D23 c
|
||||
( T $ -s * dx ) ( T $ -s * dy ) ( T $ -s * dz )
|
||||
( T $ -2 * s * ddx - c * ( dx ^ 2 ) )
|
||||
( T $ -2 * s * dxdy - 2 * c * dx * dy )
|
||||
( T $ -2 * s * ddy - c * ( dy ^ 2 ) )
|
||||
( T $ -2 * s * dxdz - 2 * c * dx * dz )
|
||||
( T $ -2 * s * dydz - 2 * c * dy * dz )
|
||||
( T $ -2 * s * ddz - c * ( dz ^ 2 ) )
|
||||
sin df =
|
||||
let
|
||||
fromInt = fromInteger @r
|
||||
add = (+) @r
|
||||
times = (*) @r
|
||||
pow = (^) @r
|
||||
dg = d2sin @r ( _D23_v df )
|
||||
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
cos df =
|
||||
let
|
||||
fromInt = fromInteger @r
|
||||
add = (+) @r
|
||||
times = (*) @r
|
||||
pow = (^) @r
|
||||
dg = d2cos @r ( _D23_v df )
|
||||
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
||||
instance Transcendental r => Transcendental ( D2𝔸4 r ) where
|
||||
pi = konst @Double @2 @( ℝ 4 ) pi
|
||||
sin df =
|
||||
let
|
||||
fromInt = fromInteger @r
|
||||
add = (+) @r
|
||||
times = (*) @r
|
||||
pow = (^) @r
|
||||
dg = d2sin @r ( _D24_v df )
|
||||
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
cos df =
|
||||
let
|
||||
fromInt = fromInteger @r
|
||||
add = (+) @r
|
||||
times = (*) @r
|
||||
pow = (^) @r
|
||||
dg = d2cos @r ( _D24_v df )
|
||||
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
||||
instance Transcendental r => Transcendental ( D3𝔸1 r ) where
|
||||
pi = konst @Double @3 @( ℝ 1 ) pi
|
||||
sin df =
|
||||
let
|
||||
fromInt = fromInteger @r
|
||||
add = (+) @r
|
||||
times = (*) @r
|
||||
pow = (^) @r
|
||||
dg = d3sin @r ( _D31_v df )
|
||||
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
cos df =
|
||||
let
|
||||
fromInt = fromInteger @r
|
||||
add = (+) @r
|
||||
times = (*) @r
|
||||
pow = (^) @r
|
||||
dg = d3cos @r ( _D31_v df )
|
||||
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
||||
instance Transcendental r => Transcendental ( D3𝔸2 r ) where
|
||||
pi = konst @Double @3 @( ℝ 2 ) pi
|
||||
sin df =
|
||||
let
|
||||
fromInt = fromInteger @r
|
||||
add = (+) @r
|
||||
times = (*) @r
|
||||
pow = (^) @r
|
||||
dg = d3sin @r ( _D32_v df )
|
||||
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
cos df =
|
||||
let
|
||||
fromInt = fromInteger @r
|
||||
add = (+) @r
|
||||
times = (*) @r
|
||||
pow = (^) @r
|
||||
dg = d3cos @r ( _D32_v df )
|
||||
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
||||
instance Transcendental r => Transcendental ( D3𝔸3 r ) where
|
||||
pi = konst @Double @3 @( ℝ 3 ) pi
|
||||
sin df =
|
||||
let
|
||||
fromInt = fromInteger @r
|
||||
add = (+) @r
|
||||
times = (*) @r
|
||||
pow = (^) @r
|
||||
dg = d3sin @r ( _D33_v df )
|
||||
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
cos df =
|
||||
let
|
||||
fromInt = fromInteger @r
|
||||
add = (+) @r
|
||||
times = (*) @r
|
||||
pow = (^) @r
|
||||
dg = d3cos @r ( _D33_v df )
|
||||
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
||||
instance Transcendental r => Transcendental ( D3𝔸4 r ) where
|
||||
pi = konst @Double @3 @( ℝ 4 ) pi
|
||||
sin df =
|
||||
let
|
||||
fromInt = fromInteger @r
|
||||
add = (+) @r
|
||||
times = (*) @r
|
||||
pow = (^) @r
|
||||
dg = d3sin @r ( _D34_v df )
|
||||
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
cos df =
|
||||
let
|
||||
fromInt = fromInteger @r
|
||||
add = (+) @r
|
||||
times = (*) @r
|
||||
pow = (^) @r
|
||||
dg = d3cos @r ( _D34_v df )
|
||||
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- HasChainRule instances.
|
||||
|
@ -513,7 +674,7 @@ instance HasChainRule Double 2 ( ℝ 1 ) where
|
|||
let !o = origin @Double @( T w )
|
||||
!p = (^+^) @Double @( T w )
|
||||
!s = (^*) @Double @( T w )
|
||||
in $$( chainRuleQ
|
||||
in $$( chainRule1NQ
|
||||
[|| o ||] [|| p ||] [|| s ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
||||
|
@ -548,7 +709,7 @@ instance HasChainRule Double 3 ( ℝ 1 ) where
|
|||
let !o = origin @Double @( T w )
|
||||
!p = (^+^) @Double @( T w )
|
||||
!s = (^*) @Double @( T w )
|
||||
in $$( chainRuleQ
|
||||
in $$( chainRule1NQ
|
||||
[|| o ||] [|| p ||] [|| s ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
||||
|
@ -583,7 +744,7 @@ instance HasChainRule Double 2 ( ℝ 2 ) where
|
|||
let !o = origin @Double @( T w )
|
||||
!p = (^+^) @Double @( T w )
|
||||
!s = (^*) @Double @( T w )
|
||||
in $$( chainRuleQ
|
||||
in $$( chainRule1NQ
|
||||
[|| o ||] [|| p ||] [|| s ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
||||
|
@ -618,7 +779,7 @@ instance HasChainRule Double 3 ( ℝ 2 ) where
|
|||
let !o = origin @Double @( T w )
|
||||
!p = (^+^) @Double @( T w )
|
||||
!s = (^*) @Double @( T w )
|
||||
in $$( chainRuleQ
|
||||
in $$( chainRule1NQ
|
||||
[|| o ||] [|| p ||] [|| s ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
||||
|
@ -653,7 +814,7 @@ instance HasChainRule Double 2 ( ℝ 3 ) where
|
|||
let !o = origin @Double @( T w )
|
||||
!p = (^+^) @Double @( T w )
|
||||
!s = (^*) @Double @( T w )
|
||||
in $$( chainRuleQ
|
||||
in $$( chainRule1NQ
|
||||
[|| o ||] [|| p ||] [|| s ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
||||
|
@ -688,7 +849,7 @@ instance HasChainRule Double 3 ( ℝ 3 ) where
|
|||
let !o = origin @Double @( T w )
|
||||
!p = (^+^) @Double @( T w )
|
||||
!s = (^*) @Double @( T w )
|
||||
in $$( chainRuleQ
|
||||
in $$( chainRule1NQ
|
||||
[|| o ||] [|| p ||] [|| s ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
||||
|
@ -723,7 +884,7 @@ instance HasChainRule Double 2 ( ℝ 4 ) where
|
|||
let !o = origin @Double @( T w )
|
||||
!p = (^+^) @Double @( T w )
|
||||
!s = (^*) @Double @( T w )
|
||||
in $$( chainRuleQ
|
||||
in $$( chainRule1NQ
|
||||
[|| o ||] [|| p ||] [|| s ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
||||
|
@ -758,6 +919,6 @@ instance HasChainRule Double 3 ( ℝ 4 ) where
|
|||
let !o = origin @Double @( T w )
|
||||
!p = (^+^) @Double @( T w )
|
||||
!s = (^*) @Double @( T w )
|
||||
in $$( chainRuleQ
|
||||
in $$( chainRule1NQ
|
||||
[|| o ||] [|| p ||] [|| s ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
|
|
@ -12,7 +12,7 @@ module Math.Algebra.Dual.Internal
|
|||
, D1𝔸3(..), D2𝔸3(..), D3𝔸3(..)
|
||||
, D1𝔸4(..), D2𝔸4(..), D3𝔸4(..)
|
||||
|
||||
, chainRuleQ
|
||||
, chainRule1NQ, chainRuleN1Q
|
||||
) where
|
||||
|
||||
-- base
|
||||
|
@ -35,7 +35,9 @@ import Math.Linear
|
|||
)
|
||||
import Math.Monomial
|
||||
( Mon(..), MonomialBasis(..), Vars, Deg
|
||||
, mons, faà, multiSubsetsSum, zeroMonomial
|
||||
, mons, zeroMonomial, isZeroMonomial, totalDegree
|
||||
, multiSubsetSumFaà, multiSubsetsSum
|
||||
, partitionFaà, partitions
|
||||
)
|
||||
import Math.Ring
|
||||
( Ring )
|
||||
|
@ -179,8 +181,10 @@ data D3𝔸4 v =
|
|||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
-- | The chain rule, to be spliced in using Template Haskell.
|
||||
chainRuleQ :: forall dr1 dv v r w d
|
||||
-- | The chain rule for a composition \( \mathbb{R}^1 \to \mathbb{R}^n \to W \)
|
||||
--
|
||||
-- (To be spliced in using Template Haskell.)
|
||||
chainRule1NQ :: forall dr1 dv v r w d
|
||||
. ( Ring r, RepresentableQ r v
|
||||
, MonomialBasis dr1, Vars dr1 ~ 1
|
||||
, MonomialBasis dv , Vars dv ~ RepDim v
|
||||
|
@ -193,7 +197,7 @@ chainRuleQ :: forall dr1 dv v r w d
|
|||
-> CodeQ ( dr1 v )
|
||||
-> CodeQ ( dv w )
|
||||
-> CodeQ ( dr1 w )
|
||||
chainRuleQ zero_w sum_w scale_w df dg =
|
||||
chainRule1NQ zero_w sum_w scale_w df dg =
|
||||
monTabulate @dr1 \ ( Mon ( k `VS` _ ) ) ->
|
||||
case k of
|
||||
-- Set the value of the composition separately,
|
||||
|
@ -203,7 +207,7 @@ chainRuleQ zero_w sum_w scale_w df dg =
|
|||
[|| unT $ $$( foldQ sum_w zero_w
|
||||
[ [|| $$scale_w ( T $$( monIndex @dv dg m_g ) )
|
||||
$$( foldQ [|| (Ring.+) ||] [|| Ring.fromInteger ( 0 :: Integer ) ||]
|
||||
[ [|| Ring.fromInteger $$( liftTyped $ fromIntegral $ faà k is ) Ring.*
|
||||
[ [|| Ring.fromInteger $$( liftTyped $ fromIntegral $ multiSubsetSumFaà k is ) Ring.*
|
||||
$$( foldQ [|| (Ring.*) ||] [|| Ring.fromInteger ( 1 :: Integer ) ||]
|
||||
[ foldQ [|| (Ring.*) ||] [|| Ring.fromInteger ( 1 :: Integer ) ||]
|
||||
[ ( indexQ @r @v ( monIndex @dr1 df ( Mon $ f_deg `VS` VZ ) ) v_index )
|
||||
|
@ -223,6 +227,45 @@ chainRuleQ zero_w sum_w scale_w df dg =
|
|||
]
|
||||
) ||]
|
||||
|
||||
-- | The chain rule for a composition \( \mathbb{R}^n \to \mathbb{R} \to \mathbb{R} \)
|
||||
--
|
||||
-- (To be spliced in using Template Haskell.)
|
||||
chainRuleN1Q :: forall du dr1 r
|
||||
. ( MonomialBasis du
|
||||
, MonomialBasis dr1, Vars dr1 ~ 1
|
||||
)
|
||||
=> CodeQ ( Integer -> r ) -- ^ fromInteger (circumvent TH constraint issue)
|
||||
-> CodeQ ( r -> r -> r ) -- ^ (+) (circumvent TH constraint issue)
|
||||
-> CodeQ ( r -> r -> r ) -- ^ (*) (circumvent TH constraint issue)
|
||||
-> CodeQ ( r -> Word -> r ) -- ^ (^) (circumvent TH constraint issue)
|
||||
-> CodeQ ( du r )
|
||||
-> CodeQ ( dr1 r )
|
||||
-> CodeQ ( du r )
|
||||
chainRuleN1Q fromInt add times pow df dg =
|
||||
monTabulate @du \ mon ->
|
||||
if
|
||||
| isZeroMonomial mon
|
||||
-- Set the value of the composition separately,
|
||||
-- as that isn't handled by the Faà di Bruno formula.
|
||||
-> monIndex @dr1 dg zeroMonomial
|
||||
| otherwise
|
||||
-> foldQ add [|| $$fromInt ( 0 :: Integer ) ||]
|
||||
[ [|| $$times $$( monIndex @dr1 dg ( Mon ( k `VS` VZ ) ) )
|
||||
$$( foldQ add [|| $$fromInt ( 0 :: Integer ) ||]
|
||||
[ [|| $$times ( $$fromInt $$( liftTyped $ fromIntegral $ partitionFaà mon is ) )
|
||||
$$( foldQ times [|| $$fromInt ( 1 :: Integer ) ||]
|
||||
[ [|| $$pow $$( monIndex @du df f_mon ) p ||]
|
||||
| ( f_mon, p ) <- is
|
||||
]
|
||||
) ||]
|
||||
| is <- mss
|
||||
]
|
||||
)
|
||||
||]
|
||||
| k <- [ 1 .. totalDegree mon ]
|
||||
, let mss = partitions k mon
|
||||
]
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- MonomialBasis instances follow (nothing else).
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ import Data.Group.Generics
|
|||
-- splines
|
||||
import Math.Algebra.Dual
|
||||
import Math.Algebra.Dual.Internal
|
||||
( chainRuleQ )
|
||||
( chainRule1NQ )
|
||||
import Math.Interval.Internal
|
||||
( 𝕀(..) )
|
||||
import Math.Linear
|
||||
|
@ -184,7 +184,7 @@ instance HasChainRule ( 𝕀 Double ) 2 ( 𝕀ℝ 1 ) where
|
|||
let !o = origin @( 𝕀 Double ) @( T w )
|
||||
!p = (^+^) @( 𝕀 Double ) @( T w )
|
||||
!s = (^*) @( 𝕀 Double ) @( T w )
|
||||
in $$( chainRuleQ
|
||||
in $$( chainRule1NQ
|
||||
[|| o ||] [|| p ||] [|| s ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
||||
|
@ -219,7 +219,7 @@ instance HasChainRule ( 𝕀 Double ) 3 ( 𝕀ℝ 1 ) where
|
|||
let !o = origin @( 𝕀 Double ) @( T w )
|
||||
!p = (^+^) @( 𝕀 Double ) @( T w )
|
||||
!s = (^*) @( 𝕀 Double ) @( T w )
|
||||
in $$( chainRuleQ
|
||||
in $$( chainRule1NQ
|
||||
[|| o ||] [|| p ||] [|| s ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
||||
|
@ -254,7 +254,7 @@ instance HasChainRule ( 𝕀 Double ) 2 ( 𝕀ℝ 2 ) where
|
|||
let !o = origin @( 𝕀 Double ) @( T w )
|
||||
!p = (^+^) @( 𝕀 Double ) @( T w )
|
||||
!s = (^*) @( 𝕀 Double ) @( T w )
|
||||
in $$( chainRuleQ
|
||||
in $$( chainRule1NQ
|
||||
[|| o ||] [|| p ||] [|| s ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
||||
|
@ -289,7 +289,7 @@ instance HasChainRule ( 𝕀 Double ) 3 ( 𝕀ℝ 2 ) where
|
|||
let !o = origin @( 𝕀 Double ) @( T w )
|
||||
!p = (^+^) @( 𝕀 Double ) @( T w )
|
||||
!s = (^*) @( 𝕀 Double ) @( T w )
|
||||
in $$( chainRuleQ
|
||||
in $$( chainRule1NQ
|
||||
[|| o ||] [|| p ||] [|| s ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
||||
|
@ -324,7 +324,7 @@ instance HasChainRule ( 𝕀 Double ) 2 ( 𝕀ℝ 3 ) where
|
|||
let !o = origin @( 𝕀 Double ) @( T w )
|
||||
!p = (^+^) @( 𝕀 Double ) @( T w )
|
||||
!s = (^*) @( 𝕀 Double ) @( T w )
|
||||
in $$( chainRuleQ
|
||||
in $$( chainRule1NQ
|
||||
[|| o ||] [|| p ||] [|| s ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
||||
|
@ -359,7 +359,7 @@ instance HasChainRule ( 𝕀 Double ) 3 ( 𝕀ℝ 3 ) where
|
|||
let !o = origin @( 𝕀 Double ) @( T w )
|
||||
!p = (^+^) @( 𝕀 Double ) @( T w )
|
||||
!s = (^*) @( 𝕀 Double ) @( T w )
|
||||
in $$( chainRuleQ
|
||||
in $$( chainRule1NQ
|
||||
[|| o ||] [|| p ||] [|| s ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
||||
|
@ -394,7 +394,7 @@ instance HasChainRule ( 𝕀 Double ) 2 ( 𝕀ℝ 4 ) where
|
|||
let !o = origin @( 𝕀 Double ) @( T w )
|
||||
!p = (^+^) @( 𝕀 Double ) @( T w )
|
||||
!s = (^*) @( 𝕀 Double ) @( T w )
|
||||
in $$( chainRuleQ
|
||||
in $$( chainRule1NQ
|
||||
[|| o ||] [|| p ||] [|| s ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
||||
|
@ -429,6 +429,6 @@ instance HasChainRule ( 𝕀 Double ) 3 ( 𝕀ℝ 4 ) where
|
|||
let !o = origin @( 𝕀 Double ) @( T w )
|
||||
!p = (^+^) @( 𝕀 Double ) @( T w )
|
||||
!s = (^*) @( 𝕀 Double ) @( T w )
|
||||
in $$( chainRuleQ
|
||||
in $$( chainRule1NQ
|
||||
[|| o ||] [|| p ||] [|| s ||]
|
||||
[|| df ||] [|| dg ||] )
|
||||
|
|
|
@ -28,6 +28,8 @@ import GHC.Show
|
|||
( showSpace )
|
||||
import GHC.TypeNats
|
||||
( Nat, type (+) )
|
||||
import Unsafe.Coerce
|
||||
( unsafeCoerce )
|
||||
|
||||
-- acts
|
||||
import Data.Act
|
||||
|
@ -118,7 +120,7 @@ infixr 5 `VS`
|
|||
type Vec :: Nat -> Type -> Type
|
||||
data Vec n a where
|
||||
VZ :: Vec 0 a
|
||||
VS :: a -> Vec n a -> Vec ( 1 + n ) a
|
||||
VS :: forall n a. a -> Vec n a -> Vec ( 1 + n ) a
|
||||
-- can't be strict, otherwise we can't conveniently
|
||||
-- unsafeCoerce from lists
|
||||
|
||||
|
@ -128,6 +130,13 @@ deriving stock instance Functor ( Vec n )
|
|||
deriving stock instance Foldable ( Vec n )
|
||||
deriving stock instance Traversable ( Vec n )
|
||||
|
||||
instance Eq a => Eq ( Vec n a ) where
|
||||
(==) = unsafeCoerce $ (==) @[a]
|
||||
|
||||
instance Ord a => Ord ( Vec n a ) where
|
||||
compare = unsafeCoerce $ compare @[a]
|
||||
(<=) = unsafeCoerce $ (<=) @[a]
|
||||
|
||||
infixl 9 !
|
||||
(!) :: forall l a. Vec l a -> Fin l -> a
|
||||
VS a _ ! Fin 1 = a
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
{-# LANGUAGE AllowAmbiguousTypes #-}
|
||||
{-# LANGUAGE QuantifiedConstraints #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE ParallelListComp #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
|
||||
|
@ -9,11 +10,13 @@
|
|||
module Math.Monomial
|
||||
( Mon(..)
|
||||
, MonomialBasis(..), Deg, Vars
|
||||
, zeroMonomial, isZeroMonomial, isLinear
|
||||
, zeroMonomial, isZeroMonomial
|
||||
, totalDegree, isLinear
|
||||
|
||||
, split, mons
|
||||
|
||||
, faà, multiSubsetsSum, multiSubsetSum
|
||||
, multiSubsetSumFaà, multiSubsetsSum, multiSubsetSum
|
||||
, partitionFaà, partitions
|
||||
|
||||
, prodRuleQ
|
||||
|
||||
|
@ -27,7 +30,7 @@ import Data.Kind
|
|||
import GHC.Exts
|
||||
( proxy# )
|
||||
import GHC.TypeNats
|
||||
( KnownNat, Nat, natVal' )
|
||||
-- ( KnownNat, Nat, natVal' )
|
||||
import Unsafe.Coerce
|
||||
( unsafeCoerce )
|
||||
|
||||
|
@ -46,7 +49,7 @@ import TH.Utils
|
|||
-- | @Mon k n@ is the set of monomials in @n@ variables of degree less than or equal to @k@.
|
||||
type Mon :: Nat -> Nat -> Type
|
||||
newtype Mon k n = Mon { monDegs :: Vec n Word } -- sum <= k
|
||||
deriving stock Show
|
||||
deriving stock ( Show, Eq, Ord )
|
||||
|
||||
type Deg :: ( Type -> Type ) -> Nat
|
||||
type Vars :: ( Type -> Type ) -> Nat
|
||||
|
@ -65,6 +68,10 @@ isZeroMonomial ( Mon ( i `VS` is ) )
|
|||
| otherwise
|
||||
= False
|
||||
|
||||
totalDegree :: Mon k n -> Word
|
||||
totalDegree ( Mon VZ ) = 0
|
||||
totalDegree ( Mon ( i `VS` is ) ) = i + totalDegree ( Mon is )
|
||||
|
||||
isLinear :: Mon k n -> Maybe ( Fin n )
|
||||
isLinear = fmap Fin . go 1
|
||||
where
|
||||
|
@ -99,12 +106,32 @@ mons' _ 0 = [ [] ]
|
|||
mons' 0 n = [ replicate ( fromIntegral n ) 0 ]
|
||||
mons' k n = [ i : is | i <- reverse [ 0 .. k ], is <- mons' ( k - i ) ( n - 1 ) ]
|
||||
|
||||
subs :: Mon k n -> [ ( Mon k n, Word ) ]
|
||||
subs ( Mon VZ ) = [ ( Mon VZ, maxBound ) ]
|
||||
subs ( Mon ( i `VS` is ) )
|
||||
= [ ( Mon ( j `VS` js )
|
||||
, if j == 0 then mult else min ( i `quot` j ) mult )
|
||||
| j <- [ 0 .. i ]
|
||||
, ( Mon js, mult ) <- subs ( Mon is )
|
||||
]
|
||||
|
||||
word :: forall n. KnownNat n => Word
|
||||
word = fromIntegral $ natVal' @n proxy#
|
||||
|
||||
-- | The factorial function \( n! = n \cdot (n-1) \cdot `ldots` \cdot 2 `cdot` 1 \).
|
||||
factorial :: Word -> Word
|
||||
factorial i = product [ 1 .. i ]
|
||||
|
||||
vecFactorial :: Vec n Word -> Word
|
||||
vecFactorial VZ = 1
|
||||
vecFactorial ( i `VS` is ) = factorial i * vecFactorial is
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Computations for the chain rule R^1 -> R^n -> R^m
|
||||
|
||||
-- | Faà di Bruno coefficient (naive implementation).
|
||||
faà :: Word -> Vec n [ ( Word, Word ) ] -> Word
|
||||
faà k multisubsets =
|
||||
multiSubsetSumFaà :: Word -> Vec n [ ( Word, Word ) ] -> Word
|
||||
multiSubsetSumFaà k multisubsets =
|
||||
factorial k `div`
|
||||
product [ factorial p * ( factorial i ) ^ p
|
||||
| multisubset <- toList multisubsets
|
||||
|
@ -146,9 +173,49 @@ multiSubsetsSum is = goMSS
|
|||
[] -> 0
|
||||
_ -> max 0 $ minimum is
|
||||
|
||||
-- | The factorial function \( n! = n \cdot (n-1) \cdot `ldots` \cdot 2 `cdot` 1 \).
|
||||
factorial :: Word -> Word
|
||||
factorial i = product [ 1 .. i ]
|
||||
--------------------------------------------------------------------------------
|
||||
-- Computations for the chain rule R^n -> R^1 -> R^1
|
||||
|
||||
partitionFaà :: Mon k n -> [ ( Mon k n, Word ) ] -> Word
|
||||
partitionFaà ( Mon mon ) multiIndexes =
|
||||
vecFactorial mon `div`
|
||||
product [ factorial p * ( vecFactorial i ) ^ p
|
||||
| ( Mon i, p ) <- toList multiIndexes ]
|
||||
|
||||
-- | @partitions p mon@ computes all partitions of the monomial @mon@ into
|
||||
-- @p@ (non-zero) parts, allowing repetition.
|
||||
partitions :: forall k n
|
||||
. Word -- ^ number of parts
|
||||
-> Mon k n -- ^ monomial to sum to
|
||||
-> [ [ ( Mon k n, Word ) ] ]
|
||||
partitions n_parts0 mon0 = go n_parts0 mon0 mon0
|
||||
where
|
||||
go :: Word -> Mon k n -> Mon k n -> [ [ ( Mon k n , Word ) ] ]
|
||||
go 0 _ mon
|
||||
| isZeroMonomial mon
|
||||
= [ [] ]
|
||||
| otherwise
|
||||
= []
|
||||
go 1 maxSub mon
|
||||
| isZeroMonomial mon
|
||||
|| mon >= maxSub
|
||||
= []
|
||||
| otherwise
|
||||
= [ [ ( mon, 1 ) ] ]
|
||||
go n_parts maxSub mon
|
||||
= [ ( part, l ) : parts
|
||||
| ( part, l_max ) <- subs mon
|
||||
, not ( isZeroMonomial part ) -- parts must be non-empty
|
||||
, part < maxSub -- use a total ordering on monomials to ensure uniqueness
|
||||
, l <- [ 1 .. min n_parts l_max ]
|
||||
, parts <- go ( n_parts - l ) part ( subPart l mon part )
|
||||
]
|
||||
|
||||
subPart :: Word -> Mon k n -> Mon k n -> Mon k n
|
||||
subPart = unsafeCoerce subPart'
|
||||
|
||||
subPart' :: Word -> [ Word ] -> [ Word ] -> [ Word ]
|
||||
subPart' m = zipWith ( \ i j -> i - m * j )
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
|
|
Loading…
Reference in a new issue