From 3ed8b4a11757ea07d22a4b9683adf652debe6ffe Mon Sep 17 00:00:00 2001 From: Andrew Kaminer Date: Sun, 14 Jan 2024 14:46:17 -0500 Subject: [PATCH 01/17] Add docstrings for python api. Need to add deprecated docstrings for the deprecated api, as well as type hinting. Type hinting will be a game changer, as it makes the docstrings way more clear (IMO). --- scienceworld/scienceworld.py | 101 ++++++++++++++++-- scienceworld/utils.py | 2 + .../runtime/SimplifierProcessor.scala | 2 +- 3 files changed, 93 insertions(+), 12 deletions(-) diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py index 995b4e34..41e83746 100644 --- a/scienceworld/scienceworld.py +++ b/scienceworld/scienceworld.py @@ -11,8 +11,18 @@ class ScienceWorldEnv: + """Python wrapper for the simulator written in Scala. The methods that are + being wrapped can be found in simulator/src/main/scala/scienceworld/runtime/AgentInterface.scala. + Please look at that for more information on the internals of the system. + """ def __init__(self, taskName=None, serverPath=None, envStepLimit=100): + '''Start the simulator. Sets up the interface between python and the JVM. + Also does basic init stuff. + :param taskName: The name of the task. Will be run through the infer_task method. Tasks can also be loaded by the load method. + :param serverPath: The filepath to the server. By default, it is just scienceworld.jar. + :param envStepLimit: The maximum number of steps taken in the environment. Defaults to 100. + ''' serverPath = serverPath or JAR_PATH # Use the builtin jar. # Launch the server and connect to the JVM. @@ -68,8 +78,16 @@ def __init__(self, taskName=None, serverPath=None, envStepLimit=100): # Ask the simulator to load an environment from a script def load(self, taskName, variationIdx=0, simplificationStr="", generateGoldPath=False): - """ Load a given task and its variation. """ - + '''Load a valid task and its variations/simplifications, and set up the simulator + and any task-specific properties (electrical, etc). Can optionally have the + simulator generate a gold path. If it successfully does, it will set + self.goldPathGenerated to True. + + :param taskName: The name of the task. Will be modified by the infer_task function. + :param variationIdx: The index for the specific variation to use. Default is 0. + :param simplificationStr: The string of simplifications to use. Should be comma separated with no spaces. Defaults to "". For more, see get_possible_simplifications + :param generateGoldPath: Boolean var to generate gold path or not. Defaults to False. + ''' # Check loading arguments. # Validate task name. taskName = infer_task(taskName) @@ -106,8 +124,9 @@ def load(self, taskName, variationIdx=0, simplificationStr="", generateGoldPath= # Keep track of whether the gold path was generated, to generate verbose error messages self.goldPathGenerated = generateGoldPath - # Ask the simulator to reset an environment back to it's initial state def reset(self): + ''' Resets the simulator back to the first move (the output of "look around" is returned) ''' + self.server.reset() # Reset last step score (used to calculate reward from current-previous score) @@ -121,9 +140,18 @@ def reset(self): # Simplifications def get_simplifications_used(self): + ''' Gets the simplifications being used by the simulator. ''' return self.server.getSimplificationsUsed() def get_possible_simplifications(self): + '''Gets the 6 possible simplifications. There are 6 simplifictions: + - teleportAction: Teleport action + - selfWateringFlowerPots: Self-watering flower pots + - openContainers: Containers open by default + - openDoors: Doors open by default + - noElectricalAction: Remove the electrical actions + - easy: use all 5 simplifications + ''' return self.server.getPossibleSimplifications().split(", ") @property @@ -133,66 +161,73 @@ def tasks(self): @property def task_names(self): - """ Get the name for the supported tasks in ScienceWorld. """ + ''' Get the name for the supported tasks in ScienceWorld. ''' return list(ID2TASK.values()) def get_task_names(self): - """ Get the name for the supported tasks in ScienceWorld. """ + ''' Get the name for the supported tasks in ScienceWorld. ''' return list(self.server.getTaskNames()) - # Get the maximum number of variations for this task def get_max_variations(self, task_name): + ''' Get the maximum number of variations for the tasks. ''' return self.server.getTaskMaxVariations(infer_task(task_name)) # Get possible actions def get_possible_actions(self): + ''' Get all possible actions in the current environment state. ''' return list(self.server.getPossibleActions()) # Get possible actions (and also include the template IDs for those actions) def get_possible_actions_with_IDs(self): + ''' Get a list of dictionaries that map "action_example" to the action template and "template_id" to the id.''' jsonStr = self.server.getPossibleActionsWithIDs() data = json.loads(jsonStr) return data - # Get possible objects def get_possible_objects(self): + ''' Get a list of all observable objects ''' return list(self.server.getPossibleObjects()) # Get a list of object_ids to unique referents def get_possible_object_referent_LUT(self): + ''' Returns lookup table (dict) mapping object IDs to their referents. ''' jsonStr = self.server.getPossibleObjectReferentLUTJSON() data = json.loads(jsonStr) return data # As above, but dictionary is referenced by object type ID def get_possible_object_referent_types_LUT(self): + ''' Returns lookup table (dict) mapping object type IDs to a dict of all objects of that type. ''' jsonStr = self.server.getPossibleObjectReferentTypesLUTJSON() data = json.loads(jsonStr) return data - # Get a list of *valid* agent-object combinations def get_valid_action_object_combinations(self): + ''' Get a list of all of the *valid* action-object combinations. ''' return list(self.server.getValidActionObjectCombinations()) def get_valid_action_object_combinations_with_templates(self): + ''' Returns list of dicts with keys "action", "template_id", and "obj_ids" ''' jsonStr = self.server.getValidActionObjectCombinationsJSON() data = json.loads(jsonStr) return data['validActions'] - # Get a LUT of object_id to type_id def get_all_object_types_LUTJSON(self): + ''' Returns look up table mapping object ids to type ids ''' jsonStr = self.server.getAllObjectTypesLUTJSON() data = json.loads(jsonStr) return data # Get a LUT of {object_id: {type_id, referent:[]} } tuples def get_all_object_ids_types_referents_LUTJSON(self): + ''' Returns look up table mapping object ids to objects with keys "type_id" and "referents" ''' jsonStr = self.server.getAllObjectIdsTypesReferentsLUTJSON() data = json.loads(jsonStr) return data # Get possible action/object combinations def get_possible_action_object_combinations(self): + ''' Get all *possible* action-object combinations, including invalid ones. ''' combinedJSON = self.server.getPossibleActionObjectCombinationsJSON() data = json.loads(combinedJSON) templates = data['templates'] @@ -200,14 +235,14 @@ def get_possible_action_object_combinations(self): return (templates, lookUpTable) - # Get a list of object types and their IDs def get_object_types(self): + ''' Get a dict mapping object names to the object id. The object name is the name of the actual file, for example "scienceworld.objects.containers.furniture.Chair". ''' jsonStr = self.server.getObjectTypesLUTJSON() data = json.loads(jsonStr) return data - # Get the vocabulary of the model (at the current state) def get_vocabulary(self): + ''' Get all words that currently have some sort of meaning to the simulator. ''' vocab = set() # Action vocabulary @@ -222,19 +257,27 @@ def get_vocabulary(self): return vocab def get_num_moves(self): + ''' Get the current number of moves. ''' return self.server.getNumMoves() def get_task_description(self): + ''' Get the description of the current task. ''' return self.server.getTaskDescription() # History def get_run_history(self): + ''' Get the run history ''' historyStr = self.server.getRunHistoryJSON() jsonOut = json.loads(historyStr) return jsonOut # History saving (provides an API to do this, so it's consistent across agents) def store_run_history(self, episode_idx_key, notes): + '''Store the run history, with notes. + + :param episode_idx_key: Episode index. Will be used as key. + :param notes: Notes on the run. + ''' packed = { 'episodeIdx': episode_idx_key, 'notes': notes, @@ -244,6 +287,10 @@ def store_run_history(self, episode_idx_key, notes): self.runHistories[episode_idx_key] = packed def save_run_histories(self, filename_out_prefix): + '''Save the run histories to a file. + + :param filename_out_prefix: The name of the file to write to. + ''' # Save history # Create verbose filename @@ -262,38 +309,58 @@ def save_run_histories(self, filename_out_prefix): json.dump(self.runHistories, outfile, sort_keys=True, indent=4) def get_run_history_size(self): + ''' Get the size of the run history ''' return len(self.runHistories) def clear_run_histories(self): + ''' Clear the run histories. ''' self.runHistories = {} # A one-stop function to handle saving. def save_run_histories_buffer_if_full(self, filename_out_prefix, max_per_file=1000, force_save=False): + '''One stop function for saving. + + If the histories buffer is full, saves to file and clears the buffer. + + :param filename_out_prefix: Name of the file to write to. + :param max_per_file: The max number of histories per file. Defaults to 1000. + :param force_save: Force the function to save, regardless of whether or not the buffer is full. Defaults to False. + ''' if ((self.get_run_history_size() >= max_per_file) or force_save): self.save_run_histories(filename_out_prefix) self.clear_run_histories() # Train/development/test sets def get_variations_train(self): + ''' Get the list of variations available for the training set. ''' return list(self.server.getVariationsTrain()) def get_variations_dev(self): + ''' Get the list of variations available for the development set. ''' return list(self.server.getVariationsDev()) def get_variations_test(self): + ''' Get the list of variations available for the testing set. ''' return list(self.server.getVariationsTest()) def get_random_variation_train(self): + ''' Get a single random variation from those available for the training set. ''' return self.server.getRandomVariationTrain() def get_random_variation_dev(self): + ''' Get a single random variation from those available for the development set. ''' return self.server.getRandomVariationDev() def get_random_variation_test(self): + ''' Get a single random variation from those available for the testing set. ''' return self.server.getRandomVariationTest() # Gold action sequence def get_gold_action_sequence(self): + '''Get the gold action sequence. + The gold action sequence is the optimal sequence of actions. This function returns that if it is generated. + If it is not generated, it generates an error. + ''' if (self.goldPathGenerated): return list(self.server.getGoldActionSequence()) else: @@ -301,6 +368,14 @@ def get_gold_action_sequence(self): # Step def step(self, input_str: str): + '''Take a step. + + This function takes one step in the typical state-action-reward cycle of RL. + :param input_str: The input string supplied to the simulator from an agent. + + Returns the observation, reward, completion status, and infos dict consisting of: + 'moves', 'score', 'reward', 'look', 'inv', 'taskDesc', 'valid', 'variationIdx', 'taskName', and 'simplificationStr'. + ''' observation = self.server.step(input_str) score = int(round(100 * self.server.getScore())) # Convert from 0-1 to 0-100 isCompleted = self.server.getCompleted() @@ -337,19 +412,23 @@ def step(self, input_str: str): # Special actions that are "free" (consume zero time) def look(self): + ''' Look around. This is a "free" action in that it consumes no time. ''' observation = self.server.freeActionLook() return observation def inventory(self): + ''' Check your inventory. This is a "free" action that consumes no time. ''' observation = self.server.freeActionInventory() return observation def taskdescription(self): + ''' Get the task description. This is a "free" action that consumes no time. ''' observation = self.server.freeActionTaskDesc() return observation # Goal progress def get_goal_progress(self): + ''' Get the progress to the goal. ''' goalStr = self.server.getGoalProgressStr() return goalStr diff --git a/scienceworld/utils.py b/scienceworld/utils.py index e917c4f1..a7692cb3 100644 --- a/scienceworld/utils.py +++ b/scienceworld/utils.py @@ -5,6 +5,8 @@ def infer_task(name_or_id): + ''' Takes a task name or task ID and processes it to produce a uniform task format. ''' + if name_or_id in NAME2ID: name_or_id = NAME2ID[name_or_id] diff --git a/simulator/src/main/scala/scienceworld/runtime/SimplifierProcessor.scala b/simulator/src/main/scala/scienceworld/runtime/SimplifierProcessor.scala index eec76d83..e265179c 100644 --- a/simulator/src/main/scala/scienceworld/runtime/SimplifierProcessor.scala +++ b/simulator/src/main/scala/scienceworld/runtime/SimplifierProcessor.scala @@ -46,7 +46,7 @@ class SimplificationOpenDoors extends Simplification(label = SIMPLIFICATION_OPEN } -// Simplification: Open all doors in the environment +// Simplification: Open all containers in the environment class SimplificationOpenContainers extends Simplification(label = SIMPLIFICATION_OPEN_CONTAINERS, description = "All containers are open by default.") { runAtInitialization = true From 0b3651f46cda36eb2225437d2087c10b58105be5 Mon Sep 17 00:00:00 2001 From: Andrew Kaminer Date: Sun, 14 Jan 2024 15:01:30 -0500 Subject: [PATCH 02/17] Fix flake8 failures. --- scienceworld/scienceworld.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py index 41e83746..a8f9277e 100644 --- a/scienceworld/scienceworld.py +++ b/scienceworld/scienceworld.py @@ -19,7 +19,8 @@ class ScienceWorldEnv: def __init__(self, taskName=None, serverPath=None, envStepLimit=100): '''Start the simulator. Sets up the interface between python and the JVM. Also does basic init stuff. - :param taskName: The name of the task. Will be run through the infer_task method. Tasks can also be loaded by the load method. + :param taskName: The name of the task. Will be run through the infer_task method. + Tasks can also be loaded by the load method. :param serverPath: The filepath to the server. By default, it is just scienceworld.jar. :param envStepLimit: The maximum number of steps taken in the environment. Defaults to 100. ''' @@ -85,7 +86,8 @@ def load(self, taskName, variationIdx=0, simplificationStr="", generateGoldPath= :param taskName: The name of the task. Will be modified by the infer_task function. :param variationIdx: The index for the specific variation to use. Default is 0. - :param simplificationStr: The string of simplifications to use. Should be comma separated with no spaces. Defaults to "". For more, see get_possible_simplifications + :param simplificationStr: The string of simplifications to use. Should be comma + separated with no spaces. Defaults to "". For more, see get_possible_simplifications :param generateGoldPath: Boolean var to generate gold path or not. Defaults to False. ''' # Check loading arguments. @@ -236,7 +238,9 @@ def get_possible_action_object_combinations(self): return (templates, lookUpTable) def get_object_types(self): - ''' Get a dict mapping object names to the object id. The object name is the name of the actual file, for example "scienceworld.objects.containers.furniture.Chair". ''' + '''Get a dict mapping object names to the object id. The object name is the name + of the actual file, for example "scienceworld.objects.containers.furniture.Chair". + ''' jsonStr = self.server.getObjectTypesLUTJSON() data = json.loads(jsonStr) return data @@ -321,10 +325,11 @@ def save_run_histories_buffer_if_full(self, filename_out_prefix, max_per_file=10 '''One stop function for saving. If the histories buffer is full, saves to file and clears the buffer. - + :param filename_out_prefix: Name of the file to write to. :param max_per_file: The max number of histories per file. Defaults to 1000. - :param force_save: Force the function to save, regardless of whether or not the buffer is full. Defaults to False. + :param force_save: Force the function to save, regardless of whether or not the + buffer is full. Defaults to False. ''' if ((self.get_run_history_size() >= max_per_file) or force_save): self.save_run_histories(filename_out_prefix) @@ -369,12 +374,13 @@ def get_gold_action_sequence(self): # Step def step(self, input_str: str): '''Take a step. - + This function takes one step in the typical state-action-reward cycle of RL. :param input_str: The input string supplied to the simulator from an agent. Returns the observation, reward, completion status, and infos dict consisting of: - 'moves', 'score', 'reward', 'look', 'inv', 'taskDesc', 'valid', 'variationIdx', 'taskName', and 'simplificationStr'. + 'moves', 'score', 'reward', 'look', 'inv', 'taskDesc', 'valid', 'variationIdx', 'taskName', + and 'simplificationStr'. ''' observation = self.server.step(input_str) score = int(round(100 * self.server.getScore())) # Convert from 0-1 to 0-100 From 1b15b5a8f414229c99f74150e99f686f1d3c0219 Mon Sep 17 00:00:00 2001 From: Andrew Kaminer Date: Mon, 15 Jan 2024 16:46:53 -0500 Subject: [PATCH 03/17] Add type hinting in function returns --- scienceworld/scienceworld.py | 82 ++++++++++++++++++------------------ 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py index a8f9277e..249af1c1 100644 --- a/scienceworld/scienceworld.py +++ b/scienceworld/scienceworld.py @@ -16,7 +16,7 @@ class ScienceWorldEnv: Please look at that for more information on the internals of the system. """ - def __init__(self, taskName=None, serverPath=None, envStepLimit=100): + def __init__(self, taskName: str=None, serverPath: str=None, envStepLimit: int=100): '''Start the simulator. Sets up the interface between python and the JVM. Also does basic init stuff. :param taskName: The name of the task. Will be run through the infer_task method. @@ -78,7 +78,7 @@ def __init__(self, taskName=None, serverPath=None, envStepLimit=100): self.goldPathGenerated = False # Ask the simulator to load an environment from a script - def load(self, taskName, variationIdx=0, simplificationStr="", generateGoldPath=False): + def load(self, taskName: str, variationIdx: int=0, simplificationStr: str="", generateGoldPath: bool=False) -> None: '''Load a valid task and its variations/simplifications, and set up the simulator and any task-specific properties (electrical, etc). Can optionally have the simulator generate a gold path. If it successfully does, it will set @@ -126,7 +126,7 @@ def load(self, taskName, variationIdx=0, simplificationStr="", generateGoldPath= # Keep track of whether the gold path was generated, to generate verbose error messages self.goldPathGenerated = generateGoldPath - def reset(self): + def reset(self) -> tuple[str, dict[str]]: ''' Resets the simulator back to the first move (the output of "look around" is returned) ''' self.server.reset() @@ -141,11 +141,11 @@ def reset(self): return observation, info # Simplifications - def get_simplifications_used(self): + def get_simplifications_used(self) -> str: ''' Gets the simplifications being used by the simulator. ''' return self.server.getSimplificationsUsed() - def get_possible_simplifications(self): + def get_possible_simplifications(self) -> list[str]: '''Gets the 6 possible simplifications. There are 6 simplifictions: - teleportAction: Teleport action - selfWateringFlowerPots: Self-watering flower pots @@ -157,78 +157,78 @@ def get_possible_simplifications(self): return self.server.getPossibleSimplifications().split(", ") @property - def tasks(self): + def tasks(self) -> OrderedDict[str, str]: """ Get the supported tasks in ScienceWorld. """ return OrderedDict(ID2TASK) @property - def task_names(self): + def task_names(self) -> list[str]: ''' Get the name for the supported tasks in ScienceWorld. ''' return list(ID2TASK.values()) - def get_task_names(self): + def get_task_names(self) -> list[str]: ''' Get the name for the supported tasks in ScienceWorld. ''' return list(self.server.getTaskNames()) - def get_max_variations(self, task_name): + def get_max_variations(self, task_name) -> int: ''' Get the maximum number of variations for the tasks. ''' return self.server.getTaskMaxVariations(infer_task(task_name)) # Get possible actions - def get_possible_actions(self): + def get_possible_actions(self) -> list[str]: ''' Get all possible actions in the current environment state. ''' return list(self.server.getPossibleActions()) # Get possible actions (and also include the template IDs for those actions) - def get_possible_actions_with_IDs(self): + def get_possible_actions_with_IDs(self) -> list[dict[str]]: ''' Get a list of dictionaries that map "action_example" to the action template and "template_id" to the id.''' jsonStr = self.server.getPossibleActionsWithIDs() data = json.loads(jsonStr) return data - def get_possible_objects(self): + def get_possible_objects(self) -> list[str]: ''' Get a list of all observable objects ''' return list(self.server.getPossibleObjects()) # Get a list of object_ids to unique referents - def get_possible_object_referent_LUT(self): + def get_possible_object_referent_LUT(self) -> dict[str, str]: ''' Returns lookup table (dict) mapping object IDs to their referents. ''' jsonStr = self.server.getPossibleObjectReferentLUTJSON() data = json.loads(jsonStr) return data # As above, but dictionary is referenced by object type ID - def get_possible_object_referent_types_LUT(self): + def get_possible_object_referent_types_LUT(self) -> dict[str, dict[str, str]]: ''' Returns lookup table (dict) mapping object type IDs to a dict of all objects of that type. ''' jsonStr = self.server.getPossibleObjectReferentTypesLUTJSON() data = json.loads(jsonStr) return data - def get_valid_action_object_combinations(self): + def get_valid_action_object_combinations(self) -> list[str]: ''' Get a list of all of the *valid* action-object combinations. ''' return list(self.server.getValidActionObjectCombinations()) - def get_valid_action_object_combinations_with_templates(self): + def get_valid_action_object_combinations_with_templates(self) -> list[dict[str]] : ''' Returns list of dicts with keys "action", "template_id", and "obj_ids" ''' jsonStr = self.server.getValidActionObjectCombinationsJSON() data = json.loads(jsonStr) return data['validActions'] - def get_all_object_types_LUTJSON(self): + def get_all_object_types_LUTJSON(self) -> dict[str, str]: ''' Returns look up table mapping object ids to type ids ''' jsonStr = self.server.getAllObjectTypesLUTJSON() data = json.loads(jsonStr) return data # Get a LUT of {object_id: {type_id, referent:[]} } tuples - def get_all_object_ids_types_referents_LUTJSON(self): + def get_all_object_ids_types_referents_LUTJSON(self) -> dict[str, dict[str]]: ''' Returns look up table mapping object ids to objects with keys "type_id" and "referents" ''' jsonStr = self.server.getAllObjectIdsTypesReferentsLUTJSON() data = json.loads(jsonStr) return data # Get possible action/object combinations - def get_possible_action_object_combinations(self): + def get_possible_action_object_combinations(self) -> tuple[list[dict[str]], dict[str, str]]: ''' Get all *possible* action-object combinations, including invalid ones. ''' combinedJSON = self.server.getPossibleActionObjectCombinationsJSON() data = json.loads(combinedJSON) @@ -237,7 +237,7 @@ def get_possible_action_object_combinations(self): return (templates, lookUpTable) - def get_object_types(self): + def get_object_types(self) -> dict[str, int]: '''Get a dict mapping object names to the object id. The object name is the name of the actual file, for example "scienceworld.objects.containers.furniture.Chair". ''' @@ -245,7 +245,7 @@ def get_object_types(self): data = json.loads(jsonStr) return data - def get_vocabulary(self): + def get_vocabulary(self) -> set[str]: ''' Get all words that currently have some sort of meaning to the simulator. ''' vocab = set() @@ -260,23 +260,23 @@ def get_vocabulary(self): return vocab - def get_num_moves(self): + def get_num_moves(self) -> int: ''' Get the current number of moves. ''' return self.server.getNumMoves() - def get_task_description(self): + def get_task_description(self) -> str: ''' Get the description of the current task. ''' return self.server.getTaskDescription() # History - def get_run_history(self): + def get_run_history(self) -> dict[str]: ''' Get the run history ''' historyStr = self.server.getRunHistoryJSON() jsonOut = json.loads(historyStr) return jsonOut # History saving (provides an API to do this, so it's consistent across agents) - def store_run_history(self, episode_idx_key, notes): + def store_run_history(self, episode_idx_key, notes) -> None: '''Store the run history, with notes. :param episode_idx_key: Episode index. Will be used as key. @@ -290,7 +290,7 @@ def store_run_history(self, episode_idx_key, notes): self.runHistories[episode_idx_key] = packed - def save_run_histories(self, filename_out_prefix): + def save_run_histories(self, filename_out_prefix) -> None: '''Save the run histories to a file. :param filename_out_prefix: The name of the file to write to. @@ -312,16 +312,16 @@ def save_run_histories(self, filename_out_prefix): with open(filenameOut, 'w') as outfile: json.dump(self.runHistories, outfile, sort_keys=True, indent=4) - def get_run_history_size(self): + def get_run_history_size(self) -> int: ''' Get the size of the run history ''' return len(self.runHistories) - def clear_run_histories(self): + def clear_run_histories(self) -> None: ''' Clear the run histories. ''' self.runHistories = {} # A one-stop function to handle saving. - def save_run_histories_buffer_if_full(self, filename_out_prefix, max_per_file=1000, force_save=False): + def save_run_histories_buffer_if_full(self, filename_out_prefix, max_per_file=1000, force_save=False) -> None: '''One stop function for saving. If the histories buffer is full, saves to file and clears the buffer. @@ -336,32 +336,32 @@ def save_run_histories_buffer_if_full(self, filename_out_prefix, max_per_file=10 self.clear_run_histories() # Train/development/test sets - def get_variations_train(self): + def get_variations_train(self) -> list[int]: ''' Get the list of variations available for the training set. ''' return list(self.server.getVariationsTrain()) - def get_variations_dev(self): + def get_variations_dev(self) -> list[int]: ''' Get the list of variations available for the development set. ''' return list(self.server.getVariationsDev()) - def get_variations_test(self): + def get_variations_test(self) -> list[int]: ''' Get the list of variations available for the testing set. ''' return list(self.server.getVariationsTest()) - def get_random_variation_train(self): + def get_random_variation_train(self) -> int: ''' Get a single random variation from those available for the training set. ''' return self.server.getRandomVariationTrain() - def get_random_variation_dev(self): + def get_random_variation_dev(self) -> int: ''' Get a single random variation from those available for the development set. ''' return self.server.getRandomVariationDev() - def get_random_variation_test(self): + def get_random_variation_test(self) -> int: ''' Get a single random variation from those available for the testing set. ''' return self.server.getRandomVariationTest() # Gold action sequence - def get_gold_action_sequence(self): + def get_gold_action_sequence(self) -> list[str]: '''Get the gold action sequence. The gold action sequence is the optimal sequence of actions. This function returns that if it is generated. If it is not generated, it generates an error. @@ -372,7 +372,7 @@ def get_gold_action_sequence(self): return ["ERROR: Gold path was not generated. Set `generateGoldPath` flag to true when calling load()."] # Step - def step(self, input_str: str): + def step(self, input_str: str) -> tuple[str, int, bool, dict[str]]: '''Take a step. This function takes one step in the typical state-action-reward cycle of RL. @@ -417,23 +417,23 @@ def step(self, input_str: str): return observation, reward, isCompleted, infos # Special actions that are "free" (consume zero time) - def look(self): + def look(self) -> str: ''' Look around. This is a "free" action in that it consumes no time. ''' observation = self.server.freeActionLook() return observation - def inventory(self): + def inventory(self) -> str: ''' Check your inventory. This is a "free" action that consumes no time. ''' observation = self.server.freeActionInventory() return observation - def taskdescription(self): + def taskdescription(self) -> str: ''' Get the task description. This is a "free" action that consumes no time. ''' observation = self.server.freeActionTaskDesc() return observation # Goal progress - def get_goal_progress(self): + def get_goal_progress(self) -> str: ''' Get the progress to the goal. ''' goalStr = self.server.getGoalProgressStr() return goalStr From 21b0bb19411046d6a007598b6f816fa9eb4ae919 Mon Sep 17 00:00:00 2001 From: Andrew Kaminer Date: Mon, 15 Jan 2024 16:54:00 -0500 Subject: [PATCH 04/17] Add type hinting to function parameters --- scienceworld/scienceworld.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py index 249af1c1..cd72476a 100644 --- a/scienceworld/scienceworld.py +++ b/scienceworld/scienceworld.py @@ -276,7 +276,7 @@ def get_run_history(self) -> dict[str]: return jsonOut # History saving (provides an API to do this, so it's consistent across agents) - def store_run_history(self, episode_idx_key, notes) -> None: + def store_run_history(self, episode_idx_key: int, notes: str) -> None: '''Store the run history, with notes. :param episode_idx_key: Episode index. Will be used as key. @@ -290,7 +290,7 @@ def store_run_history(self, episode_idx_key, notes) -> None: self.runHistories[episode_idx_key] = packed - def save_run_histories(self, filename_out_prefix) -> None: + def save_run_histories(self, filename_out_prefix: str) -> None: '''Save the run histories to a file. :param filename_out_prefix: The name of the file to write to. @@ -321,7 +321,7 @@ def clear_run_histories(self) -> None: self.runHistories = {} # A one-stop function to handle saving. - def save_run_histories_buffer_if_full(self, filename_out_prefix, max_per_file=1000, force_save=False) -> None: + def save_run_histories_buffer_if_full(self, filename_out_prefix: str, max_per_file: int=1000, force_save: bool=False) -> None: '''One stop function for saving. If the histories buffer is full, saves to file and clears the buffer. From 47d85f7a7b55f529ac5d52f98875d715bb146913 Mon Sep 17 00:00:00 2001 From: Andrew Kaminer Date: Tue, 16 Jan 2024 08:20:58 -0500 Subject: [PATCH 05/17] Fix flake8 violations. Facepalm. --- scienceworld/scienceworld.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py index cd72476a..627687d8 100644 --- a/scienceworld/scienceworld.py +++ b/scienceworld/scienceworld.py @@ -16,7 +16,7 @@ class ScienceWorldEnv: Please look at that for more information on the internals of the system. """ - def __init__(self, taskName: str=None, serverPath: str=None, envStepLimit: int=100): + def __init__(self, taskName: str = None, serverPath: str = None, envStepLimit: int = 100): '''Start the simulator. Sets up the interface between python and the JVM. Also does basic init stuff. :param taskName: The name of the task. Will be run through the infer_task method. @@ -78,7 +78,8 @@ def __init__(self, taskName: str=None, serverPath: str=None, envStepLimit: int=1 self.goldPathGenerated = False # Ask the simulator to load an environment from a script - def load(self, taskName: str, variationIdx: int=0, simplificationStr: str="", generateGoldPath: bool=False) -> None: + def load(self, taskName: str, variationIdx: int = 0, simplificationStr: str = "", + generateGoldPath: bool = False) -> None: '''Load a valid task and its variations/simplifications, and set up the simulator and any task-specific properties (electrical, etc). Can optionally have the simulator generate a gold path. If it successfully does, it will set @@ -208,7 +209,7 @@ def get_valid_action_object_combinations(self) -> list[str]: ''' Get a list of all of the *valid* action-object combinations. ''' return list(self.server.getValidActionObjectCombinations()) - def get_valid_action_object_combinations_with_templates(self) -> list[dict[str]] : + def get_valid_action_object_combinations_with_templates(self) -> list[dict[str]]: ''' Returns list of dicts with keys "action", "template_id", and "obj_ids" ''' jsonStr = self.server.getValidActionObjectCombinationsJSON() data = json.loads(jsonStr) @@ -321,7 +322,8 @@ def clear_run_histories(self) -> None: self.runHistories = {} # A one-stop function to handle saving. - def save_run_histories_buffer_if_full(self, filename_out_prefix: str, max_per_file: int=1000, force_save: bool=False) -> None: + def save_run_histories_buffer_if_full(self, filename_out_prefix: str, + max_per_file: int = 1000, force_save: bool = False) -> None: '''One stop function for saving. If the histories buffer is full, saves to file and clears the buffer. From 34066320d845400286e44f7d0bd30620e72eae2a Mon Sep 17 00:00:00 2001 From: Andrew Kaminer Date: Tue, 16 Jan 2024 08:34:30 -0500 Subject: [PATCH 06/17] Fix type hinting for python 3.8. Forgot that it's different --- scienceworld/scienceworld.py | 45 ++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py index 627687d8..fbf5a353 100644 --- a/scienceworld/scienceworld.py +++ b/scienceworld/scienceworld.py @@ -1,3 +1,4 @@ +from typing import List, Dict, Tuple, Set, Any import json import logging from collections import OrderedDict @@ -127,7 +128,7 @@ def load(self, taskName: str, variationIdx: int = 0, simplificationStr: str = "" # Keep track of whether the gold path was generated, to generate verbose error messages self.goldPathGenerated = generateGoldPath - def reset(self) -> tuple[str, dict[str]]: + def reset(self) -> Tuple[str, Dict[str, Any]]: ''' Resets the simulator back to the first move (the output of "look around" is returned) ''' self.server.reset() @@ -146,7 +147,7 @@ def get_simplifications_used(self) -> str: ''' Gets the simplifications being used by the simulator. ''' return self.server.getSimplificationsUsed() - def get_possible_simplifications(self) -> list[str]: + def get_possible_simplifications(self) -> List[str]: '''Gets the 6 possible simplifications. There are 6 simplifictions: - teleportAction: Teleport action - selfWateringFlowerPots: Self-watering flower pots @@ -163,11 +164,11 @@ def tasks(self) -> OrderedDict[str, str]: return OrderedDict(ID2TASK) @property - def task_names(self) -> list[str]: + def task_names(self) -> List[str]: ''' Get the name for the supported tasks in ScienceWorld. ''' return list(ID2TASK.values()) - def get_task_names(self) -> list[str]: + def get_task_names(self) -> List[str]: ''' Get the name for the supported tasks in ScienceWorld. ''' return list(self.server.getTaskNames()) @@ -176,60 +177,60 @@ def get_max_variations(self, task_name) -> int: return self.server.getTaskMaxVariations(infer_task(task_name)) # Get possible actions - def get_possible_actions(self) -> list[str]: + def get_possible_actions(self) -> List[str]: ''' Get all possible actions in the current environment state. ''' return list(self.server.getPossibleActions()) # Get possible actions (and also include the template IDs for those actions) - def get_possible_actions_with_IDs(self) -> list[dict[str]]: + def get_possible_actions_with_IDs(self) -> List[Dict[str, Any]]: ''' Get a list of dictionaries that map "action_example" to the action template and "template_id" to the id.''' jsonStr = self.server.getPossibleActionsWithIDs() data = json.loads(jsonStr) return data - def get_possible_objects(self) -> list[str]: + def get_possible_objects(self) -> List[str]: ''' Get a list of all observable objects ''' return list(self.server.getPossibleObjects()) # Get a list of object_ids to unique referents - def get_possible_object_referent_LUT(self) -> dict[str, str]: + def get_possible_object_referent_LUT(self) -> Dict[str, str]: ''' Returns lookup table (dict) mapping object IDs to their referents. ''' jsonStr = self.server.getPossibleObjectReferentLUTJSON() data = json.loads(jsonStr) return data # As above, but dictionary is referenced by object type ID - def get_possible_object_referent_types_LUT(self) -> dict[str, dict[str, str]]: + def get_possible_object_referent_types_LUT(self) -> Dict[str, Dict[str, str]]: ''' Returns lookup table (dict) mapping object type IDs to a dict of all objects of that type. ''' jsonStr = self.server.getPossibleObjectReferentTypesLUTJSON() data = json.loads(jsonStr) return data - def get_valid_action_object_combinations(self) -> list[str]: + def get_valid_action_object_combinations(self) -> List[str]: ''' Get a list of all of the *valid* action-object combinations. ''' return list(self.server.getValidActionObjectCombinations()) - def get_valid_action_object_combinations_with_templates(self) -> list[dict[str]]: + def get_valid_action_object_combinations_with_templates(self) -> List[Dict[str, Any]]: ''' Returns list of dicts with keys "action", "template_id", and "obj_ids" ''' jsonStr = self.server.getValidActionObjectCombinationsJSON() data = json.loads(jsonStr) return data['validActions'] - def get_all_object_types_LUTJSON(self) -> dict[str, str]: + def get_all_object_types_LUTJSON(self) -> Dict[str, str]: ''' Returns look up table mapping object ids to type ids ''' jsonStr = self.server.getAllObjectTypesLUTJSON() data = json.loads(jsonStr) return data # Get a LUT of {object_id: {type_id, referent:[]} } tuples - def get_all_object_ids_types_referents_LUTJSON(self) -> dict[str, dict[str]]: + def get_all_object_ids_types_referents_LUTJSON(self) -> Dict[str, Dict[str, Any]]: ''' Returns look up table mapping object ids to objects with keys "type_id" and "referents" ''' jsonStr = self.server.getAllObjectIdsTypesReferentsLUTJSON() data = json.loads(jsonStr) return data # Get possible action/object combinations - def get_possible_action_object_combinations(self) -> tuple[list[dict[str]], dict[str, str]]: + def get_possible_action_object_combinations(self) -> Tuple[List[Dict[str, Any]], Dict[str, str]]: ''' Get all *possible* action-object combinations, including invalid ones. ''' combinedJSON = self.server.getPossibleActionObjectCombinationsJSON() data = json.loads(combinedJSON) @@ -238,7 +239,7 @@ def get_possible_action_object_combinations(self) -> tuple[list[dict[str]], dict return (templates, lookUpTable) - def get_object_types(self) -> dict[str, int]: + def get_object_types(self) -> Dict[str, int]: '''Get a dict mapping object names to the object id. The object name is the name of the actual file, for example "scienceworld.objects.containers.furniture.Chair". ''' @@ -246,7 +247,7 @@ def get_object_types(self) -> dict[str, int]: data = json.loads(jsonStr) return data - def get_vocabulary(self) -> set[str]: + def get_vocabulary(self) -> Set[str]: ''' Get all words that currently have some sort of meaning to the simulator. ''' vocab = set() @@ -270,7 +271,7 @@ def get_task_description(self) -> str: return self.server.getTaskDescription() # History - def get_run_history(self) -> dict[str]: + def get_run_history(self) -> Dict[str, Any]: ''' Get the run history ''' historyStr = self.server.getRunHistoryJSON() jsonOut = json.loads(historyStr) @@ -338,15 +339,15 @@ def save_run_histories_buffer_if_full(self, filename_out_prefix: str, self.clear_run_histories() # Train/development/test sets - def get_variations_train(self) -> list[int]: + def get_variations_train(self) -> List[int]: ''' Get the list of variations available for the training set. ''' return list(self.server.getVariationsTrain()) - def get_variations_dev(self) -> list[int]: + def get_variations_dev(self) -> List[int]: ''' Get the list of variations available for the development set. ''' return list(self.server.getVariationsDev()) - def get_variations_test(self) -> list[int]: + def get_variations_test(self) -> List[int]: ''' Get the list of variations available for the testing set. ''' return list(self.server.getVariationsTest()) @@ -363,7 +364,7 @@ def get_random_variation_test(self) -> int: return self.server.getRandomVariationTest() # Gold action sequence - def get_gold_action_sequence(self) -> list[str]: + def get_gold_action_sequence(self) -> List[str]: '''Get the gold action sequence. The gold action sequence is the optimal sequence of actions. This function returns that if it is generated. If it is not generated, it generates an error. @@ -374,7 +375,7 @@ def get_gold_action_sequence(self) -> list[str]: return ["ERROR: Gold path was not generated. Set `generateGoldPath` flag to true when calling load()."] # Step - def step(self, input_str: str) -> tuple[str, int, bool, dict[str]]: + def step(self, input_str: str) -> Tuple[str, int, bool, Dict[str, Any]]: '''Take a step. This function takes one step in the typical state-action-reward cycle of RL. From fe59e3105d8f91ea88afb52f5f766d7f889c3400 Mon Sep 17 00:00:00 2001 From: Andrew Kaminer Date: Tue, 16 Jan 2024 08:40:43 -0500 Subject: [PATCH 07/17] Forgot OrderedDict --- scienceworld/scienceworld.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py index fbf5a353..31933a69 100644 --- a/scienceworld/scienceworld.py +++ b/scienceworld/scienceworld.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Tuple, Set, Any +from typing import List, Dict, Tuple, Set, Any, OrderedDict import json import logging from collections import OrderedDict From 70afaa57d3c5aae1c042750e4e5a453778930f31 Mon Sep 17 00:00:00 2001 From: Andrew Kaminer Date: Tue, 16 Jan 2024 08:46:03 -0500 Subject: [PATCH 08/17] Fixed naming conflict with OrderedDict --- scienceworld/scienceworld.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py index 31933a69..70359968 100644 --- a/scienceworld/scienceworld.py +++ b/scienceworld/scienceworld.py @@ -1,4 +1,5 @@ -from typing import List, Dict, Tuple, Set, Any, OrderedDict +from typing import List, Dict, Tuple, Set, Any +from typing import OrderedDict as OrderedDictType import json import logging from collections import OrderedDict @@ -159,7 +160,7 @@ def get_possible_simplifications(self) -> List[str]: return self.server.getPossibleSimplifications().split(", ") @property - def tasks(self) -> OrderedDict[str, str]: + def tasks(self) -> OrderedDictType[str, str]: """ Get the supported tasks in ScienceWorld. """ return OrderedDict(ID2TASK) From 4ce0e88207c78339d1731d579c1860c232208d30 Mon Sep 17 00:00:00 2001 From: Andrew Kaminer <94922098+AndKaminer@users.noreply.github.com> Date: Tue, 23 Jan 2024 08:05:09 -0500 Subject: [PATCH 09/17] Update scienceworld/scienceworld.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Marc-Alexandre Côté --- scienceworld/scienceworld.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py index 70359968..e2130f2e 100644 --- a/scienceworld/scienceworld.py +++ b/scienceworld/scienceworld.py @@ -149,7 +149,7 @@ def get_simplifications_used(self) -> str: return self.server.getSimplificationsUsed() def get_possible_simplifications(self) -> List[str]: - '''Gets the 6 possible simplifications. There are 6 simplifictions: + '''Gets the 6 possible simplifications. Those are: - teleportAction: Teleport action - selfWateringFlowerPots: Self-watering flower pots - openContainers: Containers open by default From 2acdc4379916b18d4128b7ff4711b660c7b85f7f Mon Sep 17 00:00:00 2001 From: Andrew Kaminer <94922098+AndKaminer@users.noreply.github.com> Date: Tue, 23 Jan 2024 08:05:16 -0500 Subject: [PATCH 10/17] Update scienceworld/scienceworld.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Marc-Alexandre Côté --- scienceworld/scienceworld.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py index e2130f2e..28c48caa 100644 --- a/scienceworld/scienceworld.py +++ b/scienceworld/scienceworld.py @@ -150,7 +150,7 @@ def get_simplifications_used(self) -> str: def get_possible_simplifications(self) -> List[str]: '''Gets the 6 possible simplifications. Those are: - - teleportAction: Teleport action + - teleportAction: Adds actions to teleport directly to any possible location - selfWateringFlowerPots: Self-watering flower pots - openContainers: Containers open by default - openDoors: Doors open by default From 1a46027dfaec7308a0721a411d11471feb88eb06 Mon Sep 17 00:00:00 2001 From: Andrew Kaminer <94922098+AndKaminer@users.noreply.github.com> Date: Tue, 23 Jan 2024 08:05:23 -0500 Subject: [PATCH 11/17] Update scienceworld/scienceworld.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Marc-Alexandre Côté --- scienceworld/scienceworld.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py index 28c48caa..71be2525 100644 --- a/scienceworld/scienceworld.py +++ b/scienceworld/scienceworld.py @@ -151,7 +151,7 @@ def get_simplifications_used(self) -> str: def get_possible_simplifications(self) -> List[str]: '''Gets the 6 possible simplifications. Those are: - teleportAction: Adds actions to teleport directly to any possible location - - selfWateringFlowerPots: Self-watering flower pots + - selfWateringFlowerPots: Flower pots will water themselves such that the plants won't die - openContainers: Containers open by default - openDoors: Doors open by default - noElectricalAction: Remove the electrical actions From 5cd8c92ca82e79ead5f2bcb310826ab3c44d1c52 Mon Sep 17 00:00:00 2001 From: Andrew Kaminer <94922098+AndKaminer@users.noreply.github.com> Date: Tue, 23 Jan 2024 08:05:29 -0500 Subject: [PATCH 12/17] Update scienceworld/scienceworld.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Marc-Alexandre Côté --- scienceworld/scienceworld.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py index 71be2525..8cc1ec14 100644 --- a/scienceworld/scienceworld.py +++ b/scienceworld/scienceworld.py @@ -152,7 +152,7 @@ def get_possible_simplifications(self) -> List[str]: '''Gets the 6 possible simplifications. Those are: - teleportAction: Adds actions to teleport directly to any possible location - selfWateringFlowerPots: Flower pots will water themselves such that the plants won't die - - openContainers: Containers open by default + - openContainers: Containers are open by default - openDoors: Doors open by default - noElectricalAction: Remove the electrical actions - easy: use all 5 simplifications From d2400482a0744adab049feea0454d384098a80b0 Mon Sep 17 00:00:00 2001 From: Andrew Kaminer <94922098+AndKaminer@users.noreply.github.com> Date: Tue, 23 Jan 2024 08:05:35 -0500 Subject: [PATCH 13/17] Update scienceworld/scienceworld.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Marc-Alexandre Côté --- scienceworld/scienceworld.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py index 8cc1ec14..70110d41 100644 --- a/scienceworld/scienceworld.py +++ b/scienceworld/scienceworld.py @@ -153,7 +153,7 @@ def get_possible_simplifications(self) -> List[str]: - teleportAction: Adds actions to teleport directly to any possible location - selfWateringFlowerPots: Flower pots will water themselves such that the plants won't die - openContainers: Containers are open by default - - openDoors: Doors open by default + - openDoors: Doors open are by default - noElectricalAction: Remove the electrical actions - easy: use all 5 simplifications ''' From ff93f5550f760a69e630f6281425532804d8885a Mon Sep 17 00:00:00 2001 From: Andrew Kaminer <94922098+AndKaminer@users.noreply.github.com> Date: Tue, 23 Jan 2024 08:05:43 -0500 Subject: [PATCH 14/17] Update scienceworld/scienceworld.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Marc-Alexandre Côté --- scienceworld/scienceworld.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py index 70110d41..eb7fc43c 100644 --- a/scienceworld/scienceworld.py +++ b/scienceworld/scienceworld.py @@ -154,7 +154,7 @@ def get_possible_simplifications(self) -> List[str]: - selfWateringFlowerPots: Flower pots will water themselves such that the plants won't die - openContainers: Containers are open by default - openDoors: Doors open are by default - - noElectricalAction: Remove the electrical actions + - noElectricalAction: Remove all `connect X to Y` actions to reduce the action space - easy: use all 5 simplifications ''' return self.server.getPossibleSimplifications().split(", ") From 7251a1516308597a7bdbe724e2da531f4721bfd2 Mon Sep 17 00:00:00 2001 From: Andrew Kaminer <94922098+AndKaminer@users.noreply.github.com> Date: Tue, 23 Jan 2024 08:05:50 -0500 Subject: [PATCH 15/17] Update scienceworld/scienceworld.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Marc-Alexandre Côté --- scienceworld/scienceworld.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py index eb7fc43c..130aba59 100644 --- a/scienceworld/scienceworld.py +++ b/scienceworld/scienceworld.py @@ -155,7 +155,7 @@ def get_possible_simplifications(self) -> List[str]: - openContainers: Containers are open by default - openDoors: Doors open are by default - noElectricalAction: Remove all `connect X to Y` actions to reduce the action space - - easy: use all 5 simplifications + - easy: use all above simplifications ''' return self.server.getPossibleSimplifications().split(", ") From fe5ce2af9cbe0c68d0ced895c945dcd08d97e433 Mon Sep 17 00:00:00 2001 From: Andrew Kaminer <94922098+AndKaminer@users.noreply.github.com> Date: Tue, 23 Jan 2024 08:06:11 -0500 Subject: [PATCH 16/17] Update scienceworld/scienceworld.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Marc-Alexandre Côté --- scienceworld/scienceworld.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py index 130aba59..e86363b1 100644 --- a/scienceworld/scienceworld.py +++ b/scienceworld/scienceworld.py @@ -367,7 +367,7 @@ def get_random_variation_test(self) -> int: # Gold action sequence def get_gold_action_sequence(self) -> List[str]: '''Get the gold action sequence. - The gold action sequence is the optimal sequence of actions. This function returns that if it is generated. + The gold action sequence is a sequence of actions that leads to a winning state (there is no guarantee it is the optimal). This function returns that if it is generated. If it is not generated, it generates an error. ''' if (self.goldPathGenerated): From 11d5f18ecb3d7ab05b6772145d4df7dc6aa6d5c2 Mon Sep 17 00:00:00 2001 From: Andrew Kaminer Date: Tue, 23 Jan 2024 09:43:36 -0500 Subject: [PATCH 17/17] Fix flake8 --- scienceworld/scienceworld.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py index e86363b1..c8f67a88 100644 --- a/scienceworld/scienceworld.py +++ b/scienceworld/scienceworld.py @@ -367,7 +367,8 @@ def get_random_variation_test(self) -> int: # Gold action sequence def get_gold_action_sequence(self) -> List[str]: '''Get the gold action sequence. - The gold action sequence is a sequence of actions that leads to a winning state (there is no guarantee it is the optimal). This function returns that if it is generated. + The gold action sequence is a sequence of actions that leads to a winning state + (there is no guarantee it is the optimal). This function returns that if it is generated. If it is not generated, it generates an error. ''' if (self.goldPathGenerated):