32
32
import pytensor .tensor as pt
33
33
import scipy .sparse as sps
34
34
35
- from pytensor .compile import DeepCopyOp , Function , get_mode
35
+ from pytensor .compile import DeepCopyOp , Function , ProfileStats , get_mode
36
36
from pytensor .compile .sharedvalue import SharedVariable
37
37
from pytensor .graph .basic import Constant , Variable , ancestors , graph_inputs
38
38
from pytensor .tensor .random .op import RandomVariable
@@ -1657,7 +1657,15 @@ def compile_fn(
1657
1657
return PointFunc (fn )
1658
1658
return fn
1659
1659
1660
- def profile (self , outs , * , n = 1000 , point = None , profile = True , ** kwargs ):
1660
+ def profile (
1661
+ self ,
1662
+ outs ,
1663
+ * ,
1664
+ n = 1000 ,
1665
+ point = None ,
1666
+ profile = True ,
1667
+ ** compile_fn_kwargs ,
1668
+ ) -> ProfileStats :
1661
1669
"""Compile and profile a PyTensor function which returns ``outs`` and takes values of model vars as a dict as an argument.
1662
1670
1663
1671
Parameters
@@ -1668,16 +1676,22 @@ def profile(self, outs, *, n=1000, point=None, profile=True, **kwargs):
1668
1676
point : Point
1669
1677
Point to pass to the function
1670
1678
profile : True or ProfileStats
1671
- args, kwargs
1672
- Compilation args
1679
+ compile_fn_kwargs
1680
+ Compilation kwargs for :func:`pymc.model.core.Model.compile_fn`
1673
1681
1674
1682
Returns
1675
1683
-------
1676
- ProfileStats
1684
+ pytensor.compile.profiling. ProfileStats
1677
1685
Use .summary() to print stats.
1678
1686
"""
1679
- kwargs .setdefault ("on_unused_input" , "ignore" )
1680
- f = self .compile_fn (outs , inputs = self .value_vars , point_fn = False , profile = profile , ** kwargs )
1687
+ compile_fn_kwargs .setdefault ("on_unused_input" , "ignore" )
1688
+ f = self .compile_fn (
1689
+ outs ,
1690
+ inputs = self .value_vars ,
1691
+ point_fn = False ,
1692
+ profile = profile ,
1693
+ ** compile_fn_kwargs ,
1694
+ )
1681
1695
if point is None :
1682
1696
point = self .initial_point ()
1683
1697
0 commit comments