Skip to content

Add docstrings #56 #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 142 additions & 53 deletions scienceworld/scienceworld.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Dict, Tuple, Set, Any
from typing import OrderedDict as OrderedDictType
import json
import logging
from collections import OrderedDict
Expand All @@ -11,8 +13,19 @@


class ScienceWorldEnv:

def __init__(self, taskName=None, serverPath=None, envStepLimit=100):
"""Python wrapper for the simulator written in Scala. The methods that are
being wrapped can be found in simulator/src/main/scala/scienceworld/runtime/AgentInterface.scala.
Please look at that for more information on the internals of the system.
"""

def __init__(self, taskName: str = None, serverPath: str = None, envStepLimit: int = 100):
'''Start the simulator. Sets up the interface between python and the JVM.
Also does basic init stuff.
:param taskName: The name of the task. Will be run through the infer_task method.
Tasks can also be loaded by the load method.
:param serverPath: The filepath to the server. By default, it is just scienceworld.jar.
:param envStepLimit: The maximum number of steps taken in the environment. Defaults to 100.
'''
serverPath = serverPath or JAR_PATH # Use the builtin jar.

# Launch the server and connect to the JVM.
Expand Down Expand Up @@ -67,9 +80,19 @@ def __init__(self, taskName=None, serverPath=None, envStepLimit=100):
self.goldPathGenerated = False

# Ask the simulator to load an environment from a script
def load(self, taskName, variationIdx=0, simplificationStr="", generateGoldPath=False):
""" Load a given task and its variation. """

