mirror of
https://gitlab.com/sheaf/metabrush.git
synced 2024-11-05 23:03:38 +00:00
add chain rule R^n -> R -> R
This commit is contained in:
parent
236055b4ca
commit
bdcac18ab9
|
@ -411,9 +411,8 @@ instance Ring r => Ring ( D3𝔸4 r ) where
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
-- Field & transcendental instances
|
-- Field & transcendental instances
|
||||||
|
|
||||||
-- TODO!!
|
|
||||||
|
|
||||||
deriving newtype instance Field r => Field ( D𝔸0 r )
|
deriving newtype instance Field r => Field ( D𝔸0 r )
|
||||||
|
|
||||||
--instance Field r => Field ( D1𝔸1 r ) where
|
--instance Field r => Field ( D1𝔸1 r ) where
|
||||||
--instance Field r => Field ( D1𝔸2 r ) where
|
--instance Field r => Field ( D1𝔸2 r ) where
|
||||||
--instance Field r => Field ( D1𝔸3 r ) where
|
--instance Field r => Field ( D1𝔸3 r ) where
|
||||||
|
@ -426,7 +425,37 @@ instance Field r => Field ( D3𝔸1 r ) where
|
||||||
instance Field r => Field ( D3𝔸2 r ) where
|
instance Field r => Field ( D3𝔸2 r ) where
|
||||||
instance Field r => Field ( D3𝔸3 r ) where
|
instance Field r => Field ( D3𝔸3 r ) where
|
||||||
instance Field r => Field ( D3𝔸4 r ) where
|
instance Field r => Field ( D3𝔸4 r ) where
|
||||||
|
-- TODO
|
||||||
|
|
||||||
|
d1sin, d1cos :: Transcendental r => r -> D1𝔸1 r
|
||||||
|
d1sin x =
|
||||||
|
let !s = sin x
|
||||||
|
!c = cos x
|
||||||
|
in D11 s ( T c )
|
||||||
|
d1cos x =
|
||||||
|
let !s = sin x
|
||||||
|
!c = cos x
|
||||||
|
in D11 c ( T -s )
|
||||||
|
|
||||||
|
d2sin, d2cos :: Transcendental r => r -> D2𝔸1 r
|
||||||
|
d2sin x =
|
||||||
|
let !s = sin x
|
||||||
|
!c = cos x
|
||||||
|
in D21 s ( T c ) ( T -s )
|
||||||
|
d2cos x =
|
||||||
|
let !s = sin x
|
||||||
|
!c = cos x
|
||||||
|
in D21 c ( T -s ) ( T -c )
|
||||||
|
|
||||||
|
d3sin, d3cos :: Transcendental r => r -> D3𝔸1 r
|
||||||
|
d3sin x =
|
||||||
|
let !s = sin x
|
||||||
|
!c = cos x
|
||||||
|
in D31 s ( T c ) ( T -s ) ( T -c )
|
||||||
|
d3cos x =
|
||||||
|
let !s = sin x
|
||||||
|
!c = cos x
|
||||||
|
in D31 c ( T -s ) ( T -c ) ( T s )
|
||||||
|
|
||||||
deriving newtype instance Transcendental r => Transcendental ( D𝔸0 r )
|
deriving newtype instance Transcendental r => Transcendental ( D𝔸0 r )
|
||||||
--instance Transcendental r => Transcendental ( D1𝔸1 r ) where
|
--instance Transcendental r => Transcendental ( D1𝔸1 r ) where
|
||||||
|
@ -434,38 +463,170 @@ deriving newtype instance Transcendental r => Transcendental ( D𝔸0 r )
|
||||||
--instance Transcendental r => Transcendental ( D1𝔸3 r ) where
|
--instance Transcendental r => Transcendental ( D1𝔸3 r ) where
|
||||||
--instance Transcendental r => Transcendental ( D1𝔸4 r ) where
|
--instance Transcendental r => Transcendental ( D1𝔸4 r ) where
|
||||||
instance Transcendental r => Transcendental ( D2𝔸1 r ) where
|
instance Transcendental r => Transcendental ( D2𝔸1 r ) where
|
||||||
|
pi = konst @Double @2 @( ℝ 1 ) pi
|
||||||
|
sin df =
|
||||||
|
let
|
||||||
|
fromInt = fromInteger @r
|
||||||
|
add = (+) @r
|
||||||
|
times = (*) @r
|
||||||
|
pow = (^) @r
|
||||||
|
dg = d2sin @r ( _D21_v df )
|
||||||
|
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||||
|
[|| df ||] [|| dg ||] )
|
||||||
|
cos df =
|
||||||
|
let
|
||||||
|
fromInt = fromInteger @r
|
||||||
|
add = (+) @r
|
||||||
|
times = (*) @r
|
||||||
|
pow = (^) @r
|
||||||
|
dg = d2cos @r ( _D21_v df )
|
||||||
|
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||||
|
[|| df ||] [|| dg ||] )
|
||||||
instance Transcendental r => Transcendental ( D2𝔸2 r ) where
|
instance Transcendental r => Transcendental ( D2𝔸2 r ) where
|
||||||
|
pi = konst @Double @2 @( ℝ 2 ) pi
|
||||||
|
sin df =
|
||||||
|
let
|
||||||
|
fromInt = fromInteger @r
|
||||||
|
add = (+) @r
|
||||||
|
times = (*) @r
|
||||||
|
pow = (^) @r
|
||||||
|
dg = d2sin @r ( _D22_v df )
|
||||||
|
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||||
|
[|| df ||] [|| dg ||] )
|
||||||
|
cos df =
|
||||||
|
let
|
||||||
|
fromInt = fromInteger @r
|
||||||
|
add = (+) @r
|
||||||
|
times = (*) @r
|
||||||
|
pow = (^) @r
|
||||||
|
dg = d2cos @r ( _D22_v df )
|
||||||
|
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||||
|
[|| df ||] [|| dg ||] )
|
||||||
instance Transcendental r => Transcendental ( D2𝔸3 r ) where
|
instance Transcendental r => Transcendental ( D2𝔸3 r ) where
|
||||||
pi = konst @Double @2 @( ℝ 3 ) pi
|
pi = konst @Double @2 @( ℝ 3 ) pi
|
||||||
sin ( D23 v ( T dx ) ( T dy ) ( T dz ) ( T ddx ) ( T dxdy ) ( T ddy ) ( T dxdz ) ( T dydz ) ( T ddz ) )
|
sin df =
|
||||||
= let !s = sin v
|
let
|
||||||
!c = cos v
|
fromInt = fromInteger @r
|
||||||
in D23 s
|
add = (+) @r
|
||||||
( T $ c * dx ) ( T $ c * dy ) ( T $ c * dz )
|
times = (*) @r
|
||||||
( T $ 2 * c * ddx - s * ( dx ^ 2 ) )
|
pow = (^) @r
|
||||||
( T $ 2 * c * dxdy - 2 * s * dx * dy )
|
dg = d2sin @r ( _D23_v df )
|
||||||
( T $ 2 * c * ddy - s * ( dy ^ 2 ) )
|
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||||
( T $ 2 * c * dxdz - 2 * s * dx * dz )
|
[|| df ||] [|| dg ||] )
|
||||||
( T $ 2 * c * dydz - 2 * s * dy * dz )
|
cos df =
|
||||||
( T $ 2 * c * ddz - s * ( dz ^ 2 ) )
|
let
|
||||||
|
fromInt = fromInteger @r
|
||||||
cos ( D23 v ( T dx ) ( T dy ) ( T dz ) ( T ddx ) ( T dxdy ) ( T ddy ) ( T dxdz ) ( T dydz ) ( T ddz ) )
|
add = (+) @r
|
||||||
= let !s = sin v
|
times = (*) @r
|
||||||
!c = cos v
|
pow = (^) @r
|
||||||
in D23 c
|
dg = d2cos @r ( _D23_v df )
|
||||||
( T $ -s * dx ) ( T $ -s * dy ) ( T $ -s * dz )
|
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||||
( T $ -2 * s * ddx - c * ( dx ^ 2 ) )
|
[|| df ||] [|| dg ||] )
|
||||||
( T $ -2 * s * dxdy - 2 * c * dx * dy )
|
|
||||||
( T $ -2 * s * ddy - c * ( dy ^ 2 ) )
|
|
||||||
( T $ -2 * s * dxdz - 2 * c * dx * dz )
|
|
||||||
( T $ -2 * s * dydz - 2 * c * dy * dz )
|
|
||||||
( T $ -2 * s * ddz - c * ( dz ^ 2 ) )
|
|
||||||
|
|
||||||
instance Transcendental r => Transcendental ( D2𝔸4 r ) where
|
instance Transcendental r => Transcendental ( D2𝔸4 r ) where
|
||||||
|
pi = konst @Double @2 @( ℝ 4 ) pi
|
||||||
|
sin df =
|
||||||
|
let
|
||||||
|
fromInt = fromInteger @r
|
||||||
|
add = (+) @r
|
||||||
|
times = (*) @r
|
||||||
|
pow = (^) @r
|
||||||
|
dg = d2sin @r ( _D24_v df )
|
||||||
|
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||||
|
[|| df ||] [|| dg ||] )
|
||||||
|
cos df =
|
||||||
|
let
|
||||||
|
fromInt = fromInteger @r
|
||||||
|
add = (+) @r
|
||||||
|
times = (*) @r
|
||||||
|
pow = (^) @r
|
||||||
|
dg = d2cos @r ( _D24_v df )
|
||||||
|
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||||
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
||||||
instance Transcendental r => Transcendental ( D3𝔸1 r ) where
|
instance Transcendental r => Transcendental ( D3𝔸1 r ) where
|
||||||
|
pi = konst @Double @3 @( ℝ 1 ) pi
|
||||||
|
sin df =
|
||||||
|
let
|
||||||
|
fromInt = fromInteger @r
|
||||||
|
add = (+) @r
|
||||||
|
times = (*) @r
|
||||||
|
pow = (^) @r
|
||||||
|
dg = d3sin @r ( _D31_v df )
|
||||||
|
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||||
|
[|| df ||] [|| dg ||] )
|
||||||
|
cos df =
|
||||||
|
let
|
||||||
|
fromInt = fromInteger @r
|
||||||
|
add = (+) @r
|
||||||
|
times = (*) @r
|
||||||
|
pow = (^) @r
|
||||||
|
dg = d3cos @r ( _D31_v df )
|
||||||
|
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||||
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
||||||
instance Transcendental r => Transcendental ( D3𝔸2 r ) where
|
instance Transcendental r => Transcendental ( D3𝔸2 r ) where
|
||||||
|
pi = konst @Double @3 @( ℝ 2 ) pi
|
||||||
|
sin df =
|
||||||
|
let
|
||||||
|
fromInt = fromInteger @r
|
||||||
|
add = (+) @r
|
||||||
|
times = (*) @r
|
||||||
|
pow = (^) @r
|
||||||
|
dg = d3sin @r ( _D32_v df )
|
||||||
|
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||||
|
[|| df ||] [|| dg ||] )
|
||||||
|
cos df =
|
||||||
|
let
|
||||||
|
fromInt = fromInteger @r
|
||||||
|
add = (+) @r
|
||||||
|
times = (*) @r
|
||||||
|
pow = (^) @r
|
||||||
|
dg = d3cos @r ( _D32_v df )
|
||||||
|
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||||
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
||||||
instance Transcendental r => Transcendental ( D3𝔸3 r ) where
|
instance Transcendental r => Transcendental ( D3𝔸3 r ) where
|
||||||
|
pi = konst @Double @3 @( ℝ 3 ) pi
|
||||||
|
sin df =
|
||||||
|
let
|
||||||
|
fromInt = fromInteger @r
|
||||||
|
add = (+) @r
|
||||||
|
times = (*) @r
|
||||||
|
pow = (^) @r
|
||||||
|
dg = d3sin @r ( _D33_v df )
|
||||||
|
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||||
|
[|| df ||] [|| dg ||] )
|
||||||
|
cos df =
|
||||||
|
let
|
||||||
|
fromInt = fromInteger @r
|
||||||
|
add = (+) @r
|
||||||
|
times = (*) @r
|
||||||
|
pow = (^) @r
|
||||||
|
dg = d3cos @r ( _D33_v df )
|
||||||
|
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||||
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
||||||
instance Transcendental r => Transcendental ( D3𝔸4 r ) where
|
instance Transcendental r => Transcendental ( D3𝔸4 r ) where
|
||||||
|
pi = konst @Double @3 @( ℝ 4 ) pi
|
||||||
|
sin df =
|
||||||
|
let
|
||||||
|
fromInt = fromInteger @r
|
||||||
|
add = (+) @r
|
||||||
|
times = (*) @r
|
||||||
|
pow = (^) @r
|
||||||
|
dg = d3sin @r ( _D34_v df )
|
||||||
|
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||||
|
[|| df ||] [|| dg ||] )
|
||||||
|
cos df =
|
||||||
|
let
|
||||||
|
fromInt = fromInteger @r
|
||||||
|
add = (+) @r
|
||||||
|
times = (*) @r
|
||||||
|
pow = (^) @r
|
||||||
|
dg = d3cos @r ( _D34_v df )
|
||||||
|
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
|
||||||
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
-- HasChainRule instances.
|
-- HasChainRule instances.
|
||||||
|
@ -513,7 +674,7 @@ instance HasChainRule Double 2 ( ℝ 1 ) where
|
||||||
let !o = origin @Double @( T w )
|
let !o = origin @Double @( T w )
|
||||||
!p = (^+^) @Double @( T w )
|
!p = (^+^) @Double @( T w )
|
||||||
!s = (^*) @Double @( T w )
|
!s = (^*) @Double @( T w )
|
||||||
in $$( chainRuleQ
|
in $$( chainRule1NQ
|
||||||
[|| o ||] [|| p ||] [|| s ||]
|
[|| o ||] [|| p ||] [|| s ||]
|
||||||
[|| df ||] [|| dg ||] )
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
||||||
|
@ -548,7 +709,7 @@ instance HasChainRule Double 3 ( ℝ 1 ) where
|
||||||
let !o = origin @Double @( T w )
|
let !o = origin @Double @( T w )
|
||||||
!p = (^+^) @Double @( T w )
|
!p = (^+^) @Double @( T w )
|
||||||
!s = (^*) @Double @( T w )
|
!s = (^*) @Double @( T w )
|
||||||
in $$( chainRuleQ
|
in $$( chainRule1NQ
|
||||||
[|| o ||] [|| p ||] [|| s ||]
|
[|| o ||] [|| p ||] [|| s ||]
|
||||||
[|| df ||] [|| dg ||] )
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
||||||
|
@ -583,7 +744,7 @@ instance HasChainRule Double 2 ( ℝ 2 ) where
|
||||||
let !o = origin @Double @( T w )
|
let !o = origin @Double @( T w )
|
||||||
!p = (^+^) @Double @( T w )
|
!p = (^+^) @Double @( T w )
|
||||||
!s = (^*) @Double @( T w )
|
!s = (^*) @Double @( T w )
|
||||||
in $$( chainRuleQ
|
in $$( chainRule1NQ
|
||||||
[|| o ||] [|| p ||] [|| s ||]
|
[|| o ||] [|| p ||] [|| s ||]
|
||||||
[|| df ||] [|| dg ||] )
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
||||||
|
@ -618,7 +779,7 @@ instance HasChainRule Double 3 ( ℝ 2 ) where
|
||||||
let !o = origin @Double @( T w )
|
let !o = origin @Double @( T w )
|
||||||
!p = (^+^) @Double @( T w )
|
!p = (^+^) @Double @( T w )
|
||||||
!s = (^*) @Double @( T w )
|
!s = (^*) @Double @( T w )
|
||||||
in $$( chainRuleQ
|
in $$( chainRule1NQ
|
||||||
[|| o ||] [|| p ||] [|| s ||]
|
[|| o ||] [|| p ||] [|| s ||]
|
||||||
[|| df ||] [|| dg ||] )
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
||||||
|
@ -653,7 +814,7 @@ instance HasChainRule Double 2 ( ℝ 3 ) where
|
||||||
let !o = origin @Double @( T w )
|
let !o = origin @Double @( T w )
|
||||||
!p = (^+^) @Double @( T w )
|
!p = (^+^) @Double @( T w )
|
||||||
!s = (^*) @Double @( T w )
|
!s = (^*) @Double @( T w )
|
||||||
in $$( chainRuleQ
|
in $$( chainRule1NQ
|
||||||
[|| o ||] [|| p ||] [|| s ||]
|
[|| o ||] [|| p ||] [|| s ||]
|
||||||
[|| df ||] [|| dg ||] )
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
||||||
|
@ -688,7 +849,7 @@ instance HasChainRule Double 3 ( ℝ 3 ) where
|
||||||
let !o = origin @Double @( T w )
|
let !o = origin @Double @( T w )
|
||||||
!p = (^+^) @Double @( T w )
|
!p = (^+^) @Double @( T w )
|
||||||
!s = (^*) @Double @( T w )
|
!s = (^*) @Double @( T w )
|
||||||
in $$( chainRuleQ
|
in $$( chainRule1NQ
|
||||||
[|| o ||] [|| p ||] [|| s ||]
|
[|| o ||] [|| p ||] [|| s ||]
|
||||||
[|| df ||] [|| dg ||] )
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
||||||
|
@ -723,7 +884,7 @@ instance HasChainRule Double 2 ( ℝ 4 ) where
|
||||||
let !o = origin @Double @( T w )
|
let !o = origin @Double @( T w )
|
||||||
!p = (^+^) @Double @( T w )
|
!p = (^+^) @Double @( T w )
|
||||||
!s = (^*) @Double @( T w )
|
!s = (^*) @Double @( T w )
|
||||||
in $$( chainRuleQ
|
in $$( chainRule1NQ
|
||||||
[|| o ||] [|| p ||] [|| s ||]
|
[|| o ||] [|| p ||] [|| s ||]
|
||||||
[|| df ||] [|| dg ||] )
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
||||||
|
@ -758,6 +919,6 @@ instance HasChainRule Double 3 ( ℝ 4 ) where
|
||||||
let !o = origin @Double @( T w )
|
let !o = origin @Double @( T w )
|
||||||
!p = (^+^) @Double @( T w )
|
!p = (^+^) @Double @( T w )
|
||||||
!s = (^*) @Double @( T w )
|
!s = (^*) @Double @( T w )
|
||||||
in $$( chainRuleQ
|
in $$( chainRule1NQ
|
||||||
[|| o ||] [|| p ||] [|| s ||]
|
[|| o ||] [|| p ||] [|| s ||]
|
||||||
[|| df ||] [|| dg ||] )
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
|
@ -12,7 +12,7 @@ module Math.Algebra.Dual.Internal
|
||||||
, D1𝔸3(..), D2𝔸3(..), D3𝔸3(..)
|
, D1𝔸3(..), D2𝔸3(..), D3𝔸3(..)
|
||||||
, D1𝔸4(..), D2𝔸4(..), D3𝔸4(..)
|
, D1𝔸4(..), D2𝔸4(..), D3𝔸4(..)
|
||||||
|
|
||||||
, chainRuleQ
|
, chainRule1NQ, chainRuleN1Q
|
||||||
) where
|
) where
|
||||||
|
|
||||||
-- base
|
-- base
|
||||||
|
@ -35,7 +35,9 @@ import Math.Linear
|
||||||
)
|
)
|
||||||
import Math.Monomial
|
import Math.Monomial
|
||||||
( Mon(..), MonomialBasis(..), Vars, Deg
|
( Mon(..), MonomialBasis(..), Vars, Deg
|
||||||
, mons, faà, multiSubsetsSum, zeroMonomial
|
, mons, zeroMonomial, isZeroMonomial, totalDegree
|
||||||
|
, multiSubsetSumFaà, multiSubsetsSum
|
||||||
|
, partitionFaà, partitions
|
||||||
)
|
)
|
||||||
import Math.Ring
|
import Math.Ring
|
||||||
( Ring )
|
( Ring )
|
||||||
|
@ -179,8 +181,10 @@ data D3𝔸4 v =
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
-- | The chain rule, to be spliced in using Template Haskell.
|
-- | The chain rule for a composition \( \mathbb{R}^1 \to \mathbb{R}^n \to W \)
|
||||||
chainRuleQ :: forall dr1 dv v r w d
|
--
|
||||||
|
-- (To be spliced in using Template Haskell.)
|
||||||
|
chainRule1NQ :: forall dr1 dv v r w d
|
||||||
. ( Ring r, RepresentableQ r v
|
. ( Ring r, RepresentableQ r v
|
||||||
, MonomialBasis dr1, Vars dr1 ~ 1
|
, MonomialBasis dr1, Vars dr1 ~ 1
|
||||||
, MonomialBasis dv , Vars dv ~ RepDim v
|
, MonomialBasis dv , Vars dv ~ RepDim v
|
||||||
|
@ -193,7 +197,7 @@ chainRuleQ :: forall dr1 dv v r w d
|
||||||
-> CodeQ ( dr1 v )
|
-> CodeQ ( dr1 v )
|
||||||
-> CodeQ ( dv w )
|
-> CodeQ ( dv w )
|
||||||
-> CodeQ ( dr1 w )
|
-> CodeQ ( dr1 w )
|
||||||
chainRuleQ zero_w sum_w scale_w df dg =
|
chainRule1NQ zero_w sum_w scale_w df dg =
|
||||||
monTabulate @dr1 \ ( Mon ( k `VS` _ ) ) ->
|
monTabulate @dr1 \ ( Mon ( k `VS` _ ) ) ->
|
||||||
case k of
|
case k of
|
||||||
-- Set the value of the composition separately,
|
-- Set the value of the composition separately,
|
||||||
|
@ -203,7 +207,7 @@ chainRuleQ zero_w sum_w scale_w df dg =
|
||||||
[|| unT $ $$( foldQ sum_w zero_w
|
[|| unT $ $$( foldQ sum_w zero_w
|
||||||
[ [|| $$scale_w ( T $$( monIndex @dv dg m_g ) )
|
[ [|| $$scale_w ( T $$( monIndex @dv dg m_g ) )
|
||||||
$$( foldQ [|| (Ring.+) ||] [|| Ring.fromInteger ( 0 :: Integer ) ||]
|
$$( foldQ [|| (Ring.+) ||] [|| Ring.fromInteger ( 0 :: Integer ) ||]
|
||||||
[ [|| Ring.fromInteger $$( liftTyped $ fromIntegral $ faà k is ) Ring.*
|
[ [|| Ring.fromInteger $$( liftTyped $ fromIntegral $ multiSubsetSumFaà k is ) Ring.*
|
||||||
$$( foldQ [|| (Ring.*) ||] [|| Ring.fromInteger ( 1 :: Integer ) ||]
|
$$( foldQ [|| (Ring.*) ||] [|| Ring.fromInteger ( 1 :: Integer ) ||]
|
||||||
[ foldQ [|| (Ring.*) ||] [|| Ring.fromInteger ( 1 :: Integer ) ||]
|
[ foldQ [|| (Ring.*) ||] [|| Ring.fromInteger ( 1 :: Integer ) ||]
|
||||||
[ ( indexQ @r @v ( monIndex @dr1 df ( Mon $ f_deg `VS` VZ ) ) v_index )
|
[ ( indexQ @r @v ( monIndex @dr1 df ( Mon $ f_deg `VS` VZ ) ) v_index )
|
||||||
|
@ -223,6 +227,45 @@ chainRuleQ zero_w sum_w scale_w df dg =
|
||||||
]
|
]
|
||||||
) ||]
|
) ||]
|
||||||
|
|
||||||
|
-- | The chain rule for a composition \( \mathbb{R}^n \to \mathbb{R} \to \mathbb{R} \)
|
||||||
|
--
|
||||||
|
-- (To be spliced in using Template Haskell.)
|
||||||
|
chainRuleN1Q :: forall du dr1 r
|
||||||
|
. ( MonomialBasis du
|
||||||
|
, MonomialBasis dr1, Vars dr1 ~ 1
|
||||||
|
)
|
||||||
|
=> CodeQ ( Integer -> r ) -- ^ fromInteger (circumvent TH constraint issue)
|
||||||
|
-> CodeQ ( r -> r -> r ) -- ^ (+) (circumvent TH constraint issue)
|
||||||
|
-> CodeQ ( r -> r -> r ) -- ^ (*) (circumvent TH constraint issue)
|
||||||
|
-> CodeQ ( r -> Word -> r ) -- ^ (^) (circumvent TH constraint issue)
|
||||||
|
-> CodeQ ( du r )
|
||||||
|
-> CodeQ ( dr1 r )
|
||||||
|
-> CodeQ ( du r )
|
||||||
|
chainRuleN1Q fromInt add times pow df dg =
|
||||||
|
monTabulate @du \ mon ->
|
||||||
|
if
|
||||||
|
| isZeroMonomial mon
|
||||||
|
-- Set the value of the composition separately,
|
||||||
|
-- as that isn't handled by the Faà di Bruno formula.
|
||||||
|
-> monIndex @dr1 dg zeroMonomial
|
||||||
|
| otherwise
|
||||||
|
-> foldQ add [|| $$fromInt ( 0 :: Integer ) ||]
|
||||||
|
[ [|| $$times $$( monIndex @dr1 dg ( Mon ( k `VS` VZ ) ) )
|
||||||
|
$$( foldQ add [|| $$fromInt ( 0 :: Integer ) ||]
|
||||||
|
[ [|| $$times ( $$fromInt $$( liftTyped $ fromIntegral $ partitionFaà mon is ) )
|
||||||
|
$$( foldQ times [|| $$fromInt ( 1 :: Integer ) ||]
|
||||||
|
[ [|| $$pow $$( monIndex @du df f_mon ) p ||]
|
||||||
|
| ( f_mon, p ) <- is
|
||||||
|
]
|
||||||
|
) ||]
|
||||||
|
| is <- mss
|
||||||
|
]
|
||||||
|
)
|
||||||
|
||]
|
||||||
|
| k <- [ 1 .. totalDegree mon ]
|
||||||
|
, let mss = partitions k mon
|
||||||
|
]
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
-- MonomialBasis instances follow (nothing else).
|
-- MonomialBasis instances follow (nothing else).
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ import Data.Group.Generics
|
||||||
-- splines
|
-- splines
|
||||||
import Math.Algebra.Dual
|
import Math.Algebra.Dual
|
||||||
import Math.Algebra.Dual.Internal
|
import Math.Algebra.Dual.Internal
|
||||||
( chainRuleQ )
|
( chainRule1NQ )
|
||||||
import Math.Interval.Internal
|
import Math.Interval.Internal
|
||||||
( 𝕀(..) )
|
( 𝕀(..) )
|
||||||
import Math.Linear
|
import Math.Linear
|
||||||
|
@ -184,7 +184,7 @@ instance HasChainRule ( 𝕀 Double ) 2 ( 𝕀ℝ 1 ) where
|
||||||
let !o = origin @( 𝕀 Double ) @( T w )
|
let !o = origin @( 𝕀 Double ) @( T w )
|
||||||
!p = (^+^) @( 𝕀 Double ) @( T w )
|
!p = (^+^) @( 𝕀 Double ) @( T w )
|
||||||
!s = (^*) @( 𝕀 Double ) @( T w )
|
!s = (^*) @( 𝕀 Double ) @( T w )
|
||||||
in $$( chainRuleQ
|
in $$( chainRule1NQ
|
||||||
[|| o ||] [|| p ||] [|| s ||]
|
[|| o ||] [|| p ||] [|| s ||]
|
||||||
[|| df ||] [|| dg ||] )
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
||||||
|
@ -219,7 +219,7 @@ instance HasChainRule ( 𝕀 Double ) 3 ( 𝕀ℝ 1 ) where
|
||||||
let !o = origin @( 𝕀 Double ) @( T w )
|
let !o = origin @( 𝕀 Double ) @( T w )
|
||||||
!p = (^+^) @( 𝕀 Double ) @( T w )
|
!p = (^+^) @( 𝕀 Double ) @( T w )
|
||||||
!s = (^*) @( 𝕀 Double ) @( T w )
|
!s = (^*) @( 𝕀 Double ) @( T w )
|
||||||
in $$( chainRuleQ
|
in $$( chainRule1NQ
|
||||||
[|| o ||] [|| p ||] [|| s ||]
|
[|| o ||] [|| p ||] [|| s ||]
|
||||||
[|| df ||] [|| dg ||] )
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
||||||
|
@ -254,7 +254,7 @@ instance HasChainRule ( 𝕀 Double ) 2 ( 𝕀ℝ 2 ) where
|
||||||
let !o = origin @( 𝕀 Double ) @( T w )
|
let !o = origin @( 𝕀 Double ) @( T w )
|
||||||
!p = (^+^) @( 𝕀 Double ) @( T w )
|
!p = (^+^) @( 𝕀 Double ) @( T w )
|
||||||
!s = (^*) @( 𝕀 Double ) @( T w )
|
!s = (^*) @( 𝕀 Double ) @( T w )
|
||||||
in $$( chainRuleQ
|
in $$( chainRule1NQ
|
||||||
[|| o ||] [|| p ||] [|| s ||]
|
[|| o ||] [|| p ||] [|| s ||]
|
||||||
[|| df ||] [|| dg ||] )
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
||||||
|
@ -289,7 +289,7 @@ instance HasChainRule ( 𝕀 Double ) 3 ( 𝕀ℝ 2 ) where
|
||||||
let !o = origin @( 𝕀 Double ) @( T w )
|
let !o = origin @( 𝕀 Double ) @( T w )
|
||||||
!p = (^+^) @( 𝕀 Double ) @( T w )
|
!p = (^+^) @( 𝕀 Double ) @( T w )
|
||||||
!s = (^*) @( 𝕀 Double ) @( T w )
|
!s = (^*) @( 𝕀 Double ) @( T w )
|
||||||
in $$( chainRuleQ
|
in $$( chainRule1NQ
|
||||||
[|| o ||] [|| p ||] [|| s ||]
|
[|| o ||] [|| p ||] [|| s ||]
|
||||||
[|| df ||] [|| dg ||] )
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
||||||
|
@ -324,7 +324,7 @@ instance HasChainRule ( 𝕀 Double ) 2 ( 𝕀ℝ 3 ) where
|
||||||
let !o = origin @( 𝕀 Double ) @( T w )
|
let !o = origin @( 𝕀 Double ) @( T w )
|
||||||
!p = (^+^) @( 𝕀 Double ) @( T w )
|
!p = (^+^) @( 𝕀 Double ) @( T w )
|
||||||
!s = (^*) @( 𝕀 Double ) @( T w )
|
!s = (^*) @( 𝕀 Double ) @( T w )
|
||||||
in $$( chainRuleQ
|
in $$( chainRule1NQ
|
||||||
[|| o ||] [|| p ||] [|| s ||]
|
[|| o ||] [|| p ||] [|| s ||]
|
||||||
[|| df ||] [|| dg ||] )
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
||||||
|
@ -359,7 +359,7 @@ instance HasChainRule ( 𝕀 Double ) 3 ( 𝕀ℝ 3 ) where
|
||||||
let !o = origin @( 𝕀 Double ) @( T w )
|
let !o = origin @( 𝕀 Double ) @( T w )
|
||||||
!p = (^+^) @( 𝕀 Double ) @( T w )
|
!p = (^+^) @( 𝕀 Double ) @( T w )
|
||||||
!s = (^*) @( 𝕀 Double ) @( T w )
|
!s = (^*) @( 𝕀 Double ) @( T w )
|
||||||
in $$( chainRuleQ
|
in $$( chainRule1NQ
|
||||||
[|| o ||] [|| p ||] [|| s ||]
|
[|| o ||] [|| p ||] [|| s ||]
|
||||||
[|| df ||] [|| dg ||] )
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
||||||
|
@ -394,7 +394,7 @@ instance HasChainRule ( 𝕀 Double ) 2 ( 𝕀ℝ 4 ) where
|
||||||
let !o = origin @( 𝕀 Double ) @( T w )
|
let !o = origin @( 𝕀 Double ) @( T w )
|
||||||
!p = (^+^) @( 𝕀 Double ) @( T w )
|
!p = (^+^) @( 𝕀 Double ) @( T w )
|
||||||
!s = (^*) @( 𝕀 Double ) @( T w )
|
!s = (^*) @( 𝕀 Double ) @( T w )
|
||||||
in $$( chainRuleQ
|
in $$( chainRule1NQ
|
||||||
[|| o ||] [|| p ||] [|| s ||]
|
[|| o ||] [|| p ||] [|| s ||]
|
||||||
[|| df ||] [|| dg ||] )
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
||||||
|
@ -429,6 +429,6 @@ instance HasChainRule ( 𝕀 Double ) 3 ( 𝕀ℝ 4 ) where
|
||||||
let !o = origin @( 𝕀 Double ) @( T w )
|
let !o = origin @( 𝕀 Double ) @( T w )
|
||||||
!p = (^+^) @( 𝕀 Double ) @( T w )
|
!p = (^+^) @( 𝕀 Double ) @( T w )
|
||||||
!s = (^*) @( 𝕀 Double ) @( T w )
|
!s = (^*) @( 𝕀 Double ) @( T w )
|
||||||
in $$( chainRuleQ
|
in $$( chainRule1NQ
|
||||||
[|| o ||] [|| p ||] [|| s ||]
|
[|| o ||] [|| p ||] [|| s ||]
|
||||||
[|| df ||] [|| dg ||] )
|
[|| df ||] [|| dg ||] )
|
||||||
|
|
|
@ -28,6 +28,8 @@ import GHC.Show
|
||||||
( showSpace )
|
( showSpace )
|
||||||
import GHC.TypeNats
|
import GHC.TypeNats
|
||||||
( Nat, type (+) )
|
( Nat, type (+) )
|
||||||
|
import Unsafe.Coerce
|
||||||
|
( unsafeCoerce )
|
||||||
|
|
||||||
-- acts
|
-- acts
|
||||||
import Data.Act
|
import Data.Act
|
||||||
|
@ -118,7 +120,7 @@ infixr 5 `VS`
|
||||||
type Vec :: Nat -> Type -> Type
|
type Vec :: Nat -> Type -> Type
|
||||||
data Vec n a where
|
data Vec n a where
|
||||||
VZ :: Vec 0 a
|
VZ :: Vec 0 a
|
||||||
VS :: a -> Vec n a -> Vec ( 1 + n ) a
|
VS :: forall n a. a -> Vec n a -> Vec ( 1 + n ) a
|
||||||
-- can't be strict, otherwise we can't conveniently
|
-- can't be strict, otherwise we can't conveniently
|
||||||
-- unsafeCoerce from lists
|
-- unsafeCoerce from lists
|
||||||
|
|
||||||
|
@ -128,6 +130,13 @@ deriving stock instance Functor ( Vec n )
|
||||||
deriving stock instance Foldable ( Vec n )
|
deriving stock instance Foldable ( Vec n )
|
||||||
deriving stock instance Traversable ( Vec n )
|
deriving stock instance Traversable ( Vec n )
|
||||||
|
|
||||||
|
instance Eq a => Eq ( Vec n a ) where
|
||||||
|
(==) = unsafeCoerce $ (==) @[a]
|
||||||
|
|
||||||
|
instance Ord a => Ord ( Vec n a ) where
|
||||||
|
compare = unsafeCoerce $ compare @[a]
|
||||||
|
(<=) = unsafeCoerce $ (<=) @[a]
|
||||||
|
|
||||||
infixl 9 !
|
infixl 9 !
|
||||||
(!) :: forall l a. Vec l a -> Fin l -> a
|
(!) :: forall l a. Vec l a -> Fin l -> a
|
||||||
VS a _ ! Fin 1 = a
|
VS a _ ! Fin 1 = a
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
{-# LANGUAGE AllowAmbiguousTypes #-}
|
{-# LANGUAGE AllowAmbiguousTypes #-}
|
||||||
{-# LANGUAGE QuantifiedConstraints #-}
|
{-# LANGUAGE QuantifiedConstraints #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE ParallelListComp #-}
|
||||||
{-# LANGUAGE TemplateHaskell #-}
|
{-# LANGUAGE TemplateHaskell #-}
|
||||||
{-# LANGUAGE UndecidableInstances #-}
|
{-# LANGUAGE UndecidableInstances #-}
|
||||||
|
|
||||||
|
@ -9,11 +10,13 @@
|
||||||
module Math.Monomial
|
module Math.Monomial
|
||||||
( Mon(..)
|
( Mon(..)
|
||||||
, MonomialBasis(..), Deg, Vars
|
, MonomialBasis(..), Deg, Vars
|
||||||
, zeroMonomial, isZeroMonomial, isLinear
|
, zeroMonomial, isZeroMonomial
|
||||||
|
, totalDegree, isLinear
|
||||||
|
|
||||||
, split, mons
|
, split, mons
|
||||||
|
|
||||||
, faà, multiSubsetsSum, multiSubsetSum
|
, multiSubsetSumFaà, multiSubsetsSum, multiSubsetSum
|
||||||
|
, partitionFaà, partitions
|
||||||
|
|
||||||
, prodRuleQ
|
, prodRuleQ
|
||||||
|
|
||||||
|
@ -27,7 +30,7 @@ import Data.Kind
|
||||||
import GHC.Exts
|
import GHC.Exts
|
||||||
( proxy# )
|
( proxy# )
|
||||||
import GHC.TypeNats
|
import GHC.TypeNats
|
||||||
( KnownNat, Nat, natVal' )
|
-- ( KnownNat, Nat, natVal' )
|
||||||
import Unsafe.Coerce
|
import Unsafe.Coerce
|
||||||
( unsafeCoerce )
|
( unsafeCoerce )
|
||||||
|
|
||||||
|
@ -46,7 +49,7 @@ import TH.Utils
|
||||||
-- | @Mon k n@ is the set of monomials in @n@ variables of degree less than or equal to @k@.
|
-- | @Mon k n@ is the set of monomials in @n@ variables of degree less than or equal to @k@.
|
||||||
type Mon :: Nat -> Nat -> Type
|
type Mon :: Nat -> Nat -> Type
|
||||||
newtype Mon k n = Mon { monDegs :: Vec n Word } -- sum <= k
|
newtype Mon k n = Mon { monDegs :: Vec n Word } -- sum <= k
|
||||||
deriving stock Show
|
deriving stock ( Show, Eq, Ord )
|
||||||
|
|
||||||
type Deg :: ( Type -> Type ) -> Nat
|
type Deg :: ( Type -> Type ) -> Nat
|
||||||
type Vars :: ( Type -> Type ) -> Nat
|
type Vars :: ( Type -> Type ) -> Nat
|
||||||
|
@ -65,6 +68,10 @@ isZeroMonomial ( Mon ( i `VS` is ) )
|
||||||
| otherwise
|
| otherwise
|
||||||
= False
|
= False
|
||||||
|
|
||||||
|
totalDegree :: Mon k n -> Word
|
||||||
|
totalDegree ( Mon VZ ) = 0
|
||||||
|
totalDegree ( Mon ( i `VS` is ) ) = i + totalDegree ( Mon is )
|
||||||
|
|
||||||
isLinear :: Mon k n -> Maybe ( Fin n )
|
isLinear :: Mon k n -> Maybe ( Fin n )
|
||||||
isLinear = fmap Fin . go 1
|
isLinear = fmap Fin . go 1
|
||||||
where
|
where
|
||||||
|
@ -99,12 +106,32 @@ mons' _ 0 = [ [] ]
|
||||||
mons' 0 n = [ replicate ( fromIntegral n ) 0 ]
|
mons' 0 n = [ replicate ( fromIntegral n ) 0 ]
|
||||||
mons' k n = [ i : is | i <- reverse [ 0 .. k ], is <- mons' ( k - i ) ( n - 1 ) ]
|
mons' k n = [ i : is | i <- reverse [ 0 .. k ], is <- mons' ( k - i ) ( n - 1 ) ]
|
||||||
|
|
||||||
|
subs :: Mon k n -> [ ( Mon k n, Word ) ]
|
||||||
|
subs ( Mon VZ ) = [ ( Mon VZ, maxBound ) ]
|
||||||
|
subs ( Mon ( i `VS` is ) )
|
||||||
|
= [ ( Mon ( j `VS` js )
|
||||||
|
, if j == 0 then mult else min ( i `quot` j ) mult )
|
||||||
|
| j <- [ 0 .. i ]
|
||||||
|
, ( Mon js, mult ) <- subs ( Mon is )
|
||||||
|
]
|
||||||
|
|
||||||
word :: forall n. KnownNat n => Word
|
word :: forall n. KnownNat n => Word
|
||||||
word = fromIntegral $ natVal' @n proxy#
|
word = fromIntegral $ natVal' @n proxy#
|
||||||
|
|
||||||
|
-- | The factorial function \( n! = n \cdot (n-1) \cdot `ldots` \cdot 2 `cdot` 1 \).
|
||||||
|
factorial :: Word -> Word
|
||||||
|
factorial i = product [ 1 .. i ]
|
||||||
|
|
||||||
|
vecFactorial :: Vec n Word -> Word
|
||||||
|
vecFactorial VZ = 1
|
||||||
|
vecFactorial ( i `VS` is ) = factorial i * vecFactorial is
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- Computations for the chain rule R^1 -> R^n -> R^m
|
||||||
|
|
||||||
-- | Faà di Bruno coefficient (naive implementation).
|
-- | Faà di Bruno coefficient (naive implementation).
|
||||||
faà :: Word -> Vec n [ ( Word, Word ) ] -> Word
|
multiSubsetSumFaà :: Word -> Vec n [ ( Word, Word ) ] -> Word
|
||||||
faà k multisubsets =
|
multiSubsetSumFaà k multisubsets =
|
||||||
factorial k `div`
|
factorial k `div`
|
||||||
product [ factorial p * ( factorial i ) ^ p
|
product [ factorial p * ( factorial i ) ^ p
|
||||||
| multisubset <- toList multisubsets
|
| multisubset <- toList multisubsets
|
||||||
|
@ -146,9 +173,49 @@ multiSubsetsSum is = goMSS
|
||||||
[] -> 0
|
[] -> 0
|
||||||
_ -> max 0 $ minimum is
|
_ -> max 0 $ minimum is
|
||||||
|
|
||||||
-- | The factorial function \( n! = n \cdot (n-1) \cdot `ldots` \cdot 2 `cdot` 1 \).
|
--------------------------------------------------------------------------------
|
||||||
factorial :: Word -> Word
|
-- Computations for the chain rule R^n -> R^1 -> R^1
|
||||||
factorial i = product [ 1 .. i ]
|
|
||||||
|
partitionFaà :: Mon k n -> [ ( Mon k n, Word ) ] -> Word
|
||||||
|
partitionFaà ( Mon mon ) multiIndexes =
|
||||||
|
vecFactorial mon `div`
|
||||||
|
product [ factorial p * ( vecFactorial i ) ^ p
|
||||||
|
| ( Mon i, p ) <- toList multiIndexes ]
|
||||||
|
|
||||||
|
-- | @partitions p mon@ computes all partitions of the monomial @mon@ into
|
||||||
|
-- @p@ (non-zero) parts, allowing repetition.
|
||||||
|
partitions :: forall k n
|
||||||
|
. Word -- ^ number of parts
|
||||||
|
-> Mon k n -- ^ monomial to sum to
|
||||||
|
-> [ [ ( Mon k n, Word ) ] ]
|
||||||
|
partitions n_parts0 mon0 = go n_parts0 mon0 mon0
|
||||||
|
where
|
||||||
|
go :: Word -> Mon k n -> Mon k n -> [ [ ( Mon k n , Word ) ] ]
|
||||||
|
go 0 _ mon
|
||||||
|
| isZeroMonomial mon
|
||||||
|
= [ [] ]
|
||||||
|
| otherwise
|
||||||
|
= []
|
||||||
|
go 1 maxSub mon
|
||||||
|
| isZeroMonomial mon
|
||||||
|
|| mon >= maxSub
|
||||||
|
= []
|
||||||
|
| otherwise
|
||||||
|
= [ [ ( mon, 1 ) ] ]
|
||||||
|
go n_parts maxSub mon
|
||||||
|
= [ ( part, l ) : parts
|
||||||
|
| ( part, l_max ) <- subs mon
|
||||||
|
, not ( isZeroMonomial part ) -- parts must be non-empty
|
||||||
|
, part < maxSub -- use a total ordering on monomials to ensure uniqueness
|
||||||
|
, l <- [ 1 .. min n_parts l_max ]
|
||||||
|
, parts <- go ( n_parts - l ) part ( subPart l mon part )
|
||||||
|
]
|
||||||
|
|
||||||
|
subPart :: Word -> Mon k n -> Mon k n -> Mon k n
|
||||||
|
subPart = unsafeCoerce subPart'
|
||||||
|
|
||||||
|
subPart' :: Word -> [ Word ] -> [ Word ] -> [ Word ]
|
||||||
|
subPart' m = zipWith ( \ i j -> i - m * j )
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue