-
Notifications
You must be signed in to change notification settings - Fork 529
/
Copy pathCargo.toml
127 lines (111 loc) · 4.32 KB
/
Cargo.toml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
[package]
authors = ["nathanielsimard <[email protected]>"]
categories = ["science", "no-std", "embedded", "wasm"]
description = "Flexible and Comprehensive Deep Learning Framework in Rust"
documentation = "https://docs.rs/burn-core"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
license.workspace = true
name = "burn-core"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-core"
version.workspace = true
[features]
dataset = ["burn-dataset"]
default = [
"std",
"burn-common/default",
"burn-dataset?/default",
"burn-tensor/default",
]
doc = [
"std",
"dataset",
"audio",
"vision",
# Doc features
"burn-common/doc",
"burn-dataset/doc",
"burn-tensor/doc",
]
network = ["burn-common/network"]
sqlite = ["burn-dataset?/sqlite"]
sqlite-bundled = ["burn-dataset?/sqlite-bundled"]
std = [
"bincode/std",
"burn-common/std",
"burn-tensor/std",
"flate2",
"half/std",
"log",
"rand/std",
"rmp-serde",
"serde/std",
"serde_json/std",
"num-traits/std",
]
vision = ["burn-dataset?/vision", "burn-common/network"]
audio = ["burn-dataset?/audio"]
# Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files.
record-item-custom-serde = ["thiserror", "regex"]
# Serialization formats
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
test-cuda = [
"burn-cuda/default",
] # To use cuda during testing, default uses ndarray.
test-hip = [
"burn-hip/default",
] # To use hip during testing, default uses ndarray.
test-tch = [
"burn-tch/default",
] # To use tch during testing, default uses ndarray.
test-wgpu = [
"burn-wgpu/default",
] # To use wgpu during testing, default uses ndarray.
test-wgpu-spirv = [
"burn-wgpu/default",
"burn-wgpu/vulkan",
] # To use wgpu-spirv during testing, default uses ndarray.
[dependencies]
# ** Please make sure all dependencies support no_std when std is disabled **
burn-common = { path = "../burn-common", version = "0.17.0", default-features = false }
burn-dataset = { path = "../burn-dataset", version = "0.17.0", optional = true, default-features = false }
burn-derive = { path = "../burn-derive", version = "0.17.0" }
burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false }
data-encoding = { workspace = true }
uuid = { workspace = true }
derive-new = { workspace = true }
log = { workspace = true, optional = true }
rand = { workspace = true, features = ["std_rng"] } # Default enables std
# The same implementation of HashMap in std but with no_std support (only alloc crate is needed)
hashbrown = { workspace = true, features = ["serde"] } # no_std compatible
# Serialize Deserialize
flate2 = { workspace = true, optional = true }
serde = { workspace = true, features = ["derive"] }
ahash = { workspace = true }
bincode = { workspace = true }
half = { workspace = true }
num-traits = { workspace = true }
regex = { workspace = true, optional = true }
rmp-serde = { workspace = true, optional = true }
serde_json = { workspace = true, features = ["alloc"] } #Default enables std
spin = { workspace = true } # Using in place of use std::sync::Mutex when std is disabled
thiserror = { workspace = true, optional = true }
# FOR TESTING
burn-cuda = { path = "../burn-cuda", version = "0.17.0", optional = true, default-features = false }
burn-hip = { path = "../burn-hip", version = "0.17.0", optional = true, default-features = false }
burn-remote = { path = "../burn-remote", version = "0.17.0", default-features = false, optional = true }
burn-router = { path = "../burn-router", version = "0.17.0", default-features = false, optional = true }
burn-tch = { path = "../burn-tch", version = "0.17.0", optional = true }
burn-wgpu = { path = "../burn-wgpu", version = "0.17.0", optional = true, default-features = false }
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
portable-atomic-util = { workspace = true }
[dev-dependencies]
burn-ndarray = { path = "../burn-ndarray", version = "0.17.0" }
burn-autodiff = { path = "../burn-autodiff", version = "0.17.0" }
burn-dataset = { path = "../burn-dataset", version = "0.17.0", features = [
"fake",
] }
[package.metadata.docs.rs]
features = ["doc"]
rustdoc-args = ["--cfg", "docsrs"]