Skip to content

Commit aaac78d

Browse files
refine code
1 parent 2f4cb60 commit aaac78d

File tree

3 files changed

+4
-35
lines changed

3 files changed

+4
-35
lines changed

python/paddle/distribution/exponential.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def sample(self, shape=()):
8484
Tensor, A tensor with prepended dimensions shape.The data type is float32.
8585
8686
"""
87-
return self.rsample(shape)
87+
with paddle.no_grad():
88+
return self.rsample(shape)
8889

8990
def rsample(self, shape=()):
9091
"""Generate reparameterized samples of the specified shape.

python/paddle/distribution/gamma.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -139,25 +139,11 @@ def entropy(self):
139139
)
140140

141141
def sample(self, shape=()):
142-
"""Generate samples of the specified shape.
143-
144-
Args:
145-
shape (Sequence[int], optional): Shape of the generated samples.
146-
147-
Returns:
148-
Tensor, A tensor with prepended dimensions shape.The data type is float32.
149-
"""
142+
"""Generate samples of the specified shape."""
150143
raise NotImplementedError
151144

152145
def rsample(self, shape=()):
153-
"""Generate reparameterized samples of the specified shape.
154-
155-
Args:
156-
shape (Sequence[int], optional): Shape of the generated samples.
157-
158-
Returns:
159-
Tensor: A tensor with prepended dimensions shape. The data type is float32.
160-
"""
146+
"""Generate reparameterized samples of the specified shape."""
161147
raise NotImplementedError
162148

163149
def kl_divergence(self, other):

test/distribution/test_distribution_gamma.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,6 @@ def test_entropy(self):
170170
@parameterize.parameterize_cls(
171171
(parameterize.TEST_CASE_NAME, 'concentration', 'rate'),
172172
[
173-
(
174-
'0-dim',
175-
0.5,
176-
0.5,
177-
),
178173
(
179174
'one-dim',
180175
parameterize.xrand(
@@ -201,19 +196,6 @@ def test_entropy(self):
201196
min=np.finfo(dtype='float32').tiny,
202197
),
203198
),
204-
(
205-
'broadcast',
206-
parameterize.xrand(
207-
(2, 1),
208-
dtype='float32',
209-
min=np.finfo(dtype='float32').tiny,
210-
),
211-
parameterize.xrand(
212-
(2, 3),
213-
dtype='float32',
214-
min=np.finfo(dtype='float32').tiny,
215-
),
216-
),
217199
],
218200
)
219201
class TestGammaSample(unittest.TestCase):

0 commit comments

Comments
 (0)