diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py index 995b4e3..c8f67a8 100644 --- a/scienceworld/scienceworld.py +++ b/scienceworld/scienceworld.py @@ -1,3 +1,5 @@ +from typing import List, Dict, Tuple, Set, Any +from typing import OrderedDict as OrderedDictType import json import logging from collections import OrderedDict @@ -11,8 +13,19 @@ class ScienceWorldEnv: - - def __init__(self, taskName=None, serverPath=None, envStepLimit=100): + """Python wrapper for the simulator written in Scala. The methods that are + being wrapped can be found in simulator/src/main/scala/scienceworld/runtime/AgentInterface.scala. + Please look at that for more information on the internals of the system. + """ + + def __init__(self, taskName: str = None, serverPath: str = None, envStepLimit: int = 100): + '''Start the simulator. Sets up the interface between python and the JVM. + Also does basic init stuff. + :param taskName: The name of the task. Will be run through the infer_task method. + Tasks can also be loaded by the load method. + :param serverPath: The filepath to the server. By default, it is just scienceworld.jar. + :param envStepLimit: The maximum number of steps taken in the environment. Defaults to 100. + ''' serverPath = serverPath or JAR_PATH # Use the builtin jar. # Launch the server and connect to the JVM. @@ -67,9 +80,19 @@ def __init__(self, taskName=None, serverPath=None, envStepLimit=100): self.goldPathGenerated = False # Ask the simulator to load an environment from a script - def load(self, taskName, variationIdx=0, simplificationStr="", generateGoldPath=False): - """ Load a given task and its variation. """ - + def load(self, taskName: str, variationIdx: int = 0, simplificationStr: str = "", + generateGoldPath: bool = False) -> None: + '''Load a valid task and its variations/simplifications, and set up the simulator + and any task-specific properties (electrical, etc). Can optionally have the + simulator generate a gold path. If it successfully does, it will set + self.goldPathGenerated to True. + + :param taskName: The name of the task. Will be modified by the infer_task function. + :param variationIdx: The index for the specific variation to use. Default is 0. + :param simplificationStr: The string of simplifications to use. Should be comma + separated with no spaces. Defaults to "". For more, see get_possible_simplifications + :param generateGoldPath: Boolean var to generate gold path or not. Defaults to False. + ''' # Check loading arguments. # Validate task name. taskName = infer_task(taskName) @@ -106,8 +129,9 @@ def load(self, taskName, variationIdx=0, simplificationStr="", generateGoldPath= # Keep track of whether the gold path was generated, to generate verbose error messages self.goldPathGenerated = generateGoldPath - # Ask the simulator to reset an environment back to it's initial state - def reset(self): + def reset(self) -> Tuple[str, Dict[str, Any]]: + ''' Resets the simulator back to the first move (the output of "look around" is returned) ''' + self.server.reset() # Reset last step score (used to calculate reward from current-previous score) @@ -120,79 +144,95 @@ def reset(self): return observation, info # Simplifications - def get_simplifications_used(self): + def get_simplifications_used(self) -> str: + ''' Gets the simplifications being used by the simulator. ''' return self.server.getSimplificationsUsed() - def get_possible_simplifications(self): + def get_possible_simplifications(self) -> List[str]: + '''Gets the 6 possible simplifications. Those are: + - teleportAction: Adds actions to teleport directly to any possible location + - selfWateringFlowerPots: Flower pots will water themselves such that the plants won't die + - openContainers: Containers are open by default + - openDoors: Doors open are by default + - noElectricalAction: Remove all `connect X to Y` actions to reduce the action space + - easy: use all above simplifications + ''' return self.server.getPossibleSimplifications().split(", ") @property - def tasks(self): + def tasks(self) -> OrderedDictType[str, str]: """ Get the supported tasks in ScienceWorld. """ return OrderedDict(ID2TASK) @property - def task_names(self): - """ Get the name for the supported tasks in ScienceWorld. """ + def task_names(self) -> List[str]: + ''' Get the name for the supported tasks in ScienceWorld. ''' return list(ID2TASK.values()) - def get_task_names(self): - """ Get the name for the supported tasks in ScienceWorld. """ + def get_task_names(self) -> List[str]: + ''' Get the name for the supported tasks in ScienceWorld. ''' return list(self.server.getTaskNames()) - # Get the maximum number of variations for this task - def get_max_variations(self, task_name): + def get_max_variations(self, task_name) -> int: + ''' Get the maximum number of variations for the tasks. ''' return self.server.getTaskMaxVariations(infer_task(task_name)) # Get possible actions - def get_possible_actions(self): + def get_possible_actions(self) -> List[str]: + ''' Get all possible actions in the current environment state. ''' return list(self.server.getPossibleActions()) # Get possible actions (and also include the template IDs for those actions) - def get_possible_actions_with_IDs(self): + def get_possible_actions_with_IDs(self) -> List[Dict[str, Any]]: + ''' Get a list of dictionaries that map "action_example" to the action template and "template_id" to the id.''' jsonStr = self.server.getPossibleActionsWithIDs() data = json.loads(jsonStr) return data - # Get possible objects - def get_possible_objects(self): + def get_possible_objects(self) -> List[str]: + ''' Get a list of all observable objects ''' return list(self.server.getPossibleObjects()) # Get a list of object_ids to unique referents - def get_possible_object_referent_LUT(self): + def get_possible_object_referent_LUT(self) -> Dict[str, str]: + ''' Returns lookup table (dict) mapping object IDs to their referents. ''' jsonStr = self.server.getPossibleObjectReferentLUTJSON() data = json.loads(jsonStr) return data # As above, but dictionary is referenced by object type ID - def get_possible_object_referent_types_LUT(self): + def get_possible_object_referent_types_LUT(self) -> Dict[str, Dict[str, str]]: + ''' Returns lookup table (dict) mapping object type IDs to a dict of all objects of that type. ''' jsonStr = self.server.getPossibleObjectReferentTypesLUTJSON() data = json.loads(jsonStr) return data - # Get a list of *valid* agent-object combinations - def get_valid_action_object_combinations(self): + def get_valid_action_object_combinations(self) -> List[str]: + ''' Get a list of all of the *valid* action-object combinations. ''' return list(self.server.getValidActionObjectCombinations()) - def get_valid_action_object_combinations_with_templates(self): + def get_valid_action_object_combinations_with_templates(self) -> List[Dict[str, Any]]: + ''' Returns list of dicts with keys "action", "template_id", and "obj_ids" ''' jsonStr = self.server.getValidActionObjectCombinationsJSON() data = json.loads(jsonStr) return data['validActions'] - # Get a LUT of object_id to type_id - def get_all_object_types_LUTJSON(self): + def get_all_object_types_LUTJSON(self) -> Dict[str, str]: + ''' Returns look up table mapping object ids to type ids ''' jsonStr = self.server.getAllObjectTypesLUTJSON() data = json.loads(jsonStr) return data # Get a LUT of {object_id: {type_id, referent:[]} } tuples - def get_all_object_ids_types_referents_LUTJSON(self): + def get_all_object_ids_types_referents_LUTJSON(self) -> Dict[str, Dict[str, Any]]: + ''' Returns look up table mapping object ids to objects with keys "type_id" and "referents" ''' jsonStr = self.server.getAllObjectIdsTypesReferentsLUTJSON() data = json.loads(jsonStr) return data # Get possible action/object combinations - def get_possible_action_object_combinations(self): + def get_possible_action_object_combinations(self) -> Tuple[List[Dict[str, Any]], Dict[str, str]]: + ''' Get all *possible* action-object combinations, including invalid ones. ''' combinedJSON = self.server.getPossibleActionObjectCombinationsJSON() data = json.loads(combinedJSON) templates = data['templates'] @@ -200,14 +240,16 @@ def get_possible_action_object_combinations(self): return (templates, lookUpTable) - # Get a list of object types and their IDs - def get_object_types(self): + def get_object_types(self) -> Dict[str, int]: + '''Get a dict mapping object names to the object id. The object name is the name + of the actual file, for example "scienceworld.objects.containers.furniture.Chair". + ''' jsonStr = self.server.getObjectTypesLUTJSON() data = json.loads(jsonStr) return data - # Get the vocabulary of the model (at the current state) - def get_vocabulary(self): + def get_vocabulary(self) -> Set[str]: + ''' Get all words that currently have some sort of meaning to the simulator. ''' vocab = set() # Action vocabulary @@ -221,20 +263,28 @@ def get_vocabulary(self): return vocab - def get_num_moves(self): + def get_num_moves(self) -> int: + ''' Get the current number of moves. ''' return self.server.getNumMoves() - def get_task_description(self): + def get_task_description(self) -> str: + ''' Get the description of the current task. ''' return self.server.getTaskDescription() # History - def get_run_history(self): + def get_run_history(self) -> Dict[str, Any]: + ''' Get the run history ''' historyStr = self.server.getRunHistoryJSON() jsonOut = json.loads(historyStr) return jsonOut # History saving (provides an API to do this, so it's consistent across agents) - def store_run_history(self, episode_idx_key, notes): + def store_run_history(self, episode_idx_key: int, notes: str) -> None: + '''Store the run history, with notes. + + :param episode_idx_key: Episode index. Will be used as key. + :param notes: Notes on the run. + ''' packed = { 'episodeIdx': episode_idx_key, 'notes': notes, @@ -243,7 +293,11 @@ def store_run_history(self, episode_idx_key, notes): self.runHistories[episode_idx_key] = packed - def save_run_histories(self, filename_out_prefix): + def save_run_histories(self, filename_out_prefix: str) -> None: + '''Save the run histories to a file. + + :param filename_out_prefix: The name of the file to write to. + ''' # Save history # Create verbose filename @@ -261,46 +315,78 @@ def save_run_histories(self, filename_out_prefix): with open(filenameOut, 'w') as outfile: json.dump(self.runHistories, outfile, sort_keys=True, indent=4) - def get_run_history_size(self): + def get_run_history_size(self) -> int: + ''' Get the size of the run history ''' return len(self.runHistories) - def clear_run_histories(self): + def clear_run_histories(self) -> None: + ''' Clear the run histories. ''' self.runHistories = {} # A one-stop function to handle saving. - def save_run_histories_buffer_if_full(self, filename_out_prefix, max_per_file=1000, force_save=False): + def save_run_histories_buffer_if_full(self, filename_out_prefix: str, + max_per_file: int = 1000, force_save: bool = False) -> None: + '''One stop function for saving. + + If the histories buffer is full, saves to file and clears the buffer. + + :param filename_out_prefix: Name of the file to write to. + :param max_per_file: The max number of histories per file. Defaults to 1000. + :param force_save: Force the function to save, regardless of whether or not the + buffer is full. Defaults to False. + ''' if ((self.get_run_history_size() >= max_per_file) or force_save): self.save_run_histories(filename_out_prefix) self.clear_run_histories() # Train/development/test sets - def get_variations_train(self): + def get_variations_train(self) -> List[int]: + ''' Get the list of variations available for the training set. ''' return list(self.server.getVariationsTrain()) - def get_variations_dev(self): + def get_variations_dev(self) -> List[int]: + ''' Get the list of variations available for the development set. ''' return list(self.server.getVariationsDev()) - def get_variations_test(self): + def get_variations_test(self) -> List[int]: + ''' Get the list of variations available for the testing set. ''' return list(self.server.getVariationsTest()) - def get_random_variation_train(self): + def get_random_variation_train(self) -> int: + ''' Get a single random variation from those available for the training set. ''' return self.server.getRandomVariationTrain() - def get_random_variation_dev(self): + def get_random_variation_dev(self) -> int: + ''' Get a single random variation from those available for the development set. ''' return self.server.getRandomVariationDev() - def get_random_variation_test(self): + def get_random_variation_test(self) -> int: + ''' Get a single random variation from those available for the testing set. ''' return self.server.getRandomVariationTest() # Gold action sequence - def get_gold_action_sequence(self): + def get_gold_action_sequence(self) -> List[str]: + '''Get the gold action sequence. + The gold action sequence is a sequence of actions that leads to a winning state + (there is no guarantee it is the optimal). This function returns that if it is generated. + If it is not generated, it generates an error. + ''' if (self.goldPathGenerated): return list(self.server.getGoldActionSequence()) else: return ["ERROR: Gold path was not generated. Set `generateGoldPath` flag to true when calling load()."] # Step - def step(self, input_str: str): + def step(self, input_str: str) -> Tuple[str, int, bool, Dict[str, Any]]: + '''Take a step. + + This function takes one step in the typical state-action-reward cycle of RL. + :param input_str: The input string supplied to the simulator from an agent. + + Returns the observation, reward, completion status, and infos dict consisting of: + 'moves', 'score', 'reward', 'look', 'inv', 'taskDesc', 'valid', 'variationIdx', 'taskName', + and 'simplificationStr'. + ''' observation = self.server.step(input_str) score = int(round(100 * self.server.getScore())) # Convert from 0-1 to 0-100 isCompleted = self.server.getCompleted() @@ -336,20 +422,24 @@ def step(self, input_str: str): return observation, reward, isCompleted, infos # Special actions that are "free" (consume zero time) - def look(self): + def look(self) -> str: + ''' Look around. This is a "free" action in that it consumes no time. ''' observation = self.server.freeActionLook() return observation - def inventory(self): + def inventory(self) -> str: + ''' Check your inventory. This is a "free" action that consumes no time. ''' observation = self.server.freeActionInventory() return observation - def taskdescription(self): + def taskdescription(self) -> str: + ''' Get the task description. This is a "free" action that consumes no time. ''' observation = self.server.freeActionTaskDesc() return observation # Goal progress - def get_goal_progress(self): + def get_goal_progress(self) -> str: + ''' Get the progress to the goal. ''' goalStr = self.server.getGoalProgressStr() return goalStr diff --git a/scienceworld/utils.py b/scienceworld/utils.py index e917c4f..a7692cb 100644 --- a/scienceworld/utils.py +++ b/scienceworld/utils.py @@ -5,6 +5,8 @@ def infer_task(name_or_id): + ''' Takes a task name or task ID and processes it to produce a uniform task format. ''' + if name_or_id in NAME2ID: name_or_id = NAME2ID[name_or_id] diff --git a/simulator/src/main/scala/scienceworld/runtime/SimplifierProcessor.scala b/simulator/src/main/scala/scienceworld/runtime/SimplifierProcessor.scala index eec76d8..e265179 100644 --- a/simulator/src/main/scala/scienceworld/runtime/SimplifierProcessor.scala +++ b/simulator/src/main/scala/scienceworld/runtime/SimplifierProcessor.scala @@ -46,7 +46,7 @@ class SimplificationOpenDoors extends Simplification(label = SIMPLIFICATION_OPEN } -// Simplification: Open all doors in the environment +// Simplification: Open all containers in the environment class SimplificationOpenContainers extends Simplification(label = SIMPLIFICATION_OPEN_CONTAINERS, description = "All containers are open by default.") { runAtInitialization = true