Memory usage of mbCG #1453
-
Hi everyone, I hope this is the adequate place to ask this question. If not I can also resort to opening an issue. I have a question regarding memory usage of mbCG, or maybe even gpytorch in general. In the appendix of the BBMM paper, observation 1 states that CG has a space complexity of O(n). From my understanding, this is not the case for mbCG, as the matrix-vector-multiplication is replaced by a matrix-matrix-multiplication. Is this true? If not, how can I make use of the memory efficiency? From some rudimentary tests I found that during a single calculation of say the marginal likelihood, the kernel's forward method is called exactly once, which from my understanding hints at the whole covariance matrix being computed, and thus a space complexity of O(n^2). Or is the column-wise access utilized in CG/mvm controlled by the lazy tensors? During these tests I also found the I think you have created a masterful package, but trying to understand which parts of the paper are utilized at which point is often really hard, due to the very nested structure of the implementations. I'm looking forward to your answer! Cheers and stay safe, |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
It's The original GPyTorch paper exclusively dealt with the setting where we are storing the entire kernel matrix in memory to compute MVMs, but this paper directly centers around extending this to O(n) storage, and we have a few example notebooks that do this. |
Beta Was this translation helpful? Give feedback.
-
Hi everyone, I also had a look at the KeOps approach, but it appears I would have to rewrite my custom kernel to be able to use it. Would this help with the issue of a lack of GPU-memory? Cheers, |
Beta Was this translation helpful? Give feedback.
It's
O(n)
because the output of CG (a single vector or small set of vectors) is that space requirement unlike Cholesky, and you don't really need to store the entire kernel matrix in memory at the same time to do so -- you can compute MVMs in a map reduce fashion.The original GPyTorch paper exclusively dealt with the setting where we are storing the entire kernel matrix in memory to compute MVMs, but this paper directly centers around extending this to O(n) storage, and we have a few example notebooks that do this.