def load(self, taskName: str, variationIdx: int = 0, simplificationStr: str = "",
generateGoldPath: bool = False) -> None:
'''Load a valid task and its variations/simplifications, and set up the simulator
and any task-specific properties (electrical, etc). Can optionally have the
simulator generate a gold path. If it successfully does, it will set
self.goldPathGenerated to True.

:param taskName: The name of the task. Will be modified by the infer_task function.
:param variationIdx: The index for the specific variation to use. Default is 0.
:param simplificationStr: The string of simplifications to use. Should be comma
separated with no spaces. Defaults to "". For more, see get_possible_simplifications
:param generateGoldPath: Boolean var to generate gold path or not. Defaults to False.
'''
# Check loading arguments.
# Validate task name.
taskName = infer_task(taskName)
Expand Down Expand Up @@ -106,8 +129,9 @@ def load(self, taskName, variationIdx=0, simplificationStr="", generateGoldPath=
# Keep track of whether the gold path was generated, to generate verbose error messages
self.goldPathGenerated = generateGoldPath

# Ask the simulator to reset an environment back to it's initial state
def reset(self):
def reset(self) -> Tuple[str, Dict[str, Any]]:
''' Resets the simulator back to the first move (the output of "look around" is returned) '''

self.server.reset()

# Reset last step score (used to calculate reward from current-previous score)
Expand All @@ -120,94 +144,112 @@ def reset(self):
return observation, info

# Simplifications
def get_simplifications_used(self):
def get_simplifications_used(self) -> str:
''' Gets the simplifications being used by the simulator. '''
return self.server.getSimplificationsUsed()

def get_possible_simplifications(self):
def get_possible_simplifications(self) -> List[str]:
'''Gets the 6 possible simplifications. There are 6 simplifictions:
- teleportAction: Teleport action
- selfWateringFlowerPots: Self-watering flower pots
- openContainers: Containers open by default
- openDoors: Doors open by default
- noElectricalAction: Remove the electrical actions
- easy: use all 5 simplifications
'''
return self.server.getPossibleSimplifications().split(", ")

@property
def tasks(self):
def tasks(self) -> OrderedDictType[str, str]:
""" Get the supported tasks in ScienceWorld. """
return OrderedDict(ID2TASK)

@property
def task_names(self):
""" Get the name for the supported tasks in ScienceWorld. """
def task_names(self) -> List[str]:
''' Get the name for the supported tasks in ScienceWorld. '''
return list(ID2TASK.values())

def get_task_names(self):
""" Get the name for the supported tasks in ScienceWorld. """
def get_task_names(self) -> List[str]:
''' Get the name for the supported tasks in ScienceWorld. '''
return list(self.server.getTaskNames())

# Get the maximum number of variations for this task
def get_max_variations(self, task_name):
def get_max_variations(self, task_name) -> int:
''' Get the maximum number of variations for the tasks. '''
return self.server.getTaskMaxVariations(infer_task(task_name))

# Get possible actions
def get_possible_actions(self):
def get_possible_actions(self) -> List[str]:
''' Get all possible actions in the current environment state. '''
return list(self.server.getPossibleActions())

# Get possible actions (and also include the template IDs for those actions)
def get_possible_actions_with_IDs(self):
def get_possible_actions_with_IDs(self) -> List[Dict[str, Any]]:
''' Get a list of dictionaries that map "action_example" to the action template and "template_id" to the id.'''
jsonStr = self.server.getPossibleActionsWithIDs()
data = json.loads(jsonStr)
return data

# Get possible objects
def get_possible_objects(self):
def get_possible_objects(self) -> List[str]:
''' Get a list of all observable objects '''
return list(self.server.getPossibleObjects())

# Get a list of object_ids to unique referents
def get_possible_object_referent_LUT(self):
def get_possible_object_referent_LUT(self) -> Dict[str, str]:
''' Returns lookup table (dict) mapping object IDs to their referents. '''
jsonStr = self.server.getPossibleObjectReferentLUTJSON()
data = json.loads(jsonStr)
return data

# As above, but dictionary is referenced by object type ID
def get_possible_object_referent_types_LUT(self):
def get_possible_object_referent_types_LUT(self) -> Dict[str, Dict[str, str]]:
''' Returns lookup table (dict) mapping object type IDs to a dict of all objects of that type. '''
jsonStr = self.server.getPossibleObjectReferentTypesLUTJSON()
data = json.loads(jsonStr)
return data

# Get a list of *valid* agent-object combinations
def get_valid_action_object_combinations(self):
def get_valid_action_object_combinations(self) -> List[str]:
''' Get a list of all of the *valid* action-object combinations. '''
return list(self.server.getValidActionObjectCombinations())

def get_valid_action_object_combinations_with_templates(self):
def get_valid_action_object_combinations_with_templates(self) -> List[Dict[str, Any]]:
''' Returns list of dicts with keys "action", "template_id", and "obj_ids" '''
jsonStr = self.server.getValidActionObjectCombinationsJSON()
data = json.loads(jsonStr)
return data['validActions']

# Get a LUT of object_id to type_id
def get_all_object_types_LUTJSON(self):
def get_all_object_types_LUTJSON(self) -> Dict[str, str]:
''' Returns look up table mapping object ids to type ids '''
jsonStr = self.server.getAllObjectTypesLUTJSON()
data = json.loads(jsonStr)
return data

# Get a LUT of {object_id: {type_id, referent:[]} } tuples
def get_all_object_ids_types_referents_LUTJSON(self):
def get_all_object_ids_types_referents_LUTJSON(self) -> Dict[str, Dict[str, Any]]:
''' Returns look up table mapping object ids to objects with keys "type_id" and "referents" '''
jsonStr = self.server.getAllObjectIdsTypesReferentsLUTJSON()
data = json.loads(jsonStr)
return data

# Get possible action/object combinations
def get_possible_action_object_combinations(self):
def get_possible_action_object_combinations(self) -> Tuple[List[Dict[str, Any]], Dict[str, str]]:
''' Get all *possible* action-object combinations, including invalid ones. '''
combinedJSON = self.server.getPossibleActionObjectCombinationsJSON()
data = json.loads(combinedJSON)
templates = data['templates']
lookUpTable = data['lookUpTable']

return (templates, lookUpTable)

# Get a list of object types and their IDs
def get_object_types(self):
def get_object_types(self) -> Dict[str, int]:
'''Get a dict mapping object names to the object id. The object name is the name
of the actual file, for example "scienceworld.objects.containers.furniture.Chair".
'''
jsonStr = self.server.getObjectTypesLUTJSON()
data = json.loads(jsonStr)
return data

# Get the vocabulary of the model (at the current state)
def get_vocabulary(self):
def get_vocabulary(self) -> Set[str]:
''' Get all words that currently have some sort of meaning to the simulator. '''
vocab = set()

# Action vocabulary
Expand All @@ -221,20 +263,28 @@ def get_vocabulary(self):

return vocab

def get_num_moves(self):
def get_num_moves(self) -> int:
''' Get the current number of moves. '''
return self.server.getNumMoves()

def get_task_description(self):
def get_task_description(self) -> str:
''' Get the description of the current task. '''
return self.server.getTaskDescription()

# History
def get_run_history(self):
def get_run_history(self) -> Dict[str, Any]:
''' Get the run history '''
historyStr = self.server.getRunHistoryJSON()
jsonOut = json.loads(historyStr)
return jsonOut

# History saving (provides an API to do this, so it's consistent across agents)
def store_run_history(self, episode_idx_key, notes):
def store_run_history(self, episode_idx_key: int, notes: str) -> None:
'''Store the run history, with notes.

:param episode_idx_key: Episode index. Will be used as key.
:param notes: Notes on the run.
'''
packed = {
'episodeIdx': episode_idx_key,
'notes': notes,
Expand All @@ -243,7 +293,11 @@ def store_run_history(self, episode_idx_key, notes):

self.runHistories[episode_idx_key] = packed

def save_run_histories(self, filename_out_prefix):
def save_run_histories(self, filename_out_prefix: str) -> None:
'''Save the run histories to a file.

:param filename_out_prefix: The name of the file to write to.
'''
# Save history

# Create verbose filename
Expand All @@ -261,46 +315,77 @@ def save_run_histories(self, filename_out_prefix):
with open(filenameOut, 'w') as outfile:
json.dump(self.runHistories, outfile, sort_keys=True, indent=4)

def get_run_history_size(self):
def get_run_history_size(self) -> int:
''' Get the size of the run history '''
return len(self.runHistories)

def clear_run_histories(self):
def clear_run_histories(self) -> None:
''' Clear the run histories. '''
self.runHistories = {}

# A one-stop function to handle saving.
def save_run_histories_buffer_if_full(self, filename_out_prefix, max_per_file=1000, force_save=False):
def save_run_histories_buffer_if_full(self, filename_out_prefix: str,
max_per_file: int = 1000, force_save: bool = False) -> None:
'''One stop function for saving.

If the histories buffer is full, saves to file and clears the buffer.

:param filename_out_prefix: Name of the file to write to.
:param max_per_file: The max number of histories per file. Defaults to 1000.
:param force_save: Force the function to save, regardless of whether or not the
buffer is full. Defaults to False.
'''
if ((self.get_run_history_size() >= max_per_file) or force_save):
self.save_run_histories(filename_out_prefix)
self.clear_run_histories()

# Train/development/test sets
def get_variations_train(self):
def get_variations_train(self) -> List[int]:
''' Get the list of variations available for the training set. '''
return list(self.server.getVariationsTrain())

def get_variations_dev(self):
def get_variations_dev(self) -> List[int]:
''' Get the list of variations available for the development set. '''
return list(self.server.getVariationsDev())

def get_variations_test(self):
def get_variations_test(self) -> List[int]:
''' Get the list of variations available for the testing set. '''
return list(self.server.getVariationsTest())

def get_random_variation_train(self):
def get_random_variation_train(self) -> int:
''' Get a single random variation from those available for the training set. '''
return self.server.getRandomVariationTrain()

def get_random_variation_dev(self):
def get_random_variation_dev(self) -> int:
''' Get a single random variation from those available for the development set. '''
return self.server.getRandomVariationDev()

def get_random_variation_test(self):
def get_random_variation_test(self) -> int:
''' Get a single random variation from those available for the testing set. '''
return self.server.getRandomVariationTest()

# Gold action sequence
def get_gold_action_sequence(self):
def get_gold_action_sequence(self) -> List[str]:
'''Get the gold action sequence.
The gold action sequence is the optimal sequence of actions. This function returns that if it is generated.
If it is not generated, it generates an error.
'''
if (self.goldPathGenerated):
return list(self.server.getGoldActionSequence())
else:
return ["ERROR: Gold path was not generated. Set `generateGoldPath` flag to true when calling load()."]

# Step
def step(self, input_str: str):
def step(self, input_str: str) -> Tuple[str, int, bool, Dict[str, Any]]:
'''Take a step.

This function takes one step in the typical state-action-reward cycle of RL.
:param input_str: The input string supplied to the simulator from an agent.

Returns the observation, reward, completion status, and infos dict consisting of:
'moves', 'score', 'reward', 'look', 'inv', 'taskDesc', 'valid', 'variationIdx', 'taskName',
and 'simplificationStr'.
'''
observation = self.server.step(input_str)
score = int(round(100 * self.server.getScore())) # Convert from 0-1 to 0-100
isCompleted = self.server.getCompleted()
Expand Down Expand Up @@ -336,20 +421,24 @@ def step(self, input_str: str):
return observation, reward, isCompleted, infos

# Special actions that are "free" (consume zero time)
def look(self):
def look(self) -> str:
''' Look around. This is a "free" action in that it consumes no time. '''
observation = self.server.freeActionLook()
return observation

def inventory(self):
def inventory(self) -> str:
''' Check your inventory. This is a "free" action that consumes no time. '''
observation = self.server.freeActionInventory()
return observation

def taskdescription(self):
def taskdescription(self) -> str:
''' Get the task description. This is a "free" action that consumes no time. '''
observation = self.server.freeActionTaskDesc()
return observation

# Goal progress
def get_goal_progress(self):
def get_goal_progress(self) -> str:
''' Get the progress to the goal. '''
goalStr = self.server.getGoalProgressStr()
return goalStr

Expand Down
2 changes: 2 additions & 0 deletions scienceworld/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@


def infer_task(name_or_id):
''' Takes a task name or task ID and processes it to produce a uniform task format. '''

if name_or_id in NAME2ID:
name_or_id = NAME2ID[name_or_id]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class SimplificationOpenDoors extends Simplification(label = SIMPLIFICATION_OPEN
}


// Simplification: Open all doors in the environment
// Simplification: Open all containers in the environment
class SimplificationOpenContainers extends Simplification(label = SIMPLIFICATION_OPEN_CONTAINERS, description = "All containers are open by default.") {
runAtInitialization = true

Expand Down