diff --git a/examples/human.py b/examples/human.py
index 067a7ef..e300b89 100644
--- a/examples/human.py
+++ b/examples/human.py
@@ -13,7 +13,7 @@ def userConsole(args):
# Initialize environment
env = ScienceWorldEnv("", args['jar_path'], envStepLimit=args['env_step_limit'])
- taskNames = env.getTaskNames()
+ taskNames = env.get_task_names()
print("Task Names: " + str(taskNames))
# Choose task
@@ -30,41 +30,41 @@ def userConsole(args):
# (Many of these are similar to the Jericho API)
#
print("Task Names: " + str(taskNames))
- print("Possible actions: " + str(env.getPossibleActions()))
- print("Possible objects: " + str(env.getPossibleObjects()))
- templates, lut = env.getPossibleActionObjectCombinations()
+ print("Possible actions: " + str(env.get_possible_actions()))
+ print("Possible objects: " + str(env.get_possible_objects()))
+ templates, lut = env.get_possible_action_object_combinations()
print("Possible action/object combinations: " + str(templates))
# print("Object IDX to Object Referent LUT: " + str(lut))
- print("Vocabulary: " + str(env.getVocabulary()))
- print("Possible actions (with IDs): " + str(env.getPossibleActionsWithIDs()))
- print("Possible object types: " + str(env.getObjectTypes()))
+ print("Vocabulary: " + str(env.get_vocabulary()))
+ print("Possible actions (with IDs): " + str(env.get_possible_actions_with_IDs()))
+ print("Possible object types: " + str(env.get_object_types()))
print("Object IDX to Object Referent LUT: " + str(lut))
print("\n")
- print("Possible object referents LUT: " + str(env.getPossibleObjectReferentLUT()))
+ print("Possible object referents LUT: " + str(env.get_possible_object_referent_LUT()))
print("\n")
print("Valid action-object combinations: " +
- str(env.getValidActionObjectCombinations()))
+ str(env.get_valid_action_object_combinations()))
print("\n")
- print("Object_ids to type_ids: " + str(env.getAllObjectTypesLUTJSON()))
+ print("Object_ids to type_ids: " + str(env.get_all_object_types_LUTJSON()))
print("\n")
print("All objects, their ids, types, and referents: " +
- str(env.getAllObjectIdsTypesReferentsLUTJSON()))
+ str(env.get_all_object_ids_types_referents_LUTJSON()))
print("\n")
print("Valid action-object combinations (with templates): " +
- str(env.getValidActionObjectCombinationsWithTemplates()))
+ str(env.get_valid_action_object_combinations_with_templates()))
print("\n")
- print("Object Type LUT: " + str(env.getPossibleObjectReferentTypesLUT()))
- print("Variations (train): " + str(env.getVariationsTrain()))
+ print("Object Type LUT: " + str(env.get_possible_object_referent_types_LUT()))
+ print("Variations (train): " + str(env.get_variations_train()))
print("")
print("----------------------------------------------------------------------------------")
print("")
- print("Gold Path:" + str(env.getGoldActionSequence()))
+ print("Gold Path:" + str(env.get_gold_action_sequence()))
print("Task Name: " + taskName)
- print("Variation: " + str(args['var_num']) + " / " + str(env.getMaxVariations(taskName)))
- print("Task Description: " + str(env.getTaskDescription()))
+ print("Variation: " + str(args['var_num']) + " / " + str(env.get_max_variations(taskName)))
+ print("Task Description: " + str(env.get_task_description()))
#
# Main user input loop
@@ -73,20 +73,20 @@ def userConsole(args):
while (userInputStr not in exitCommands):
if (userInputStr == "help"):
print("Possible actions: ")
- for actionStr in env.getPossibleActions():
+ for actionStr in env.get_possible_actions():
print("\t" + str(actionStr))
elif (userInputStr == "objects"):
print("Possible objects (one referent listed per object): ")
- for actionStr in env.getPossibleObjects():
+ for actionStr in env.get_possible_objects():
print("\t" + str(actionStr))
elif (userInputStr == "valid"):
print("Valid action-object combinations:")
- print(env.getValidActionObjectCombinationsWithTemplates())
+ print(env.get_valid_action_object_combinations_with_templates())
elif (userInputStr == 'goals'):
- print(env.getGoalProgressStr())
+ print(env.get_goal_progress())
else:
# Send user input, get response
@@ -109,7 +109,7 @@ def userConsole(args):
userInputStr = userInputStr.lower().strip()
# Display run history
- runHistory = env.getRunHistory()
+ runHistory = env.get_run_history()
print("Run History:")
print(runHistory)
for item in runHistory:
@@ -117,7 +117,7 @@ def userConsole(args):
print("")
# Display subgoal progress
- print(env.getGoalProgressStr())
+ print(env.get_goal_progress_str())
print("Completed.")
diff --git a/examples/random_agent.py b/examples/random_agent.py
index 1bbf38f..bf81652 100644
--- a/examples/random_agent.py
+++ b/examples/random_agent.py
@@ -19,35 +19,35 @@ def randomModel(args):
# Initialize environment
env = ScienceWorldEnv("", args['jar_path'], envStepLimit=args['env_step_limit'])
- taskNames = env.getTaskNames()
+ taskNames = env.get_task_names()
print("Task Names: " + str(taskNames))
# Choose task
taskName = taskNames[taskIdx] # Just get first task
# Load the task, we we have access to some extra accessors e.g. get_random_variation_train()
env.load(taskName, 0, "")
- maxVariations = env.getMaxVariations(taskName)
+ maxVariations = env.get_max_variations(taskName)
print("Starting Task " + str(taskIdx) + ": " + taskName)
time.sleep(2)
# Start running episodes
for episodeIdx in range(0, numEpisodes):
# Pick a random task variation
- randVariationIdx = env.getRandomVariationTrain()
+ randVariationIdx = env.get_random_variation_train()
env.load(taskName, randVariationIdx, simplificationStr)
# Reset the environment
initialObs, initialDict = env.reset()
# Example accessors
- print("Possible actions: " + str(env.getPossibleActions()))
- print("Possible objects: " + str(env.getPossibleObjects()))
- templates, lut = env.getPossibleActionObjectCombinations()
+ print("Possible actions: " + str(env.get_possible_actions()))
+ print("Possible objects: " + str(env.get_possible_objects()))
+ templates, lut = env.get_possible_action_object_combinations()
print("Possible action/object combinations: " + str(templates))
print("Object IDX to Object Referent LUT: " + str(lut))
print("Task Name: " + taskName)
print("Task Variation: " + str(randVariationIdx) + " / " + str(maxVariations))
- print("Task Description: " + str(env.getTaskDescription()))
+ print("Task Description: " + str(env.get_task_description()))
print("look: " + str(env.look()))
print("inventory: " + str(env.inventory()))
print("taskdescription: " + str(env.taskdescription()))
@@ -79,7 +79,7 @@ def randomModel(args):
# Randomly select action
# Any action (valid or not)
- # templates, lut = env.getPossibleActionObjectCombinations()
+ # templates, lut = env.get_possible_action_object_combinations()
# print("Possible action/object combinations: " + str(templates))
# print("Object IDX to Object Referent LUT: " + str(lut))
# randomTemplate = random.choice( templates )
@@ -87,7 +87,7 @@ def randomModel(args):
# userInputStr = randomTemplate["action"]
# Only valid actions
- validActions = env.getValidActionObjectCombinationsWithTemplates()
+ validActions = env.get_valid_action_object_combinations_with_templates()
randomAction = random.choice(validActions)
print("Next random action: " + str(randomAction))
userInputStr = randomAction["action"]
@@ -102,7 +102,7 @@ def randomModel(args):
curIter += 1
print("Goal Progress:")
- print(env.getGoalProgressStr())
+ print(env.get_goal_progress_str())
time.sleep(1)
# Episode finished -- Record the final score
@@ -114,11 +114,11 @@ def randomModel(args):
# Save history -- and when we reach maxPerFile, export them to file
filenameOutPrefix = args['output_path_prefix'] + str(taskIdx)
- env.storeRunHistory(episodeIdx, notes={'text': 'my notes here'})
- env.saveRunHistoriesBufferIfFull(filenameOutPrefix, maxPerFile=args['max_episode_per_file'])
+ env.store_run_history(episodeIdx, notes={'text': 'my notes here'})
+ env.save_run_histories_buffer_if_full(filenameOutPrefix, max_per_file=args['max_episode_per_file'])
# Episodes are finished -- manually save any last histories still in the buffer
- env.saveRunHistoriesBufferIfFull(filenameOutPrefix, maxPerFile=args['max_episode_per_file'], forceSave=True)
+ env.save_run_histories_buffer_if_full(filenameOutPrefix, max_per_file=args['max_episode_per_file'], force_save=True)
# Show final episode scores to user
# Clip negative scores to 0 for average calculation
diff --git a/examples/scienceworld-web-server-example.py b/examples/scienceworld-web-server-example.py
index b5258af..a6952ae 100644
--- a/examples/scienceworld-web-server-example.py
+++ b/examples/scienceworld-web-server-example.py
@@ -101,8 +101,8 @@ def app():
htmlLog.addHeading("Science World (Text Simulation)")
htmlLog.addHorizontalRule()
- taskName = pywebio.input.select("Select a task:", env.getTaskNames())
- maxVariations = env.getMaxVariations(taskName)
+ taskName = pywebio.input.select("Select a task:", env.get_task_names())
+ maxVariations = env.get_max_variations(taskName)
# variationIdx = slider("Task Variation: ", min_value=0, max_value=(maxVariations-1))
variationIdx = pywebio.input.input('Enter the task variation (min = 0, max = ' + str(maxVariations) + "):")
@@ -118,11 +118,11 @@ def app():
# print("Possible action/object combinations: " + str(env.getPossibleActionObjectCombinations()))
pywebio_out.put_table([
- ["Task", env.getTaskDescription()],
+ ["Task", env.get_task_description()],
["Variation", str(variationIdx) + " / " + str(maxVariations)]
])
- htmlLog.addStr("Task: " + env.getTaskDescription() + "
")
+ htmlLog.addStr("Task: " + env.get_task_description() + "
")
htmlLog.addStr("Variation: " + str(variationIdx) + "
")
htmlLog.addHorizontalRule()
@@ -172,7 +172,7 @@ def app():
'isCompleted': isCompleted,
'userInput': userInputStr,
'taskName': taskName,
- 'taskDescription': env.getTaskDescription(),
+ 'taskDescription': env.get_task_description(),
'look': env.look(),
'inventory': env.inventory(),
'variationIdx': variationIdx,
diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py
index 2427078..995b4e3 100644
--- a/scienceworld/scienceworld.py
+++ b/scienceworld/scienceworld.py
@@ -61,7 +61,7 @@ def __init__(self, taskName=None, serverPath=None, envStepLimit=100):
self.envStepLimit = envStepLimit
# Clear the run histories
- self.clearRunHistories()
+ self.clear_run_histories()
# By default, set that the gold path was not generated unless the user asked for it
self.goldPathGenerated = False
@@ -73,15 +73,15 @@ def load(self, taskName, variationIdx=0, simplificationStr="", generateGoldPath=
# Check loading arguments.
# Validate task name.
taskName = infer_task(taskName)
- if taskName not in self.getTaskNames():
+ if taskName not in self.get_task_names():
msg = "Unknown taskName: '{}'. ".format(taskName)
- msg += "Supported tasks are: {}".format(self.getTaskNames())
+ msg += "Supported tasks are: {}".format(self.get_task_names())
raise ValueError(msg)
self.taskName = taskName
# Validate simplification string.
- possible_simplifications = ["easy"] + self.getPossibleSimplifications()
+ possible_simplifications = ["easy"] + self.get_possible_simplifications()
for simplification in simplificationStr.split(","):
if simplification and simplification not in possible_simplifications:
msg = "Unknown simplification: '{}'. ".format(simplification)
@@ -349,7 +349,7 @@ def taskdescription(self):
return observation
# Goal progress
- def get_goal_progress_str(self):
+ def get_goal_progress(self):
goalStr = self.server.getGoalProgressStr()
return goalStr
diff --git a/scienceworld/utils.py b/scienceworld/utils.py
index 2b5f427..e917c4f 100644
--- a/scienceworld/utils.py
+++ b/scienceworld/utils.py
@@ -23,4 +23,4 @@ def infer_task(name_or_id):
def snake_case_deprecation_warning():
message = "You are using the camel case api. This feature is deprecated. Please migrate to the snake_case api."
formatted_message = f"\033[91m {message} \033[00m"
- warnings.warn(formatted_message, UserWarning, stacklevel=2)
+ warnings.warn(formatted_message, UserWarning, stacklevel=3)
diff --git a/tests/test_scienceworld.py b/tests/test_scienceworld.py
index 193087d..839ca34 100644
--- a/tests/test_scienceworld.py
+++ b/tests/test_scienceworld.py
@@ -47,11 +47,11 @@ def test_multiple_instances():
def test_variation_sets_are_disjoint():
env = ScienceWorldEnv()
- for task in env.getTaskNames():
+ for task in env.get_task_names():
env.load(task)
- train = set(env.getVariationsTrain())
- dev = set(env.getVariationsDev())
- test = set(env.getVariationsTest())
+ train = set(env.get_variations_train())
+ dev = set(env.get_variations_dev())
+ test = set(env.get_variations_test())
assert set.isdisjoint(train, dev)
assert set.isdisjoint(train, test)
assert set.isdisjoint(dev, test)
@@ -119,4 +119,4 @@ def test_load():
def test_consistent_task_names():
"""Verify that Scala and Python code use the same task names."""
env = ScienceWorldEnv()
- assert sorted(env.task_names) == sorted(env.getTaskNames())
+ assert sorted(env.task_names) == sorted(env.get_task_names())