Skip to content

Commit 905e112

Browse files
authored
Fix handling of device when compiled for but disabled nccl (#1227)
1 parent 7306205 commit 905e112

File tree

4 files changed

+7
-6
lines changed

4 files changed

+7
-6
lines changed

mistralrs-bench/src/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ fn main() -> anyhow::Result<()> {
370370
#[cfg(feature = "metal")]
371371
let device = Device::new_metal(0)?;
372372
#[cfg(not(feature = "metal"))]
373-
let device = if cfg!(feature = "nccl") {
373+
let device = if mistralrs_core::distributed::use_nccl() {
374374
Device::Cpu
375375
} else {
376376
Device::cuda_if_available(0)?
@@ -433,7 +433,7 @@ fn main() -> anyhow::Result<()> {
433433
DeviceMapSetting::Auto(auto_device_map_params)
434434
};
435435

436-
let no_paged_attn = if device.is_cuda() || cfg!(feature = "nccl") {
436+
let no_paged_attn = if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
437437
args.no_paged_attn
438438
} else if device.is_metal() {
439439
!args.paged_attn

mistralrs-core/src/distributed.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use core::ffi::c_char;
44
use interprocess::local_socket::traits::{Listener, Stream};
55
use interprocess::local_socket::{GenericNamespaced, Name, ToNsName};
66
use interprocess::local_socket::{ListenerOptions, Stream as LocalStream};
7+
pub use mistralrs_quant::distributed::use_nccl;
78
use mistralrs_quant::{ShardedSafeTensors, ShardedVarBuilder};
89
use serde::{Deserialize, Serialize};
910
use serde_big_array::BigArray;

mistralrs-pyo3/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ static DEVICE: OnceLock<Result<Device>> = OnceLock::new();
4444
#[cfg(not(feature = "metal"))]
4545
fn get_device(seed: Option<u64>) -> &'static Result<Device> {
4646
DEVICE.get_or_init(|| {
47-
let device = if cfg!(feature = "nccl") {
47+
let device = if mistralrs_core::distributed::use_nccl() {
4848
Device::Cpu
4949
} else {
5050
Device::cuda_if_available(0)?
@@ -652,7 +652,7 @@ impl Runner {
652652
None => DeviceMapSetting::Auto(auto_map_params),
653653
};
654654

655-
let no_paged_attn = if device.is_cuda() || cfg!(feature = "nccl") {
655+
let no_paged_attn = if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
656656
no_paged_attn
657657
} else if device.is_metal() {
658658
!paged_attn

mistralrs-server/src/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ async fn main() -> Result<()> {
317317
let device = if args.cpu {
318318
args.no_paged_attn = true;
319319
Device::Cpu
320-
} else if cfg!(feature = "nccl") {
320+
} else if mistralrs_core::distributed::use_nccl() {
321321
Device::Cpu
322322
} else {
323323
Device::cuda_if_available(0)?
@@ -379,7 +379,7 @@ async fn main() -> Result<()> {
379379
DeviceMapSetting::Auto(auto_device_map_params)
380380
};
381381

382-
let no_paged_attn = if device.is_cuda() || cfg!(feature = "nccl") {
382+
let no_paged_attn = if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
383383
args.no_paged_attn
384384
} else if device.is_metal() {
385385
!args.paged_attn

0 commit comments

Comments
 (0)