add action_masks

This commit is contained in:
steam
2023-06-11 20:00:12 +03:00
parent fd940dbba2
commit afd54d39a5

View File

@@ -141,6 +141,9 @@ class BaseEnvironment(gym.Env):
Unique to the environment action count. Must be inherited. Unique to the environment action count. Must be inherited.
""" """
def action_masks(self) -> list[bool]:
return [self._is_valid(action.value) for action in self.actions]
def seed(self, seed: int = 1): def seed(self, seed: int = 1):
self.np_random, seed = seeding.np_random(seed) self.np_random, seed = seeding.np_random(seed)
return [seed] return [seed]