Skip to content

Commit 0bb3479

Browse files
committed
adding option to return ram state in info
1 parent f411fc2 commit 0bb3479

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

envpool/atari/atari_env.h

+23-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class AtariEnvFns {
5656
"img_height"_.Bind(84), "img_width"_.Bind(84),
5757
"task"_.Bind(std::string("pong")), "full_action_space"_.Bind(false),
5858
"repeat_action_probability"_.Bind(0.0f),
59-
"use_inter_area_resize"_.Bind(true), "gray_scale"_.Bind(true));
59+
"use_inter_area_resize"_.Bind(true), "gray_scale"_.Bind(true), "expose_ram"_.Bind(false));
6060
}
6161
template <typename Config>
6262
static decltype(auto) StateSpec(const Config& conf) {
@@ -66,7 +66,9 @@ class AtariEnvFns {
6666
{0, 255})),
6767
"info:lives"_.Bind(Spec<int>({-1})),
6868
"info:reward"_.Bind(Spec<float>({-1})),
69-
"info:terminated"_.Bind(Spec<int>({-1}, {0, 1})));
69+
"info:terminated"_.Bind(Spec<int>({-1}, {0, 1})),
70+
"info:ram"_.Bind(Spec<uint8_t>({128}, {0, 255}))
71+
);
7072
}
7173
template <typename Config>
7274
static decltype(auto) ActionSpec(const Config& conf) {
@@ -99,6 +101,7 @@ class AtariEnv : public Env<AtariEnvSpec> {
99101
std::vector<Array> maxpool_buf_;
100102
Array resize_img_;
101103
std::uniform_int_distribution<> dist_noop_;
104+
bool expose_ram_{false};
102105
std::string rom_path_;
103106

104107
public:
@@ -121,6 +124,7 @@ class AtariEnv : public Env<AtariEnvSpec> {
121124
spec.config["img_width"_]}),
122125
resize_img_(resize_spec_),
123126
dist_noop_(0, spec.config["noop_max"_] - 1),
127+
expose_ram_(spec.config["expose_ram"_]),
124128
rom_path_(GetRomPath(spec.config["base_path"_], spec.config["task"_])) {
125129
env_->setFloat("repeat_action_probability",
126130
spec.config["repeat_action_probability"_]);
@@ -247,6 +251,23 @@ class AtariEnv : public Env<AtariEnvSpec> {
247251
.Slice(gray_scale_ ? i : i * 3, gray_scale_ ? i + 1 : (i + 1) * 3)
248252
.Assign(stack_buf_[i]);
249253
}
254+
// Optionally add RAM state if expose_ram_ is true
255+
if (expose_ram_) {
256+
// const auto& ram = env_->getRAM(); // Get a reference to the RAM.
257+
// const size_t ram_size = ram.size(); // Obtain the size of the RAM.
258+
// const uint8_t* ram_data_ptr = ram.data();
259+
// std::vector<uint8_t> ram_data(ram_data_ptr, ram_data_ptr + ram_size);
260+
const size_t ram_size = env_->getRAM().size();
261+
std::vector<uint8_t> ram_data(ram_size);
262+
263+
// Assuming getRAM().array() gives direct access to the RAM data
264+
const uint8_t* ale_ram = env_->getRAM().array();
265+
std::copy(ale_ram, ale_ram + ram_size, ram_data.begin());
266+
state["info:ram"_].Assign(ale_ram, ram_size);
267+
// for (size_t i = 0; i < ram_size; ++i) {
268+
// state["ram"_].At(i) = ram[i]; // Directly write RAM data into state
269+
// }
270+
}
250271
}
251272

252273
/**

0 commit comments

Comments
 (0)