Skip to content

Fix sample from prior for ConstantMean #2042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 16, 2022
Merged

Fix sample from prior for ConstantMean #2042

merged 4 commits into from
Jun 16, 2022

Conversation

dme65
Copy link
Collaborator

@dme65 dme65 commented Jun 15, 2022

The use of fill_ doesn't play well with the shape of the priors.

@dme65 dme65 requested a review from Balandat June 15, 2022 22:11
Copy link
Collaborator

@Balandat Balandat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. What exactly goes wrong when using fill_ here? Is the issue that priors.sample() produces a tensor of some shape rather than a scalar and so fill_ does weird things? Would be good to add some type annotation to understand what _constant_closure is supposed to operate on.

Also, does that mean all the other modules should be updated in the same way?

Comment on lines +21 to +22
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.constant)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this gating necessary or could we just always apply torch.as_tensor (should be a nullop)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just following what is done for other closures, for example: https://github.com/cornellius-gp/gpytorch/blob/master/gpytorch/kernels/kernel.py#L262.

@dme65
Copy link
Collaborator Author

dme65 commented Jun 16, 2022

Is the issue that priors.sample() produces a tensor of some shape rather than a scalar and so fill_ does weird things

Yes, fill_ will fail if you call it with a tensor with more than 0-dimensions. For example:

x = torch.randn(5)
y = torch.randn(1)
x.fill_(y)

results in RuntimeError: fill_ only supports 0-dimension value tensor but got tensor with 1 dimensions.

@dme65 dme65 merged commit 538648b into master Jun 16, 2022
@dme65 dme65 deleted the constant_mean_prior branch June 16, 2022 16:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants