Skip to content

Get rid of deprecation warnings in tests and examples #63

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions examples/human.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def userConsole(args):

# Initialize environment
env = ScienceWorldEnv("", args['jar_path'], envStepLimit=args['env_step_limit'])
taskNames = env.getTaskNames()
taskNames = env.get_task_names()
print("Task Names: " + str(taskNames))

# Choose task
Expand All @@ -30,41 +30,41 @@ def userConsole(args):
# (Many of these are similar to the Jericho API)
#
print("Task Names: " + str(taskNames))
print("Possible actions: " + str(env.getPossibleActions()))
print("Possible objects: " + str(env.getPossibleObjects()))
templates, lut = env.getPossibleActionObjectCombinations()
print("Possible actions: " + str(env.get_possible_actions()))
print("Possible objects: " + str(env.get_possible_objects()))
templates, lut = env.get_possible_action_object_combinations()
print("Possible action/object combinations: " + str(templates))
# print("Object IDX to Object Referent LUT: " + str(lut))
print("Vocabulary: " + str(env.getVocabulary()))
print("Possible actions (with IDs): " + str(env.getPossibleActionsWithIDs()))
print("Possible object types: " + str(env.getObjectTypes()))
print("Vocabulary: " + str(env.get_vocabulary()))
print("Possible actions (with IDs): " + str(env.get_possible_actions_with_IDs()))
print("Possible object types: " + str(env.get_object_types()))
print("Object IDX to Object Referent LUT: " + str(lut))
print("\n")
print("Possible object referents LUT: " + str(env.getPossibleObjectReferentLUT()))
print("Possible object referents LUT: " + str(env.get_possible_object_referent_LUT()))
print("\n")
print("Valid action-object combinations: " +
str(env.getValidActionObjectCombinations()))
str(env.get_valid_action_object_combinations()))
print("\n")
print("Object_ids to type_ids: " + str(env.getAllObjectTypesLUTJSON()))
print("Object_ids to type_ids: " + str(env.get_all_object_types_LUTJSON()))
print("\n")
print("All objects, their ids, types, and referents: " +
str(env.getAllObjectIdsTypesReferentsLUTJSON()))
str(env.get_all_object_ids_types_referents_LUTJSON()))
print("\n")
print("Valid action-object combinations (with templates): " +
str(env.getValidActionObjectCombinationsWithTemplates()))
str(env.get_valid_action_object_combinations_with_templates()))
print("\n")
print("Object Type LUT: " + str(env.getPossibleObjectReferentTypesLUT()))
print("Variations (train): " + str(env.getVariationsTrain()))
print("Object Type LUT: " + str(env.get_possible_object_referent_types_LUT()))
print("Variations (train): " + str(env.get_variations_train()))

print("")
print("----------------------------------------------------------------------------------")
print("")

print("Gold Path:" + str(env.getGoldActionSequence()))
print("Gold Path:" + str(env.get_gold_action_sequence()))

print("Task Name: " + taskName)
print("Variation: " + str(args['var_num']) + " / " + str(env.getMaxVariations(taskName)))
print("Task Description: " + str(env.getTaskDescription()))
print("Variation: " + str(args['var_num']) + " / " + str(env.get_max_variations(taskName)))
print("Task Description: " + str(env.get_task_description()))

#
# Main user input loop
Expand All @@ -73,20 +73,20 @@ def userConsole(args):
while (userInputStr not in exitCommands):
if (userInputStr == "help"):
print("Possible actions: ")
for actionStr in env.getPossibleActions():
for actionStr in env.get_possible_actions():
print("\t" + str(actionStr))

elif (userInputStr == "objects"):
print("Possible objects (one referent listed per object): ")
for actionStr in env.getPossibleObjects():
for actionStr in env.get_possible_objects():
print("\t" + str(actionStr))

elif (userInputStr == "valid"):
print("Valid action-object combinations:")
print(env.getValidActionObjectCombinationsWithTemplates())
print(env.get_valid_action_object_combinations_with_templates())

elif (userInputStr == 'goals'):
print(env.getGoalProgressStr())
print(env.get_goal_progress())

else:
# Send user input, get response
Expand All @@ -109,15 +109,15 @@ def userConsole(args):
userInputStr = userInputStr.lower().strip()

