diff --git a/examples/human.py b/examples/human.py index 067a7ef..e300b89 100644 --- a/examples/human.py +++ b/examples/human.py @@ -13,7 +13,7 @@ def userConsole(args): # Initialize environment env = ScienceWorldEnv("", args['jar_path'], envStepLimit=args['env_step_limit']) - taskNames = env.getTaskNames() + taskNames = env.get_task_names() print("Task Names: " + str(taskNames)) # Choose task @@ -30,41 +30,41 @@ def userConsole(args): # (Many of these are similar to the Jericho API) # print("Task Names: " + str(taskNames)) - print("Possible actions: " + str(env.getPossibleActions())) - print("Possible objects: " + str(env.getPossibleObjects())) - templates, lut = env.getPossibleActionObjectCombinations() + print("Possible actions: " + str(env.get_possible_actions())) + print("Possible objects: " + str(env.get_possible_objects())) + templates, lut = env.get_possible_action_object_combinations() print("Possible action/object combinations: " + str(templates)) # print("Object IDX to Object Referent LUT: " + str(lut)) - print("Vocabulary: " + str(env.getVocabulary())) - print("Possible actions (with IDs): " + str(env.getPossibleActionsWithIDs())) - print("Possible object types: " + str(env.getObjectTypes())) + print("Vocabulary: " + str(env.get_vocabulary())) + print("Possible actions (with IDs): " + str(env.get_possible_actions_with_IDs())) + print("Possible object types: " + str(env.get_object_types())) print("Object IDX to Object Referent LUT: " + str(lut)) print("\n") - print("Possible object referents LUT: " + str(env.getPossibleObjectReferentLUT())) + print("Possible object referents LUT: " + str(env.get_possible_object_referent_LUT())) print("\n") print("Valid action-object combinations: " + - str(env.getValidActionObjectCombinations())) + str(env.get_valid_action_object_combinations())) print("\n") - print("Object_ids to type_ids: " + str(env.getAllObjectTypesLUTJSON())) + print("Object_ids to type_ids: " + str(env.get_all_object_types_LUTJSON())) print("\n") print("All objects, their ids, types, and referents: " + - str(env.getAllObjectIdsTypesReferentsLUTJSON())) + str(env.get_all_object_ids_types_referents_LUTJSON())) print("\n") print("Valid action-object combinations (with templates): " + - str(env.getValidActionObjectCombinationsWithTemplates())) + str(env.get_valid_action_object_combinations_with_templates())) print("\n") - print("Object Type LUT: " + str(env.getPossibleObjectReferentTypesLUT())) - print("Variations (train): " + str(env.getVariationsTrain())) + print("Object Type LUT: " + str(env.get_possible_object_referent_types_LUT())) + print("Variations (train): " + str(env.get_variations_train())) print("") print("----------------------------------------------------------------------------------") print("") - print("Gold Path:" + str(env.getGoldActionSequence())) + print("Gold Path:" + str(env.get_gold_action_sequence())) print("Task Name: " + taskName) - print("Variation: " + str(args['var_num']) + " / " + str(env.getMaxVariations(taskName))) - print("Task Description: " + str(env.getTaskDescription())) + print("Variation: " + str(args['var_num']) + " / " + str(env.get_max_variations(taskName))) + print("Task Description: " + str(env.get_task_description())) # # Main user input loop @@ -73,20 +73,20 @@ def userConsole(args): while (userInputStr not in exitCommands): if (userInputStr == "help"): print("Possible actions: ") - for actionStr in env.getPossibleActions(): + for actionStr in env.get_possible_actions(): print("\t" + str(actionStr)) elif (userInputStr == "objects"): print("Possible objects (one referent listed per object): ") - for actionStr in env.getPossibleObjects(): + for actionStr in env.get_possible_objects(): print("\t" + str(actionStr)) elif (userInputStr == "valid"): print("Valid action-object combinations:") - print(env.getValidActionObjectCombinationsWithTemplates()) + print(env.get_valid_action_object_combinations_with_templates()) elif (userInputStr == 'goals'): - print(env.getGoalProgressStr()) + print(env.get_goal_progress()) else: # Send user input, get response @@ -109,7 +109,7 @@ def userConsole(args): userInputStr = userInputStr.lower().strip() # Display run history - runHistory = env.getRunHistory() + runHistory = env.get_run_history() print("Run History:") print(runHistory) for item in runHistory: @@ -117,7 +117,7 @@ def userConsole(args): print("") # Display subgoal progress - print(env.getGoalProgressStr()) + print(env.get_goal_progress_str()) print("Completed.") diff --git a/examples/random_agent.py b/examples/random_agent.py index 1bbf38f..bf81652 100644 --- a/examples/random_agent.py +++ b/examples/random_agent.py @@ -19,35 +19,35 @@ def randomModel(args): # Initialize environment env = ScienceWorldEnv("", args['jar_path'], envStepLimit=args['env_step_limit']) - taskNames = env.getTaskNames() + taskNames = env.get_task_names() print("Task Names: " + str(taskNames)) # Choose task taskName = taskNames[taskIdx] # Just get first task # Load the task, we we have access to some extra accessors e.g. get_random_variation_train() env.load(taskName, 0, "") - maxVariations = env.getMaxVariations(taskName) + maxVariations = env.get_max_variations(taskName) print("Starting Task " + str(taskIdx) + ": " + taskName) time.sleep(2) # Start running episodes for episodeIdx in range(0, numEpisodes): # Pick a random task variation - randVariationIdx = env.getRandomVariationTrain() + randVariationIdx = env.get_random_variation_train() env.load(taskName, randVariationIdx, simplificationStr) # Reset the environment initialObs, initialDict = env.reset() # Example accessors - print("Possible actions: " + str(env.getPossibleActions())) - print("Possible objects: " + str(env.getPossibleObjects())) - templates, lut = env.getPossibleActionObjectCombinations() + print("Possible actions: " + str(env.get_possible_actions())) + print("Possible objects: " + str(env.get_possible_objects())) + templates, lut = env.get_possible_action_object_combinations() print("Possible action/object combinations: " + str(templates)) print("Object IDX to Object Referent LUT: " + str(lut)) print("Task Name: " + taskName) print("Task Variation: " + str(randVariationIdx) + " / " + str(maxVariations)) - print("Task Description: " + str(env.getTaskDescription())) + print("Task Description: " + str(env.get_task_description())) print("look: " + str(env.look())) print("inventory: " + str(env.inventory())) print("taskdescription: " + str(env.taskdescription())) @@ -79,7 +79,7 @@ def randomModel(args): # Randomly select action # Any action (valid or not) - # templates, lut = env.getPossibleActionObjectCombinations() + # templates, lut = env.get_possible_action_object_combinations() # print("Possible action/object combinations: " + str(templates)) # print("Object IDX to Object Referent LUT: " + str(lut)) # randomTemplate = random.choice( templates ) @@ -87,7 +87,7 @@ def randomModel(args): # userInputStr = randomTemplate["action"] # Only valid actions - validActions = env.getValidActionObjectCombinationsWithTemplates() + validActions = env.get_valid_action_object_combinations_with_templates() randomAction = random.choice(validActions) print("Next random action: " + str(randomAction)) userInputStr = randomAction["action"] @@ -102,7 +102,7 @@ def randomModel(args): curIter += 1 print("Goal Progress:") - print(env.getGoalProgressStr()) + print(env.get_goal_progress_str()) time.sleep(1) # Episode finished -- Record the final score @@ -114,11 +114,11 @@ def randomModel(args): # Save history -- and when we reach maxPerFile, export them to file filenameOutPrefix = args['output_path_prefix'] + str(taskIdx) - env.storeRunHistory(episodeIdx, notes={'text': 'my notes here'}) - env.saveRunHistoriesBufferIfFull(filenameOutPrefix, maxPerFile=args['max_episode_per_file']) + env.store_run_history(episodeIdx, notes={'text': 'my notes here'}) + env.save_run_histories_buffer_if_full(filenameOutPrefix, max_per_file=args['max_episode_per_file']) # Episodes are finished -- manually save any last histories still in the buffer - env.saveRunHistoriesBufferIfFull(filenameOutPrefix, maxPerFile=args['max_episode_per_file'], forceSave=True) + env.save_run_histories_buffer_if_full(filenameOutPrefix, max_per_file=args['max_episode_per_file'], force_save=True) # Show final episode scores to user # Clip negative scores to 0 for average calculation diff --git a/examples/scienceworld-web-server-example.py b/examples/scienceworld-web-server-example.py index b5258af..a6952ae 100644 --- a/examples/scienceworld-web-server-example.py +++ b/examples/scienceworld-web-server-example.py @@ -101,8 +101,8 @@ def app(): htmlLog.addHeading("Science World (Text Simulation)") htmlLog.addHorizontalRule() - taskName = pywebio.input.select("Select a task:", env.getTaskNames()) - maxVariations = env.getMaxVariations(taskName) + taskName = pywebio.input.select("Select a task:", env.get_task_names()) + maxVariations = env.get_max_variations(taskName) # variationIdx = slider("Task Variation: ", min_value=0, max_value=(maxVariations-1)) variationIdx = pywebio.input.input('Enter the task variation (min = 0, max = ' + str(maxVariations) + "):") @@ -118,11 +118,11 @@ def app(): # print("Possible action/object combinations: " + str(env.getPossibleActionObjectCombinations())) pywebio_out.put_table([ - ["Task", env.getTaskDescription()], + ["Task", env.get_task_description()], ["Variation", str(variationIdx) + " / " + str(maxVariations)] ]) - htmlLog.addStr("Task: " + env.getTaskDescription() + "
") + htmlLog.addStr("Task: " + env.get_task_description() + "
") htmlLog.addStr("Variation: " + str(variationIdx) + "
") htmlLog.addHorizontalRule() @@ -172,7 +172,7 @@ def app(): 'isCompleted': isCompleted, 'userInput': userInputStr, 'taskName': taskName, - 'taskDescription': env.getTaskDescription(), + 'taskDescription': env.get_task_description(), 'look': env.look(), 'inventory': env.inventory(), 'variationIdx': variationIdx, diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py index 2427078..995b4e3 100644 --- a/scienceworld/scienceworld.py +++ b/scienceworld/scienceworld.py @@ -61,7 +61,7 @@ def __init__(self, taskName=None, serverPath=None, envStepLimit=100): self.envStepLimit = envStepLimit # Clear the run histories - self.clearRunHistories() + self.clear_run_histories() # By default, set that the gold path was not generated unless the user asked for it self.goldPathGenerated = False @@ -73,15 +73,15 @@ def load(self, taskName, variationIdx=0, simplificationStr="", generateGoldPath= # Check loading arguments. # Validate task name. taskName = infer_task(taskName) - if taskName not in self.getTaskNames(): + if taskName not in self.get_task_names(): msg = "Unknown taskName: '{}'. ".format(taskName) - msg += "Supported tasks are: {}".format(self.getTaskNames()) + msg += "Supported tasks are: {}".format(self.get_task_names()) raise ValueError(msg) self.taskName = taskName # Validate simplification string. - possible_simplifications = ["easy"] + self.getPossibleSimplifications() + possible_simplifications = ["easy"] + self.get_possible_simplifications() for simplification in simplificationStr.split(","): if simplification and simplification not in possible_simplifications: msg = "Unknown simplification: '{}'. ".format(simplification) @@ -349,7 +349,7 @@ def taskdescription(self): return observation # Goal progress - def get_goal_progress_str(self): + def get_goal_progress(self): goalStr = self.server.getGoalProgressStr() return goalStr diff --git a/scienceworld/utils.py b/scienceworld/utils.py index 2b5f427..e917c4f 100644 --- a/scienceworld/utils.py +++ b/scienceworld/utils.py @@ -23,4 +23,4 @@ def infer_task(name_or_id): def snake_case_deprecation_warning(): message = "You are using the camel case api. This feature is deprecated. Please migrate to the snake_case api." formatted_message = f"\033[91m {message} \033[00m" - warnings.warn(formatted_message, UserWarning, stacklevel=2) + warnings.warn(formatted_message, UserWarning, stacklevel=3) diff --git a/tests/test_scienceworld.py b/tests/test_scienceworld.py index 193087d..839ca34 100644 --- a/tests/test_scienceworld.py +++ b/tests/test_scienceworld.py @@ -47,11 +47,11 @@ def test_multiple_instances(): def test_variation_sets_are_disjoint(): env = ScienceWorldEnv() - for task in env.getTaskNames(): + for task in env.get_task_names(): env.load(task) - train = set(env.getVariationsTrain()) - dev = set(env.getVariationsDev()) - test = set(env.getVariationsTest()) + train = set(env.get_variations_train()) + dev = set(env.get_variations_dev()) + test = set(env.get_variations_test()) assert set.isdisjoint(train, dev) assert set.isdisjoint(train, test) assert set.isdisjoint(dev, test) @@ -119,4 +119,4 @@ def test_load(): def test_consistent_task_names(): """Verify that Scala and Python code use the same task names.""" env = ScienceWorldEnv() - assert sorted(env.task_names) == sorted(env.getTaskNames()) + assert sorted(env.task_names) == sorted(env.get_task_names())