|
13 | 13 | #include <limits>
|
14 | 14 | #include <string>
|
15 | 15 |
|
| 16 | +#include "ur_util.hpp" |
| 17 | + |
16 | 18 | namespace usm {
|
17 | 19 | constexpr auto operator""_B(unsigned long long x) -> size_t { return x; }
|
18 | 20 | constexpr auto operator""_KB(unsigned long long x) -> size_t {
|
@@ -70,157 +72,110 @@ DisjointPoolAllConfigs::DisjointPoolAllConfigs(int trace) {
|
70 | 72 | Configs[DisjointPoolMemType::SharedReadOnly].SlabMinSize = 2_MB;
|
71 | 73 | }
|
72 | 74 |
|
| 75 | +std::optional<size_t> stringToNumber(std::string_view s) { |
| 76 | + auto unitPos = s.find_first_of("kKmMgG"); |
| 77 | + size_t multiplier = 1; |
| 78 | + if (unitPos != std::string_view::npos) { |
| 79 | + switch (tolower(s[unitPos])) { |
| 80 | + case 'k': |
| 81 | + multiplier = 1_KB; |
| 82 | + break; |
| 83 | + case 'm': |
| 84 | + multiplier = 1_MB; |
| 85 | + break; |
| 86 | + case 'g': |
| 87 | + multiplier = 1_GB; |
| 88 | + break; |
| 89 | + } |
| 90 | + } |
| 91 | + |
| 92 | + try { |
| 93 | + return std::stoull(std::string(s.substr(0, unitPos))) * multiplier; |
| 94 | + } catch (...) { |
| 95 | + return std::nullopt; |
| 96 | + } |
| 97 | +} |
| 98 | + |
| 99 | +std::pair<std::optional<size_t>, std::string_view> |
| 100 | +maybeParseNumber(std::string_view s) { |
| 101 | + auto separator = ';'; |
| 102 | + auto separatorPos = s.find(separator); |
| 103 | + if (separatorPos == std::string_view::npos) { |
| 104 | + auto number = stringToNumber(s.substr(0, separatorPos)); |
| 105 | + if (number) |
| 106 | + return {number, std::string_view()}; |
| 107 | + else |
| 108 | + return {std::nullopt, s}; |
| 109 | + } |
| 110 | + |
| 111 | + return {stringToNumber(s.substr(0, separatorPos)), |
| 112 | + s.substr(separatorPos + 1)}; |
| 113 | +} |
| 114 | + |
| 115 | +DisjointPoolMemType parseMemType(std::string_view s) { |
| 116 | + if (s == "host") |
| 117 | + return DisjointPoolMemType::Host; |
| 118 | + if (s == "device") |
| 119 | + return DisjointPoolMemType::Device; |
| 120 | + if (s == "shared") |
| 121 | + return DisjointPoolMemType::Shared; |
| 122 | + if (s == "read_only_shared") |
| 123 | + return DisjointPoolMemType::SharedReadOnly; |
| 124 | + |
| 125 | + throw std::invalid_argument("Unknown memory type: " + std::string(s)); |
| 126 | +} |
| 127 | + |
73 | 128 | DisjointPoolAllConfigs parseDisjointPoolConfig(const std::string &config,
|
74 | 129 | int trace) {
|
75 | 130 | DisjointPoolAllConfigs AllConfigs;
|
76 | 131 |
|
77 |
| - // TODO: replace with UR ENV var parser and avoid creating a copy of 'config' |
78 |
| - auto GetValue = [](std::string &Param, size_t Length, size_t &Setting) { |
79 |
| - size_t Multiplier = 1; |
80 |
| - if (tolower(Param[Length - 1]) == 'k') { |
81 |
| - Length--; |
82 |
| - Multiplier = 1_KB; |
83 |
| - } |
84 |
| - if (tolower(Param[Length - 1]) == 'm') { |
85 |
| - Length--; |
86 |
| - Multiplier = 1_MB; |
87 |
| - } |
88 |
| - if (tolower(Param[Length - 1]) == 'g') { |
89 |
| - Length--; |
90 |
| - Multiplier = 1_GB; |
91 |
| - } |
92 |
| - std::string TheNumber = Param.substr(0, Length); |
93 |
| - if (TheNumber.find_first_not_of("0123456789") == std::string::npos) { |
94 |
| - Setting = std::stoi(TheNumber) * Multiplier; |
95 |
| - } |
96 |
| - }; |
| 132 | + std::optional<size_t> Buffers, MaxSize; |
| 133 | + std::string input; |
97 | 134 |
|
98 |
| - auto ParamParser = [GetValue](std::string &Params, size_t &Setting, |
99 |
| - bool &ParamWasSet) { |
100 |
| - bool More; |
101 |
| - if (Params.size() == 0) { |
102 |
| - ParamWasSet = false; |
103 |
| - return false; |
104 |
| - } |
105 |
| - size_t Pos = Params.find(','); |
106 |
| - if (Pos != std::string::npos) { |
107 |
| - if (Pos > 0) { |
108 |
| - GetValue(Params, Pos, Setting); |
109 |
| - ParamWasSet = true; |
110 |
| - } |
111 |
| - Params.erase(0, Pos + 1); |
112 |
| - More = true; |
113 |
| - } else { |
114 |
| - GetValue(Params, Params.size(), Setting); |
115 |
| - ParamWasSet = true; |
116 |
| - More = false; |
117 |
| - } |
118 |
| - return More; |
| 135 | + std::tie(Buffers, input) = maybeParseNumber(config); |
| 136 | + std::tie(MaxSize, input) = maybeParseNumber(input); |
| 137 | + |
| 138 | + auto setConfigValues = [](umf_disjoint_pool_config_t &config, |
| 139 | + std::vector<std::string> values) { |
| 140 | + if (values.size() > 0) |
| 141 | + config.MaxPoolableSize = stringToNumber(values[0]).value(); |
| 142 | + if (values.size() > 1) |
| 143 | + config.Capacity = stringToNumber(values[1]).value(); |
| 144 | + if (values.size() > 2) |
| 145 | + config.SlabMinSize = stringToNumber(values[2]).value(); |
119 | 146 | };
|
120 | 147 |
|
121 |
| - auto MemParser = [&AllConfigs, ParamParser](std::string &Params, |
122 |
| - DisjointPoolMemType memType = |
123 |
| - DisjointPoolMemType::All) { |
124 |
| - bool ParamWasSet; |
125 |
| - DisjointPoolMemType LM = memType; |
126 |
| - if (memType == DisjointPoolMemType::All) { |
127 |
| - LM = DisjointPoolMemType::Host; |
128 |
| - } |
| 148 | + try { |
| 149 | + // try to parse the string with per-type settings |
129 | 150 |
|
130 |
| - bool More = ParamParser(Params, AllConfigs.Configs[LM].MaxPoolableSize, |
131 |
| - ParamWasSet); |
132 |
| - if (ParamWasSet && memType == DisjointPoolMemType::All) { |
133 |
| - for (auto &Config : AllConfigs.Configs) { |
134 |
| - Config.MaxPoolableSize = AllConfigs.Configs[LM].MaxPoolableSize; |
135 |
| - } |
136 |
| - } |
137 |
| - if (More) { |
138 |
| - More = ParamParser(Params, AllConfigs.Configs[LM].Capacity, ParamWasSet); |
139 |
| - if (ParamWasSet && memType == DisjointPoolMemType::All) { |
140 |
| - for (auto &Config : AllConfigs.Configs) { |
141 |
| - Config.Capacity = AllConfigs.Configs[LM].Capacity; |
142 |
| - } |
143 |
| - } |
144 |
| - } |
145 |
| - if (More) { |
146 |
| - ParamParser(Params, AllConfigs.Configs[LM].SlabMinSize, ParamWasSet); |
147 |
| - if (ParamWasSet && memType == DisjointPoolMemType::All) { |
148 |
| - for (auto &Config : AllConfigs.Configs) { |
149 |
| - Config.SlabMinSize = AllConfigs.Configs[LM].SlabMinSize; |
150 |
| - } |
151 |
| - } |
152 |
| - } |
153 |
| - }; |
| 151 | + auto perTypeSettings = parse_string_to_map(input, false); |
| 152 | + for (auto &[type, values] : perTypeSettings) { |
| 153 | + DisjointPoolMemType memType = parseMemType(type); |
154 | 154 |
|
155 |
| - auto MemTypeParser = [MemParser](std::string &Params) { |
156 |
| - int Pos = 0; |
157 |
| - DisjointPoolMemType M(DisjointPoolMemType::All); |
158 |
| - if (Params.compare(0, 5, "host:") == 0) { |
159 |
| - Pos = 5; |
160 |
| - M = DisjointPoolMemType::Host; |
161 |
| - } else if (Params.compare(0, 7, "device:") == 0) { |
162 |
| - Pos = 7; |
163 |
| - M = DisjointPoolMemType::Device; |
164 |
| - } else if (Params.compare(0, 7, "shared:") == 0) { |
165 |
| - Pos = 7; |
166 |
| - M = DisjointPoolMemType::Shared; |
167 |
| - } else if (Params.compare(0, 17, "read_only_shared:") == 0) { |
168 |
| - Pos = 17; |
169 |
| - M = DisjointPoolMemType::SharedReadOnly; |
170 |
| - } |
171 |
| - if (Pos > 0) { |
172 |
| - Params.erase(0, Pos); |
| 155 | + auto &config = AllConfigs.Configs[memType]; |
| 156 | + if (values.size() > 3) |
| 157 | + throw std::invalid_argument("Too many values for memory type " + type); |
| 158 | + |
| 159 | + setConfigValues(config, values); |
173 | 160 | }
|
174 |
| - MemParser(Params, M); |
175 |
| - }; |
| 161 | + } catch (std::invalid_argument &) { |
| 162 | + // if parsing per-type failed, try to parse the string with settings for all |
| 163 | + // types |
| 164 | + |
| 165 | + auto allTypeSettings = parse_string_to_vec(input); |
| 166 | + if (allTypeSettings.size() > 3) |
| 167 | + throw std::invalid_argument("Too many values for memory type settings"); |
176 | 168 |
|
177 |
| - size_t MaxSize = (std::numeric_limits<size_t>::max)(); |
178 |
| - |
179 |
| - // Update pool settings if specified in environment. |
180 |
| - size_t EnableBuffers = 1; |
181 |
| - if (config != "") { |
182 |
| - std::string Params = config; |
183 |
| - size_t Pos = Params.find(';'); |
184 |
| - if (Pos != std::string::npos) { |
185 |
| - if (Pos > 0) { |
186 |
| - GetValue(Params, Pos, EnableBuffers); |
187 |
| - } |
188 |
| - Params.erase(0, Pos + 1); |
189 |
| - size_t Pos = Params.find(';'); |
190 |
| - if (Pos != std::string::npos) { |
191 |
| - if (Pos > 0) { |
192 |
| - GetValue(Params, Pos, MaxSize); |
193 |
| - } |
194 |
| - Params.erase(0, Pos + 1); |
195 |
| - do { |
196 |
| - size_t Pos = Params.find(';'); |
197 |
| - if (Pos != std::string::npos) { |
198 |
| - if (Pos > 0) { |
199 |
| - std::string MemParams = Params.substr(0, Pos); |
200 |
| - MemTypeParser(MemParams); |
201 |
| - } |
202 |
| - Params.erase(0, Pos + 1); |
203 |
| - if (Params.size() == 0) { |
204 |
| - break; |
205 |
| - } |
206 |
| - } else { |
207 |
| - MemTypeParser(Params); |
208 |
| - break; |
209 |
| - } |
210 |
| - } while (true); |
211 |
| - } else { |
212 |
| - // set MaxPoolSize for all configs |
213 |
| - GetValue(Params, Params.size(), MaxSize); |
214 |
| - } |
215 |
| - } else { |
216 |
| - GetValue(Params, Params.size(), EnableBuffers); |
| 169 | + for (auto &Config : AllConfigs.Configs) { |
| 170 | + setConfigValues(Config, allTypeSettings); |
217 | 171 | }
|
218 | 172 | }
|
219 | 173 |
|
220 |
| - AllConfigs.EnableBuffers = EnableBuffers; |
| 174 | + AllConfigs.EnableBuffers = Buffers.value_or(1); |
221 | 175 |
|
222 | 176 | AllConfigs.limits = std::shared_ptr<umf_disjoint_pool_shared_limits_t>(
|
223 |
| - umfDisjointPoolSharedLimitsCreate(MaxSize), |
| 177 | + umfDisjointPoolSharedLimitsCreate( |
| 178 | + MaxSize.value_or(std::numeric_limits<size_t>::max())), |
224 | 179 | umfDisjointPoolSharedLimitsDestroy);
|
225 | 180 |
|
226 | 181 | for (auto &Config : AllConfigs.Configs) {
|
@@ -268,10 +223,11 @@ DisjointPoolAllConfigs parseDisjointPoolConfig(const std::string &config,
|
268 | 223 | << std::setw(12)
|
269 | 224 | << AllConfigs.Configs[DisjointPoolMemType::SharedReadOnly].Capacity
|
270 | 225 | << std::endl;
|
271 |
| - std::cout << std::setw(15) << "MaxPoolSize" << std::setw(12) << MaxSize |
| 226 | + std::cout << std::setw(15) << "MaxPoolSize" << std::setw(12) |
| 227 | + << MaxSize.value_or(std::numeric_limits<size_t>::max()) |
272 | 228 | << std::endl;
|
273 | 229 | std::cout << std::setw(15) << "EnableBuffers" << std::setw(12)
|
274 |
| - << EnableBuffers << std::endl |
| 230 | + << AllConfigs.EnableBuffers << std::endl |
275 | 231 | << std::endl;
|
276 | 232 |
|
277 | 233 | return AllConfigs;
|
|
0 commit comments