Skip to content

Commit c8d83ca

Browse files
williambdeanricardoV94
authored andcommitted
profile docstring and type hint
1 parent 618634b commit c8d83ca

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

pymc/model/core.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import pytensor.tensor as pt
3333
import scipy.sparse as sps
3434

35-
from pytensor.compile import DeepCopyOp, Function, get_mode
35+
from pytensor.compile import DeepCopyOp, Function, ProfileStats, get_mode
3636
from pytensor.compile.sharedvalue import SharedVariable
3737
from pytensor.graph.basic import Constant, Variable, ancestors, graph_inputs
3838
from pytensor.tensor.random.op import RandomVariable
@@ -1657,7 +1657,15 @@ def compile_fn(
16571657
return PointFunc(fn)
16581658
return fn
16591659

1660-
def profile(self, outs, *, n=1000, point=None, profile=True, **kwargs):
1660+
def profile(
1661+
self,
1662+
outs,
1663+
*,
1664+
n=1000,
1665+
point=None,
1666+
profile=True,
1667+
**compile_fn_kwargs,
1668+
) -> ProfileStats:
16611669
"""Compile and profile a PyTensor function which returns ``outs`` and takes values of model vars as a dict as an argument.
16621670
16631671
Parameters
@@ -1668,16 +1676,22 @@ def profile(self, outs, *, n=1000, point=None, profile=True, **kwargs):
16681676
point : Point
16691677
Point to pass to the function
16701678
profile : True or ProfileStats
1671-
args, kwargs
1672-
Compilation args
1679+
compile_fn_kwargs
1680+
Compilation kwargs for :func:`pymc.model.core.Model.compile_fn`
16731681
16741682
Returns
16751683
-------
1676-
ProfileStats
1684+
pytensor.compile.profiling.ProfileStats
16771685
Use .summary() to print stats.
16781686
"""
1679-
kwargs.setdefault("on_unused_input", "ignore")
1680-
f = self.compile_fn(outs, inputs=self.value_vars, point_fn=False, profile=profile, **kwargs)
1687+
compile_fn_kwargs.setdefault("on_unused_input", "ignore")
1688+
f = self.compile_fn(
1689+
outs,
1690+
inputs=self.value_vars,
1691+
point_fn=False,
1692+
profile=profile,
1693+
**compile_fn_kwargs,
1694+
)
16811695
if point is None:
16821696
point = self.initial_point()
16831697

0 commit comments

Comments
 (0)