Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT CLOSE] Library TODOs and call for contributions #116

Open
1 of 5 tasks
matteobettini opened this issue Jun 26, 2024 · 1 comment
Open
1 of 5 tasks

[DO NOT CLOSE] Library TODOs and call for contributions #116

matteobettini opened this issue Jun 26, 2024 · 1 comment

Comments

@matteobettini
Copy link
Member

matteobettini commented Jun 26, 2024

Hello people!

In this issue I will list the things I would really like to have in VMAS and will tick them off as they are implemented!

These were previously in the README TODOs

It is also a really good place to find something you would like to contribute.

Features

Sensors

  • Implement 1D camera sensor
  • Implement 2D birds eye view camera sensor
@matteobettini matteobettini pinned this issue Jun 26, 2024
@19991006
Copy link

19991006 commented Apr 2, 2025

Hi,

Thanks for your vectorized MARL environment! It helps a lot!

I have noticed that you mentioned a bird eye 2D sensor is needed, I have created a custom grid sensor which may meets your requirements. Here is the code.

Since I'm no professional developer (sad), the code may need further debug.

# -*- coding: utf-8 -*-
# @Time : 2025/3/17 10:41
# @Author : Dong Shaoqian
# @File : sensors.py
# @Software: PyCharm
import torch
from torch import Tensor

from vmas.simulator.core import World, Entity, Agent, Landmark
from vmas.simulator.utils import Color, X, Y
from vmas.simulator.sensors import Sensor
from vmas.simulator.rendering import Geom

from typing import List, Union, Tuple

from envs.simulator.utils import entity_size


def is_ally(agent: Agent, entity: Entity):
    if (isinstance(entity, Agent)
            and not (agent.adversary ^ entity.adversary)):
        return True
    else:
        return False


def is_enemy(agent: Agent, entity: Entity):
    if (isinstance(entity, Agent)
            and (agent.adversary ^ entity.adversary)):
        return True
    else:
        return False


def is_obstacle(_, entity: Entity):
    if isinstance(entity, Landmark) and 'obstacle' in entity.name:
        return True
    else:
        return False


def is_target(_, entity: Entity):
    if isinstance(entity, Landmark) and 'target' in entity.name:
        return True
    else:
        return False


def out_of_sensor_range(resolution, coordinate):
    if (coordinate < resolution / 2
            or coordinate >= resolution / 2):
        return False
    else:
        return True


class GridSensor(Sensor):
    def __init__(
            self,
            world: World,
            resolution: int = 28,
            grid_size: float = 0.02,
            render_color: Union[Color, Tuple[float, float, float]] = (0.7, 0.7, 0.7),
            render: bool = False,
    ):
        super().__init__(world=world)
        self._grid_size = grid_size * 0.5 * (world.x_semidim + world.y_semidim)
        self._world = world
        self._render = render
        self._render_color = render_color
        self._resolution = resolution
        self._n_grids = resolution ** 2

        self.check_resolution()

    @property
    def world(self):
        return self._world

    @property
    def grid_size(self):
        return self._grid_size

    @property
    def n_grids(self):
        return self._n_grids

    @property
    def resolution(self):
        return self._resolution

    def check_resolution(self):
        for entity in self.world.entities:
            radius = entity_size(entity)
            assert self.grid_size * 1.414 <= radius, \
                f'{entity.name} are too small to be detected'

    def measure_entities(self, entity_filter):
        grids = []
        for entity in self.world.entities:
            if self.agent is entity or not entity_filter(self.agent, entity=entity):
                continue
            rel_pos = entity.state.pos - self.agent.state.pos  # (batch, 2)
            rel_pos_x = rel_pos[:, X].reshape(-1, 1, 1)
            rel_pos_y = rel_pos[:, Y].reshape(-1, 1, 1)  # (batch, )

            center_range = self.grid_size * (self.resolution / 2 - 0.5)
            mesh = torch.arange(
                -center_range,
                center_range + self.grid_size / 2,
                self.grid_size,
                device=self.world.device
            )
            x, y = torch.meshgrid(mesh, mesh, indexing='ij')  # (res, res)
            x = x.expand(
                self.world.batch_dim, self.resolution, self.resolution
            )
            y = y.expand(
                self.world.batch_dim, self.resolution, self.resolution
            )  # (batch, res, res)

            dist = torch.sqrt(
                (x - rel_pos_x) ** 2
                + (y - rel_pos_y) ** 2
            )  # (batch, res, res)

            radius = entity_size(entity)
            grid = (dist <= radius).to(torch.int32)  # (batch, res, res)
            grids.append(grid)

        if len(grids) == 0:
            return torch.zeros((self.world.batch_dim, self.resolution, self.resolution),
                               dtype=torch.float32,
                               device=self.world.device)
        else:
            grids = torch.stack(grids, dim=-1)  # (batch, res, res, num_entities)
            grids = grids.sum(-1).to(torch.bool)  # (batch, res, res)

            return grids.to(torch.float32)

    def measure(self) -> Tensor:
        allies = self.measure_entities(is_ally)
        enemies = self.measure_entities(is_enemy)
        obstacles = self.measure_entities(is_obstacle)
        target = self.measure_entities(is_target)

        return torch.stack(
            [
                allies,
                enemies,
                obstacles,
                target
            ],
            dim=1
        )  # (batch, channels=4, res, res)

    def render(self, env_index: int = 0) -> "List[Geom]":
        return []

    def to(self, device: torch.device):
        pass

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants