@@ -56,7 +56,7 @@ class AtariEnvFns {
56
56
" img_height" _.Bind (84 ), " img_width" _.Bind (84 ),
57
57
" task" _.Bind (std::string (" pong" )), " full_action_space" _.Bind (false ),
58
58
" repeat_action_probability" _.Bind (0 .0f ),
59
- " use_inter_area_resize" _.Bind (true ), " gray_scale" _.Bind (true ));
59
+ " use_inter_area_resize" _.Bind (true ), " gray_scale" _.Bind (true ), " expose_ram " _. Bind ( false ) );
60
60
}
61
61
template <typename Config>
62
62
static decltype (auto ) StateSpec(const Config& conf) {
@@ -66,7 +66,9 @@ class AtariEnvFns {
66
66
{0 , 255 })),
67
67
" info:lives" _.Bind (Spec<int >({-1 })),
68
68
" info:reward" _.Bind (Spec<float >({-1 })),
69
- " info:terminated" _.Bind (Spec<int >({-1 }, {0 , 1 })));
69
+ " info:terminated" _.Bind (Spec<int >({-1 }, {0 , 1 })),
70
+ " info:ram" _.Bind (Spec<uint8_t >({128 }, {0 , 255 }))
71
+ );
70
72
}
71
73
template <typename Config>
72
74
static decltype (auto ) ActionSpec(const Config& conf) {
@@ -99,6 +101,7 @@ class AtariEnv : public Env<AtariEnvSpec> {
99
101
std::vector<Array> maxpool_buf_;
100
102
Array resize_img_;
101
103
std::uniform_int_distribution<> dist_noop_;
104
+ bool expose_ram_{false };
102
105
std::string rom_path_;
103
106
104
107
public:
@@ -121,6 +124,7 @@ class AtariEnv : public Env<AtariEnvSpec> {
121
124
spec.config [" img_width" _]}),
122
125
resize_img_ (resize_spec_),
123
126
dist_noop_ (0 , spec.config[" noop_max" _] - 1 ),
127
+ expose_ram_ (spec.config[" expose_ram" _]),
124
128
rom_path_ (GetRomPath(spec.config[" base_path" _], spec.config[" task" _])) {
125
129
env_->setFloat (" repeat_action_probability" ,
126
130
spec.config [" repeat_action_probability" _]);
@@ -247,6 +251,23 @@ class AtariEnv : public Env<AtariEnvSpec> {
247
251
.Slice (gray_scale_ ? i : i * 3 , gray_scale_ ? i + 1 : (i + 1 ) * 3 )
248
252
.Assign (stack_buf_[i]);
249
253
}
254
+ // Optionally add RAM state if expose_ram_ is true
255
+ if (expose_ram_) {
256
+ // const auto& ram = env_->getRAM(); // Get a reference to the RAM.
257
+ // const size_t ram_size = ram.size(); // Obtain the size of the RAM.
258
+ // const uint8_t* ram_data_ptr = ram.data();
259
+ // std::vector<uint8_t> ram_data(ram_data_ptr, ram_data_ptr + ram_size);
260
+ const size_t ram_size = env_->getRAM ().size ();
261
+ std::vector<uint8_t > ram_data (ram_size);
262
+
263
+ // Assuming getRAM().array() gives direct access to the RAM data
264
+ const uint8_t * ale_ram = env_->getRAM ().array ();
265
+ std::copy (ale_ram, ale_ram + ram_size, ram_data.begin ());
266
+ state[" info:ram" _].Assign (ale_ram, ram_size);
267
+ // for (size_t i = 0; i < ram_size; ++i) {
268
+ // state["ram"_].At(i) = ram[i]; // Directly write RAM data into state
269
+ // }
270
+ }
250
271
}
251
272
252
273
/* *
0 commit comments