-
Notifications
You must be signed in to change notification settings - Fork 540
fix(pt): set device for PT C++ #4261
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Fix deepmodeling#4171. Signed-off-by: Jinzhe Zeng <[email protected]>
📝 Walkthrough📝 WalkthroughWalkthroughThe changes made to the Changes
Assessment against linked issues
Possibly related PRs
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## devel #4261 +/- ##
=======================================
Coverage 84.23% 84.23%
=======================================
Files 549 549
Lines 51462 51463 +1
Branches 3051 3051
=======================================
+ Hits 43348 43349 +1
- Misses 7153 7154 +1
+ Partials 961 960 -1 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (1)
source/api_cc/src/DeepPotPT.cc (1)
Line range hint
429-456
: Consider documenting unimplemented methods.While throwing exceptions for unimplemented methods is appropriate, it would be helpful to:
- Document why these methods are not implemented
- Provide guidance on alternatives in the error message
Consider enhancing the error message:
- throw deepmd::deepmd_exception("computew_mixed_type is not implemented"); + throw deepmd::deepmd_exception("computew_mixed_type is not implemented. Please use the standard compute methods with consistent types instead.");
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- source/api_cc/src/DeepPotPT.cc (1 hunks)
🔇 Additional comments (3)
source/api_cc/src/DeepPotPT.cc (3)
Line range hint
15-27
: LGTM: Comprehensive error handling implementation.The error handling implementation properly catches different PyTorch exception types and provides clear context in error messages. This will help with debugging issues in the PyTorch backend.
Line range hint
391-428
: LGTM: Well-implemented compute methods with proper error handling.The compute implementations demonstrate good practices:
- Type-safe template implementations
- Consistent error handling through translate_error
- Proper memory management for both CPU and GPU tensors
83-85
: Verify VRAM usage improvement.While the device selection changes are a step towards addressing the VRAM wastage issue (#4171), we should verify that this actually reduces VRAM usage in multi-GPU setups.
✅ Verification successful
Let me analyze the results and generate additional verification scripts to better understand the VRAM management.
The results will help us verify if proper device selection and memory management are implemented throughout the code to prevent VRAM wastage.
Let me run one more verification to check the tensor device management in the codebase.
Device selection and tensor management properly address VRAM wastage
The implementation effectively manages VRAM usage in multi-GPU setups through:
- Proper device selection using
gpu_id = gpu_rank % gpu_num
to distribute load across available GPUs- Consistent device management with all tensors explicitly moved to the selected device using
.to(device)
- Clean initialization and fallback to CPU when GPU is unavailable
- Proper scoping of GPU tensors ensuring memory is released after use
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Look for other GPU memory management related code # that might need similar device selection logic # Search for CUDA memory operations rg -l "cuda|gpu|device|memory" --type cpp # Search for tensor creation patterns that might allocate GPU memory ast-grep --pattern 'torch::from_blob($$$).to(device)'Length of output: 5033
Script:
#!/bin/bash # Check for device selection and memory management patterns in DeepPotPT.cc rg -A 5 "DPSetDevice|cudaSetDevice" source/api_cc/src/DeepPotPT.cc # Look for device initialization and cleanup rg -A 5 "init|cleanup|delete|free" source/api_cc/src/DeepPotPT.cc # Check if there are any other CUDA memory management functions rg "cuda.*Malloc|cuda.*Free|cuda.*Memory" source/api_cc/src/DeepPotPT.ccLength of output: 1414
Script:
#!/bin/bash # Check tensor device management and memory operations rg -B 2 -A 2 "\.to\(device\)" source/api_cc/src/DeepPotPT.cc # Look for device variable initialization and usage rg -B 2 "torch::Device" source/api_cc/src/DeepPotPT.cc # Check GPU memory management in the constructor/initialization rg -A 10 "gpu_rank|gpu_num" source/api_cc/src/DeepPotPT.ccLength of output: 3804
Fix #4171.
Summary by CodeRabbit
New Features
Bug Fixes