# Display run history
runHistory = env.getRunHistory()
runHistory = env.get_run_history()
print("Run History:")
print(runHistory)
for item in runHistory:
print(item)
print("")

# Display subgoal progress
print(env.getGoalProgressStr())
print(env.get_goal_progress_str())

print("Completed.")

Expand Down
26 changes: 13 additions & 13 deletions examples/random_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,35 @@ def randomModel(args):
# Initialize environment
env = ScienceWorldEnv("", args['jar_path'], envStepLimit=args['env_step_limit'])

taskNames = env.getTaskNames()
taskNames = env.get_task_names()
print("Task Names: " + str(taskNames))

# Choose task
taskName = taskNames[taskIdx] # Just get first task
# Load the task, we we have access to some extra accessors e.g. get_random_variation_train()
env.load(taskName, 0, "")
maxVariations = env.getMaxVariations(taskName)
maxVariations = env.get_max_variations(taskName)
print("Starting Task " + str(taskIdx) + ": " + taskName)
time.sleep(2)

# Start running episodes
for episodeIdx in range(0, numEpisodes):
# Pick a random task variation
randVariationIdx = env.getRandomVariationTrain()
randVariationIdx = env.get_random_variation_train()
env.load(taskName, randVariationIdx, simplificationStr)

# Reset the environment
initialObs, initialDict = env.reset()

# Example accessors
print("Possible actions: " + str(env.getPossibleActions()))
print("Possible objects: " + str(env.getPossibleObjects()))
templates, lut = env.getPossibleActionObjectCombinations()
print("Possible actions: " + str(env.get_possible_actions()))
print("Possible objects: " + str(env.get_possible_objects()))
templates, lut = env.get_possible_action_object_combinations()
print("Possible action/object combinations: " + str(templates))
print("Object IDX to Object Referent LUT: " + str(lut))
print("Task Name: " + taskName)
print("Task Variation: " + str(randVariationIdx) + " / " + str(maxVariations))
print("Task Description: " + str(env.getTaskDescription()))
print("Task Description: " + str(env.get_task_description()))
print("look: " + str(env.look()))
print("inventory: " + str(env.inventory()))
print("taskdescription: " + str(env.taskdescription()))
Expand Down Expand Up @@ -79,15 +79,15 @@ def randomModel(args):
# Randomly select action

# Any action (valid or not)
# templates, lut = env.getPossibleActionObjectCombinations()
# templates, lut = env.get_possible_action_object_combinations()
# print("Possible action/object combinations: " + str(templates))
# print("Object IDX to Object Referent LUT: " + str(lut))
# randomTemplate = random.choice( templates )
# print("Next random action: " + str(randomTemplate))
# userInputStr = randomTemplate["action"]

# Only valid actions
validActions = env.getValidActionObjectCombinationsWithTemplates()
validActions = env.get_valid_action_object_combinations_with_templates()
randomAction = random.choice(validActions)
print("Next random action: " + str(randomAction))
userInputStr = randomAction["action"]
Expand All @@ -102,7 +102,7 @@ def randomModel(args):
curIter += 1

print("Goal Progress:")
print(env.getGoalProgressStr())
print(env.get_goal_progress_str())
time.sleep(1)

# Episode finished -- Record the final score
Expand All @@ -114,11 +114,11 @@ def randomModel(args):

# Save history -- and when we reach maxPerFile, export them to file
filenameOutPrefix = args['output_path_prefix'] + str(taskIdx)
env.storeRunHistory(episodeIdx, notes={'text': 'my notes here'})
env.saveRunHistoriesBufferIfFull(filenameOutPrefix, maxPerFile=args['max_episode_per_file'])
env.store_run_history(episodeIdx, notes={'text': 'my notes here'})
env.save_run_histories_buffer_if_full(filenameOutPrefix, max_per_file=args['max_episode_per_file'])

# Episodes are finished -- manually save any last histories still in the buffer
env.saveRunHistoriesBufferIfFull(filenameOutPrefix, maxPerFile=args['max_episode_per_file'], forceSave=True)
env.save_run_histories_buffer_if_full(filenameOutPrefix, max_per_file=args['max_episode_per_file'], force_save=True)

# Show final episode scores to user
# Clip negative scores to 0 for average calculation
Expand Down
10 changes: 5 additions & 5 deletions examples/scienceworld-web-server-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def app():
htmlLog.addHeading("Science World (Text Simulation)")
htmlLog.addHorizontalRule()

