add chain rule R^n -> R -> R

This commit is contained in:
sheaf 2023-01-22 04:51:23 +01:00
parent 236055b4ca
commit bdcac18ab9
5 changed files with 350 additions and 70 deletions

View file

@ -411,9 +411,8 @@ instance Ring r => Ring ( D3𝔸4 r ) where
--------------------------------------------------------------------------------
-- Field & transcendental instances
-- TODO!!
deriving newtype instance Field r => Field ( D𝔸0 r )
--instance Field r => Field ( D1𝔸1 r ) where
--instance Field r => Field ( D1𝔸2 r ) where
--instance Field r => Field ( D1𝔸3 r ) where
@ -426,7 +425,37 @@ instance Field r => Field ( D3𝔸1 r ) where
instance Field r => Field ( D3𝔸2 r ) where
instance Field r => Field ( D3𝔸3 r ) where
instance Field r => Field ( D3𝔸4 r ) where
-- TODO
d1sin, d1cos :: Transcendental r => r -> D1𝔸1 r
d1sin x =
let !s = sin x
!c = cos x
in D11 s ( T c )
d1cos x =
let !s = sin x
!c = cos x
in D11 c ( T -s )
d2sin, d2cos :: Transcendental r => r -> D2𝔸1 r
d2sin x =
let !s = sin x
!c = cos x
in D21 s ( T c ) ( T -s )
d2cos x =
let !s = sin x
!c = cos x
in D21 c ( T -s ) ( T -c )
d3sin, d3cos :: Transcendental r => r -> D3𝔸1 r
d3sin x =
let !s = sin x
!c = cos x
in D31 s ( T c ) ( T -s ) ( T -c )
d3cos x =
let !s = sin x
!c = cos x
in D31 c ( T -s ) ( T -c ) ( T s )
deriving newtype instance Transcendental r => Transcendental ( D𝔸0 r )
--instance Transcendental r => Transcendental ( D1𝔸1 r ) where
@ -434,38 +463,170 @@ deriving newtype instance Transcendental r => Transcendental ( D𝔸0 r )
--instance Transcendental r => Transcendental ( D1𝔸3 r ) where
--instance Transcendental r => Transcendental ( D1𝔸4 r ) where
instance Transcendental r => Transcendental ( D2𝔸1 r ) where
pi = konst @Double @2 @( 1 ) pi
sin df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d2sin @r ( _D21_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
cos df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d2cos @r ( _D21_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
instance Transcendental r => Transcendental ( D2𝔸2 r ) where
pi = konst @Double @2 @( 2 ) pi
sin df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d2sin @r ( _D22_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
cos df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d2cos @r ( _D22_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
instance Transcendental r => Transcendental ( D2𝔸3 r ) where
pi = konst @Double @2 @( 3 ) pi
sin ( D23 v ( T dx ) ( T dy ) ( T dz ) ( T ddx ) ( T dxdy ) ( T ddy ) ( T dxdz ) ( T dydz ) ( T ddz ) )
= let !s = sin v
!c = cos v
in D23 s
( T $ c * dx ) ( T $ c * dy ) ( T $ c * dz )
( T $ 2 * c * ddx - s * ( dx ^ 2 ) )
( T $ 2 * c * dxdy - 2 * s * dx * dy )
( T $ 2 * c * ddy - s * ( dy ^ 2 ) )
( T $ 2 * c * dxdz - 2 * s * dx * dz )
( T $ 2 * c * dydz - 2 * s * dy * dz )
( T $ 2 * c * ddz - s * ( dz ^ 2 ) )
cos ( D23 v ( T dx ) ( T dy ) ( T dz ) ( T ddx ) ( T dxdy ) ( T ddy ) ( T dxdz ) ( T dydz ) ( T ddz ) )
= let !s = sin v
!c = cos v
in D23 c
( T $ -s * dx ) ( T $ -s * dy ) ( T $ -s * dz )
( T $ -2 * s * ddx - c * ( dx ^ 2 ) )
( T $ -2 * s * dxdy - 2 * c * dx * dy )
( T $ -2 * s * ddy - c * ( dy ^ 2 ) )
( T $ -2 * s * dxdz - 2 * c * dx * dz )
( T $ -2 * s * dydz - 2 * c * dy * dz )
( T $ -2 * s * ddz - c * ( dz ^ 2 ) )
sin df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d2sin @r ( _D23_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
cos df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d2cos @r ( _D23_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
instance Transcendental r => Transcendental ( D2𝔸4 r ) where
pi = konst @Double @2 @( 4 ) pi
sin df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d2sin @r ( _D24_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
cos df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d2cos @r ( _D24_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
instance Transcendental r => Transcendental ( D3𝔸1 r ) where
pi = konst @Double @3 @( 1 ) pi
sin df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d3sin @r ( _D31_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
cos df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d3cos @r ( _D31_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
instance Transcendental r => Transcendental ( D3𝔸2 r ) where
pi = konst @Double @3 @( 2 ) pi
sin df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d3sin @r ( _D32_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
cos df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d3cos @r ( _D32_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
instance Transcendental r => Transcendental ( D3𝔸3 r ) where
pi = konst @Double @3 @( 3 ) pi
sin df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d3sin @r ( _D33_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
cos df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d3cos @r ( _D33_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
instance Transcendental r => Transcendental ( D3𝔸4 r ) where
pi = konst @Double @3 @( 4 ) pi
sin df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d3sin @r ( _D34_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
cos df =
let
fromInt = fromInteger @r
add = (+) @r
times = (*) @r
pow = (^) @r
dg = d3cos @r ( _D34_v df )
in $$( chainRuleN1Q [|| fromInt ||] [|| add ||] [|| times ||] [|| pow ||]
[|| df ||] [|| dg ||] )
--------------------------------------------------------------------------------
-- HasChainRule instances.
@ -513,7 +674,7 @@ instance HasChainRule Double 2 ( 1 ) where
let !o = origin @Double @( T w )
!p = (^+^) @Double @( T w )
!s = (^*) @Double @( T w )
in $$( chainRuleQ
in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] )
@ -548,7 +709,7 @@ instance HasChainRule Double 3 ( 1 ) where
let !o = origin @Double @( T w )
!p = (^+^) @Double @( T w )
!s = (^*) @Double @( T w )
in $$( chainRuleQ
in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] )
@ -583,7 +744,7 @@ instance HasChainRule Double 2 ( 2 ) where
let !o = origin @Double @( T w )
!p = (^+^) @Double @( T w )
!s = (^*) @Double @( T w )
in $$( chainRuleQ
in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] )
@ -618,7 +779,7 @@ instance HasChainRule Double 3 ( 2 ) where
let !o = origin @Double @( T w )
!p = (^+^) @Double @( T w )
!s = (^*) @Double @( T w )
in $$( chainRuleQ
in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] )
@ -653,7 +814,7 @@ instance HasChainRule Double 2 ( 3 ) where
let !o = origin @Double @( T w )
!p = (^+^) @Double @( T w )
!s = (^*) @Double @( T w )
in $$( chainRuleQ
in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] )
@ -688,7 +849,7 @@ instance HasChainRule Double 3 ( 3 ) where
let !o = origin @Double @( T w )
!p = (^+^) @Double @( T w )
!s = (^*) @Double @( T w )
in $$( chainRuleQ
in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] )
@ -723,7 +884,7 @@ instance HasChainRule Double 2 ( 4 ) where
let !o = origin @Double @( T w )
!p = (^+^) @Double @( T w )
!s = (^*) @Double @( T w )
in $$( chainRuleQ
in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] )
@ -758,6 +919,6 @@ instance HasChainRule Double 3 ( 4 ) where
let !o = origin @Double @( T w )
!p = (^+^) @Double @( T w )
!s = (^*) @Double @( T w )
in $$( chainRuleQ
in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] )

View file

@ -12,7 +12,7 @@ module Math.Algebra.Dual.Internal
, D1𝔸3(..), D2𝔸3(..), D3𝔸3(..)
, D1𝔸4(..), D2𝔸4(..), D3𝔸4(..)
, chainRuleQ
, chainRule1NQ, chainRuleN1Q
) where
-- base
@ -35,7 +35,9 @@ import Math.Linear
)
import Math.Monomial
( Mon(..), MonomialBasis(..), Vars, Deg
, mons, faà, multiSubsetsSum, zeroMonomial
, mons, zeroMonomial, isZeroMonomial, totalDegree
, multiSubsetSumFaà, multiSubsetsSum
, partitionFaà, partitions
)
import Math.Ring
( Ring )
@ -179,21 +181,23 @@ data D3𝔸4 v =
--------------------------------------------------------------------------------
-- | The chain rule, to be spliced in using Template Haskell.
chainRuleQ :: forall dr1 dv v r w d
. ( Ring r, RepresentableQ r v
, MonomialBasis dr1, Vars dr1 ~ 1
, MonomialBasis dv , Vars dv ~ RepDim v
, Deg dr1 ~ Deg dv
, d ~ Vars dv, KnownNat d
)
=> CodeQ ( T w ) -- Module r ( T w )
-> CodeQ ( T w -> T w -> T w ) --
-> CodeQ ( T w -> r -> T w ) -- (circumvent TH constraint issue)
-> CodeQ ( dr1 v )
-> CodeQ ( dv w )
-> CodeQ ( dr1 w )
chainRuleQ zero_w sum_w scale_w df dg =
-- | The chain rule for a composition \( \mathbb{R}^1 \to \mathbb{R}^n \to W \)
--
-- (To be spliced in using Template Haskell.)
chainRule1NQ :: forall dr1 dv v r w d
. ( Ring r, RepresentableQ r v
, MonomialBasis dr1, Vars dr1 ~ 1
, MonomialBasis dv , Vars dv ~ RepDim v
, Deg dr1 ~ Deg dv
, d ~ Vars dv, KnownNat d
)
=> CodeQ ( T w ) -- Module r ( T w )
-> CodeQ ( T w -> T w -> T w ) --
-> CodeQ ( T w -> r -> T w ) -- (circumvent TH constraint issue)
-> CodeQ ( dr1 v )
-> CodeQ ( dv w )
-> CodeQ ( dr1 w )
chainRule1NQ zero_w sum_w scale_w df dg =
monTabulate @dr1 \ ( Mon ( k `VS` _ ) ) ->
case k of
-- Set the value of the composition separately,
@ -203,7 +207,7 @@ chainRuleQ zero_w sum_w scale_w df dg =
[|| unT $ $$( foldQ sum_w zero_w
[ [|| $$scale_w ( T $$( monIndex @dv dg m_g ) )
$$( foldQ [|| (Ring.+) ||] [|| Ring.fromInteger ( 0 :: Integer ) ||]
[ [|| Ring.fromInteger $$( liftTyped $ fromIntegral $ fk is ) Ring.*
[ [|| Ring.fromInteger $$( liftTyped $ fromIntegral $ multiSubsetSumFk is ) Ring.*
$$( foldQ [|| (Ring.*) ||] [|| Ring.fromInteger ( 1 :: Integer ) ||]
[ foldQ [|| (Ring.*) ||] [|| Ring.fromInteger ( 1 :: Integer ) ||]
[ ( indexQ @r @v ( monIndex @dr1 df ( Mon $ f_deg `VS` VZ ) ) v_index )
@ -223,6 +227,45 @@ chainRuleQ zero_w sum_w scale_w df dg =
]
) ||]
-- | The chain rule for a composition \( \mathbb{R}^n \to \mathbb{R} \to \mathbb{R} \)
--
-- (To be spliced in using Template Haskell.)
chainRuleN1Q :: forall du dr1 r
. ( MonomialBasis du
, MonomialBasis dr1, Vars dr1 ~ 1
)
=> CodeQ ( Integer -> r ) -- ^ fromInteger (circumvent TH constraint issue)
-> CodeQ ( r -> r -> r ) -- ^ (+) (circumvent TH constraint issue)
-> CodeQ ( r -> r -> r ) -- ^ (*) (circumvent TH constraint issue)
-> CodeQ ( r -> Word -> r ) -- ^ (^) (circumvent TH constraint issue)
-> CodeQ ( du r )
-> CodeQ ( dr1 r )
-> CodeQ ( du r )
chainRuleN1Q fromInt add times pow df dg =
monTabulate @du \ mon ->
if
| isZeroMonomial mon
-- Set the value of the composition separately,
-- as that isn't handled by the Faà di Bruno formula.
-> monIndex @dr1 dg zeroMonomial
| otherwise
-> foldQ add [|| $$fromInt ( 0 :: Integer ) ||]
[ [|| $$times $$( monIndex @dr1 dg ( Mon ( k `VS` VZ ) ) )
$$( foldQ add [|| $$fromInt ( 0 :: Integer ) ||]
[ [|| $$times ( $$fromInt $$( liftTyped $ fromIntegral $ partitionFaà mon is ) )
$$( foldQ times [|| $$fromInt ( 1 :: Integer ) ||]
[ [|| $$pow $$( monIndex @du df f_mon ) p ||]
| ( f_mon, p ) <- is
]
) ||]
| is <- mss
]
)
||]
| k <- [ 1 .. totalDegree mon ]
, let mss = partitions k mon
]
--------------------------------------------------------------------------------
-- MonomialBasis instances follow (nothing else).

View file

@ -35,7 +35,7 @@ import Data.Group.Generics
-- splines
import Math.Algebra.Dual
import Math.Algebra.Dual.Internal
( chainRuleQ )
( chainRule1NQ )
import Math.Interval.Internal
( 𝕀(..) )
import Math.Linear
@ -184,7 +184,7 @@ instance HasChainRule ( 𝕀 Double ) 2 ( 𝕀 1 ) where
let !o = origin @( 𝕀 Double ) @( T w )
!p = (^+^) @( 𝕀 Double ) @( T w )
!s = (^*) @( 𝕀 Double ) @( T w )
in $$( chainRuleQ
in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] )
@ -219,7 +219,7 @@ instance HasChainRule ( 𝕀 Double ) 3 ( 𝕀 1 ) where
let !o = origin @( 𝕀 Double ) @( T w )
!p = (^+^) @( 𝕀 Double ) @( T w )
!s = (^*) @( 𝕀 Double ) @( T w )
in $$( chainRuleQ
in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] )
@ -254,7 +254,7 @@ instance HasChainRule ( 𝕀 Double ) 2 ( 𝕀 2 ) where
let !o = origin @( 𝕀 Double ) @( T w )
!p = (^+^) @( 𝕀 Double ) @( T w )
!s = (^*) @( 𝕀 Double ) @( T w )
in $$( chainRuleQ
in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] )
@ -289,7 +289,7 @@ instance HasChainRule ( 𝕀 Double ) 3 ( 𝕀 2 ) where
let !o = origin @( 𝕀 Double ) @( T w )
!p = (^+^) @( 𝕀 Double ) @( T w )
!s = (^*) @( 𝕀 Double ) @( T w )
in $$( chainRuleQ
in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] )
@ -324,7 +324,7 @@ instance HasChainRule ( 𝕀 Double ) 2 ( 𝕀 3 ) where
let !o = origin @( 𝕀 Double ) @( T w )
!p = (^+^) @( 𝕀 Double ) @( T w )
!s = (^*) @( 𝕀 Double ) @( T w )
in $$( chainRuleQ
in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] )
@ -359,7 +359,7 @@ instance HasChainRule ( 𝕀 Double ) 3 ( 𝕀 3 ) where
let !o = origin @( 𝕀 Double ) @( T w )
!p = (^+^) @( 𝕀 Double ) @( T w )
!s = (^*) @( 𝕀 Double ) @( T w )
in $$( chainRuleQ
in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] )
@ -394,7 +394,7 @@ instance HasChainRule ( 𝕀 Double ) 2 ( 𝕀 4 ) where
let !o = origin @( 𝕀 Double ) @( T w )
!p = (^+^) @( 𝕀 Double ) @( T w )
!s = (^*) @( 𝕀 Double ) @( T w )
in $$( chainRuleQ
in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] )
@ -429,6 +429,6 @@ instance HasChainRule ( 𝕀 Double ) 3 ( 𝕀 4 ) where
let !o = origin @( 𝕀 Double ) @( T w )
!p = (^+^) @( 𝕀 Double ) @( T w )
!s = (^*) @( 𝕀 Double ) @( T w )
in $$( chainRuleQ
in $$( chainRule1NQ
[|| o ||] [|| p ||] [|| s ||]
[|| df ||] [|| dg ||] )

View file

@ -28,6 +28,8 @@ import GHC.Show
( showSpace )
import GHC.TypeNats
( Nat, type (+) )
import Unsafe.Coerce
( unsafeCoerce )
-- acts
import Data.Act
@ -118,7 +120,7 @@ infixr 5 `VS`
type Vec :: Nat -> Type -> Type
data Vec n a where
VZ :: Vec 0 a
VS :: a -> Vec n a -> Vec ( 1 + n ) a
VS :: forall n a. a -> Vec n a -> Vec ( 1 + n ) a
-- can't be strict, otherwise we can't conveniently
-- unsafeCoerce from lists
@ -128,6 +130,13 @@ deriving stock instance Functor ( Vec n )
deriving stock instance Foldable ( Vec n )
deriving stock instance Traversable ( Vec n )
instance Eq a => Eq ( Vec n a ) where
(==) = unsafeCoerce $ (==) @[a]
instance Ord a => Ord ( Vec n a ) where
compare = unsafeCoerce $ compare @[a]
(<=) = unsafeCoerce $ (<=) @[a]
infixl 9 !
(!) :: forall l a. Vec l a -> Fin l -> a
VS a _ ! Fin 1 = a

View file

@ -1,6 +1,7 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ParallelListComp #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UndecidableInstances #-}
@ -9,11 +10,13 @@
module Math.Monomial
( Mon(..)
, MonomialBasis(..), Deg, Vars
, zeroMonomial, isZeroMonomial, isLinear
, zeroMonomial, isZeroMonomial
, totalDegree, isLinear
, split, mons
, faà, multiSubsetsSum, multiSubsetSum
, multiSubsetSumFaà, multiSubsetsSum, multiSubsetSum
, partitionFaà, partitions
, prodRuleQ
@ -27,7 +30,7 @@ import Data.Kind
import GHC.Exts
( proxy# )
import GHC.TypeNats
( KnownNat, Nat, natVal' )
-- ( KnownNat, Nat, natVal' )
import Unsafe.Coerce
( unsafeCoerce )
@ -46,7 +49,7 @@ import TH.Utils
-- | @Mon k n@ is the set of monomials in @n@ variables of degree less than or equal to @k@.
type Mon :: Nat -> Nat -> Type
newtype Mon k n = Mon { monDegs :: Vec n Word } -- sum <= k
deriving stock Show
deriving stock ( Show, Eq, Ord )
type Deg :: ( Type -> Type ) -> Nat
type Vars :: ( Type -> Type ) -> Nat
@ -65,6 +68,10 @@ isZeroMonomial ( Mon ( i `VS` is ) )
| otherwise
= False
totalDegree :: Mon k n -> Word
totalDegree ( Mon VZ ) = 0
totalDegree ( Mon ( i `VS` is ) ) = i + totalDegree ( Mon is )
isLinear :: Mon k n -> Maybe ( Fin n )
isLinear = fmap Fin . go 1
where
@ -99,12 +106,32 @@ mons' _ 0 = [ [] ]
mons' 0 n = [ replicate ( fromIntegral n ) 0 ]
mons' k n = [ i : is | i <- reverse [ 0 .. k ], is <- mons' ( k - i ) ( n - 1 ) ]
subs :: Mon k n -> [ ( Mon k n, Word ) ]
subs ( Mon VZ ) = [ ( Mon VZ, maxBound ) ]
subs ( Mon ( i `VS` is ) )
= [ ( Mon ( j `VS` js )
, if j == 0 then mult else min ( i `quot` j ) mult )
| j <- [ 0 .. i ]
, ( Mon js, mult ) <- subs ( Mon is )
]
word :: forall n. KnownNat n => Word
word = fromIntegral $ natVal' @n proxy#
-- | The factorial function \( n! = n \cdot (n-1) \cdot `ldots` \cdot 2 `cdot` 1 \).
factorial :: Word -> Word
factorial i = product [ 1 .. i ]
vecFactorial :: Vec n Word -> Word
vecFactorial VZ = 1
vecFactorial ( i `VS` is ) = factorial i * vecFactorial is
--------------------------------------------------------------------------------
-- Computations for the chain rule R^1 -> R^n -> R^m
-- | Faà di Bruno coefficient (naive implementation).
faà :: Word -> Vec n [ ( Word, Word ) ] -> Word
faà k multisubsets =
multiSubsetSumF:: Word -> Vec n [ ( Word, Word ) ] -> Word
multiSubsetSumFk multisubsets =
factorial k `div`
product [ factorial p * ( factorial i ) ^ p
| multisubset <- toList multisubsets
@ -146,9 +173,49 @@ multiSubsetsSum is = goMSS
[] -> 0
_ -> max 0 $ minimum is
-- | The factorial function \( n! = n \cdot (n-1) \cdot `ldots` \cdot 2 `cdot` 1 \).
factorial :: Word -> Word
factorial i = product [ 1 .. i ]
--------------------------------------------------------------------------------
-- Computations for the chain rule R^n -> R^1 -> R^1
partitionFaà :: Mon k n -> [ ( Mon k n, Word ) ] -> Word
partitionFaà ( Mon mon ) multiIndexes =
vecFactorial mon `div`
product [ factorial p * ( vecFactorial i ) ^ p
| ( Mon i, p ) <- toList multiIndexes ]
-- | @partitions p mon@ computes all partitions of the monomial @mon@ into
-- @p@ (non-zero) parts, allowing repetition.
partitions :: forall k n
. Word -- ^ number of parts
-> Mon k n -- ^ monomial to sum to
-> [ [ ( Mon k n, Word ) ] ]
partitions n_parts0 mon0 = go n_parts0 mon0 mon0
where
go :: Word -> Mon k n -> Mon k n -> [ [ ( Mon k n , Word ) ] ]
go 0 _ mon
| isZeroMonomial mon
= [ [] ]
| otherwise
= []
go 1 maxSub mon
| isZeroMonomial mon
|| mon >= maxSub
= []
| otherwise
= [ [ ( mon, 1 ) ] ]
go n_parts maxSub mon
= [ ( part, l ) : parts
| ( part, l_max ) <- subs mon
, not ( isZeroMonomial part ) -- parts must be non-empty
, part < maxSub -- use a total ordering on monomials to ensure uniqueness
, l <- [ 1 .. min n_parts l_max ]
, parts <- go ( n_parts - l ) part ( subPart l mon part )
]
subPart :: Word -> Mon k n -> Mon k n -> Mon k n
subPart = unsafeCoerce subPart'
subPart' :: Word -> [ Word ] -> [ Word ] -> [ Word ]
subPart' m = zipWith ( \ i j -> i - m * j )
--------------------------------------------------------------------------------