Skip to content

Commit 7c01ed6

Browse files
committed
parsing
1 parent ed09541 commit 7c01ed6

File tree

4 files changed

+203
-181
lines changed

4 files changed

+203
-181
lines changed

source/common/umf_pools/disjoint_pool_config_parser.cpp

Lines changed: 93 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include <limits>
1414
#include <string>
1515

16+
#include "ur_util.hpp"
17+
1618
namespace usm {
1719
constexpr auto operator""_B(unsigned long long x) -> size_t { return x; }
1820
constexpr auto operator""_KB(unsigned long long x) -> size_t {
@@ -70,157 +72,110 @@ DisjointPoolAllConfigs::DisjointPoolAllConfigs(int trace) {
7072
Configs[DisjointPoolMemType::SharedReadOnly].SlabMinSize = 2_MB;
7173
}
7274

75+
std::optional<size_t> stringToNumber(std::string_view s) {
76+
auto unitPos = s.find_first_of("kKmMgG");
77+
size_t multiplier = 1;
78+
if (unitPos != std::string_view::npos) {
79+
switch (tolower(s[unitPos])) {
80+
case 'k':
81+
multiplier = 1_KB;
82+
break;
83+
case 'm':
84+
multiplier = 1_MB;
85+
break;
86+
case 'g':
87+
multiplier = 1_GB;
88+
break;
89+
}
90+
}
91+
92+
try {
93+
return std::stoull(std::string(s.substr(0, unitPos))) * multiplier;
94+
} catch (...) {
95+
return std::nullopt;
96+
}
97+
}
98+
99+
std::pair<std::optional<size_t>, std::string_view>
100+
maybeParseNumber(std::string_view s) {
101+
auto separator = ';';
102+
auto separatorPos = s.find(separator);
103+
if (separatorPos == std::string_view::npos) {
104+
auto number = stringToNumber(s.substr(0, separatorPos));
105+
if (number)
106+
return {number, std::string_view()};
107+
else
108+
return {std::nullopt, s};
109+
}
110+
111+
return {stringToNumber(s.substr(0, separatorPos)),
112+
s.substr(separatorPos + 1)};
113+
}
114+
115+
DisjointPoolMemType parseMemType(std::string_view s) {
116+
if (s == "host")
117+
return DisjointPoolMemType::Host;
118+
if (s == "device")
119+
return DisjointPoolMemType::Device;
120+
if (s == "shared")
121+
return DisjointPoolMemType::Shared;
122+
if (s == "read_only_shared")
123+
return DisjointPoolMemType::SharedReadOnly;
124+
125+
throw std::invalid_argument("Unknown memory type: " + std::string(s));
126+
}
127+
73128
DisjointPoolAllConfigs parseDisjointPoolConfig(const std::string &config,
74129
int trace) {
75130
DisjointPoolAllConfigs AllConfigs;
76131

77-
// TODO: replace with UR ENV var parser and avoid creating a copy of 'config'
78-
auto GetValue = [](std::string &Param, size_t Length, size_t &Setting) {
79-
size_t Multiplier = 1;
80-
if (tolower(Param[Length - 1]) == 'k') {
81-
Length--;
82-
Multiplier = 1_KB;
83-
}
84-
if (tolower(Param[Length - 1]) == 'm') {
85-
Length--;
86-
Multiplier = 1_MB;
87-
}
88-
if (tolower(Param[Length - 1]) == 'g') {
89-
Length--;
90-
Multiplier = 1_GB;
91-
}
92-
std::string TheNumber = Param.substr(0, Length);
93-
if (TheNumber.find_first_not_of("0123456789") == std::string::npos) {
94-
Setting = std::stoi(TheNumber) * Multiplier;
95-
}
96-
};
132+
std::optional<size_t> Buffers, MaxSize;
133+
std::string input;
97134

98-
auto ParamParser = [GetValue](std::string &Params, size_t &Setting,
99-
bool &ParamWasSet) {
100-
bool More;
101-
if (Params.size() == 0) {
102-
ParamWasSet = false;
103-
return false;
104-
}
105-
size_t Pos = Params.find(',');
106-
if (Pos != std::string::npos) {
107-
if (Pos > 0) {
108-
GetValue(Params, Pos, Setting);
109-
ParamWasSet = true;
110-
}
111-
Params.erase(0, Pos + 1);
112-
More = true;
113-
} else {
114-
GetValue(Params, Params.size(), Setting);
115-
ParamWasSet = true;
116-
More = false;
117-
}
118-
return More;
135+
std::tie(Buffers, input) = maybeParseNumber(config);
136+
std::tie(MaxSize, input) = maybeParseNumber(input);
137+
138+
auto setConfigValues = [](umf_disjoint_pool_config_t &config,
139+
std::vector<std::string> values) {
140+
if (values.size() > 0)
141+
config.MaxPoolableSize = stringToNumber(values[0]).value();
142+
if (values.size() > 1)
143+
config.Capacity = stringToNumber(values[1]).value();
144+
if (values.size() > 2)
145+
config.SlabMinSize = stringToNumber(values[2]).value();
119146
};
120147

121-
auto MemParser = [&AllConfigs, ParamParser](std::string &Params,
122-
DisjointPoolMemType memType =
123-
DisjointPoolMemType::All) {
124-
bool ParamWasSet;
125-
DisjointPoolMemType LM = memType;
126-
if (memType == DisjointPoolMemType::All) {
127-
LM = DisjointPoolMemType::Host;
128-
}
148+
try {
149+
// try to parse the string with per-type settings
129150

130-
bool More = ParamParser(Params, AllConfigs.Configs[LM].MaxPoolableSize,
131-
ParamWasSet);
132-
if (ParamWasSet && memType == DisjointPoolMemType::All) {
133-
for (auto &Config : AllConfigs.Configs) {
134-
Config.MaxPoolableSize = AllConfigs.Configs[LM].MaxPoolableSize;
135-
}
136-
}
137-
if (More) {
138-
More = ParamParser(Params, AllConfigs.Configs[LM].Capacity, ParamWasSet);
139-
if (ParamWasSet && memType == DisjointPoolMemType::All) {
140-
for (auto &Config : AllConfigs.Configs) {
141-
Config.Capacity = AllConfigs.Configs[LM].Capacity;
142-
}
143-
}
144-
}
145-
if (More) {
146-
ParamParser(Params, AllConfigs.Configs[LM].SlabMinSize, ParamWasSet);
147-
if (ParamWasSet && memType == DisjointPoolMemType::All) {
148-
for (auto &Config : AllConfigs.Configs) {
149-
Config.SlabMinSize = AllConfigs.Configs[LM].SlabMinSize;
150-
}
151-
}
152-
}
153-
};
151+
auto perTypeSettings = parse_string_to_map(input, false);
152+
for (auto &[type, values] : perTypeSettings) {
153+
DisjointPoolMemType memType = parseMemType(type);
154154

155-
auto MemTypeParser = [MemParser](std::string &Params) {
156-
int Pos = 0;
157-
DisjointPoolMemType M(DisjointPoolMemType::All);
158-
if (Params.compare(0, 5, "host:") == 0) {
159-
Pos = 5;
160-
M = DisjointPoolMemType::Host;
161-
} else if (Params.compare(0, 7, "device:") == 0) {
162-
Pos = 7;
163-
M = DisjointPoolMemType::Device;
164-
} else if (Params.compare(0, 7, "shared:") == 0) {
165-
Pos = 7;
166-
M = DisjointPoolMemType::Shared;
167-
} else if (Params.compare(0, 17, "read_only_shared:") == 0) {
168-
Pos = 17;
169-
M = DisjointPoolMemType::SharedReadOnly;
170-
}
171-
if (Pos > 0) {
172-
Params.erase(0, Pos);
155+
auto &config = AllConfigs.Configs[memType];
156+
if (values.size() > 3)
157+
throw std::invalid_argument("Too many values for memory type " + type);
158+
159+
setConfigValues(config, values);
173160
}
174-
MemParser(Params, M);
175-
};
161+
} catch (std::invalid_argument &) {
162+
// if parsing per-type failed, try to parse the string with settings for all
163+
// types
164+
165+
auto allTypeSettings = parse_string_to_vec(input);
166+
if (allTypeSettings.size() > 3)
167+
throw std::invalid_argument("Too many values for memory type settings");
176168

177-
size_t MaxSize = (std::numeric_limits<size_t>::max)();
178-
179-
// Update pool settings if specified in environment.
180-
size_t EnableBuffers = 1;
181-
if (config != "") {
182-
std::string Params = config;
183-
size_t Pos = Params.find(';');
184-
if (Pos != std::string::npos) {
185-
if (Pos > 0) {
186-
GetValue(Params, Pos, EnableBuffers);
187-
}
188-
Params.erase(0, Pos + 1);
189-
size_t Pos = Params.find(';');
190-
if (Pos != std::string::npos) {
191-
if (Pos > 0) {
192-
GetValue(Params, Pos, MaxSize);
193-
}
194-
Params.erase(0, Pos + 1);
195-
do {
196-
size_t Pos = Params.find(';');
197-
if (Pos != std::string::npos) {
198-
if (Pos > 0) {
199-
std::string MemParams = Params.substr(0, Pos);
200-
MemTypeParser(MemParams);
201-
}
202-
Params.erase(0, Pos + 1);
203-
if (Params.size() == 0) {
204-
break;
205-
}
206-
} else {
207-
MemTypeParser(Params);
208-
break;
209-
}
210-
} while (true);
211-
} else {
212-
// set MaxPoolSize for all configs
213-
GetValue(Params, Params.size(), MaxSize);
214-
}
215-
} else {
216-
GetValue(Params, Params.size(), EnableBuffers);
169+
for (auto &Config : AllConfigs.Configs) {
170+
setConfigValues(Config, allTypeSettings);
217171
}
218172
}
219173

220-
AllConfigs.EnableBuffers = EnableBuffers;
174+
AllConfigs.EnableBuffers = Buffers.value_or(1);
221175

222176
AllConfigs.limits = std::shared_ptr<umf_disjoint_pool_shared_limits_t>(
223-
umfDisjointPoolSharedLimitsCreate(MaxSize),
177+
umfDisjointPoolSharedLimitsCreate(
178+
MaxSize.value_or(std::numeric_limits<size_t>::max())),
224179
umfDisjointPoolSharedLimitsDestroy);
225180

226181
for (auto &Config : AllConfigs.Configs) {
@@ -268,10 +223,11 @@ DisjointPoolAllConfigs parseDisjointPoolConfig(const std::string &config,
268223
<< std::setw(12)
269224
<< AllConfigs.Configs[DisjointPoolMemType::SharedReadOnly].Capacity
270225
<< std::endl;
271-
std::cout << std::setw(15) << "MaxPoolSize" << std::setw(12) << MaxSize
226+
std::cout << std::setw(15) << "MaxPoolSize" << std::setw(12)
227+
<< MaxSize.value_or(std::numeric_limits<size_t>::max())
272228
<< std::endl;
273229
std::cout << std::setw(15) << "EnableBuffers" << std::setw(12)
274-
<< EnableBuffers << std::endl
230+
<< AllConfigs.EnableBuffers << std::endl
275231
<< std::endl;
276232

277233
return AllConfigs;

source/common/umf_pools/disjoint_pool_config_parser.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class DisjointPoolAllConfigs {
5858
// EnableBuffers: Apply chunking/pooling to SYCL buffers.
5959
// Default 1.
6060
// MaxPoolSize: Limit on overall unfreed memory.
61-
// Default 16MB.
61+
// Default: no limit.
6262
// MaxPoolableSize: Maximum allocation size subject to chunking/pooling.
6363
// Default 2MB host, 4MB device and 0 shared.
6464
// Capacity: Maximum number of unfreed allocations in each bucket.

0 commit comments

Comments
 (0)