taskName = pywebio.input.select("Select a task:", env.getTaskNames())
maxVariations = env.getMaxVariations(taskName)
taskName = pywebio.input.select("Select a task:", env.get_task_names())
maxVariations = env.get_max_variations(taskName)

# variationIdx = slider("Task Variation: ", min_value=0, max_value=(maxVariations-1))
variationIdx = pywebio.input.input('Enter the task variation (min = 0, max = ' + str(maxVariations) + "):")
Expand All @@ -118,11 +118,11 @@ def app():
# print("Possible action/object combinations: " + str(env.getPossibleActionObjectCombinations()))

pywebio_out.put_table([
["Task", env.getTaskDescription()],
["Task", env.get_task_description()],
["Variation", str(variationIdx) + " / " + str(maxVariations)]
])

htmlLog.addStr("<b>Task:</b> " + env.getTaskDescription() + "<br>")
htmlLog.addStr("<b>Task:</b> " + env.get_task_description() + "<br>")
htmlLog.addStr("<b>Variation:</b> " + str(variationIdx) + "<br>")
htmlLog.addHorizontalRule()

Expand Down Expand Up @@ -172,7 +172,7 @@ def app():
'isCompleted': isCompleted,
'userInput': userInputStr,
'taskName': taskName,
'taskDescription': env.getTaskDescription(),
'taskDescription': env.get_task_description(),
'look': env.look(),
'inventory': env.inventory(),
'variationIdx': variationIdx,
Expand Down
10 changes: 5 additions & 5 deletions scienceworld/scienceworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, taskName=None, serverPath=None, envStepLimit=100):
self.envStepLimit = envStepLimit

# Clear the run histories
self.clearRunHistories()
self.clear_run_histories()

# By default, set that the gold path was not generated unless the user asked for it
self.goldPathGenerated = False
Expand All @@ -73,15 +73,15 @@ def load(self, taskName, variationIdx=0, simplificationStr="", generateGoldPath=
# Check loading arguments.
# Validate task name.
taskName = infer_task(taskName)
if taskName not in self.getTaskNames():
if taskName not in self.get_task_names():
msg = "Unknown taskName: '{}'. ".format(taskName)
msg += "Supported tasks are: {}".format(self.getTaskNames())
msg += "Supported tasks are: {}".format(self.get_task_names())
raise ValueError(msg)

self.taskName = taskName

# Validate simplification string.
possible_simplifications = ["easy"] + self.getPossibleSimplifications()
possible_simplifications = ["easy"] + self.get_possible_simplifications()
for simplification in simplificationStr.split(","):
if simplification and simplification not in possible_simplifications:
msg = "Unknown simplification: '{}'. ".format(simplification)
Expand Down Expand Up @@ -349,7 +349,7 @@ def taskdescription(self):
return observation

# Goal progress
def get_goal_progress_str(self):
def get_goal_progress(self):
goalStr = self.server.getGoalProgressStr()
return goalStr

Expand Down
2 changes: 1 addition & 1 deletion scienceworld/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ def infer_task(name_or_id):
def snake_case_deprecation_warning():
message = "You are using the camel case api. This feature is deprecated. Please migrate to the snake_case api."
formatted_message = f"\033[91m {message} \033[00m"
warnings.warn(formatted_message, UserWarning, stacklevel=2)
warnings.warn(formatted_message, UserWarning, stacklevel=3)
10 changes: 5 additions & 5 deletions tests/test_scienceworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def test_multiple_instances():
def test_variation_sets_are_disjoint():
env = ScienceWorldEnv()

for task in env.getTaskNames():
for task in env.get_task_names():
env.load(task)
train = set(env.getVariationsTrain())
dev = set(env.getVariationsDev())
test = set(env.getVariationsTest())
train = set(env.get_variations_train())
dev = set(env.get_variations_dev())
test = set(env.get_variations_test())
assert set.isdisjoint(train, dev)
assert set.isdisjoint(train, test)
assert set.isdisjoint(dev, test)
Expand Down Expand Up @@ -119,4 +119,4 @@ def test_load():
def test_consistent_task_names():
"""Verify that Scala and Python code use the same task names."""
env = ScienceWorldEnv()
assert sorted(env.task_names) == sorted(env.getTaskNames())
assert sorted(env.task_names) == sorted(env.get_task_names())