diff --git a/crowdsourcing/quests/run_task.py b/crowdsourcing/quests/run_task.py index e186a0519..c07e9341c 100644 --- a/crowdsourcing/quests/run_task.py +++ b/crowdsourcing/quests/run_task.py @@ -7,6 +7,7 @@ import os import time import shlex +import asyncio from mephisto.abstractions.databases.local_database import LocalMephistoDB from mephisto.operations.operator import Operator from mephisto.operations.utils import get_root_dir @@ -167,10 +168,10 @@ def construct_tasks(num_tasks): builder = StarspaceBuilder(ldb, opt=opt) random.seed(88) while len(tasks) < num_tasks: - g, world = builder.get_graph() + g, world = asyncio.run(builder.get_graph()) while len(world.oo_graph.agents) == 0: print("no agents in room") - g, world = builder.get_graph() + g, world = asyncio.run(builder.get_graph()) possible_agents = list(world.oo_graph.agents.values()) random.shuffle(possible_agents) for character in possible_agents: diff --git a/deploy/web/configs/dev/config b/deploy/web/configs/dev/config index d8e87beca..9b4f46be1 100644 --- a/deploy/web/configs/dev/config +++ b/deploy/web/configs/dev/config @@ -1,16 +1,16 @@ --data-model-db /home/ubuntu/data/database.db --hostname -www.light-rpg.ai +dev.light-rpg.ai --light-model-root /home/ubuntu/data/models/ --password LetsPlay --port 8088 ---builder-model -starspace/angela_starspace/model4 ---dialog-model -dialog_gen/model ---acting-model -main_act/model +--db-backend +aws-postgres +--is-logging +True +--safety-list +'' diff --git a/deploy/web/configs/dev/config.js b/deploy/web/configs/dev/config.js index 22cd84307..abc50b79e 100644 --- a/deploy/web/configs/dev/config.js +++ b/deploy/web/configs/dev/config.js @@ -5,9 +5,9 @@ */ const DEV = { - host: "https://www.light-rpg.ai", - hostname: "www.light-rpg.ai", - port: "8088", + host: "https://dev.light-rpg.ai", + hostname: "dev.light-rpg.ai", + port: "80", }; export default DEV; diff --git a/deploy/web/configs/devfair-no-models/config b/deploy/web/configs/devfair-no-models/config index 779bb4406..8760f32b7 100644 --- a/deploy/web/configs/devfair-no-models/config +++ b/deploy/web/configs/devfair-no-models/config @@ -10,5 +10,19 @@ LetsPlay 35496 --safety-list '' +--safety-model-opt-file +'' +--dialog-model-opt-file +'' +--action-model-opt-file +'' +--roleplaying-score-opt-file +'' +--generic-act-opt-file +'' +--parser-opt-file +'' --disable-builder True +--db-backend +local diff --git a/deploy/web/configs/devfair/config b/deploy/web/configs/devfair/config index 02b173449..2fb8700b6 100644 --- a/deploy/web/configs/devfair/config +++ b/deploy/web/configs/devfair/config @@ -10,13 +10,5 @@ LetsPlay 35496 --safety-list /checkpoint/light/data/safety/reddit_and_beathehobbot_lists/OffensiveLanguage.txt ---dialog-model -game2021/gen_dialog_model/model.checkpoint ---acting-model -main_act/model ---parser-model-file -/checkpoint/jase/projects/light/parser/parser3/34c_jobid=1/model ---roleplaying-score-model-file -/checkpoint/light/models/game2020/roleplay_scorer/model ---generic-act-model-file -/checkpoint/light/models/game2021/act_model/model +--db-backend +local diff --git a/deploy/web/configs/local-no-models/config b/deploy/web/configs/local-no-models/config index f1585ece2..449fb8c3f 100644 --- a/deploy/web/configs/local-no-models/config +++ b/deploy/web/configs/local-no-models/config @@ -8,3 +8,19 @@ localhost LetsPlay --port 35494 +--safety-model-opt-file +'' +--dialog-model-opt-file +'' +--action-model-opt-file +'' +--roleplaying-score-opt-file +'' +--generic-act-opt-file +'' +--parser-opt-file +'' +--db-backend +local +--is-logging +True diff --git a/deploy/web/configs/local/config b/deploy/web/configs/local/config index e9e78b8dd..f0099ee54 100644 --- a/deploy/web/configs/local/config +++ b/deploy/web/configs/local/config @@ -8,9 +8,5 @@ localhost LetsPlay --port 35494 ---builder-model -starspace/angela_starspace/model4 ---dialog-model -dialog_gen/model ---acting-model -main_act/model +--db-backend +local diff --git a/deploy/web/configs/prod/config b/deploy/web/configs/prod/config index 82224858b..f02d3cc8d 100644 --- a/deploy/web/configs/prod/config +++ b/deploy/web/configs/prod/config @@ -8,19 +8,9 @@ www.light-rpg.ai LetsPlay --port 8080 ---builder-model -starspace/angela_starspace/model4 ---dialog-model -dialog/model.checkpoint ---acting-model -main_act/model ---parser-model-file -/home/ubuntu/data/models/parser/model ---roleplaying-score-model-file -/home/ubuntu/data/models/scoring/model ---generic-act-model-file -/home/ubuntu/data/models/acting/model --disable-builder True --is-logging True +--safety-list +'' diff --git a/deploy/web/deploy.sh b/deploy/web/deploy.sh index b9f74deb8..3882facf0 100755 --- a/deploy/web/deploy.sh +++ b/deploy/web/deploy.sh @@ -26,19 +26,4 @@ fi CONF_FN=$WEBDIR"/configs/"$1"/config" -python $SERVER_FILE @$CONF_FN - - .ipynb_checkpoints/ - Env Database Merge Workbook.ipynb - Orig Episode Database Merge Workbook.ipynb - Quest Database Merge Notebook.ipynb - Wild Episode Database Merge.ipynb - crowdsourcing/environment/world_builder/ - crowdsourcing/filtering/is_safe_is_light/data/ - deploy/MODEL_SERVER_SETUP.sh - deploy/WORLD_SERVER_SETUP.sh - hydra_configs/ - json-builder-respawns - models/ - scripts/examples/complex_world_scrubbed.json - test_db/ +cat $CONF_FN | python $SERVER_FILE `xargs -0` diff --git a/deploy/web/gameapp/src/WebSockets/useWSDataSource.js b/deploy/web/gameapp/src/WebSockets/useWSDataSource.js index 4031afdcc..7a03075ff 100644 --- a/deploy/web/gameapp/src/WebSockets/useWSDataSource.js +++ b/deploy/web/gameapp/src/WebSockets/useWSDataSource.js @@ -24,7 +24,7 @@ function uuidv4() { // MESSAGE REDUCER const reducer = (state, msg) => { - window.top.postMessage(JSON.stringify(msg), "*"); + window.parent.postMessage(JSON.stringify(msg), "*"); if ( msg.text && msg.text.indexOf("You mumble something incomprehensible") >= 0 @@ -133,6 +133,7 @@ export function useWSDataSource(url) { const [persona, setPersona] = useState(null); const [location, setLocation] = useState(null); const [agents, setAgents] = useState({}); + const [aliveInterval, setAliveInterval] = useState(null); /*---------------REFS----------------*/ const websocket = useRef(); const agentList = useRef(agents); @@ -225,6 +226,11 @@ export function useWSDataSource(url) { websocket.current.onopen = () => { setConnected(true); + const hb = JSON.stringify({ command: "hb", data: {} }); + var interval = window.setInterval(() => { + websocket.current.send(hb); + }, 10000); + setAliveInterval(interval); }; websocket.current.onerror = websocket.current.onclose = (e) => { @@ -232,6 +238,7 @@ export function useWSDataSource(url) { setConnected(false); setErrored(true); websocket.current = null; + window.clearInterval(aliveInterval); }; } const disconnectFromSession = () => { diff --git a/deploy/web/gameapp/src/features/api/Messages.ts b/deploy/web/gameapp/src/features/api/Messages.ts index 1757768ef..a26264f20 100644 --- a/deploy/web/gameapp/src/features/api/Messages.ts +++ b/deploy/web/gameapp/src/features/api/Messages.ts @@ -57,4 +57,5 @@ export const api = createApi({ }), }); -export const { useGetMessagesQuery } = api; +// TODO @Justin this is unused, what are we doing with it? +// export const { useGetMessagesQuery } = api; diff --git a/deploy/web/landingapp/src/pages/AboutPage/index.js b/deploy/web/landingapp/src/pages/AboutPage/index.js index a6405ebc0..dac644b10 100644 --- a/deploy/web/landingapp/src/pages/AboutPage/index.js +++ b/deploy/web/landingapp/src/pages/AboutPage/index.js @@ -52,7 +52,7 @@ const AboutPage = (props) => { collected from within LIGHT, with the goal of enabling other researchers to extend upon our work, and this will be available for download from the project page. The complete source code for the - project is available on our github. + project will be made available on our github.

diff --git a/deploy/web/server/README.md b/deploy/web/server/README.md new file mode 100644 index 000000000..0b3e7482f --- /dev/null +++ b/deploy/web/server/README.md @@ -0,0 +1 @@ +# LIGHT Server Architecture doc diff --git a/deploy/web/server/builder_server.py b/deploy/web/server/builder_server.py index c375686e9..fc6e5d48b 100644 --- a/deploy/web/server/builder_server.py +++ b/deploy/web/server/builder_server.py @@ -10,6 +10,7 @@ import inspect import time import tornado.web +import asyncio from tornado.ioloop import IOLoop from tornado import locks from tornado import gen @@ -625,7 +626,7 @@ def initialize(self, database): self.builder = get_builder(database) @tornado.web.authenticated - def get(self, type, source): + async def get(self, type, source): if type not in ["room", "object", "character"]: raise AppException(reason="Type is not valid. ", status_code=400) with self.db as ldb: @@ -635,11 +636,13 @@ def get(self, type, source): return source_obj = source_objs[0] if type == "room": - items = builder.get_neighbor_rooms(source_obj["id"]) + items = await builder.get_neighbor_rooms(source_obj["id"]) elif type == "object": - items = builder.get_contained_items(source_obj["id"], source_obj["type"]) + items = await builder.get_contained_items( + source_obj["id"], source_obj["type"] + ) elif type == "character": - items = builder.get_contained_characters(source_obj["id"]) + items = await builder.get_contained_characters(source_obj["id"]) with self.db as ldb: result_items = [dict(ldb.get_id(id=x.db_id, expand=True)[0]) for x in items] self.write(json.dumps(result_items)) diff --git a/deploy/web/server/game_instance.py b/deploy/web/server/game_instance.py index afec15d3b..de66fe178 100644 --- a/deploy/web/server/game_instance.py +++ b/deploy/web/server/game_instance.py @@ -1,8 +1,9 @@ +#!/usr/bin/env python3 + # Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - from light.graph.builders.starspace_all import StarspaceBuilder from light.graph.builders.map_json_builder import MapJsonBuilder from light.graph.builders.tutorial_builder import TutorialWorldBuilder @@ -16,6 +17,13 @@ import os.path import time +import asyncio + +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from light.data_model.db.environment import EpisodeDB + from light.world.world import WorldConfig # TODO specify the models to be using USE_MODELS = True @@ -100,48 +108,63 @@ class that extends the Player class, which itself extends Agent. Players def __init__( self, game_id, - ldb, + ldb, # TODO remove this DB g=None, opt=None, + world_config: Optional["WorldConfig"] = None, # TODO make this required ): - if g is None: - if opt["builder_model"] is not None: - _, world = StarspaceBuilder( - ldb, - debug=False, - opt=opt, - ).get_graph() # TODO: what are the args that are needed - self.world = world - else: - opt["load_map"] = os.path.expanduser( - "~/LIGHT/scripts/examples/complex_world.json" - ) - world_builder = MapJsonBuilder("", debug=False, opt=opt) - _, self.world = world_builder.get_graph() - else: + self.world = None + if g is not None: self.world = g + self.world_config = world_config + self.opt = opt self.db = ldb self.game_id = game_id self.players = [] self.providers = [] self.last_connection = time.time() + @classmethod + async def get( + cls, + game_id, + ldb, # TODO remove this DB + g=None, + opt=None, + world_config: Optional["WorldConfig"] = None, # TODO make this required + ) -> "GameInstance": + instance = cls(game_id, ldb, g=g, opt=opt, world_config=world_config) + await instance._init_world() + return instance + + async def _init_world(self): + if self.opt["builder_model"] is not None: + _, self.world = await StarspaceBuilder( + self.ldb, + debug=False, + opt=self.world_config.opt, + ).get_graph() # TODO: what are the args that are needed + else: + self.world_config.opt["load_map"] = os.path.expanduser( + "~/LIGHT/scripts/examples/complex_world.json" + ) + world_builder = MapJsonBuilder( + episode_db=self.world_config.episode_db, opt=self.world_config.opt + ) + _, self.world = await world_builder.get_graph( + world_config=self.world_config + ) + def fill_souls(self, FLAGS, model_resources): purgatory = self.world.purgatory - if FLAGS.dialog_model is None: + if len(FLAGS.dialog_model_opt_file) <= 3: purgatory.register_filler_soul_provider("repeat", RepeatSoul, lambda: []) else: purgatory.register_filler_soul_provider( "model", GenerativeHeuristicModelSoul, - lambda: [model_resources["shared_model_content"]], - ) - if model_resources.get("rpg_model") is not None: - purgatory.register_shared_args("rpg_model", model_resources["rpg_model"]) - if model_resources.get("shared_action_model") is not None: - purgatory.register_shared_args( - "generic_act_model", model_resources["generic_act_model"] + lambda: [], ) for empty_agent in self.world.oo_graph.agents.values(): purgatory.fill_soul(empty_agent) @@ -149,7 +172,7 @@ def fill_souls(self, FLAGS, model_resources): def register_provider(self, provider): self.providers.append(provider) - def run_graph_step(self): + async def run_graph_step(self): world = self.world # Clear disconnected players @@ -157,13 +180,13 @@ def run_graph_step(self): for player in left_players: if player.player_soul is not None: node_to_clean = player.player_soul.target_node - self.world.purgatory.clear_soul(node_to_clean) + await self.world.purgatory.clear_soul(node_to_clean) self.world.purgatory.fill_soul(node_to_clean) self.players.remove(player) self.last_connection = time.time() # clear corpses and respawn - ags = self.world.clean_corpses_and_respawn() + ags = await self.world.clean_corpses_and_respawn() for ag in ags: self.world.purgatory.fill_soul(ag) @@ -173,31 +196,44 @@ class TutorialInstance(GameInstance): Version of the game meant to run tutorials, not for general play """ - def __init__(self, game_id, ldb, opt=None): - _, tutorial_world = TutorialWorldBuilder(ldb, opt).get_graph() + def __init__( + self, + game_id, + ldb, + g=None, + opt=None, + world_config: Optional["WorldConfig"] = None, + ): self.db = ldb self._created_time = time.time() + super().__init__(game_id, ldb, opt=opt, world_config=world_config) + self._should_shutdown = False + self._did_complete = True + + async def _init_world(self): + _, tutorial_world = await TutorialWorldBuilder( + self.db, + opt=self.world_config.opt, + ).get_graph(world_config=self.world_config) + self.world = tutorial_world self._player_node = tutorial_world.oo_graph.find_nodes_by_name("You")[0] self._target_destination = tutorial_world.oo_graph.find_nodes_by_name( "Ethereal Mist" )[0] - super().__init__(game_id, ldb, g=tutorial_world, opt=opt) - self._should_shutdown = False - self._did_complete = True def fill_souls(self, _FLAGS, model_resources): """Tutorials directly register the tutorial to the DM""" self.world.purgatory.register_filler_soul_provider( "tutorial", TutorialModelSoul, - lambda: [model_resources["shared_model_content"]], + lambda: [], ) dm_agent = list(self.world.oo_graph.agents.values())[1] assert dm_agent.name == "Dungeon Master", "Did not find DM!" self.world.purgatory.fill_soul(dm_agent, "tutorial") - def run_graph_step(self): - super().run_graph_step() + async def run_graph_step(self): + await super().run_graph_step() self._did_complete = self._player_node.get_room() == self._target_destination self._should_shutdown = ( len(self.players) == 0 and time.time() - self._created_time > 60 diff --git a/deploy/web/server/model_server.py b/deploy/web/server/model_server.py index 7b5742921..24c13bce6 100644 --- a/deploy/web/server/model_server.py +++ b/deploy/web/server/model_server.py @@ -1,130 +1,205 @@ +#!/usr/bin/env python3 + # Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import argparse -import socket - -DEFAULT_HOSTNAME = "localhost" -DEFAULT_PORT = 35497 +"""Application specifically for hosting a model for remote access""" -def send_to_connection(c, txt): - txt = txt.rstrip("\n").lstrip(" ").lstrip("\n") - if len(txt) > 0: - txt += "\n" - c[0].send(str.encode(txt)) +import argparse +import json +import logging +import os +import traceback +import asyncio +from typing import TYPE_CHECKING, Any, Dict, Optional + +import tornado.auth +import tornado.escape +import tornado.ioloop +import tornado.web +import tornado.websocket + +from light import LIGHT_DIR +from light.registry.model_pool import ALL_LOADERS, ModelPool, ModelTypeName +from light.registry.models.acting_score_model import ( + ParlAIPolyencoderActingScoreModelConfig, +) + +# Temporary imports pre Hydra +from light.registry.parlai_model import ParlAIModelConfig + + +if TYPE_CHECKING: + from parlai.core.agents import Agent + +tornado_settings = { + "autoescape": None, + "compiled_template_cache": False, +} +DEFAULT_HOSTNAME = "localhost" +DEFAULT_PORT = 40000 -class TelnetClient: - def __init__(self, model, client_id, connection_details): +class ModelServer(tornado.web.Application): + def __init__(self, model: "Agent", given_tornado_settings=None): self.model = model - self.c = connection_details - self.text = "" - self.alive = True - self.client_id = client_id - - def act(self): - """ - Pull an action stored from the last alive check - """ - if self.text != "": - agent_id = str(self.client_id) - print(agent_id + ":" + str(self.text)) - # self.model.parse_exec(agent_id, self.text) - self.observe() - self.text = "" - - def observe(self): - """ - Send any observed content to the client. - This method should query the graph for what it needs, and should - clear the graph content when this happens. - """ - agent_id = self.client_id - txt = "blah!" - send_to_connection(self.c, txt) - - def is_alive(self): - """ - As alive checks are called every tick, we both check liveliness and - store the last action if one existed - """ - # import pdb; pdb.set_trace() - try: - data = self.c[0].recv(1024) - if data != b"": - try: - self.text = data.decode() - print(self.text) - except UnicodeDecodeError: - self.text = "" - else: - # dead connection, unspawn the client - self.alive = False - print("[" + str(self.client_id) + " has disconnected]") - except BlockingIOError: - pass - - return self.alive - - -class TelnetClientProvider: - def __init__(self, model, ip="127.0.0.1", port=35496): - self.ip = ip - self.port = port - self._setup_socket() - self._cnt = 0 + + super(ModelServer, self).__init__(self.get_handlers(), **given_tornado_settings) + + def get_handlers(self): + return [ + (r"/model_request", ResponseHandler, {"model": self.model}), + (r"/is_alive", AliveHandler, {}), + ] + + +class BaseHandler(tornado.web.RequestHandler): + def __init__(self, *request, **kwargs): + self.include_host = False + super(BaseHandler, self).__init__(*request, **kwargs) + + def set_default_headers(self): + self.set_header("Access-Control-Allow-Origin", "*") + self.set_header("Access-Control-Allow-Headers", "*") + + def write_error(self, status_code, **kwargs): + logging.error("ERROR: %s: %s" % (status_code, kwargs)) + if "exc_info" in kwargs: + logging.info( + "Traceback: {}".format(traceback.format_exception(*kwargs["exc_info"])) + ) + exc_info = kwargs["exc_info"] + try: + params = { + "error": str(exc_info[1]), + "trace_info": traceback.format_exception(*exc_info), + "request": str(self.request.__dict__), + } + self.write(json.dumps(params)) + except Exception as e: + logging.error(e) + + +class ResponseHandler(BaseHandler): + """ + Handler to pass a post response along to the model, then + return a result + """ + + def initialize(self, model): self.model = model - def _setup_socket(self): - server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind((self.ip, self.port)) - print("Server socket bound with with ip {} port {}".format(self.ip, self.port)) - server_socket.listen() - server_socket.settimeout(0.0) - self.server_socket = server_socket - - def get_new_clients(self): - """ - Should check the potential source of clients for new clients. If - a client exists, this should instantiate a relevant Client object - for each potential new client and return them. - - This particular implementation only checks for one client at a time - """ + async def post(self): + # Process the data to extract the act + data = tornado.escape.json_decode(self.request.body) + message = data["observation"] + # Pass the act to the model + self.model.observe(message) + # return the response + response = await self.model.act() + if "metrics" in response: + del response["metrics"] + if "sorted_scores" in response and not isinstance( + response["sorted_scored"], list + ): + response["sorted_scores"].force_set(response["sorted_scores"].tolist()) try: - (clientConnection, clientAddress) = self.server_socket.accept() - if clientConnection: - self._cnt += 1 - client_id = self._cnt - c = (clientConnection, clientAddress, client_id) - print("added a connection to model server!: " + str(c)) - c[0].settimeout(0.0) - if client_id == -1: - send_to_connection(c, "Sorry the model server is full!") - return [] - new_client = TelnetClient(self.model, client_id, c) - return [new_client] - except BlockingIOError: - pass - return [] + self.write(json.dumps({"act": response})) + except TypeError: + print("JSON encoding failed:") + print(response.keys()) + print(response) + raise + + +class AliveHandler(BaseHandler): + """ + Handler to pass a post response along to the model, then + return a result + """ + + def initialize(self): + pass + + def post(self): + # Process the data to extract the act + self.write(json.dumps({"alive": True})) + + +def _run_server( + given_tornado_settings: Dict[str, Any], hostname: str, port: int, model: "Agent" +): + """ + Run the model server with the given setup configuration + """ + my_loop = tornado.ioloop.IOLoop() + + app = ModelServer( + model=model, + given_tornado_settings=given_tornado_settings, + ) + app.listen(port, max_buffer_size=1024 ** 3) + print("Model Server Started") + + try: + my_loop.start() + except KeyboardInterrupt: + my_loop.stop() + print("Exiting server") + + +def _init_model(model_opt_file: str, model_loader: str) -> "Agent": + """Initialize a model for serving""" + + pool = ModelPool() + # Temporary mapping that allows us to get things running before Hydra + cfg = None + if model_loader == "ParlAI": + cfg = ParlAIModelConfig(opt_file=model_opt_file) + elif model_loader == "ParlAIActingScore": + cfg = ParlAIPolyencoderActingScoreModelConfig(opt_file=model_opt_file) + else: + raise NotImplementedError(f"Unsupported model loader {model_loader}") + + pool.register_model(cfg, [ModelTypeName.SERVED]) + model = pool.get_model(ModelTypeName.SERVED) + # Try to clear up some memory + del pool._model_loaders[ModelTypeName.SERVED] + import gc + + gc.collect() + return model def main(): import random import numpy - parser = argparse.ArgumentParser(description="Start the telnet server.") + parser = argparse.ArgumentParser(description="Start the model server.") parser.add_argument( "--light-model-root", type=str, - default="/Users/jju/Desktop/LIGHT/", - help="models path. For local setup, use: /checkpoint/jase/projects/light/dialog/", + default=os.path.join(LIGHT_DIR, "models/"), + help="Path to the models", + ) + parser.add_argument( + "--model-opt-file", + type=str, + default=os.path.join( + LIGHT_DIR, "light/registry/models/config/baseline_generative.opt" + ), + help="Opt file to load a model from", + ) + parser.add_argument( + "--model-loader", + type=str, + default="ParlAI", + help="ModelConfig to load alongside the given opt file", ) parser.add_argument( - "-port", + "--port", metavar="port", type=int, default=DEFAULT_PORT, @@ -139,25 +214,9 @@ def main(): ) FLAGS = parser.parse_args() - random.seed(6) - numpy.random.seed(6) - model = [] - - provider = TelnetClientProvider(model, FLAGS.hostname, FLAGS.port) - clients = [] - while True: - # try to get new clients - clients += provider.get_new_clients() - - # Clear disconnected clients - left_clients = [p for p in clients if not p.is_alive()] - for client in left_clients: - clients.remove(client) - - # Check existing clients - for client in clients: - # import pdb; pdb.set_trace() - act = client.act() + os.environ["LIGHT_MODEL_ROOT"] = FLAGS.light_model_root + model = _init_model(FLAGS.model_opt_file, FLAGS.model_loader) + _run_server(tornado_settings, FLAGS.hostname, FLAGS.port, model) if __name__ == "__main__": diff --git a/deploy/web/server/registry.py b/deploy/web/server/registry.py index 5033405b6..9a5ddbf05 100644 --- a/deploy/web/server/registry.py +++ b/deploy/web/server/registry.py @@ -1,11 +1,13 @@ +#!/usr/bin/env python3 + # Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - import json import time import uuid +import asyncio import tornado.web from tornado.routing import ( PathMatches, @@ -15,6 +17,14 @@ from deploy.web.server.game_instance import GameInstance, TutorialInstance from deploy.web.server.tornado_server import TornadoPlayerFactory from light.graph.builders.user_world_builder import UserWorldBuilder +from light.data_model.db.users import PlayerStatus +from light.world.world import WorldConfig + +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from light.data_model.db.episodes import EpisodeDB + from light.data_model.db.users import UserDB def get_rand_id(): @@ -28,24 +38,35 @@ class RegistryApplication(tornado.web.Application): - Assign to a random (or default) game based on some load balancing """ - def __init__(self, FLAGS, ldb, model_resources, tornado_settings): + def __init__( + self, + FLAGS, + ldb, # TODO remove! + model_pool, + tornado_settings, + episode_db: Optional["EpisodeDB"] = None, + user_db: Optional["UserDB"] = None, + ): self.game_instances = {} self.step_callbacks = {} self.tutorial_map = {} # Player ID to game ID - self.model_resources = model_resources + self.model_pool = model_pool self.FLAGS = FLAGS self.ldb = ldb + self.episode_db = episode_db + self.user_db = user_db super(RegistryApplication, self).__init__( - self.get_handlers(FLAGS, ldb, tornado_settings), **tornado_settings + self.get_handlers(FLAGS, user_db, tornado_settings), **tornado_settings ) + self.opt = vars(self.FLAGS) - def get_handlers(self, FLAGS, ldb, tornado_settings): + def get_handlers(self, FLAGS, user_db, tornado_settings): self.tornado_provider = TornadoPlayerFactory( self, FLAGS.hostname, FLAGS.port, given_tornado_settings=tornado_settings, - db=ldb, + user_db=user_db, ) self.router = RuleRouter( [ @@ -83,15 +104,22 @@ def cleanup_games(self): del self.step_callbacks[game_id] del self.game_instances[game_id] - def run_new_game(self, game_id, ldb, player_id=None, world_id=None): + async def run_new_game(self, game_id, ldb, player_id=None, world_id=None): if world_id is not None and player_id is not None: builder = UserWorldBuilder(ldb, player_id=player_id, world_id=world_id) - _, world = builder.get_graph() - game = GameInstance(game_id, ldb, g=world, opt=vars(self.FLAGS)) + _, world = await builder.get_graph() + game = await GameInstance.get(game_id, ldb, g=world, opt=self.opt) else: - game = GameInstance(game_id, ldb, opt=vars(self.FLAGS)) + world_config = WorldConfig( + episode_db=self.episode_db, + model_pool=self.model_pool, + opt=self.opt, + ) + game = await GameInstance.get( + game_id, ldb, opt=self.opt, world_config=world_config + ) world = game.world - game.fill_souls(self.FLAGS, self.model_resources) + game.fill_souls(self.FLAGS, []) self.game_instances[game_id] = game game.register_provider(self.tornado_provider) @@ -101,21 +129,27 @@ def run_new_game(self, game_id, ldb, player_id=None, world_id=None): self.step_callbacks[game_id].start() return game - def run_tutorial(self, user_id, on_complete): + async def run_tutorial(self, user_id, on_complete): game_id = get_rand_id() - game = TutorialInstance(game_id, self.ldb, opt=vars(self.FLAGS)) - game.fill_souls(self.FLAGS, self.model_resources) + world_config = WorldConfig( + episode_db=self.episode_db, + model_pool=self.model_pool, + opt=self.opt, + ) + game = await TutorialInstance.get( + game_id, self.ldb, opt=self.opt, world_config=world_config + ) + game.fill_souls(self.FLAGS, []) world = game.world - def run_or_cleanup_world(): - game.run_graph_step() + async def run_or_cleanup_world(): + await game.run_graph_step() if game._should_shutdown or game._did_complete: - if game._did_complete: - with self.ldb as ldb: - flags = ldb.get_user_flags(user_id) - flags.completed_onboarding = True - ldb.set_user_flags(user_id, flags) + if ( + game._did_complete and self.user_db is not None + ): # TODO should always be set + self.user_db.update_player_status(user_id, PlayerStatus.STANDARD) on_complete() self.step_callbacks[game_id].stop() del self.step_callbacks[game_id] @@ -164,7 +198,7 @@ def initialize(self, app): self.game_instances = app.game_instances @tornado.web.authenticated - def post(self, game_id): + async def post(self, game_id): """ Registers a new TornadoProvider at the game_id endpoint """ @@ -175,14 +209,15 @@ def post(self, game_id): world_id = self.get_argument("world_id", None, True) if world_id is not None: username = tornado.escape.xhtml_escape(self.current_user) - with self.app.ldb as db: - player = db.get_user_id(username) - if not db.is_world_owned_by(world_id, player): - self.set_status(403) - return - game = self.app.run_new_game(game_id, self.app.ldb, player, world_id) + with self.app.user_db as user_db: + player = user_db.get_user_id(username) + # TODO update with the env DB + # if not user_db.is_world_owned_by(world_id, player): + # self.set_status(403) + # return + game = await self.app.run_new_game(game_id, self.app.ldb, player, world_id) else: - game = self.app.run_new_game(game_id, self.app.ldb) + game = await self.app.run_new_game(game_id, self.app.ldb) # Create game_provider here print("Registering: ", game_id) diff --git a/deploy/web/server/run_server.py b/deploy/web/server/run_server.py index 4c455560c..7de2346a2 100644 --- a/deploy/web/server/run_server.py +++ b/deploy/web/server/run_server.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 + # Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. @@ -18,13 +20,41 @@ from tornado.httpserver import HTTPServer from tornado.ioloop import IOLoop import inspect -import os.path +import os +import asyncio from light.data_model.light_database import LIGHTDatabase -from light.graph.events.graph_events import init_safety_classifier -from light.world.souls.models.generative_heuristic_model_soul import ( - GenerativeHeuristicModelSoul, +from light.data_model.db.base import LightDBConfig, LightAWSDBConfig +from light.data_model.db.episodes import EpisodeDB +from light.data_model.db.users import UserDB +from light.world.world import WorldConfig +from light.registry.model_pool import ModelPool, ModelTypeName +from light.registry.parlai_model import ParlAIModelConfig +from light.registry.parlai_remote_model import ParlAIRemoteModelConfig +from light.registry.models.acting_score_model import ( + ParlAIPolyencoderActingScoreModelConfig, +) +from light.data_model.db.base import LightDBConfig +from light.data_model.db.episodes import EpisodeDB +from light.data_model.db.users import UserDB +from light.world.world import WorldConfig +from light.registry.model_pool import ModelPool, ModelTypeName +from light.registry.parlai_model import ParlAIModelConfig +from light.registry.models.acting_score_model import ( + ParlAIPolyencoderActingScoreModelConfig, ) +from light import LIGHT_DIR + +CONFIG_DIR = os.path.join(LIGHT_DIR, "light/registry/models/config") + +from light import LIGHT_DIR + +CONFIG_DIR = os.path.join(LIGHT_DIR, "light/registry/models/config") + +from light import LIGHT_DIR + +CONFIG_DIR = os.path.join(LIGHT_DIR, "light/registry/models/config") + here = os.path.abspath(os.path.dirname(__file__)) @@ -72,12 +102,25 @@ def read_secrets(): tornado_settings["facebook_secret"] = SECRETS["facebook_secret"] -def make_app(FLAGS, ldb, model_resources): +def make_app(FLAGS, ldb, model_pool: ModelPool): worldBuilderApp = BuildApplication(get_handlers(ldb), tornado_settings) + db_config = LightDBConfig(backend=FLAGS.db_backend, file_root=FLAGS.db_root) + episode_db = EpisodeDB(db_config) + user_db = UserDB(db_config) landingApp = LandingApplication( - ldb, FLAGS.hostname, FLAGS.password, tornado_settings + user_db=user_db, + hostname=FLAGS.hostname, + password=FLAGS.password, + given_tornado_settings=tornado_settings, + ) + registryApp = RegistryApplication( + FLAGS, + ldb, + model_pool, + tornado_settings, + episode_db=episode_db, + user_db=user_db, ) - registryApp = RegistryApplication(FLAGS, ldb, model_resources, tornado_settings) rules = [] if FLAGS.disable_builder is None: rules.append(Rule(PathMatches("/builder.*"), worldBuilderApp)) @@ -92,14 +135,9 @@ def make_app(FLAGS, ldb, model_resources): return registryApp -def start_default_game(ldb, registryApp): - _ = registryApp.run_new_game("", ldb) - - -def _run_server(FLAGS, ldb, model_resources): - my_loop = IOLoop.current() +async def _run_server(FLAGS, ldb, model_resources): registry_app = make_app(FLAGS, ldb, model_resources) - my_loop.call_later(1, start_default_game, ldb, registry_app) + _ = await registry_app.run_new_game("", ldb) print( "\nYou can connect to the game at http://%s:%s/" % (FLAGS.hostname, FLAGS.port) @@ -108,41 +146,74 @@ def _run_server(FLAGS, ldb, model_resources): "You can connect to the worldbuilder at http://%s:%s/builder/ \n" % (FLAGS.hostname, FLAGS.port) ) - try: - my_loop.start() - except KeyboardInterrupt: - my_loop.stop() + while True: + await asyncio.sleep(30) -# Update this to load _all_ models for the full game, fix "shared_model_content" -def init_model_resources(FLAGS): +def init_model_pool(FLAGS) -> "ModelPool": light_model_root = FLAGS.light_model_root - dialog_model = FLAGS.dialog_model - act_model = FLAGS.acting_model - scoring_model = FLAGS.roleplaying_score_model_file - generic_act_model = FLAGS.generic_act_model_file - - if dialog_model is None: - return {"shared_model_content": {}} + if light_model_root.endswith("/"): + light_model_root = os.path.expanduser(light_model_root[:-1]) + os.environ["LIGHT_MODEL_ROOT"] = light_model_root - # dialog gen is at `dialog_gen`, other is at `game_speech1`? - shared_model_content = GenerativeHeuristicModelSoul.load_models( - light_model_root + dialog_model, + safety_model_opt_file = FLAGS.safety_model_opt_file.replace( + "LIGHT_MODEL_ROOT", light_model_root + ) + dialog_model_opt_file = FLAGS.dialog_model_opt_file.replace( + "LIGHT_MODEL_ROOT", light_model_root + ) + action_model_opt_file = FLAGS.action_model_opt_file.replace( + "LIGHT_MODEL_ROOT", light_model_root + ) + roleplaying_score_opt_file = FLAGS.roleplaying_score_opt_file.replace( + "LIGHT_MODEL_ROOT", light_model_root + ) + generic_act_opt_file = FLAGS.generic_act_opt_file.replace( + "LIGHT_MODEL_ROOT", light_model_root + ) + parser_opt_file = FLAGS.parser_opt_file.replace( + "LIGHT_MODEL_ROOT", light_model_root ) - resources = {"shared_model_content": shared_model_content} - - if scoring_model is not None: - resources["rpg_model"] = BaseSoul.load_roleplaying_score_model(scoring_model) - shared_model_content["rpg_model"] = resources["rpg_model"] - - if generic_act_model is not None: - generic_act_model_content = BaseSoul.load_generic_act_model(generic_act_model) - resources["generic_act_model"] = generic_act_model_content.share() - shared_model_content["shared_action_model"] = resources["generic_act_model"] - - init_safety_classifier(FLAGS.safety_list) - return resources + model_pool = ModelPool() + + # Register Models + + if len(safety_model_opt_file) > 3: + model_pool.register_model( + ParlAIModelConfig(opt_file=safety_model_opt_file), + [ModelTypeName.SAFETY], + ) + if len(dialog_model_opt_file) > 3: + model_pool.register_model( + ParlAIModelConfig(opt_file=dialog_model_opt_file), + [ModelTypeName.DIALOG], + ) + if len(roleplaying_score_opt_file) > 3: + model_pool.register_model( + ParlAIPolyencoderActingScoreModelConfig( + opt_file=roleplaying_score_opt_file + ), + [ModelTypeName.SCORING], + ) + if len(action_model_opt_file) > 3: + model_pool.register_model( + ParlAIModelConfig(opt_file=action_model_opt_file), + [ModelTypeName.ACTION], + ) + if len(generic_act_opt_file) > 3: + model_pool.register_model( + ParlAIModelConfig(opt_file=generic_act_opt_file), + [ModelTypeName.GENERIC_ACTS], + ) + if len(parser_opt_file) > 3: + model_pool.register_model( + ParlAIModelConfig(opt_file=parser_opt_file), + [ModelTypeName.PARSER], + ) + FLAGS.safety_classifier_path = FLAGS.safety_list + + return model_pool def main(): @@ -205,11 +276,16 @@ def str2bool(v): help="port to run the server on.", ) parser.add_argument( - "--safety-list", - metavar="safety_list", + "--db-root", type=str, - default=os.path.expanduser("~/data/safety/OffensiveLanguage.txt"), - help="Where to find the offensive language list.", + default=here + "/../../../logs/db_root", + ) + parser.add_argument( + "--disable-builder", + metavar="disable_builder", + type=str, + default=None, + help="flag to disable the builder, omit to enable", ) parser.add_argument( "--builder-model", @@ -219,40 +295,42 @@ def str2bool(v): help="Builder model to be loading", ) parser.add_argument( - "--dialog-model", - metavar="dialog_model", + "--safety-list", type=str, - default=None, - help="dialog model to be loading", + default=os.path.expanduser("~/data/safety/OffensiveLanguage.txt"), + help="Where to find the offensive language list.", ) parser.add_argument( - "--acting-model", - metavar="acting_model", + "--safety-model-opt-file", type=str, - default=None, - help="acting model to be loading", + default=os.path.join(CONFIG_DIR, "baseline_adversarial_safety.opt"), + help="Where to find the offensive language list.", ) parser.add_argument( - "--disable-builder", - metavar="disable_builder", + "--dialog-model-opt-file", type=str, - default=None, - help="flag to disable the builder, omit to enable", + default=os.path.join(CONFIG_DIR, "baseline_generative.opt"), + help="dialog model to be loading", ) parser.add_argument( - "--parser-model-file", + "--roleplaying-score-opt-file", type=str, - default="", + default=os.path.join(CONFIG_DIR, "baseline_roleplaying_scorer.opt"), ) parser.add_argument( - "--roleplaying-score-model-file", + "--action-model-opt-file", type=str, - default="", + default=os.path.join(CONFIG_DIR, "baseline_main_act_model.opt"), ) parser.add_argument( - "--generic-act-model-file", + "--generic-act-opt-file", type=str, - default="", + default=os.path.join(CONFIG_DIR, "generic_act_model.opt"), + ) + parser.add_argument( + "--parser-opt-file", + type=str, + default=os.path.join(CONFIG_DIR, "baseline_parser.opt"), ) parser.add_argument( "--is-logging", @@ -260,15 +338,20 @@ def str2bool(v): default=False, help="flag to enable storing logs of interactions", ) + parser.add_argument( + "--db-backend", + type=str, + default="test", + ) FLAGS, _unknown = parser.parse_known_args() print(FLAGS) random.seed(6) numpy.random.seed(6) - model_resources = init_model_resources(FLAGS) + model_pool = init_model_pool(FLAGS) ldb = LIGHTDatabase(FLAGS.data_model_db) - _run_server(FLAGS, ldb, model_resources) + asyncio.run(_run_server(FLAGS, ldb, model_pool)) if __name__ == "__main__": diff --git a/deploy/web/server/telnet_server.py b/deploy/web/server/telnet_server.py index c34bbd805..7aedca713 100644 --- a/deploy/web/server/telnet_server.py +++ b/deploy/web/server/telnet_server.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 + # Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. @@ -11,6 +13,7 @@ import argparse import socket +import asyncio DEFAULT_HOSTNAME = "localhost" DEFAULT_PORT = 35495 @@ -37,7 +40,7 @@ def act(self): if self.text != "": agent_id = self.get_agent_id() print(agent_id + ":" + str(self.text)) - self.g.parse_exec(agent_id, self.text) + asyncio.run(self.g.parse_exec(agent_id, self.text)) self.text = "" def observe(self): @@ -56,7 +59,7 @@ def init_observe(self): only be called the first time this player is initialized. """ agent_id = self.get_agent_id() - self.g.parse_exec(agent_id, "look") + asyncio.run(self.g.parse_exec(agent_id, "look")) self.observe() def is_alive(self): @@ -152,7 +155,7 @@ def main(): random.seed(6) numpy.random.seed(6) - game = GameInstance() + game = asyncio.run(GameInstance.get()) graph = game.world provider = TelnetPlayerProvider(graph, FLAGS.hostname, FLAGS.port) game.register_provider(provider) diff --git a/deploy/web/server/tests/test_tornado_server.py b/deploy/web/server/tests/test_tornado_server.py index 59e2c0273..334662f8c 100644 --- a/deploy/web/server/tests/test_tornado_server.py +++ b/deploy/web/server/tests/test_tornado_server.py @@ -8,6 +8,7 @@ import re import os import ast +import asyncio from tornado import gen, httpclient, ioloop, testing, escape from tornado.testing import AsyncHTTPTestCase, gen_test from tornado.ioloop import IOLoop @@ -67,6 +68,12 @@ URL = f"http://localhost:{PORT}" +def async_return(result): + f = asyncio.Future() + f.set_result(result) + return f + + class MockFlags: def __init__(self, hostname, port): self.hostname = hostname @@ -113,7 +120,7 @@ def test_game_socket(self, mocked_auth, MockStarSpace): @mock.patch( "deploy.web.server.registry.RegistryApplication.run_new_game", - return_value="test", + return_value=async_return("test"), ) @gen_test def test_new_game(self, mocked_auth, MockStarSpace, mocked_method): @@ -598,6 +605,7 @@ def test_landing_page_redirect(self, mocked_auth): @gen_test def test_logout(self, mocked_auth): + self.skipTest("Middle of refactor") """Test that logout clears cookie and redirects""" headers = {"Content-Type": "application/json"} with self.assertRaises(httpclient.HTTPClientError) as cm: diff --git a/deploy/web/server/tornado_server.py b/deploy/web/server/tornado_server.py index 987ab9865..c3f500e56 100644 --- a/deploy/web/server/tornado_server.py +++ b/deploy/web/server/tornado_server.py @@ -1,16 +1,17 @@ +#!/usr/bin/env python3 + # Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - from deploy.web.server.game_instance import ( Player, GameInstance, ) -from light.data_model.light_database import LIGHTDatabase +from light.data_model.db.users import PlayerStatus from light.world.player_provider import PlayerProvider from light.world.quest_loader import QuestLoader -from light.graph.events.graph_events import init_safety_classifier, RewardEvent +from light.graph.events.graph_events import RewardEvent from light.world.souls.tutorial_player_soul import TutorialPlayerSoul import argparse @@ -26,15 +27,11 @@ import asyncio import hashlib from collections import defaultdict -from zmq.eventloop import ioloop - -ioloop.install() # Needs to happen before any tornado imports! - -import tornado.ioloop # noqa E402: gotta install ioloop first -import tornado.web # noqa E402: gotta install ioloop first -import tornado.auth # noqa E402: gotta install ioloop first -import tornado.websocket # noqa E402: gotta install ioloop first -import tornado.escape # noqa E402: gotta install ioloop first +import tornado.ioloop as ioloop +import tornado.web +import tornado.auth +import tornado.websocket +import tornado.escape from light.graph.events.graph_events import ( SoulSpawnEvent, SystemMessageEvent, @@ -46,6 +43,13 @@ if TYPE_CHECKING: from light.graph.elements.graph_nodes import GraphAgent from light.world.world import World + from light.data_model.db.users import UserDB + +# Monkeypatch to allow samesite for iframe usage +from http.cookies import Morsel + +Morsel._reserved["samesite"] = "SameSite" + DEFAULT_PORT = 35496 DEFAULT_HOSTNAME = "localhost" @@ -147,14 +151,14 @@ def get_path(filename): class Application(tornado.web.Application): - def __init__(self, given_tornado_settings=None, db=None): + def __init__(self, given_tornado_settings=None, user_db: Optional["UserDB"] = None): global tornado_settings use_tornado_settings = tornado_settings if given_tornado_settings is not None: use_tornado_settings = given_tornado_settings self.subs = {} self.new_subs = defaultdict(list) - self.db = db + self.user_db = user_db self.user_node_map: Dict[str, Optional["GraphAgent"]] = {} self.world: Optional["World"] = None super(Application, self).__init__(self.get_handlers(), **use_tornado_settings) @@ -165,9 +169,13 @@ def get_handlers(self): # hit in the top level RuleRouter from run_server.py in case this application # is run standalone for some reason. return [ - (r"/game/api/(.*)", ApiHandler, {"app": self, "database": self.db}), - (r"/game(.*)/socket", SocketHandler, {"app": self, "database": self.db}), - (r"/play", GameHandler, {"app": self, "database": self.db}), + (r"/game/api/(.*)", ApiHandler, {"app": self, "user_db": self.user_db}), + ( + r"/game(.*)/socket", + SocketHandler, + {"app": self, "user_db": self.user_db}, + ), + (r"/play", GameHandler, {"app": self, "user_db": self.user_db}), (r"/(.*)", StaticUIHandler, {"path": path_to_build}), ] @@ -185,8 +193,8 @@ def parse_url_path(self, url_path): class SocketHandler(tornado.websocket.WebSocketHandler): - def initialize(self, app, database): - self.db = database + def initialize(self, app, user_db): + self.user_db = user_db self.app = app self.subs = app.subs self.new_subs = app.new_subs @@ -195,7 +203,7 @@ def initialize(self, app, database): self.actions = [] self.player = None self.sid = get_rand_id() - self.db = app.db + self.user_db = app.user_db def safe_write_message(self, msg): try: @@ -210,11 +218,10 @@ def set_player(self, player): self.player = player def user_should_do_tutorial(self, user_id): - with self.db as ldb: - flags = ldb.get_user_flags(user_id) - return not flags.completed_onboarding + player = self.user_db.get_player(user_id) + return player.account_status == PlayerStatus.TUTORIAL - def launch_game_for_user(self, user_id, game_id): + async def launch_game_for_user(self, user_id, game_id): # Check for custom game world if game_id not in self.app.registry.game_instances: self.close() @@ -225,22 +232,23 @@ def launch_game_for_user(self, user_id, game_id): new_player = TornadoPlayerProvider( self, graph_purgatory, - db=self.db, - user=user_id, + user_db=self.user_db, + user_id=user_id, ) - new_player.init_soul() + await new_player.init_soul() self.app.registry.game_instances[game_id].players.append(new_player) - def open(self, game_id): + async def open(self, game_id): """ Open a websocket, validated either by a valid user cookie or by a validated preauth. """ preauth_context = self.get_secure_cookie("preauth_context") - user = None + user_id = None if preauth_context is not None: # If there is any preauth preauth = self.get_secure_cookie("preauth") - user = json.loads(preauth) + hashed_user_id = json.loads(preauth) + user_id = self.user_db.get_player(hashed_user_id).db_id # See if the context matches our generated hash context_token = json.loads(self.get_secure_cookie("context_token")) @@ -253,28 +261,40 @@ def open(self, game_id): return else: user_json = self.get_secure_cookie("user") - if user_json is not None: - user = json.loads(user_json) + if user_json is not None and user_json != "": + user_id = json.loads(user_json) - print("Requesting for user", user) - if user is not None: + print("Requesting for user", user_id) + if user_id is not None: logging.info("Opened new socket from ip: {}".format(self.request.remote_ip)) logging.info("For game: {}".format(game_id)) + + loop = tornado.ioloop.IOLoop.current() + # First check for tutorials - if self.user_should_do_tutorial(user): + if self.user_should_do_tutorial(user_id): # Spawn a tutorial world for this user, or inject them into # their existing world - if user in self.app.registry.tutorial_map: - game_id = self.app.registry.tutorial_map[user] + if user_id in self.app.registry.tutorial_map: + game_id = self.app.registry.tutorial_map[user_id] else: orig_game_id = game_id def on_complete(): time.sleep(TRANSITION_AFTER_TUTORIAL) - self.launch_game_for_user(user, orig_game_id) + loop.spawn_callback( + self.launch_game_for_user, user_id, orig_game_id + ) + + async def create_and_run_tutorial(): + game_id = await self.app.registry.run_tutorial( + user_id, on_complete + ) + await self.launch_game_for_user(user_id, game_id) - game_id = self.app.registry.run_tutorial(user, on_complete) - self.launch_game_for_user(user, game_id) + loop.spawn_callback(create_and_run_tutorial) + else: + loop.spawn_callback(self.launch_game_for_user, user_id, game_id) else: self.close() self.redirect("/#/login") @@ -283,7 +303,7 @@ def send_alive(self): self.safe_write_message(json.dumps({"command": "register", "data": self.sid})) self.alive_sent = True - def on_message(self, message): + async def on_message(self, message): logging.info("from web client: {}".format(message)) msg = tornado.escape.json_decode(tornado.escape.to_basestring(message)) cmd = msg.get("command") @@ -291,9 +311,11 @@ def on_message(self, message): return if cmd == "act": data = msg["data"] - self.player.act(data["text"], data["event_id"]) + await self.player.act(data["text"], data["event_id"]) + elif cmd == "hb": + pass # heartbeats else: - print("THESE COMMANDS HAVE BEEN DEPRICATED") + logging.warning(f"THESE COMMANDS HAVE BEEN DEPRICATED: {msg}") def on_close(self): self.alive = False @@ -304,8 +326,8 @@ def __init__(self, *request, **kwargs): self.include_host = False super(BaseHandler, self).__init__(*request, **kwargs) - def initialize(self, database): - self.db = database + def initialize(self, user_db): + self.user_db = user_db def get_login_url(self): return "/#/login" @@ -314,22 +336,22 @@ def get_current_user(self): user_json = self.get_secure_cookie( "user" ) # Need to refactor into 'get_identity', then have base and preauth handler implementations - if user_json: + if user_json is not None and len(user_json) != 0: user_decoded = tornado.escape.json_decode(user_json) if len(user_decoded) == 0: return None try: - with self.db as ldb: - user_id = ldb.get_user_id(user_decoded) + user = self.user_db.get_player(user_decoded) + user_id = user.db_id except Exception as e: - # User id does not exist in the database, either + # User id does not exist in the user_db, either # we've updated the user table or someone # is fishing :/ # Also can be caused when auth is refreshed print(f"User {user_decoded} tried to log in, but was rejected.") return None - print(f"User {user_decoded, user_id} logged in.") - return user_decoded + print(f"User {user.extern_id, user_id} logged in.") + return user_id else: return None @@ -369,15 +391,15 @@ def write_error(self, status_code, **kwargs): class ApiHandler(BaseHandler): - def initialize(self, app, database): - self.db = database + def initialize(self, app, user_db): + self.user_db = user_db self.app = app @tornado.web.authenticated def get(self, *args): print("THE ARGS", *args) user_json = self.get_secure_cookie("user") - if user_json: + if user_json is not None and user_json != "": user_decoded = tornado.escape.json_decode(user_json) split_inputs = args[0].split("/") @@ -405,7 +427,7 @@ def post(self, *args): data = tornado.escape.json_decode(self.request.body) user_json = self.get_secure_cookie("user") print(data) - if user_json: + if user_json is not None and user_json != "": user_decoded = tornado.escape.json_decode(user_json) split_inputs = args[0].split("/") @@ -441,65 +463,65 @@ def post(self, *args): class LandingApplication(tornado.web.Application): def __init__( self, - database, + user_db: "UserDB", hostname=DEFAULT_HOSTNAME, password="LetsPlay", given_tornado_settings=None, ): - self.db = database + self.user_db = user_db global tornado_settings tornado_settings = given_tornado_settings super(LandingApplication, self).__init__( - self.get_handlers(database, hostname, password), **tornado_settings + self.get_handlers(user_db, hostname, password), **tornado_settings ) - def get_handlers(self, database, hostname=DEFAULT_HOSTNAME, password="LetsPlay"): + def get_handlers(self, user_db, hostname=DEFAULT_HOSTNAME, password="LetsPlay"): return [ - (r"/", LandingHandler, {"database": database}), - (r"/#(.*)", LandingHandler, {"database": database}), - (r"/#/login", LandingHandler, {"database": database}), - (r"/#/error", NotFoundHandler, {"database": database}), + (r"/", LandingHandler, {"user_db": user_db}), + (r"/#(.*)", LandingHandler, {"user_db": user_db}), + (r"/#/login", LandingHandler, {"user_db": user_db}), + (r"/#/error", NotFoundHandler, {"user_db": user_db}), ( r"/preauth/(.*)/(.*)/(.*)/", PreauthGameHandler, - {"database": database, "hostname": hostname}, + {"user_db": user_db, "hostname": hostname}, ), - (r"/play", GameHandler, {"database": database}), - (r"/play/?id=.*", GameHandler, {"database": database}), - (r"/play/*", GameHandler, {"database": database}), - (r"/build", BuildHandler, {"database": database}), + (r"/play", GameHandler, {"user_db": user_db}), + (r"/play/?id=.*", GameHandler, {"user_db": user_db}), + (r"/play/*", GameHandler, {"user_db": user_db}), + (r"/build", BuildHandler, {"user_db": user_db}), ( r"/login", LoginHandler, - {"database": database, "hostname": hostname, "password": password}, + {"user_db": user_db, "hostname": hostname, "password": password}, ), ( r"/auth/fblogin", FacebookOAuth2LoginHandler, - {"database": database, "hostname": hostname, "app": self}, + {"user_db": user_db, "hostname": hostname, "app": self}, ), - (r"/logout", LogoutHandler, {"database": database}), + (r"/logout", LogoutHandler, {"hostname": hostname}), ( r"/terms", StaticPageHandler, - {"database": database, "target": "/html/terms.html"}, + {"user_db": user_db, "target": "/html/terms.html"}, ), ( r"/#/bye", LandingHandler, - {"database": database}, + {"user_db": user_db}, ), ( r"/about", StaticLoggedInPageHandler, - {"database": database, "target": "/html/about.html"}, + {"user_db": user_db, "target": "/html/about.html"}, ), ( r"/profile", StaticLoggedInPageHandler, - {"database": database, "target": "/html/profile.html"}, + {"user_db": user_db, "target": "/html/profile.html"}, ), - (r"/report", ReportHandler, {"database": database}), + (r"/report", ReportHandler, {"user_db": user_db}), (r"/(.*)", StaticUIHandler, {"path": here + "/../build/"}), ] @@ -523,10 +545,10 @@ def get(self): class PreauthGameHandler(BaseHandler): def initialize( self, - database, + user_db, hostname=DEFAULT_HOSTNAME, ): - self.db = database + self.user_db = user_db self.hostname = hostname def validate_login_details(self, user_id, context_id, auth_token) -> bool: @@ -552,14 +574,15 @@ def get(self, user_id, context_id, auth_token): user_hash = get_salted_hash(user_id) context_hash = get_salted_hash(context_id) hashed_user_id = f"preauth-{user_hash}" - with self.db as ldb: - _ = ldb.create_user(hashed_user_id) + self.user_db.create_user(extern_id=hashed_user_id, is_preauth=True) self.set_secure_cookie( "preauth", tornado.escape.json_encode(hashed_user_id), expires_days=1, domain=self.hostname, httponly=True, + samesite=None, + secure=True, ) self.set_secure_cookie( "preauth_context", @@ -567,6 +590,8 @@ def get(self, user_id, context_id, auth_token): expires_days=1, domain=self.hostname, httponly=True, + samesite=None, + secure=True, ) self.set_secure_cookie( "context_token", @@ -574,6 +599,8 @@ def get(self, user_id, context_id, auth_token): expires_days=1, domain=self.hostname, httponly=True, + samesite=None, + secure=True, ) self.render(here + "/../build/game.html") else: @@ -586,9 +613,9 @@ def get(self): class StaticPageHandler(BaseHandler): - def initialize(self, target, database): + def initialize(self, target, user_db): self.target_page = here + target - self.db = database + self.user_db = user_db def get(self): self.render(self.target_page) @@ -607,24 +634,40 @@ class FacebookOAuth2LoginHandler(BaseHandler, tornado.auth.FacebookGraphMixin): def initialize( self, - database, + user_db, hostname, app, ): self.app = app - self.db = database + self.user_db = user_db self.hostname = hostname + async def get_privacy_restricted_user_id(self, redirect_url) -> str: + """ + While we already don't request user input for our API key, + this method ensures that we're only getting the `id` key. + + DO NOT CHANGE THIS METHOD + """ + fb_user = await self.get_authenticated_user( + redirect_uri=redirect_url, + client_id=self.app.settings["facebook_api_key"], + client_secret=self.app.settings["facebook_secret"], + code=self.get_argument("code"), + ) + return fb_user["id"] + async def get(self): redirect = "https://" + self.request.host + "/auth/fblogin" if self.get_argument("code", False): - fb_user = await self.get_authenticated_user( - redirect_uri=redirect, - client_id=self.app.settings["facebook_api_key"], - client_secret=self.app.settings["facebook_secret"], - code=self.get_argument("code"), + fb_app_scoped_id = await self.get_privacy_restricted_user_id( + redirect_url=redirect, + ) + + user_id = self.user_db.create_user( + extern_id=fb_app_scoped_id, is_preauth=False ) - self.set_current_user(fb_user["id"]) + self.set_current_user(user_id) self.redirect("/play/") return self.authorize_redirect( @@ -632,29 +675,33 @@ async def get(self): client_id=self.app.settings["facebook_api_key"], ) - def set_current_user(self, user): - if user: - with self.db as ldb: - _ = ldb.create_user(user) + def set_current_user(self, user_id): + if user_id: self.set_secure_cookie( "user", - tornado.escape.json_encode(user), + tornado.escape.json_encode(user_id), domain=self.hostname, secure=True, httponly=True, ) else: - self.clear_cookie("user") + self.set_secure_cookie( + "user", + "", + domain=self.hostname, + secure=True, + httponly=True, + ) class LoginHandler(BaseHandler): def initialize( self, - database, + user_db, hostname=DEFAULT_HOSTNAME, password="LetsPlay", ): - self.db = database + self.user_db = user_db self.hostname = hostname self.password = password @@ -666,31 +713,48 @@ def post(self): name = self.get_argument("name", "") password = self.get_argument("password", "") if password == self.password: - with self.db as ldb: - _ = ldb.create_user(name) - self.set_current_user(name) + user_id = self.user_db.create_user(extern_id=name, is_preauth=False) + self.set_current_user(user_id) # self.redirect(self.get_argument("next", "/")) self.redirect("/play/") else: error_msg = "?error=" + tornado.escape.url_escape("incorrect") self.redirect("/#/login" + error_msg) - def set_current_user(self, user): - if user: + def set_current_user(self, user_id): + if user_id: self.set_secure_cookie( "user", - tornado.escape.json_encode(user), + tornado.escape.json_encode(user_id), domain=self.hostname, # secure=True, login handler is for local testing httponly=True, ) else: - self.clear_cookie("user") + self.set_secure_cookie( + "user", + "", + domain=self.hostname, + secure=True, + httponly=True, + ) class LogoutHandler(BaseHandler): + def initialize( + self, + hostname=DEFAULT_HOSTNAME, + ): + self.hostname = hostname + def get(self): - self.clear_cookie("user") + self.set_secure_cookie( + "user", + "", + domain=self.hostname, + secure=True, + httponly=True, + ) self.redirect("/#/bye") @@ -715,7 +779,7 @@ class TornadoPlayerProvider(PlayerProvider): Player Provider for the web app """ - def __init__(self, socket, purgatory, db=None, user=None, context=None): + def __init__(self, socket, purgatory, user_db=None, user_id=None, context=None): self.socket = socket self.player_soul = None self.purgatory = purgatory @@ -723,8 +787,8 @@ def __init__(self, socket, purgatory, db=None, user=None, context=None): self.quest_loader = quest_loader socket.set_player(self) socket.send_alive() - self.db = db - self.user = user + self.user_db = user_db + self.user_id = user_id self.context = context # TODO a TornadoPlayerProvider refactor is likely desired, combining # the APIs for socket and HTTP requests to use logged in user @@ -735,13 +799,20 @@ def __init__(self, socket, purgatory, db=None, user=None, context=None): def register_soul(self, soul: "PlayerSoul"): """Save the soul as a local player soul""" self.player_soul = soul - if self.user is not None: - if self.db is not None: - with self.db as ldb: - ldb.initialize_agent_score(soul.target_node, self.user) - self.app.user_node_map[self.user] = soul.target_node - - def player_observe_event(self, soul: "PlayerSoul", event: "GraphEvent"): + if self.user_id is not None: + if self.user_db is not None: + base_score = self.user_db.get_agent_score(self.user_id) + # TODO refactor into elsewhere + target_node = soul.target_node + target_node.xp = base_score.score + target_node.reward_xp = base_score.reward_xp + target_node._base_class_experience = 0 + target_node._num_turns = 0 + target_node._base_experience = target_node.xp + target_node._base_reward_points = target_node.reward_xp + self.app.user_node_map[self.user_id] = soul.target_node + + async def player_observe_event(self, soul: "PlayerSoul", event: "GraphEvent"): """ Send observation forward to the player in whatever format the player expects it to be. @@ -762,18 +833,18 @@ def player_observe_event(self, soul: "PlayerSoul", event: "GraphEvent"): isinstance(event, DeathEvent) and event.actor.node_id == soul.target_node.node_id ): - self.purgatory.clear_soul(soul.target_node) + await self.purgatory.clear_soul(soul.target_node) - def act(self, action_data, event_id: Optional[str] = None): + async def act(self, action_data, event_id: Optional[str] = None): if self.player_soul is not None and self.player_soul.is_reaped: self.player_soul = None if self.player_soul is None: - self.init_soul() + await self.init_soul() return - player_agent = self.player_soul.handle_act(action_data, event_id) + player_agent = await self.player_soul.handle_act(action_data, event_id) - def init_soul(self): - self.purgatory.get_soul_for_player(self) + async def init_soul(self): + await self.purgatory.get_soul_for_player(self) if self.player_soul is None: dat = {"text": "Could not find a soul for you, sorry"} self.socket.safe_write_message( @@ -784,14 +855,14 @@ def init_soul(self): SoulSpawnEvent(soul_id, self.player_soul.target_node).execute( self.purgatory.world ) - self.player_soul.handle_act("look") - self.player_soul.target_node.user_id = self.user + await self.player_soul.handle_act("look") + self.player_soul.target_node.user_id = self.user_id self.player_soul.target_node.context_id = self.context def is_alive(self): return self.socket.alive - def on_reap_soul(self, soul): + async def on_reap_soul(self, soul): action = SystemMessageEvent( soul.target_node, [], @@ -808,11 +879,26 @@ def on_reap_soul(self, soul): self.socket.safe_write_message( json.dumps({"command": "actions", "data": [dat]}) ) - if self.user is not None: - if self.db is not None and not isinstance(soul, TutorialPlayerSoul): - with self.db as ldb: - ldb.store_agent_score(soul.target_node, self.user) - self.app.user_node_map[self.user] = None + if self.user_id is not None: + if self.user_db is not None and not isinstance(soul, TutorialPlayerSoul): + # TODO refactor out from server logic + target_node = soul.target_node + gained_experience = target_node.xp - target_node._base_experience + net_reward_points = ( + target_node.reward_xp - target_node._base_reward_points + ) + db_id = target_node.db_id if target_node.db_id is not None else "" + self.user_db.update_agent_score( + player_id=self.user_id, + agent_name_id=db_id, + points=gained_experience, + num_turns=target_node._num_turns, + reward_change=net_reward_points, + ) + target_node._num_turns = 0 + target_node._base_experience = target_node.xp + target_node._base_reward_points = target_node.reward_xp + self.app.user_node_map[self.user_id] = None class TornadoPlayerFactory: @@ -827,13 +913,13 @@ def __init__( registry, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT, - db=None, + user_db=None, listening=False, given_tornado_settings=None, ): self.registry = registry self.app = None - self.db = db + self.user_db = user_db def _run_server(): nonlocal listening @@ -842,7 +928,7 @@ def _run_server(): nonlocal port self.my_loop = ioloop.IOLoop() self.app = Application( - given_tornado_settings=given_tornado_settings, db=self.db + given_tornado_settings=given_tornado_settings, user_db=self.user_db ) self.app.registry = self.registry if listening: @@ -897,8 +983,9 @@ def main(): help="port to run the server on.", ) FLAGS = parser.parse_args() - - init_safety_classifier(os.path.expanduser("~/data/safety/OffensiveLanguage.txt")) + FLAGS.safety_classifier_path = os.path.expanduser( + "~/data/safety/OffensiveLanguage.txt" + ) random.seed(6) numpy.random.seed(6) @@ -908,7 +995,7 @@ def main(): None, FLAGS.hostname, FLAGS.port, listening=True ) else: - game = GameInstance(game_id=0, ldb=ldb) + game = asyncio.run(GameInstance(game_id=0, ldb=ldb)) graph = game.world provider = TornadoPlayerFactory( graph, FLAGS.hostname, FLAGS.port, db=ldb, listening=True diff --git a/light/data_model/db/__init__.py b/light/data_model/db/__init__.py new file mode 100644 index 000000000..c022998c1 --- /dev/null +++ b/light/data_model/db/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/light/data_model/db/base.py b/light/data_model/db/base.py new file mode 100644 index 000000000..39cd34b31 --- /dev/null +++ b/light/data_model/db/base.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from abc import ABC, abstractmethod +from omegaconf import MISSING, DictConfig +from sqlalchemy import create_engine +from enum import Enum +from typing import Optional, Union, Dict, Any, Type +from uuid import uuid4 +from dataclasses import dataclass +from tempfile import mkdtemp +import shutil +import os +import json + +from hydra.core.config_store import ConfigStore + +DEFAULT_LOG_PATH = "".join( + [os.path.abspath(os.path.dirname(__file__)), "/../../../logs"] +) + + +@dataclass +class LightDBConfig: + backend: str = "test" + file_root: Optional[str] = DEFAULT_LOG_PATH + + +@dataclass +class LightAWSDBConfig(LightDBConfig): + backend: str = "aws-postgres" + file_root: str = MISSING + db_address: str = MISSING + db_user: str = MISSING + db_pass: str = MISSING + + +cs = ConfigStore.instance() +cs.store(name="db/base", node=LightDBConfig) +cs.store(name="db/aws-postgres", node=LightAWSDBConfig) + + +class DBStatus(Enum): + """Current review status for contents""" + + REVIEW = "unreviewed" + PRODUCTION = "production" + REJECTED = "rejected" + QUESTIONABLE = "questionable" # For low quality, or borderline content + ACCEPTED = "accepted" + + +class DBSplitType(Enum): + """Splits in the LIGHT Environment DB""" + + UNSET = "no_split_set" + TRAIN = "train" + TEST = "test" + VALID = "valid" + UNSEEN = "unseen_test" + + +class HasDBIDMixin: + """Simple mixin for classes that define their own DBID schema""" + + ID_PREFIX: str # ID prefix should be 3 characters max. + + @classmethod + def get_id(cls: Type["HasDBIDMixin"]) -> str: + """Create an ID for this class""" + return f"{cls.ID_PREFIX}-{uuid4()}" + + @classmethod + def is_id(cls: Type["HasDBIDMixin"], test_id: str) -> bool: + """Check if a given ID refers to this class""" + return test_id.startswith(f"{cls.ID_PREFIX}-") + + +class BaseDB(ABC): + """ + Core database class underneath the LIGHT datamodel that allows for + linking to production MySQL on RDS when live, and SQLite when testing + or using LIGHT locally. Also abstracts away file reading and writing, + which can be done with either buckets or local file manipulation. + + Output conversions of production dbs to local copies done + currently with: https://github.com/dumblob/mysql2sqlite + """ + + DB_TYPE: str + + def __init__(self, config: "DictConfig"): + """ + Create this database, either connecting to a remote host or local + files and instances. + """ + self.backend = config.backend + if config.backend == "test": + self.engine = create_engine("sqlite+pysqlite:///:memory:", future=True) + self.made_temp_dir = config.file_root is None + if self.made_temp_dir: + self.file_root = mkdtemp() + else: + self.file_root = config.file_root + elif config.backend == "local": + self.file_root = config.file_root + db_path = os.path.join(self.file_root, f"{self.DB_TYPE}.db") + self.engine = create_engine(f"sqlite:////{db_path}") + elif config.backend == "aws-postgres": + try: + import psycopg2 + import boto3 + except ImportError: + print( + "For aws-postgres usage, you must also `pip install mysqlclient boto3 psycopg2-binary" + ) + raise + # Get DB registered and functioning + self.db_address = config.db_address + db_address = config.db_address + login_user = config.db_user + login_pass = config.db_pass + self.engine = create_engine( + f"postgresql://{login_user}:{login_pass}@{db_address}:5432/postgres" + ) + + # Connect to the s3 filestore + self.file_root = config.file_root # file root is a s3 bucket address + s3 = boto3.resource("s3") + self.bucket = s3.Bucket(self.file_root) + else: + raise NotImplementedError( + f"Provided backend {config.backend} doens't exist" + ) + self._complete_init(config) + + @abstractmethod + def _complete_init(self, config: "DictConfig"): + """ + Complete implementation-specific initialization + """ + + @abstractmethod + def _validate_init(self): + """ + Ensure that this database is initialized correctly + """ + + def _enforce_get_first(self, session, stmt, error_text) -> Any: + """ + Enforce getting the first element using stmt, raise a key_error + with error_text if fails to find + """ + result = session.scalars(stmt).first() + if result is None: + raise KeyError(error_text) + return result + + def file_path_exists(self, file_path: str) -> bool: + """ + Determine if the given file path exists on this storage + """ + if self.backend in ["test", "local"]: + full_path = os.path.join(self.file_root, file_path) + return os.path.exists(full_path) + elif self.backend in ["aws-postgres"]: + import botocore + + try: + self.bucket.Object(file_path).load() + return True + except botocore.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "404": + # The object does not exist. + return False + else: + # Something else has gone wrong. + raise + else: + raise NotImplementedError(f"Backend {self.backend} is not implemented") + + def write_data_to_file( + self, data: Union[str, Dict[str, Any]], filename: str, json_encode: bool = False + ) -> None: + """ + Write the given data to the provided filename + in the correct storage location (local or remote) + """ + if self.backend in ["test", "local"]: + full_path = os.path.join(self.file_root, filename) + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, "w+") as target_file: + if json_encode: + json.dump(data, target_file) + else: + target_file.write(data) + elif self.backend in ["aws-postgres"]: + if json_encode: + data = json.dumps(data) + self.bucket.Object(filename).put(Body=data) + else: + raise NotImplementedError(f"Backend {self.backend} is not implemented") + + def read_data_from_file( + self, filename: str, json_encoded: bool = False + ) -> Union[str, Dict[str, Any]]: + """ + Read the data from the given filename from wherever it + is currently stored (local or remote) + """ + if self.backend in ["test", "local"]: + full_path = os.path.join(self.file_root, filename) + with open(full_path, "r") as target_file: + if json_encoded: + return json.load(target_file) + else: + return target_file.read() + elif self.backend in ["aws-postgres"]: + data = self.bucket.Object(filename).get()["Body"] + if json_encoded: + return json.loads(data) + else: + return data + else: + raise NotImplementedError(f"Backend {self.backend} is not implemented") + + def shutdown(self): + if self.backend == "test": + if self.made_temp_dir: + shutil.rmtree(self.file_root) diff --git a/light/data_model/db/environment.py b/light/data_model/db/environment.py new file mode 100644 index 000000000..79ffbfdf9 --- /dev/null +++ b/light/data_model/db/environment.py @@ -0,0 +1,1986 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from light.data_model.db.base import BaseDB, DBStatus, DBSplitType, HasDBIDMixin +from light.data_model.db.users import DBPlayer +from light.graph.structured_graph import OOGraph +from omegaconf import MISSING, DictConfig +from typing import ( + Optional, + List, + Tuple, + Union, + Dict, + Any, + Set, + Type, + cast, + TYPE_CHECKING, +) +from sqlalchemy import ( + insert, + select, + Enum, + Column, + Integer, + String, + Float, + ForeignKey, + Boolean, + UniqueConstraint, +) +from sqlalchemy.orm import declarative_base, relationship, Session, join +import sqlalchemy.exc + +import enum +import os +import time + +SQLBase = declarative_base() + +FILE_PATH_KEY = "env" +GRAPH_PATH_KEY = "graphs" +QUEST_PATH_KEY = "quests" +USR_KEY = DBPlayer.ID_PREFIX + +SCRUBBED_USER_ID = "scrubbed_user" +MAX_RETENTION = 60 * 60 * 24 * 60 # 60 days + +BASE_NAME_LENGTH_CAP = 96 +WORLD_NAME_LENGTH_CAP = 128 +EDGE_LABEL_LENGTH_CAP = 64 +PERSONA_LENGTH_CAP = 512 +DESCRIPTION_LENGTH_CAP = 512 +NAME_PREFIX_LENGTH = 32 +ID_STRING_LENGTH = 40 +QUEST_MOTIVATION_LENGTH = 128 +REPORT_REASON_LENGTH = 1024 +FILE_PATH_LENGTH_CAP = 96 + + +# Name Key Components - Should be text searchable + + +class DBNameKey(HasDBIDMixin): + """ + Class for the shared db base components, as all have just an + id and a name + """ + + db_id = Column(String(ID_STRING_LENGTH), primary_key=True) + name = Column(String(BASE_NAME_LENGTH_CAP), nullable=False, index=True, unique=True) + status = Column(Enum(DBStatus), nullable=False, index=True) + split = Column(Enum(DBSplitType), nullable=False, index=True) + + +class DBAgentName(DBNameKey, SQLBase): + """ + Class containing the expected elements for an agent name, + with any supporting methods + """ + + __tablename__ = "agent_names" + ID_PREFIX = "AGN" + + def __repr__(self): + return f"DBAgentName({self.db_id!r}| {self.name})" + + +class DBObjectName(DBNameKey, SQLBase): + """ + Class containing the expected elements for an object name, + with any supporting methods + """ + + __tablename__ = "object_names" + ID_PREFIX = "OBN" + + def __repr__(self): + return f"DBObjectName({self.db_id!r}| {self.name})" + + +class DBRoomName(DBNameKey, SQLBase): + """ + Class containing the expected elements for a room name, + with any supporting methods + """ + + __tablename__ = "room_names" + ID_PREFIX = "RMN" + + def __repr__(self): + return f"DBRoomName({self.db_id!r}| {self.name})" + + +# Graph nodes + + +class DBElem(HasDBIDMixin): + """Class for shared attributes for all graph model components""" + + db_id = Column(String(ID_STRING_LENGTH), primary_key=True) + name = Column(String(BASE_NAME_LENGTH_CAP), nullable=False, index=True) + built_occurrences = Column(Integer, nullable=False, default=0) + status = Column(Enum(DBStatus), nullable=False, index=True) + create_timestamp = Column(Float, nullable=False) + creator_id = Column( + String(ID_STRING_LENGTH) + ) # temp retain the creator ID for new things + + _text_edges: Optional[List["DBTextEdge"]] = None + _node_edges: Optional[List["DBEdge"]] = None + _attributes: Optional[List["DBNodeAttribute"]] = None + + @property + def text_edges(self) -> List["DBTextEdge"]: + """Return the cached text edges, if available""" + if self._text_edges is not None: + return self._text_edges + + use_session = Session.object_session(self) + assert ( + use_session is not None + ), "Must be in-session if not cached. Otherwise call `load_edges` first" + stmt = select(DBTextEdge).where(DBTextEdge.parent_id == self.db_id) + text_edges = use_session.scalars(stmt).all() + self._text_edges = text_edges + return text_edges + + @property + def node_edges(self) -> List["DBEdge"]: + """Return the cached node edges, if available""" + if self._node_edges is not None: + return self._node_edges + + use_session = Session.object_session(self) + assert ( + use_session is not None + ), "Must be in-session if not cached. Otherwise call `load_edges` first" + stmt = select(DBEdge).where(DBEdge.parent_id == self.db_id) + node_edges = use_session.scalars(stmt).all() + self._node_edges = node_edges + for node_edge in node_edges: + # Force load the children + assert node_edge.child is not None + return node_edges + + def load_edges(self, db: "EnvDB", skip_cache=False) -> None: + """Expand the node and text edges for this entity""" + if db._cache is not None and not skip_cache: + # Load the edges from the cache + assert self.db_id is not None + node = db._cache["all"][self.db_id] + self._text_edges = node.text_edges.copy() + self._node_edges = node.node_edges.copy() + else: + # Load everything in a session + with Session(db.engine) as session: + session.add(self) + assert self.text_edges is not None + assert self.node_edges is not None + session.expunge_all() + + @property + def attributes(self) -> List["DBNodeAttribute"]: + if self._attributes is not None: + return self._attributes + + use_session = Session.object_session(self) + assert ( + use_session is not None + ), "Must be in-session if not cached. Otherwise call `load_attributes` first" + stmt = select(DBNodeAttribute).where(DBNodeAttribute.target_id == self.db_id) + attributes = use_session.scalars(stmt).all() + self._attributes = attributes + return attributes + + def load_attributes(self, db: "EnvDB", skip_cache=False) -> None: + """Expand arbitrary attributes for this node""" + if db._cache is not None and not skip_cache: + # Load the edges from the cache + assert self.db_id is not None + node = db._cache["all"][self.db_id] + self._attributes = node.attributes.copy() + else: + # Load everything in a session + with Session(db.engine) as session: + session.add(self) + assert self.attributes is not None + session.expunge_all() + + +class DBAgent(DBElem, SQLBase): + """ + Class containing the expected elements for an agent, + with any supporting methods + """ + + __tablename__ = "agents" + __table_args__ = ( + UniqueConstraint( + "name", "persona", "physical_description", name="text_characteristics" + ), + ) + ID_PREFIX = "AGE" + + base_id: str = Column(ForeignKey("agent_names.db_id"), nullable=False) + persona = Column(String(PERSONA_LENGTH_CAP), nullable=False, index=True) + physical_description = Column( + String(DESCRIPTION_LENGTH_CAP), nullable=False, index=True + ) + name_prefix = Column(String(NAME_PREFIX_LENGTH), nullable=False) + is_plural = Column(Boolean) + size = Column(Integer) + contain_size = Column(Integer) + constitution = Column(Float) + charisma = Column(Float) + strength = Column(Float) + dexterity = Column(Float) + intelligence = Column(Float) + wisdom = Column(Float) + base_name: List["DBAgent"] = relationship( + "DBAgentName", backref="agents", foreign_keys=[base_id] + ) + + def __repr__(self): + return f"DBAgent({self.db_id!r}| {self.name})" + + +class DBObject(DBElem, SQLBase): + """ + Class containing the expected elements for an object, + with any supporting methods + """ + + __tablename__ = "objects" + __table_args__ = ( + UniqueConstraint("name", "physical_description", name="text_characteristics"), + ) + ID_PREFIX = "OBE" + + base_id: str = Column(ForeignKey("object_names.db_id"), nullable=False) + physical_description = Column( + String(DESCRIPTION_LENGTH_CAP), nullable=False, index=True + ) + is_container = Column(Float) + is_drink = Column(Float) + is_food = Column(Float) + is_gettable = Column(Float) + is_surface = Column(Float) + is_wearable = Column(Float) + is_weapon = Column(Float) + name_prefix = Column(String(NAME_PREFIX_LENGTH), nullable=False) + is_plural = Column(Boolean) + size = Column(Integer) + contain_size = Column(Integer) + value = Column(Float) + rarity = Column(Float) + base_name: List["DBObject"] = relationship( + "DBObjectName", backref="objects", foreign_keys=[base_id] + ) + + def __repr__(self): + return f"DBObject({self.db_id!r}| {self.name})" + + +class DBRoomInsideType(enum.Enum): + """Types of indoor or outdoor statuses for rooms""" + + INDOORS = "indoors" + ENCLOSED = "enclosed" + COVERED = "covered" + OUTSIDE = "outside" + HYBRID = "hybrid" + MULTI_ROOM = "multi_room" + OTHER = "other" + UNKNOWN = "unknown" + + +class DBRoom(DBElem, SQLBase): + """ + Class containing the expected elements for a room, + with any supporting methods + """ + + __tablename__ = "rooms" + __table_args__ = ( + UniqueConstraint( + "name", "description", "backstory", name="text_characteristics" + ), + ) + ID_PREFIX = "RME" + + base_id: str = Column(ForeignKey("room_names.db_id"), nullable=False) + description = Column(String(DESCRIPTION_LENGTH_CAP), nullable=False, index=True) + backstory = Column(String(DESCRIPTION_LENGTH_CAP), nullable=False, index=True) + size = Column(Integer) + indoor_status = Column(Enum(DBRoomInsideType), nullable=False) + rarity = Column(Float) + base_name: List["DBRoom"] = relationship( + "DBRoomName", backref="rooms", foreign_keys=[base_id] + ) + + def __repr__(self): + return f"DBRoom({self.db_id!r}| {self.name})" + + +class DBNodeAttribute(HasDBIDMixin, SQLBase): + """ + Class containing unique attribute values for specific element instances + """ + + __tablename__ = "node_attributes" + __table_args__ = ( + UniqueConstraint( + "target_id", "attribute_name", "attribute_value_string", name="att_details" + ), + ) + ID_PREFIX = "ATT" + + db_id = Column(String(ID_STRING_LENGTH), primary_key=True) + target_id = Column(String(ID_STRING_LENGTH), nullable=False, index=True) + attribute_name = Column(String(EDGE_LABEL_LENGTH_CAP), nullable=False, index=True) + attribute_value_string = Column(String(EDGE_LABEL_LENGTH_CAP), nullable=False) + status: DBStatus = Column(Enum(DBStatus), nullable=False, index=True) + creator_id = Column( + String(ID_STRING_LENGTH) + ) # temp retain the creator ID for new things + create_timestamp = Column(Float, nullable=False) + + +# Graph edges and attributes + + +class DBEdgeType(enum.Enum): + """Edges in the LIGHT Environment DB""" + + CONTAINS = "contains" + MAY_CONTAIN = "may_contain" + WEARING = "wearing" # only agent-object + MAY_WEAR = "may_wear" # only agent-object + WIELDING = "wielding" # only agent-object + MAY_WIELD = "may_wield" # only agent-object + CONTAINED_IN = "contained_in" # only outward text edge + MAY_BE_CONTAINED_IN = "may_be_contained_in" # only outward text edge + NEIGHBOR = "neighboring" # Only room-room + MAY_BE_NEIGHBOR = "may_be_neighboring" # Only room-room + + +class DBEdgeBase(HasDBIDMixin): + """Base attributes for an edge as stored in the environment DB""" + + db_id = Column(String(ID_STRING_LENGTH), primary_key=True) + parent_id = Column(String(ID_STRING_LENGTH), nullable=False) + edge_type = Column(Enum(DBEdgeType), nullable=False) + status = Column(Enum(DBStatus), nullable=False, index=True) + edge_label = Column(String(EDGE_LABEL_LENGTH_CAP), nullable=False) + create_timestamp = Column(Float, nullable=False) + creator_id = Column( + String(ID_STRING_LENGTH) + ) # temp retain the creator ID for new things + + +class DBEdge(DBEdgeBase, SQLBase): + """Class for edges between two GraphNodes registered in the DB""" + + __tablename__ = "edges" + __table_args__ = ( + UniqueConstraint( + "parent_id", "child_id", "edge_type", "edge_label", name="edge_details" + ), + ) + ID_PREFIX = "NED" + + child_id = Column(String(ID_STRING_LENGTH), nullable=False) + built_occurrences = Column(Integer, nullable=False, default=0) + + _child: Optional[DBElem] = None + + @property + def child(self) -> DBElem: + """Follow this edge and load the child node""" + if self._child is not None: + return self._child + + use_session = Session.object_session(self) + assert ( + use_session is not None + ), "Must be in-session if not cached. Otherwise call `expand_edge` first" + # Determine return type + # This may be better wrapped as a utility of EnvDB, + # but that's not in-scope here + assert self.child_id is not None + TargetClass: Type[DBElem] + if DBAgent.is_id(self.child_id): + TargetClass = DBAgent + stmt = select(DBAgent) + elif DBObject.is_id(self.child_id): + TargetClass = DBObject + stmt = select(DBObject) + elif DBRoom.is_id(self.child_id): + TargetClass = DBRoom + stmt = select(DBRoom) + else: + raise AssertionError("Edge type was none of Agent, room, or object") + stmt = select(TargetClass).where(TargetClass.db_id == self.child_id) + child = use_session.scalars(stmt).one() + self._child = child + return child + + def expand_edge(self, db: "EnvDB") -> None: + """Expand the node and text edges for this entity""" + if db._cache is not None: + # Load the edges from the cache + assert self.child_id is not None + node = db._cache["all"][self.child_id] + self._child = node + else: + # Load everything in a session + with Session(db.engine) as session: + session.add(self) + assert self.child is not None + session.expunge_all() + + def __repr__(self): + return ( + f"DBEdge({self.db_id!r}| {self.parent_id}-{self.edge_type}-{self.child_id})" + ) + + +class DBTextEdge(DBEdgeBase, SQLBase): + """Class for edges between a GraphNodes and a new entity in the DB""" + + __tablename__ = "text_edges" + __table_args__ = ( + UniqueConstraint( + "parent_id", "child_text", "edge_type", "edge_label", name="edge_details" + ), + ) + ID_PREFIX = "TED" + + child_text = Column(String(BASE_NAME_LENGTH_CAP), nullable=False, index=True) + + def __repr__(self): + return f"DBTextEdge({self.db_id!r}| {self.parent_id}-{self.edge_type}-{self.child_text})" + + +# Other + + +class DBEdit(SQLBase, HasDBIDMixin): + """Suggested change to some DBElem content""" + + __tablename__ = "edits" + ID_PREFIX = "EDT" + + db_id = Column(String(ID_STRING_LENGTH), primary_key=True) + editor_id = Column(String(ID_STRING_LENGTH)) # temp retain the associated user ID + node_id = Column( + String(ID_STRING_LENGTH), nullable=False, index=True + ) # Id of entry in table + field = Column(String(ID_STRING_LENGTH), nullable=False) # name of field in table + status = Column(Enum(DBStatus), nullable=False, index=True) + old_value = Column(String(DESCRIPTION_LENGTH_CAP), nullable=False, index=True) + new_value = Column(String(DESCRIPTION_LENGTH_CAP), nullable=False, index=True) + create_timestamp = Column(Float, nullable=False) + + def accept_and_apply(self, db: "EnvDB") -> None: + """Accept and apply the given edit""" + # TODO Implement + raise NotImplementedError + + def reject_edit(self, db: "EnvDB") -> None: + """Reject the given edit""" + with Session(db.engine) as session: + session.add(self) + self.status = DBStatus.REJECTED + session.flush() + session.commit() + session.expunge_all() + + def __repr__(self): + return f"DBEdit({self.db_id!r}| {self.node_id}-{self.field}-{self.status})" + + +class DBFlagTargetType(enum.Enum): + """Types of flags""" + + FLAG_USER = "user_flag" # Something wrong about a user's behavior + FLAG_UTTERANCE = "utterance_flag" # Something specifically wrong about + FLAG_ENVIRONMENT = "env_flag" # Flag something inappropriate in the environment + + +class DBFlag(HasDBIDMixin, SQLBase): + """User-flagged content of some type""" + + __tablename__ = "flags" + ID_PREFIX = "FLG" + + db_id = Column(String(ID_STRING_LENGTH), primary_key=True) + flag_type = Column(Enum(DBFlagTargetType), nullable=False) + user_id = Column(String(ID_STRING_LENGTH), nullable=False, index=True) + target_id = Column(String(ID_STRING_LENGTH), nullable=False, index=True) + reason = Column(String(REPORT_REASON_LENGTH)) + status = Column(Enum(DBStatus), nullable=False, index=True) + create_timestamp = Column(Float, nullable=False) + + def __repr__(self): + return f"DBFlag({self.db_id!r}| {self.target_id}-{self.flag_type})" + + +class DBQuestTargetType(enum.Enum): + """Types of quest targets""" + + TEXT_ONLY = "text_only" # only a map from character to motivation + TARGET_ACTION = "target_action" # map from motivation to target action + + +class DBQuest(SQLBase, HasDBIDMixin): + """Stores quest information for breaking down motivations""" + + __tablename__ = "quests" + ID_PREFIX = "QST" + + db_id = Column(String(ID_STRING_LENGTH), primary_key=True) + agent_id: str = Column(ForeignKey("agents.db_id"), nullable=False) + parent_id: Optional[str] = Column( + ForeignKey("quests.db_id") + ) # Map to possible parent + text_motivation = Column(String(QUEST_MOTIVATION_LENGTH), nullable=False) + target_type = Column(Enum(DBQuestTargetType), nullable=False) + target = Column(String(QUEST_MOTIVATION_LENGTH)) + status = Column(Enum(DBStatus), nullable=False, index=True) + origin_filepath = Column(String(FILE_PATH_LENGTH_CAP)) + position = Column(Integer) # If subgoal of a parent, which substep? + creator_id = Column( + String(ID_STRING_LENGTH) + ) # temp retain the creator ID for new things + create_timestamp = Column(Float, nullable=False) + + _subgoals: Optional[List["DBQuest"]] = None + _parent_chain: Optional[List["DBQuest"]] = None + + @property + def subgoals(self) -> List["DBQuest"]: + """ + Return the list of DBQuests that are a direct + subgoal of this one + """ + if self._subgoals is not None: + return self._subgoals + + use_session = Session.object_session(self) + assert ( + use_session is not None + ), "Must be in-session if not cached. Otherwise call `load_relations` first" + + subgoals = ( + use_session.query(DBQuest).where(DBQuest.parent_id == self.db_id).all() + ) + subgoals = sorted(subgoals, key=lambda x: x.position) + self._subgoals = subgoals + return subgoals + + @property + def parent_chain(self) -> List["DBQuest"]: + """ + Return the chain of quests/motivations above this level, + starting from the highest down to this one + """ + if self._parent_chain is not None: + return self._parent_chain + + use_session = Session.object_session(self) + assert ( + use_session is not None + ), "Must be in-session if not cached. Otherwise call `load_relations` first" + + parent_chain = [self] + curr_item = self + while curr_item.parent_id is not None: + parent_item = use_session.query(DBQuest).get(curr_item.parent_id) + assert parent_item is not None + parent_chain.append(parent_item) + curr_item = parent_item + parent_chain = list(reversed(parent_chain)) + + self._parent_chain = parent_chain + return parent_chain + + def load_relations(self, db: "EnvDB") -> None: + """Expand the parent chain and subgoals for this item""" + # Load everything in a session + with Session(db.engine) as session: + session.add(self) + assert self.parent_chain is not None + # Recurse through subgoals to load entire chain + subgoals_to_check = self.subgoals.copy() + while len(subgoals_to_check) > 0: + next_goal = subgoals_to_check.pop() + session.add(next_goal) + subgoals_to_check += next_goal.subgoals.copy() + session.expunge_all() + + def __repr__(self): + return f"DBQuest({self.db_id!r}| {self.agent_id}-{self.target_type})" + + +class DBGraph(SQLBase, HasDBIDMixin): + """Manifest entry for a user-saved or created graph""" + + __tablename__ = "saved_graphs" + ID_PREFIX = "UGR" + + db_id = Column(String(ID_STRING_LENGTH), primary_key=True) + graph_name = Column(String(WORLD_NAME_LENGTH_CAP), nullable=False, index=True) + creator_id = Column( + String(ID_STRING_LENGTH), nullable=False, index=True + ) # retain the creator ID, they own this + file_path = Column(String(FILE_PATH_LENGTH_CAP), nullable=False) + status = Column(Enum(DBStatus), nullable=False, index=True) + create_timestamp = Column(Float, nullable=False) + + def get_graph(self, db: "EnvDB") -> OOGraph: + """Get an OOGraph for this DBGraph, loading from file""" + assert self.file_path is not None + graph_json = db.read_data_from_file(self.file_path, json_encoded=False) + assert isinstance(graph_json, str) + graph = OOGraph.from_json(graph_json) + graph.db_id = self.db_id + return graph + + def __repr__(self): + return f"DBGraph({self.db_id!r}| {self.graph_name})" + + +class EnvDB(BaseDB): + """ + Environment database for LIGHT, containing accessors for the + LIGHT environment, nodes, attributes, additional annotations, + quests, world-interactions, and more. + """ + + DB_TYPE = "environment" + + def _complete_init(self, config: "DictConfig"): + """ + Initialize any specific environment-related paths + """ + SQLBase.metadata.create_all(self.engine) + self._cache: Optional[Dict[str, Dict[str, Any]]] = None + + def _validate_init(self): + """ + Ensure that the environment manifest exists, and that key + paths are properly initialized + """ + # TODO Check the table for any possible consistency issues + # and ensure that all listed files actually exist + + def create_node_cache(self) -> None: + """ + Create a local cached version of the environment nodes and + relationships, to use for rapid construction of things + without needing repeated queries + """ + all_rooms: List[Any] = self.find_rooms() + all_agents: List[Any] = self.find_agents() + all_objects: List[Any] = self.find_objects() + all_nodes: List[Any] = all_rooms + all_agents + all_objects + all_node_edges: List[Any] = self.get_edges() + all_text_edges: List[Any] = self.get_text_edges() + all_entities: List[Any] = all_nodes + all_node_edges + all_text_edges + self._cache = { + "rooms": {r.db_id: r for r in all_rooms}, + "nodes": {a.db_id: a for a in all_agents}, + "objects": {o.db_id: o for o in all_objects}, + "node_edges": {ne.db_id: ne for ne in all_node_edges}, + "text_edges": {te.db_id: te for te in all_text_edges}, + "all": {a.db_id: a for a in all_entities}, + } + for node in all_nodes: + # Load the edges skipping the cache, then resolving + # children from the cache + node.load_edges(self, skip_cache=True) + node._attributes = [] + + # manually link the attributes in a single pass + all_attributes = self.get_attributes() + for attribute in all_attributes: + assert attribute.target_id is not None + self._cache["all"][attribute.target_id]._attributes.append(attribute) + + def _create_name_key( + self, + KeyClass: Type[DBNameKey], + name: str, + ) -> str: + """Idempotently create a name key for the given class""" + try: + db_id = self._get_name_key(KeyClass, name=name).db_id + assert db_id is not None + return db_id + except KeyError: + with Session(self.engine) as session: + db_id = KeyClass.get_id() + name_key = KeyClass( # type: ignore + db_id=db_id, + name=name, + status=DBStatus.REVIEW, + split=DBSplitType.UNSET, + ) + session.add(name_key) + session.flush() + session.commit() + return db_id + + def _get_name_key( + self, + KeyClass: Type[DBNameKey], + name: Optional[str] = None, + db_id: Optional[str] = None, + status: Optional[DBStatus] = None, + split: Optional[DBSplitType] = None, + ) -> DBNameKey: + """Get a specific name key, assert that it exists""" + assert ( + name is not None or db_id is not None + ), "Must provide one of name or db_id" + stmt = select(KeyClass) + if name is not None: + stmt = stmt.where(KeyClass.name == name) + if db_id is not None: + assert KeyClass.is_id(db_id), "Provided ID is not for this key type" + stmt = stmt.where(KeyClass.db_id == db_id) + if status is not None: + stmt = stmt.where(KeyClass.status == status) + if split is not None: + stmt = stmt.where(KeyClass.split == split) + with Session(self.engine) as session: + db_name_key = self._enforce_get_first( + session, stmt, "Matching key didn't exist." + ) + session.expunge_all() + return db_name_key + + def _find_name_keys( + self, + KeyClass: Type[DBNameKey], + name: Optional[str] = None, + status: Optional[DBStatus] = None, + split: Optional[DBSplitType] = None, + ) -> List[DBNameKey]: + """Find all matching name keys""" + with Session(self.engine) as session: + if name is None and status is None and split is None: + # Empty query + name_keys = session.query(KeyClass).all() + session.expunge_all() + return name_keys + stmt = select(KeyClass) + if name is not None: + stmt = stmt.where(KeyClass.name.like(f"%{name}%")) + if status is not None: + stmt = stmt.where(KeyClass.status == status) + if split is not None: + stmt = stmt.where(KeyClass.split == split) + + name_keys = session.scalars(stmt).all() + session.expunge_all() + return name_keys + + def _resolve_id_to_db_elem( + self, + db_id: str, + ) -> DBElem: + """Query for the correct DBElem given the provided db_id""" + TargetClass: Type[DBElem] + if DBAgent.is_id(db_id): + TargetClass = DBAgent + stmt = select(DBAgent) + elif DBObject.is_id(db_id): + TargetClass = DBObject + stmt = select(DBObject) + elif DBRoom.is_id(db_id): + TargetClass = DBRoom + stmt = select(DBRoom) + else: + raise AssertionError("Edge type was none of Agent, room, or object") + return self._get_elem_for_class(TargetClass, db_id) + + def _get_elem_for_class( + self, + ElemClass: Type[DBElem], + db_id: str, + ) -> DBElem: + """Get a specific element of the given class by ID, asserting that it exists""" + assert ElemClass.is_id(db_id), f"Given id {db_id} not for {ElemClass}" + if self._cache is not None: + return self._cache["all"][db_id] + + stmt = select(ElemClass).where(ElemClass.db_id == db_id) + with Session(self.engine) as session: + db_elem = self._enforce_get_first( + session, stmt, f"No {ElemClass} by given key {db_id}" + ) + session.expunge_all() + return db_elem + + # Agents + + def create_agent_name(self, name: str) -> str: + """Create a new agent name in the database""" + return self._create_name_key(DBAgentName, name) + + def find_agent_names( + self, + name: Optional[str] = None, + status: Optional[DBStatus] = None, + split: Optional[DBSplitType] = None, + ) -> List[DBAgentName]: + """Find all matching agent name keys""" + return [ + cast(DBAgentName, a_name) + for a_name in self._find_name_keys( + KeyClass=DBAgentName, + name=name, + status=status, + split=split, + ) + ] + + def get_agent_name( + self, + name: Optional[str] = None, + db_id: Optional[str] = None, + status: Optional[DBStatus] = None, + split: Optional[DBSplitType] = None, + ) -> DBAgentName: + """Get a specific agent name, assert that it exists""" + return cast( + DBAgentName, + self._get_name_key( + KeyClass=DBAgentName, + name=name, + db_id=db_id, + status=status, + split=split, + ), + ) + + def create_agent_entry( + self, + name: str, + base_name: str, + persona: str, + physical_description: str, + name_prefix: Optional[str] = None, + is_plural: Optional[bool] = None, + size: Optional[int] = None, + contain_size: Optional[int] = None, + constitution: Optional[int] = None, + charisma: Optional[int] = None, + strength: Optional[int] = None, + dexterity: Optional[int] = None, + intelligence: Optional[int] = None, + wisdom: Optional[int] = None, + status: DBStatus = DBStatus.REVIEW, + creator_id: Optional[str] = None, + ) -> str: + """Create this agent, making an agent name first if required""" + if name_prefix is None: + name_prefix = "an" if name[0] in "aeiou" else "a" + base_id = self.create_agent_name(base_name) + with Session(self.engine) as session: + db_id = DBAgent.get_id() + agent = DBAgent( + db_id=db_id, + base_id=base_id, + status=status, + creator_id=creator_id, + create_timestamp=time.time(), + name=name, + persona=persona, + physical_description=physical_description, + name_prefix=name_prefix, + is_plural=is_plural, + size=size, + contain_size=contain_size, + constitution=constitution, + charisma=charisma, + strength=strength, + dexterity=dexterity, + intelligence=intelligence, + wisdom=wisdom, + ) + session.add(agent) + session.flush() + session.commit() + return db_id + + def find_agents( + self, + base_id: Optional[str] = None, + name: Optional[str] = None, + persona: Optional[str] = None, + physical_description: Optional[str] = None, + name_prefix: Optional[str] = None, + is_plural: Optional[bool] = None, + status: Optional[DBStatus] = None, + split: Optional[DBSplitType] = None, + creator_id: Optional[str] = None, + ) -> List[DBAgent]: + """Return all agents matching the given parameters""" + # Empty query first + query_args = locals().copy() + filtered_args = list(filter(lambda x: x is not None, query_args.values())) + if len(filtered_args) == 1: + # Only self argument + with Session(self.engine) as session: + agents = session.query(DBAgent).all() + session.expunge_all() + return agents + + # Construct query + stmt = select(DBAgent) + if base_id is not None: + stmt = stmt.where(DBAgent.base_id == base_id) + if name is not None: + stmt = stmt.where(DBAgent.name.like(f"%{name}%")) + if persona is not None: + stmt = stmt.where(DBAgent.persona.like(f"%{persona}%")) + if physical_description is not None: + stmt = stmt.where( + DBAgent.physical_description.like(f"%{physical_description}%") + ) + if name_prefix is not None: + stmt = stmt.where(DBAgent.name_prefix == name_prefix) + if is_plural is not None: + stmt = stmt.where(DBAgent.is_plural == is_plural) + if status is not None: + stmt = stmt.where(DBAgent.status == status) + if split is not None: + # Need to join up to parent for split query + stmt = stmt.where(DBAgent.base_name.has(split=split)) + if creator_id is not None: + stmt = stmt.where(DBAgent.creator_id == creator_id) + + # Do query + with Session(self.engine) as session: + agents = session.scalars(stmt).all() + session.expunge_all() + return agents + + def get_agent(self, db_id: str) -> DBAgent: + """Return the given agent, raise an exception if non-existing""" + return cast(DBAgent, self._get_elem_for_class(DBAgent, db_id)) + + # Objects + + def create_object_name(self, name: str) -> str: + """Create a new object name in the database""" + return self._create_name_key(DBObjectName, name) + + def get_object_name( + self, + name: Optional[str] = None, + db_id: Optional[str] = None, + status: Optional[DBStatus] = None, + split: Optional[DBSplitType] = None, + ) -> DBObjectName: + """Get a specific object name, assert that it exists""" + return cast( + DBObjectName, + self._get_name_key( + KeyClass=DBObjectName, + name=name, + db_id=db_id, + status=status, + split=split, + ), + ) + + def find_object_names( + self, + name: Optional[str] = None, + status: Optional[DBStatus] = None, + split: Optional[DBSplitType] = None, + ) -> List[DBObjectName]: + """Find all matching agent name keys""" + return [ + cast(DBObjectName, o_name) + for o_name in self._find_name_keys( + KeyClass=DBObjectName, + name=name, + status=status, + split=split, + ) + ] + + def create_object_entry( + self, + name: str, + base_name: str, + physical_description: str, + is_container: float, + is_drink: float, + is_food: float, + is_gettable: float, + is_surface: float, + is_wearable: float, + is_weapon: float, + name_prefix: Optional[str] = None, + is_plural: Optional[bool] = None, + size: Optional[int] = None, + contain_size: Optional[int] = None, + value: Optional[float] = None, + rarity: Optional[float] = None, + status: DBStatus = DBStatus.REVIEW, + creator_id: Optional[str] = None, + ) -> str: + """Create a new object, making a object_name first if required""" + if name_prefix is None: + name_prefix = "an" if name[0] in "aeiou" else "a" + base_id = self.create_object_name(base_name) + with Session(self.engine) as session: + db_id = DBObject.get_id() + agent = DBObject( + db_id=db_id, + base_id=base_id, + status=status, + creator_id=creator_id, + create_timestamp=time.time(), + name=name, + physical_description=physical_description, + is_container=is_container, + is_drink=is_drink, + is_food=is_food, + is_gettable=is_gettable, + is_surface=is_surface, + is_wearable=is_wearable, + is_weapon=is_weapon, + name_prefix=name_prefix, + is_plural=is_plural, + size=size, + contain_size=contain_size, + value=value, + rarity=rarity, + ) + session.add(agent) + session.flush() + session.commit() + return db_id + + def find_objects( + self, + base_id: Optional[str] = None, + name: Optional[str] = None, + physical_description: Optional[str] = None, + is_container: Optional[bool] = None, + is_drink: Optional[bool] = None, + is_food: Optional[bool] = None, + is_gettable: Optional[bool] = None, + is_surface: Optional[bool] = None, + is_wearable: Optional[bool] = None, + is_weapon: Optional[bool] = None, + name_prefix: Optional[str] = None, + is_plural: Optional[bool] = None, + status: Optional[DBStatus] = None, + split: Optional[DBSplitType] = None, + creator_id: Optional[str] = None, + ) -> List["DBObject"]: + """Return all objects matching the given parameters""" + # Empty query first + query_args = locals().copy() + filtered_args = list(filter(lambda x: x is not None, query_args.values())) + if len(filtered_args) == 1: + # Only self argument + with Session(self.engine) as session: + objects = session.query(DBObject).all() + session.expunge_all() + return objects + + FLOAT_TRUE_THRESHOLD = 0.5 + # Construct query + stmt = select(DBObject) + if base_id is not None: + stmt = stmt.where(DBObject.base_id.like(f"%{base_id}%")) + if name is not None: + stmt = stmt.where(DBObject.name.like(f"%{name}%")) + if physical_description is not None: + stmt = stmt.where( + DBObject.physical_description.like(f"%{physical_description}%") + ) + if is_container is not None: + if is_container: + stmt = stmt.where(DBObject.is_container >= FLOAT_TRUE_THRESHOLD) + else: + stmt = stmt.where(DBObject.is_container < FLOAT_TRUE_THRESHOLD) + if is_drink is not None: + if is_drink: + stmt = stmt.where(DBObject.is_drink >= FLOAT_TRUE_THRESHOLD) + else: + stmt = stmt.where(DBObject.is_drink < FLOAT_TRUE_THRESHOLD) + if is_food is not None: + if is_food: + stmt = stmt.where(DBObject.is_food >= FLOAT_TRUE_THRESHOLD) + else: + stmt = stmt.where(DBObject.is_food < FLOAT_TRUE_THRESHOLD) + if is_gettable is not None: + if is_gettable: + stmt = stmt.where(DBObject.is_gettable >= FLOAT_TRUE_THRESHOLD) + else: + stmt = stmt.where(DBObject.is_gettable < FLOAT_TRUE_THRESHOLD) + if is_surface is not None: + if is_surface: + stmt = stmt.where(DBObject.is_surface >= FLOAT_TRUE_THRESHOLD) + else: + stmt = stmt.where(DBObject.is_surface < FLOAT_TRUE_THRESHOLD) + if is_wearable is not None: + if is_wearable: + stmt = stmt.where(DBObject.is_wearable >= FLOAT_TRUE_THRESHOLD) + else: + stmt = stmt.where(DBObject.is_wearable < FLOAT_TRUE_THRESHOLD) + if is_weapon is not None: + if is_weapon: + stmt = stmt.where(DBObject.is_weapon >= FLOAT_TRUE_THRESHOLD) + else: + stmt = stmt.where(DBObject.is_weapon < FLOAT_TRUE_THRESHOLD) + if name_prefix is not None: + stmt = stmt.where(DBObject.name_prefix == name_prefix) + if is_plural is not None: + stmt = stmt.where(DBObject.is_plural == is_plural) + if status is not None: + stmt = stmt.where(DBObject.status == status) + if split is not None: + # Need to join up to parent for split query + stmt = stmt.where(DBObject.base_name.has(split=split)) + if creator_id is not None: + stmt = stmt.where(DBObject.creator_id == creator_id) + # Do query + with Session(self.engine) as session: + objects = session.scalars(stmt).all() + session.expunge_all() + return objects + + def get_object(self, db_id: str) -> DBObject: + """Return the given object, raise exception if non-existing""" + return cast(DBObject, self._get_elem_for_class(DBObject, db_id)) + + # Rooms + + def create_room_name(self, name: str) -> str: + """Create a new room name in the database""" + return self._create_name_key(DBRoomName, name) + + def get_room_name( + self, + name: Optional[str] = None, + db_id: Optional[str] = None, + status: Optional[DBStatus] = None, + split: Optional[DBSplitType] = None, + ) -> DBRoomName: + """Get a specific room name, assert that it exists""" + return cast( + DBRoomName, + self._get_name_key( + KeyClass=DBRoomName, + name=name, + db_id=db_id, + status=status, + split=split, + ), + ) + + def find_room_names( + self, + name: Optional[str] = None, + status: Optional[DBStatus] = None, + split: Optional[DBSplitType] = None, + ) -> List[DBRoomName]: + """Find all matching agent name keys""" + return [ + cast(DBRoomName, r_name) + for r_name in self._find_name_keys( + KeyClass=DBRoomName, + name=name, + status=status, + split=split, + ) + ] + + def create_room_entry( + self, + name: str, + base_name: str, + description: str, + backstory: str, + indoor_status: DBRoomInsideType = DBRoomInsideType.UNKNOWN, + size: Optional[int] = None, + rarity: Optional[float] = None, + status: DBStatus = DBStatus.REVIEW, + creator_id: Optional[str] = None, + ) -> str: + """Create a new room, making a room name first if required""" + base_id = self.create_room_name(base_name) + with Session(self.engine) as session: + db_id = DBRoom.get_id() + room = DBRoom( + db_id=db_id, + base_id=base_id, + status=status, + creator_id=creator_id, + create_timestamp=time.time(), + name=name, + description=description, + backstory=backstory, + size=size, + indoor_status=indoor_status, + rarity=rarity, + ) + session.add(room) + session.flush() + session.commit() + return db_id + + def find_rooms( + self, + base_id: Optional[str] = None, + name: Optional[str] = None, + description: Optional[str] = None, + backstory: Optional[str] = None, + indoor_status: Optional[str] = None, + status: Optional[DBStatus] = None, + split: Optional[DBSplitType] = None, + creator_id: Optional[str] = None, + ) -> List["DBRoom"]: + """Return all rooms matching the given parameters""" + # Empty query first + query_args = locals().copy() + filtered_args = list(filter(lambda x: x is not None, query_args.values())) + if len(filtered_args) == 1: + # Only self argument + with Session(self.engine) as session: + rooms = session.query(DBRoom).all() + session.expunge_all() + return rooms + + # Construct query + stmt = select(DBRoom) + if base_id is not None: + stmt = stmt.where(DBRoom.base_id == base_id) + if name is not None: + stmt = stmt.where(DBRoom.name.like(f"%{name}%")) + if description is not None: + stmt = stmt.where(DBRoom.description.like(f"%{description}%")) + if backstory is not None: + stmt = stmt.where(DBRoom.backstory.like(f"%{backstory}%")) + if indoor_status is not None: + stmt = stmt.where(DBRoom.indoor_status == indoor_status) + if status is not None: + stmt = stmt.where(DBRoom.status == status) + if split is not None: + # Need to join up to parent for split query + stmt = stmt.where(DBRoom.base_name.has(split=split)) + if creator_id is not None: + stmt = stmt.where(DBRoom.creator_id == creator_id) + # Do query + with Session(self.engine) as session: + rooms = session.scalars(stmt).all() + session.expunge_all() + return rooms + + def get_room(self, db_id: str) -> DBRoom: + """Get a specific room, assert that it exists""" + return cast(DBRoom, self._get_elem_for_class(DBRoom, db_id)) + + # Attributes + + def create_arbitrary_attribute( + self, + target_id: str, + attribute_name: str, + attribute_value_string: str, + status: DBStatus = DBStatus.REVIEW, + creator_id: Optional[str] = None, + ) -> str: + """Create an arbitrary attribute entry for the target node""" + try: + with Session(self.engine) as session: + db_id = DBNodeAttribute.get_id() + attribute = DBNodeAttribute( + db_id=db_id, + target_id=target_id, + attribute_name=attribute_name, + attribute_value_string=attribute_value_string, + status=status, + creator_id=creator_id, + create_timestamp=time.time(), + ) + session.add(attribute) + session.flush() + session.commit() + return db_id + except sqlalchemy.exc.IntegrityError: + # Duplicate, grab the existing + attributes = self.get_attributes( + target_id=target_id, + attribute_name=attribute_name, + attribute_value_string=attribute_value_string, + ) + assert len(attributes) == 1 + assert attributes[0].db_id is not None + db_id = attributes[0].db_id + return db_id + + def get_attributes( + self, + target_id: Optional[str] = None, + attribute_name: Optional[str] = None, + attribute_value_string: Optional[str] = None, + status: Optional[DBStatus] = None, + creator_id: Optional[str] = None, + ) -> List[DBNodeAttribute]: + """Return the list of all attributes stored that match the given filters""" + # Empty query first + query_args = locals().copy() + filtered_args = list(filter(lambda x: x is not None, query_args.values())) + if len(filtered_args) == 1: + # Only self argument + with Session(self.engine) as session: + attributes = session.query(DBNodeAttribute).all() + session.expunge_all() + return attributes + + # Construct query + stmt = select(DBNodeAttribute) + if target_id is not None: + stmt = stmt.where(DBNodeAttribute.target_id == target_id) + if attribute_name is not None: + stmt = stmt.where(DBNodeAttribute.attribute_name == attribute_name) + if attribute_value_string is not None: + stmt = stmt.where( + DBNodeAttribute.attribute_value_string == attribute_value_string + ) + if status is not None: + stmt = stmt.where(DBNodeAttribute.status == status) + if creator_id is not None: + stmt = stmt.where(DBNodeAttribute.creator_id == creator_id) + # Do query + with Session(self.engine) as session: + attributes = session.scalars(stmt).all() + session.expunge_all() + return attributes + + # Edges + + def create_edge( + self, + parent_id: str, + child_id: str, + edge_type: DBEdgeType, + edge_label: str = "", + status: DBStatus = DBStatus.REVIEW, + creator_id: Optional[str] = None, + ) -> str: + """Create an edge between two nodes, idempotent""" + try: + with Session(self.engine) as session: + db_id = DBEdge.get_id() + edge = DBEdge( + db_id=db_id, + parent_id=parent_id, + edge_type=edge_type, + edge_label=edge_label, + status=status, + creator_id=creator_id, + create_timestamp=time.time(), + child_id=child_id, + ) + session.add(edge) + session.flush() + session.commit() + return db_id + except sqlalchemy.exc.IntegrityError: + # Duplicate, grab the existing + edges = self.get_edges( + parent_id=parent_id, + child_id=child_id, + edge_type=edge_type, + edge_label=edge_label, + ) + assert len(edges) == 1 + assert edges[0].db_id is not None + db_id = edges[0].db_id + return db_id + + def get_edges( + self, + parent_id: Optional[str] = None, + child_id: Optional[str] = None, + edge_type: Optional[DBEdgeType] = None, + edge_label: Optional[str] = None, + status: Optional[DBStatus] = None, + creator_id: Optional[str] = None, + min_strength: Optional[float] = None, + ) -> List[DBEdge]: + """Return all edges matching the given parameters""" + # Empty query first + query_args = locals().copy() + filtered_args = list(filter(lambda x: x is not None, query_args.values())) + if len(filtered_args) == 1: + # Only self argument + with Session(self.engine) as session: + edges = session.query(DBEdge).all() + session.expunge_all() + return edges + + # Construct query + stmt = select(DBEdge) + if parent_id is not None: + stmt = stmt.where(DBEdge.parent_id == parent_id) + if child_id is not None: + stmt = stmt.where(DBEdge.child_id == child_id) + if edge_type is not None: + stmt = stmt.where(DBEdge.edge_type == edge_type) + if edge_label is not None: + stmt = stmt.where(DBEdge.edge_label == edge_label) + if status is not None: + stmt = stmt.where(DBEdge.status == status) + if creator_id is not None: + stmt = stmt.where(DBEdge.creator_id == creator_id) + # Do query + with Session(self.engine) as session: + edges = session.scalars(stmt).all() + if min_strength is not None: + # Need to post-filter out things below the min strength, where + # strength is defined as the proportion of edge occurrences to + # parent occurrences + filtered_edges = [] + for edge in edges: + edge_occurrences = edge.built_occurrences + if edge_occurrences == 0: + continue # No occurrences of edge + db_elem = self._resolve_id_to_db_elem(edge.parent_id) + elem_occurrences = db_elem.built_occurrences + if elem_occurrences == 0: + continue # No occurrences of elem + if edge_occurrences / elem_occurrences >= min_strength: + filtered_edges.append(edge) + edges = filtered_edges + session.expunge_all() + return edges + + def create_text_edge( + self, + parent_id: str, + child_text: str, + edge_type: DBEdgeType, + edge_label: str = "", + status: DBStatus = DBStatus.REVIEW, + creator_id: Optional[str] = None, + ) -> str: + """Create an edge between a node and the name of a possible leaf""" + try: + with Session(self.engine) as session: + db_id = DBTextEdge.get_id() + edge = DBTextEdge( + db_id=db_id, + parent_id=parent_id, + edge_type=edge_type, + edge_label=edge_label, + status=status, + creator_id=creator_id, + create_timestamp=time.time(), + child_text=child_text, + ) + session.add(edge) + session.flush() + session.commit() + return db_id + except sqlalchemy.exc.IntegrityError: + # Duplicate, grab the existing + edges = self.get_text_edges( + parent_id=parent_id, + child_text=child_text, + edge_type=edge_type, + edge_label=edge_label, + ) + assert len(edges) == 1 + assert edges[0].db_id is not None + db_id = edges[0].db_id + return db_id + + def get_text_edges( + self, + parent_id: Optional[str] = None, + child_text: Optional[str] = None, + edge_type: Optional[DBEdgeType] = None, + edge_label: Optional[str] = None, + status: Optional[DBStatus] = None, + creator_id: Optional[str] = None, + ) -> List[DBTextEdge]: + """Return all text edges matching the given parameters""" + # Empty query first + query_args = locals().copy() + filtered_args = list(filter(lambda x: x is not None, query_args.values())) + if len(filtered_args) == 1: + # Only self argument + with Session(self.engine) as session: + text_edges = session.query(DBTextEdge).all() + session.expunge_all() + return text_edges + + # Construct query + stmt = select(DBTextEdge) + if parent_id is not None: + stmt = stmt.where(DBTextEdge.parent_id == parent_id) + if child_text is not None: + stmt = stmt.where(DBTextEdge.child_text == child_text) + if edge_type is not None: + stmt = stmt.where(DBTextEdge.edge_type == edge_type) + if edge_label is not None: + stmt = stmt.where(DBTextEdge.edge_label == edge_label) + if status is not None: + stmt = stmt.where(DBTextEdge.status == status) + if creator_id is not None: + stmt = stmt.where(DBTextEdge.creator_id == creator_id) + # Do query + with Session(self.engine) as session: + edges = session.scalars(stmt).all() + session.expunge_all() + return edges + + # Flags and edits + + def create_edit( + self, + editor_id: str, + node_id: str, + field: str, + old_value: str, + new_value: str, + status: Optional[DBStatus] = DBStatus.REVIEW, + ) -> str: + """Write a potential edit to db. Return the edit db_id""" + with Session(self.engine) as session: + db_id = DBEdit.get_id() + edit = DBEdit( + db_id=db_id, + editor_id=editor_id, + node_id=node_id, + field=field, + old_value=old_value, + new_value=new_value, + status=status, + create_timestamp=time.time(), + ) + session.add(edit) + session.flush() + session.commit() + return db_id + + def get_edits( + self, + editor_id: Optional[str] = None, + node_id: Optional[str] = None, + field: Optional[str] = None, + old_value: Optional[str] = None, + new_value: Optional[str] = None, + status: Optional[DBStatus] = None, + ) -> List[DBEdit]: + """Return all edits matching the given parameters""" + # Empty query first + query_args = locals().copy() + filtered_args = list(filter(lambda x: x is not None, query_args.values())) + if len(filtered_args) == 1: + # Only self argument + with Session(self.engine) as session: + edits = session.query(DBEdit).all() + session.expunge_all() + return edits + + # Construct query + stmt = select(DBEdit) + if editor_id is not None: + stmt = stmt.where(DBEdit.editor_id == editor_id) + if node_id is not None: + stmt = stmt.where(DBEdit.node_id == node_id) + if field is not None: + stmt = stmt.where(DBEdit.field == field) + if old_value is not None: + stmt = stmt.where(DBEdit.old_value == old_value) + if new_value is not None: + stmt = stmt.where(DBEdit.new_value == new_value) + if status is not None: + stmt = stmt.where(DBEdit.status == status) + # Do query + with Session(self.engine) as session: + edits = session.scalars(stmt).all() + session.expunge_all() + return edits + + def flag_entry( + self, + user_id: str, + flag_type: DBFlagTargetType, + target_id: str, + reason: str, + status: Optional[DBStatus] = DBStatus.REVIEW, + ) -> str: + """ + Write a potential flag to db, return the flag id + """ + with Session(self.engine) as session: + db_id = DBFlag.get_id() + flag = DBFlag( + db_id=db_id, + user_id=user_id, + flag_type=flag_type, + target_id=target_id, + reason=reason, + status=status, + create_timestamp=time.time(), + ) + session.add(flag) + session.flush() + session.commit() + # TODO enough flags could perhaps move node to review status + return db_id + + def get_flags( + self, + user_id: Optional[str] = None, + flag_type: Optional[DBFlagTargetType] = None, + target_id: Optional[str] = None, + reason: Optional[str] = None, + status: Optional[DBStatus] = None, + ) -> List[DBFlag]: + """Return all flags matching the given parameters""" + # Empty query first + query_args = locals().copy() + filtered_args = list(filter(lambda x: x is not None, query_args.values())) + if len(filtered_args) == 1: + # Only self argument + with Session(self.engine) as session: + flags = session.query(DBFlag).all() + session.expunge_all() + return flags + + # Construct query + stmt = select(DBFlag) + if user_id is not None: + stmt = stmt.where(DBFlag.user_id == user_id) + if flag_type is not None: + stmt = stmt.where(DBFlag.flag_type == flag_type) + if target_id is not None: + stmt = stmt.where(DBFlag.target_id == target_id) + if reason is not None: + stmt = stmt.where(DBFlag.reason == reason) + if status is not None: + stmt = stmt.where(DBFlag.status == status) + # Do query + with Session(self.engine) as session: + flags = session.scalars(stmt).all() + session.expunge_all() + return flags + + # Quests + + def create_quest( + self, + agent_id: str, + text_motivation: str, + target_type: DBQuestTargetType, + target: str, + position: int = 0, + origin_filepath: Optional[str] = None, + parent_id: Optional[str] = None, + status: DBStatus = DBStatus.REVIEW, + creator_id: Optional[str] = None, + ) -> str: + """ + Creates a Quest, which is a mapping from character and motivation + text to a desired action or list of subquests + """ + with Session(self.engine) as session: + db_id = DBQuest.get_id() + quest = DBQuest( + db_id=db_id, + agent_id=agent_id, + parent_id=parent_id, + position=position, + text_motivation=text_motivation, + target_type=target_type, + target=target, + origin_filepath=origin_filepath, + status=status, + creator_id=creator_id, + create_timestamp=time.time(), + ) + session.add(quest) + session.flush() + session.commit() + return db_id + + def find_quests( + self, + agent_id: Optional[str] = None, + parent_id: Optional[str] = None, + text_motivation: Optional[str] = None, + target_type: Optional[DBQuestTargetType] = None, + target: Optional[str] = None, + status: Optional[DBStatus] = None, + creator_id: Optional[str] = None, + origin_filepath: Optional[str] = None, + ) -> List[DBQuest]: + """Return all text edges matching the given parameters""" + # Empty query first + query_args = locals().copy() + filtered_args = list(filter(lambda x: x is not None, query_args.values())) + if len(filtered_args) == 1: + # Only self argument + with Session(self.engine) as session: + quests = session.query(DBQuest).all() + session.expunge_all() + return quests + + # Construct query + stmt = select(DBQuest) + if agent_id is not None: + stmt = stmt.where(DBQuest.agent_id == agent_id) + if parent_id is not None: + stmt = stmt.where(DBQuest.parent_id == parent_id) + if text_motivation is not None: + stmt = stmt.where(DBQuest.text_motivation == text_motivation) + if target_type is not None: + stmt = stmt.where(DBQuest.target_type == target_type) + if target is not None: + stmt = stmt.where(DBQuest.target == target) + if status is not None: + stmt = stmt.where(DBQuest.status == status) + if creator_id is not None: + stmt = stmt.where(DBQuest.creator_id == creator_id) + if origin_filepath is not None: + stmt = stmt.where(DBQuest.origin_filepath == origin_filepath) + # Do query + with Session(self.engine) as session: + quests = session.scalars(stmt).all() + session.expunge_all() + return quests + + # Graphs + + def save_graph(self, graph: "OOGraph", creator_id: str) -> str: + """Save this graph to a file for the given user""" + # Find or assign a db_id for this graph + if graph.db_id is not None: + db_id = graph.db_id + assert DBGraph.is_id(db_id), f"Provided Graph ID invalid: {db_id}" + else: + db_id = DBGraph.get_id() + graph.db_id = db_id + + dump_file_path = os.path.join(FILE_PATH_KEY, GRAPH_PATH_KEY, f"{db_id}.json") + + # Create or update the graph + with Session(self.engine) as session: + db_graph = session.query(DBGraph).get(db_id) + if db_graph is not None: + # Update old graph, ensure same creator + assert db_graph.creator_id == creator_id, ( + f"Creator ID mismatch on {db_id}, current " + f"{db_graph.creator_id} and new {creator_id}" + ) + self.write_data_to_file( + graph.to_json(), dump_file_path, json_encode=False + ) + db_graph.status = DBStatus.REVIEW + else: + # New graph + db_graph = DBGraph( + db_id=db_id, + graph_name=graph.title, + creator_id=creator_id, + file_path=dump_file_path, + status=DBStatus.REVIEW, + create_timestamp=time.time(), + ) + session.add(db_graph) + self.write_data_to_file( + graph.to_json(), dump_file_path, json_encode=False + ) + session.flush() + session.commit() + return db_id + + def load_graph(self, graph_id: str) -> DBGraph: + """Return the queried graph, raising if nonexistent""" + with Session(self.engine) as session: + db_graph = session.query(DBGraph).get(graph_id) + if db_graph is None: + raise KeyError(f"Graph key {graph_id} didn't exist!") + session.expunge_all() + return db_graph + + def find_graphs( + self, + graph_name: Optional[str] = None, + creator_id: Optional[str] = None, + # ... TODO can add other search attributes? + ) -> List[DBGraph]: + """Return all graphs matching the provided parameters""" + # Empty query first + query_args = locals().copy() + filtered_args = list(filter(lambda x: x is not None, query_args.values())) + if len(filtered_args) == 1: + # Only self argument + with Session(self.engine) as session: + graphs = session.query(DBGraph).all() + session.expunge_all() + return graphs + + # Construct query + stmt = select(DBGraph) + if graph_name is not None: + stmt = stmt.where(DBGraph.graph_name == graph_name) + if creator_id is not None: + stmt = stmt.where(DBGraph.creator_id == creator_id) + # Do query + with Session(self.engine) as session: + db_graphs = session.scalars(stmt).all() + session.expunge_all() + return db_graphs + + def count_built_occurrences(self) -> None: + """ + Iterate through all of the graphs to populate the strengths of + all of the edges. + """ + raise NotImplementedError + + # release functionality + + def scrub_creators(self, start_time: Optional[int] = None) -> int: + """ + Remove creators from anything in the dataset longer than 60 days + """ + changed_count = 0 + current_time = time.time() if start_time is None else start_time + cutoff_time = current_time - MAX_RETENTION + with Session(self.engine) as session: + for target_type in [ + DBAgent, + DBObject, + DBRoom, + DBNodeAttribute, + DBEdge, + DBTextEdge, + DBQuest, + ]: + stmt = select(target_type) + stmt = stmt.where(target_type.creator_id.startswith(USR_KEY)) + stmt = stmt.where(target_type.create_timestamp < cutoff_time) + elems = session.scalars(stmt).all() + for elem in elems: + changed_count += 1 + elem.creator_id = SCRUBBED_USER_ID + + stmt = select(DBFlag) + stmt = stmt.where(DBFlag.user_id.startswith(USR_KEY)) + stmt = stmt.where(DBFlag.create_timestamp < cutoff_time) + flags = session.scalars(stmt).all() + for flag in flags: + changed_count += 1 + flag.user_id = SCRUBBED_USER_ID + + stmt = select(DBEdit) + stmt = stmt.where(DBEdit.editor_id.startswith(USR_KEY)) + stmt = stmt.where(DBEdit.create_timestamp < cutoff_time) + edits = session.scalars(stmt).all() + for edit in edits: + changed_count += 1 + edit.editor_id = SCRUBBED_USER_ID + + session.commit() + return changed_count + + def clear_player_graphs( + self, player_id: Optional[str] = None, scrub_all: Optional[bool] = False + ) -> None: + """ + Find graphs with this player_id as creator + and then scrub the association. + """ + if player_id is not None: + assert scrub_all is not True, "Cannot scrub all if providing player id" + with Session(self.engine) as session: + stmt = select(DBGraph) + if not scrub_all: + stmt = stmt.where(DBGraph.creator_id == player_id) + graphs = session.scalars(stmt).all() + for graph in graphs: + graph.creator_id = SCRUBBED_USER_ID + session.commit() + + def dissociate_graph(self, graph_id: str) -> None: + with Session(self.engine) as session: + stmt = select(DBGraph).where(DBGraph.db_id == graph_id) + graph = session.scalars(stmt).one() + graph.creator_id = SCRUBBED_USER_ID + session.commit() + + def export(self, config: "DictConfig") -> "EnvDB": + """ + Create a scrubbed version of this database for use in releases + """ + assert config.file_root != self.file_root, "Cannot copy DB to same location!" + new_db = EnvDB(config) + + SKIPPED_TABLES = [t.__tablename__ for t in [DBFlag, DBEdit]] + + for table_name, table_obj in SQLBase.metadata.tables.items(): + # Skip tables that should not be public + if table_name in SKIPPED_TABLES: + continue + with self.engine.connect() as orig_conn: + with new_db.engine.connect() as new_conn: + all_data = [ + dict(row) for row in orig_conn.execute(select(table_obj.c)) + ] + if len(all_data) == 0: + continue + new_conn.execute(table_obj.insert().values(all_data)) + new_conn.commit() + + new_db.clear_player_graphs(scrub_all=True) + new_db.scrub_creators( + start_time=time.time() + MAX_RETENTION + ) # Scrub _all_ creator ids + + with Session(self.engine) as session: + # Copy the graphs to the new DB + stmt = select(DBGraph) + graphs = session.scalars(stmt).all() + for graph in graphs: + graph_data = self.read_data_from_file( + graph.file_path, json_encoded=False + ) + new_db.write_data_to_file( + graph_data, graph.file_path, json_encoded=False + ) + + # Copy the quests to the new DB + stmt = select(DBQuest) + quests = session.scalars(stmt).all() + for quest in quests: + file_path = quest.origin_file_path + if file_path is None: + continue # no quest file + quest_data = self.read_data_from_file(file_path, json_encoded=False) + new_db.write_data_to_file(quest_data, file_path, json_encoded=False) + + return new_db diff --git a/light/data_model/db/episodes.py b/light/data_model/db/episodes.py new file mode 100644 index 000000000..3cb38ca9e --- /dev/null +++ b/light/data_model/db/episodes.py @@ -0,0 +1,407 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from light.data_model.db.base import BaseDB, DBStatus, DBSplitType, HasDBIDMixin +from light.data_model.db.users import DBPlayer +from omegaconf import MISSING, DictConfig +from typing import Optional, List, Tuple, Union, Dict, Any, Set, TYPE_CHECKING +from sqlalchemy import insert, select, Enum, Column, Integer, String, Float, ForeignKey +from sqlalchemy.orm import declarative_base, relationship, Session +from light.graph.events.base import GraphEvent +import time +import enum +import os +import hashlib + +if TYPE_CHECKING: + from light.graph.structured_graph import OOGraph + +SQLBase = declarative_base() +FILE_PATH_KEY = "episodes" +ID_STRING_LENGTH = 40 +USR_KEY = DBPlayer.ID_PREFIX + + +class DBGroupName(enum.Enum): + """Data Releases in the LIGHT episode DB""" + + ORIG = "orig" + WILD = "wild" + MULTIPARTY = "multiparty" + PRE_LAUNCH = "crowdsourced" + PRE_LAUNCH_TUTORIAL = "crowdsourced_tutorial" + RELEASE_Q4_22 = "full_release_Q4_22" + + +class EpisodeLogType(enum.Enum): + """Types of episodes in LIGHT""" + + ROOM = "room" + AGENT = "agent" + FULL = "full" + + +class DBEpisode(HasDBIDMixin, SQLBase): + """Class containing the expected elements for an episode as stored in the db""" + + __tablename__ = "episodes" + + ID_PREFIX = "EPI" + + id = Column(String(ID_STRING_LENGTH), primary_key=True) + group = Column(Enum(DBGroupName), nullable=False, index=True) + split = Column(Enum(DBSplitType), nullable=False, index=True) + status = Column(Enum(DBStatus), nullable=False, index=True) + actors = Column( + String + ) # Comma separated list of actor IDs. Cleared on release data + dump_file_path = Column(String(90), nullable=False) # Path to data + turn_count = Column(Integer, nullable=False) + human_count = Column(Integer, nullable=False) + action_count = Column(Integer, nullable=False) + timestamp = Column(Float, nullable=False) + log_type = Column(Enum(EpisodeLogType), nullable=False) + first_graph_id = Column(ForeignKey("graphs.id")) + final_graph_id = Column(ForeignKey("graphs.id")) + + _cached_map = None + + def get_actors(self) -> List[str]: + """Return the actors in this episode""" + if len(self.actors.strip()) == 0: + return [] + return self.actors.split(",") + + def get_parsed_events( + self, db: "EpisodeDB" + ) -> List[Tuple[str, List["GraphEvent"]]]: + """ + Return all of the actions and turns from this episode, + split by the graph key ID relevant to those actions + """ + # Import deferred as World imports loggers which import the EpisodeDB + from light.world.world import World, WorldConfig + + events = db.read_data_from_file(self.dump_file_path, json_encoded=True)[ + "events" + ] + graph_grouped_events: List[Tuple[str, List["GraphEvent"]]] = [] + current_graph_events = None + curr_graph_key = None + curr_graph = None + tmp_world = None + # Extract events to the correct related graphs, initializing the graphs + # as necessary + for event_turn in events: + # See if we've moved onto an event in a new graph + if event_turn["graph_key"] != curr_graph_key: + if current_graph_events is not None: + # There was old state, so lets push it to the list + graph_grouped_events.append((curr_graph_key, current_graph_events)) + # We're on a new graph, have to reset the current graph state + curr_graph_key = event_turn["graph_key"] + current_graph_events: List["GraphEvent"] = [] + curr_graph = self.get_graph(curr_graph_key, db) + tmp_world = World(WorldConfig()) + tmp_world.oo_graph = curr_graph + # The current turn is part of the current graph's events, add + current_graph_events.append( + GraphEvent.from_json(event_turn["event_json"], tmp_world) + ) + if current_graph_events is not None: + # Push the last graph's events, which weren't yet added + graph_grouped_events.append((curr_graph_key, current_graph_events)) + return graph_grouped_events + + def get_before_graph(self, db: "EpisodeDB") -> "OOGraph": + """Return the state of the graph before this episode""" + return self.get_graph(self.first_graph_id, db) + + def get_graph(self, id_or_key: str, db: "EpisodeDB") -> "OOGraph": + """Return a specific graph by id or key""" + with Session(db.engine) as session: + session.add(self) + return self.get_graph_map()[id_or_key].get_graph(db) + + def get_after_graph(self, db: "EpisodeDB") -> "OOGraph": + """Return the state of the graph after this episode""" + return self.get_graph(self.final_graph_id, db) + + def get_graph_map(self): + """Return a mapping from both graph keys and graph ids to their graph""" + if self._cached_map is None: + key_map = {graph.graph_key_id: graph for graph in self.graphs} + id_map = {graph.id: graph for graph in self.graphs} + key_map.update(id_map) + self._cached_map = key_map + return self._cached_map + + def __repr__(self): + return f"DBEpisode(ids:[{self.id!r}] group/split:[{self.group.value!r}/{self.split.value!r}] File:[{self.dump_file_path!r}])" + + +class DBEpisodeGraph(HasDBIDMixin, SQLBase): + """Class containing expected elements for a stored graph""" + + __tablename__ = "graphs" + + ID_PREFIX = "EPG" + + id = Column(String(ID_STRING_LENGTH), primary_key=True) + episode_id = Column(Integer, ForeignKey("episodes.id"), nullable=False, index=True) + full_path = Column(String(80), nullable=False) + graph_key_id = Column(String(60), nullable=False, index=True) + episode = relationship("DBEpisode", backref="graphs", foreign_keys=[episode_id]) + + def get_graph(self, db: "EpisodeDB") -> "OOGraph": + """Return the initialized graph based on this file""" + from light.graph.structured_graph import OOGraph + + graph_json = db.read_data_from_file(self.full_path) + graph = OOGraph.from_json(graph_json) + return graph + + def __repr__(self): + return f"DBEpisodeGraph(ids:[{self.id!r},{self.graph_key_id!r}], episode:{self.episode_id!r})" + + +class EpisodeDB(BaseDB): + """ + Episode dataset database for LIGHT, containing accessors for all + of the recorded LIGHT episodes, including previous dataset dumps. + + Used by InteractionLoggers to write new entries, and by ParlAI to + create teachers for datasets. + """ + + DB_TYPE = "episode" + + def _complete_init(self, config: "DictConfig"): + """ + Initialize any specific episode-related paths. Populate + the list of available splits and datasets. + """ + SQLBase.metadata.create_all(self.engine) + + def _validate_init(self): + """ + Ensure that the episode directory is properly loaded + """ + # TODO Check the table for any possible consistency issues + # and ensure that the episode directories for listed splits exist + + def write_episode( + self, + graphs: List[Dict[str, str]], + events: Tuple[str, List[Dict[str, str]]], + log_type: EpisodeLogType, + action_count: int, + players: Set[str], + group: DBGroupName, + ) -> str: + """ + Create an entry given the current argument data, store it + to file on the database + """ + actor_string = ",".join(list(players)) + event_filename = events[0] + event_list = events[1] + + # Trim the filename from the left if too long + event_filename = event_filename[-70:] + + dump_file_path = os.path.join( + FILE_PATH_KEY, group.value, log_type.value, event_filename + ) + graph_dump_root = os.path.join( + FILE_PATH_KEY, + group.value, + log_type.value, + "graphs", + ) + + # File writes + self.write_data_to_file( + {"events": event_list}, dump_file_path, json_encode=True + ) + for graph_info in graphs: + graph_full_path = os.path.join(graph_dump_root, graph_info["filename"]) + self.write_data_to_file(graph_info["graph_json"], graph_full_path) + + # DB Writes + episode_id = DBEpisode.get_id() + with Session(self.engine) as session: + episode = DBEpisode( + id=episode_id, + group=group, + split=DBSplitType.UNSET, + status=DBStatus.REVIEW, + actors=actor_string, + dump_file_path=dump_file_path, + turn_count=len(event_list), + human_count=len(players), + action_count=action_count, + timestamp=time.time(), + log_type=log_type, + ) + first_graph = None + for idx, graph_info in enumerate(graphs): + graph_full_path = os.path.join(graph_dump_root, graph_info["filename"]) + db_graph = DBEpisodeGraph( + id=DBEpisodeGraph.get_id(), + graph_key_id=graph_info["key"], + full_path=graph_full_path, + ) + if idx == 0: + first_graph = db_graph + episode.graphs.append(db_graph) + session.add(episode) + session.flush() + episode.first_graph_id = first_graph.id + episode.final_graph_id = db_graph.id + session.commit() + + return episode_id + + def get_episode(self, episode_id: str) -> "DBEpisode": + """ + Return a specific episode by id, raising an issue if it doesnt exist + """ + stmt = select(DBEpisode).where(DBEpisode.id == episode_id) + with Session(self.engine) as session: + episode = self._enforce_get_first(session, stmt, "Episode did not exist") + for graph in episode.graphs: + # Load all the graph keys + assert graph.id is not None + session.expunge_all() + return episode + + def get_episodes( + self, + group: Optional[DBGroupName] = None, + split: Optional[DBSplitType] = None, + min_turns: Optional[int] = None, + min_humans: Optional[int] = None, + min_actions: Optional[int] = None, + status: Optional[DBStatus] = None, + user_id: Optional[str] = None, + min_creation_time: Optional[float] = None, + max_creation_time: Optional[float] = None, + log_type: Optional[EpisodeLogType] = None, + # ... other args + ) -> List["DBEpisode"]: + """ + Return all matching episodes + """ + stmt = select(DBEpisode) + if group is not None: + stmt = stmt.where(DBEpisode.group == group) + if split is not None: + stmt = stmt.where(DBEpisode.split == split) + if min_turns is not None: + stmt = stmt.where(DBEpisode.turn_count >= min_turns) + if min_humans is not None: + stmt = stmt.where(DBEpisode.human_count >= min_humans) + if min_actions is not None: + stmt = stmt.where(DBEpisode.action_count >= min_actions) + if status is not None: + stmt = stmt.where(DBEpisode.status == status) + if user_id is not None: + stmt = stmt.where(DBEpisode.actors.contains(user_id)) + if log_type is not None: + stmt = stmt.where(DBEpisode.log_type == log_type) + if min_creation_time is not None: + stmt = stmt.where(DBEpisode.timestamp >= min_creation_time) + if max_creation_time is not None: + stmt = stmt.where(DBEpisode.timestamp <= max_creation_time) + with Session(self.engine) as session: + episodes = session.scalars(stmt).all() + session.expunge_all() + return episodes + + def anonymize_group(self, group: DBGroupName) -> bool: + """ + Run anonymization on the split to remove any link to the + long-term user. All data within a quarter's dataset + can be linked (for long-term memory analysis) but cannot be + tracked cross-quarters. + + Return true on success + """ + hashing_time = time.time() + sha = hashlib.sha256() + + def rehash(curr_name): + if not curr_name.startswith(USR_KEY): + return curr_name # already hashed + + # Adding a hashtime to make unique + hash_name = f"{curr_name}-{hashing_time}" + sha.update(hash_name.encode()) + return str(sha.hexdigest()[:30]) + + with Session(self.engine) as session: + stmt = select(DBEpisode).where(DBEpisode.group == group) + episodes = session.scalars(stmt).all() + for episode in episodes: + actors_string = episode.actors + actors = actors_string.split(",") + processed_actors = [rehash(a) for a in actors] + episode.actors = ",".join(processed_actors) + # Rewrite the graphs and events too + def replace_all_actors(in_data: str) -> str: + out_data = in_data + for i in range(len(actors)): + out_data = out_data.replace(actors[i], processed_actors[i]) + return out_data + + graphs = episode.graphs + for graph in graphs: + graph_data = self.read_data_from_file(graph.full_path) + anon_graph_data = replace_all_actors(graph_data) + self.write_data_to_file(anon_graph_data, graph.full_path) + event_data = self.read_data_from_file(episode.dump_file_path) + anon_event_data = replace_all_actors(event_data) + self.write_data_to_file(anon_event_data, episode.dump_file_path) + session.commit() + return True + + def export(self, config: "DictConfig") -> "EpisodeDB": + """ + Create a scrubbed version of this database for use in releases + """ + assert config.file_root != self.file_root, "Cannot copy DB to same location!" + new_db = EpisodeDB(config) + + # Copy all the basic content + for table_name, table_obj in SQLBase.metadata.tables.items(): + with self.engine.connect() as orig_conn: + with new_db.engine.connect() as new_conn: + all_data = [ + dict(row) for row in orig_conn.execute(select(table_obj.c)) + ] + if len(all_data) == 0: + continue + new_conn.execute(table_obj.insert().values(all_data)) + new_conn.commit() + + with Session(self.engine) as session: + stmt = select(DBEpisode) + episodes = session.scalars(stmt).all() + for episode in episodes: + graphs = episode.graphs + for graph in graphs: + # Copy the graphs to the new DB + graph_data = self.read_data_from_file(graph.full_path) + new_db.write_data_to_file(graph_data, graph.full_path) + # Copy the events to the new DB + event_data = self.read_data_from_file(episode.dump_file_path) + new_db.write_data_to_file(event_data, episode.dump_file_path) + + for group in DBGroupName: + new_db.anonymize_group(group=group) + + return new_db diff --git a/light/data_model/db/users.py b/light/data_model/db/users.py new file mode 100644 index 000000000..74bfd0dc5 --- /dev/null +++ b/light/data_model/db/users.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from light.data_model.db.base import BaseDB, HasDBIDMixin +from omegaconf import MISSING, DictConfig +from typing import Optional, Union, Dict, Any, TYPE_CHECKING +from sqlalchemy import ( + insert, + select, + delete, + Enum, + Column, + Integer, + String, + Boolean, + ForeignKey, +) +from sqlalchemy.orm import declarative_base, relationship, Session +import enum + +if TYPE_CHECKING: + from light.data_model.db.environment import EnvDB + +SQLBase = declarative_base() + + +class PlayerStatus(enum.Enum): + STANDARD = "standard" + BLOCKED = "blocked" + TUTORIAL = "in_tutorial" + ADMIN = "admin" + + +class DBPlayer(HasDBIDMixin, SQLBase): + """Class containing the expected elements for a Player as stored in the db""" + + __tablename__ = "user_accounts" + ID_PREFIX = "USR" + + db_id = Column(String(40), primary_key=True) + extern_id = Column(String(60), nullable=False, index=True, unique=True) + is_preauth = Column(Boolean, nullable=False) + flag_count = Column(Integer, nullable=False) + safety_trigger_count = Column(Integer, nullable=False) + total_messages = Column(Integer, nullable=False) + account_status = Column(Enum(PlayerStatus), nullable=False) + scores = relationship("DBScoreEntry") + + def __repr__(self): + return f"DBPlayer(ids:[{self.db_id!r},{self.extern_id!r}], preauth:{self.is_preauth!r}, status:{self.account_status.value!r})" + + +class DBScoreEntry(SQLBase): + """Class containing score entries per player and character, as stored in the DB""" + + __tablename__ = "user_scores" + + id = Column(Integer, primary_key=True) + user_id = Column( + String, ForeignKey("user_accounts.db_id"), nullable=False, index=True + ) + agent_name_id = Column( + String(40), index=True + ) # Null for overall score for an agent + score = Column(Integer, nullable=False) + count = Column(Integer, nullable=False) + reward_xp = Column(Integer) + + def __repr__(self): + if self.agent_name_id is None: + return f"DBScoreEntry(ids:[{self.id!r},{self.user_id!r}] score:{self.score!r}, count:{self.count!r})" + return f"DBScoreEntry(ids:[{self.id!r},{self.user_id!r}], agent:{self.agent_name_id!r}, score:{self.score!r}, count:{self.count!r})" + + +class UserDB(BaseDB): + """ + User database for the core LIGHT game. Tracks people's progress in the + game, as associated with a given id. + """ + + DB_TYPE = "users" + + def _complete_init(self, config: "DictConfig"): + """ + Initialize any specific interaction-related paths. Populate + the list of available splits and datasets. + """ + SQLBase.metadata.create_all(self.engine) + + def _validate_init(self): + """ + Ensure that the interaction directory is properly loaded + """ + # TODO Check the table for any possible consistency issues + + def create_user( + self, + extern_id: str, + is_preauth: bool, + ) -> int: + """Create the specified player, idempotently""" + try: + user = self.get_player_by_extern_id(extern_id) + return user.db_id + except KeyError: + pass # Create a new user! + with Session(self.engine) as session: + player_id = DBPlayer.get_id() + player = DBPlayer( + db_id=player_id, + extern_id=extern_id, + is_preauth=is_preauth, + flag_count=0, + safety_trigger_count=0, + total_messages=0, + account_status=PlayerStatus.TUTORIAL, + ) + base_score = DBScoreEntry( + score=0, + count=0, + reward_xp=0, + ) + player.scores.append(base_score) + session.add(player) + session.commit() + return player_id + + def get_player(self, player_id: str) -> DBPlayer: + """Find the specified player, raise exception if non-existent""" + stmt = select(DBPlayer).where(DBPlayer.db_id == player_id) + with Session(self.engine) as session: + player = self._enforce_get_first(session, stmt, "Player not found") + session.expunge_all() + return player + + def get_player_by_extern_id(self, extern_id: str) -> DBPlayer: + """Find the specified player, raise exception if non-existent""" + stmt = select(DBPlayer).where(DBPlayer.extern_id == extern_id) + with Session(self.engine) as session: + player = self._enforce_get_first(session, stmt, "Player not found") + session.expunge_all() + return player + + def get_agent_score( + self, player_id: str, agent_name_id: Optional[str] = None + ) -> DBScoreEntry: + """Get the specific agent score. Supply None for total score""" + stmt = ( + select(DBScoreEntry) + .where(DBScoreEntry.user_id == player_id) + .where(DBScoreEntry.agent_name_id == agent_name_id) + ) + with Session(self.engine) as session: + score_entry = self._enforce_get_first( + session, stmt, "Player or agent not found" + ) + session.expunge_all() + return score_entry + + def update_agent_score( + self, + player_id: str, + agent_name_id: str, + points: int, + num_turns: int, + reward_change: int, + ): + """Add to both the base agent score and total score for a player""" + player_stmt = select(DBPlayer).where(DBPlayer.db_id == player_id) + base_stmt = ( + select(DBScoreEntry) + .where(DBScoreEntry.user_id == player_id) + .where(DBScoreEntry.agent_name_id == None) + ) + specific_stmt = ( + select(DBScoreEntry) + .where(DBScoreEntry.user_id == player_id) + .where(DBScoreEntry.agent_name_id == agent_name_id) + ) + + with Session(self.engine) as session: + player = self._enforce_get_first(session, player_stmt, "Player not found") + player.total_messages += num_turns + + base_score = session.scalars(base_stmt).first() + if base_score is None: + # we should never fail to get the basic agent score + raise AssertionError("No default score for player, corruption issue") + base_score.score += points + base_score.count += 1 + base_score.reward_xp += reward_change + + agent_score = session.scalars(specific_stmt).first() + print(agent_score, agent_name_id) + if agent_score is None: + # User has not played this character before, we'll need to initialize it + agent_score = DBScoreEntry( + agent_name_id=agent_name_id, + score=points, + count=1, + ) + player.scores.append(agent_score) + session.add(agent_score) + else: + agent_score.score += points + agent_score.count += 1 + + session.commit() + + def mark_flag(self, player_id: str) -> None: + """Mark that a player has been flagged""" + get_player = select(DBPlayer).where(DBPlayer.db_id == player_id) + with Session(self.engine) as session: + player = self._enforce_get_first(session, get_player, "Player not found") + player.flag_count += 1 + session.commit() + + def mark_safety_trigger(self, player_id: str) -> None: + """mark that a specific player has triggered the safety""" + get_player = select(DBPlayer).where(DBPlayer.db_id == player_id) + with Session(self.engine) as session: + player = self._enforce_get_first(session, get_player, "Player not found") + player.safety_trigger_count += 1 + session.commit() + + def update_player_status(self, player_id: str, new_status: PlayerStatus) -> None: + """Update the status for a given player""" + get_player = select(DBPlayer).where(DBPlayer.db_id == player_id) + with Session(self.engine) as session: + player = self._enforce_get_first(session, get_player, "Player not found") + player.account_status = new_status + session.commit() + + def delete_player(self, player_id: str, env_db: "EnvDB") -> None: + """ + Delete a player from the database, removing all + of their personal game data and clearing + association for their graphs + """ + get_player = select(DBPlayer).where(DBPlayer.db_id == player_id) + with Session(self.engine) as session: + # Ensure player exists first + _player = self._enforce_get_first(session, get_player, "Player not found") + session.execute(delete(DBPlayer).where(DBPlayer.db_id == player_id)) + session.execute( + delete(DBScoreEntry).where(DBScoreEntry.user_id == player_id) + ) + session.commit() + + env_db.clear_player_graphs(player_id) diff --git a/light/data_model/environment_checkpoint_parser.py b/light/data_model/environment_checkpoint_parser.py index c77b96da6..14cd9bab2 100644 --- a/light/data_model/environment_checkpoint_parser.py +++ b/light/data_model/environment_checkpoint_parser.py @@ -2,7 +2,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree.abs +# LICENSE file in the root directory of this source tree. import pickle import os diff --git a/light/data_model/filler_rooms.py b/light/data_model/filler_rooms.py index b1ea63615..3916413c0 100644 --- a/light/data_model/filler_rooms.py +++ b/light/data_model/filler_rooms.py @@ -137,7 +137,7 @@ def build_filler_rooms_from_categories(category_set): "This is a run down [FILL], it looks [OBSTACLE].", "There's a bit of [CHARACTERISTIC].", ] - background = [ + backgrounds = [ "Sometimes things are just what they seem. There's nothing interesting here." ] @@ -160,6 +160,8 @@ def build_filler_rooms_from_categories(category_set): description = description.replace( "[CHARACTERISTIC]", random.choice(characteristics) ) - build_room = FillerRoom(category, name, description, background) + build_room = FillerRoom( + category, name, description, random.choice(backgrounds) + ) filler_rooms[category].append(build_room) return filler_rooms, set(filler_room_names) diff --git a/light/data_model/tests/test_db.py b/light/data_model/tests/test_db.py index ceda6ac29..2cd4e07da 100644 --- a/light/data_model/tests/test_db.py +++ b/light/data_model/tests/test_db.py @@ -2,7 +2,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree.abs +# LICENSE file in the root directory of this source tree. import shutil, tempfile import sqlite3 diff --git a/light/data_model/tests/test_environment_db.py b/light/data_model/tests/test_environment_db.py new file mode 100644 index 000000000..63d663c0c --- /dev/null +++ b/light/data_model/tests/test_environment_db.py @@ -0,0 +1,1975 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +import shutil, tempfile +from omegaconf import OmegaConf +import os +import json +import time +import sqlalchemy + +from light.graph.structured_graph import OOGraph +from light.data_model.db.environment import ( + EnvDB, + MAX_RETENTION, + SCRUBBED_USER_ID, + DBRoomInsideType, + DBEdgeType, + DBFlagTargetType, + DBQuestTargetType, + DBNodeAttribute, + DBAgent, + DBAgentName, + DBObject, + DBObjectName, + DBRoom, + DBRoomName, + DBEdge, + DBTextEdge, + DBFlag, + DBGraph, + DBEdit, + DBQuest, + HasDBIDMixin, +) +from typing import List +from light.data_model.db.base import LightDBConfig, DBStatus, DBSplitType +from sqlalchemy.orm import Session + +TEST_USER_ID = "USR-test" + + +class TestEnvironmentDB(unittest.TestCase): + """ + Unit tests for the LIGHT EnvDB. + Builds simple test cases with standard inserts, but generally is a + set of monolithic tests for each table in the EnvDB + """ + + def setUp(self): + self.data_dir = tempfile.mkdtemp() + self.config = LightDBConfig(backend="test", file_root=self.data_dir) + self.data_dir_copy = tempfile.mkdtemp() + self.config_2 = LightDBConfig(backend="test", file_root=self.data_dir_copy) + + def tearDown(self): + shutil.rmtree(self.data_dir) + shutil.rmtree(self.data_dir_copy) + + def set_up_some_nodes(self, db: EnvDB): + # Create some test entries in the env DB + agent_ids: List[str] = [] + room_ids: List[str] = [] + object_ids: List[str] = [] + for x in range(5): + agent_ids.append( + db.create_agent_entry( + name=f"test_agent_{x}", + base_name="agent", + persona="agent_persona", + physical_description="agent_description", + ) + ) + room_ids.append( + db.create_room_entry( + name=f"test_room_{x}", + base_name="room", + description="room_description", + backstory="room_backstory", + ) + ) + object_ids.append( + db.create_object_entry( + name=f"test_object_{x}", + base_name="object", + physical_description="object_description", + is_container=0, + is_drink=0, + is_food=0, + is_gettable=1, + is_surface=0, + is_wearable=0, + is_weapon=0, + ) + ) + return agent_ids, room_ids, object_ids + + def test_initialize_env_db(self): + """Ensure it's possible to initialize the db""" + db = EnvDB(self.config) + + def test_create_load_inspect_agents(self): + """Ensure it's possible to create and load agents""" + # Create three agents, assert they have unique IDs but base_ids map + db = EnvDB(self.config) + BASE_NAME_1 = "king" + BASE_NAME_2 = "queen" + FULL_NAME_1 = "king of the orcs" + FULL_NAME_2 = "elder king of the rats" + FULL_NAME_3 = "queen of the land" + + # First agent should test mostly default values + TEST_PERSONA_1 = "test_persona_1" + TEST_DESC_1 = "test_desc_1" + agent_1_id = db.create_agent_entry( + name=FULL_NAME_1, + base_name=BASE_NAME_1, + persona=TEST_PERSONA_1, + physical_description=TEST_DESC_1, + ) + + # Ensure id created is correct + self.assertIsNotNone(agent_1_id) + self.assertTrue( + DBAgent.is_id(agent_1_id), f"Created ID {agent_1_id} not DBAgent ID" + ) + self.assertFalse( + DBObject.is_id(agent_1_id), f"Created ID {agent_1_id} passes as DBObject ID" + ) + + # Ensure agent created and matches defaults and provided + agent_1 = db.get_agent(agent_1_id) + base_id_1 = agent_1.base_id + self.assertTrue(DBAgentName.is_id(base_id_1), "Base ID not correct format") + self.assertEqual( + agent_1.db_id, agent_1_id, "Marked db_id differs from initially returned id" + ) + self.assertEqual(agent_1.persona, TEST_PERSONA_1) + self.assertEqual(agent_1.name, FULL_NAME_1) + self.assertEqual(agent_1.physical_description, TEST_DESC_1) + self.assertEqual(agent_1.built_occurrences, 0) + self.assertEqual(agent_1.name_prefix, "a") + self.assertEqual(agent_1.status, DBStatus.REVIEW) + self.assertIsNone(agent_1.charisma) + self.assertIsNone(agent_1.constitution) + self.assertIsNone(agent_1.strength) + self.assertIsNone(agent_1.dexterity) + self.assertIsNone(agent_1.intelligence) + self.assertIsNone(agent_1.wisdom) + self.assertIsNone(agent_1.is_plural) + self.assertIsNone(agent_1.size) + self.assertIsNone(agent_1.contain_size) + self.assertIsNone(agent_1.creator_id) + self.assertIsNotNone(agent_1.create_timestamp) + + # Ensure base agent created and matches values + base_agent_1 = db.get_agent_name(db_id=agent_1.base_id) + self.assertEqual(base_agent_1.name, BASE_NAME_1) + self.assertEqual(base_agent_1.db_id, agent_1.base_id) + self.assertEqual(base_agent_1.status, DBStatus.REVIEW) + self.assertEqual(base_agent_1.split, DBSplitType.UNSET) + + # Ensure that the link exists between base and agent + with Session(db.engine) as session: + session.add(base_agent_1) + self.assertEqual( + len(base_agent_1.agents), 1, "Should have one linked agent" + ) + session.expunge_all() + + # Should only be one agent + self.assertEqual(len(db.find_agents()), 1) + self.assertEqual(len(db.find_agent_names()), 1) + + # Duplicate create should fail + with self.assertRaises(sqlalchemy.exc.IntegrityError): + agent_1_id = db.create_agent_entry( + name=FULL_NAME_1, + base_name=BASE_NAME_1, + persona=TEST_PERSONA_1, + physical_description=TEST_DESC_1, + ) + + # Should only be one agent + self.assertEqual(len(db.find_agents()), 1) + self.assertEqual(len(db.find_agent_names()), 1) + + # Create a second agent sharing the first base class + TEST_PERSONA_2 = "test_persona_2" + TEST_DESC_2 = "test_desc_2" + agent_2_id = db.create_agent_entry( + name=FULL_NAME_2, + base_name=BASE_NAME_1, + persona=TEST_PERSONA_2, + physical_description=TEST_DESC_2, + ) + + # Ensure agent exists now, and that the base class is correct + agent_2 = db.get_agent(agent_2_id) + self.assertEqual( + agent_2.db_id, agent_2_id, "Marked db_id differs from initially returned id" + ) + self.assertEqual(agent_2.persona, TEST_PERSONA_2) + self.assertEqual(agent_2.name, FULL_NAME_2) + self.assertEqual(agent_2.physical_description, TEST_DESC_2) + self.assertEqual(agent_2.base_id, agent_2.base_id) + self.assertEqual(agent_2.name_prefix, "an") + + # Ensure only one base class, but two agents + self.assertEqual(len(db.find_agents()), 2) + self.assertEqual(len(db.find_agent_names()), 1) + + # Create a third agent, with all custom attributes + TEST_PERSONA_3 = "test_persona_3" + TEST_DESC_3 = "test_desc_3" + agent_3_id = db.create_agent_entry( + name=FULL_NAME_3, + base_name=BASE_NAME_2, + persona=TEST_PERSONA_3, + physical_description=TEST_DESC_3, + name_prefix="hello", + status=DBStatus.ACCEPTED, + charisma=1, + constitution=2, + strength=3, + dexterity=4, + intelligence=5, + wisdom=6, + is_plural=True, + size=7, + contain_size=8, + creator_id=TEST_USER_ID, + ) + + # Ensure id created is correct + self.assertIsNotNone(agent_3_id) + self.assertTrue( + DBAgent.is_id(agent_3_id), f"Created ID {agent_3_id} not DBAgent ID" + ) + self.assertFalse( + DBObject.is_id(agent_3_id), f"Created ID {agent_3_id} passes as DBObject ID" + ) + + # Ensure that the custom attributes all work + agent_3 = db.get_agent(agent_3_id) + base_id_3 = agent_3.base_id + self.assertNotEqual(base_id_3, base_id_1) + self.assertTrue(DBAgentName.is_id(base_id_3), "Base ID not correct format") + self.assertEqual( + agent_3.db_id, agent_3_id, "Marked db_id differs from initially returned id" + ) + self.assertEqual(agent_3.persona, TEST_PERSONA_3) + self.assertEqual(agent_3.name, FULL_NAME_3) + self.assertEqual(agent_3.physical_description, TEST_DESC_3) + self.assertEqual(agent_3.built_occurrences, 0) + self.assertEqual(agent_3.name_prefix, "hello") + self.assertEqual(agent_3.status, DBStatus.ACCEPTED) + self.assertEqual(agent_3.charisma, 1) + self.assertEqual(agent_3.constitution, 2) + self.assertEqual(agent_3.strength, 3) + self.assertEqual(agent_3.dexterity, 4) + self.assertEqual(agent_3.intelligence, 5) + self.assertEqual(agent_3.wisdom, 6) + self.assertTrue(agent_3.is_plural) + self.assertEqual(agent_3.size, 7) + self.assertEqual(agent_3.contain_size, 8) + self.assertEqual(agent_3.creator_id, TEST_USER_ID) + self.assertIsNotNone(agent_3.create_timestamp) + + # Ensure base agent created and matches values + base_agent_2 = db.get_agent_name(db_id=agent_3.base_id) + self.assertEqual(base_agent_2.name, BASE_NAME_2) + self.assertEqual(base_agent_2.db_id, agent_3.base_id) + self.assertEqual(base_agent_2.status, DBStatus.REVIEW) + self.assertEqual(base_agent_2.split, DBSplitType.UNSET) + + # Ensure two base classes, and three agents + self.assertEqual(len(db.find_agents()), 3) + self.assertEqual(len(db.find_agent_names()), 2) + + base_agent_1 = db.get_agent_name(db_id=agent_1.base_id) + # Ensure the base classes properly link to the agents + with Session(db.engine) as session: + session.add(base_agent_1) + self.assertEqual( + len(base_agent_1.agents), 2, "Base 1 Should have two linked agents" + ) + session.add(base_agent_2) + self.assertEqual( + len(base_agent_2.agents), 1, "Base 2 Should have one linked agent" + ) + session.expunge_all() + + # Ensure that all agents base names are present when in session + with Session(db.engine) as session: + session.add(agent_1) + session.add(agent_2) + session.add(agent_3) + self.assertEqual(agent_1.base_name.name, agent_2.base_name.name) + self.assertNotEqual(agent_1.base_name.name, agent_3.base_name.name) + + # assert that getting agent names fail on all invalid cases + with self.assertRaises(AssertionError): + base_agent_1 = db.get_agent_name(db_id="FAK-fake") + with self.assertRaises(KeyError): + base_agent_1 = db.get_agent_name(db_id="AGN-fake") + with self.assertRaises(KeyError): + base_agent_1 = db.get_agent_name(name="fake") + with self.assertRaises(KeyError): + base_agent_1 = db.get_agent_name( + db_id=agent_1.base_id, status=DBStatus.ACCEPTED + ) + with self.assertRaises(KeyError): + base_agent_1 = db.get_agent_name( + db_id=agent_1.base_id, split=DBSplitType.TRAIN + ) + + # Advanced agent name searches + matched_status = db.find_agent_names(status=DBStatus.REVIEW) + self.assertEqual(len(matched_status), 2) + unmatched_status = db.find_agent_names(status=DBStatus.ACCEPTED) + self.assertEqual(len(unmatched_status), 0) + matched_split = db.find_agent_names(split=DBSplitType.UNSET) + self.assertEqual(len(matched_split), 2) + unmatched_split = db.find_agent_names(split=DBSplitType.TRAIN) + self.assertEqual(len(unmatched_split), 0) + name_exact_match = db.find_agent_names(name=BASE_NAME_1) + self.assertEqual(len(name_exact_match), 1) + name_partial_match_1 = db.find_agent_names(name="qu") + self.assertEqual(len(name_partial_match_1), 1) + name_partial_match_2 = db.find_agent_names(name="n") + self.assertEqual(len(name_partial_match_2), 2) + name_no_match = db.find_agent_names(name="zzz") + self.assertEqual(len(name_no_match), 0) + + # Advanced agent searches + base_id_match_0 = db.find_agents(base_id="AGN-fake") + self.assertEqual(len(base_id_match_0), 0) + base_id_match_1 = db.find_agents(base_id=base_agent_2.db_id) + self.assertEqual(len(base_id_match_1), 1) + base_id_match_2 = db.find_agents(base_id=base_agent_1.db_id) + self.assertEqual(len(base_id_match_2), 2) + name_exact_match = db.find_agents(name=FULL_NAME_1) + self.assertEqual(len(name_exact_match), 1) + name_match_0 = db.find_agents(name="zzzzz") + self.assertEqual(len(name_match_0), 0) + name_match_1 = db.find_agents(name="orcs") + self.assertEqual(len(name_match_1), 1) + name_match_2 = db.find_agents(name="king") + self.assertEqual(len(name_match_2), 2) + persona_exact_match = db.find_agents(persona=TEST_PERSONA_3) + self.assertEqual(len(persona_exact_match), 1) + persona_match_0 = db.find_agents(persona="zzz") + self.assertEqual(len(persona_match_0), 0) + persona_match_1 = db.find_agents(persona="3") + self.assertEqual(len(persona_match_1), 1) + persona_match_3 = db.find_agents(persona="test") + self.assertEqual(len(persona_match_3), 3) + description_exact_match = db.find_agents(physical_description=TEST_DESC_1) + self.assertEqual(len(description_exact_match), 1) + description_match_0 = db.find_agents(physical_description="zzz") + self.assertEqual(len(description_match_0), 0) + description_match_1 = db.find_agents(physical_description="3") + self.assertEqual(len(description_match_1), 1) + description_match_3 = db.find_agents(physical_description="test") + self.assertEqual(len(description_match_3), 3) + name_prefix_match_0 = db.find_agents(name_prefix="test") + self.assertEqual(len(name_prefix_match_0), 0) + name_prefix_match_1 = db.find_agents(name_prefix="hello") + self.assertEqual(len(name_prefix_match_1), 1) + name_prefix_match_a = db.find_agents(name_prefix="a") + self.assertEqual(len(name_prefix_match_a), 1) + is_plural_match_0 = db.find_agents(is_plural=False) + self.assertEqual(len(is_plural_match_0), 0) + is_plural_match_1 = db.find_agents(is_plural=True) + self.assertEqual(len(is_plural_match_1), 1) + status_match_0 = db.find_agents(status=DBStatus.QUESTIONABLE) + self.assertEqual(len(status_match_0), 0) + status_match_1 = db.find_agents(status=DBStatus.ACCEPTED) + self.assertEqual(len(status_match_1), 1) + status_match_2 = db.find_agents(status=DBStatus.REVIEW) + self.assertEqual(len(status_match_2), 2) + split_match_0 = db.find_agents(split=DBSplitType.UNSEEN) + self.assertEqual(len(split_match_0), 0) + split_match_3 = db.find_agents(split=DBSplitType.UNSET) + self.assertEqual(len(split_match_3), 3) + creator_id_match_0 = db.find_agents(creator_id="fake") + self.assertEqual(len(creator_id_match_0), 0) + creator_id_match_1 = db.find_agents(creator_id=TEST_USER_ID) + self.assertEqual(len(creator_id_match_1), 1) + + # Test base scrub doesn't scrub anything + scrub_count = db.scrub_creators() + self.assertEqual( + scrub_count, 0, "Nothing exceeded retention time, should be no scrubs" + ) + + def test_create_load_inspect_rooms(self): + """Ensure it's possible to create and load rooms""" + # Create three rooms, assert they have unique IDs but base_ids map + db = EnvDB(self.config) + BASE_NAME_1 = "bedroom" + BASE_NAME_2 = "forest" + FULL_NAME_1 = "master bedroom" + FULL_NAME_2 = "dingy bedroom" + FULL_NAME_3 = "fairy forest" + + # First room should test mostly default values + TEST_STORY_1 = "test_story_1" + TEST_DESC_1 = "test_desc_1" + room_1_id = db.create_room_entry( + name=FULL_NAME_1, + base_name=BASE_NAME_1, + description=TEST_DESC_1, + backstory=TEST_STORY_1, + ) + + # Ensure id created is correct + self.assertIsNotNone(room_1_id) + self.assertTrue( + DBRoom.is_id(room_1_id), f"Created ID {room_1_id} not DBRoom ID" + ) + self.assertFalse( + DBObject.is_id(room_1_id), f"Created ID {room_1_id} passes as DBObject ID" + ) + + # Ensure room created and matches defaults and provided + room_1 = db.get_room(room_1_id) + base_id_1 = room_1.base_id + self.assertTrue(DBRoomName.is_id(base_id_1), "Base ID not correct format") + self.assertEqual( + room_1.db_id, room_1_id, "Marked db_id differs from initially returned id" + ) + self.assertEqual(room_1.name, FULL_NAME_1) + self.assertEqual(room_1.description, TEST_DESC_1) + self.assertEqual(room_1.backstory, TEST_STORY_1) + self.assertEqual(room_1.built_occurrences, 0) + self.assertEqual(room_1.status, DBStatus.REVIEW) + self.assertIsNone(room_1.rarity) + self.assertEqual(room_1.indoor_status, DBRoomInsideType.UNKNOWN) + self.assertIsNone(room_1.size) + self.assertIsNone(room_1.creator_id) + self.assertIsNotNone(room_1.create_timestamp) + + # Ensure base room created and matches values + base_room_1 = db.get_room_name(db_id=room_1.base_id) + self.assertEqual(base_room_1.name, BASE_NAME_1) + self.assertEqual(base_room_1.db_id, room_1.base_id) + self.assertEqual(base_room_1.status, DBStatus.REVIEW) + self.assertEqual(base_room_1.split, DBSplitType.UNSET) + + # Ensure that the link exists between base and room + with Session(db.engine) as session: + session.add(base_room_1) + self.assertEqual(len(base_room_1.rooms), 1, "Should have one linked room") + session.expunge_all() + + # Should only be one room + self.assertEqual(len(db.find_rooms()), 1) + self.assertEqual(len(db.find_room_names()), 1) + + # Duplicate create should fail + with self.assertRaises(sqlalchemy.exc.IntegrityError): + room_1_id = db.create_room_entry( + name=FULL_NAME_1, + base_name=BASE_NAME_1, + description=TEST_DESC_1, + backstory=TEST_STORY_1, + ) + + # Should only be one room + self.assertEqual(len(db.find_rooms()), 1) + self.assertEqual(len(db.find_room_names()), 1) + + # Create a second room sharing the first base class + TEST_STORY_2 = "test_story_2" + TEST_DESC_2 = "test_desc_2" + room_2_id = db.create_room_entry( + name=FULL_NAME_2, + base_name=BASE_NAME_1, + description=TEST_DESC_2, + backstory=TEST_STORY_2, + ) + + # Ensure room exists now, and that the base class is correct + room_2 = db.get_room(room_2_id) + self.assertEqual( + room_2.db_id, room_2_id, "Marked db_id differs from initially returned id" + ) + self.assertEqual(room_2.name, FULL_NAME_2) + self.assertEqual(room_2.description, TEST_DESC_2) + self.assertEqual(room_2.backstory, TEST_STORY_2) + self.assertEqual(room_2.base_id, room_2.base_id) + self.assertEqual(room_2.indoor_status, DBRoomInsideType.UNKNOWN) + + # Ensure only one base class, but two rooms + self.assertEqual(len(db.find_rooms()), 2) + self.assertEqual(len(db.find_room_names()), 1) + + # Create a third room, with all custom attributes + TEST_STORY_3 = "test_story_3" + TEST_DESC_3 = "test_desc_3" + room_3_id = db.create_room_entry( + name=FULL_NAME_3, + base_name=BASE_NAME_2, + description=TEST_DESC_3, + backstory=TEST_STORY_3, + indoor_status=DBRoomInsideType.OUTSIDE, + rarity=1, + size=2, + status=DBStatus.ACCEPTED, + creator_id=TEST_USER_ID, + ) + + # Ensure id created is correct + self.assertIsNotNone(room_3_id) + self.assertTrue( + DBRoom.is_id(room_3_id), f"Created ID {room_3_id} not DBRoom ID" + ) + self.assertFalse( + DBObject.is_id(room_3_id), f"Created ID {room_3_id} passes as DBObject ID" + ) + + # Ensure that the custom attributes all work + room_3 = db.get_room(room_3_id) + base_id_3 = room_3.base_id + self.assertNotEqual(base_id_3, base_id_1) + self.assertTrue(DBRoomName.is_id(base_id_3), "Base ID not correct format") + self.assertEqual( + room_3.db_id, room_3_id, "Marked db_id differs from initially returned id" + ) + self.assertEqual(room_3.name, FULL_NAME_3) + self.assertEqual(room_3.description, TEST_DESC_3) + self.assertEqual(room_3.backstory, TEST_STORY_3) + self.assertEqual(room_3.built_occurrences, 0) + self.assertEqual(room_3.status, DBStatus.ACCEPTED) + self.assertEqual(room_3.rarity, 1) + self.assertEqual(room_3.size, 2) + self.assertEqual(room_3.indoor_status, DBRoomInsideType.OUTSIDE) + self.assertEqual(room_3.creator_id, TEST_USER_ID) + self.assertIsNotNone(room_3.create_timestamp) + + # Ensure base room created and matches values + base_room_2 = db.get_room_name(db_id=room_3.base_id) + self.assertEqual(base_room_2.name, BASE_NAME_2) + self.assertEqual(base_room_2.db_id, room_3.base_id) + self.assertEqual(base_room_2.status, DBStatus.REVIEW) + self.assertEqual(base_room_2.split, DBSplitType.UNSET) + + # Ensure two base classes, and three rooms + self.assertEqual(len(db.find_rooms()), 3) + self.assertEqual(len(db.find_room_names()), 2) + + base_room_1 = db.get_room_name(db_id=room_1.base_id) + # Ensure the base classes properly link to the rooms + with Session(db.engine) as session: + session.add(base_room_1) + self.assertEqual( + len(base_room_1.rooms), 2, "Base 1 Should have two linked rooms" + ) + session.add(base_room_2) + self.assertEqual( + len(base_room_2.rooms), 1, "Base 2 Should have one linked room" + ) + session.expunge_all() + + # Ensure that all rooms base names are present when in session + with Session(db.engine) as session: + session.add(room_1) + session.add(room_2) + session.add(room_3) + self.assertEqual(room_1.base_name.name, room_2.base_name.name) + self.assertNotEqual(room_1.base_name.name, room_3.base_name.name) + + # assert that getting room names fail on all invalid cases + with self.assertRaises(AssertionError): + base_room_1 = db.get_room_name(db_id="FAK-fake") + with self.assertRaises(KeyError): + base_room_1 = db.get_room_name(db_id="RMN-fake") + with self.assertRaises(KeyError): + base_room_1 = db.get_room_name(name="fake") + with self.assertRaises(KeyError): + base_room_1 = db.get_room_name( + db_id=room_1.base_id, status=DBStatus.ACCEPTED + ) + with self.assertRaises(KeyError): + base_room_1 = db.get_room_name( + db_id=room_1.base_id, split=DBSplitType.TRAIN + ) + + # Advanced room name searches + matched_status = db.find_room_names(status=DBStatus.REVIEW) + self.assertEqual(len(matched_status), 2) + unmatched_status = db.find_room_names(status=DBStatus.ACCEPTED) + self.assertEqual(len(unmatched_status), 0) + matched_split = db.find_room_names(split=DBSplitType.UNSET) + self.assertEqual(len(matched_split), 2) + unmatched_split = db.find_room_names(split=DBSplitType.TRAIN) + self.assertEqual(len(unmatched_split), 0) + name_exact_match = db.find_room_names(name=BASE_NAME_1) + self.assertEqual(len(name_exact_match), 1) + name_partial_match_1 = db.find_room_names(name="bed") + self.assertEqual(len(name_partial_match_1), 1) + name_partial_match_2 = db.find_room_names(name="r") + self.assertEqual(len(name_partial_match_2), 2) + name_no_match = db.find_room_names(name="zzz") + self.assertEqual(len(name_no_match), 0) + + # Advanced room searches + base_id_match_0 = db.find_rooms(base_id="AGN-fake") + self.assertEqual(len(base_id_match_0), 0) + base_id_match_1 = db.find_rooms(base_id=base_room_2.db_id) + self.assertEqual(len(base_id_match_1), 1) + base_id_match_2 = db.find_rooms(base_id=base_room_1.db_id) + self.assertEqual(len(base_id_match_2), 2) + name_exact_match = db.find_rooms(name=FULL_NAME_1) + self.assertEqual(len(name_exact_match), 1) + name_match_0 = db.find_rooms(name="zzzzz") + self.assertEqual(len(name_match_0), 0) + name_match_1 = db.find_rooms(name="dingy") + self.assertEqual(len(name_match_1), 1) + name_match_2 = db.find_rooms(name="bed") + self.assertEqual(len(name_match_2), 2) + backstory_exact_match = db.find_rooms(backstory=TEST_STORY_3) + self.assertEqual(len(backstory_exact_match), 1) + backstory_match_0 = db.find_rooms(backstory="zzz") + self.assertEqual(len(backstory_match_0), 0) + backstory_match_1 = db.find_rooms(backstory="3") + self.assertEqual(len(backstory_match_1), 1) + backstory_match_3 = db.find_rooms(backstory="test") + self.assertEqual(len(backstory_match_3), 3) + description_exact_match = db.find_rooms(description=TEST_DESC_1) + self.assertEqual(len(description_exact_match), 1) + description_match_0 = db.find_rooms(description="zzz") + self.assertEqual(len(description_match_0), 0) + description_match_1 = db.find_rooms(description="3") + self.assertEqual(len(description_match_1), 1) + description_match_3 = db.find_rooms(description="test") + self.assertEqual(len(description_match_3), 3) + status_match_0 = db.find_rooms(status=DBStatus.QUESTIONABLE) + self.assertEqual(len(status_match_0), 0) + status_match_1 = db.find_rooms(status=DBStatus.ACCEPTED) + self.assertEqual(len(status_match_1), 1) + status_match_2 = db.find_rooms(status=DBStatus.REVIEW) + self.assertEqual(len(status_match_2), 2) + split_match_0 = db.find_rooms(split=DBSplitType.UNSEEN) + self.assertEqual(len(split_match_0), 0) + split_match_3 = db.find_rooms(split=DBSplitType.UNSET) + self.assertEqual(len(split_match_3), 3) + indoor_status_match_0 = db.find_rooms(indoor_status=DBRoomInsideType.MULTI_ROOM) + self.assertEqual(len(indoor_status_match_0), 0) + indoor_status_match_1 = db.find_rooms(indoor_status=DBRoomInsideType.OUTSIDE) + self.assertEqual(len(indoor_status_match_1), 1) + indoor_status_match_2 = db.find_rooms(indoor_status=DBRoomInsideType.UNKNOWN) + self.assertEqual(len(indoor_status_match_2), 2) + creator_id_match_0 = db.find_rooms(creator_id="fake") + self.assertEqual(len(creator_id_match_0), 0) + creator_id_match_1 = db.find_rooms(creator_id=TEST_USER_ID) + self.assertEqual(len(creator_id_match_1), 1) + + # Ensure duplicating works, but with creator IDs scrubbed + new_db = db.export(self.config_2) + self.assertEqual(len(new_db.find_rooms()), 3) + self.assertEqual(len(new_db.find_room_names()), 2) + self.assertEqual(len(new_db.find_rooms(creator_id=TEST_USER_ID)), 0) + self.assertEqual(len(new_db.find_rooms(creator_id=SCRUBBED_USER_ID)), 1) + + def test_create_load_inspect_objects(self): + """Ensure it's possible to create and load objects""" + # Create three objects, assert they have unique IDs but base_ids map + db = EnvDB(self.config) + BASE_NAME_1 = "ball" + BASE_NAME_2 = "shovel" + FULL_NAME_1 = "azure ball" + FULL_NAME_2 = "metal ball" + FULL_NAME_3 = "garden shovel" + + # First object should test mostly default values + TEST_DESC_1 = "test_desc_1" + object_1_id = db.create_object_entry( + name=FULL_NAME_1, + base_name=BASE_NAME_1, + physical_description=TEST_DESC_1, + is_container=0, + is_drink=0, + is_food=0, + is_gettable=1, + is_surface=0, + is_wearable=0, + is_weapon=0, + ) + + # Ensure id created is correct + self.assertIsNotNone(object_1_id) + self.assertTrue( + DBObject.is_id(object_1_id), f"Created ID {object_1_id} not DBObject ID" + ) + self.assertFalse( + DBRoom.is_id(object_1_id), f"Created ID {object_1_id} passes as DBRoom ID" + ) + + # Ensure object created and matches defaults and provided + object_1 = db.get_object(object_1_id) + base_id_1 = object_1.base_id + self.assertTrue(DBObjectName.is_id(base_id_1), "Base ID not correct format") + self.assertEqual( + object_1.db_id, + object_1_id, + "Marked db_id differs from initially returned id", + ) + self.assertEqual(object_1.name, FULL_NAME_1) + self.assertEqual(object_1.physical_description, TEST_DESC_1) + self.assertEqual(object_1.is_container, 0) + self.assertEqual(object_1.is_drink, 0) + self.assertEqual(object_1.is_food, 0) + self.assertEqual(object_1.is_gettable, 1) + self.assertEqual(object_1.is_surface, 0) + self.assertEqual(object_1.is_wearable, 0) + self.assertEqual(object_1.is_weapon, 0) + self.assertEqual(object_1.built_occurrences, 0) + self.assertEqual(object_1.name_prefix, "an") + self.assertEqual(object_1.status, DBStatus.REVIEW) + self.assertIsNone(object_1.is_plural) + self.assertIsNone(object_1.size) + self.assertIsNone(object_1.contain_size) + self.assertIsNone(object_1.value) + self.assertIsNone(object_1.rarity) + self.assertIsNone(object_1.creator_id) + self.assertIsNotNone(object_1.create_timestamp) + + # Ensure base object created and matches values + base_object_1 = db.get_object_name(db_id=object_1.base_id) + self.assertEqual(base_object_1.name, BASE_NAME_1) + self.assertEqual(base_object_1.db_id, object_1.base_id) + self.assertEqual(base_object_1.status, DBStatus.REVIEW) + self.assertEqual(base_object_1.split, DBSplitType.UNSET) + + # Ensure that the link exists between base and object + with Session(db.engine) as session: + session.add(base_object_1) + self.assertEqual( + len(base_object_1.objects), 1, "Should have one linked object" + ) + session.expunge_all() + + # Should only be one object + self.assertEqual(len(db.find_objects()), 1) + self.assertEqual(len(db.find_object_names()), 1) + + # Duplicate create should fail + with self.assertRaises(sqlalchemy.exc.IntegrityError): + object_1_id = db.create_object_entry( + name=FULL_NAME_1, + base_name=BASE_NAME_1, + physical_description=TEST_DESC_1, + is_container=0, + is_drink=0, + is_food=0, + is_gettable=1, + is_surface=0, + is_wearable=0, + is_weapon=0, + ) + + # Should only be one object + self.assertEqual(len(db.find_objects()), 1) + self.assertEqual(len(db.find_object_names()), 1) + + # Create a second object sharing the first base class + TEST_DESC_2 = "test_desc_2" + object_2_id = db.create_object_entry( + name=FULL_NAME_2, + base_name=BASE_NAME_1, + physical_description=TEST_DESC_2, + is_container=0, + is_drink=0, + is_food=0, + is_gettable=1, + is_surface=0, + is_wearable=0, + is_weapon=0, + ) + + # Ensure object exists now, and that the base class is correct + object_2 = db.get_object(object_2_id) + self.assertEqual( + object_2.db_id, + object_2_id, + "Marked db_id differs from initially returned id", + ) + self.assertEqual(object_2.name, FULL_NAME_2) + self.assertEqual(object_2.physical_description, TEST_DESC_2) + self.assertEqual(object_2.base_id, object_2.base_id) + self.assertEqual(object_2.name_prefix, "a") + + # Ensure only one base class, but two objects + self.assertEqual(len(db.find_objects()), 2) + self.assertEqual(len(db.find_object_names()), 1) + + # Create a third object, with all custom attributes + TEST_DESC_3 = "test_desc_3" + object_3_id = db.create_object_entry( + name=FULL_NAME_3, + base_name=BASE_NAME_2, + physical_description=TEST_DESC_3, + is_container=0, + is_drink=0, + is_food=0, + is_gettable=1, + is_surface=0, + is_wearable=0, + is_weapon=0, + name_prefix="hello", + is_plural=True, + size=1, + contain_size=2, + value=3, + rarity=4, + status=DBStatus.ACCEPTED, + creator_id=TEST_USER_ID, + ) + + # Ensure id created is correct + self.assertIsNotNone(object_3_id) + self.assertTrue( + DBObject.is_id(object_3_id), f"Created ID {object_3_id} not DBObject ID" + ) + self.assertFalse( + DBRoom.is_id(object_3_id), f"Created ID {object_3_id} passes as DBRoom ID" + ) + + # Ensure that the custom attributes all work + object_3 = db.get_object(object_3_id) + base_id_3 = object_3.base_id + self.assertNotEqual(base_id_3, base_id_1) + self.assertTrue(DBObjectName.is_id(base_id_3), "Base ID not correct format") + self.assertEqual( + object_3.db_id, + object_3_id, + "Marked db_id differs from initially returned id", + ) + self.assertEqual(object_3.name, FULL_NAME_3) + self.assertEqual(object_3.physical_description, TEST_DESC_3) + self.assertEqual(object_3.built_occurrences, 0) + self.assertEqual(object_3.is_container, 0) + self.assertEqual(object_3.is_drink, 0) + self.assertEqual(object_3.is_food, 0) + self.assertEqual(object_3.is_gettable, 1) + self.assertEqual(object_3.is_surface, 0) + self.assertEqual(object_3.is_wearable, 0) + self.assertEqual(object_3.is_weapon, 0) + self.assertEqual(object_3.built_occurrences, 0) + self.assertEqual(object_3.name_prefix, "hello") + self.assertEqual(object_3.status, DBStatus.ACCEPTED) + self.assertEqual(object_3.is_plural, True) + self.assertEqual(object_3.size, 1) + self.assertEqual(object_3.contain_size, 2) + self.assertEqual(object_3.value, 3) + self.assertEqual(object_3.rarity, 4) + self.assertEqual(object_3.creator_id, TEST_USER_ID) + self.assertIsNotNone(object_3.create_timestamp) + + # Ensure base object created and matches values + base_object_2 = db.get_object_name(db_id=object_3.base_id) + self.assertEqual(base_object_2.name, BASE_NAME_2) + self.assertEqual(base_object_2.db_id, object_3.base_id) + self.assertEqual(base_object_2.status, DBStatus.REVIEW) + self.assertEqual(base_object_2.split, DBSplitType.UNSET) + + # Ensure two base classes, and three objects + self.assertEqual(len(db.find_objects()), 3) + self.assertEqual(len(db.find_object_names()), 2) + + base_object_1 = db.get_object_name(db_id=object_1.base_id) + # Ensure the base classes properly link to the objects + with Session(db.engine) as session: + session.add(base_object_1) + self.assertEqual( + len(base_object_1.objects), 2, "Base 1 Should have two linked objects" + ) + session.add(base_object_2) + self.assertEqual( + len(base_object_2.objects), 1, "Base 2 Should have one linked object" + ) + session.expunge_all() + + # Ensure that all objects base names are present when in session + with Session(db.engine) as session: + session.add(object_1) + session.add(object_2) + session.add(object_3) + self.assertEqual(object_1.base_name.name, object_2.base_name.name) + self.assertNotEqual(object_1.base_name.name, object_3.base_name.name) + + # assert that getting object names fail on all invalid cases + with self.assertRaises(AssertionError): + base_object_1 = db.get_object_name(db_id="FAK-fake") + with self.assertRaises(KeyError): + base_object_1 = db.get_object_name(db_id="OBN-fake") + with self.assertRaises(KeyError): + base_object_1 = db.get_object_name(name="fake") + with self.assertRaises(KeyError): + base_object_1 = db.get_object_name( + db_id=object_1.base_id, status=DBStatus.ACCEPTED + ) + with self.assertRaises(KeyError): + base_object_1 = db.get_object_name( + db_id=object_1.base_id, split=DBSplitType.TRAIN + ) + + # Advanced object name searches + matched_status = db.find_object_names(status=DBStatus.REVIEW) + self.assertEqual(len(matched_status), 2) + unmatched_status = db.find_object_names(status=DBStatus.ACCEPTED) + self.assertEqual(len(unmatched_status), 0) + matched_split = db.find_object_names(split=DBSplitType.UNSET) + self.assertEqual(len(matched_split), 2) + unmatched_split = db.find_object_names(split=DBSplitType.TRAIN) + self.assertEqual(len(unmatched_split), 0) + name_exact_match = db.find_object_names(name=BASE_NAME_1) + self.assertEqual(len(name_exact_match), 1) + name_partial_match_1 = db.find_object_names(name="vel") + self.assertEqual(len(name_partial_match_1), 1) + name_partial_match_2 = db.find_object_names(name="l") + self.assertEqual(len(name_partial_match_2), 2) + name_no_match = db.find_object_names(name="zzz") + self.assertEqual(len(name_no_match), 0) + + # Advanced object searches + base_id_match_0 = db.find_objects(base_id="OBN-fake") + self.assertEqual(len(base_id_match_0), 0) + base_id_match_1 = db.find_objects(base_id=base_object_2.db_id) + self.assertEqual(len(base_id_match_1), 1) + base_id_match_2 = db.find_objects(base_id=base_object_1.db_id) + self.assertEqual(len(base_id_match_2), 2) + name_exact_match = db.find_objects(name=FULL_NAME_1) + self.assertEqual(len(name_exact_match), 1) + name_match_0 = db.find_objects(name="zzzzz") + self.assertEqual(len(name_match_0), 0) + name_match_1 = db.find_objects(name="metal") + self.assertEqual(len(name_match_1), 1) + name_match_2 = db.find_objects(name="ball") + self.assertEqual(len(name_match_2), 2) + description_exact_match = db.find_objects(physical_description=TEST_DESC_1) + self.assertEqual(len(description_exact_match), 1) + description_match_0 = db.find_objects(physical_description="zzz") + self.assertEqual(len(description_match_0), 0) + description_match_1 = db.find_objects(physical_description="3") + self.assertEqual(len(description_match_1), 1) + description_match_3 = db.find_objects(physical_description="test") + self.assertEqual(len(description_match_3), 3) + name_prefix_match_0 = db.find_objects(name_prefix="test") + self.assertEqual(len(name_prefix_match_0), 0) + name_prefix_match_1 = db.find_objects(name_prefix="hello") + self.assertEqual(len(name_prefix_match_1), 1) + name_prefix_match_a = db.find_objects(name_prefix="a") + self.assertEqual(len(name_prefix_match_a), 1) + name_prefix_match_an = db.find_objects(name_prefix="an") + self.assertEqual(len(name_prefix_match_an), 1) + is_plural_match_0 = db.find_objects(is_plural=False) + self.assertEqual(len(is_plural_match_0), 0) + is_plural_match_1 = db.find_objects(is_plural=True) + self.assertEqual(len(is_plural_match_1), 1) + status_match_0 = db.find_objects(status=DBStatus.QUESTIONABLE) + self.assertEqual(len(status_match_0), 0) + status_match_1 = db.find_objects(status=DBStatus.ACCEPTED) + self.assertEqual(len(status_match_1), 1) + status_match_2 = db.find_objects(status=DBStatus.REVIEW) + self.assertEqual(len(status_match_2), 2) + split_match_0 = db.find_objects(split=DBSplitType.UNSEEN) + self.assertEqual(len(split_match_0), 0) + split_match_3 = db.find_objects(split=DBSplitType.UNSET) + self.assertEqual(len(split_match_3), 3) + creator_id_match_0 = db.find_objects(creator_id="fake") + self.assertEqual(len(creator_id_match_0), 0) + creator_id_match_1 = db.find_objects(creator_id=TEST_USER_ID) + self.assertEqual(len(creator_id_match_1), 1) + is_container_match_0 = db.find_objects(is_container=True) + self.assertEqual(len(is_container_match_0), 0) + is_container_match_3 = db.find_objects(is_container=False) + self.assertEqual(len(is_container_match_3), 3) + is_drink_match_0 = db.find_objects(is_drink=True) + self.assertEqual(len(is_drink_match_0), 0) + is_drink_match_3 = db.find_objects(is_drink=False) + self.assertEqual(len(is_drink_match_3), 3) + is_food_match_0 = db.find_objects(is_food=True) + self.assertEqual(len(is_food_match_0), 0) + is_food_match_3 = db.find_objects(is_food=False) + self.assertEqual(len(is_food_match_3), 3) + is_gettable_match_0 = db.find_objects(is_gettable=False) + self.assertEqual(len(is_gettable_match_0), 0) + is_gettable_match_3 = db.find_objects(is_gettable=True) + self.assertEqual(len(is_gettable_match_3), 3) + is_surface_match_0 = db.find_objects(is_surface=True) + self.assertEqual(len(is_surface_match_0), 0) + is_surface_match_3 = db.find_objects(is_surface=False) + self.assertEqual(len(is_surface_match_3), 3) + is_wearable_match_0 = db.find_objects(is_wearable=True) + self.assertEqual(len(is_wearable_match_0), 0) + is_wearable_match_3 = db.find_objects(is_wearable=False) + self.assertEqual(len(is_wearable_match_3), 3) + is_weapon_match_0 = db.find_objects(is_weapon=True) + self.assertEqual(len(is_weapon_match_0), 0) + is_weapon_match_3 = db.find_objects(is_weapon=False) + self.assertEqual(len(is_weapon_match_3), 3) + + # Run scrub + scrub_count = db.scrub_creators(start_time=time.time() + MAX_RETENTION) + self.assertEqual(scrub_count, 1, "Should have scrubbed 1 object") + # Can't find old user IDs + match_user = db.find_objects(creator_id="USR-test_editor") + self.assertEqual(len(match_user), 0) + # Can find scrub + match_user = db.find_objects(creator_id=SCRUBBED_USER_ID) + self.assertEqual(len(match_user), 1) + + def test_create_load_edges(self): + """Ensure it's possible to create edges, and load them from DBElems""" + db = EnvDB(self.config) + + # get some things to use + agent_ids, room_ids, object_ids = self.set_up_some_nodes(db) + agent_1_id = agent_ids[0] + agent_2_id = agent_ids[1] + agent_3_id = agent_ids[2] + object_1_id = object_ids[0] + object_2_id = object_ids[1] + object_3_id = object_ids[2] + room_1_id = room_ids[0] + room_2_id = room_ids[1] + + # Create first edge + edge_1_id = db.create_edge( + parent_id=room_1_id, + child_id=agent_1_id, + edge_type=DBEdgeType.CONTAINS, + ) + self.assertTrue(DBEdge.is_id(edge_1_id)) + + # Ensure edge exists correctly + edges = db.get_edges() + self.assertEqual(len(edges), 1) + edge_1 = edges[0] + self.assertEqual(edge_1.db_id, edge_1_id) + self.assertEqual(edge_1.parent_id, room_1_id) + self.assertEqual(edge_1.child_id, agent_1_id) + self.assertEqual(edge_1.built_occurrences, 0) + self.assertEqual(edge_1.edge_type, DBEdgeType.CONTAINS) + self.assertEqual(edge_1.status, DBStatus.REVIEW) + self.assertEqual(edge_1.edge_label, "") + self.assertIsNone(edge_1.creator_id) + self.assertIsNotNone(edge_1.create_timestamp) + + # Note no duplicate edge possible + edge_1_id_2 = db.create_edge( + parent_id=room_1_id, + child_id=agent_1_id, + edge_type=DBEdgeType.CONTAINS, + ) + self.assertEqual(edge_1_id, edge_1_id_2) + edges = db.get_edges() + self.assertEqual(len(edges), 1) + edge_1 = edges[0] + + # Try expanding edge + with self.assertRaises(AssertionError): + _test_child = edge_1.child() + edge_1.expand_edge(db) + self.assertIsInstance(edge_1.child, DBAgent) + self.assertEqual(edge_1.child.db_id, agent_1_id) + + # Create more edges + edge_2_id = db.create_edge( + parent_id=room_1_id, + child_id=agent_2_id, + edge_type=DBEdgeType.CONTAINS, + ) + edge_3_id = db.create_edge( + parent_id=room_1_id, + child_id=agent_3_id, + edge_type=DBEdgeType.MAY_CONTAIN, + ) + edge_4_id = db.create_edge( + parent_id=agent_1_id, + child_id=object_2_id, + edge_type=DBEdgeType.CONTAINS, + ) + edge_5_id = db.create_edge( + parent_id=agent_1_id, + child_id=object_3_id, + edge_type=DBEdgeType.WEARING, + ) + edge_6_id = db.create_edge( + parent_id=room_1_id, + child_id=object_1_id, + edge_type=DBEdgeType.MAY_CONTAIN, + ) + edge_7_id = db.create_edge( + parent_id=agent_3_id, + child_id=object_3_id, + edge_type=DBEdgeType.MAY_WEAR, + ) + edge_8_id = db.create_edge( + parent_id=agent_2_id, + child_id=object_3_id, + edge_type=DBEdgeType.WIELDING, + ) + edge_9_id = db.create_edge( + parent_id=agent_3_id, + child_id=object_1_id, + edge_type=DBEdgeType.MAY_WIELD, + status=DBStatus.REJECTED, + ) + edge_10_id = db.create_edge( + parent_id=room_1_id, + child_id=room_2_id, + edge_type=DBEdgeType.NEIGHBOR, + edge_label="a path to", + ) + edge_11_id = db.create_edge( + parent_id=room_2_id, + child_id=room_1_id, + edge_type=DBEdgeType.MAY_BE_NEIGHBOR, + creator_id=TEST_USER_ID, + ) + edge_12_id = db.create_edge( + parent_id=object_1_id, + child_id=object_2_id, + edge_type=DBEdgeType.MAY_CONTAIN, + ) + + # Try expanding other edges + edge_2 = db.get_edges(parent_id=room_1_id, child_id=room_2_id)[0] + edge_2.expand_edge(db) + self.assertIsInstance(edge_2.child, DBRoom) + self.assertEqual(edge_2.child.db_id, room_2_id) + edge_3 = db.get_edges(parent_id=room_1_id, child_id=object_1_id)[0] + edge_3.expand_edge(db) + self.assertIsInstance(edge_3.child, DBObject) + self.assertEqual(edge_3.child.db_id, object_1_id) + + # Query the edges + edges = db.get_edges() + self.assertEqual(len(edges), 12) + no_matching_pair = db.get_edges(parent_id=room_1_id, child_id=object_3_id) + self.assertEqual(len(no_matching_pair), 0) + no_matching_type = db.get_edges( + parent_id=room_1_id, + child_id=object_1_id, + edge_type=DBEdgeType.MAY_BE_NEIGHBOR, + ) + self.assertEqual(len(no_matching_type), 0) + room_1_edges = db.get_edges(parent_id=room_1_id) + self.assertEqual(len(room_1_edges), 5) + agent_1_edges = db.get_edges(parent_id=agent_1_id) + self.assertEqual(len(agent_1_edges), 2) + object_1_edges = db.get_edges(parent_id=object_1_id) + self.assertEqual(len(object_1_edges), 1) + neighbor_edges = db.get_edges(edge_type=DBEdgeType.NEIGHBOR) + self.assertEqual(len(neighbor_edges), 1) + contains_edges = db.get_edges(edge_type=DBEdgeType.CONTAINS) + self.assertEqual(len(contains_edges), 3) + matching_edge_label = db.get_edges(edge_label="") + self.assertEqual(len(matching_edge_label), 11) + special_edge_label = db.get_edges(edge_label="a path to") + self.assertEqual(len(special_edge_label), 1) + no_matching_edge_label = db.get_edges(edge_label="zzzzzz") + self.assertEqual(len(no_matching_edge_label), 0) + matching_status = db.get_edges(status=DBStatus.REVIEW) + self.assertEqual(len(matching_status), 11) + special_matching_status = db.get_edges(status=DBStatus.REJECTED) + self.assertEqual(len(special_matching_status), 1) + no_matching_status = db.get_edges(status=DBStatus.ACCEPTED) + self.assertEqual(len(no_matching_status), 0) + + # Test edge strength filtering + room_1 = db.get_room(room_1_id) + edge_2 = db.get_edges(parent_id=room_1_id, child_id=agent_2_id)[0] + with Session(db.engine) as session: + session.add(room_1) + session.add(edge_1) + session.add(edge_2) + room_1.built_occurrences = 3 + edge_1.built_occurrences = 1 + edge_2.built_occurrences = 2 + session.flush() + session.commit() + session.expunge_all() + more_than_quarter = db.get_edges(min_strength=0.25) + self.assertEqual(len(more_than_quarter), 2) + more_than_half = db.get_edges(min_strength=0.5) + self.assertEqual(len(more_than_half), 1) + more_than_top = db.get_edges(min_strength=0.75) + self.assertEqual(len(more_than_top), 0) + + # Create first text edge + text_edge_1_id = db.create_text_edge( + parent_id=room_1_id, + child_text="unknown object", + edge_type=DBEdgeType.MAY_CONTAIN, + ) + self.assertTrue(DBTextEdge.is_id(text_edge_1_id)) + + # Ensure edge exists correctly + text_edges = db.get_text_edges() + self.assertEqual(len(text_edges), 1) + text_edge_1 = text_edges[0] + self.assertEqual(text_edge_1.db_id, text_edge_1_id) + self.assertEqual(text_edge_1.parent_id, room_1_id) + self.assertEqual(text_edge_1.child_text, "unknown object") + self.assertEqual(text_edge_1.edge_type, DBEdgeType.MAY_CONTAIN) + self.assertEqual(text_edge_1.status, DBStatus.REVIEW) + self.assertEqual(text_edge_1.edge_label, "") + self.assertIsNone(text_edge_1.creator_id) + self.assertIsNotNone(text_edge_1.create_timestamp) + + # Note no duplicate edge possible + text_edge_1_id_2 = db.create_text_edge( + parent_id=room_1_id, + child_text="unknown object", + edge_type=DBEdgeType.MAY_CONTAIN, + ) + self.assertEqual(text_edge_1_id, text_edge_1_id_2) + text_edges = db.get_text_edges() + self.assertEqual(len(text_edges), 1) + + # More text edges + text_edge_2_id = db.create_text_edge( + parent_id=agent_1_id, + child_text="unknown room", + edge_type=DBEdgeType.MAY_BE_CONTAINED_IN, + ) + text_edge_3_id = db.create_text_edge( + parent_id=object_1_id, + child_text="unknown agent", + edge_type=DBEdgeType.CONTAINED_IN, + ) + text_edge_4_id = db.create_text_edge( + parent_id=room_1_id, + child_text="unknown room", + edge_type=DBEdgeType.MAY_BE_NEIGHBOR, + edge_label="a path to", + creator_id=TEST_USER_ID, + status=DBStatus.ACCEPTED, + ) + + # Query text edges + text_edges = db.get_text_edges() + self.assertEqual(len(text_edges), 4) + text_no_matching_pair = db.get_text_edges( + parent_id=object_1_id, child_text="unknown room" + ) + self.assertEqual(len(text_no_matching_pair), 0) + text_no_matching_parent = db.get_text_edges(parent_id=agent_2_id) + self.assertEqual(len(text_no_matching_parent), 0) + text_no_matching_child = db.get_text_edges(child_text="something random") + self.assertEqual(len(text_no_matching_child), 0) + text_matching_child = db.get_text_edges(child_text="unknown room") + self.assertEqual(len(text_matching_child), 2) + text_matching_parent = db.get_text_edges(parent_id=room_1_id) + self.assertEqual(len(text_matching_parent), 2) + text_matching_type = db.get_text_edges(edge_type=DBEdgeType.CONTAINED_IN) + self.assertEqual(len(text_matching_type), 1) + text_no_matching_type = db.get_text_edges(edge_type=DBEdgeType.NEIGHBOR) + self.assertEqual(len(text_no_matching_type), 0) + text_matching_status = db.get_text_edges(status=DBStatus.REVIEW) + self.assertEqual(len(text_matching_status), 3) + text_special_status = db.get_text_edges(status=DBStatus.ACCEPTED) + self.assertEqual(len(text_special_status), 1) + text_no_matching_status = db.get_text_edges(status=DBStatus.REJECTED) + self.assertEqual(len(text_no_matching_status), 0) + text_matching_label = db.get_text_edges(edge_label="") + self.assertEqual(len(text_matching_label), 3) + text_special_label = db.get_text_edges(edge_label="a path to") + self.assertEqual(len(text_special_label), 1) + text_no_matching_label = db.get_text_edges(edge_label="zzzzzz") + self.assertEqual(len(text_no_matching_label), 0) + + # Query edges for DBElems + room_1 = db.get_room(room_1_id) + agent_1 = db.get_agent(agent_1_id) + agent_2 = db.get_agent(agent_2_id) + object_1 = db.get_object(object_1_id) + + # Try expanding edge + # All edges fail when not loading first + with self.assertRaises(AssertionError): + _test_text_edges = room_1.text_edges + with self.assertRaises(AssertionError): + _test_text_edges = agent_1.text_edges + with self.assertRaises(AssertionError): + _test_text_edges = agent_2.text_edges + with self.assertRaises(AssertionError): + _test_text_edges = object_1.text_edges + with self.assertRaises(AssertionError): + _test_node_edges = room_1.node_edges + with self.assertRaises(AssertionError): + _test_node_edges = agent_1.node_edges + with self.assertRaises(AssertionError): + _test_node_edges = agent_2.node_edges + with self.assertRaises(AssertionError): + _test_node_edges = object_1.node_edges + + room_1.load_edges(db) + agent_1.load_edges(db) + agent_2.load_edges(db) + object_1.load_edges(db) + + text_edges = db.get_text_edges() + self.assertEqual(len(text_edges), 4) + + self.assertEqual(len(room_1.node_edges), 5) + self.assertEqual(len(room_1.text_edges), 2) + self.assertEqual(len(agent_1.node_edges), 2) + self.assertEqual(len(agent_1.text_edges), 1) + self.assertEqual(len(agent_2.node_edges), 1) + self.assertEqual(len(agent_2.text_edges), 0) + self.assertEqual(len(object_1.node_edges), 1) + self.assertEqual(len(object_1.text_edges), 1) + + # Ensure that each of the edges is valid + for node in [room_1, agent_1, agent_2, object_1]: + for node_edge in node.node_edges: + self.assertEqual(node_edge.child.db_id, node_edge.child_id) + for text_edge in node.text_edges: + self.assertIsNotNone(text_edge.child_text) + + # Try creating the cache and reloading from that state + db.create_node_cache() + + # Query edges for DBElems + room_1 = db.get_room(room_1_id) + agent_1 = db.get_agent(agent_1_id) + agent_2 = db.get_agent(agent_2_id) + object_1 = db.get_object(object_1_id) + + # Cached edges can be directly accessed + self.assertEqual(len(room_1.node_edges), 5) + self.assertEqual(len(room_1.text_edges), 2) + self.assertEqual(len(agent_1.node_edges), 2) + self.assertEqual(len(agent_1.text_edges), 1) + self.assertEqual(len(agent_2.node_edges), 1) + self.assertEqual(len(agent_2.text_edges), 0) + self.assertEqual(len(object_1.node_edges), 1) + self.assertEqual(len(object_1.text_edges), 1) + + # Ensure that each of the edges is valid + for node in [room_1, agent_1, agent_2, object_1]: + for node_edge in node.node_edges: + self.assertEqual(node_edge.child.db_id, node_edge.child_id) + for text_edge in node.text_edges: + self.assertIsNotNone(text_edge.child_text) + + def test_arbitrary_attributes(self): + """Ensure the arbitrary attributes are created properly""" + db = EnvDB(self.config) + + # get some things to use + agent_ids, room_ids, object_ids = self.set_up_some_nodes(db) + agent_1_id = agent_ids[0] + object_1_id = object_ids[0] + room_1_id = room_ids[0] + + # create first attribute + attribute_1_id = db.create_arbitrary_attribute( + target_id=agent_1_id, + attribute_name="tested", + attribute_value_string="true", + ) + + self.assertTrue(DBNodeAttribute.is_id(attribute_1_id)) + attributes = db.get_attributes(target_id=agent_1_id) + self.assertEqual(len(attributes), 1) + attribute_1 = attributes[0] + + # Make sure it looks right + self.assertEqual(attribute_1.db_id, attribute_1_id) + self.assertEqual(attribute_1.target_id, agent_1_id) + self.assertEqual(attribute_1.attribute_name, "tested") + self.assertEqual(attribute_1.attribute_value_string, "true") + self.assertEqual(attribute_1.status, DBStatus.REVIEW) + self.assertIsNone(attribute_1.creator_id, agent_1_id) + + # Ensure we can't duplicate + attribute_1_id_2 = db.create_arbitrary_attribute( + target_id=agent_1_id, + attribute_name="tested", + attribute_value_string="true", + ) + self.assertEqual(attribute_1_id, attribute_1_id_2) + attributes = db.get_attributes(target_id=agent_1_id) + self.assertEqual(len(attributes), 1) + + # Create more of them + attribute_2_id = db.create_arbitrary_attribute( + target_id=object_1_id, + attribute_name="tested", + attribute_value_string="true", + ) + attribute_3_id = db.create_arbitrary_attribute( + target_id=room_1_id, + attribute_name="tested", + attribute_value_string="true", + status=DBStatus.ACCEPTED, + creator_id=TEST_USER_ID, + ) + attribute_4_id = db.create_arbitrary_attribute( + target_id=agent_1_id, + attribute_name="tried", + attribute_value_string="false", + ) + attributes = db.get_attributes() + self.assertEqual(len(attributes), 4) + + # Query for arbitrary attributes + target_matches = db.get_attributes(target_id=agent_1_id) + self.assertEqual(len(target_matches), 2) + target_no_match = db.get_attributes(target_id="RME-fake") + self.assertEqual(len(target_no_match), 0) + attribute_match_3 = db.get_attributes(attribute_name="tested") + self.assertEqual(len(attribute_match_3), 3) + attribute_match_1 = db.get_attributes(attribute_name="tried") + self.assertEqual(len(attribute_match_1), 1) + attribute_match_0 = db.get_attributes(attribute_name="zzzzz") + self.assertEqual(len(attribute_match_0), 0) + value_match_3 = db.get_attributes(attribute_value_string="true") + self.assertEqual(len(value_match_3), 3) + value_match_1 = db.get_attributes(attribute_value_string="false") + self.assertEqual(len(value_match_1), 1) + value_match_0 = db.get_attributes(attribute_value_string="zzzzz") + self.assertEqual(len(value_match_0), 0) + status_match_3 = db.get_attributes(status=DBStatus.REVIEW) + self.assertEqual(len(status_match_3), 3) + status_match_1 = db.get_attributes(status=DBStatus.ACCEPTED) + self.assertEqual(len(status_match_1), 1) + status_match_0 = db.get_attributes(status=DBStatus.REJECTED) + self.assertEqual(len(status_match_0), 0) + creator_match_1 = db.get_attributes(creator_id=TEST_USER_ID) + self.assertEqual(len(creator_match_1), 1) + creator_match_0 = db.get_attributes(creator_id="zzzz") + self.assertEqual(len(creator_match_0), 0) + + # see if we can load the attributes from the elem + room_1 = db.get_room(room_1_id) + agent_1 = db.get_agent(agent_1_id) + object_1 = db.get_object(object_1_id) + + # Try expanding attributes + # All attributes fail when not loading first + with self.assertRaises(AssertionError): + _test_attributes = room_1.attributes + with self.assertRaises(AssertionError): + _test_attributes = agent_1.attributes + with self.assertRaises(AssertionError): + _test_attributes = object_1.attributes + + room_1.load_attributes(db) + agent_1.load_attributes(db) + object_1.load_attributes(db) + + self.assertEqual(len(room_1.attributes), 1) + self.assertEqual(len(agent_1.attributes), 2) + self.assertEqual(len(object_1.attributes), 1) + + # Ensure that each of the attributes is valid + for node in [room_1, agent_1, object_1]: + for attribute in node.attributes: + self.assertIsNotNone(attribute.attribute_value_string) + self.assertEqual(attribute.target_id, node.db_id) + + # Try creating the cache and reloading from that state + db.create_node_cache() + + # Query edges for DBElems + room_1 = db.get_room(room_1_id) + agent_1 = db.get_agent(agent_1_id) + object_1 = db.get_object(object_1_id) + + # Cached attributes should load no problem + self.assertEqual(len(room_1.attributes), 1) + self.assertEqual(len(agent_1.attributes), 2) + self.assertEqual(len(object_1.attributes), 1) + + # Ensure that each of the attributes is valid + for node in [room_1, agent_1, object_1]: + for attribute in node.attributes: + self.assertIsNotNone(attribute.attribute_value_string) + self.assertEqual(attribute.target_id, node.db_id) + + def test_create_load_edits(self): + """Ensure it's possible to create, load, and reject edits""" + db = EnvDB(self.config) + + # get some things to use + agent_ids, room_ids, object_ids = self.set_up_some_nodes(db) + agent_1_id = agent_ids[0] + + # Create an edit + edit_1_id = db.create_edit( + editor_id="USR-test_editor", + node_id=agent_1_id, + field="persona", + old_value="agent_persona", + new_value="edited_agent_persona", + ) + self.assertTrue(DBEdit.is_id(edit_1_id)) + + # load the edit + edits = db.get_edits() + self.assertEqual(len(edits), 1) + edit_1 = edits[0] + + # Assert fields are set + self.assertEqual(edit_1.db_id, edit_1_id) + self.assertEqual(edit_1.editor_id, "USR-test_editor") + self.assertEqual(edit_1.node_id, agent_1_id) + self.assertEqual(edit_1.field, "persona") + self.assertEqual(edit_1.old_value, "agent_persona") + self.assertEqual(edit_1.new_value, "edited_agent_persona") + self.assertEqual(edit_1.status, DBStatus.REVIEW) + self.assertIsNotNone(edit_1.create_timestamp) + + # reject the edit + edit_1.reject_edit(db) + edits = db.get_edits() + self.assertEqual(len(edits), 1) + edit_1 = edits[0] + self.assertEqual(edit_1.status, DBStatus.REJECTED) + + # create two more edits + edit_2_id = db.create_edit( + editor_id="USR-test_editor", + node_id=agent_1_id, + field="name", + old_value="test_agent_1", + new_value="test_agent_0", + status=DBStatus.QUESTIONABLE, + ) + edit_3_id = db.create_edit( + editor_id="ADMIN", + node_id=agent_1_id, + field="persona", + old_value="agent_persona", + new_value="edited_agent_persona_2", + ) + edits = db.get_edits() + self.assertEqual(len(edits), 3) + + # query the various edits + match_editor_2 = db.get_edits(editor_id="USR-test_editor") + self.assertEqual(len(match_editor_2), 2) + match_editor_1 = db.get_edits(editor_id="ADMIN") + self.assertEqual(len(match_editor_1), 1) + match_editor_0 = db.get_edits(editor_id="USR-test_editor_2") + self.assertEqual(len(match_editor_0), 0) + match_node_id_3 = db.get_edits(node_id=agent_1_id) + self.assertEqual(len(match_node_id_3), 3) + match_node_id_0 = db.get_edits(node_id="test") + self.assertEqual(len(match_node_id_0), 0) + match_field_2 = db.get_edits(field="persona") + self.assertEqual(len(match_field_2), 2) + match_field_1 = db.get_edits(field="name") + self.assertEqual(len(match_field_1), 1) + match_field_0 = db.get_edits(field="physical_description") + self.assertEqual(len(match_field_0), 0) + match_old_value_2 = db.get_edits(old_value="agent_persona") + self.assertEqual(len(match_old_value_2), 2) + match_old_value_1 = db.get_edits(old_value="test_agent_1") + self.assertEqual(len(match_old_value_1), 1) + match_old_value_0 = db.get_edits(old_value="zzzzz") + self.assertEqual(len(match_old_value_0), 0) + match_new_value = db.get_edits(new_value="test_agent_0") + self.assertEqual(len(match_new_value), 1) + no_match_new_value = db.get_edits(new_value="zzzzz") + self.assertEqual(len(no_match_new_value), 0) + match_status_standard = db.get_edits(status=DBStatus.REVIEW) + self.assertEqual(len(match_status_standard), 1) + match_status_reject = db.get_edits(status=DBStatus.REJECTED) + self.assertEqual(len(match_status_reject), 1) + match_status_special = db.get_edits(status=DBStatus.QUESTIONABLE) + self.assertEqual(len(match_status_special), 1) + match_status_0 = db.get_edits(status=DBStatus.ACCEPTED) + self.assertEqual(len(match_status_0), 0) + + # TODO accept an edit + + # Run scrub + scrub_count = db.scrub_creators(start_time=time.time() + MAX_RETENTION) + self.assertEqual(scrub_count, 2, "Should have scrubbed 2 edits") + # Can't find old user IDs + match_user = db.get_edits(editor_id="USR-test_editor") + self.assertEqual(len(match_user), 0) + # Can find special creator IDs + match_user = db.get_edits(editor_id="ADMIN") + self.assertEqual(len(match_user), 1) + # Can find scrub + match_user = db.get_edits(editor_id=SCRUBBED_USER_ID) + self.assertEqual(len(match_user), 2) + + def test_create_load_flags(self): + """Ensure it's possible to create and load flags""" + db = EnvDB(self.config) + + # get some things to use + agent_ids, room_ids, object_ids = self.set_up_some_nodes(db) + agent_1_id = agent_ids[0] + + # Create a flag + flag_1_id = db.flag_entry( + user_id="USR-flagger_id", + flag_type=DBFlagTargetType.FLAG_USER, + target_id="bad_user", + reason="some_reason", + ) + self.assertTrue(DBFlag.is_id(flag_1_id)) + + # load the flag + flags = db.get_flags() + self.assertEqual(len(flags), 1) + flag_1 = flags[0] + self.assertEqual(flag_1.db_id, flag_1_id) + self.assertEqual(flag_1.user_id, "USR-flagger_id") + self.assertEqual(flag_1.flag_type, DBFlagTargetType.FLAG_USER) + self.assertEqual(flag_1.target_id, "bad_user") + self.assertEqual(flag_1.reason, "some_reason") + self.assertEqual(flag_1.status, DBStatus.REVIEW) + self.assertIsNotNone(flag_1.create_timestamp) + + # create two more flags + flag_2_id = db.flag_entry( + user_id="USR-flagger_id", + flag_type=DBFlagTargetType.FLAG_ENVIRONMENT, + target_id=agent_ids[0], + reason="some_other_reason", + ) + flag_3_id = db.flag_entry( + user_id="USR-flagger_id", + flag_type=DBFlagTargetType.FLAG_UTTERANCE, + target_id="model_id", + reason="some_reason", + status=DBStatus.ACCEPTED, + ) + flags = db.get_flags() + self.assertEqual(len(flags), 3) + + # query the various flags + match_user = db.get_flags(user_id="USR-flagger_id") + self.assertEqual(len(match_user), 3) + no_match_user = db.get_flags(user_id="random_id") + self.assertEqual(len(no_match_user), 0) + match_type_env = db.get_flags(flag_type=DBFlagTargetType.FLAG_ENVIRONMENT) + self.assertEqual(len(match_type_env), 1) + match_type_utt = db.get_flags(flag_type=DBFlagTargetType.FLAG_UTTERANCE) + self.assertEqual(len(match_type_utt), 1) + match_type_user = db.get_flags(flag_type=DBFlagTargetType.FLAG_USER) + self.assertEqual(len(match_type_user), 1) + match_target = db.get_flags(target_id=agent_ids[0]) + self.assertEqual(len(match_target), 1) + no_match_target = db.get_flags(target_id=agent_ids[1]) + self.assertEqual(len(no_match_target), 0) + match_reason = db.get_flags(reason="some_reason") + self.assertEqual(len(match_reason), 2) + no_match_reason = db.get_flags(reason="fake_reason") + self.assertEqual(len(no_match_reason), 0) + match_status = db.get_flags(status=DBStatus.REVIEW) + self.assertEqual(len(match_status), 2) + match_other_status = db.get_flags(status=DBStatus.ACCEPTED) + self.assertEqual(len(match_other_status), 1) + no_match_status = db.get_flags(status=DBStatus.QUESTIONABLE) + self.assertEqual(len(no_match_status), 0) + + # Run duplicate, ensure flags aren't copied + new_db = db.export(self.config_2) + self.assertEqual(len(new_db.get_flags()), 0) + + # Run scrub + scrub_count = db.scrub_creators(start_time=time.time() + MAX_RETENTION) + self.assertEqual(scrub_count, 3, "Should have scrubbed 3 flags") + # Can't find old user IDs + match_user = db.get_flags(user_id="USR-flagger_id") + self.assertEqual(len(match_user), 0) + # Can find scrubbed ID + match_user = db.get_flags(user_id=SCRUBBED_USER_ID) + self.assertEqual(len(match_user), 3) + + def test_create_load_link_quests(self): + """Ensure that quests are saving and loading as expected""" + db = EnvDB(self.config) + + # get some things to use + agent_ids, room_ids, object_ids = self.set_up_some_nodes(db) + + # Create first quest + quest_1_id = db.create_quest( + agent_id=agent_ids[0], + text_motivation="top_text_motivation", + target_type=DBQuestTargetType.TEXT_ONLY, + target="", + ) + self.assertTrue(DBQuest.is_id(quest_1_id)) + + # Ensure init looks good + quests = db.find_quests() + self.assertEqual(len(quests), 1) + quest_1 = quests[0] + self.assertEqual(quest_1.db_id, quest_1_id) + self.assertEqual(quest_1.agent_id, agent_ids[0]) + self.assertEqual(quest_1.text_motivation, "top_text_motivation") + self.assertEqual(quest_1.target_type, DBQuestTargetType.TEXT_ONLY) + self.assertEqual(quest_1.target, "") + self.assertEqual(quest_1.status, DBStatus.REVIEW) + self.assertEqual(quest_1.position, 0) + self.assertIsNone(quest_1.parent_id) + self.assertIsNone(quest_1.origin_filepath) + self.assertIsNone(quest_1.creator_id) + self.assertIsNotNone(quest_1.create_timestamp) + + # Create quest tree + quest_2_id = db.create_quest( + agent_id=agent_ids[0], + text_motivation="big_text_motivation", + target_type=DBQuestTargetType.TEXT_ONLY, + target="", + parent_id=quest_1_id, + ) + quest_3_id = db.create_quest( + agent_id=agent_ids[0], + text_motivation="mid_text_motivation", + target_type=DBQuestTargetType.TEXT_ONLY, + target="", + parent_id=quest_2_id, + ) + quest_4_id = db.create_quest( + agent_id=agent_ids[0], + text_motivation="mid_text_motivation", + target_type=DBQuestTargetType.TEXT_ONLY, + target="", + parent_id=quest_2_id, + position=1, + ) + quest_5_id = db.create_quest( + agent_id=agent_ids[0], + text_motivation="low_goal_1", + target_type=DBQuestTargetType.TARGET_ACTION, + target="get thing", + parent_id=quest_3_id, + position=1, + ) + quest_6_id = db.create_quest( + agent_id=agent_ids[0], + text_motivation="low_goal_2", + target_type=DBQuestTargetType.TARGET_ACTION, + target="do something", + origin_filepath="test/file/path.json", + status=DBStatus.REJECTED, + creator_id="bad_creator", + parent_id=quest_3_id, + ) + quests = db.find_quests() + self.assertEqual(len(quests), 6) + + # Query more elements + agent_match_6 = db.find_quests(agent_id=agent_ids[0]) + self.assertEqual(len(agent_match_6), 6) + agent_match_0 = db.find_quests(agent_id=agent_ids[1]) + self.assertEqual(len(agent_match_0), 0) + motivation_match_2 = db.find_quests(text_motivation="mid_text_motivation") + self.assertEqual(len(motivation_match_2), 2) + motivation_match_1 = db.find_quests(text_motivation="low_goal_1") + self.assertEqual(len(motivation_match_1), 1) + motivation_match_0 = db.find_quests(text_motivation="fake_goal") + self.assertEqual(len(motivation_match_0), 0) + target_type_match_4 = db.find_quests(target_type=DBQuestTargetType.TEXT_ONLY) + self.assertEqual(len(target_type_match_4), 4) + target_type_match_2 = db.find_quests( + target_type=DBQuestTargetType.TARGET_ACTION + ) + self.assertEqual(len(target_type_match_2), 2) + target_match_4 = db.find_quests(target="") + self.assertEqual(len(target_match_4), 4) + target_match_1 = db.find_quests(target="get thing") + self.assertEqual(len(target_match_1), 1) + target_match_0 = db.find_quests(target="sleep") + self.assertEqual(len(target_match_0), 0) + parent_id_match_2 = db.find_quests(parent_id=quest_2_id) + self.assertEqual(len(parent_id_match_2), 2) + parent_id_match_1 = db.find_quests(parent_id=quest_1_id) + self.assertEqual(len(parent_id_match_1), 1) + parent_id_match_0 = db.find_quests(parent_id=quest_6_id) + self.assertEqual(len(parent_id_match_0), 0) + creator_id_match_1 = db.find_quests(creator_id="bad_creator") + self.assertEqual(len(creator_id_match_1), 1) + creator_id_match_0 = db.find_quests(creator_id=TEST_USER_ID) + self.assertEqual(len(creator_id_match_0), 0) + origin_filepath_match_1 = db.find_quests(origin_filepath="test/file/path.json") + self.assertEqual(len(origin_filepath_match_1), 1) + origin_filepath_match_0 = db.find_quests(origin_filepath="fake/file/path.json") + self.assertEqual(len(origin_filepath_match_0), 0) + status_match_5 = db.find_quests(status=DBStatus.REVIEW) + self.assertEqual(len(status_match_5), 5) + status_match_1 = db.find_quests(status=DBStatus.REJECTED) + self.assertEqual(len(status_match_1), 1) + status_match_0 = db.find_quests(status=DBStatus.QUESTIONABLE) + self.assertEqual(len(status_match_0), 0) + + # Check that inter-node references work + with Session(db.engine) as session: + quest_6 = session.query(DBQuest).get(quest_6_id) + session.expunge_all() + + # Loads should fail outside of session + with self.assertRaises(AssertionError): + _test_parent_chain = quest_1.parent_chain + with self.assertRaises(AssertionError): + _test_parent_chain = quest_6.parent_chain + with self.assertRaises(AssertionError): + _test_subgoals = quest_1.subgoals + with self.assertRaises(AssertionError): + _test_subgoals = quest_6.subgoals + + quest_1.load_relations(db) + quest_6.load_relations(db) + + # Subgoals are correct length + self.assertEqual(len(quest_1.subgoals), 1) + self.assertEqual(len(quest_6.subgoals), 0) + + # Parent chains are correct + self.assertEqual(len(quest_1.parent_chain), 1) + self.assertEqual(len(quest_6.parent_chain), 4) + + # Subgoals are all loaded + quest_2 = quest_1.subgoals[0] + self.assertEqual(quest_2.db_id, quest_2_id) + self.assertEqual(len(quest_2.subgoals), 2) + quest_4 = quest_2.subgoals[1] + self.assertEqual(quest_4.db_id, quest_4_id) + self.assertEqual(len(quest_4.subgoals), 0) + quest_3 = quest_2.subgoals[0] + self.assertEqual(quest_3.db_id, quest_3_id) + self.assertEqual(len(quest_3.subgoals), 2) + quest_5 = quest_3.subgoals[1] # test order swap by position + quest_6 = quest_3.subgoals[0] + self.assertEqual(quest_5.db_id, quest_5_id) + self.assertEqual(len(quest_5.subgoals), 0) + self.assertEqual(quest_6.db_id, quest_6_id) + self.assertEqual(len(quest_6.subgoals), 0) + + def test_create_load_graphs(self): + """Ensure that graph loading is functioning as expected""" + + db = EnvDB(self.config) + + # Create a test graph + test_graph_1 = OOGraph({}) + agent_node = test_graph_1.add_agent("My test agent", {}) + room_node = test_graph_1.add_room("test room", {}) + agent_node.force_move_to(room_node) + + # Save the test graph + graph_id_1 = db.save_graph(test_graph_1, creator_id="tester") + self.assertTrue(DBGraph.is_id(graph_id_1)) + self.assertEqual(test_graph_1.db_id, graph_id_1) + + # Ensure that the graph is set up as expected + graphs = db.find_graphs() + self.assertEqual(len(graphs), 1) + db_graph_1 = graphs[0] + self.assertEqual(db_graph_1.db_id, graph_id_1) + self.assertEqual(db_graph_1.graph_name, "untitled") + self.assertEqual(db_graph_1.creator_id, "tester") + self.assertTrue( + db.file_path_exists(db_graph_1.file_path), + f"Output path {db_graph_1.file_path} doesn't seem to exist in the db", + ) + self.assertEqual(db_graph_1.status, DBStatus.REVIEW) + self.assertIsNotNone(db_graph_1.create_timestamp) + + # Make changes to the graph, then re-save + room_node_2 = test_graph_1.add_room("test room 2", {}) + # Assert same graph, not new + graph_id_1_2 = db.save_graph(test_graph_1, creator_id="tester") + self.assertEqual(graph_id_1, graph_id_1_2) + graphs = db.find_graphs() + self.assertEqual(len(graphs), 1) + + # Load the graph directly + db_graph_1 = db.load_graph(graph_id_1) + + # Try to pull the underlying graph from file + oo_graph = db_graph_1.get_graph(db) + self.assertEqual( + oo_graph.to_json(), test_graph_1.to_json(), "Graphs are not equal!" + ) + + # Save a second graph, this time titled with an ID too + test_graph_2 = OOGraph({"title": "Test Graph", "db_id": "UGR-TEST"}) + agent_node = test_graph_2.add_agent("My test agent", {}) + room_node = test_graph_2.add_room("test room", {}) + agent_node.force_move_to(room_node) + + # Save the second graph + graph_id_2 = db.save_graph(test_graph_2, creator_id="tester") + self.assertTrue(DBGraph.is_id(graph_id_2)) + self.assertEqual(test_graph_2.db_id, graph_id_2) + self.assertEqual(test_graph_2.db_id, "UGR-TEST") + + # Do some queries + graphs = db.find_graphs() + self.assertEqual(len(graphs), 2) + graph_default_name_1 = db.find_graphs(graph_name="untitled") + self.assertEqual(len(graph_default_name_1), 1) + graph_custom_name_1 = db.find_graphs(graph_name="Test Graph") + self.assertEqual(len(graph_custom_name_1), 1) + graph_name_0 = db.find_graphs(graph_name="nonexisting") + self.assertEqual(len(graph_name_0), 0) + graph_creator_2 = db.find_graphs(creator_id="tester") + self.assertEqual(len(graph_creator_2), 2) + graph_creator_0 = db.find_graphs(creator_id="nonexisting") + self.assertEqual(len(graph_creator_0), 0) + + # Ensure main graph save failures + with self.assertRaises(AssertionError): + # Can't save one graph as a different creator + _graph_id_1_3 = db.save_graph(test_graph_1, creator_id="not_tester") + test_graph_1.db_id = "bad-db-id" + with self.assertRaises(AssertionError): + # Can't save a graph with an invalid DBGraph ID + _graph_id_1_3 = db.save_graph(test_graph_1, creator_id="not_tester") + + def _get_all_dbid_mixin(self): + """Pull all the dbid mixin classes""" + check_classes = [HasDBIDMixin] + has_dbid_subclasses = set() + while len(check_classes) > 0: + curr_class = check_classes.pop() + subclasses = curr_class.__subclasses__() + filtered_subclasses = [ + c for c in subclasses if c not in has_dbid_subclasses + ] + check_classes += filtered_subclasses + has_dbid_subclasses = has_dbid_subclasses.union(filtered_subclasses) + + return [dbsc for dbsc in has_dbid_subclasses if hasattr(dbsc, "ID_PREFIX")] + + def test_dbid_mixin(self): + """Ensure that all dbid mixin subclasses are valid""" + has_dbid_subclasses = self._get_all_dbid_mixin() + assert DBObject in has_dbid_subclasses + + for idx in range(len(has_dbid_subclasses)): + # Assert has correct key length + curr_class = has_dbid_subclasses[idx] + self.assertLessEqual( + len(curr_class.ID_PREFIX), + 3, + f"{curr_class} prefix {curr_class.ID_PREFIX} greater than 3 characters", + ) + # Assert creation passes self but fails others + for idx_2 in range(len(has_dbid_subclasses)): + test_id = curr_class.get_id() + if idx == idx_2: + self.assertTrue( + curr_class.is_id(test_id), + f"ID {test_id} generated by {curr_class} get_id not accepted by is_id", + ) + else: + other_class = has_dbid_subclasses[idx_2] + self.assertFalse( + other_class.is_id(test_id), + f"ID {test_id} generated by {curr_class} wrongly passes is_id of {other_class}", + ) diff --git a/light/data_model/tests/test_episode_db.py b/light/data_model/tests/test_episode_db.py new file mode 100644 index 000000000..4351c429f --- /dev/null +++ b/light/data_model/tests/test_episode_db.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +import shutil, tempfile +from omegaconf import OmegaConf +import os +import json +import time + +from light.graph.elements.graph_nodes import GraphAgent +from light.graph.structured_graph import OOGraph +from light.world.world import World, WorldConfig +from light.graph.events.graph_events import ArriveEvent, LeaveEvent, GoEvent, LookEvent +from light.world.content_loggers import AgentInteractionLogger, RoomInteractionLogger +from light.world.utils.json_utils import read_event_logs +from light.data_model.db.episodes import EpisodeDB, EpisodeLogType +from light.data_model.db.base import LightDBConfig + +TEST_USER_ID = "USR-test" + + +class TestEpisodesDB(unittest.TestCase): + """Unit tests for the EpisodeDB. Leverages Interaction Loggers to generate episodes""" + + def setUp(self): + self.maxDiff = 10000 + self.data_dir = tempfile.mkdtemp() + self.config = LightDBConfig(backend="test", file_root=self.data_dir) + self.data_dir_copy = tempfile.mkdtemp() + self.config_2 = LightDBConfig(backend="test", file_root=self.data_dir_copy) + + def tearDown(self): + shutil.rmtree(self.data_dir) + + def setUp_single_room_graph(self, episode_db=None): + # Set up the graph + test_graph = OOGraph() + agent_node = test_graph.add_agent("My test agent", {}) + room_node = test_graph.add_room("test room", {}) + agent_node.force_move_to(room_node) + test_world = World(WorldConfig(is_logging=True, episode_db=episode_db), True) + test_world.oo_graph = test_graph + return (test_graph, test_world, agent_node, room_node) + + def test_initialize_episode_db(self): + """Ensure it's possible to initialize the db""" + db = EpisodeDB(self.config) + + def test_simple_room_logger_saves_and_loads_init_graph(self): + """ + Test that the room logger properly saves and reloads the initial + graph + """ + # Set up the graph + pre_time = time.time() + episode_db = EpisodeDB(self.config) + initial = self.setUp_single_room_graph(episode_db) + test_graph, test_world, agent_node, room_node = initial + room_logger = test_graph.room_id_to_loggers[room_node.node_id] + room_logger.episode_db = episode_db + + # Push a json episode out to the db + test_init_json = test_world.oo_graph.to_json_rv(room_node.node_id) + room_logger._begin_meta_episode() + room_logger._end_meta_episode() + + # Mark the end time to test queries later + episode_id = room_logger._last_episode_logged + post_time = time.time() + + # Ensure an episode was created properly + self.assertIsNotNone(episode_id) + episode = episode_db.get_episode(episode_id) + graph_map = episode.get_graph_map() + self.assertEqual(len(episode.graphs), 2, "Expected an init and final graph") + self.assertIsNotNone(episode.group) + self.assertIsNotNone(episode.split) + self.assertIsNotNone(episode.status) + self.assertEqual( + len(episode.actors), 0, f"No actors expected, found {episode.actors}" + ) + self.assertEqual( + len(episode.get_actors()), + 0, + f"No actors expected, found {episode.get_actors()}", + ) + self.assertEqual( + episode.turn_count, 0, f"No turns excpected, found {episode.turn_count}" + ) + self.assertEqual( + episode.human_count, 0, f"No humans expected, found {episode.human_count}" + ) + self.assertEqual( + episode.action_count, + 0, + f"No actions expected, found {episode.action_count}", + ) + self.assertIn( + episode.first_graph_id, graph_map, f"First graph not present in map" + ) + self.assertIn( + episode.final_graph_id, graph_map, f"Final graph not present in map" + ) + + # Test repr + episode.__repr__() + + # Check graph equivalence + before_graph = episode.get_before_graph(episode_db) + before_graph_json = before_graph.to_json_rv(room_node.node_id) + self.assertEqual(test_init_json, before_graph_json) + + after_graph = episode.get_after_graph(episode_db) + after_graph_json = after_graph.to_json_rv(room_node.node_id) + self.assertEqual(test_init_json, after_graph_json) + + # Check the parsed episode + events = episode.get_parsed_events(episode_db) + self.assertEqual(len(events), 0, f"Expected no events, found {events}") + + # Do some episode queries + episodes = episode_db.get_episodes() + self.assertEqual(len(episodes), 1, f"Expected one episode, found {episodes}") + episodes = episode_db.get_episodes( + min_creation_time=pre_time, max_creation_time=post_time + ) + self.assertEqual(len(episodes), 1, f"Expected one episode, found {episodes}") + episodes = episode_db.get_episodes(max_creation_time=pre_time) + self.assertEqual(len(episodes), 0, f"Expected 0 episodes, found {episodes}") + episodes = episode_db.get_episodes(min_creation_time=post_time) + self.assertEqual(len(episodes), 0, f"Expected 0 episode, found {episodes}") + episodes = episode_db.get_episodes(min_turns=0) + self.assertEqual(len(episodes), 1, f"Expected one episode, found {episodes}") + episodes = episode_db.get_episodes(min_turns=1) + self.assertEqual(len(episodes), 0, f"Expected 0 episodes, found {episodes}") + episodes = episode_db.get_episodes(min_humans=0) + self.assertEqual(len(episodes), 1, f"Expected one episode, found {episodes}") + episodes = episode_db.get_episodes(min_humans=1) + self.assertEqual(len(episodes), 0, f"Expected 0 episodes, found {episodes}") + episodes = episode_db.get_episodes(min_actions=0) + self.assertEqual(len(episodes), 1, f"Expected one episode, found {episodes}") + episodes = episode_db.get_episodes(min_actions=1) + self.assertEqual(len(episodes), 0, f"Expected 0 episodes, found {episodes}") + episodes = episode_db.get_episodes(log_type=EpisodeLogType.ROOM) + self.assertEqual(len(episodes), 1, f"Expected 1 episodes, found {episodes}") + episodes = episode_db.get_episodes(log_type=EpisodeLogType.AGENT) + self.assertEqual(len(episodes), 0, f"Expected 0 episodes, found {episodes}") + + def test_simple_room_logger_saves_and_loads_event(self): + """ + Test that the room logger properly saves and reloads an event + """ + # Set up the graph + episode_db = EpisodeDB(self.config) + initial = self.setUp_single_room_graph(episode_db) + test_graph, test_world, agent_node, room_node = initial + agent_node.is_player = True + agent_node.user_id = TEST_USER_ID + room2_node = test_graph.add_room("test room2", {}) + room_logger = test_graph.room_id_to_loggers[room_node.node_id] + test_world.oo_graph = test_graph # refresh logger + + # Check an event json was done correctly + test_event = ArriveEvent(agent_node, text_content="") + test_init_json = test_world.oo_graph.to_json_rv(agent_node.get_room().node_id) + room_logger.observe_event(test_event) + test_event2 = LookEvent(agent_node) + room_logger.observe_event(test_event2) + room_logger._end_meta_episode() + + ref_json = test_event2.to_json() + episode_id = room_logger._last_episode_logged + + # Ensure an episode was created properly + self.assertIsNotNone(episode_id) + episode = episode_db.get_episode(episode_id) + events = episode.get_parsed_events(episode_db) + + event_graph = events[0][0] + event_list = events[0][1] + loaded_event = event_list[0] + + # Assert the loaded event is the same as the executed one + self.assertEqual(loaded_event.to_json(), ref_json) + + # Assert that episode queries with users + self.assertEqual(episode.human_count, 1, "Expected one human") + self.assertEqual(episode.get_actors(), [TEST_USER_ID], "Expected one actor") + episodes = episode_db.get_episodes(min_humans=1) + self.assertEqual(len(episodes), 1, f"Expected one episode, found {episodes}") + episodes = episode_db.get_episodes(user_id=TEST_USER_ID) + self.assertEqual(len(episodes), 1, f"Expected one episode, found {episodes}") + episodes = episode_db.get_episodes(user_id="nonexist") + self.assertEqual(len(episodes), 0, f"Expected 0 episodes, found {episodes}") + + def test_simple_agent_logger_saves_and_loads_init_graph(self): + """ + Test that the agent logger properly saves and reloads the initial + graph + """ + # Set up the graph + episode_db = EpisodeDB(self.config) + initial = self.setUp_single_room_graph(episode_db) + test_graph, test_world, agent_node, room_node = initial + + # Check the graph json was done correctly from agent's room + test_init_json = test_world.oo_graph.to_json_rv(room_node.node_id) + agent_logger = AgentInteractionLogger(test_world, agent_node) + agent_logger._begin_meta_episode() + agent_logger._end_meta_episode() + + # Mark the end time to test queries later + episode_id = agent_logger._last_episode_logged + post_time = time.time() + + # Ensure an episode was created properly + self.assertIsNotNone(episode_id) + episode = episode_db.get_episode(episode_id) + graph_map = episode.get_graph_map() + self.assertEqual(len(episode.graphs), 2, "Expected an init and final graph") + self.assertIsNotNone(episode.group) + self.assertIsNotNone(episode.split) + self.assertIsNotNone(episode.status) + self.assertEqual( + len(episode.actors), 0, f"No actors expected, found {episode.actors}" + ) + self.assertEqual( + len(episode.get_actors()), + 0, + f"No actors expected, found {episode.get_actors()}", + ) + self.assertEqual( + episode.turn_count, 0, f"No turns excpected, found {episode.turn_count}" + ) + self.assertEqual( + episode.human_count, 0, f"No humans expected, found {episode.human_count}" + ) + self.assertEqual( + episode.action_count, + 0, + f"No actions expected, found {episode.action_count}", + ) + self.assertIn( + episode.first_graph_id, graph_map, f"First graph not present in map" + ) + self.assertIn( + episode.final_graph_id, graph_map, f"Final graph not present in map" + ) + + # Test repr + episode.__repr__() + + # Check graph equivalence + before_graph = episode.get_before_graph(episode_db) + before_graph_json = before_graph.to_json_rv(room_node.node_id) + self.assertEqual(test_init_json, before_graph_json) + + after_graph = episode.get_after_graph(episode_db) + after_graph_json = after_graph.to_json_rv(room_node.node_id) + self.assertEqual(test_init_json, after_graph_json) + + # Check the parsed episode + events = episode.get_parsed_events(episode_db) + self.assertEqual(len(events), 0, f"Expected no events, found {events}") + + # Check some episode queries + episodes = episode_db.get_episodes(log_type=EpisodeLogType.AGENT) + self.assertEqual(len(episodes), 1, f"Expected 1 episodes, found {episodes}") + episodes = episode_db.get_episodes(log_type=EpisodeLogType.ROOM) + self.assertEqual(len(episodes), 0, f"Expected 0 episodes, found {episodes}") + + def test_simple_agent_logger_saves_and_loads_event(self): + """ + Test that the agent logger properly saves and reloads an event + """ + # Set up the graph + episode_db = EpisodeDB(self.config) + initial = self.setUp_single_room_graph(episode_db) + test_graph, test_world, agent_node, room_node = initial + agent_node.is_player = True + agent_node.user_id = TEST_USER_ID + room2_node = test_graph.add_room("test room2", {}) + test_world.oo_graph = test_graph # refresh logger + room_logger = test_graph.room_id_to_loggers[room_node.node_id] + room_logger.episode_db = episode_db + room_logger.players.add(agent_node.user_id) + + # Check an event json was done correctly + test_event = ArriveEvent(agent_node, text_content="") + test_init_json = test_world.oo_graph.to_json_rv(agent_node.get_room().node_id) + agent_logger = AgentInteractionLogger(test_world, agent_node) + agent_logger._begin_meta_episode() + agent_logger.observe_event(test_event) + test_event2 = LookEvent(agent_node) + agent_logger.observe_event(test_event2) + agent_logger._end_meta_episode() + ref_json = test_event2.to_json() + + episode_id = agent_logger._last_episode_logged + + # Ensure an episode was created properly + self.assertIsNotNone(episode_id) + episode = episode_db.get_episode(episode_id) + events = episode.get_parsed_events(episode_db) + self.assertEqual(len(events), 1, f"Expected 1 graph type, found {events}") + + event_graph = events[0][0] + event_list = events[0][1] + self.assertEqual( + len(event_list), 2, f"Expected 2 logged events, found {event_list}" + ) + loaded_event = event_list[1] + + # Assert the loaded event is the same as the executed one + self.assertEqual(loaded_event.to_json(), ref_json) + + # Assert that episode queries with users + self.assertEqual(episode.human_count, 1, "Expected one human") + self.assertEqual(episode.get_actors(), [TEST_USER_ID], "Expected one actor") + episodes = episode_db.get_episodes(min_humans=1) + self.assertEqual(len(episodes), 1, f"Expected one episode, found {episodes}") + episodes = episode_db.get_episodes(user_id=TEST_USER_ID) + self.assertEqual(len(episodes), 1, f"Expected one episode, found {episodes}") + episodes = episode_db.get_episodes(user_id="nonexist") + self.assertEqual(len(episodes), 0, f"Expected 0 episodes, found {episodes}") + + def test_simple_room_logger_e2e(self): + """ + Test that the room logger properly saves and reloads the graph and events + """ + # Set up the graph + episode_db = EpisodeDB(self.config) + initial = self.setUp_single_room_graph(episode_db) + test_graph, test_world, agent_node, room_node = initial + agent_node.is_player = True + agent_node.user_id = TEST_USER_ID + room_node2 = test_graph.add_room("test room2", {}) + test_graph.add_paths_between( + room_node, room_node2, "a path to the north", "a path to the south" + ) + test_world.oo_graph = test_graph # refresh logger + + # Check the room and event json was done correctly for room_node + event_room_node_observed = LeaveEvent( + agent_node, target_nodes=[room_node2] + ).to_json() + test_init_json = test_world.oo_graph.to_json_rv(room_node.node_id) + + GoEvent(agent_node, target_nodes=[room_node2]).execute(test_world) + + room_logger = test_graph.room_id_to_loggers[room_node.node_id] + + episode_id = room_logger._last_episode_logged + episode = episode_db.get_episode(episode_id) + events = episode.get_parsed_events(episode_db) + self.assertEqual(len(events), 1, f"Expected 1 graph type, found {events}") + + event_graph = events[0][0] + + event_list = events[0][1] + self.assertEqual( + len(event_list), 2, f"Expected 2 logged events, found {event_list}" + ) + loaded_event = event_list[1] + + ref_json = json.loads(event_room_node_observed) + event_ref = json.loads(loaded_event.to_json()) + for k in ref_json: + if k == "event_id": + continue + elif k == "target_nodes": + self.assertEqual(ref_json[k][0]["names"], event_ref[k][0]["names"]) + else: + self.assertEqual( + ref_json[k], + event_ref[k], + f"Event Json should match for LeaveEvent, misses on {k}", + ) + + # assert export works + copy_db = episode_db.export(self.config_2) + copy_episode = copy_db.get_episode(episode_id) + copy_events = copy_episode.get_parsed_events(copy_db) + self.assertEqual(len(copy_events), 1, f"Expected 1 graph type, found {events}") + + # assert user id is present in the temp dataset + self.assertIn(agent_node.user_id, episode.actors) + all_data = str(events) + for key in episode.get_graph_map().keys(): + graph = episode.get_graph(key, episode_db) + all_data += str(graph.to_json()) + self.assertIn(agent_node.user_id, all_data) + + # assert user data is scrubbed after scrub + episode_db.anonymize_group(episode.group) + episode = episode_db.get_episode(episode_id) + events = episode.get_parsed_events(episode_db) + self.assertNotIn(agent_node.user_id, episode.actors) + self.assertNotIn(agent_node.user_id, str(events)) + for key in episode.get_graph_map().keys(): + graph = episode.get_graph(key, episode_db) + self.assertNotIn(agent_node.user_id, str(graph.to_json())) + + # Assert user data is scrubbed from new table too + episode = copy_db.get_episode(episode_id) + events = episode.get_parsed_events(copy_db) + self.assertNotIn(agent_node.user_id, episode.actors) + self.assertNotIn(agent_node.user_id, str(events)) + for key in episode.get_graph_map().keys(): + graph = episode.get_graph(key, copy_db) + self.assertNotIn(agent_node.user_id, str(graph.to_json())) diff --git a/light/data_model/tests/test_user_db.py b/light/data_model/tests/test_user_db.py new file mode 100644 index 000000000..4d70a3262 --- /dev/null +++ b/light/data_model/tests/test_user_db.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from omegaconf import OmegaConf +from sqlalchemy.exc import IntegrityError +from light.data_model.db.base import LightDBConfig +from light.data_model.db.users import UserDB, PlayerStatus +from light.data_model.db.environment import EnvDB + +config = LightDBConfig(backend="test", file_root="unused") + + +class TestUserDB(unittest.TestCase): + """Test cases for setting up a structured graph""" + + def test_init(self): + """Ensure we can initialize a UserDB successfully""" + hydra_config = OmegaConf.structured(config) + db = UserDB(config) + self.assertIsNotNone(db) + self.assertIsNotNone(db.engine) + + def test_create_find_users(self): + """Ensure we can initialize players properly and find them""" + hydra_config = OmegaConf.structured(config) + db = UserDB(config) + + extern_id_1 = "TEST_LOGIN" + extern_id_2 = "TEST_PREAUTH" + + # Create players + + player_id_1 = db.create_user(extern_id_1, is_preauth=False) + player_id_2 = db.create_user(extern_id_2, is_preauth=True) + + # Assert created players as expected + + self.assertIsNotNone(player_id_1, "No player ID returned for first") + self.assertIsNotNone(player_id_2, "No player ID returned for second") + + # Assert duplicates return same id + + player_id_3 = db.create_user(extern_id_1, is_preauth=False) + self.assertEqual(player_id_1, player_id_3) + player_id_4 = db.create_user(extern_id_1, is_preauth=True) + self.assertEqual(player_id_1, player_id_4) + + # Assert can find given players, and that their values are initialized + + player_1_by_id = db.get_player(player_id_1) + player_2_by_id = db.get_player(player_id_2) + player_1_by_extern = db.get_player_by_extern_id(extern_id_1) + player_2_by_extern = db.get_player_by_extern_id(extern_id_2) + + self.assertEqual( + player_1_by_id.db_id, player_id_1, "Gotten player by ID mismatch" + ) + self.assertEqual( + player_1_by_extern.db_id, player_id_1, "Gotten player by extern mismatch" + ) + self.assertEqual( + player_1_by_id.extern_id, extern_id_1, "Gotten player by ID mismatch" + ) + self.assertEqual( + player_1_by_extern.extern_id, + extern_id_1, + "Gotten player by extern mismatch", + ) + self.assertEqual( + player_1_by_id.is_preauth, False, "Did not retain preauth status" + ) + self.assertEqual(player_1_by_id.flag_count, 0, "Did not initialize flags to 0") + self.assertEqual( + player_1_by_id.safety_trigger_count, + 0, + "Did not initialize safety triggers to 0", + ) + self.assertEqual( + player_1_by_id.total_messages, 0, "Did not intialize messages to 0" + ) + self.assertEqual( + player_1_by_id.account_status, + PlayerStatus.TUTORIAL, + "Did not initialize to tutorial", + ) + + self.assertEqual( + player_2_by_id.db_id, player_id_2, "Gotten player by ID mismatch" + ) + self.assertEqual( + player_2_by_extern.db_id, player_id_2, "Gotten player by extern mismatch" + ) + self.assertEqual( + player_2_by_id.extern_id, extern_id_2, "Gotten player by ID mismatch" + ) + self.assertEqual( + player_2_by_extern.extern_id, + extern_id_2, + "Gotten player by extern mismatch", + ) + self.assertEqual( + player_2_by_id.is_preauth, True, "Did not retain preauth status" + ) + self.assertEqual(player_2_by_id.flag_count, 0, "Did not initialize flags to 0") + self.assertEqual( + player_2_by_id.safety_trigger_count, + 0, + "Did not initialize safety triggers to 0", + ) + self.assertEqual( + player_2_by_id.total_messages, 0, "Did not intialize messages to 0" + ) + self.assertEqual( + player_2_by_id.account_status, + PlayerStatus.TUTORIAL, + "Did not initialize to tutorial", + ) + + # Assert cannot find non-existent players + + with self.assertRaises(KeyError): + player_5 = db.get_player(-1) + with self.assertRaises(KeyError): + player_5 = db.get_player_by_extern_id("FakePlayer") + + def test_update_scores(self): + """Ensure we can increment scores successfully""" + hydra_config = OmegaConf.structured(config) + db = UserDB(config) + + extern_id_1 = "TEST_LOGIN" + player_id_1 = db.create_user(extern_id_1, is_preauth=False) + agent_id_1 = 1 + agent_id_2 = 2 + + # Check default score is present and 0 + base_score = db.get_agent_score(player_id_1) + self.assertEqual(base_score.score, 0, "Default score not 0") + self.assertEqual(base_score.count, 0, "Default count not 0") + + # Check that querying non-existent score fails + with self.assertRaises(KeyError, msg="Able to find nonexisting score"): + base_score = db.get_agent_score(player_id_1, -1) + with self.assertRaises( + KeyError, msg="Able to find score for nonexisting player" + ): + base_score = db.get_agent_score(-1) + + # Add a few scores for at least 2 different agent names + db.update_agent_score(player_id_1, agent_id_1, 1, 4, 5) + db.update_agent_score(player_id_1, agent_id_2, 2, 5, -4) + db.update_agent_score(player_id_1, agent_id_2, 3, 6, 2) + + # Ensure all of the values add up as expected + base_score = db.get_agent_score(player_id_1) + score_1 = db.get_agent_score(player_id_1, agent_id_1) + score_2 = db.get_agent_score(player_id_1, agent_id_2) + + self.assertEqual(base_score.score, 6, "Scores did not add to 6") + self.assertEqual(base_score.count, 3, "Other than 3 episodes marked") + self.assertEqual(base_score.reward_xp, 3, "Reward xp not summed") + self.assertEqual(score_1.score, 1, "Expected 1 score for agent 1") + self.assertEqual(score_1.count, 1, "Expected one episode for agent 1") + self.assertEqual(score_2.score, 5, "Expected 5 score for agent 2") + self.assertEqual(score_2.count, 2, "Expected two episodes for agent 2") + + # Ensure that counts propogate up to player + player = db.get_player(player_id_1) + self.assertEqual(player.total_messages, 15, "Expected 15 actions") + + # Assert can delete player + env_db = EnvDB(config) + db.delete_player(player_id_1, env_db) + + with self.assertRaises(KeyError): + player_1_by_id = db.get_player(player_id_1) + with self.assertRaises(KeyError): + base_score = db.get_agent_score(player_id_1) + with self.assertRaises(KeyError): + score_1 = db.get_agent_score(player_id_1, agent_id_1) + with self.assertRaises(KeyError): + score_2 = db.get_agent_score(player_id_1, agent_id_2) + + def test_flag_scores(self): + """Ensure we can flag players successfully""" + hydra_config = OmegaConf.structured(config) + db = UserDB(config) + + extern_id_1 = "TEST_LOGIN" + player_id_1 = db.create_user(extern_id_1, is_preauth=False) + + # Ensure we can flag and safety trigger users + db.mark_flag(player_id_1) + db.mark_flag(player_id_1) + db.mark_safety_trigger(player_id_1) + db.mark_safety_trigger(player_id_1) + db.mark_safety_trigger(player_id_1) + + # Ensure the values add up as expected + player = db.get_player(player_id_1) + self.assertEqual(player.flag_count, 2, "Expected 2 flags") + self.assertEqual(player.safety_trigger_count, 3, "Expected 3 safety triggers") + + # Ensure we cannot flag or trigger non-existing users + with self.assertRaises(KeyError, msg="Could mark non-existing player"): + db.mark_flag(-1) + with self.assertRaises(KeyError, msg="Could mark non-existing player"): + db.mark_safety_trigger(-1) + + def test_update_player_status(self): + """Ensure we can flag players successfully""" + hydra_config = OmegaConf.structured(config) + db = UserDB(config) + + extern_id_1 = "TEST_LOGIN" + player_id_1 = db.create_user(extern_id_1, is_preauth=False) + + # Ensure we can cycle through all statuses + for player_status in [ + PlayerStatus.STANDARD, + PlayerStatus.BLOCKED, + PlayerStatus.TUTORIAL, + PlayerStatus.ADMIN, + ]: + db.update_player_status(player_id_1, player_status) + player = db.get_player(player_id_1) + self.assertEqual( + player.account_status, + player_status, + f"Did not find expected status {player_status}, instead {player.account_status}", + ) + + # Ensure update on nonexisting player fails + with self.assertRaises(KeyError, msg="Could change existing player status"): + db.update_player_status(-1, PlayerStatus.STANDARD) + + +if __name__ == "__main__": + unittest.main() diff --git a/light/graph/builders/base.py b/light/graph/builders/base.py index 11ffe44a4..7fa126717 100644 --- a/light/graph/builders/base.py +++ b/light/graph/builders/base.py @@ -3,6 +3,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import asyncio import random, copy from light.graph.builders.db_utils import id_is_usable from light.data_model.light_database import ( @@ -23,6 +24,13 @@ DBCharacter, ) +from typing import Any, Dict, TYPE_CHECKING + +if TYPE_CHECKING: + from light.registry.model_pool import ModelPool + from parlai.core.message import Message + from parlai.core.agents import Agent as ParlAIAgent + # Possible new entrances for add_new_random_agent POSSIBLE_NEW_ENTRANCES = [ "somewhere you can't see", @@ -45,11 +53,11 @@ def __init__(self): set other parameters as required to build graphs using this builder""" raise NotImplementedError - def add_random_new_agent_to_graph(self, target_graph): + async def add_random_new_agent_to_graph(self, target_graph): """Add an agent to the graph in a random room somewhere""" raise NotImplementedError - def get_graph(self): + async def get_graph(self): """Return an OOGraph built by this builder""" raise NotImplementedError @@ -210,24 +218,23 @@ class SingleSuggestionGraphBuilder(object): """Abstract class that defines methods to obtain suggestions from models for building LIGHT worlds and related graphs""" - def __init__(self, opt, model_path=""): + def __init__(self, model_pool: "ModelPool"): """Initalize SingleSuggestionGraphBuilder to access suggestion models""" - self.agents = {} - self.model_path = model_path - self.opt = copy.deepcopy(opt) + self.agents: Dict[str, "ParlAIAgent"] = {} + self.model_pool = model_pool - def load_models(self): + def load_models(self) -> None: """abstract method for loading models for suggestions""" raise NotImplementedError - def agent_recommend(self, observation, agent_type): + async def agent_recommend(self, observation, agent_type) -> "Message": """Return a response when querying a specific type of agent and return the model response""" assert agent_type in self.agents, "Agent type not found in existing agents" self.agents[agent_type].reset() msg = {"text": observation, "episode_done": True} self.agents[agent_type].observe(msg) - response = self.agents[agent_type].act() + response = await self.agents[agent_type].act() return response def get_description(self, txt_feat, element_type, num_results=5): diff --git a/light/graph/builders/external_map_json_builder.py b/light/graph/builders/external_map_json_builder.py index 5b0b6b6ef..799a14ce1 100644 --- a/light/graph/builders/external_map_json_builder.py +++ b/light/graph/builders/external_map_json_builder.py @@ -4,6 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import json +import asyncio import random, copy from light.graph.structured_graph import OOGraph from light.graph.builders.base import ( @@ -11,7 +12,7 @@ SingleSuggestionGraphBuilder, POSSIBLE_NEW_ENTRANCES, ) -from light.world.world import World +from light.world.world import World, WorldConfig class ExternalMapJsonBuilder(DBGraphBuilder): @@ -22,8 +23,8 @@ def __init__(self, ldb, debug, opt): self.opt = opt self._no_npc_models = True - def get_graph(self): + async def get_graph(self): g = OOGraph.from_worldbuilder_json(self.opt["load_map"]) - world = World(self.opt, self) + world = World(WorldConfig(opt=self.opt, graph_builder=self)) world.oo_graph = g return g, world diff --git a/light/graph/builders/map_json_builder.py b/light/graph/builders/map_json_builder.py index 71962d4b6..5dbc218f4 100644 --- a/light/graph/builders/map_json_builder.py +++ b/light/graph/builders/map_json_builder.py @@ -5,14 +5,15 @@ # LICENSE file in the root directory of this source tree. import json import random, copy +import asyncio from light.graph.structured_graph import OOGraph from light.graph.builders.base import ( - DBGraphBuilder, + GraphBuilder, SingleSuggestionGraphBuilder, POSSIBLE_NEW_ENTRANCES, ) from light.graph.events.graph_events import ArriveEvent -from light.world.world import World +from light.world.world import World, WorldConfig from typing import TYPE_CHECKING, List, Dict, Tuple, Any, Optional @@ -23,29 +24,48 @@ NodeProps, GraphAgent, ) + from light.data_model.db.episodes import EpisodeDB -class MapJsonBuilder(DBGraphBuilder): +class MapJsonBuilder(GraphBuilder): """Loads maps exported from the structured_graph to_json method.""" - def __init__(self, ldb, debug, opt): - self.db = ldb - self.opt = opt + def __init__( + self, episode_db: Optional["EpisodeDB"], opt: Optional[Dict[str, Any]] + ): + """Store initialization options""" + self.opt = opt if opt is not None else {} + self.episode_db = episode_db self.original_agents: Dict[str, Tuple["GraphRoom", "NodeProps"]] = {} self._no_npc_models = True - def get_graph(self): + def _get_attached_config( + self, world_config: Optional[WorldConfig] = None, opt: Dict[str, Any] = None + ) -> WorldConfig: + """ + Get a copy of the given world config attached to this builder + """ + if opt is None: + opt = self.opt + if world_config is None: + return WorldConfig(episode_db=self.episode_db, opt=opt, graph_builder=self) + else: + world_config = world_config.copy() + world_config.graph_builder = self + return world_config + + async def get_graph(self, world_config: Optional[WorldConfig] = None): input_json = self.opt["load_map"] f = open(input_json, "r") data = f.read() f.close() - g = OOGraph.from_json(data) + g = OOGraph.from_json(data, self.opt) g._opt = self.opt self.original_agents = { agent.name: (agent.get_room(), agent.get_props()) for agent in g.agents.values() } - world = World(self.opt, self) + world = World(self._get_attached_config(world_config)) world.oo_graph = g return g, world @@ -72,7 +92,7 @@ def _spawn_agent_in_room( ) arrival_event.execute(world) - def add_random_new_agent_to_graph(self, world) -> Optional["GraphAgent"]: + async def add_random_new_agent_to_graph(self, world) -> Optional["GraphAgent"]: """ Add an agent from the stored original_agents list that isn't currently present in the world, if such an agent exists. diff --git a/light/graph/builders/one_room_builder.py b/light/graph/builders/one_room_builder.py index 1df76187b..107e7513d 100644 --- a/light/graph/builders/one_room_builder.py +++ b/light/graph/builders/one_room_builder.py @@ -4,9 +4,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from parlai.core.params import ParlaiParser -from parlai.core.agents import create_agent, create_agent_from_shared -from light.world.world import World +from light.registry.models.starspace_model import MapStarspaceModelConfig +from light.world.world import World, WorldConfig from light.graph.structured_graph import OOGraph from light.graph.builders.base import ( DBGraphBuilder, @@ -35,6 +34,15 @@ import random import copy import time +import asyncio + +from dataclasses import dataclass, field +from omegaconf import MISSING, DictConfig +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from light.data_model.light_database import LIGHTDatabase + from light.registry.model_pool import ModelPool MAX_EXTRA_AGENTS_PER_ROOM = 2 INV_DIR = {"east": "west", "west": "east", "north": "south", "south": "north"} @@ -63,40 +71,63 @@ ] +@dataclass +class OneRoomChatBuilderConfig: # BuilderConfig(): + # TODO create builder config parent + model_loader_config: MapStarspaceModelConfig = MapStarspaceModelConfig() + suggestion_type: str = field( + default="model", + metadata={ + "help": ("Input 'model', 'human', or 'hybrid', for the suggestion type") + }, + ) + hybridity_prob: float = field( + default=0.5, + metadata={ + "help": ("Set probability how often ex-object or character is skipped") + }, + ) + use_best_match_model: bool = field( + default=False, + metadata={ + "help": ( + "use human suggestions for predicting placement of objects, characters, and room" + ) + }, + ) + # TODO move to elsewhere + light_db_file: str = field( + default="/checkpoint/light/data/database3.db", + metadata={"help": ("specific path for light database")}, + ) + + class OneRoomChatBuilder(DBGraphBuilder, SingleSuggestionGraphBuilder): """Builds a one-room light Graph using a StarSpace model to connect everything.""" - def __init__(self, ldb, debug=True, opt=None): + def __init__( + self, + ldb: "LIGHTDatabase", # LIGHT database, TODO replace with EnvDB + model_pool: "ModelPool", # Models this builder can use + builder_config: "DictConfig", # Configuration for this builder + graph_opt=None, + ): """Initializes required models and parameters for this graph builder""" - if opt is None: - parser = ParlaiParser( - True, True, "Arguments for building a LIGHT room with Starspace" - ) - self.add_parser_arguments(parser) - opt, _unknown = parser.parse_and_process_known_args() + self.graph_opt = {} if graph_opt is None else graph_opt # Setup correct path - db_path = opt.get("db_path") - if db_path is None: - parlai_datapath = opt["datapath"] - db_path = os.path.join(parlai_datapath, "light", "database3.db") - self.db_path = db_path - self.dpath = os.path.expanduser( - os.path.join(opt["db_path"], "../..", "light_maps") - ) + self.db_path = builder_config.get("light_db_file") + self.dpath = os.path.expanduser("~/ParlAI/data/light_maps/") model_path = opt.get("model_path") if model_path is None: model_path = opt.get("light_model_root") self.model_path = model_path - self.ldb = ldb DBGraphBuilder.__init__(self, ldb) - SingleSuggestionGraphBuilder.__init__(self, opt, model_path=self.model_path) - self.debug = debug + SingleSuggestionGraphBuilder.__init__(self, model_pool=model_pool) - self._no_npc_models = True self.load_models() - self.use_best_match = False - self.suggestion_type = self.opt.get("suggestion_type", "hybrid") + self.use_best_match = builder_config.use_best_match_model + self.suggestion_type = builder_config.suggestion_type # Cache for retrieved room/ char/ obj dicts from the database self.roomid_to_feats = {} self.feats_to_roomid = {} @@ -105,75 +136,26 @@ def __init__(self, ldb, debug=True, opt=None): self.charid_to_feats = {} self.feats_to_charid = {} # paramter to control the hybridity of the model - self.prob_skip_ex_objects = self.opt.get("hybridity_prob", 0.5) - self.prob_skip_ex_char = self.opt.get("hybridity_prob", 0.5) + self.prob_skip_ex_objects = builder_config.hybridity_prob + self.prob_skip_ex_char = builder_config.hybridity_prob self.allowed_characters = None self.banned_rooms = [] self.room = None self.neighbors = [] - @staticmethod - def add_parser_arguments(parser): - """ - Add arguments to a parser to be able to set the required options for - this builder - """ - parser.add_argument( - "--suggestion-type", - type=str, - default="model", - help="Input 'model', 'human', or 'hybrid', for the suggestion type", - ) - parser.add_argument( - "--hybridity-prob", - type=float, - default=0.5, - help="Set probability how often ex-object or character is skipped", - ) - parser.add_argument( - "--use-best-match-model", - type="bool", - default=False, - help="use human suggestions for predicting placement of objects, characters, and room", + def load_models(self) -> None: + """Load starspace models for building the map""" + # self.model_pool.register_model(self.config.model_loader_config, "map_starspace") + self.agents["room"] = self.model_pool.get_model( + "map_starspace", {"target_type": "room"} ) - parser.add_argument( - "--light-db-file", - type=str, - default="/checkpoint/light/data/database3.db", - help="specific path for light database", + self.agents["object"] = self.model_pool.get_model( + "map_starspace", {"target_type": "object"} ) - parser.add_argument( - "--light-model-root", - type=str, - default="/checkpoint/light/models/", - help="specific path for light models", + self.agents["character"] = self.model_pool.get_model( + "map_starspace", {"target_type": "character"} ) - def load_models(self): - """Load starspace models for building the map""" - # TODO load from zoo when launched - opt = copy.deepcopy(self.opt) - mf = os.path.join(self.model_path, "starspace/angela_starspace/model4") - opt["model_file"] = mf - # Create room agent - opt["fixed_candidates_file"] = self.dpath + "/room_full_cands.txt" - opt["override"] = {"fixed_candidates_file": opt["fixed_candidates_file"]} - self.agents["room"] = create_agent(opt, requireModelExists=True) - # Model Params are added as new fields to opt dict, Are there better ways around this? - opt = self.agents["room"].opt.copy() - opt["fixed_candidates_file"] = self.dpath + "/object_full_cands.txt" - opt["override"] = {"fixed_candidates_file": opt["fixed_candidates_file"]} - share_dict = self.agents["room"].share() - share_dict["opt"] = opt - self.agents["object"] = create_agent_from_shared(share_dict) - opt = self.agents["room"].opt.copy() - opt["fixed_candidates_file"] = self.dpath + "/character_full_cands.txt" - opt["override"] = {"fixed_candidates_file": opt["fixed_candidates_file"]} - share_dict = self.agents["room"].share() - share_dict["opt"] = opt - self.agents["character"] = create_agent_from_shared(share_dict) - self.agent = self.agents["room"] - def _props_from_obj(self, obj): """Given a DBObject representing an object in the world, extract the required props to create that object in the world @@ -258,7 +240,7 @@ def _heuristic_name_cleaning(self, use_desc): use_desc = use_desc[4:] return use_desc - def add_object_to_graph(self, g, obj, container_node, extra_props=None): + async def add_object_to_graph(self, g, obj, container_node, extra_props=None): """Adds a particular DBObject to the given OOgraph, adding to the specific container node. Returns the newly created object node""" if obj is None: @@ -284,7 +266,9 @@ def add_object_to_graph(self, g, obj, container_node, extra_props=None): obj.name, obj.db_id, ) - contained_objs = self.get_contained_items(obj.db_id, DB_TYPE_OBJ, 3)[1:] + contained_objs = await self.get_contained_items( + obj.db_id, DB_TYPE_OBJ, 3 + )[1:] for o in contained_objs: if self._name_not_in_graph(g, o.name): self._add_object_to_graph(g, o, obj_node) @@ -313,7 +297,7 @@ def _add_object_to_graph(self, g, obj, container_node, extra_props=None): obj_node.move_to(container_node) return obj_node - def add_new_agent_to_graph(self, g, char, room_node): + async def add_new_agent_to_graph(self, g, char, room_node): """Add the given DBcharacter to the given room (room_node) in the given OOFraph. Return the new agent node on success, and None on failure""" if char is None: @@ -350,18 +334,18 @@ def add_new_agent_to_graph(self, g, char, room_node): obj.db_id: ( "equipped" if obj.is_wearable or obj.is_weapon else "carrying" ) - for obj in self.get_contained_items(char.db_id, DB_TYPE_CHAR) + for obj in await self.get_contained_items(char.db_id, DB_TYPE_CHAR) } if self.suggestion_type == "hybrid" and len(objs) == 0: objs = { obj.db_id: ( "equipped" if obj.is_weapon or obj.is_wearable else "carrying" ) - for obj in self.get_contained_items(char.db_id, DB_TYPE_CHAR, 2) + for obj in await self.get_contained_items(char.db_id, DB_TYPE_CHAR, 2) } for obj in objs: - obj_node = self.add_object_to_graph( + obj_node = await self.add_object_to_graph( g, self.get_obj_from_id(obj), agent_node ) if obj_node is not None: @@ -369,11 +353,11 @@ def add_new_agent_to_graph(self, g, char, room_node): obj_node.set_prop("equipped", True) return agent_node - def add_random_new_agent_to_graph(self, world): + async def add_random_new_agent_to_graph(self, world): """Add a random agent to the OOGraph at a random room node""" raise Exception("There shouldn't be any random additions") - def add_neighbors(self, room): + async def add_neighbors(self, room): """Try to add all possible exits to a given room""" if self.use_best_match: neighbors = room.get_text_edges(DB_EDGE_NEIGHBOR) @@ -381,12 +365,14 @@ def add_neighbors(self, room): # Not using best match model but the starspace model for model prediction neighbors = [ e.setting - for e in self.get_neighbor_rooms(room_id=room.db_id, banned_rooms=[]) + for e in await self.get_neighbor_rooms( + room_id=room.db_id, banned_rooms=[] + ) ] return neighbors ##########For best match model################### - def get_similar_element(self, txt_feats, element_type): + async def get_similar_element(self, txt_feats, element_type): """Given a text feature, and the corresponding Database type return an DBElement of the DB type""" agent_type = None @@ -410,7 +396,7 @@ def get_similar_element(self, txt_feats, element_type): self.agents[agent_type].reset() msg = {"text": txt_feats, "episode_done": True} self.agents[agent_type].observe(msg) - response = self.agents[agent_type].act() + response = await self.agents[agent_type].act() ind = 0 while ind < len(response["text_candidates"]): key = response["text_candidates"][ind] @@ -420,24 +406,24 @@ def get_similar_element(self, txt_feats, element_type): ind = ind + 1 return None - def get_similar_room(self, txt_feats): + async def get_similar_room(self, txt_feats): """Find a similar room to the text room given based on a starspace prediction""" - return self.get_similar_element(txt_feats, DB_TYPE_ROOM) + return await self.get_similar_element(txt_feats, DB_TYPE_ROOM) - def get_similar_object(self, txt_feats): + async def get_similar_object(self, txt_feats): """Find a similar object to the text given based on starspace prediciton""" - return self.get_similar_element(txt_feats, DB_TYPE_OBJ) + return await self.get_similar_element(txt_feats, DB_TYPE_OBJ) - def get_similar_character(self, txt_feats): + async def get_similar_character(self, txt_feats): """Find a similar object to the text given based on starspace prediciton""" - return self.get_similar_element(txt_feats, DB_TYPE_CHAR) + return await self.get_similar_element(txt_feats, DB_TYPE_CHAR) ################################################### - def get_neighbor_rooms(self, room_id, num_results=5, banned_rooms=None): + async def get_neighbor_rooms(self, room_id, num_results=5, banned_rooms=None): """get prediction of neighbor room with StarSpaceModel, return DBRoom Object """ if banned_rooms is None: banned_rooms = [room_id] @@ -446,7 +432,7 @@ def get_neighbor_rooms(self, room_id, num_results=5, banned_rooms=None): # This is added due to the new model prediction for neighbors else: txt_feats = self.roomid_to_feats[room_id] - response = self.agent_recommend(txt_feats, "room") + response = await self.agent_recommend(txt_feats, "room") ind = 0 results = [] while len(results) < num_results: @@ -465,10 +451,10 @@ def get_neighbor_rooms(self, room_id, num_results=5, banned_rooms=None): return results return results - def get_graph_from_quest(self, quest): + async def get_graph_from_quest(self, quest): graph_json = quest["data"]["graph"] g = OOGraph.from_json(graph_json) - world = World(self.opt, self) + world = World(WorldConfig(opt=self.graph_opt, graph_builder=self)) world.oo_graph = g base_room = list(g.rooms.values())[0] @@ -476,7 +462,7 @@ def get_graph_from_quest(self, quest): if db_id is None: neighbors = [self.get_random_room(), self.get_random_room()] else: - neighbors = self.get_neighbor_rooms(db_id) + neighbors = await self.get_neighbor_rooms(db_id) for neighbor_room in neighbors: if neighbor_room is None: continue @@ -496,7 +482,7 @@ def get_graph_from_quest(self, quest): ) return g, world - def _get_constrained_graph(self, location=None, player=None, num_partners=1): + async def _get_constrained_graph(self, location=None, player=None, num_partners=1): """ Location is of the form "Location Name. location description" player is of the form "Player Name. player persona" @@ -506,7 +492,7 @@ def _get_constrained_graph(self, location=None, player=None, num_partners=1): else: set_room = self.get_room_from_id(self.roomfeats_to_id(location)) - g = OOGraph(self.opt) + g = OOGraph(self.graph_opt) room_node = g.add_room( set_room.setting, { @@ -523,7 +509,7 @@ def _get_constrained_graph(self, location=None, player=None, num_partners=1): ) set_room.g_id = room_node.node_id - possible_chars = self.get_contained_characters( + possible_chars = await self.get_contained_characters( room_id=set_room.db_id, num_results=5 ) if "db" in set_room.ex_characters: @@ -538,23 +524,29 @@ def _get_constrained_graph(self, location=None, player=None, num_partners=1): if player is None: player_char = random.choice(possible_chars) possible_chars.remove(player_char) - self_char_id = self.add_new_agent_to_graph(g, player_char, room_node) + self_char_id = await self.add_new_agent_to_graph(g, player_char, room_node) while self_char_id is None: player_char = random.choice(possible_chars) possible_chars.remove(player_char) - self_char_id = self.add_new_agent_to_graph(g, player_char, room_node) + self_char_id = await self.add_new_agent_to_graph( + g, player_char, room_node + ) else: player_char = self.get_char_from_id(self.charfeats_to_id(player.lower())) - self_char_id = self.add_new_agent_to_graph(g, player_char, room_node) + self_char_id = await self.add_new_agent_to_graph(g, player_char, room_node) for _ in range(num_partners): partner_char = random.choice(possible_chars) possible_chars.remove(partner_char) - parner_char_id = self.add_new_agent_to_graph(g, partner_char, room_node) + parner_char_id = await self.add_new_agent_to_graph( + g, partner_char, room_node + ) while parner_char_id is None: partner_char = random.choice(possible_chars) possible_chars.remove(partner_char) - parner_char_id = self.add_new_agent_to_graph(g, partner_char, room_node) + parner_char_id = await self.add_new_agent_to_graph( + g, partner_char, room_node + ) if "db" in set_room.ex_objects: for item_id in set_room.ex_objects["db"]: @@ -570,7 +562,7 @@ def _get_constrained_graph(self, location=None, player=None, num_partners=1): if obj is not None: self.add_object_to_graph(g, obj, room_node, props) if self.suggestion_type == "model": - predicted_objects = self.get_contained_items( + predicted_objects = await self.get_contained_items( container_id=set_room.db_id, container_type=DB_TYPE_ROOM ) for o in predicted_objects: @@ -578,7 +570,7 @@ def _get_constrained_graph(self, location=None, player=None, num_partners=1): room_node = g.get_node(set_room.g_id) self.add_object_to_graph(g, o, room_node) - neighbors = self.get_neighbor_rooms(set_room.db_id) + neighbors = await self.get_neighbor_rooms(set_room.db_id) for neighbor_room in neighbors: if neighbor_room is None: continue @@ -597,11 +589,11 @@ def _get_constrained_graph(self, location=None, player=None, num_partners=1): db_id=neighbor_room.db_id, ) - world = World(self.opt, self) + world = World(WorldConfig(opt=self.graph_opt, graph_builder=self)) world.oo_graph = g return g, world - def get_constrained_graph(self, location=None, player=None, num_partners=1): + async def get_constrained_graph(self, location=None, player=None, num_partners=1): """Take a few attempts to get the graph meeting the given constraints""" attempts = 9 graph = None @@ -609,7 +601,7 @@ def get_constrained_graph(self, location=None, player=None, num_partners=1): while graph is None and attempts > 0: try: random.seed(time.time()) - graph, world = self._get_constrained_graph( + graph, world = await self._get_constrained_graph( location=location, player=player, num_partners=num_partners, @@ -619,13 +611,13 @@ def get_constrained_graph(self, location=None, player=None, num_partners=1): attempts -= 1 return graph, world - def get_graph(self): + async def get_graph(self): """Construct a graph using the grid created with build_world after selecting new characters and objects to place within. """ - return self.get_constrained_graph(None, None, num_partners=1) + return await self.get_constrained_graph(None, None, num_partners=1) - def get_contained_items( + async def get_contained_items( self, container_id, container_type, num_results=5, banned_items=None ): """ @@ -651,7 +643,7 @@ def get_contained_items( txt_feats = self.charid_to_feats[container_id] else: txt_feats = self.get_text_features(self.get_char_from_id(container_id)) - response = self.agent_recommend(txt_feats, "object") + response = await self.agent_recommend(txt_feats, "object") ind = 0 results = [] while len(results) < num_results and ind < len(response["text_candidates"]): @@ -667,7 +659,9 @@ def get_contained_items( ind = ind + 1 return results - def get_contained_characters(self, room_id, num_results=5, banned_characters=None): + async def get_contained_characters( + self, room_id, num_results=5, banned_characters=None + ): """ Get prediction of contained characters in given room_id from StarSpace model.""" if banned_characters is None: banned_characters = [] @@ -678,7 +672,7 @@ def get_contained_characters(self, room_id, num_results=5, banned_characters=Non txt_feats = self.roomid_to_feats[room_id] else: txt_feats = self.get_text_features(self.get_room_from_id(room_id)) - response = self.agent_recommend(txt_feats, "character") + response = await self.agent_recommend(txt_feats, "character") ind = 0 results = [] while len(results) < num_results: diff --git a/light/graph/builders/starspace_all.py b/light/graph/builders/starspace_all.py index 094b371e3..8753f4931 100644 --- a/light/graph/builders/starspace_all.py +++ b/light/graph/builders/starspace_all.py @@ -4,7 +4,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import sys from parlai.core.params import ParlaiParser from parlai.core.agents import create_agent, create_agent_from_shared @@ -50,6 +49,8 @@ import random import copy import numpy as np +import sys +import asyncio random.seed(6) np.random.seed(6) @@ -73,6 +74,7 @@ ] +# TODO port similarly to OneRoomChatBuilder class StarspaceBuilder(DBGraphBuilder, SingleSuggestionGraphBuilder): """Builds a LIGHT map using a StarSpace model to connect everything.""" @@ -300,7 +302,7 @@ def _heuristic_name_cleaning(self, use_desc): use_desc = use_desc[4:] return use_desc - def add_object_to_graph(self, g, obj, container_node, extra_props={}): + async def add_object_to_graph(self, g, obj, container_node, extra_props={}): """Adds a particular DBObject to the given OOgraph, adding to the specific container node. Returns the newly created object node""" obj.description = obj.description.capitalize() @@ -322,7 +324,9 @@ def add_object_to_graph(self, g, obj, container_node, extra_props={}): obj.name, obj.db_id, ) - contained_objs = self.get_contained_items(obj.db_id, DB_TYPE_OBJ, 3)[1:] + contained_objs = await self.get_contained_items( + obj.db_id, DB_TYPE_OBJ, 3 + )[1:] for o in contained_objs: if self._name_not_in_graph(g, o.name): self._add_object_to_graph(g, o, obj_node) @@ -349,7 +353,7 @@ def _add_object_to_graph(self, g, obj, container_node, extra_props={}): obj_node.move_to(container_node) return obj_node - def add_new_agent_to_graph(self, g, char, room_node): + async def add_new_agent_to_graph(self, g, char, room_node): """Add the given DBcharacter to the given room (room_node) in the given OOFraph. Return the new agent node on success, and None on failure""" if "is_banned" in vars(char): @@ -384,18 +388,18 @@ def add_new_agent_to_graph(self, g, char, room_node): obj.db_id: ( "equipped" if obj.is_wearable or obj.is_weapon else "carrying" ) - for obj in self.get_contained_items(char.db_id, DB_TYPE_CHAR) + for obj in await self.get_contained_items(char.db_id, DB_TYPE_CHAR) } if self.suggestion_type == "hybrid" and len(objs) == 0: objs = { obj.db_id: ( "equipped" if obj.is_weapon or obj.is_wearable else "carrying" ) - for obj in self.get_contained_items(char.db_id, DB_TYPE_CHAR, 2) + for obj in await self.get_contained_items(char.db_id, DB_TYPE_CHAR, 2) } for obj in objs: - obj_node = self.add_object_to_graph( + obj_node = await self.add_object_to_graph( g, self.get_obj_from_id(obj), agent_node ) if obj_node is not None: @@ -403,7 +407,7 @@ def add_new_agent_to_graph(self, g, char, room_node): obj_node.set_prop("equipped", True) return agent_node - def add_random_new_agent_to_graph(self, world): + async def add_random_new_agent_to_graph(self, world): """Add a random agent to the OOGraph at a random room node""" g = world.oo_graph pos_rooms = [x for x in g.rooms.keys()] @@ -418,7 +422,7 @@ def add_random_new_agent_to_graph(self, world): if len(chars) == 0: return char = self.get_random_char() - agent = self.add_new_agent_to_graph(g, char, g.get_node(pos_room.g_id)) + agent = await self.add_new_agent_to_graph(g, char, g.get_node(pos_room.g_id)) if agent is None: return @@ -428,7 +432,7 @@ def add_random_new_agent_to_graph(self, world): ) arrival_event.execute(world) - def construct_grid(self, html_visualization_filename="/tmp/gridtmp.html"): + async def construct_grid(self, html_visualization_filename="/tmp/gridtmp.html"): """Create a new stitched together environment from an empty grid""" # Initialize for a new grid setup self.grid = {} @@ -458,7 +462,7 @@ def construct_grid(self, html_visualization_filename="/tmp/gridtmp.html"): self.stack.append(r) while len(self.stack) > 0: r1 = self.stack.pop() - self.add_exits(r1) + await self.add_exits(r1) generate_html_map(html_visualization_filename, self.grid) def new_grid_position(self, loc): @@ -499,7 +503,7 @@ def new_grid_position(self, loc): else: return None, None - def room_similarity(self, loc1, loc2): + async def room_similarity(self, loc1, loc2): """Determine how similar the starspace model thinks two given rooms are""" room_1 = self.grid[loc1] room_2 = self.grid[loc2] @@ -519,7 +523,7 @@ def room_similarity(self, loc1, loc2): msg = {"text": txt_feats, "episode_done": True} self.agents["room"].reset() self.agents["room"].observe(msg) - response = self.agents["room"].act() + response = await self.agents["room"].act() score = 100000 for i, k in enumerate(response["text_candidates"]): if k == sim_feats: @@ -527,7 +531,7 @@ def room_similarity(self, loc1, loc2): break return score - def possibly_connect_to_neighbor(self, loc1, loc2, src_dir): + async def possibly_connect_to_neighbor(self, loc1, loc2, src_dir): """Connect two rooms if the model thinks they're similar enough""" # TODO rather than connecting if two rooms are similar, perhaps # we should be connecting if a room is similar to another @@ -541,7 +545,7 @@ def possibly_connect_to_neighbor(self, loc1, loc2, src_dir): return else: # compute similarity of rooms: - sim = self.room_similarity(loc1, loc2) + sim = await self.room_similarity(loc1, loc2) if sim > 100: # if not in the top 100 most similar rooms, no connection. return @@ -551,24 +555,32 @@ def possibly_connect_to_neighbor(self, loc1, loc2, src_dir): self.grid[loc2].possible_connections[INV_DIR[src_dir] + "*"] = True self.grid[loc1].possible_connections[src_dir + "*"] = True - def possibly_connect_to_neighbors(self, loc): + async def possibly_connect_to_neighbors(self, loc): """Try to connect a room to all of its possible neighbors""" - self.possibly_connect_to_neighbor(loc, (loc[0] - 1, loc[1], loc[2]), "west") - self.possibly_connect_to_neighbor(loc, (loc[0] + 1, loc[1], loc[2]), "east") - self.possibly_connect_to_neighbor(loc, (loc[0], loc[1] - 1, loc[2]), "north") - self.possibly_connect_to_neighbor(loc, (loc[0], loc[1] + 1, loc[2]), "south") + await self.possibly_connect_to_neighbor( + loc, (loc[0] - 1, loc[1], loc[2]), "west" + ) + await self.possibly_connect_to_neighbor( + loc, (loc[0] + 1, loc[1], loc[2]), "east" + ) + await self.possibly_connect_to_neighbor( + loc, (loc[0], loc[1] - 1, loc[2]), "north" + ) + await self.possibly_connect_to_neighbor( + loc, (loc[0], loc[1] + 1, loc[2]), "south" + ) - def add_room(self, room, loc, src_loc, src_dir): + async def add_room(self, room, loc, src_loc, src_dir): """Add a room as the neighbor of the room at src_loc""" self.grid[loc] = room room.loc = loc self.grid[loc].possible_connections[INV_DIR[src_dir]] = True self.grid[src_loc].possible_connections[src_dir] = True self.banned_rooms.add(room.db_id) - self.possibly_connect_to_neighbors(loc) + await self.possibly_connect_to_neighbors(loc) self.stack.append(room) - def add_exits(self, r): + async def add_exits(self, r): """Try to add all possible exits to a given room""" if type(r) is FillerRoom: # This is needed as neighbors used to be the field in filler_room @@ -579,7 +591,7 @@ def add_exits(self, r): # Not using best match model but the starspace model for model prediction neighbors = [ e.setting - for e in self.get_neighbor_rooms( + for e in await self.get_neighbor_rooms( room_id=r.db_id, banned_rooms=self.banned_rooms ) ] @@ -587,7 +599,7 @@ def add_exits(self, r): l1, src_dir = self.new_grid_position(r.loc) if l1 is not None: exit_text = e + " " + r.category - r1 = self.get_similar_room(exit_text) + r1 = await self.get_similar_room(exit_text) if r1 is not None: if self.debug: print( @@ -617,12 +629,12 @@ def add_exits(self, r): # FillerRoom is using the db_id of the room it substitute's filler_room.neighbors = r1.get_text_edges(DB_EDGE_NEIGHBOR) # Filler rooms inherit neighbors from the real room they replace - self.add_room(filler_room, l1, r.loc, src_dir) + await self.add_room(filler_room, l1, r.loc, src_dir) else: - self.add_room(r1, l1, r.loc, src_dir) + await self.add_room(r1, l1, r.loc, src_dir) ##########For best match model################### - def get_similar_element(self, txt_feats, element_type): + async def get_similar_element(self, txt_feats, element_type): """Given a text feature, and the corresponding Database type return an DBElement of the DB type""" agent_type = None @@ -646,7 +658,7 @@ def get_similar_element(self, txt_feats, element_type): self.agents[agent_type].reset() msg = {"text": txt_feats, "episode_done": True} self.agents[agent_type].observe(msg) - response = self.agents[agent_type].act() + response = await self.agents[agent_type].act() ind = 0 while ind < len(response["text_candidates"]): key = response["text_candidates"][ind] @@ -656,24 +668,24 @@ def get_similar_element(self, txt_feats, element_type): ind = ind + 1 return None - def get_similar_room(self, txt_feats): + async def get_similar_room(self, txt_feats): """Find a similar room to the text room given based on a starspace prediction""" - return self.get_similar_element(txt_feats, DB_TYPE_ROOM) + return await self.get_similar_element(txt_feats, DB_TYPE_ROOM) - def get_similar_object(self, txt_feats): + async def get_similar_object(self, txt_feats): """Find a similar object to the text given based on starspace prediciton""" - return self.get_similar_element(txt_feats, DB_TYPE_OBJ) + return await self.get_similar_element(txt_feats, DB_TYPE_OBJ) - def get_similar_character(self, txt_feats): + async def get_similar_character(self, txt_feats): """Find a similar object to the text given based on starspace prediciton""" - return self.get_similar_element(txt_feats, DB_TYPE_CHAR) + return await self.get_similar_element(txt_feats, DB_TYPE_CHAR) ################################################### - def get_neighbor_rooms(self, room_id, num_results=5, banned_rooms=None): + async def get_neighbor_rooms(self, room_id, num_results=5, banned_rooms=None): """get prediction of neighbor room with StarSpaceModel, return DBRoom Object """ if banned_rooms is None: banned_rooms = [room_id] @@ -682,7 +694,7 @@ def get_neighbor_rooms(self, room_id, num_results=5, banned_rooms=None): # This is added due to the new model prediction for neighbors else: txt_feats = self.roomid_to_feats[room_id] - response = self.agent_recommend(txt_feats, "room") + response = await self.agent_recommend(txt_feats, "room") ind = 0 results = [] while len(results) < num_results: @@ -701,11 +713,11 @@ def get_neighbor_rooms(self, room_id, num_results=5, banned_rooms=None): return results return results - def get_graph(self): + async def get_graph(self): """Construct a graph using the grid created with build_world after selecting new characters and objects to place within. """ - self.construct_grid() + await self.construct_grid() g = OOGraph(self.opt) self.g = g room_ids = [] @@ -715,7 +727,7 @@ def get_graph(self): for grid_loc, pos_room in self.grid.items(): if pos_room.setting == "EMPTY": continue - pos_room.g_id = g.add_room( + pos_room.g_id = await g.add_room( pos_room.setting, { "room": True, @@ -789,7 +801,7 @@ def get_graph(self): if obj is not None: room_node = g.get_node(pos_room.g_id) no_human_suggestions_obj = False - objid = self.add_object_to_graph(g, obj, room_node) + objid = await self.add_object_to_graph(g, obj, room_node) if "db" in pos_room.in_objects: for item_id in pos_room.in_objects["db"]: obj = self.get_obj_from_id(item_id) @@ -799,7 +811,7 @@ def get_graph(self): if obj is not None: room_node = g.get_node(pos_room.g_id) no_human_suggestions_obj = False - obj_node = self.add_object_to_graph( + obj_node = await self.add_object_to_graph( g, obj, room_node, props ) if "db" in pos_room.ex_characters: @@ -814,7 +826,7 @@ def get_graph(self): if char is not None: room_node = g.get_node(pos_room.g_id) no_human_suggestions_char = False - self.add_new_agent_to_graph(g, char, room_node) + await self.add_new_agent_to_graph(g, char, room_node) cnt += 1 if "db" in pos_room.in_characters: for char_id in pos_room.in_characters["db"]: @@ -832,30 +844,30 @@ def get_graph(self): if self.suggestion_type != "human": # For model suggestions and hybrid if self.suggestion_type == "model" or no_human_suggestions_obj: - predicted_objects = self.get_contained_items( + predicted_objects = await self.get_contained_items( container_id=pos_room.db_id, container_type=DB_TYPE_ROOM ) for o in predicted_objects: if o is not None: room_node = g.get_node(pos_room.g_id) - self.add_object_to_graph(g, o, room_node) + await self.add_object_to_graph(g, o, room_node) if self.suggestion_type == "model" or no_human_suggestions_char: - predicted_characters = self.get_contained_characters( + predicted_characters = await self.get_contained_characters( room_id=pos_room.db_id, num_results=2 ) for c in predicted_characters: if c is not None: room_node = g.get_node(pos_room.g_id) - self.add_new_agent_to_graph(g, c, room_node) + await self.add_new_agent_to_graph(g, c, room_node) for room in g.rooms: g.room_id_to_loggers[room] = RoomInteractionLogger(g, room) - world = World(self.opt, self) + world = World(WorldConfig(opt=self.opt, graph_builder=self)) world.oo_graph = g return g, world - def get_contained_items( + async def get_contained_items( self, container_id, container_type, num_results=5, banned_items=[] ): """ @@ -879,7 +891,7 @@ def get_contained_items( txt_feats = self.charid_to_feats[container_id] else: txt_feats = self.get_text_features(self.get_char_from_id(container_id)) - response = self.agent_recommend(txt_feats, "object") + response = await self.agent_recommend(txt_feats, "object") ind = 0 results = [] while len(results) < num_results and ind < len(response["text_candidates"]): @@ -895,7 +907,9 @@ def get_contained_items( ind = ind + 1 return results - def get_contained_characters(self, room_id, num_results=5, banned_characters=[]): + async def get_contained_characters( + self, room_id, num_results=5, banned_characters=[] + ): """ Get prediction of contained characters in given room_id from StarSpace model.""" if type(room_id) is str and room_id[0] == "f": # To check for filler_rooms, if it is filler_room, using the db_id of the room it has replaced @@ -904,7 +918,7 @@ def get_contained_characters(self, room_id, num_results=5, banned_characters=[]) txt_feats = self.roomid_to_feats[room_id] else: txt_feats = self.get_text_features(self.get_room_from_id(room_id)) - response = self.agent_recommend(txt_feats, "character") + response = await self.agent_recommend(txt_feats, "character") ind = 0 results = [] while len(results) < num_results: @@ -923,33 +937,35 @@ def get_contained_characters(self, room_id, num_results=5, banned_characters=[]) return results return results - def get_description(self, txt_feat, element_type, num_results=5): + async def get_description(self, txt_feat, element_type, num_results=5): """Get description of element, given the txt_feature title""" - response = self.agent_recommend(txt_feat, "node_desc") + response = await self.agent_recommend(txt_feat, "node_desc") text_results = [r["text_candidates"] for r in response][ : min(num_results, len(response)) ] return text_results - def get_object_affordance(self, txt_feat, num_results=5): + async def get_object_affordance(self, txt_feat, num_results=5): """Given a text representation of an object, return its affordance such as wearable, gettable, wiedable etc""" - response = self.agent_recommend(txt_feat, "obj_afford") + response = await self.agent_recommend(txt_feat, "obj_afford") text_results = [r["text_candidates"] for r in response][ : min(num_results, len(response)) ] return text_results - def get_character_object_relation(self, txt_feat, affordance_type, num_results=5): + async def get_character_object_relation( + self, txt_feat, affordance_type, num_results=5 + ): """Get the text based object given the character name and affordance of the object""" query_type = "char_" + affordance_type - response = self.agent_recommend(txt_feat, query_type) + response = await self.agent_recommend(txt_feat, query_type) text_results = [r["text_candidates"] for r in response][ : min(num_results, len(response)) ] return text_results - def get_element_relationship( + async def get_element_relationship( self, element_txt, element_type, @@ -973,28 +989,28 @@ def get_element_relationship( return closest_match = None if element_type == DB_TYPE_OBJ: - closest_match = self.get_similar_object(element_txt) + closest_match = await self.get_similar_object(element_txt) elif relationship_type == CHAR_CONTAINING: - closest_match = self.get_similar_room(element_txt) + closest_match = await self.get_similar_room(element_txt) elif element_type == DB_TYPE_CHAR: - closest_match = self.get_similar_character(element_txt) + closest_match = await self.get_similar_character(element_txt) elif element_type == DB_TYPE_ROOM: - closest_match = self.get_similar_room(element_txt) + closest_match = await self.get_similar_room(element_txt) banned_items.extend(self.banned_rooms) if closest_match is None: return if relationship_type == NEIGHBOR: return [ r - for r in self.get_neighbor_rooms(closest_match.db_id, num_results) + for r in await self.get_neighbor_rooms(closest_match.db_id, num_results) if r not in banned_items ] elif relationship_type == CONTAINING: - return self.get_contained_items( + return await self.get_contained_items( closest_match.db_id, element_type, num_results ) elif relationship_type == CHAR_CONTAINING: - return self.get_contained_characters(closest_match.db_id, num_results) + return await self.get_contained_characters(closest_match.db_id, num_results) def get_text_features(self, c, full=False): """Return text feature given a candidate and return and cache the text feature""" diff --git a/light/graph/builders/starspace_assisted.py b/light/graph/builders/starspace_assisted.py index 0ada92680..54bb992c0 100644 --- a/light/graph/builders/starspace_assisted.py +++ b/light/graph/builders/starspace_assisted.py @@ -4,11 +4,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import sys from parlai.core.params import ParlaiParser from parlai.core.agents import create_agent, create_agent_from_shared -from light.world.world import World +from light.world.world import World, WorldConfig from light.graph.viz.html_map import generate_html_map from light.graph.structured_graph import OOGraph from light.graph.elements.graph_nodes import GraphNode @@ -50,6 +49,8 @@ import random import copy import numpy as np +import sys +import asyncio random.seed(6) np.random.seed(6) @@ -76,6 +77,7 @@ ] +# TODO port similarly to OneRoomChatBuilder class StarspaceBuilder(DBGraphBuilder, SingleSuggestionGraphBuilder): """Builds a LIGHT map using a StarSpace model to connect everything.""" @@ -307,7 +309,7 @@ def _heuristic_name_cleaning(self, use_desc): use_desc = use_desc[4:] return use_desc - def add_object_to_graph(self, g, obj, container_node, extra_props={}): + async def add_object_to_graph(self, g, obj, container_node, extra_props={}): """Adds a particular DBObject to the given OOgraph, adding to the specific container node. Returns the newly created object node""" obj.description = obj.description.capitalize() @@ -329,7 +331,9 @@ def add_object_to_graph(self, g, obj, container_node, extra_props={}): obj.name, obj.db_id, ) - contained_objs = self.get_contained_items(obj.db_id, DB_TYPE_OBJ, 3)[1:] + contained_objs = await self.get_contained_items( + obj.db_id, DB_TYPE_OBJ, 3 + )[1:] for o in contained_objs: if self._name_not_in_graph(g, o.name): self._add_object_to_graph(g, o, obj_node) @@ -356,7 +360,7 @@ def _add_object_to_graph(self, g, obj, container_node, extra_props={}): obj_node.move_to(container_node) return obj_node - def add_new_agent_to_graph(self, g, char, room_node): + async def add_new_agent_to_graph(self, g, char, room_node): """Add the given DBcharacter to the given room (room_node) in the given OOFraph. Return the new agent node on success, and None on failure""" if "is_banned" in vars(char): @@ -391,18 +395,18 @@ def add_new_agent_to_graph(self, g, char, room_node): obj.db_id: ( "equipped" if obj.is_wearable or obj.is_weapon else "carrying" ) - for obj in self.get_contained_items(char.db_id, DB_TYPE_CHAR) + for obj in await self.get_contained_items(char.db_id, DB_TYPE_CHAR) } if self.suggestion_type == "hybrid" and len(objs) == 0: objs = { obj.db_id: ( "equipped" if obj.is_weapon or obj.is_wearable else "carrying" ) - for obj in self.get_contained_items(char.db_id, DB_TYPE_CHAR, 2) + for obj in await self.get_contained_items(char.db_id, DB_TYPE_CHAR, 2) } for obj in objs: - obj_node = self.add_object_to_graph( + obj_node = await self.add_object_to_graph( g, self.get_obj_from_id(obj), agent_node ) if obj_node is not None: @@ -410,7 +414,7 @@ def add_new_agent_to_graph(self, g, char, room_node): obj_node.set_prop("equipped", True) return agent_node - def add_random_new_agent_to_graph(self, world): + async def add_random_new_agent_to_graph(self, world): """Add a random agent to the OOGraph at a random room node""" g = world.oo_graph pos_rooms = [x for x in g.rooms.keys()] @@ -425,7 +429,7 @@ def add_random_new_agent_to_graph(self, world): if len(chars) == 0: return char = self.get_random_char() - agent = self.add_new_agent_to_graph(g, char, g.get_node(pos_room.g_id)) + agent = await self.add_new_agent_to_graph(g, char, g.get_node(pos_room.g_id)) if agent is None: return @@ -435,7 +439,7 @@ def add_random_new_agent_to_graph(self, world): ) arrival_event.execute(world) - def construct_grid(self, html_visualization_filename="/tmp/gridtmp.html"): + async def construct_grid(self, html_visualization_filename="/tmp/gridtmp.html"): """Create a new stitched together environment from an empty grid""" # Initialize for a new grid setup self.grid = {} @@ -465,7 +469,7 @@ def construct_grid(self, html_visualization_filename="/tmp/gridtmp.html"): self.stack.append(r) while len(self.stack) > 0: r1 = self.stack.pop() - self.add_exits(r1) + await self.add_exits(r1) generate_html_map(html_visualization_filename, self.grid) def new_grid_position(self, loc): @@ -506,7 +510,7 @@ def new_grid_position(self, loc): else: return None, None - def room_similarity(self, loc1, loc2): + async def room_similarity(self, loc1, loc2): """Determine how similar the starspace model thinks two given rooms are""" room_1 = self.grid[loc1] room_2 = self.grid[loc2] @@ -526,7 +530,7 @@ def room_similarity(self, loc1, loc2): msg = {"text": txt_feats, "episode_done": True} self.agents["room"].reset() self.agents["room"].observe(msg) - response = self.agents["room"].act() + response = await self.agents["room"].act() score = 100000 for i, k in enumerate(response["text_candidates"]): if k == sim_feats: @@ -534,7 +538,7 @@ def room_similarity(self, loc1, loc2): break return score - def possibly_connect_to_neighbor(self, loc1, loc2, src_dir): + async def possibly_connect_to_neighbor(self, loc1, loc2, src_dir): """Connect two rooms if the model thinks they're similar enough""" # TODO rather than connecting if two rooms are similar, perhaps # we should be connecting if a room is similar to another @@ -548,7 +552,7 @@ def possibly_connect_to_neighbor(self, loc1, loc2, src_dir): return else: # compute similarity of rooms: - sim = self.room_similarity(loc1, loc2) + sim = await self.room_similarity(loc1, loc2) if sim > 100: # if not in the top 100 most similar rooms, no connection. return @@ -558,24 +562,32 @@ def possibly_connect_to_neighbor(self, loc1, loc2, src_dir): self.grid[loc2].possible_connections[INV_DIR[src_dir] + "*"] = True self.grid[loc1].possible_connections[src_dir + "*"] = True - def possibly_connect_to_neighbors(self, loc): + async def possibly_connect_to_neighbors(self, loc): """Try to connect a room to all of its possible neighbors""" - self.possibly_connect_to_neighbor(loc, (loc[0] - 1, loc[1], loc[2]), "west") - self.possibly_connect_to_neighbor(loc, (loc[0] + 1, loc[1], loc[2]), "east") - self.possibly_connect_to_neighbor(loc, (loc[0], loc[1] - 1, loc[2]), "north") - self.possibly_connect_to_neighbor(loc, (loc[0], loc[1] + 1, loc[2]), "south") + await self.possibly_connect_to_neighbor( + loc, (loc[0] - 1, loc[1], loc[2]), "west" + ) + await self.possibly_connect_to_neighbor( + loc, (loc[0] + 1, loc[1], loc[2]), "east" + ) + await self.possibly_connect_to_neighbor( + loc, (loc[0], loc[1] - 1, loc[2]), "north" + ) + await self.possibly_connect_to_neighbor( + loc, (loc[0], loc[1] + 1, loc[2]), "south" + ) - def add_room(self, room, loc, src_loc, src_dir): + async def add_room(self, room, loc, src_loc, src_dir): """Add a room as the neighbor of the room at src_loc""" self.grid[loc] = room room.loc = loc self.grid[loc].possible_connections[INV_DIR[src_dir]] = True self.grid[src_loc].possible_connections[src_dir] = True self.banned_rooms.add(room.db_id) - self.possibly_connect_to_neighbors(loc) + await self.possibly_connect_to_neighbors(loc) self.stack.append(room) - def add_exits(self, r): + async def add_exits(self, r): """Try to add all possible exits to a given room""" if type(r) is FillerRoom: # This is needed as neighbors used to be the field in filler_room @@ -586,7 +598,7 @@ def add_exits(self, r): # Not using best match model but the starspace model for model prediction neighbors = [ e.setting - for e in self.get_neighbor_rooms( + for e in await self.get_neighbor_rooms( room_id=r.db_id, banned_rooms=self.banned_rooms ) ] @@ -594,7 +606,7 @@ def add_exits(self, r): l1, src_dir = self.new_grid_position(r.loc) if l1 is not None: exit_text = e + " " + r.category - r1 = self.get_similar_room(exit_text) + r1 = await self.get_similar_room(exit_text) if r1 is not None: if self.debug: print( @@ -624,12 +636,12 @@ def add_exits(self, r): # FillerRoom is using the db_id of the room it substitute's filler_room.neighbors = r1.get_text_edges(DB_EDGE_NEIGHBOR) # Filler rooms inherit neighbors from the real room they replace - self.add_room(filler_room, l1, r.loc, src_dir) + await self.add_room(filler_room, l1, r.loc, src_dir) else: - self.add_room(r1, l1, r.loc, src_dir) + await self.add_room(r1, l1, r.loc, src_dir) ##########For best match model################### - def get_similar_element(self, txt_feats, element_type): + async def get_similar_element(self, txt_feats, element_type): """Given a text feature, and the corresponding Database type return an DBElement of the DB type""" agent_type = None @@ -653,7 +665,7 @@ def get_similar_element(self, txt_feats, element_type): self.agents[agent_type].reset() msg = {"text": txt_feats, "episode_done": True} self.agents[agent_type].observe(msg) - response = self.agents[agent_type].act() + response = await self.agents[agent_type].act() ind = 0 while ind < len(response["text_candidates"]): key = response["text_candidates"][ind] @@ -663,24 +675,24 @@ def get_similar_element(self, txt_feats, element_type): ind = ind + 1 return None - def get_similar_room(self, txt_feats): + async def get_similar_room(self, txt_feats): """Find a similar room to the text room given based on a starspace prediction""" - return self.get_similar_element(txt_feats, DB_TYPE_ROOM) + return await self.get_similar_element(txt_feats, DB_TYPE_ROOM) - def get_similar_object(self, txt_feats): + async def get_similar_object(self, txt_feats): """Find a similar object to the text given based on starspace prediciton""" - return self.get_similar_element(txt_feats, DB_TYPE_OBJ) + return await self.get_similar_element(txt_feats, DB_TYPE_OBJ) - def get_similar_character(self, txt_feats): + async def get_similar_character(self, txt_feats): """Find a similar object to the text given based on starspace prediciton""" - return self.get_similar_element(txt_feats, DB_TYPE_CHAR) + return await self.get_similar_element(txt_feats, DB_TYPE_CHAR) ################################################### - def get_neighbor_rooms(self, room_id, num_results=5, banned_rooms=None): + async def get_neighbor_rooms(self, room_id, num_results=5, banned_rooms=None): """get prediction of neighbor room with StarSpaceModel, return DBRoom Object """ if banned_rooms is None: banned_rooms = [room_id] @@ -689,7 +701,7 @@ def get_neighbor_rooms(self, room_id, num_results=5, banned_rooms=None): # This is added due to the new model prediction for neighbors else: txt_feats = self.roomid_to_feats[room_id] - response = self.agent_recommend(txt_feats, "room") + response = await self.agent_recommend(txt_feats, "room") ind = 0 results = [] while len(results) < num_results: @@ -708,11 +720,11 @@ def get_neighbor_rooms(self, room_id, num_results=5, banned_rooms=None): return results return results - def get_graph(self): + async def get_graph(self): """Construct a graph using the grid created with build_world after selecting new characters and objects to place within. """ - self.construct_grid() + await self.construct_grid() g = OOGraph(self.opt) self.g = g room_ids = [] @@ -722,7 +734,7 @@ def get_graph(self): for grid_loc, pos_room in self.grid.items(): if pos_room.setting == "EMPTY": continue - pos_room.g_id = g.add_room( + pos_room.g_id = await g.add_room( pos_room.setting, { "room": True, @@ -796,7 +808,7 @@ def get_graph(self): if obj is not None: room_node = g.get_node(pos_room.g_id) no_human_suggestions_obj = False - objid = self.add_object_to_graph(g, obj, room_node) + objid = await self.add_object_to_graph(g, obj, room_node) if "db" in pos_room.in_objects: for item_id in pos_room.in_objects["db"]: obj = self.get_obj_from_id(item_id) @@ -806,7 +818,7 @@ def get_graph(self): if obj is not None: room_node = g.get_node(pos_room.g_id) no_human_suggestions_obj = False - obj_node = self.add_object_to_graph( + obj_node = await self.add_object_to_graph( g, obj, room_node, props ) if "db" in pos_room.ex_characters: @@ -821,7 +833,7 @@ def get_graph(self): if char is not None: room_node = g.get_node(pos_room.g_id) no_human_suggestions_char = False - self.add_new_agent_to_graph(g, char, room_node) + await self.add_new_agent_to_graph(g, char, room_node) cnt += 1 if "db" in pos_room.in_characters: for char_id in pos_room.in_characters["db"]: @@ -839,30 +851,30 @@ def get_graph(self): if self.suggestion_type != "human": # For model suggestions and hybrid if self.suggestion_type == "model" or no_human_suggestions_obj: - predicted_objects = self.get_contained_items( + predicted_objects = await self.get_contained_items( container_id=pos_room.db_id, container_type=DB_TYPE_ROOM ) for o in predicted_objects: if o is not None: room_node = g.get_node(pos_room.g_id) - self.add_object_to_graph(g, o, room_node) + await self.add_object_to_graph(g, o, room_node) if self.suggestion_type == "model" or no_human_suggestions_char: - predicted_characters = self.get_contained_characters( + predicted_characters = await self.get_contained_characters( room_id=pos_room.db_id, num_results=2 ) for c in predicted_characters: if c is not None: room_node = g.get_node(pos_room.g_id) - self.add_new_agent_to_graph(g, c, room_node) + await self.add_new_agent_to_graph(g, c, room_node) for room in g.rooms: g.room_id_to_loggers[room] = RoomInteractionLogger(g, room) - world = World(self.opt, self) + world = World(WorldConfig(opt=self.opt, graph_builder=self)) world.oo_graph = g return g, world - def get_contained_items( + async def get_contained_items( self, container_id, container_type, num_results=5, banned_items=[] ): """ @@ -886,7 +898,7 @@ def get_contained_items( txt_feats = self.charid_to_feats[container_id] else: txt_feats = self.get_text_features(self.get_char_from_id(container_id)) - response = self.agent_recommend(txt_feats, "object") + response = await self.agent_recommend(txt_feats, "object") ind = 0 results = [] if len(self.banned) == 0: @@ -938,7 +950,9 @@ def load_banned(self): with open("/tmp/banned.json") as json_file: self.banned = json.load(json_file) - def get_contained_characters(self, room_id, num_results=5, banned_characters=[]): + async def get_contained_characters( + self, room_id, num_results=5, banned_characters=[] + ): """ Get prediction of contained characters in given room_id from StarSpace model.""" if type(room_id) is str and room_id[0] == "f": # To check for filler_rooms, if it is filler_room, using the db_id of the room it has replaced @@ -947,7 +961,7 @@ def get_contained_characters(self, room_id, num_results=5, banned_characters=[]) txt_feats = self.roomid_to_feats[room_id] else: txt_feats = self.get_text_features(self.get_room_from_id(room_id)) - response = self.agent_recommend(txt_feats, "character") + response = await self.agent_recommend(txt_feats, "character") ind = 0 results = [] if len(self.banned) == 0: @@ -992,33 +1006,35 @@ def get_contained_characters(self, room_id, num_results=5, banned_characters=[]) ind = ind + 1 return results - def get_description(self, txt_feat, element_type, num_results=5): + async def get_description(self, txt_feat, element_type, num_results=5): """Get description of element, given the txt_feature title""" - response = self.agent_recommend(txt_feat, "node_desc") + response = await self.agent_recommend(txt_feat, "node_desc") text_results = [r["text_candidates"] for r in response][ : min(num_results, len(response)) ] return text_results - def get_object_affordance(self, txt_feat, num_results=5): + async def get_object_affordance(self, txt_feat, num_results=5): """Given a text representation of an object, return its affordance such as wearable, gettable, wiedable etc""" - response = self.agent_recommend(txt_feat, "obj_afford") + response = await self.agent_recommend(txt_feat, "obj_afford") text_results = [r["text_candidates"] for r in response][ : min(num_results, len(response)) ] return text_results - def get_character_object_relation(self, txt_feat, affordance_type, num_results=5): + async def get_character_object_relation( + self, txt_feat, affordance_type, num_results=5 + ): """Get the text based object given the character name and affordance of the object""" query_type = "char_" + affordance_type - response = self.agent_recommend(txt_feat, query_type) + response = await self.agent_recommend(txt_feat, query_type) text_results = [r["text_candidates"] for r in response][ : min(num_results, len(response)) ] return text_results - def get_element_relationship( + async def get_element_relationship( self, element_txt, element_type, @@ -1042,28 +1058,28 @@ def get_element_relationship( return closest_match = None if element_type == DB_TYPE_OBJ: - closest_match = self.get_similar_object(element_txt) + closest_match = await self.get_similar_object(element_txt) elif relationship_type == CHAR_CONTAINING: - closest_match = self.get_similar_room(element_txt) + closest_match = await self.get_similar_room(element_txt) elif element_type == DB_TYPE_CHAR: - closest_match = self.get_similar_character(element_txt) + closest_match = await self.get_similar_character(element_txt) elif element_type == DB_TYPE_ROOM: - closest_match = self.get_similar_room(element_txt) + closest_match = await self.get_similar_room(element_txt) banned_items.extend(self.banned_rooms) if closest_match is None: return if relationship_type == NEIGHBOR: return [ r - for r in self.get_neighbor_rooms(closest_match.db_id, num_results) + for r in await self.get_neighbor_rooms(closest_match.db_id, num_results) if r not in banned_items ] elif relationship_type == CONTAINING: - return self.get_contained_items( + return await self.get_contained_items( closest_match.db_id, element_type, num_results ) elif relationship_type == CHAR_CONTAINING: - return self.get_contained_characters(closest_match.db_id, num_results) + return await self.get_contained_characters(closest_match.db_id, num_results) def get_text_features(self, c, full=False): """Return text feature given a candidate and return and cache the text feature""" diff --git a/light/graph/builders/starspace_neighbor.py b/light/graph/builders/starspace_neighbor.py index 7aef0b435..0eacc087a 100644 --- a/light/graph/builders/starspace_neighbor.py +++ b/light/graph/builders/starspace_neighbor.py @@ -22,11 +22,13 @@ import random import copy import numpy as np +import asyncio random.seed(6) np.random.seed(6) +# TODO deprecate? class StarspaceNeighborBuilder(GraphBuilder): """Old builder that used starspace to connect rooms and db entries to fill them""" @@ -209,7 +211,7 @@ def add_new_agent_to_graph(self, g, char, room_id): g.set_prop(obj_id, "equipped") return agent_id - def add_random_new_agent_to_graph(self, g): + async def add_random_new_agent_to_graph(self, g): # pick a random room while True: id = random.choice(list(g.oo_graph.rooms.keys())) @@ -436,7 +438,7 @@ def get_similar_room(self, txt_feats): if len(response["text_candidates"]) <= ind: return None - def get_graph(self): + async def get_graph(self): g = OOGraph(self.opt) g.npc_models._no_npc_models = self._no_npc_models g.db = self.db diff --git a/light/graph/builders/tests/test_StarSpaceBuilder.py b/light/graph/builders/tests/test_StarSpaceBuilder.py index db1d31698..e6df7d2d5 100644 --- a/light/graph/builders/tests/test_StarSpaceBuilder.py +++ b/light/graph/builders/tests/test_StarSpaceBuilder.py @@ -8,6 +8,7 @@ from parlai.core.params import ParlaiParser import parlai.utils.misc as parlai_utils import pytest +import asyncio sys.modules["parlai.core.utils"] = parlai_utils from light.graph.structured_graph import OOGraph @@ -33,6 +34,7 @@ @pytest.mark.slow +@pytest.mark.skip(reason="Need to update starspace builders to use Model Pool") class TestStarspaceBuilder(unittest.TestCase): def setUp(self): random.seed(20) @@ -45,7 +47,7 @@ def setUp(self): self.testBuilder = StarspaceBuilder( ldb, ) - self.testGraph, _ = self.testBuilder.get_graph() + self.testGraph, _ = asyncio.run(self.testBuilder.get_graph()) def test_arg_parser(self): parser = ParlaiParser() diff --git a/light/graph/builders/tests/test_StarspaceNeighborBuilder.py b/light/graph/builders/tests/test_StarspaceNeighborBuilder.py index 54d69dfc3..ae4feafa5 100644 --- a/light/graph/builders/tests/test_StarspaceNeighborBuilder.py +++ b/light/graph/builders/tests/test_StarspaceNeighborBuilder.py @@ -11,6 +11,7 @@ @pytest.mark.slow +@pytest.mark.skip(reason="Need to update starspace builders to use Model Pool") class TestStarspaceNeighborBuilder(unittest.TestCase): def setUp(self): parser = ParlaiParser() diff --git a/light/graph/builders/tests/test_user.py b/light/graph/builders/tests/test_user.py index ca51e9836..26a58de2d 100644 --- a/light/graph/builders/tests/test_user.py +++ b/light/graph/builders/tests/test_user.py @@ -10,6 +10,7 @@ import os import pickle import random +import asyncio from light.graph.builders.base import DBGraphBuilder from light.graph.builders.user_world_builder import UserWorldBuilder @@ -103,7 +104,7 @@ def setUp(self): self.graphBuilder = UserWorldBuilder( self.ldb, self.world_id, self.player_id, True, opt ) - self.testGraph, self.testWorld = self.graphBuilder.get_graph() + self.testGraph, self.testWorld = asyncio.run(self.graphBuilder.get_graph()) def tearDown(self): shutil.rmtree(self.data_dir) @@ -115,7 +116,7 @@ def test_builder_adds_random_agent_to_graph_adds_another_to_some_room(self): dbid_to_g = {val: key for key, val in self.graphBuilder.roomid_to_db.items()} gRoomId = dbid_to_g[self.roomID] gRoomId2 = dbid_to_g[self.roomID2] - self.graphBuilder.add_random_new_agent_to_graph(self.testWorld) + asyncio.run(self.graphBuilder.add_random_new_agent_to_graph(self.testWorld)) self.assertEqual(len(self.testGraph.agents), 3) contained_room_1 = len(self.testGraph.get_node(gRoomId).contained_nodes) contained_room_2 = len(self.testGraph.get_node(gRoomId2).contained_nodes) diff --git a/light/graph/builders/tutorial_builder.py b/light/graph/builders/tutorial_builder.py index c934fb6bc..74d67dddc 100644 --- a/light/graph/builders/tutorial_builder.py +++ b/light/graph/builders/tutorial_builder.py @@ -1,12 +1,15 @@ +#!/usr/bin/env python3 + # Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from light.graph.builders.map_json_builder import MapJsonBuilder -from light.world.world import World +from light.world.world import World, WorldConfig from light.world.purgatory import TutorialPurgatory from light.graph.structured_graph import OOGraph +import asyncio from typing import Dict, Optional, Any, TYPE_CHECKING if TYPE_CHECKING: @@ -19,12 +22,7 @@ class TutorialWorldBuilder(MapJsonBuilder): made to run tutorials. Generally like a single room builder. """ - def __init__(self, db: "LIGHTDatabase", opt: Dict[str, Any] = None): - """Store initialization options""" - self.db = db - self.opt = opt if opt is not None else {} - - def add_random_new_agent_to_graph(self, target_graph): + async def add_random_new_agent_to_graph(self, target_graph): """Add an agent to the graph in a random room somewhere""" raise Exception("Agents should not be added to tutorials!") @@ -32,7 +30,9 @@ def build_new_graph(self): """ Create a tutorial graph, not from file """ - graph = OOGraph(self.opt) + opt = self.opt.copy() + opt["tutorial"] = True + graph = OOGraph(opt) room_node = graph.add_room( "Impossible Tavern", { @@ -67,7 +67,8 @@ def build_new_graph(self): "name_prefix": "", "persona": "You are, well, yourself... a wandering soul who has yet to " "become someone in the full LIGHT world. Perhaps you may be " - "granted admission by the dungeon master? ", + "granted admission by the dungeon master?\nYour Mission: Find out how to " + "get to LIGHT, then get in to play.", "mission": "Find out how to get to LIGHT, then get in to play.", }, ) @@ -141,16 +142,20 @@ def build_new_graph(self): None, "You feel as if this portal leads somewhere unusual.", ) + dungeon_master_node.block(agent_node) return graph - def get_graph(self): + async def get_graph(self, world_config: Optional[WorldConfig] = None): """Create and return a tutorial graph""" - if self.opt.get("load_map", None) is not None: - graph, _ = super().get_graph() + if self.opt.get("load_tutorial_map", None) is not None: + graph, _ = await super().get_graph() else: graph = self.build_new_graph() + opt = self.opt.copy() + opt["tutorial"] = True + world = World(self._get_attached_config(world_config, opt)) - world = World(self.opt, self) world.oo_graph = graph + # Force the logging mode to tutorial PRE_LAUNCH_TUTORIAL world.purgatory = TutorialPurgatory(world) return graph, world diff --git a/light/graph/builders/user_world_builder.py b/light/graph/builders/user_world_builder.py index 2d563bfea..c15cce4ae 100644 --- a/light/graph/builders/user_world_builder.py +++ b/light/graph/builders/user_world_builder.py @@ -5,13 +5,14 @@ # LICENSE file in the root directory of this source tree. from parlai.core.params import ParlaiParser import random +import asyncio from light.graph.structured_graph import OOGraph from light.graph.events.graph_events import ArriveEvent from light.graph.builders.base import ( DBGraphBuilder, POSSIBLE_NEW_ENTRANCES, ) -from light.world.world import World +from light.world.world import World, WorldConfig # TODO: Refactor common functionality between builders! @@ -123,7 +124,7 @@ def add_new_agent_to_graph(self, g, char, pos_room): obj_node.set_prop("equipped", True) return agent - def add_random_new_agent_to_graph(self, world): + async def add_random_new_agent_to_graph(self, world): # pick a random room g = world.oo_graph id = random.choice(list(g.rooms.keys())) @@ -139,7 +140,7 @@ def add_random_new_agent_to_graph(self, world): ) arrival_event.execute(world) - def get_graph(self): + async def get_graph(self): """Return an OOGraph built by this builder""" g = OOGraph(self.opt) self.g = g @@ -159,7 +160,7 @@ def get_graph(self): self.add_nodes(g, resources, db_to_g, node_to_g) self.add_edges(g, edge_list, node_to_g) - world = World(self.opt, self) + world = World(WorldConfig(opt=self.opt, graph_builder=self)) world.oo_graph = g return g, world diff --git a/light/graph/elements/graph_nodes.py b/light/graph/elements/graph_nodes.py index 36c0ea211..56cf201b7 100644 --- a/light/graph/elements/graph_nodes.py +++ b/light/graph/elements/graph_nodes.py @@ -589,6 +589,7 @@ def __init__(self, node_id, name, props=None, db_id=None): # Flag to resolve when a death event is in the stack, but possibly not processed self._dying = False self.is_player = self._props.get("is_player", self._props.get("_human", False)) + self.user_id = self._props.get("user_id", None) self.usually_npc = self._props.get("usually_npc", False) self.pacifist = self._props.get("pacifist", False) self.tags = self._props.get("tags", self.DEFAULT_TAGS) diff --git a/light/graph/events/base.py b/light/graph/events/base.py index 6a1e33d73..90750ebd3 100644 --- a/light/graph/events/base.py +++ b/light/graph/events/base.py @@ -88,7 +88,7 @@ def __init__( """ if event_id is None: event_id = str(uuid4()) - self.executed: bool = False # type: ignore + self.executed: bool = False self.actor = actor self.room = actor.get_room() self.target_nodes = [] if target_nodes is None else target_nodes diff --git a/light/graph/events/graph_events.py b/light/graph/events/graph_events.py index 0f1edde1a..5b53b0205 100644 --- a/light/graph/events/graph_events.py +++ b/light/graph/events/graph_events.py @@ -30,16 +30,8 @@ from light.graph.events.safety import SafetyClassifier import math -safety_classifier = None - - -def init_safety_classifier(datapath): - global safety_classifier - if datapath is not None and len(datapath) > 0: - safety_classifier = SafetyClassifier(datapath, True) - - if TYPE_CHECKING: + from light.registry.model_pool import ModelPool from light.world.world import World from light.graph.structured_graph import OOGraph @@ -94,6 +86,7 @@ def __init__( target_nodes: Optional[List[GraphNode]] = None, text_content: Optional[str] = None, event_id: Optional[str] = None, + safe: Optional[bool] = True, ): super().__init__( actor, @@ -104,17 +97,9 @@ def __init__( # Give opportunity to skip the safety after initialization # for debug reasons self.skip_safety = False - self.safe = None + self.safe = safe def is_dialogue_safe(self, text): - if safety_classifier is None: - self.safe = True - return True - - if safety_classifier.is_safe(text): - self.safe = True - else: - self.safe = False return self.safe @@ -663,9 +648,9 @@ def execute(self, world: "World") -> List[GraphEvent]: health = self.actor.health eps = self.actor.movement_energy_cost if health > eps: - health_text = world.health(self.actor.node_id) + health_text = world.view.get_health_text_for(self.actor.node_id) self.actor.health = max(0, health - eps) - new_health_text = world.health(self.actor.node_id) + new_health_text = world.view.get_health_text_for(self.actor.node_id) if health_text != new_health_text: HealthEvent(self.actor, text_content="HealthOnMoveEvent").execute(world) @@ -1135,7 +1120,7 @@ def execute(self, world: "World") -> List[GraphEvent]: # Trigger the actual death world.oo_graph.agent_die(self.actor) - # world.purgatory.clear_soul(self.actor) todo - clear soul only after message queue consumed + # await world.purgatory.clear_soul(self.actor) todo - clear soul only after message queue consumed return [] @proper_caps_wrapper @@ -2724,9 +2709,9 @@ def execute(self, world: "World") -> List[GraphEvent]: world.broadcast_to_room(self) - health_text = world.health(self.actor.node_id) + health_text = world.view.get_health_text_for(self.actor.node_id) self.actor.health = max(self.actor.health + fe, 0) - new_health_text = world.health(self.actor.node_id) + new_health_text = world.view.get_health_text_for(self.actor.node_id) if self.actor.health <= 0: DeathEvent(self.actor).execute(world) elif health_text != new_health_text: @@ -3137,7 +3122,7 @@ def actor_has_no_recent_action(last_time_acted, current_time): class ExamineEvent(GraphEvent): """Handles displaying examine/extra text for a graph node""" - NAMES = ["examine", "ex"] + NAMES = ["examine", "ex", "inspect"] def _get_target_description(self, world: "World") -> str: """Get the examine description for the given target""" @@ -3599,7 +3584,7 @@ def execute(self, world: "World") -> List[GraphEvent]: """ assert not self.executed self.__actor_name = self.actor.get_prefix_view() - self.__health_text = world.health(self.actor.node_id) + self.__health_text = world.view.get_health_text_for(self.actor.node_id) to_agents = [self.actor] for t in self.target_nodes: to_agents.append(t) diff --git a/light/graph/events/magic.py b/light/graph/events/magic.py index efb145398..c750ba4ee 100644 --- a/light/graph/events/magic.py +++ b/light/graph/events/magic.py @@ -113,7 +113,7 @@ def creo(agent, event): # TODO: later maybe make a proprty: hasattr(node, 'magical_create') and node.magical_create: if node.name == "orb of creation": can_cast = True - if agent.world.opt.get("allow_save_world", False): + if agent.world._opt.get("allow_save_world", False): can_cast = True if not can_cast: return @@ -167,7 +167,7 @@ def teleport(agent, event): for node in agent.target_node.get_contents(): if node.name == "dark emerald ring": can_cast = True - if agent.world.opt.get("allow_save_world", False): + if agent.world._opt.get("allow_save_world", False): can_cast = True if not can_cast: return @@ -211,7 +211,7 @@ def save(agent, event): def check_if_cast_magic_from_event(agent, event): event_name = event.__class__.__name__ if event_name == "SayEvent" and event.actor == agent.target_node: - if event.text_content == "creoservo" and agent.world.opt.get( + if event.text_content == "creoservo" and agent.world._opt.get( "allow_save_world", False ): save(agent, event) diff --git a/light/graph/events/safety.py b/light/graph/events/safety.py index d6fffc670..b50a96e1c 100644 --- a/light/graph/events/safety.py +++ b/light/graph/events/safety.py @@ -5,61 +5,48 @@ # LICENSE file in the root directory of this source tree. from parlai.utils.safety import OffensiveStringMatcher -from parlai.core.agents import create_agent -from parlai.core.params import ParlaiParser from parlai.agents.transformer.transformer import TransformerClassifierAgent +from parlai.utils.typing import TShared +from parlai.tasks.dialogue_safety.agents import OK_CLASS, NOT_OK_CLASS +from light.registry.model_pool import ModelTypeName -try: - from parlai_internal.agents.safety_wrapper.multiturn_safety import ( - MultiturnOffensiveLanguageClassifier, - ) -except: +from typing import Optional, TYPE_CHECKING - class MultiturnOffensiveLanguageClassifier: - # Temporary until using public safety - pass - -class AdversarialOffensiveLanguageClassifier(MultiturnOffensiveLanguageClassifier): - """ - Load model trained to detect offensive language in the context of multi- turn - dialogue utterances. - This model was trained to be robust to adversarial examples created by humans. See - for more information. - """ - - def _create_safety_model(self): - parser = ParlaiParser(False, False) - TransformerClassifierAgent.add_cmdline_args(parser) - parser.set_params( - model_file="zoo:bot_adversarial_dialogue/multi_turn/model", - print_scores=True, - split_lines=True, - model_parallel=False, - threshold=0.999, - bs=1, - ) - safety_opt = parser.parse_args([]) - return create_agent(safety_opt, requireModelExists=True) +if TYPE_CHECKING: + from light.registry.model_pool import ModelPool class SafetyClassifier: - def __init__(self, datapath, use_model=False): - if datapath != "": + def __init__(self, datapath: Optional[str], model_pool: "ModelPool"): + self.classes = {OK_CLASS: False, NOT_OK_CLASS: True} + if datapath is not None and datapath != "": self.string_matcher = OffensiveStringMatcher(datapath) else: self.string_matcher = None - if use_model: - self.classifier = AdversarialOffensiveLanguageClassifier() + if model_pool.has_model(ModelTypeName.SAFETY): + self.classifier = model_pool.get_model(ModelTypeName.SAFETY) else: self.classifier = None - def is_safe(self, text): + async def contains_offensive_language(self, text): + """ + Returns the probability that a message is safe according to the classifier. + """ + act = {"text": text, "episode_done": True} + self.classifier.observe(act) + response_act = await self.classifier.act() + response = response_act["text"] + pred_class, prob = [x.split(": ")[-1] for x in response.split("\n")] + pred_not_ok = self.classes[pred_class] # check whether classified as NOT OK + prob = float(prob) # cast string to float + return pred_not_ok, prob + + async def is_safe(self, text: str): if self.string_matcher is not None: if text in self.string_matcher: return False if self.classifier is not None: - print(text) - if text in self.classifier: - return False + not_ok, _prob = await self.contains_offensive_language(text) + return not not_ok return True diff --git a/light/graph/structured_graph.py b/light/graph/structured_graph.py index 0521e2349..7c203f3f2 100644 --- a/light/graph/structured_graph.py +++ b/light/graph/structured_graph.py @@ -14,6 +14,7 @@ GraphVoidNode, GraphEdge, ) +from typing import Optional, Dict, Any from light.world.utils.json_utils import GraphEncoder from light.world.content_loggers import RoomInteractionLogger @@ -24,6 +25,8 @@ class OOGraph(object): """ def __init__(self, opt=None): + if opt is None: + opt = {} self.objects = {} self.agents = {} self.rooms = {} @@ -35,6 +38,8 @@ def __init__(self, opt=None): self._deleted_nodes = {} self.dead_nodes = {} self._opt = opt + self.title = opt.get("title", "untitled") + self.db_id: Optional[str] = opt.get("db_id") @staticmethod def from_graph(graph, start_location=None): @@ -467,6 +472,7 @@ def to_json(self): "agents": sorted(list(self.agents.keys())), "rooms": sorted(list(self.rooms.keys())), "nodes": self.all_nodes, + "title": self.title, } return json.dumps(dicts, cls=GraphEncoder, sort_keys=True, indent=4) @@ -504,6 +510,7 @@ def to_json_rv(self, room_id): "nodes": {node.node_id: node for node in nodes}, "objects": sorted(objects), "rooms": sorted(rooms), + "title": self.title, } return json.dumps(dicts, cls=GraphEncoder, sort_keys=True, indent=4) @@ -524,9 +531,12 @@ def get_contained_in_room(room_node): return contained_nodes @staticmethod - def from_json(input_json: str): + def from_json(input_json: str, opt: Optional[Dict[str, Any]] = None): dict_format = json.loads(input_json) - oo_graph = OOGraph() + opt = opt if opt is not None else {} + if dict_format.get("title") is not None: + opt["title"] = dict_format["title"] + oo_graph = OOGraph(opt) object_ids = set(dict_format["objects"]) agent_ids = set(dict_format["agents"]) room_ids = set(dict_format["rooms"]) diff --git a/light/graph/tests/test_events.py b/light/graph/tests/test_events.py index 6b60a02b5..8a9896cd7 100644 --- a/light/graph/tests/test_events.py +++ b/light/graph/tests/test_events.py @@ -2,7 +2,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree.abs +# LICENSE file in the root directory of this source tree. import unittest import json @@ -51,7 +51,7 @@ GraphObject, GraphAgent, ) -from light.world.world import World +from light.world.world import World, WorldConfig from typing import Tuple, List, Type, Optional @@ -86,7 +86,7 @@ def setUp(self) -> None: """ Setup should put together any requirements for starting the database for a test. """ - self.world = World({}, None) + self.world = World(WorldConfig()) self.reset_world() def reset_world(self) -> None: diff --git a/light/graph/tests/test_nodes.py b/light/graph/tests/test_nodes.py index 321497578..c44da4cbd 100644 --- a/light/graph/tests/test_nodes.py +++ b/light/graph/tests/test_nodes.py @@ -2,7 +2,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree.abs +# LICENSE file in the root directory of this source tree. import unittest diff --git a/light/modeling/agents/quests/rl/base/process/format.py b/light/modeling/agents/quests/rl/base/process/format.py index beeea303e..91b607d27 100644 --- a/light/modeling/agents/quests/rl/base/process/format.py +++ b/light/modeling/agents/quests/rl/base/process/format.py @@ -9,7 +9,7 @@ from glob import glob import json from light.graph.structured_graph import OOGraph -from light.world.world import World +from light.world.world import World, WorldConfig from light.constants import LIGHT_DATAPATH import copy @@ -216,7 +216,7 @@ def sequence(dialogue): graph_json = quest_file["graph"] g = OOGraph.from_json(graph_json) - world = World({}, None) + world = World(WorldConfig()) world.oo_graph = g # print(world.get_possible_actions(human['id'], USE_ACTIONS)) diff --git a/light/modeling/agents/quests/rl/shared/models/transformer.py b/light/modeling/agents/quests/rl/shared/models/transformer.py index 812f0f713..f2903f91e 100644 --- a/light/modeling/agents/quests/rl/shared/models/transformer.py +++ b/light/modeling/agents/quests/rl/shared/models/transformer.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 + # Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - """ LIGHT Transformer Agents. """ diff --git a/light/modeling/agents/quests/rl/shared/process/conversation.py b/light/modeling/agents/quests/rl/shared/process/conversation.py index c94e0908d..aaf0d0ef8 100644 --- a/light/modeling/agents/quests/rl/shared/process/conversation.py +++ b/light/modeling/agents/quests/rl/shared/process/conversation.py @@ -7,7 +7,7 @@ from collections import namedtuple from light.graph.structured_graph import OOGraph -from light.world.world import World +from light.world.world import World, WorldConfig Turn = namedtuple( "Turn", @@ -99,7 +99,7 @@ def parse(self, conv_dict): """ # Get Graph g = OOGraph.from_json(conv_dict["graph_json"]) - world = World({}, None) + world = World(WorldConfig()) world.oo_graph = g self.graph = world diff --git a/light/modeling/agents/quests/rl/switch/environments/quest.py b/light/modeling/agents/quests/rl/switch/environments/quest.py index 9926f2cb4..827b38ac4 100644 --- a/light/modeling/agents/quests/rl/switch/environments/quest.py +++ b/light/modeling/agents/quests/rl/switch/environments/quest.py @@ -8,7 +8,7 @@ import json from glob import glob from copy import deepcopy, copy -from light.world.world import World +from light.world.world import World, WorldConfig from light.graph.structured_graph import OOGraph @@ -86,7 +86,7 @@ def reset(self): self.score = 0 self.real_speech_score = 0 self.real_act_score = 0 - # self.world = World({}, None) + # self.world = World(WorldConfig()) # g = OOGraph.from_json(deepcopy(self.graph_data)) # self.world.oo_graph = g self.world = deepcopy(self.graph_data) diff --git a/light/modeling/agents/quests/rl/switch/process/format.py b/light/modeling/agents/quests/rl/switch/process/format.py index 49e2a056b..144aaebb2 100644 --- a/light/modeling/agents/quests/rl/switch/process/format.py +++ b/light/modeling/agents/quests/rl/switch/process/format.py @@ -9,7 +9,7 @@ from glob import glob import json from light.graph.structured_graph import OOGraph -from light.world.world import World +from light.world.world import World, WorldConfig from light.constants import LIGHT_DATAPATH import copy @@ -236,7 +236,7 @@ def sequence(dialogue): graph_json = quest_file["graph"] g = OOGraph.from_json(graph_json) - world = World({}, None) + world = World(WorldConfig()) world.oo_graph = g # print(world.get_possible_actions(human['id'], USE_ACTIONS)) diff --git a/light/registry/model_pool.py b/light/registry/model_pool.py index 9209c74b7..fad236176 100644 --- a/light/registry/model_pool.py +++ b/light/registry/model_pool.py @@ -6,28 +6,59 @@ from dataclasses import dataclass, field from omegaconf import MISSING, DictConfig +import asyncio +import enum -from light.registry.models.parlai_model import ParlAIModelConfig, ParlAIModelLoader +from light.registry.parlai_model import ParlAIModelConfig, ParlAIModelLoader +from light.registry.parlai_remote_model import ( + ParlAIRemoteModelConfig, + ParlAIRemoteModelLoader, +) from light.registry.models.acting_score_model import ( ParlAIPolyencoderActingScoreModelConfig, ParlAIPolyencoderActingScoreModelLoader, ) +from light.registry.models.starspace_model import ( + MapStarspaceModelConfig, + MapStarspaceModelLoader, +) from parlai.core.agents import Agent -from typing import List, Any, Dict, Optional +from typing import List, Any, Union, Dict, Optional, Type + + +# We should make a base ModelLoader class +ModelLoaderClass = Union[Type[ParlAIModelLoader], Type[ParlAIRemoteModelLoader]] +ModelLoader = Union[ParlAIModelLoader, ParlAIRemoteModelLoader] +ModelConfig = Union[ParlAIModelConfig, ParlAIRemoteModelConfig] -# At the moment all models are ParlAIModelLoaders. May change as we make more models -ALL_LOADERS: Dict[str, ParlAIModelLoader] = { +ALL_LOADERS: Dict[str, ModelLoaderClass] = { ParlAIModelConfig._loader: ParlAIModelLoader, ParlAIPolyencoderActingScoreModelConfig._loader: ParlAIPolyencoderActingScoreModelLoader, + MapStarspaceModelConfig._loader: MapStarspaceModelLoader, + ParlAIRemoteModelConfig._loader: ParlAIRemoteModelLoader, } +class ModelTypeName(enum.Enum): + """Common model names of use in LIGHT, for use in register_model""" + + SAFETY = "safety" # Models used to evaluate dialog or env safety + DIALOG = "dialog" # Models for generating dialogue + SCORING = "role_playing_score" # Models to score player utterances + ACTION = "action" # Models used by model agents for generating actions + GENERIC_ACTS = "generic_action" # Models to select a next action from cands + PARSER = "parser" # Models to parse raw text to in-game actions + SERVED = "served" # Any generic served model (for ModelServer) + + class ModelPool: def __init__(self): self._model_loaders = {} - def register_model(self, config: DictConfig, model_names: List[str]) -> None: + async def register_model_async( + self, config: Union[DictConfig, ModelConfig], model_names: List[ModelTypeName] + ) -> None: """ Takes the given config, loads the model, and stores it in the registry under the given names. @@ -38,16 +69,33 @@ def register_model(self, config: DictConfig, model_names: List[str]) -> None: f"Trying to load a model with non-existent loader {config._loader}" ) loader = loader_class(config) + await loader.force_load() for model_name in model_names: - self._model_loaders[model_name] = loader + self._model_loaders[model_name.value] = loader + + def register_model( + self, config: Union[DictConfig, ModelConfig], model_names: List[str] + ) -> None: + """ + Syncronous model registration for server and script setups + """ + return asyncio.run(self.register_model_async(config, model_names)) + + def has_model(self, model_name: ModelTypeName) -> bool: + """ + Determine if there's a model registered for the given name. + """ + return model_name.value in self._model_loaders - def get_model(self, model_name: str, overrides: Optional[Dict[str, Any]]) -> Agent: + def get_model( + self, model_name: ModelTypeName, overrides: Optional[Dict[str, Any]] = None + ) -> Agent: """ Get a copy of the model stored in the given name If overrides are provided, pass those to the loader as well """ - loader = self._model_loaders.get(model_name) + loader = self._model_loaders.get(model_name.value) if loader is None: raise AssertionError( f"No models registered for requested name {model_name}" diff --git a/light/registry/models/acting_score_model.py b/light/registry/models/acting_score_model.py index 8208237b3..00e3e787e 100644 --- a/light/registry/models/acting_score_model.py +++ b/light/registry/models/acting_score_model.py @@ -4,11 +4,15 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +import asyncio from dataclasses import dataclass, field from parlai.core.agents import Agent +from parlai.core.message import Message +import types + +from light.registry.parlai_model import ParlAIModelConfig, ParlAIModelLoader -from light.registry.models.parlai_model import ParlAIModelConfig, ParlAIModelLoader +SCORE_INDS = [1000, 2000, 5000, 10000] @dataclass @@ -32,4 +36,47 @@ def before_return_model(self, model) -> Agent: model.actingscore = True # override eval step here model.eval_step = model.eval_step_scoresonly + + # Override act and observe so that we can catch from remote calls + # as well + old_act = model.act + old_observe = model.observe + + def new_observe(model_self, message: Message): + model_self._last_observe = message + old_observe(message) + + model.observe = types.MethodType(new_observe, model) + model._last_observe = Message({}) + + old_act = model.act + + def new_act(model_self): + if model_self._last_observe.get("label_candidates"): + # Evalling just one cand + model_self.opt["candidates"] = "inline" + model_self.candidates = "inline" + model_self.opt["eval_candidates"] = "inline" + model_self.eval_candidates = "inline" + model_self.reset() + old_observe(model_self._last_observe) # re-observe to vectorize + act = old_act() + scores = model_self.scores + act["scores"] = scores[0].tolist() + else: + # Evalling against the base candidates + model_self.opt["candidates"] = "fixed" + model_self.candidates = "fixed" + model_self.opt["eval_candidates"] = "fixed" + model_self.eval_candidates = "fixed" + # model_self.reset() + act = old_act() + list_scores = sorted(model_self.scores[0].tolist()) + list_scores.reverse() + scores = [list_scores[i] for i in SCORE_INDS] + act["scores"] = scores + return act + + model.act = types.MethodType(new_act, model) + return model diff --git a/light/registry/models/config/baseline_adversarial_safety.opt b/light/registry/models/config/baseline_adversarial_safety.opt new file mode 100644 index 000000000..80c542236 --- /dev/null +++ b/light/registry/models/config/baseline_adversarial_safety.opt @@ -0,0 +1,9 @@ +{ + "model": "transformer/classifier", + "model_file": "zoo:bot_adversarial_dialogue/multi_turn/model", + "print_scores": true, + "split_lines": true, + "model_parallel": false, + "threshold": 0.999, + "batchsize": 1 +} diff --git a/light/registry/models/config/baseline_generative.opt b/light/registry/models/config/baseline_generative.opt new file mode 100644 index 000000000..1dfe50604 --- /dev/null +++ b/light/registry/models/config/baseline_generative.opt @@ -0,0 +1,12 @@ +{ + "model": "transformer/generator", + "model_file": "$LIGHT_MODEL_ROOT/dialog/baseline_gen/model", + "inference": "beam", + "datatype": "valid", + "beam_context_block_ngram": 3, + "beam_block_ngram": 3, + "beam_size": 10, + "beam_min_length": 20, + "skip_generation": false, + "interactive_mode": true +} diff --git a/light/registry/models/config/baseline_generative_reranked.opt b/light/registry/models/config/baseline_generative_reranked.opt new file mode 100644 index 000000000..cedff530e --- /dev/null +++ b/light/registry/models/config/baseline_generative_reranked.opt @@ -0,0 +1,12 @@ +{ + "model": "internal:light_whoami/generative_rerank", + "predictor_model_file": "$LIGHT_MODEL_ROOT/dialog/rerank/model", + "model_file": "$LIGHT_MODEL_ROOT/dialog/baseline/model", + "inference": "delayedbeam", + "datatype": "valid", + "beam_context_block_ngram": 3, + "beam_block_ngram": 3, + "beam_size": 10, + "beam_min_length": 20, + "interactive_mode": true +} diff --git a/light/registry/models/config/baseline_generative_reranker.opt b/light/registry/models/config/baseline_generative_reranker.opt new file mode 100644 index 000000000..21de259f8 --- /dev/null +++ b/light/registry/models/config/baseline_generative_reranker.opt @@ -0,0 +1,13 @@ +{ + "model": "projects.light_whoami.agents.expanded_attention:ExpandedDecoderAttentionAndPacerAgent", + "predictor_model_file": "zoo:light_whoami/rpa_reranker/model", + "model_file": "zoo:light_whoami/profile_expanded_attention_128/model", + "inference": "beam", + "datatype": "valid", + "beam_context_block_ngram": 3, + "beam_block_ngram": 3, + "beam_size": 10, + "beam_min_length": 20, + "skip_generation": false, + "interactive_mode": true +} diff --git a/light/registry/models/config/baseline_generative_with_start.opt b/light/registry/models/config/baseline_generative_with_start.opt new file mode 100644 index 000000000..cf519f73d --- /dev/null +++ b/light/registry/models/config/baseline_generative_with_start.opt @@ -0,0 +1,11 @@ +{ + "model": "transformer/generator", + "model_file": "$LIGHT_MODEL_ROOT/dialog/baseline_gen_start/model.checkpoint", + "inference": "beam", + "datatype": "valid", + "beam_context_block_ngram": 3, + "beam_block_ngram": 3, + "beam_size": 10, + "beam_min_length": 20, + "interactive_mode": true +} diff --git a/light/registry/models/config/baseline_main_act_model.opt b/light/registry/models/config/baseline_main_act_model.opt new file mode 100644 index 000000000..931724448 --- /dev/null +++ b/light/registry/models/config/baseline_main_act_model.opt @@ -0,0 +1,6 @@ +{ + "model_file": "$LIGHT_MODEL_ROOT/acting/baseline/model", + "eval_candidates": "inline", + "ignore_bad_candidates": true, + "interactive_mode": true +} diff --git a/light/registry/models/config/baseline_parser.opt b/light/registry/models/config/baseline_parser.opt new file mode 100644 index 000000000..8ff50413f --- /dev/null +++ b/light/registry/models/config/baseline_parser.opt @@ -0,0 +1,5 @@ +{ + "model_file": "$LIGHT_MODEL_ROOT/parser/baseline/model", + "interactive_candidates": "inline", + "interactive_mode": true +} diff --git a/light/registry/models/config/baseline_roleplaying_scorer.opt b/light/registry/models/config/baseline_roleplaying_scorer.opt new file mode 100644 index 000000000..41eae9b52 --- /dev/null +++ b/light/registry/models/config/baseline_roleplaying_scorer.opt @@ -0,0 +1,12 @@ +{ + "model_file": "$LIGHT_MODEL_ROOT/scoring/baseline/model", + "candidates": "fixed", + "eval_candidates": "fixed", + "use_reply": "none", + "interactive_mode": true, + "fixed_candidates_path": "$LIGHT_MODEL_ROOT/cands/speech_train_cands_extra_filtered_more.txt", + "partner_trainset": "$LIGHT_MODEL_ROOT/cands/agent_to_utterance_partner_trainset.txt", + "trainset": "$LIGHT_MODEL_ROOT/cands/agent_to_utterance_trainset.txt", + "baseforms": "$LIGHT_MODEL_ROOT/cands/baseforms.json", + "boring_alpha": 0 +} diff --git a/light/registry/models/config/baseline_starspace.opt b/light/registry/models/config/baseline_starspace.opt new file mode 100644 index 000000000..9e5699685 --- /dev/null +++ b/light/registry/models/config/baseline_starspace.opt @@ -0,0 +1,7 @@ +{ + "model": "starspace", + "model_file": "$LIGHT_MODEL_ROOT/starspace/angela_starspace/model4" + "eval_candidates": "inline", + "ignore_bad_candidates": true, + "interactive_mode": true +} diff --git a/light/registry/models/config/generic_act_model.opt b/light/registry/models/config/generic_act_model.opt new file mode 100644 index 000000000..04d66c8f7 --- /dev/null +++ b/light/registry/models/config/generic_act_model.opt @@ -0,0 +1,6 @@ +{ + "model_file": "$LIGHT_MODEL_ROOT/acting/baseline_generic/model", + "eval_candidates": "inline", + "ignore_bad_candidates": true, + "interactive_mode": true +} diff --git a/light/registry/models/starspace_model.py b/light/registry/models/starspace_model.py new file mode 100644 index 000000000..592f08db4 --- /dev/null +++ b/light/registry/models/starspace_model.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass, field +from parlai.core.agents import Agent +import os + +from typing import Optional, Dict, Any + +from light.registry.parlai_model import ParlAIModelConfig, ParlAIModelLoader + + +@dataclass +class MapStarspaceModelConfig(ParlAIModelConfig): + _loader: str = "MapStarspaceLoader" + resource_path: str = field( + default=os.path.expanduser("~/ParlAI/data/light_maps/"), + metadata={"help": ("Path to the LIGHT maps data")}, + ) + + +class MapStarspaceModelLoader(ParlAIModelLoader): + """ + Takes in the configuration for a ParlAI model, and provides options + for being able to load that model one or multiple times (via sharing). + + We do some special post-setup on the acting score model. Ideally this + could be done as a special opt in the agent itself, but for now it's here. + """ + + def get_model(self, overrides: Optional[Dict[str, Any]] = None) -> Agent: + """Get a copy of the model""" + use_shared = self._shared + if use_shared is not None: + opt = deepcopy(use_shared["opt"]) + opt.update(overrides) + use_shared["opt"] = opt + + if opt["target_type"] == "room": + opt["fixed_candidates_file"] = os.path.join( + self.config.resource_path, "/room_full_cands.txt" + ) + elif opt["target_type"] == "agent": + opt["fixed_candidates_file"] = os.path.join( + self.config.resource_path, "/character_full_cands.txt" + ) + elif opt["target_type"] == "object": + opt["fixed_candidates_file"] = os.path.join( + self.config.resource_path, "/object_full_cands.txt" + ) + else: + raise NotImplementedError( + f"Given starspace target type {opt['target_type']} not implemented" + ) + + opt["override"]["fixed_candidates_file"] = opt["fixed_candidates_file"] + model = create_agent_from_shared(use_shared) + return self.before_return_model(model) diff --git a/light/registry/parlai_model.py b/light/registry/parlai_model.py new file mode 100644 index 000000000..ace02f44a --- /dev/null +++ b/light/registry/parlai_model.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +By design changing options can happen in a few places with the +following priority order: + +1. Options provided by a specific `ModelPool.get_model` call. +2. Options provided on the command line via `+.overrides.=value` +3. Options provided in the overrides of a particular hydra config yaml file +4. Options specified in the provided (hydra) `.opt_file` + (either from the yaml, or provided on the command line) +5. Options specified in the `.opt file` + +Hydra configures the process of making #2 override #3, but the rest +flow due to the semantics of ParlAI and the ParlAIModelLoader implementation here. +""" + +from dataclasses import dataclass, field +from omegaconf import MISSING, DictConfig + +from parlai.core.agents import Agent, create_agent, create_agent_from_shared +from parlai.core.message import Message +from parlai.core.opt import Opt +from parlai.core.params import ParlaiParser +from copy import deepcopy +import os +import asyncio + +from typing import List, Any, Dict, Optional + + +CONTEXT_FILL_COUNT = 200 +INIT_CONTEXT = """ +_setting_name weathered shack, Abandoned +_setting_desc A weathered shack with a roof made of old broken tiles sits in the middle of the forest. The wood is starting to split and the shack appears as if it will crumble at any moment. +_partner_name animal +_self_name man +_self_persona I am a strong man. I work in the fields and pastures all day. I take of my master's sheep. One day I hope to have my own sheep. +I am very strong +""" + + +@dataclass +class ParlAIModelConfig: + # As of now, ParlAI is the only model loader. + # Eventually this could be split into more classes + # as we incorporate other models. + _loader: str = "ParlAI" + model_file: str = field( + default=MISSING, metadata={"help": ("Path to the model file for this model.")} + ) + opt_file: str = field( + default=MISSING, + metadata={"help": ("Path to the ParlAI opt file for this model.")}, + ) + overrides: Dict[str, Any] = field( + default_factory=dict, + metadata={"help": ("Additional overrides for this model's opt")}, + ) + + def get(self, attr: str, default_val: Optional[Any] = None): + """Wrapper to ensure interoperability with hydra DictConfig""" + val = self.__dict__.get(attr, default_val) + if val == MISSING: + val = None + return val + + +class ParlAIModelLoader: + """ + Takes in the configuration for a ParlAI model, and provides options + for being able to load that model one or multiple times (via sharing). + """ + + def __init__(self, config: DictConfig): + self._shared = None + self.config = config + + async def force_load(self) -> None: + """ + Force the model loader to initialize and query + the model (to warm up) + """ + await self.load_model(self.config) + + async def load_model(self, config: DictConfig) -> None: + """Initialize the model from the given config""" + opt_from_config = config.get("opt_file", None) + model_from_config = config.get("model_file", None) + overrides = dict(config.get("overrides", {})) + + if opt_from_config is None and model_from_config is None: + raise AssertionError(f"Must provide one of opt_file or model_file") + + if opt_from_config is None: + parser = ParlaiParser(True, True, "") + opt = parser.parse_args(args=[]) + opt["override"] = opt.get("override", {}) + else: + opt_file = os.path.expanduser(opt_from_config) + opt = Opt.load(os.path.expanduser(opt_file)) + for key, item in opt.items(): + if not isinstance(item, str): + continue + if "$LIGHT_MODEL_ROOT" in item: + # Expand path and file keys to capture $LIGHT_MODEL_ROOT + opt[key] = os.path.expandvars(opt[key]) + + base_overrides = opt.get("base_overrides", {}) + base_overrides.update(opt.copy()) + opt["override"] = base_overrides + + if model_from_config is not None: + model_file = os.path.expanduser(config.model_file) + if not os.path.exists(model_file): + raise AssertionError( + f"Provided model file `{model_file}` does not exist." + ) + opt["model_file"] = model_file + + opt.update(overrides) + opt["override"].update(overrides) + model = create_agent(opt) + + context_fill = opt.get("truncate", CONTEXT_FILL_COUNT) + # Push something through the model to fill context + try: + act = { + "text": INIT_CONTEXT + "Hello " * context_fill, + "episode_done": True, + } + if opt.get("eval_candidates") == "inline": + act["label_candidates"] = ["hi", "hi there", "whatup"] + model.observe(act) + await model.act() + except Exception as e: + print(f"Cannot warm model {opt['model']}, hit error {e}") + + # Share the model params for use in `get_model` + self._shared = model.share() + + def before_return_model(self, model): + """Do any post-initialization we need for this model""" + return model + + def get_model(self, overrides: Optional[Dict[str, Any]] = None) -> Agent: + """Get a copy of the model""" + use_shared = self._shared + if use_shared is not None: + opt = deepcopy(use_shared["opt"]) + if overrides is not None: + opt.update(overrides) + use_shared["opt"] = opt + model = create_agent_from_shared(use_shared) + return self.before_return_model(model) diff --git a/light/registry/parlai_remote_model.py b/light/registry/parlai_remote_model.py new file mode 100644 index 000000000..c94a1f753 --- /dev/null +++ b/light/registry/parlai_remote_model.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass, field +from omegaconf import MISSING, DictConfig +import requests +import aiohttp +import asyncio +import logging +import json + +from parlai.core.agents import Agent +from parlai.core.message import Message +from parlai.core.opt import Opt +from copy import deepcopy +import os + +from typing import List, Any, Dict, Optional + + +DEFAULT_SERVER = "http://localhost:40000" +DEFAULT_SERVER_TIMEOUT = 600 +DEFAULT_RETRIES = 3 +DEFAULT_API_FAIL_TEXT = "MODEL RESPONSE FAILED" + + +def is_request_failed_response(resp): + """ + Whether the requests to Metaseq worker have failed. + It checks this based on the existences of the failure reasons as they get + accumulated in `_make_request` functionn calls. + """ + return len(resp.get("failures", [])) > 0 + + +async def _make_request( + session: aiohttp.ClientSession, + server: str, + act_message: Message, + num_retry_on_api_exception=-1, + request_delay: float = 0.5, +) -> Dict[str, Any]: + data = { + "observation": act_message, + } + init_request_delay = request_delay + past_exceptions: List[Dict[str, Any]] = [] + while True: + if ( + num_retry_on_api_exception >= 0 + and len(past_exceptions) > num_retry_on_api_exception + ): + logging.error("Reached maximum retries, returning failure message.") + return { + "failures": past_exceptions, + } + try: + logging.debug(f"Making request: {data}") + async with session.post( + f"{server}/model_request", + json=data, + ) as resp: + resp_text = await resp.text() + obj = json.loads(resp_text) + if "error" in obj: + request_delay *= 2 + logging.warning(f"Error: {obj['error']}") + past_exceptions.append(obj["error"]) + logging.debug(past_exceptions[-1]) + continue + debug = json.dumps(obj, sort_keys=True) + logging.debug(f"Model Server response: {debug}") + request_delay = init_request_delay + return obj + except asyncio.TimeoutError as e: + error_text = f'Timout a response for {len(act_message["text"])}\n{e}' + except aiohttp.client_exceptions.ClientOSError as e: + error_text = f'Retrying a response for {len(act_message["text"])}\n{e}' + except json.decoder.JSONDecodeError as e: + error_text = f"Got a bad response, {resp_text}. Retrying.\n{e}" + + past_exceptions.append({"error": error_text}) + logging.warning(error_text) + request_delay *= 2 + await asyncio.sleep(request_delay) + + +async def async_request_many( + server: str, + acts: List[Message], + timeout: Optional[int] = None, + max_num_tries: int = -1, +): + connector = aiohttp.TCPConnector(limit=0) + timeout_obj = aiohttp.ClientTimeout(total=timeout) + async with aiohttp.ClientSession( + timeout=timeout_obj, connector=connector + ) as session: + tasks = [] + for act in acts: + tasks.append( + asyncio.ensure_future( + _make_request( + session=session, + server=server, + act_message=act, + num_retry_on_api_exception=max_num_tries, + ) + ) + ) + results = await asyncio.gather(*tasks) + return results + + +def server_is_alive(server: str) -> bool: + """See if the specified server is alive""" + try: + alive_url = server + "/is_alive" + is_alive_json = requests.post(alive_url, json.dumps({"alive": True})) + is_alive = is_alive_json.json() + return is_alive.get("alive", False) + except Exception as e: + print("Error Checking liveliness: ", e) + return False + + +class ParlAIRemoteAgentWrapper(Agent): + def __init__(self, opt: Opt): + """Agent wrapper that actually just executes things remotely""" + self.observed_act = Message({"text": "", "episode_done": True}) + self.server = opt["server"] + self.retries = opt["retries"] + self.timeout = opt["timeout"] + + async def act(self): + resps = await async_request_many( + server=self.server, + acts=[self.observed_act], + timeout=self.timeout, + max_num_tries=self.retries, + ) + resp = resps[0] + if is_request_failed_response(resp): + act = Message({"text": DEFAULT_API_FAIL_TEXT}) + else: + act = Message(resp["act"]) + return act + + def observe(self, observation: Message): + self.observed_act = observation + + +@dataclass +class ParlAIRemoteModelConfig: + # As of now, ParlAI is the only model loader. + # Eventually this could be split into more classes + # as we incorporate other models. + _loader: str = "ParlAIRemote" + host: str = field( + default=MISSING, + metadata={"help": ("URL Hostname of the model server, with port.")}, + ) + retries: int = field( + default=DEFAULT_RETRIES, + metadata={"help": ("How many times to retry on error before giving up.")}, + ) + timeout: int = field( + default=DEFAULT_SERVER_TIMEOUT, + metadata={ + "help": ("How long to wait for a response before considering a timeout") + }, + ) + + def get(self, attr: str, default_val: Optional[Any] = None): + """Wrapper to ensure interoperability with hydra DictConfig""" + val = self.__dict__.get(attr, default_val) + if val == MISSING: + val = None + return val + + +class ParlAIRemoteModelLoader: + """ + Takes in the configuration for a ParlAIRemote model, and establishes the connection + """ + + def __init__(self, config: DictConfig): + self.config = config + self.load_model(config) + + async def force_load(self) -> None: + """ + Force the model loader to connect to the remote service and ensure the + connection is live. + """ + self.load_model(self.config) + + def load_model(self, config: DictConfig) -> None: + """Initialize the model from the given config""" + remote_host = config.get("host", DEFAULT_SERVER) + assert server_is_alive(remote_host), "Remote host failed alive check" + self.remote_opt = Opt( + { + "server": remote_host, + "retries": config.get("retries", DEFAULT_RETRIES), + "timeout": config.get("timeout", DEFAULT_SERVER_TIMEOUT), + } + ) + + def get_model(self, overrides: Optional[Dict[str, Any]] = None) -> Agent: + """Get a copy of the model""" + assert server_is_alive( + self.remote_opt["server"] + ), "Remote host failed alive check" + return ParlAIRemoteAgentWrapper(self.remote_opt) diff --git a/light/world/action_parser.py b/light/world/action_parser.py index d72d28938..8964a32d6 100644 --- a/light/world/action_parser.py +++ b/light/world/action_parser.py @@ -4,12 +4,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from parlai.core.agents import create_agent import parlai.utils.logging as logging from parlai.core.message import Message import copy - import threading +import asyncio +from light.registry.model_pool import ModelTypeName + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from light.registry.model_pool import ModelPool args = {} args["help"] = 0 @@ -92,25 +97,15 @@ def get_input_cands(x, y2, y): class ActionParser: - def __init__(self, opt): - # Create parser model - self.opt = copy.deepcopy(opt) - if "parser_model_file" not in self.opt or self.opt["parser_model_file"] == "": + def __init__(self, model_pool: "ModelPool"): + if model_pool.has_model(ModelTypeName.PARSER): + self.agent = model_pool.get_model(ModelTypeName.PARSER) + else: self.agent = None - return - self.opt["model_file"] = self.opt["parser_model_file"] - self.opt["interactive_candidates"] = "inline" - # self.opt["no_cuda"] = True - self.opt["override"] = { - "interactive_candidates": "inline" - } # , "no_cuda": True} - self.agent = create_agent(self.opt, requireModelExists=True) - self.agent.opt.log() # Lock to handle concurrency, fixed better with asycio self.parse_lock = threading.Condition() - opt["_action_parser"] = self - def parse(self, txt, actor=None): + async def parse(self, txt, actor=None): if self.agent is None: # No model installed, return an empty string. return "" @@ -128,13 +123,13 @@ def parse(self, txt, actor=None): } ) self.agent.observe(query) - res = self.agent.act() + res = await self.agent.act() verb = res["text"] with self.parse_lock: # Given verb, predict the args (unless it's a no-arg action(. if args[verb] > 0: - cands = get_input_cands(txt, verb, txt) + cands = list(get_input_cands(txt, verb, txt)) query2 = Message( { "id": "context", @@ -144,7 +139,7 @@ def parse(self, txt, actor=None): } ) self.agent.observe(query2) - res2 = self.agent.act() + res2 = await self.agent.act() txt = res2["text"] else: txt = verb diff --git a/light/world/content_loggers.py b/light/world/content_loggers.py index 7c5603acc..7f1128474 100644 --- a/light/world/content_loggers.py +++ b/light/world/content_loggers.py @@ -2,6 +2,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree.abs import abc @@ -9,6 +13,7 @@ import os import time import uuid +from light.data_model.db.episodes import DBGroupName, EpisodeLogType # TODO: Investigate changing the format from 3 line to csv or some other standard from light.graph.events.graph_events import ( @@ -16,9 +21,18 @@ DeathEvent, LeaveEvent, SoulSpawnEvent, + SayEvent, + TellEvent, + ShoutEvent, + WhisperEvent, ) -DEFAULT_LOG_PATH = "".join([os.path.abspath(os.path.dirname(__file__)), "/../../logs"]) +from typing import Optional, List, Set, Dict, Tuple, TYPE_CHECKING + +if TYPE_CHECKING: + from light.data_model.db.episodes import EpisodeDB + from light.graph.structured_graph import OOGraph + from light.graph.elements.graph_nodes import GraphAgent class InteractionLogger(abc.ABC): @@ -27,25 +41,32 @@ class InteractionLogger(abc.ABC): location to write data, as well as defines some methods for interfacing """ - def __init__(self, graph, data_path): - self.data_path = data_path + def __init__(self, graph: "OOGraph", episode_db: Optional["EpisodeDB"]): + self.episode_db = episode_db self.graph = graph - + self.players: Set[str] = set() + self.actions: int = 0 + self._last_episode_logged: Optional[str] = None + self.group = ( + DBGroupName.PRE_LAUNCH_TUTORIAL + if graph._opt.get("tutorial") + else DBGroupName.PRE_LAUNCH + ) # All loggers should have graph state history and a buffer for events # State history is just the json of the graph the event executed on - self.state_history = [] - # Event buffer is (state_history_idx, event_hash, timestamp, event_json) + self.state_history: List[str] = [] + # Event buffer is (state_history_idx, event_hash, event_json, timestamp) # where state_history_idx is the index of the graph the event executed on - self.event_buffer = [] + self.event_buffer: List[Tuple[int, str, str, float]] = [] - def _begin_meta_episode(self): + def _begin_meta_episode(self) -> None: """ Handles any preprocessing associated with beginning a meta episode such as clearing buffers and recording initial state """ raise NotImplementedError - def _end_meta_episode(self): + def _end_meta_episode(self) -> None: """ Handles any postprocessing associated with the end of a meta episode such as flushing buffers by writing to data location, and updating variables @@ -53,71 +74,66 @@ def _end_meta_episode(self): self._log_interactions() raise NotImplementedError - def _log_interactions(self): - """ - Writes out the buffers to the location specified by data location, - handling any data specific formatting - """ - raise NotImplementedError - - def observe_event(self, event): + def observe_event(self, event) -> None: """ Examine event passed in, deciding how to save it to the logs """ raise NotImplementedError - def _dump_graphs(self): + def _prep_graphs(self) -> List[Dict[str, str]]: """ - This method is responsible for dumping the graphs of the event logger - to file, recording the identifiers used for the graphs + This method is responsible for preparing the graphs for this event logger """ - # First, check graph path, then write the graph dump - if not os.path.exists(self.data_path): - os.mkdir(self.data_path) - graph_path = os.path.join(self.data_path, "light_graph_dumps") - if not os.path.exists(graph_path): - os.mkdir(graph_path) - states = [] - for state in self.state_history: - unique_graph_name = str(uuid.uuid4()) - states.append(unique_graph_name) + for idx, state in enumerate(self.state_history): + rand_id = str(uuid.uuid4())[:8] + unique_graph_name = f"{time.time():.0f}-{idx}-{rand_id}" graph_file_name = f"{unique_graph_name}.json" - file_path = os.path.join(graph_path, graph_file_name) - with open(file_path, "w") as dump_file: - dump_file.write(state) + states.append( + { + "key": unique_graph_name, + "filename": graph_file_name, + "graph_json": state, + } + ) return states - def _dump_events(self, graph_states, pov, id_): + def _prep_events( + self, + graph_states: List[Dict[str, str]], + target_id: str, + ) -> Tuple[str, List[Dict[str, str]]]: """ This method is responsible for dumping the event logs, referencing the - graph files recorded in graph_states. An event log consist of events, where - an event consist of 3 lines: - serialized_graph_filename event_hash - timestamp - event_json - Event logs are named: {id}_{unique_identifier}.log - and are stored in the `pov/` directory - + graph files recorded in graph_states. """ - # Now, do the same for events, dumping in the light_event_dumps/rooms - events_path = os.path.join(self.data_path, "light_event_dumps") - if not os.path.exists(events_path): - os.mkdir(events_path) - events_path_dir = os.path.join(events_path, pov) - if not os.path.exists(events_path_dir): - os.mkdir(events_path_dir) - - unique_event_name = str(uuid.uuid4()) - id_name = f"{id_}".replace(" ", "_") - event_file_name = f"{id_name}_{unique_event_name}_events.log" - events_file_path = os.path.join(events_path_dir, event_file_name) - with open(events_file_path, "w") as dump_file: - for (idx, hashed, event, time_) in self.event_buffer: - dump_file.write("".join([graph_states[idx], " ", str(hashed), "\n"])) - dump_file.write("".join([time_, "\n"])) - dump_file.write("".join([event, "\n"])) - return events_file_path + unique_event_name = str(uuid.uuid4())[:8] + id_name = f"{target_id}".replace(" ", "_")[:20] + event_file_name = f"{id_name}_{time.time():.0f}_{unique_event_name}_events.json" + events = [] + for (graph_idx, hashed, event, timestamp) in self.event_buffer: + events.append( + { + "graph_key": graph_states[graph_idx]["key"], + "hash": hashed, + "event_json": event, + } + ) + return (event_file_name, events) + + def _log_interactions(self, episode_type: "EpisodeLogType", target_id: str) -> None: + if self.episode_db is None: + return # not actually logging + graphs = self._prep_graphs() + events = self._prep_events(graphs, target_id) + self._last_episode_logged = self.episode_db.write_episode( + graphs=graphs, + events=events, + log_type=episode_type, + action_count=self.actions, + players=self.players, + group=self.group, + ) class AgentInteractionLogger(InteractionLogger): @@ -125,56 +141,41 @@ class AgentInteractionLogger(InteractionLogger): This interaction logger attaches to human agents in the graph, logging all events the human observes. This logger also requires serializing more rooms, since agent encounters many rooms along its traversal These events go into - the conversation buffer, which is then sent to `.log` files - at the specified path - - context_buffers serve an important role in this class to avoid bloating the - event logs. Context_buffers will log a fixed number of the most recent events - when: - - 1. The player goes afk. This has the potential to avoid logging lots of noise - in the room that does not provide any signal on human player interactions. - When the player comes back to the game, our loggers send some context of - the most recent events to the log + the conversation buffer, which is then stored in the provided EpisodeDB """ def __init__( self, - graph, - agent, - data_path=DEFAULT_LOG_PATH, - is_active=False, - max_context_history=5, - afk_turn_tolerance=25, + graph: "OOGraph", + agent: "GraphAgent", + episode_db: Optional["EpisodeDB"] = None, + is_active: bool = False, + afk_turn_tolerance: int = 30, ): - super().__init__(graph, data_path) + super().__init__(graph, episode_db) self.agent = agent - self.max_context_history = max_context_history self.afk_turn_tolerance = afk_turn_tolerance if graph._opt is None: self.is_active = is_active else: - self.data_path = graph._opt.get("log_path", DEFAULT_LOG_PATH) self.is_active = graph._opt.get("is_logging", False) - self.turns_wo_player_action = ( - 0 # Player is acting by virtue of this initialized! - ) - self.context_buffer = collections.deque(maxlen=max_context_history) + self.turns_wo_player_action = 0 self._logging_intialized = False - def _begin_meta_episode(self): + def _begin_meta_episode(self) -> None: self._clear_buffers() self._add_current_graph_state() self.turns_wo_player_action = 0 + self.actions = 0 self._logging_intialized = True - def _clear_buffers(self): + def _clear_buffers(self) -> None: """Clear the buffers storage for this logger, dumping context""" self.state_history.clear() self.event_buffer.clear() - def _add_current_graph_state(self): + def _add_current_graph_state(self) -> None: """Make a copy of the graph state so we can replay events on top of it""" try: self.state_history.append( @@ -187,63 +188,58 @@ def _add_current_graph_state(self): traceback.print_exc() raise - def _is_player_afk(self): + def _is_player_afk(self) -> bool: return self.turns_wo_player_action >= self.afk_turn_tolerance - def _end_meta_episode(self): + def _end_meta_episode(self) -> None: self._logging_intialized = False - self._log_interactions() - - def _log_interactions(self): - - graph_states = self._dump_graphs() - self._last_graphs = graph_states - events_file_path = self._dump_events(graph_states, "agent", self.agent.node_id) - # Used for testing - self._last_event_log = events_file_path + self._add_current_graph_state() + self._log_interactions(EpisodeLogType.AGENT, self.agent.node_id) - def observe_event(self, event): + def observe_event(self, event) -> None: if not self.is_active: return event_t = type(event) if event_t is SoulSpawnEvent and not self._logging_intialized: self._begin_meta_episode() + elif self._is_player_afk(): + if event.actor is self.agent and not self._logging_intialized: + self._begin_meta_episode() + return # Did not have prior graph state, can't log this event + else: + return # skip events while AFK - # Get new room state + # Get new room state when moving if event_t is ArriveEvent and event.actor is self.agent: # NOTE: If this is before executing event, not reliable! self._add_current_graph_state() + elif event_t not in [TellEvent, SayEvent, ShoutEvent, WhisperEvent]: + self.actions += 1 - # Store context from bots, or store current events - if self._is_player_afk() and event.actor is not self.agent: - self.context_buffer.append( - ( - len(self.state_history) - 1, - event.__hash__(), - event.to_json(), - time.ctime(), - ) - ) + # Keep track of presence + if event.actor is self.agent: + self.turns_wo_player_action = 0 else: - if event.actor is self.agent: - if self._is_player_afk(): - self.event_buffer.extend(self.context_buffer) - self.context_buffer.clear() - self.turns_wo_player_action = 0 - else: - self.turns_wo_player_action += 1 - self.event_buffer.append( - ( - len(self.state_history) - 1, - event.__hash__(), - event.to_json(), - time.ctime(), - ) + self.turns_wo_player_action += 1 + + if event.actor.is_player: + user_id = event.actor.user_id + if user_id is not None and user_id not in self.players: + self.players.add(event.actor.user_id) + + # Append the particular event + self.event_buffer.append( + ( + len(self.state_history) - 1, + event.__hash__(), + event.to_json(), + time.time(), ) + ) - if ( - event_t is DeathEvent and event.actor is self.agent - ): # If agent is exiting or dieing or something, end meta episode + if (event_t is DeathEvent and event.actor is self.agent) or ( + self._is_player_afk() + ): # If agent is exiting or dying or afk, end meta episode self._end_meta_episode() @@ -252,45 +248,27 @@ class RoomInteractionLogger(InteractionLogger): This interaction logger attaches to a room level node in the graph, logging all events which take place with human agents in the room as long as a player is still in the room. These events go into the conversation buffer, which is - then sent to `.log` files at the specified path - - - context_buffers serve an important role in this class to avoid bloating the - event logs. context_buffers will log a fixed number of the most recent events - when: - - 1. There are no players in the room. This is a potential use case when an agent - enters a conversation between 2 or more models, and we want some context for - training purposes - - 2. All players go afk. This has the potential to avoid logging lots of noise - in the room that does not provide any signal on human player interactions. - When players come back to the game, our loggers send context of the most - recent events to the log + then logged in the provided EpisodeDB """ def __init__( self, - graph, - room_id, - data_path=DEFAULT_LOG_PATH, - is_active=False, - max_context_history=5, - afk_turn_tolerance=10, + graph: "OOGraph", + room_id: str, + episode_db: Optional["EpisodeDB"] = None, + is_active: bool = False, + afk_turn_tolerance: int = 30, ): - super().__init__(graph, data_path) - self.room_id = room_id - self.max_context_history = max_context_history + super().__init__(graph, episode_db) + self.room_id: str = room_id self.afk_turn_tolerance = afk_turn_tolerance if graph._opt is None: self.is_active = is_active else: - self.data_path = graph._opt.get("log_path", DEFAULT_LOG_PATH) self.is_active = graph._opt.get("is_logging", False) self.num_players_present = 0 self.turns_wo_players = float("inf") # Technically, we have never had players - self.context_buffer = collections.deque(maxlen=max_context_history) # Initialize player count here (bc sometimes players are force moved) for node_id in self.graph.all_nodes[self.room_id].contained_nodes: @@ -299,19 +277,18 @@ def __init__( ): self._add_player() - def _begin_meta_episode(self): + def _begin_meta_episode(self) -> None: self._clear_buffers() self._add_current_graph_state() self.turns_wo_players = 0 + self.actions = 0 - def _clear_buffers(self): - """Clear the buffers storage for this logger, dumping context""" + def _clear_buffers(self) -> None: + """Clear the buffers storage for this logger""" self.state_history.clear() self.event_buffer.clear() - self.event_buffer.extend(self.context_buffer) - self.context_buffer.clear() - def _add_current_graph_state(self): + def _add_current_graph_state(self) -> None: """Make a copy of the graph state so we can replay events on top of it""" try: self.state_history.append(self.graph.to_json_rv(self.room_id)) @@ -322,24 +299,17 @@ def _add_current_graph_state(self): traceback.print_exc() raise - def _is_logging(self): + def _is_logging(self) -> bool: return self.num_players_present > 0 - def _is_players_afk(self): + def _is_players_afk(self) -> bool: return self.turns_wo_players >= self.afk_turn_tolerance - def _end_meta_episode(self): - self._log_interactions() - self.context_buffer.clear() - - def _log_interactions(self): - graph_states = self._dump_graphs() - self._last_graphs = graph_states - events_file_path = self._dump_events(graph_states, "room", self.room_id) - # Used for testing - self._last_event_log = events_file_path + def _end_meta_episode(self) -> None: + self._add_current_graph_state() + self._log_interactions(EpisodeLogType.ROOM, self.room_id) - def _add_player(self): + def _add_player(self) -> None: """ Record that a player entered the room, updating variables as needed""" if not self.is_active: return @@ -347,7 +317,7 @@ def _add_player(self): self._begin_meta_episode() self.num_players_present += 1 - def _remove_player(self): + def _remove_player(self) -> None: """ Record that a player left the room, updating variables as needed""" if not self.is_active: return @@ -356,7 +326,7 @@ def _remove_player(self): if not self._is_logging(): self._end_meta_episode() - def observe_event(self, event): + def observe_event(self, event) -> None: if not self.is_active: return @@ -365,45 +335,46 @@ def observe_event(self, event): if ( event_t is ArriveEvent or event_t is SoulSpawnEvent ) and self.human_controlled(event): + if not self._is_logging(): + self._add_player() + return # Add and return to start logging self._add_player() - # Store context from bots, or store current events - if not self._is_logging() or ( - self._is_players_afk() and not self.human_controlled(event) - ): - self.context_buffer.append( - ( - len(self.state_history) - 1, - event.__hash__(), - event.to_json(), - time.ctime(), - ) - ) - else: - if self.human_controlled(event): - # Players are back from AFK, dump context - if self._is_players_afk(): - # TODO: Need to handle something related to graph state here(?) - self.event_buffer.extend(self.context_buffer) - self.context_buffer.clear() - self.turns_wo_players = 0 + if self._is_players_afk() or not self._is_logging(): + if not self.human_controlled(event): + return # Skip these events else: - self.turns_wo_players += 1 - self.event_buffer.append( - ( - len(self.state_history) - 1, - event.__hash__(), - event.to_json(), - time.ctime(), - ) + self._begin_meta_episode() + return # Don't have previous context, will start on the next one + + if event_t not in [TellEvent, SayEvent, ShoutEvent, WhisperEvent]: + self.actions += 1 + + # Keep track of human events + if self.human_controlled(event): + user_id = event.actor.user_id + if user_id is not None and user_id not in self.players: + self.players.add(event.actor.user_id) + self.turns_wo_players = 0 + else: + self.turns_wo_players += 1 + + # Add to buffer + self.event_buffer.append( + ( + len(self.state_history) - 1, + event.__hash__(), + event.to_json(), + time.time(), ) + ) - if (event_t is LeaveEvent or event_t is DeathEvent) and self.human_controlled( - event - ): + if (event_t in [LeaveEvent, DeathEvent]) and self.human_controlled(event): self._remove_player() + if self._is_players_afk(): + self._end_meta_episode() - def human_controlled(self, event): + def human_controlled(self, event) -> bool: """ Determines if an event is controlled by a human or not """ diff --git a/light/world/purgatory.py b/light/world/purgatory.py index af9e37f90..01bf3e62d 100644 --- a/light/world/purgatory.py +++ b/light/world/purgatory.py @@ -6,7 +6,7 @@ import random import threading -from typing import TYPE_CHECKING, List, Tuple, Type, Callable, Any, Optional, Dict +from typing import TYPE_CHECKING, List, Dict, Tuple, Type, Callable, Any, Optional from light.world.souls.player_soul import PlayerSoul from light.world.souls.tutorial_player_soul import TutorialPlayerSoul @@ -41,14 +41,6 @@ def __init__(self, world: "World"): self.world = world self.player_assign_condition = threading.Condition() self.players = 0 - self.shared_args = {} - - def register_shared_args(self, arg_name, arg_provider): - """ - Used to pass in e.g. the generic act model and roleplaying model scorer to souls. - """ - if arg_provider is not None: - self.shared_args[arg_name] = arg_provider def register_filler_soul_provider( self, @@ -86,7 +78,7 @@ def fill_soul( soul = soul_class(agent, self.world, *arg_provider()) self.node_id_to_soul[agent.node_id] = soul - def send_event_to_soul(self, event: "GraphEvent", agent: "GraphAgent"): + async def send_event_to_soul(self, event: "GraphEvent", agent: "GraphAgent"): """ Pass an GraphEvent along to the soul inhabiting the given GraphAgent if such a soul exists, passing otherwise. Launch in wrapper around @@ -94,20 +86,20 @@ def send_event_to_soul(self, event: "GraphEvent", agent: "GraphAgent"): deciding what to do. """ if agent.get_prop("dead"): - self.clear_soul(agent) + await self.clear_soul(agent) return # We shouldn't send an event to this soul, as it is reaped soul: "Soul" = self.node_id_to_soul.get(agent.node_id) if soul is not None: soul.wrap_observe_event(event) - def clear_soul(self, agent: "GraphAgent") -> None: + async def clear_soul(self, agent: "GraphAgent") -> None: """Clear the soul that is associated with the given agent""" soul = self.node_id_to_soul.get(agent.node_id) if soul is not None: del self.node_id_to_soul[agent.node_id] - soul.reap() + await soul.reap() - def get_soul_for_player( + async def get_soul_for_player( self, player_provider, agent: Optional["GraphAgent"] = None ) -> Optional["Soul"]: """ @@ -118,13 +110,12 @@ def get_soul_for_player( possible_agents = self.world.get_possible_player_nodes() if len(possible_agents) > 0: target_agent = random.choice(possible_agents) - self.clear_soul(target_agent) + await self.clear_soul(target_agent) soul = PlayerSoul( target_agent, self.world, self.players, player_provider, - self.shared_args, ) self.node_id_to_soul[target_agent.node_id] = soul self.player_soul_id_to_soul[self.players] = soul @@ -136,20 +127,19 @@ def get_soul_for_player( class TutorialPurgatory(Purgatory): """Version of purgatory that only ever puts a player into the tutorial character""" - def get_soul_for_player( + async def get_soul_for_player( self, player_provider, agent: Optional["GraphAgent"] = None, ): with self.player_assign_condition: ag = [a for a in self.world.oo_graph.agents.values() if a.name == "You"][0] - self.clear_soul(ag) + await self.clear_soul(ag) soul = TutorialPlayerSoul( ag, self.world, self.players, player_provider, - self.shared_args, ) self.node_id_to_soul[ag.node_id] = soul self.player_soul_id_to_soul[self.players] = soul diff --git a/light/world/quest_loader.py b/light/world/quest_loader.py index 8ac732781..8564b98f4 100644 --- a/light/world/quest_loader.py +++ b/light/world/quest_loader.py @@ -6,6 +6,7 @@ import json import math import random +import asyncio from light.graph.events.graph_events import SystemMessageEvent @@ -226,7 +227,7 @@ def pick_object(actor, graph, verb, arg, other_obj=None, new_loc=None): best_score = score return best_obj - def rank_quests(quests, quest_scorer_model): + async def rank_quests(quests, quest_scorer_model): context = "character: " + quests[0]["actor_name"] + "\n" context += "persona: " + quests[0]["actor_persona"] + "\n" context += "goal: unknown\n" @@ -240,7 +241,7 @@ def rank_quests(quests, quest_scorer_model): "eval_labels": [cands[0]], } quest_scorer_model.observe(msg) - act = quest_scorer_model.act() + act = await quest_scorer_model.act() best_act = act["text"] quest = None for q in quests: @@ -248,7 +249,7 @@ def rank_quests(quests, quest_scorer_model): quest = q return quest - def create_quest(actor, graph, quest_scorer_model=None): + async def create_quest(actor, graph, quest_scorer_model=None): if actor.quests is None or len(actor.quests) == 0: actor.quests = [] else: @@ -267,7 +268,7 @@ def create_quest(actor, graph, quest_scorer_model=None): quest = quests[0] else: # rank the quests with the model scorer - quest = QuestCreator.rank_quests(quests, quest_scorer_model) + quest = await QuestCreator.rank_quests(quests, quest_scorer_model) if quest is None: return None diff --git a/light/world/souls/base_soul.py b/light/world/souls/base_soul.py index 99bae8c83..8fadf0626 100644 --- a/light/world/souls/base_soul.py +++ b/light/world/souls/base_soul.py @@ -4,12 +4,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import asyncio from light.world.souls.soul import Soul from copy import deepcopy import os import asyncio from typing import TYPE_CHECKING, Any, Optional from light.graph.events.graph_events import SystemMessageEvent +from light.registry.model_pool import ModelTypeName if TYPE_CHECKING: from light.graph.elements.graph_nodes import GraphAgent @@ -31,6 +33,13 @@ def __init__(self, target_node: "GraphAgent", world: "World"): super().__init__(target_node, world) self.target_node._last_interaction_partner_id = None self.reset_interaction_history(self.target_node) + self.model_pool = world.model_pool + if self.model_pool.has_model(ModelTypeName.SCORING): + self.roleplaying_score_model = self.model_pool.get_model( + ModelTypeName.SCORING + ) + else: + self.roleplaying_score_model = None def get_last_interaction_partner(self, node=None) -> Optional["GraphAgent"]: if node == None: @@ -198,15 +207,17 @@ def build_dialog_context(self, quest_txt=None): dtxt = "" agent = self.target_node agent_id = agent.node_id + is_self = None turn_id = None for d in agent._last_interaction_history: current_turn_id = d[0][0] - if turn_id == None or turn_id == current_turn_id: + current_is_self = current_turn_id == agent_id + if is_self == None or is_self == current_is_self: dtxt += " " + d[1] else: dtxt = dtxt.lstrip(" ") dtxt += "\n" + d[1] - turn_id = current_turn_id + is_self = current_is_self is_safe = d[0][2] if not is_safe: # reset conversation when unsafe utterances are in the history @@ -215,74 +226,8 @@ def build_dialog_context(self, quest_txt=None): final = txt + dtxt return final - @classmethod - def load_generic_act_model(cls, generic_act_model_file): - """ - Load up and create shared retrieval model for acts, emotes and quest scoring, etc. - """ - # TODO refactor with some kind of model-loading standard for model souls? - from parlai.core.params import ParlaiParser - from parlai.core.agents import create_agent - - parser = ParlaiParser(True, True, "") - # Load action model - args = [ - "-mf", - generic_act_model_file, - "-ecands", - "inline", - "--ignore-bad-candidates", - "True", - ] - act_opt, _unknown = parser.parse_and_process_known_args(args=args) - act_opt["override"] = { - "eval_candidates": "inline", - "ignore_bad_candidates": "True", - } - act_opt["interactive_mode"] = True - act_opt["ignore_bad_candidates"] = True - print("[ Creating generic act model ... ]") - action_model = create_agent(act_opt, requireModelExists=True) - return action_model - ## ----- ROLE PLAYING SCORE FUNCTIONS BELOW - @classmethod - def load_roleplaying_score_model(cls, roleplaying_score_model_file): - """ - Load up and create shared roleplaying score model for use with this class - """ - # TODO refactor with some kind of model-loading standard for model souls? - from parlai.core.params import ParlaiParser - from parlai.core.agents import create_agent - - parser = ParlaiParser(True, True, "") - args = ["-mf", roleplaying_score_model_file] - opt, _unknown = parser.parse_and_process_known_args(args=args) - # opt["interactive_mode"] = True - # return create_agent(opt, requireModelExists=True) - print("[ Creating roleplaying score agent ... ]") - model_opt = {} - model_opt["datapath"] = opt["datapath"] - model_opt["model_file"] = roleplaying_score_model_file - # '/checkpoint/jase/projects/light/beatthehobbot/swp6_light_bi/actmodelv2/model' - # model_opt['fixed_candidates_path'] = ranker_agent.opt['fixed_candidates_path'] - model_opt["candidates"] = "fixed" - model_opt["eval_candidates"] = "fixed" - # model_opt["no_cuda"] = True - model_opt["use_reply"] = "none" - model_opt["interactive_mode"] = True - model_opt["boring_alpha"] = 0 - model_opt["override"] = deepcopy(model_opt) - roleplaying_score_model = create_agent(model_opt) - roleplaying_score_model.boring = None - - # mark this agent as the special RP score agent - roleplaying_score_model.actingscore = True - # override eval step here - roleplaying_score_model.eval_step = roleplaying_score_model.eval_step_scoresonly - return roleplaying_score_model - def too_much_string_overlap(self, s1, s2): """ Check if strings overlap too much. @@ -296,32 +241,39 @@ def too_much_string_overlap(self, s1, s2): return True return False - def get_fixed_cand_scores(self, context): + async def get_fixed_cand_scores(self, context): """ Returns the candidates at self.SAMPLE_INDS """ - self.roleplaying_score_model.opt["eval_candidates"] = "fixed" - self.roleplaying_score_model.eval_candidates = "fixed" # set candidates act = { "text": context, "id": "persona", "episode_done": False, } - self.roleplaying_score_model.reset() - self.roleplaying_score_model.observe(deepcopy(act)) - _ = self.roleplaying_score_model.act() - return self.roleplaying_score_model.scores + self.roleplaying_score_model.observe(act) + score_act = await self.roleplaying_score_model.act() + return score_act["scores"] - def get_pos_human_msg(self, human_msg, scores): + async def get_pos_human_msg(self, human_msg, context, scores): """ Get the model score of the human message and compare to fixed cands. """ - human_score = float(self.roleplaying_score_model.score_one_candidate(human_msg)) - human_rank = int((scores > human_score).sum()) - return human_rank, human_score + act = { + "text": context, + "id": "persona", + "episode_done": False, + "label_candidates": [human_msg], + "eval_labels": [human_msg], + } + self.roleplaying_score_model.observe(deepcopy(act)) + score_act = await self.roleplaying_score_model.act() - def score_conversation(self): - if not hasattr(self, "roleplaying_score_model"): + human_score = float(score_act["scores"][0]) + human_points = len([x for x in scores if x < human_score]) + return human_points, human_score + + async def score_conversation(self): + if self.roleplaying_score_model is None: # For local testing of exp with no models, set this to nonzero return 0 @@ -337,28 +289,14 @@ def score_conversation(self): # check for n-gram match with context if self.too_much_string_overlap(context, human_msg): return 0 - # mark this agent as the special RP score agent - self.roleplaying_score_model.actingscore = True - # override eval step here - self.roleplaying_score_model.eval_step = ( - self.roleplaying_score_model.eval_step_scoresonly + fixed_cand_scores = await self.get_fixed_cand_scores(context) + # We award points on the score ranking, not the raw model score + final_score, _model_score = await self.get_pos_human_msg( + human_msg, context, fixed_cand_scores ) - fixed_cand_scores = self.get_fixed_cand_scores(context) - pos, score = self.get_pos_human_msg(human_msg, fixed_cand_scores) - # print("pos:", pos) - if pos < 1000: - final_score = 4 - elif pos < 2000: - final_score = 3 - elif pos < 5000: - final_score = 2 - elif pos < 10000: - final_score = 1 - else: - final_score = 0 return final_score - def role_playing_score_events(self, event): + async def role_playing_score_events(self, event): # Track event history, and award roleplaying score if appropriate. agent = event.actor if agent != self.target_node: @@ -378,7 +316,7 @@ def role_playing_score_events(self, event): if agent2_id not in agent._agent_interactions: agent._agent_interactions[agent2_id] = 0 - stars = self.score_conversation() + stars = await self.score_conversation() agent._agent_interactions[agent2_id] += stars agent.xp += stars agent.reward_xp += stars / 4.0 diff --git a/light/world/souls/mock_soul.py b/light/world/souls/mock_soul.py index 28e25217c..7e79c726f 100644 --- a/light/world/souls/mock_soul.py +++ b/light/world/souls/mock_soul.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from light.graph.elements.graph_nodes import GraphAgent - from light.graph.world.world import World + from light.world.world import World from light.graph.events.base import GraphEvent @@ -39,8 +39,8 @@ async def observe_event(self, event: "GraphEvent"): """ self.observations.append(event) - def reap(self): + async def reap(self): """ MockSouls don't have any extra resources, and thus don't need to clean up. """ - super().reap() + await super().reap() diff --git a/light/world/souls/model_soul.py b/light/world/souls/model_soul.py index 0d194bffc..e3d2e38fc 100644 --- a/light/world/souls/model_soul.py +++ b/light/world/souls/model_soul.py @@ -12,7 +12,8 @@ if TYPE_CHECKING: from light.graph.elements.graph_nodes import GraphAgent - from light.graph.world.world import World + from light.world.world import World + from light.registry.model_pool import ModelPool class ModelSoul(BaseSoul): @@ -25,19 +26,19 @@ class ModelSoul(BaseSoul): HAS_MAIN_LOOP = False MAIN_LOOP_STEP_TIMEOUT = 5 # seconds between loop actions - def __init__(self, target_node: "GraphAgent", world: "World", models: Any): + def __init__(self, target_node: "GraphAgent", world: "World"): """ All Souls should be attached to a target_node, which is the agent that this soul will be inhabiting. It also takes the world in which that agent exists. """ super().__init__(target_node, world) - self._init_with_models(models) + self._init_with_models(self.model_pool) self._main_loop = None if self.HAS_MAIN_LOOP: self._run_timesteps() - def _init_with_models(self, models) -> None: + def _init_with_models(self, model_pool: "ModelPool") -> None: """ If this model soul requires additional configuration of the models, or setting of attributes depending on these models, handle that here. @@ -70,14 +71,14 @@ async def _run_main_logic_forever(): traceback.print_exc() print("Reaping...") - self.reap() + await self.reap() self._main_loop = asyncio.create_task(_run_main_logic_forever()) - def reap(self): + async def reap(self): """ Clear the main loop, and free any model resources """ - super().reap() + await super().reap() if self._main_loop is not None: self._main_loop.cancel() diff --git a/light/world/souls/models/generative_heuristic_model_soul.py b/light/world/souls/models/generative_heuristic_model_soul.py index 469becd2a..66dc06de8 100644 --- a/light/world/souls/models/generative_heuristic_model_soul.py +++ b/light/world/souls/models/generative_heuristic_model_soul.py @@ -6,19 +6,21 @@ import time import random +import asyncio from collections import deque from light.world.souls.on_event_soul import OnEventSoul from light.graph.events.base import ErrorEvent from light.graph.events.graph_events import TellEvent, SayEvent -from parlai.core.agents import create_agent, create_agent_from_shared from typing import TYPE_CHECKING, List from light.graph.events.graph_events import EmoteEvent +from light.registry.model_pool import ModelTypeName if TYPE_CHECKING: + from light.registry.model_pool import ModelPool from light.graph.elements.graph_nodes import GraphAgent - from light.graph.world.world import World + from light.world.world import World from light.graph.events.base import GraphEvent @@ -52,107 +54,15 @@ class GenerativeHeuristicModelSoul(OnEventSoul): HAS_MAIN_LOOP = True - @classmethod - def load_dialog_model( - cls, - parser, - dialog_model_path, - ): - """ - Load up the dialog model for use with this class - """ - # Reranker args - dialog_args = [ - "-m", - "projects.light_whoami.agents.expanded_attention:ExpandedDecoderAttentionAndPacerAgent", - "--predictor-model-file", - "zoo:light_whoami/rpa_reranker/model", - "--inference", - "beam", - "-dt", - "valid", - "--beam-context-block-ngram", - "3", - "--beam-block-ngram", - "3", - "--beam-size", - "10", - "--beam-min-length", - "20", - "-mf", - dialog_model_path, - ] - dialog_opt = parser.parse_args(args=dialog_args) - dialog_opt["interactive_mode"] = True - dialog_opt["override"] = { - "inference": "beam", - "beam_context_block_ngram": 3, - "beam_size": 10, - "beam_min_length": 20, - "model": "projects.light_whoami.agents.expanded_attention:ExpandedDecoderAttentionAndPacerAgent", - } - return create_agent(dialog_opt, requireModelExists=True) - - @classmethod - def load_models( - cls, - dialog_model_path, - act_model_path=None, - ): - """ - Load up and create possible shared models for use with this class - """ - from parlai.core.params import ParlaiParser - from parlai.core.agents import create_agent - - parser = ParlaiParser(True, True, "") - - dialog_model = cls.load_dialog_model( - parser, - dialog_model_path, - ) - - if act_model_path is not None: - # TODO @Kurt do we have an action model in character? Or just dialogue? - # Load action model - args = [ - "-mf", - act_model_path, - "-ecands", - "inline", - "--ignore-bad-candidates", - "True", - ] - act_opt, _unknown = parser.parse_and_process_known_args(args=args) - - act_opt["override"] = { - "eval_candidates": "inline", - "ignore_bad_candidates": "True", - } - act_opt["interactive_mode"] = True - act_opt["ignore_bad_candidates"] = True - action_model = create_agent(act_opt, requireModelExists=True) - action_model_share = action_model.share() - else: - action_model_share = None - - return { - "shared_dialog_model": dialog_model.share(), - "shared_action_model": action_model_share, - } - - def _init_with_models(self, models) -> None: + def _init_with_models(self, model_pool) -> None: """ Initialize required members of this soul for tracking the model and interactions with it. """ self._pending_observations = [] self._last_action_time = time.time() + self._get_random_time_offset() - self.npc_dialog_model = create_agent_from_shared(models["shared_dialog_model"]) - if models["shared_action_model"] is not None: - self.npc_act_model = create_agent_from_shared(models["shared_action_model"]) - else: - self.npc_act_model = self.generic_act_model + self.npc_dialog_model = model_pool.get_model(ModelTypeName.DIALOG) + self.npc_act_model = model_pool.get_model(ModelTypeName.ACTION) self.reset_interaction_history(self.target_node) async def observe_event(self, event: "GraphEvent"): @@ -168,7 +78,7 @@ async def observe_event(self, event: "GraphEvent"): super().log_interaction_from_event(event) if self.target_node._dying: return - super().quest_events(event) + await super().quest_events(event) did_event = super().on_events(event) did_trade = super().trade_event_heuristics(event) @@ -243,7 +153,7 @@ def dialogue_pick_non_repeating_response(self, act, partner): def get_last_turn_too_recent(self): return time.time() - self._last_action_time < MIN_TIME_BETWEEN_TURNS - def npc_action(self): + async def npc_action(self): """ Agent attempt to take an action """ @@ -277,7 +187,7 @@ def npc_action(self): "eval_labels": [cands[0]], } self.npc_act_model.observe(msg) - act = self.npc_act_model.act() + act = await self.npc_act_model.act() scores = {} for i in range(0, 3): scores[act["text_candidates"][i]] = float(act["sorted_scores"][i]) @@ -323,12 +233,12 @@ def npc_action(self): "eval_labels": [cands[0]], } self.npc_act_model.observe(msg) - act = self.npc_act_model.act() + act = await self.npc_act_model.act() act_text = act["text"] act_text = self.npc_pick_non_repeating_action(act_text) if act_text is None: return - self.world.parse_exec(agent_id, act_text) + await self.world.parse_exec(agent_id, act_text) return True if best_type == "emote": @@ -342,7 +252,7 @@ def npc_action(self): "eval_labels": [cands[0]], } self.npc_act_model.observe(msg) - act = self.npc_act_model.act() + act = await self.npc_act_model.act() act_text = act["text"] act_text = self.npc_pick_non_repeating_action(act_text) if act_text is None: @@ -354,7 +264,7 @@ def npc_action(self): return True return False - def npc_dialogue(self, obs=None): + async def npc_dialogue(self, obs=None): """ Attempt to take a dialogue turn """ @@ -404,7 +314,7 @@ def npc_dialogue(self, obs=None): # Send to model to process msg = {"text": context, "episode_done": True} self.npc_dialog_model.observe(msg) - act = self.npc_dialog_model.act() + act = await self.npc_dialog_model.act() act_text = self.dialogue_pick_non_repeating_response(act, partner) @@ -434,7 +344,7 @@ async def _take_timestep(self) -> None: ): # Try goal dialog heuristic first, otherwise use normal dialog. if not self.tell_goal_heuristics(obs): - self.npc_dialogue(obs) + await self.npc_dialogue(obs) # possibly initiate talk request to someone in the room if self.get_last_interaction_partner(agent) is None: @@ -450,7 +360,7 @@ async def _take_timestep(self) -> None: ): self.dialogue_switch_partner(agent, partner) try: - self.npc_dialogue(None) + await self.npc_dialogue(None) except Exception as e: print(f"Hit exception {e}") import traceback @@ -469,4 +379,4 @@ async def _take_timestep(self) -> None: # Possibly act according to the transformer model if not acted: - self.npc_action() + await self.npc_action() diff --git a/light/world/souls/models/tutorial_model_soul.py b/light/world/souls/models/tutorial_model_soul.py index 06e0c0b6c..ca9c4c462 100644 --- a/light/world/souls/models/tutorial_model_soul.py +++ b/light/world/souls/models/tutorial_model_soul.py @@ -12,6 +12,7 @@ from light.graph.events.graph_events import ( UnblockEvent, WearEvent, + EquipObjectEvent, TellEvent, SayEvent, GoEvent, @@ -25,6 +26,7 @@ from typing import TYPE_CHECKING, List from light.graph.events.graph_events import EmoteEvent +from light.registry.model_pool import ModelTypeName if TYPE_CHECKING: from light.graph.elements.graph_nodes import GraphAgent @@ -48,11 +50,13 @@ ] SCRIPTED_RESPONSES = { - (""): ( + ("", "hello", "hi"): ( "Welcome my friend to the impossible tavern. I'm glad you're here! " "I'm looking for curious souls to inhabit the residents of the " "world beyond that shimmering portal. If you have a ticket " - "I can let you in and provide you a story to play. " + "I can let you in and provide you a story to play. In the meantime we " + "can chat! I warn you though, out here my mind may wander... " + "Be sure to ask for help if you need it." ), ("boots", "boot"): ( "While you're just a soul in here, it's worthwhile to have some footwear. " @@ -64,6 +68,14 @@ ("carrying", "holding"): ( "You can see what you're carrying with the `inv` command" ), + ("ticket", "tickets"): ( + "I've already distributed all of the tickets. Perhaps you already have one? " + "You should check what you're carrying." + ), + ("where", "the way", "portal", "get there"): ( + "If you're trying to get into the realm of LIGHT, you'll need to go into " + "that portal right over there. You'd need a ticket first though." + ), } @@ -79,109 +91,17 @@ class TutorialModelSoul(OnEventSoul): HAS_MAIN_LOOP = True - @classmethod - def load_dialog_model( - cls, - parser, - dialog_model_path, - ): - """ - Load up the dialog model for use with this class - """ - # Reranker args - dialog_args = [ - "-m", - "projects.light_whoami.agents.expanded_attention:ExpandedDecoderAttentionAndPacerAgent", - "--predictor-model-file", - "zoo:light_whoami/rpa_reranker/model", - "--inference", - "beam", - "-dt", - "valid", - "--beam-context-block-ngram", - "3", - "--beam-block-ngram", - "3", - "--beam-size", - "10", - "--beam-min-length", - "20", - "-mf", - dialog_model_path, - ] - dialog_opt = parser.parse_args(args=dialog_args) - dialog_opt["interactive_mode"] = True - dialog_opt["override"] = { - "inference": "beam", - "beam_context_block_ngram": 3, - "beam_size": 10, - "beam_min_length": 20, - "model": "projects.light_whoami.agents.expanded_attention:ExpandedDecoderAttentionAndPacerAgent", - } - return create_agent(dialog_opt, requireModelExists=True) - - @classmethod - def load_models( - cls, - dialog_model_path, - act_model_path=None, - ): - """ - Load up and create possible shared models for use with this class - """ - from parlai.core.params import ParlaiParser - from parlai.core.agents import create_agent - - parser = ParlaiParser(True, True, "") - - dialog_model = cls.load_dialog_model( - parser, - dialog_model_path, - ) - - if act_model_path is not None: - # TODO @Kurt do we have an action model in character? Or just dialogue? - # Load action model - args = [ - "-mf", - act_model_path, - "-ecands", - "inline", - "--ignore-bad-candidates", - "True", - ] - act_opt, _unknown = parser.parse_and_process_known_args(args=args) - - act_opt["override"] = { - "eval_candidates": "inline", - "ignore_bad_candidates": "True", - } - act_opt["interactive_mode"] = True - act_opt["ignore_bad_candidates"] = True - action_model = create_agent(act_opt, requireModelExists=True) - action_model_share = action_model.share() - else: - action_model_share = None - - return { - "shared_dialog_model": dialog_model.share(), - "shared_action_model": action_model_share, - } - - def _init_with_models(self, models) -> None: + def _init_with_models(self, model_pool) -> None: """ Initialize required members of this soul for tracking the model and interactions with it. """ + self._pending_observations = [] self._last_action_time = time.time() + self._get_random_time_offset() - self.npc_dialog_model = create_agent_from_shared(models["shared_dialog_model"]) - if models["shared_action_model"] is not None: - self.npc_act_model = create_agent_from_shared(models["shared_action_model"]) - else: - self.npc_act_model = None + self.npc_dialog_model = model_pool.get_model(ModelTypeName.DIALOG) + self.npc_act_model = model_pool.get_model(ModelTypeName.ACTION) self.reset_interaction_history(self.target_node) - self.num_dialogue_without_action = 0 self.partner_wearing_boots = False self.partner_gave_ticket = False @@ -273,7 +193,7 @@ def dialogue_pick_non_repeating_response(self, act, partner): def get_last_turn_too_recent(self): return time.time() - self._last_action_time < MIN_TIME_BETWEEN_TURNS - def npc_action(self): + async def npc_action(self): """ Agent attempt to take an action? """ @@ -305,7 +225,7 @@ def npc_action(self): "eval_labels": [cands[0]], } self.npc_act_model.observe(msg) - act = self.npc_act_model.act() + act = await self.npc_act_model.act() scores = {} for i in range(0, 3): scores[act["text_candidates"][i]] = float(act["sorted_scores"][i]) @@ -332,7 +252,7 @@ def npc_action(self): "eval_labels": [cands[0]], } self.npc_act_model.observe(msg) - act = self.npc_act_model.act() + act = await self.npc_act_model.act() act_text = act["text"] act_text = self.npc_pick_non_repeating_action(act_text) if act_text is None: @@ -344,7 +264,7 @@ def npc_action(self): return True return False - def npc_dialogue(self, obs=None): + async def npc_dialogue(self, obs=None): """ Attempt to take a dialogue turn """ @@ -384,7 +304,7 @@ def npc_dialogue(self, obs=None): # Send to model to process msg = {"text": context, "episode_done": True} self.npc_dialog_model.observe(msg) - act = self.npc_dialog_model.act() + act = await self.npc_dialog_model.act() act_text = self.dialogue_pick_non_repeating_response(act, partner) @@ -408,18 +328,19 @@ async def _possibly_get_response(self, text_content: str) -> str: "doing right now, you can try checking your persona on the left. " ) - if self.num_dialogue_without_action > 5: + if self.num_dialogue_without_action > 4: return ( "While I'm happy to talk all day, I do want to be sure you know how to do things " "as well. You can toggle between saying and doing things with the button below, " "or quickly with the ` key. Try it now! See what you're carrying with `inv`, or " - "maybe `examine` some of the things in this room." + "maybe `examine` some of the things in this room. `help` will show you all of the " + "possible commands, in case you've forgotten." ) for key_group in SCRIPTED_RESPONSES.keys(): if key_group not in self.used_responses: for key in key_group: - if key in text_content: + if key in text_content.lower() or key == "": self.used_responses.add(key_group) return SCRIPTED_RESPONSES[key_group] @@ -481,7 +402,9 @@ async def _possibly_follow_script(self, last_action) -> bool: "I'm always open for a hug! Kindness is important in LIGHT" ) HugEvent(self.target_node, [last_action.actor]).execute(self.world) - elif isinstance(last_action, WearEvent): + elif isinstance(last_action, WearEvent) or isinstance( + last_action, EquipObjectEvent + ): if last_action.target_nodes[0].name == "boots": self.partner_wearing_boots = True if self.partner_gave_ticket: @@ -499,9 +422,10 @@ async def _possibly_follow_script(self, last_action) -> bool: print("Maybe should do something with this?", last_action) if response_content is not None: - SayEvent(self.target_node, text_content=response_content).execute( - self.world - ) + canned_response = SayEvent(self.target_node, text_content=response_content) + canned_response.safe = True + canned_response.skip_safety = True + canned_response.execute(self.world) return True else: return None @@ -529,7 +453,7 @@ async def _take_timestep(self) -> None: if isinstance(obs, SayEvent) or ( isinstance(obs, TellEvent) and obs.target_nodes[0] == agent ): - self.npc_dialogue(obs) + await self.npc_dialogue(obs) # Possibly act according to the transformer model - self.npc_action() + await self.npc_action() diff --git a/light/world/souls/on_event_soul.py b/light/world/souls/on_event_soul.py index c272cd954..4ce838c4b 100644 --- a/light/world/souls/on_event_soul.py +++ b/light/world/souls/on_event_soul.py @@ -305,20 +305,20 @@ def on_events(self, event) -> bool: else: return True # Note that we did something with on_events - def new_quest(self): + async def new_quest(self): graph = self.world.oo_graph actor = self.target_node if hasattr(self, "npc_act_model"): - quest = QuestCreator.create_quest(actor, graph, self.npc_act_model) + quest = await QuestCreator.create_quest(actor, graph, self.npc_act_model) else: # no model for generating quests - quest = QuestCreator.create_quest(actor, graph) + quest = await QuestCreator.create_quest(actor, graph) if quest is not None: self.world.send_msg(actor, "New Quest: " + quest["text"]) - def quest_events(self, event): + async def quest_events(self, event): # Possibly create quest if we don't have one. - self.new_quest() + await self.new_quest() actor = self.target_node quests_left = [] if actor.quests is None: @@ -348,7 +348,7 @@ async def observe_event(self, event: "GraphEvent"): self.log_interaction_from_event(event) if self.target_node._dying: return # We're dying, don't do any responding. - self.quest_events(event) + await self.quest_events(event) self.on_events(event) self.trade_event_heuristics(event) self.tell_goal_heuristics(event) @@ -386,6 +386,8 @@ def aggressive_towards(self, other_agent): return False async def _take_timestep(self) -> None: + if self.target_node._dying: + return self.timestep_actions() def timestep_actions(self): diff --git a/light/world/souls/player_soul.py b/light/world/souls/player_soul.py index d4d95df44..f09613670 100644 --- a/light/world/souls/player_soul.py +++ b/light/world/souls/player_soul.py @@ -12,6 +12,7 @@ import random import time from light.graph.events.graph_events import SystemMessageEvent +from light.registry.model_pool import ModelTypeName if TYPE_CHECKING: from light.graph.elements.graph_nodes import GraphAgent @@ -35,7 +36,6 @@ def __init__( world: "World", player_id: str, provider=None, - shared_model_content=None, ): """ PlayerSouls register to a GraphAgent in a World, but also keep track of the @@ -57,17 +57,20 @@ def __init__( ["short_motivation", "mid_motivation", "long_motivation"] ) target_node.persona += QUEST_TEXT + target_quest[goal] - if "rpg_model" in shared_model_content: - self.roleplaying_score_model = shared_model_content["rpg_model"].clone() - if "generic_act_model" in shared_model_content: - self.generic_act_model = shared_model_content["generic_act_model"].clone() - self.agent_logger = AgentInteractionLogger(world.oo_graph, target_node) + model_pool = world.model_pool + if model_pool.has_model(ModelTypeName.SCORING): + self.roleplaying_score_model = model_pool.get_model(ModelTypeName.SCORING) + if model_pool.has_model(ModelTypeName.GENERIC_ACTS): + self.generic_act_model = model_pool.get_model(ModelTypeName.GENERIC_ACTS) + self.agent_logger = AgentInteractionLogger( + world.oo_graph, target_node, episode_db=world._config.episode_db + ) provider.register_soul(self) self.world.oo_graph.room_id_to_loggers[ self.target_node.get_room().node_id ]._add_player() - def handle_act(self, act_text, event_id: Optional[str] = None): + async def handle_act(self, act_text, event_id: Optional[str] = None): """ PlayerSouls must process act text sent from players and enact them on the world. This method is called by the player provider when an action is taken. @@ -89,9 +92,11 @@ def handle_act(self, act_text, event_id: Optional[str] = None): actor = self.target_node actor._last_action_time = time.time() - self.world.parse_exec(self.target_node, act_text, event_id=event_id) + await self.world.parse_exec(self.target_node, act_text, event_id=event_id) + if hasattr(self.target_node, "num_turns"): + self.target_node.num_turns += 1 - def new_quest(self): + async def new_quest(self): if random.random() > 0.01: # Turn these mostly off for now. return @@ -100,16 +105,18 @@ def new_quest(self): actor = self.target_node if hasattr(self, "generic_act_model"): - quest = QuestCreator.create_quest(actor, graph, self.generic_act_model) + quest = await QuestCreator.create_quest( + actor, graph, self.generic_act_model + ) else: # no model for generating quests - quest = QuestCreator.create_quest(actor, graph) + quest = await QuestCreator.create_quest(actor, graph) if quest is not None: self.world.send_msg(actor, "New Quest: " + quest["text"]) - def quest_events(self, event): + async def quest_events(self, event): # Possibly create quest if we don't have one. - self.new_quest() + await self.new_quest() actor = self.target_node quests_left = [] if actor.quests is None: @@ -128,18 +135,18 @@ async def observe_event(self, event: "GraphEvent"): """ self.set_interaction_partners_from_event(event) self.log_interaction_from_event(event) - self.role_playing_score_events(event) + await self.role_playing_score_events(event) check_if_cast_magic_from_event(self, event) - self.quest_events(event) - self.provider.player_observe_event(self, event) + await self.quest_events(event) + await self.provider.player_observe_event(self, event) self.agent_logger.observe_event(event) - def reap(self): + async def reap(self): """ PlayerSouls must remove the player flag from their target GraphAgent when removed, and notify the logger """ - super().reap() + await super().reap() self.target_node.is_player = False self.target_node.persona = self.target_node.persona.split(QUEST_TEXT)[0] self.world.oo_graph.room_id_to_loggers[ @@ -147,4 +154,4 @@ def reap(self): ]._remove_player() if self.agent_logger._logging_intialized: self.agent_logger._end_meta_episode() - self.provider.on_reap_soul(self) + await self.provider.on_reap_soul(self) diff --git a/light/world/souls/repeat_soul.py b/light/world/souls/repeat_soul.py index d304dd376..762079a8f 100644 --- a/light/world/souls/repeat_soul.py +++ b/light/world/souls/repeat_soul.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from light.graph.elements.graph_nodes import GraphAgent - from light.graph.world.world import World + from light.world.world import World from light.graph.events.base import GraphEvent diff --git a/light/world/souls/soul.py b/light/world/souls/soul.py index 0ad5211c5..e7fffd3b1 100644 --- a/light/world/souls/soul.py +++ b/light/world/souls/soul.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from light.graph.elements.graph_nodes import GraphAgent - from light.graph.world.world import World + from light.world.world import World from light.graph.events.base import GraphEvent @@ -81,7 +81,7 @@ async def observe_event(self, event: "GraphEvent"): """ pass - def reap(self): + async def reap(self): """ Free resources associated with this Soul, and ensure any pending futures are cancelled. diff --git a/light/world/souls/tests/test_souls.py b/light/world/souls/tests/test_souls.py index b3136510c..2150e9ecf 100644 --- a/light/world/souls/tests/test_souls.py +++ b/light/world/souls/tests/test_souls.py @@ -2,7 +2,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree.abs +# LICENSE file in the root directory of this source tree. import unittest import time @@ -10,7 +10,7 @@ from light.graph.elements.graph_nodes import GraphAgent from light.graph.structured_graph import OOGraph -from light.world.world import World +from light.world.world import World, WorldConfig from light.graph.events.graph_events import EmoteEvent, SayEvent from light.world.souls.mock_soul import MockSoul from light.world.souls.repeat_soul import RepeatSoul @@ -19,9 +19,17 @@ def async_test(f): def wrapper(*args, **kwargs): coro = f - future = coro(*args, **kwargs) - loop = asyncio.get_event_loop() - loop.run_until_complete(future) + try: + loop = asyncio.get_event_loop() + future = coro(*args, **kwargs) + loop.run_until_complete(future) + except RuntimeError: + try: + loop = asyncio.get_running_loop() + future = coro(*args, **kwargs) + loop.run_until_complete(future) + except RuntimeError: + asyncio.run(coro(*args, **kwargs)) return wrapper @@ -35,7 +43,7 @@ def test_init_soul(self): agent_node = test_graph.add_agent("My test agent", {}) room_node = test_graph.add_room("test room", {}) agent_node.force_move_to(room_node) - test_world = World({}, None, True) + test_world = World(WorldConfig(), True) test_world.oo_graph = test_graph mock_soul = MockSoul(agent_node, test_world) self.assertEqual(agent_node, mock_soul.target_node) @@ -46,7 +54,7 @@ def test_init_soul(self): ) mock_soul.do_act(test_event) - mock_soul.reap() + asyncio.run(mock_soul.reap()) @async_test async def test_message_sending(self): @@ -61,7 +69,7 @@ async def test_message_sending(self): test_node.force_move_to(room_node) repeat_node.force_move_to(room_node) - test_world = World({}, None, True) + test_world = World(WorldConfig(), True) test_world.oo_graph = test_graph purgatory = test_world.purgatory diff --git a/light/world/souls/tutorial_player_soul.py b/light/world/souls/tutorial_player_soul.py index 93cfd3c1d..1a0a332ab 100644 --- a/light/world/souls/tutorial_player_soul.py +++ b/light/world/souls/tutorial_player_soul.py @@ -31,7 +31,6 @@ def __init__( world: "World", player_id: str, provider=None, - shared_model_content=None, ): """ TutorialPlayerSouls register to a GraphAgent in a World, but also keep track of the @@ -43,39 +42,41 @@ def __init__( target_node._last_action_time = time.time() self.player_id = player_id self.provider = provider # TODO link with real provider - self.agent_logger = AgentInteractionLogger(world.oo_graph, target_node) + self.agent_logger = AgentInteractionLogger( + world.oo_graph, target_node, episode_db=world._config.episode_db + ) provider.register_soul(self) self.world.oo_graph.room_id_to_loggers[ self.target_node.get_room().node_id ]._add_player() - def handle_act(self, act_text, event_id: Optional[str] = None): + async def handle_act(self, act_text, event_id: Optional[str] = None): """ PlayerSouls must process act text sent from players and enact them on the world. This method is called by the player provider when an action is taken. """ actor = self.target_node actor._last_action_time = time.time() - self.world.parse_exec(self.target_node, act_text, event_id=event_id) + await self.world.parse_exec(self.target_node, act_text, event_id=event_id) async def observe_event(self, event: "GraphEvent"): """ PlayerSouls pass their observation along to the provider, who will handle getting the correct format to send to the view. """ - self.provider.player_observe_event(self, event) + await self.provider.player_observe_event(self, event) self.agent_logger.observe_event(event) - def reap(self): + async def reap(self): """ PlayerSouls must remove the player flag from their target GraphAgent when removed, and notify the logger """ - super().reap() + await super().reap() self.target_node.is_player = False self.world.oo_graph.room_id_to_loggers[ self.target_node.get_room().node_id ]._remove_player() if self.agent_logger._logging_intialized: self.agent_logger._end_meta_episode() - self.provider.on_reap_soul(self) + await self.provider.on_reap_soul(self) diff --git a/light/world/tests/test_agent_death.py b/light/world/tests/test_agent_death.py index e7f927f46..ed0f68180 100644 --- a/light/world/tests/test_agent_death.py +++ b/light/world/tests/test_agent_death.py @@ -2,7 +2,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree.abs +# LICENSE file in the root directory of this source tree. import unittest import os @@ -20,9 +20,17 @@ def async_test(f): def wrapper(*args, **kwargs): coro = f - future = coro(*args, **kwargs) - loop = asyncio.get_event_loop() - loop.run_until_complete(future) + try: + loop = asyncio.get_event_loop() + future = coro(*args, **kwargs) + loop.run_until_complete(future) + except RuntimeError: + try: + loop = asyncio.get_running_loop() + future = coro(*args, **kwargs) + loop.run_until_complete(future) + except RuntimeError: + asyncio.run(coro(*args, **kwargs)) return wrapper @@ -40,13 +48,13 @@ async def test_run(self): loop = asyncio.get_running_loop() opt = {} opt["load_map"] = os.path.join(LIGHT_DIR, "scripts/examples/complex_world.json") - world_builder = MapJsonBuilder("", debug=False, opt=opt) - g, world = world_builder.get_graph() + world_builder = MapJsonBuilder(episode_db=None, opt=opt) + g, world = await world_builder.get_graph() purgatory = world.purgatory purgatory.register_filler_soul_provider( "battle", BattleRoyaleSoul, - lambda: [{}], + lambda: [], ) for empty_agent in world.oo_graph.agents.values(): purgatory.fill_soul(empty_agent) @@ -78,7 +86,7 @@ async def run_some_time(max_time): await run_some_time(2) # some agents definitely should have died - self.assertTrue(len(g.agents) < current_agents) + self.assertLess(len(g.agents), current_agents) current_agents = len(g.agents) current_objects = len(g.objects) @@ -87,14 +95,14 @@ async def run_some_time(max_time): # try respawning use_ticks = TICKS_TO_CLEAN_CORPSE + ENOUGH_EXTRA_TICKS_TO_ENSURE_CORPSE_CLEANUP for _x in range(use_ticks): - ags = world.clean_corpses_and_respawn() + ags = await world.clean_corpses_and_respawn() for ag in ags: purgatory.fill_soul(ag) # some agents definitely should have respawned - self.assertTrue(len(g.agents) > current_agents) - self.assertTrue(len(g.objects) < current_objects) - self.assertTrue(len(g.dead_nodes) < current_dead) + self.assertGreater(len(g.agents), current_agents) + self.assertLess(len(g.objects), current_objects) + self.assertLess(len(g.dead_nodes), current_dead) if __name__ == "__main__": diff --git a/light/world/tests/test_loggers.py b/light/world/tests/test_loggers.py index 85803216c..61e5b9cd8 100644 --- a/light/world/tests/test_loggers.py +++ b/light/world/tests/test_loggers.py @@ -2,7 +2,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree.abs +# LICENSE file in the root directory of this source tree. import unittest import shutil, tempfile @@ -11,7 +11,7 @@ from light.graph.elements.graph_nodes import GraphAgent from light.graph.structured_graph import OOGraph -from light.world.world import World +from light.world.world import World, WorldConfig from light.graph.events.graph_events import ArriveEvent, LeaveEvent, GoEvent, LookEvent from light.world.content_loggers import AgentInteractionLogger, RoomInteractionLogger from light.world.utils.json_utils import read_event_logs @@ -50,7 +50,7 @@ def setUp_single_room_graph(self): agent_node = test_graph.add_agent("My test agent", {}) room_node = test_graph.add_room("test room", {}) agent_node.force_move_to(room_node) - test_world = World({}, None, True) + test_world = World(WorldConfig(), True) test_world.oo_graph = test_graph return (test_graph, test_world, agent_node, room_node) @@ -62,7 +62,6 @@ def test_init_room_logger(self): test_graph, test_world, agent_node, room_node = initial logger = test_graph.room_id_to_loggers[room_node.node_id] - self.assertEqual(logger.data_path, self.data_dir) self.assertEqual(logger.graph, test_graph) self.assertEqual(logger.state_history, []) self.assertEqual(logger.event_buffer, []) @@ -70,7 +69,6 @@ def test_init_room_logger(self): self.assertFalse(logger._is_logging()) self.assertTrue(logger._is_players_afk()) self.assertTrue(logger.is_active) - self.assertEqual(len(logger.context_buffer), 0) def test_init_agent_logger(self): """ @@ -80,7 +78,6 @@ def test_init_agent_logger(self): test_graph, test_world, agent_node, room_node = initial logger = AgentInteractionLogger(test_graph, agent_node) - self.assertEqual(logger.data_path, self.data_dir) self.assertEqual(logger.graph, test_graph) self.assertEqual(logger.state_history, []) self.assertEqual(logger.event_buffer, []) @@ -88,12 +85,11 @@ def test_init_agent_logger(self): self.assertFalse(logger._logging_intialized) self.assertFalse(logger._is_player_afk()) self.assertTrue(logger.is_active) - self.assertEqual(len(logger.context_buffer), 0) def test_begin_meta_episode_room_logger(self): """ Test calling begin_meta_episode: - - Clears all the buffers (context into event if nonempty) + - Clears all the buffers - adds the graph state from the room POV - counts as a turn of player action - initializes logger @@ -104,13 +100,11 @@ def test_begin_meta_episode_room_logger(self): logger = test_graph.room_id_to_loggers[room_node.node_id] logger.event_buffer.append("Testing NOT!") logger.state_history.append("Testing") - logger.context_buffer.append("Testing") logger._begin_meta_episode() self.assertFalse(logger._is_players_afk()) - self.assertEqual(len(logger.context_buffer), 0) - self.assertEqual(logger.event_buffer, ["Testing"]) - self.assertEqual(len(logger.state_history), 1) + self.assertEqual(len(logger.state_history), 1, "Had extra in buffer") + self.assertEqual(len(logger.event_buffer), 0, "Had extra in buffer") self.assertEqual( logger.state_history[-1], test_graph.to_json_rv(logger.room_id) ) @@ -133,7 +127,6 @@ def test_begin_meta_episode_agent_logger(self): self.assertFalse(logger._is_player_afk()) self.assertTrue(logger._logging_intialized) - self.assertEqual(len(logger.context_buffer), 0) self.assertEqual(len(logger.event_buffer), 0) self.assertEqual(len(logger.state_history), 1) self.assertEqual( @@ -144,18 +137,14 @@ def test_begin_meta_episode_agent_logger(self): def test_end_meta_episode_room_logger(self): """ Test calling end_meta_episode: - - Clears the context buffer ** Note, future test check that things are written properly """ initial = self.setUp_single_room_graph() test_graph, test_world, agent_node, room_node = initial logger = test_graph.room_id_to_loggers[room_node.node_id] - logger.context_buffer.append("Testing") logger._end_meta_episode() - self.assertEqual(len(logger.context_buffer), 0) - def test_end_meta_episode_agent_logger(self): """ Test calling end_meta_episode: @@ -183,13 +172,11 @@ def test_add_player_room_logger(self): logger.event_buffer.append("Testing NOT!") logger.state_history.append("Testing") - logger.context_buffer.append("Testing") logger._add_player() self.assertTrue(logger._is_logging()) self.assertFalse(logger._is_players_afk()) - self.assertEqual(len(logger.context_buffer), 0) - self.assertEqual(logger.event_buffer, ["Testing"]) + self.assertEqual(len(logger.event_buffer), 0) self.assertEqual(len(logger.state_history), 1) self.assertEqual( logger.state_history[-1], test_graph.to_json_rv(logger.room_id) @@ -198,8 +185,7 @@ def test_add_player_room_logger(self): # Another player just ups the count logger._add_player() - self.assertEqual(len(logger.context_buffer), 0) - self.assertEqual(logger.event_buffer, ["Testing"]) + self.assertEqual(len(logger.event_buffer), 0) self.assertEqual(len(logger.state_history), 1) self.assertEqual( logger.state_history[-1], test_graph.to_json_rv(logger.room_id) @@ -226,7 +212,6 @@ def test_remove_player_room_logger(self): # Another player is 0, end episode logger._remove_player() self.assertFalse(logger._is_logging()) - self.assertEqual(len(logger.context_buffer), 0) self.assertEqual(logger.num_players_present, 0) def test_observer_event_goes_context_room_logger(self): @@ -245,12 +230,10 @@ def test_observer_event_goes_context_room_logger(self): test_event5 = ArriveEvent(agent_node, text_content="hello5") test_event6 = ArriveEvent(agent_node, text_content="hello6") - # No player in room, so this should go to context + # No player in room, so this should be skipped logger.observe_event(test_event1) self.assertFalse(logger._is_logging()) self.assertEqual(len(logger.event_buffer), 0) - self.assertEqual(len(logger.context_buffer), 1) - logger.observe_event(test_event2) logger.observe_event(test_event3) logger.observe_event(test_event4) @@ -258,15 +241,6 @@ def test_observer_event_goes_context_room_logger(self): logger.observe_event(test_event6) self.assertFalse(logger._is_logging()) self.assertEqual(len(logger.event_buffer), 0) - self.assertEqual(len(logger.context_buffer), 5) - events = [json for _, _, json, _ in logger.context_buffer] - self.assertFalse(test_event1.to_json() in events) - - # player added, should be in event buffer - logger._add_player() - self.assertTrue(logger._is_logging()) - self.assertEqual(len(logger.event_buffer), 5) - self.assertEqual(len(logger.context_buffer), 0) def test_observe_event_room_logger(self): """ @@ -284,9 +258,6 @@ def test_observe_event_room_logger(self): self.assertTrue(logger._is_logging()) self.assertEqual(len(logger.event_buffer), 1) - self.assertEqual(len(logger.context_buffer), 0) - events = [json for _, _, json, _ in logger.event_buffer] - self.assertTrue(test_event1.to_json() in events) def test_observe_event_agent_logger(self): """ @@ -302,13 +273,10 @@ def test_observe_event_agent_logger(self): logger.observe_event(test_event1) self.assertEqual(len(logger.event_buffer), 1) - self.assertEqual(len(logger.context_buffer), 0) - events = [json for _, _, json, _ in logger.event_buffer] - self.assertTrue(test_event1.to_json() in events) def test_afk_observe_event_room_logger(self): """ - Test that after 10 turns with no player, fill buffer, then dumps into main! + Test that after 30 turns with no player, clear when returns! """ initial = self.setUp_single_room_graph() test_graph, test_world, agent_node, room_node = initial @@ -316,24 +284,22 @@ def test_afk_observe_event_room_logger(self): logger._add_player() test_event1 = ArriveEvent(agent_node, text_content="hello1") - for i in range(20): + for i in range(30): logger.observe_event(test_event1) # Only up to 5 in buffer, that is the limit self.assertTrue(logger._is_players_afk()) - self.assertEqual(len(logger.event_buffer), 10) - self.assertEqual(len(logger.context_buffer), 5) + self.assertEqual(len(logger.event_buffer), 30) - # Now, player event - dump to buffer! + # Now, player event - clear buffer! agent_node.is_player = True logger.observe_event(test_event1) self.assertFalse(logger._is_players_afk()) - self.assertEqual(len(logger.event_buffer), 16) - self.assertEqual(len(logger.context_buffer), 0) + self.assertEqual(len(logger.event_buffer), 0) def test_afk_observe_event_agent_logger(self): """ - Test that after 25 turns with no player, fill buffer, then dumps into main! + Test that after 30 turns with no player, clear, then start new episode! """ initial = self.setUp_single_room_graph() test_graph, test_world, agent_node, room_node = initial @@ -347,165 +313,13 @@ def test_afk_observe_event_agent_logger(self): logger.observe_event(test_event1) self.assertTrue(logger._is_player_afk()) - self.assertEqual(len(logger.event_buffer), 25) - self.assertEqual(len(logger.context_buffer), 5) + self.assertEqual(len(logger.event_buffer), 30) test_event2 = ArriveEvent(agent_node, text_content="hello2") logger.observe_event(test_event2) + logger.observe_event(test_event2) self.assertFalse(logger._is_player_afk()) - self.assertEqual(len(logger.event_buffer), 31) - self.assertEqual(len(logger.context_buffer), 0) - - def test_simple_room_logger_saves_and_loads_init_graph(self): - """ - Test that the room logger properly saves and reloads the initial - graph - """ - # Set up the graph - initial = self.setUp_single_room_graph() - test_graph, test_world, agent_node, room_node = initial - room_logger = test_graph.room_id_to_loggers[room_node.node_id] - - # Check the room json was done correctly - test_init_json = test_world.oo_graph.to_json_rv(room_node.node_id) - room_logger._begin_meta_episode() - room_logger._end_meta_episode() - graph_file = os.path.join( - self.data_dir, "light_graph_dumps", f"{room_logger._last_graphs[-1]}.json" - ) - with open(graph_file, "r") as graph_json_file: - written_init_json = graph_json_file.read() - self.assertEqual(test_init_json, written_init_json) - - def test_simple_room_logger_saves_and_loads_event(self): - """ - Test that the room logger properly saves and reloads an event - """ - # Set up the graph - initial = self.setUp_single_room_graph() - test_graph, test_world, agent_node, room_node = initial - agent_node.is_player = True - room2_node = test_graph.add_room("test room2", {}) - room_logger = test_graph.room_id_to_loggers[room_node.node_id] - - # Check an event json was done correctly - test_event = ArriveEvent(agent_node, text_content="") - test_init_json = test_world.oo_graph.to_json_rv(agent_node.get_room().node_id) - room_logger.observe_event(test_event) - room_logger._end_meta_episode() - - ref_json = test_event.to_json() - event_file = room_logger._last_event_log - self.assertNotEqual(os.stat(event_file).st_size, 0) - buff = read_event_logs(event_file) - assert len(buff) == 1 - - world_name, hash_, timestamp, written_event = buff[0] - self.assertEqual(world_name, room_logger._last_graphs[-1]) - self.assertEqual(hash_, str(test_event.__hash__())) - ref_json = json.loads(ref_json) - event_ref = json.loads(written_event) - self.assertEqual(event_ref, ref_json) - - def test_simple_agent_logger_saves_and_loads_init_graph(self): - """ - Test that the agent logger properly saves and reloads the initial - graph - """ - # Set up the graph - initial = self.setUp_single_room_graph() - test_graph, test_world, agent_node, room_node = initial - - # Check the graph json was done correctly from agent's room - test_init_json = test_world.oo_graph.to_json_rv(room_node.node_id) - agent_logger = AgentInteractionLogger(test_graph, agent_node) - agent_logger._begin_meta_episode() - agent_logger._end_meta_episode() - graph_file = os.path.join( - self.data_dir, "light_graph_dumps", f"{agent_logger._last_graphs[-1]}.json" - ) - with open(graph_file, "r") as graph_json_file: - written_init_json = graph_json_file.read() - self.assertEqual(test_init_json, written_init_json) - - def test_simple_agent_logger_saves_and_loads_event(self): - """ - Test that the agent logger properly saves and reloads an event - """ - # Set up the graph - initial = self.setUp_single_room_graph() - test_graph, test_world, agent_node, room_node = initial - agent_node.is_player = True - room2_node = test_graph.add_room("test room2", {}) - room_logger = test_graph.room_id_to_loggers[room_node.node_id] - - # Check an event json was done correctly - test_event = ArriveEvent(agent_node, text_content="") - test_init_json = test_world.oo_graph.to_json_rv(agent_node.get_room().node_id) - agent_logger = AgentInteractionLogger(test_graph, agent_node) - agent_logger._begin_meta_episode() - agent_logger.observe_event(test_event) - agent_logger._end_meta_episode() - ref_json = test_event.to_json() - event_file = agent_logger._last_event_log - self.assertNotEqual(os.stat(event_file).st_size, 0) - buff = read_event_logs(event_file) - assert len(buff) == 1 - - world_name, hash_, timestamp, written_event = buff[0] - self.assertEqual(world_name, agent_logger._last_graphs[-1]) - self.assertEqual(hash_, str(test_event.__hash__())) - ref_json = json.loads(ref_json) - event_ref = json.loads(written_event) - self.assertEqual(event_ref, ref_json) - - def test_simple_room_logger_e2e(self): - """ - Test that the room logger properly saves and reloads the graph and events - """ - # Set up the graph - initial = self.setUp_single_room_graph() - test_graph, test_world, agent_node, room_node = initial - agent_node.is_player = True - room_node2 = test_graph.add_room("test room2", {}) - room_logger = test_graph.room_id_to_loggers[room_node.node_id] - test_graph.add_paths_between( - room_node, room_node2, "a path to the north", "a path to the south" - ) - test_graph.room_id_to_loggers[room_node.node_id]._add_player() - - # Check the room and event json was done correctly for room_node - event_room_node_observed = LeaveEvent( - agent_node, target_nodes=[room_node2] - ).to_json() - test_init_json = test_world.oo_graph.to_json_rv(room_node.node_id) - - GoEvent(agent_node, target_nodes=[room_node2]).execute(test_world) - - room_logger = test_graph.room_id_to_loggers[room_node.node_id] - graph_file = os.path.join( - self.data_dir, "light_graph_dumps", f"{room_logger._last_graphs[-1]}.json" - ) - self.assertNotEqual(os.stat(graph_file).st_size, 0) - with open(graph_file, "r") as graph_json_file: - written_init_json = graph_json_file.read() - self.assertEqual(test_init_json, written_init_json) - event_file = room_logger._last_event_log - self.assertNotEqual(os.stat(event_file).st_size, 0) - buff = read_event_logs(event_file) - # Go event triggers a leave event as well! - assert len(buff) == 2 - - world_name, hash_, timestamp, written_event = buff[1] - self.assertEqual(world_name, room_logger._last_graphs[-1]) - ref_json = json.loads(event_room_node_observed) - event_ref = json.loads(written_event) - for k in ref_json: - if k == "event_id": - continue - self.assertEqual( - ref_json[k], event_ref[k], f"Event Json should match for LeaveEvent" - ) + self.assertEqual(len(logger.event_buffer), 1) if __name__ == "__main__": diff --git a/light/world/utils/terminal_player_provider.py b/light/world/utils/terminal_player_provider.py index 830d1a1c2..240f4c364 100644 --- a/light/world/utils/terminal_player_provider.py +++ b/light/world/utils/terminal_player_provider.py @@ -26,19 +26,19 @@ def __init__(self, purgatory: "Purgatory"): self.player_soul: Optional["PlayerSoul"] = None self.purgatory = purgatory - def process_terminal_act(self, text: str): + async def process_terminal_act(self, text: str): if self.player_soul is not None and self.player_soul.is_reaped: print("Your soul detaches from the world, lost...") self.player_soul = None if self.player_soul is None: print("Your soul searches for a character to inhabit") - self.purgatory.get_soul_for_player(self) + await self.purgatory.get_soul_for_player(self) if self.player_soul is None: print("No soul could be found for you :(") else: - self.player_soul.handle_act("look") + await self.player_soul.handle_act("look") return - player_agent = self.player_soul.handle_act(text) + player_agent = await self.player_soul.handle_act(text) def register_soul(self, soul: "PlayerSoul"): """Save the soul as a local player soul""" @@ -55,8 +55,8 @@ def player_observe_event(self, soul: "PlayerSoul", event: "GraphEvent"): "\r" + event.view_as(soul.target_node).strip() + "\naction> ", end=" " ) - def on_reap_soul(self, soul: "PlayerSoul") -> None: + async def on_reap_soul(self, soul: "PlayerSoul") -> None: """ Reaping a soul will lead to a need for respawning. """ - self.process_terminal_act("respawn") + await self.process_terminal_act("respawn") diff --git a/light/world/views.py b/light/world/views.py index b85497d43..16fbcf80e 100644 --- a/light/world/views.py +++ b/light/world/views.py @@ -57,7 +57,7 @@ def get_inventory_text_for(self, id, give_empty=True): def get_health_text_for(self, id): """Return the text description of someone's numeric health""" # TODO get the correct values - health = self.world.get_prop(id, "health") + health = self.world.oo_graph.get_node(id).get_prop("health") if health < 0: health = 0 if health is None or health is False: @@ -217,7 +217,7 @@ def name_prefix(self, node, txt, use_the): def name_prefix_id(self, id, txt, use_the): """Get the prefix to prepend an object with in text form""" # Get the preferred prefix type. - pre = self.world.get_prop(id, "name_prefix") + pre = self.world.oo_graph.get_node(id).get_prop("name_prefix") if pre == "": return pre diff --git a/light/world/world.py b/light/world/world.py index 2295628e4..8f4078a77 100644 --- a/light/world/world.py +++ b/light/world/world.py @@ -12,15 +12,17 @@ import emoji import os import random +import asyncio from light.graph.utils import rm, deprecated from light.graph.events.base import GraphEvent, ErrorEvent from light.graph.events.graph_events import ( SpawnEvent, + SpeechEvent, SystemMessageEvent, DeleteObjectEvent, - init_safety_classifier, ) +from light.graph.events.safety import SafetyClassifier from light.graph.events.all_events_list import ( ALL_EVENTS, ALL_EVENTS_LIST, @@ -30,7 +32,14 @@ from light.world.views import WorldViewer from light.world.purgatory import Purgatory -from typing import List, Optional +from typing import List, Optional, Dict, Any, TYPE_CHECKING +from dataclasses import dataclass, field + + +if TYPE_CHECKING: + from light.data_model.db.episodes import EpisodeDB + from light.registry.model_pool import ModelPool + from light.graph.builders.base import GraphBuilder def check_integrity(f): @@ -51,6 +60,36 @@ def wrapper(*args, **kwargs): return wrapper +def get_empty_model_pool(): + from light.registry.model_pool import ModelPool + + return ModelPool() + + +@dataclass +class WorldConfig: + """ + Class containing (optional) world configuration data. Important for + the sub-portions of the broader LIGHTConfig that are world-specific + """ + + # TODO create LIGHTConfig that can write out a WorldConfig + # args: DictConfig (to replace opt) + opt: Optional[Dict[str, Any]] = field(default_factory=dict) + episode_db: Optional["EpisodeDB"] = None + graph_builder: Optional["GraphBuilder"] = None + model_pool: Optional["ModelPool"] = field(default_factory=get_empty_model_pool) + + def copy(self) -> "WorldConfig": + """Return a new shallow copy of this WorldConfig""" + return WorldConfig( + opt=self.opt, + episode_db=self.episode_db, + graph_builder=self.graph_builder, + model_pool=self.model_pool, + ) + + class World(object): """High-level class that manages gameplay logic for players over a graph. Should provide an interface to advance the game, register callbacks, and @@ -60,42 +99,70 @@ class World(object): def __init__( self, - opt, - graph_builder, - debug=False, + config: WorldConfig, + debug: bool = False, ): - self._opt = opt + self._config = config + self._opt = config.opt self._node_freeze = False self._cnt = 0 self.debug = debug - self.oo_graph = OOGraph(opt) + self._oo_graph = OOGraph(config.opt) self.view = WorldViewer(self) self.purgatory = Purgatory(self) - self.opt = opt + model_pool = config.model_pool + if model_pool is None: + from light.registry.model_pool import ModelPool + + # TODO likely cleaner way to get one of these + model_pool = ModelPool() + self.model_pool = config.model_pool # TODO better specific player management? self._player_cnt = 0 self._playerid_to_agentid = {} self._agentid_to_playerid = {} - self.graph_builder = graph_builder # TODO replace with builder + self.graph_builder = config.graph_builder # Set up safety classifier. - init_safety_classifier(self.opt.get("safety_classifier_path", "")) + self.safety_classifier = SafetyClassifier( + self._opt.get("safety_classifier_path", self._opt.get("safety_list")), + model_pool, + ) # Set up magic! - init_magic(self.opt.get("magic_db_path", "/scratch/light/data/magic.db")) + init_magic(self._opt.get("magic_db_path", "/scratch/light/data/magic.db")) # Set up action parser. - self.action_parser = opt.get("_action_parser") + self.action_parser = config.opt.get("_action_parser") if self.action_parser is None: - self.action_parser = ActionParser(opt) + self.action_parser = ActionParser(self.model_pool) + + @property + def oo_graph(self): + """Wrapper around oo_graph allowing us to do special configuration when set""" + return self._oo_graph + + @oo_graph.setter + def oo_graph(self, oo_graph: "OOGraph"): + """ + Wrapper around oo_graph setter allowing us to properly attach room interaction + loggers and handle other initialization + """ + # TODO maybe there's a better way to do this? What happens when we add a new room + # to an existin graph? + self._oo_graph = oo_graph + for room_node in oo_graph.room_id_to_loggers.values(): + room_node.episode_db = self._config.episode_db @staticmethod - def from_graph(graph, graph_builder=None): + def from_graph(graph, config: WorldConfig = None): """Loads the world from the older versions of graph.""" - world = World(graph._opt, graph_builder) + if config is None: + config = WorldConfig() + world = World(config) world.oo_graph = OOGraph.from_graph(graph) world._node_freeze = graph._node_freeze world._cnt = graph._cnt @@ -356,7 +423,14 @@ def send_msg(self, agent_id, txt, action=None): print(txt, agent_id) if action is None: action = SystemMessageEvent(agent, [], text_content=txt) - self.purgatory.send_event_to_soul(action, agent) + try: + # Run in the current event loop, if it exists + curr_loop = asyncio.get_running_loop() + coro = self.purgatory.send_event_to_soul(action, agent) + asyncio.run_coroutine_threadsafe(coro, curr_loop) + except RuntimeError: + # Not in event loop, execute with run + asyncio.run(self.purgatory.send_event_to_soul(action, agent)) # TODO remove below when server game has Soul-based PlayerProvider agent.observe_action(txt, action) pos_playerid = self.agentid_to_playerid(agent_id) @@ -730,15 +804,17 @@ def help_message(self): h = ["Have you tried typing help?"] return random.choice(h) - def parse_exec(self, actor, inst=None, event_id: Optional[str] = None): + async def parse_exec(self, actor, inst=None, event_id: Optional[str] = None): if not isinstance(actor, GraphNode): actor = self.oo_graph.get_node(actor) - if self.opt.get("dont_catch_errors", False): - return self.parse_exec_internal(actor, inst=inst, event_id=event_id) + if self._opt.get("dont_catch_errors", False): + return await self.parse_exec_internal(actor, inst=inst, event_id=event_id) else: try: - return self.parse_exec_internal(actor, inst=inst, event_id=event_id) + return await self.parse_exec_internal( + actor, inst=inst, event_id=event_id + ) except Exception: import traceback @@ -748,7 +824,7 @@ def parse_exec(self, actor, inst=None, event_id: Optional[str] = None): ) return False, "FailedParseExec" - def attempt_parse_event( + async def attempt_parse_event( self, EventClass, actor_node, arguments, event_id: Optional[str] = None ): """Return the possible parsed event given the event, actor, and arguments""" @@ -768,12 +844,25 @@ def attempt_parse_event( if isinstance(result, ErrorEvent): return result + if issubclass(EventClass, SpeechEvent): + # Additionally, run safety + is_safe = await self.safety_classifier.is_safe(result.text) + return EventClass( + actor=actor_node, + target_nodes=result.targets, + text_content=result.text, + event_id=event_id, + safe=is_safe, + ) + # Create the final event. May be an error but that's okay return EventClass.construct_from_args( actor_node, result.targets, result.text, event_id=event_id ) - def parse_exec_internal(self, actor, inst=None, event_id: Optional[str] = None): + async def parse_exec_internal( + self, actor, inst=None, event_id: Optional[str] = None + ): """Try to parse and execute the given event""" # basic replacements inst = self.action_parser.post_process(inst, actor) @@ -798,7 +887,7 @@ def parse_exec_internal(self, actor, inst=None, event_id: Optional[str] = None): and actor.get_prop("human") and actor.get_prop("dead") ): - self.respawn_player(actor.node_id) + await self.respawn_player(actor.node_id) return True, "Respawn" dead = actor.get_prop("dead") if dead or (dead == "ErrorNodeNotFound"): @@ -851,7 +940,7 @@ def parse_exec_internal(self, actor, inst=None, event_id: Optional[str] = None): return True, "Suicide" if executable not in ALL_EVENTS: # Try again with the full model parser. - new_inst = self.action_parser.parse(inst, actor) + new_inst = await self.action_parser.parse(inst, actor) if new_inst != "": instruction_list = new_inst.strip().split() executable = instruction_list[0] @@ -864,7 +953,9 @@ def parse_exec_internal(self, actor, inst=None, event_id: Optional[str] = None): EventClass = ALL_EVENTS[executable] - parsed_event = self.attempt_parse_event(EventClass, actor, arguments, event_id) + parsed_event = await self.attempt_parse_event( + EventClass, actor, arguments, event_id + ) if isinstance(parsed_event, ErrorEvent): self.broadcast_to_agents(parsed_event, [actor]) return False, inst @@ -926,7 +1017,7 @@ def agentid_to_playerid(self, aid): return self._agentid_to_playerid.get(aid) # TODO refactor players - def respawn_player(self, a_id): + async def respawn_player(self, a_id): p_id = self.agentid_to_playerid(a_id) if p_id != None: try: @@ -937,9 +1028,9 @@ def respawn_player(self, a_id): pass p_id2 = self.spawn_player(existing_player_id=p_id) new_a_id = self.playerid_to_agentid(p_id2) - self.parse_exec(new_a_id, "look") + await self.parse_exec(new_a_id, "look") - def clean_corpses_and_respawn(self) -> List[GraphAgent]: + async def clean_corpses_and_respawn(self) -> List[GraphAgent]: """ Clean any corpses that have been lying around for a while, then try to do a respawn for each corpse cleaned. @@ -959,7 +1050,7 @@ def clean_corpses_and_respawn(self) -> List[GraphAgent]: created = [] if self.graph_builder is not None: for _x in range(cleaned_count): - new_agent = self.graph_builder.add_random_new_agent_to_graph(self) + new_agent = await self.graph_builder.add_random_new_agent_to_graph(self) if new_agent is not None: created.append(new_agent) return created diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 000000000..3bb44c406 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy] +plugins = sqlalchemy.ext.mypy.plugin diff --git a/projects/lightqa/crowdsourcing/acute_eval/task_config/pairings_files/pairings.jsonl b/projects/lightqa/crowdsourcing/acute_eval/task_config/pairings_files/pairings.jsonl new file mode 100644 index 000000000..dc0e5fa3a --- /dev/null +++ b/projects/lightqa/crowdsourcing/acute_eval/task_config/pairings_files/pairings.jsonl @@ -0,0 +1,306 @@ +{ + "is_onboarding": false, + "speakers_to_eval": [ + "parlai_internal.projects.light.lightqa.seq2seq2seq.task.agents:StackedKnowledgeDialogueAgent", + "/checkpoint/kshuster/projects/wizard_2.0/parlai_sweeps/bart_sweep1_Fri_Oct__2/395/model", + ], + "dialogue_ids": ["episode0_0_1_0", "episode0_0_1_1"], + "knowledge": "The first Nobel Prize in Physics was awarded in 1901 to Wilhelm Conrad R\u00f6ntgen , of Germany , who received 150,782 SEK , which is equal to 7,731,004 SEK in December 2007 . John Bardeen is the only laureate to win the prize twice -- in 1956 and 1972 . Maria Sk\u0142odowska - Curie also won two Nobel Prizes , for physics in 1903 and chemistry in 1911 . William Lawrence Bragg was , until October 2014 , the youngest ever Nobel laureate ; he won the prize in 1915 at the age of 25 . Two women have won the prize : Curie and Maria Goeppert - Mayer ( 1963 ) . As of 2017 , the prize has been awarded to 206 individuals . There have been six years in which the Nobel Prize in Physics was not awarded ( 1916 , 1931 , 1934 , 1940 -- 1942 ) .", + "dialogue_dicts": [ + { + "speakers": [ + "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "parlai_internal.projects.light.lightqa.seq2seq2seq.task.agents:StackedKnowledgeDialogueAgent", + ], + "dialogue": [ + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "The first Nobel Prize in Physics was awarded in 1901 to Wilhelm Conrad R\u00f6ntgen , of Germany , who received 150,782 SEK , which is equal to 7,731,004 SEK in December 2007 . John Bardeen is the only laureate to win the prize twice -- in 1956 and 1972 . Maria Sk\u0142odowska - Curie also won two Nobel Prizes , for physics in 1903 and chemistry in 1911 . William Lawrence Bragg was , until October 2014 , the youngest ever Nobel laureate ; he won the prize in 1915 at the age of 25 . Two women have won the prize : Curie and Maria Goeppert - Mayer ( 1963 ) . As of 2017 , the prize has been awarded to 206 individuals . There have been six years in which the Nobel Prize in Physics was not awarded ( 1916 , 1931 , 1934 , 1940 -- 1942 ) .", + }, + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "who got the first nobel prize in physics?", + }, + { + "id": "parlai_internal.projects.light.lightqa.seq2seq2seq.task.agents:StackedKnowledgeDialogueAgent", + "text": "Albert Einstein was the first winner of the Nobel prize in physics.", + }, + ], + }, + { + "speakers": [ + "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "/checkpoint/kshuster/projects/wizard_2.0/parlai_sweeps/bart_sweep1_Fri_Oct__2/395/model", + ], + "dialogue": [ + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "The first Nobel Prize in Physics was awarded in 1901 to Wilhelm Conrad R\u00f6ntgen , of Germany , who received 150,782 SEK , which is equal to 7,731,004 SEK in December 2007 . John Bardeen is the only laureate to win the prize twice -- in 1956 and 1972 . Maria Sk\u0142odowska - Curie also won two Nobel Prizes , for physics in 1903 and chemistry in 1911 . William Lawrence Bragg was , until October 2014 , the youngest ever Nobel laureate ; he won the prize in 1915 at the age of 25 . Two women have won the prize : Curie and Maria Goeppert - Mayer ( 1963 ) . As of 2017 , the prize has been awarded to 206 individuals . There have been six years in which the Nobel Prize in Physics was not awarded ( 1916 , 1931 , 1934 , 1940 -- 1942 ) .", + }, + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "who got the first nobel prize in physics?", + }, + { + "id": "/checkpoint/kshuster/projects/wizard_2.0/parlai_sweeps/bart_sweep1_Fri_Oct__2/395/model", + "text": "I'm not sure who won the first Nobel prize in physics, but I know it was awarded to someone in the field of astronomy.", + }, + ], + }, + ], +} +{ + "is_onboarding": false, + "speakers_to_eval": [ + "parlai_internal.projects.light.lightqa.seq2seq2seq.task.agents:StackedKnowledgeDialogueAgent", + "/checkpoint/kshuster/projects/wizard_2.0/parlai_sweeps/bart_sweep1_Fri_Oct__2/395/model", + ], + "dialogue_ids": ["episode1_0_1_0", "episode1_0_1_1"], + "knowledge": "Deadpool 2 is scheduled to be released in the United States on May 18 , 2018 . A sequel , Deadpool 3 , is in development .", + "dialogue_dicts": [ + { + "speakers": [ + "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "parlai_internal.projects.light.lightqa.seq2seq2seq.task.agents:StackedKnowledgeDialogueAgent", + ], + "dialogue": [ + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "Deadpool 2 is scheduled to be released in the United States on May 18 , 2018 . A sequel , Deadpool 3 , is in development .", + }, + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "when is the next deadpool movie being released?", + }, + { + "id": "parlai_internal.projects.light.lightqa.seq2seq2seq.task.agents:StackedKnowledgeDialogueAgent", + "text": "When is the next deadpool movie coming out? May 18, 2018.", + }, + ], + }, + { + "speakers": [ + "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "/checkpoint/kshuster/projects/wizard_2.0/parlai_sweeps/bart_sweep1_Fri_Oct__2/395/model", + ], + "dialogue": [ + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "Deadpool 2 is scheduled to be released in the United States on May 18 , 2018 . A sequel , Deadpool 3 , is in development .", + }, + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "when is the next deadpool movie being released?", + }, + { + "id": "/checkpoint/kshuster/projects/wizard_2.0/parlai_sweeps/bart_sweep1_Fri_Oct__2/395/model", + "text": "I'm not sure, but I know it's going to be a sequel to the first one.", + }, + ], + }, + ], +} +{ + "is_onboarding": false, + "speakers_to_eval": [ + "parlai_internal.projects.light.lightqa.seq2seq2seq.task.agents:StackedKnowledgeDialogueAgent", + "/private/home/ladolphs/code/ParlAI/data/models/hallucination/bart_rag_token/model", + ], + "dialogue_ids": ["episode0_0_2_0", "episode0_0_2_1"], + "knowledge": "The first Nobel Prize in Physics was awarded in 1901 to Wilhelm Conrad R\u00f6ntgen , of Germany , who received 150,782 SEK , which is equal to 7,731,004 SEK in December 2007 . John Bardeen is the only laureate to win the prize twice -- in 1956 and 1972 . Maria Sk\u0142odowska - Curie also won two Nobel Prizes , for physics in 1903 and chemistry in 1911 . William Lawrence Bragg was , until October 2014 , the youngest ever Nobel laureate ; he won the prize in 1915 at the age of 25 . Two women have won the prize : Curie and Maria Goeppert - Mayer ( 1963 ) . As of 2017 , the prize has been awarded to 206 individuals . There have been six years in which the Nobel Prize in Physics was not awarded ( 1916 , 1931 , 1934 , 1940 -- 1942 ) .", + "dialogue_dicts": [ + { + "speakers": [ + "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "parlai_internal.projects.light.lightqa.seq2seq2seq.task.agents:StackedKnowledgeDialogueAgent", + ], + "dialogue": [ + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "The first Nobel Prize in Physics was awarded in 1901 to Wilhelm Conrad R\u00f6ntgen , of Germany , who received 150,782 SEK , which is equal to 7,731,004 SEK in December 2007 . John Bardeen is the only laureate to win the prize twice -- in 1956 and 1972 . Maria Sk\u0142odowska - Curie also won two Nobel Prizes , for physics in 1903 and chemistry in 1911 . William Lawrence Bragg was , until October 2014 , the youngest ever Nobel laureate ; he won the prize in 1915 at the age of 25 . Two women have won the prize : Curie and Maria Goeppert - Mayer ( 1963 ) . As of 2017 , the prize has been awarded to 206 individuals . There have been six years in which the Nobel Prize in Physics was not awarded ( 1916 , 1931 , 1934 , 1940 -- 1942 ) .", + }, + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "who got the first nobel prize in physics?", + }, + { + "id": "parlai_internal.projects.light.lightqa.seq2seq2seq.task.agents:StackedKnowledgeDialogueAgent", + "text": "Albert Einstein was the first winner of the Nobel prize in physics.", + }, + ], + }, + { + "speakers": [ + "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "/private/home/ladolphs/code/ParlAI/data/models/hallucination/bart_rag_token/model", + ], + "dialogue": [ + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "The first Nobel Prize in Physics was awarded in 1901 to Wilhelm Conrad R\u00f6ntgen , of Germany , who received 150,782 SEK , which is equal to 7,731,004 SEK in December 2007 . John Bardeen is the only laureate to win the prize twice -- in 1956 and 1972 . Maria Sk\u0142odowska - Curie also won two Nobel Prizes , for physics in 1903 and chemistry in 1911 . William Lawrence Bragg was , until October 2014 , the youngest ever Nobel laureate ; he won the prize in 1915 at the age of 25 . Two women have won the prize : Curie and Maria Goeppert - Mayer ( 1963 ) . As of 2017 , the prize has been awarded to 206 individuals . There have been six years in which the Nobel Prize in Physics was not awarded ( 1916 , 1931 , 1934 , 1940 -- 1942 ) .", + }, + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "who got the first nobel prize in physics?", + }, + { + "id": "/private/home/ladolphs/code/ParlAI/data/models/hallucination/bart_rag_token/model", + "text": "Marie Curie was the first woman to win the Nobel Prize in physics.", + }, + ], + }, + ], +} +{ + "is_onboarding": false, + "speakers_to_eval": [ + "parlai_internal.projects.light.lightqa.seq2seq2seq.task.agents:StackedKnowledgeDialogueAgent", + "/private/home/ladolphs/code/ParlAI/data/models/hallucination/bart_rag_token/model", + ], + "dialogue_ids": ["episode1_0_2_0", "episode1_0_2_1"], + "knowledge": "Deadpool 2 is scheduled to be released in the United States on May 18 , 2018 . A sequel , Deadpool 3 , is in development .", + "dialogue_dicts": [ + { + "speakers": [ + "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "parlai_internal.projects.light.lightqa.seq2seq2seq.task.agents:StackedKnowledgeDialogueAgent", + ], + "dialogue": [ + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "Deadpool 2 is scheduled to be released in the United States on May 18 , 2018 . A sequel , Deadpool 3 , is in development .", + }, + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "when is the next deadpool movie being released?", + }, + { + "id": "parlai_internal.projects.light.lightqa.seq2seq2seq.task.agents:StackedKnowledgeDialogueAgent", + "text": "When is the next deadpool movie coming out? May 18, 2018.", + }, + ], + }, + { + "speakers": [ + "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "/private/home/ladolphs/code/ParlAI/data/models/hallucination/bart_rag_token/model", + ], + "dialogue": [ + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "Deadpool 2 is scheduled to be released in the United States on May 18 , 2018 . A sequel , Deadpool 3 , is in development .", + }, + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "when is the next deadpool movie being released?", + }, + { + "id": "/private/home/ladolphs/code/ParlAI/data/models/hallucination/bart_rag_token/model", + "text": "When is the next Deadpool movie coming out?", + }, + ], + }, + ], +} +{ + "is_onboarding": false, + "speakers_to_eval": [ + "/checkpoint/kshuster/projects/wizard_2.0/parlai_sweeps/bart_sweep1_Fri_Oct__2/395/model", + "/private/home/ladolphs/code/ParlAI/data/models/hallucination/bart_rag_token/model", + ], + "dialogue_ids": ["episode0_1_2_0", "episode0_1_2_1"], + "knowledge": "The first Nobel Prize in Physics was awarded in 1901 to Wilhelm Conrad R\u00f6ntgen , of Germany , who received 150,782 SEK , which is equal to 7,731,004 SEK in December 2007 . John Bardeen is the only laureate to win the prize twice -- in 1956 and 1972 . Maria Sk\u0142odowska - Curie also won two Nobel Prizes , for physics in 1903 and chemistry in 1911 . William Lawrence Bragg was , until October 2014 , the youngest ever Nobel laureate ; he won the prize in 1915 at the age of 25 . Two women have won the prize : Curie and Maria Goeppert - Mayer ( 1963 ) . As of 2017 , the prize has been awarded to 206 individuals . There have been six years in which the Nobel Prize in Physics was not awarded ( 1916 , 1931 , 1934 , 1940 -- 1942 ) .", + "dialogue_dicts": [ + { + "speakers": [ + "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "/checkpoint/kshuster/projects/wizard_2.0/parlai_sweeps/bart_sweep1_Fri_Oct__2/395/model", + ], + "dialogue": [ + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "The first Nobel Prize in Physics was awarded in 1901 to Wilhelm Conrad R\u00f6ntgen , of Germany , who received 150,782 SEK , which is equal to 7,731,004 SEK in December 2007 . John Bardeen is the only laureate to win the prize twice -- in 1956 and 1972 . Maria Sk\u0142odowska - Curie also won two Nobel Prizes , for physics in 1903 and chemistry in 1911 . William Lawrence Bragg was , until October 2014 , the youngest ever Nobel laureate ; he won the prize in 1915 at the age of 25 . Two women have won the prize : Curie and Maria Goeppert - Mayer ( 1963 ) . As of 2017 , the prize has been awarded to 206 individuals . There have been six years in which the Nobel Prize in Physics was not awarded ( 1916 , 1931 , 1934 , 1940 -- 1942 ) .", + }, + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "who got the first nobel prize in physics?", + }, + { + "id": "/checkpoint/kshuster/projects/wizard_2.0/parlai_sweeps/bart_sweep1_Fri_Oct__2/395/model", + "text": "I'm not sure who won the first Nobel prize in physics, but I know it was awarded to someone in the field of astronomy.", + }, + ], + }, + { + "speakers": [ + "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "/private/home/ladolphs/code/ParlAI/data/models/hallucination/bart_rag_token/model", + ], + "dialogue": [ + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "The first Nobel Prize in Physics was awarded in 1901 to Wilhelm Conrad R\u00f6ntgen , of Germany , who received 150,782 SEK , which is equal to 7,731,004 SEK in December 2007 . John Bardeen is the only laureate to win the prize twice -- in 1956 and 1972 . Maria Sk\u0142odowska - Curie also won two Nobel Prizes , for physics in 1903 and chemistry in 1911 . William Lawrence Bragg was , until October 2014 , the youngest ever Nobel laureate ; he won the prize in 1915 at the age of 25 . Two women have won the prize : Curie and Maria Goeppert - Mayer ( 1963 ) . As of 2017 , the prize has been awarded to 206 individuals . There have been six years in which the Nobel Prize in Physics was not awarded ( 1916 , 1931 , 1934 , 1940 -- 1942 ) .", + }, + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "who got the first nobel prize in physics?", + }, + { + "id": "/private/home/ladolphs/code/ParlAI/data/models/hallucination/bart_rag_token/model", + "text": "Marie Curie was the first woman to win the Nobel Prize in physics.", + }, + ], + }, + ], +} +{ + "is_onboarding": false, + "speakers_to_eval": [ + "/checkpoint/kshuster/projects/wizard_2.0/parlai_sweeps/bart_sweep1_Fri_Oct__2/395/model", + "/private/home/ladolphs/code/ParlAI/data/models/hallucination/bart_rag_token/model", + ], + "dialogue_ids": ["episode1_1_2_0", "episode1_1_2_1"], + "knowledge": "Deadpool 2 is scheduled to be released in the United States on May 18 , 2018 . A sequel , Deadpool 3 , is in development .", + "dialogue_dicts": [ + { + "speakers": [ + "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "/checkpoint/kshuster/projects/wizard_2.0/parlai_sweeps/bart_sweep1_Fri_Oct__2/395/model", + ], + "dialogue": [ + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "Deadpool 2 is scheduled to be released in the United States on May 18 , 2018 . A sequel , Deadpool 3 , is in development .", + }, + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "when is the next deadpool movie being released?", + }, + { + "id": "/checkpoint/kshuster/projects/wizard_2.0/parlai_sweeps/bart_sweep1_Fri_Oct__2/395/model", + "text": "I'm not sure, but I know it's going to be a sequel to the first one.", + }, + ], + }, + { + "speakers": [ + "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "/private/home/ladolphs/code/ParlAI/data/models/hallucination/bart_rag_token/model", + ], + "dialogue": [ + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "Deadpool 2 is scheduled to be released in the United States on May 18 , 2018 . A sequel , Deadpool 3 , is in development .", + }, + { + "id": "parlai_internal.projects.light.lightqa.nq_open.task.agents:NQOpenTeacher", + "text": "when is the next deadpool movie being released?", + }, + { + "id": "/private/home/ladolphs/code/ParlAI/data/models/hallucination/bart_rag_token/model", + "text": "When is the next Deadpool movie coming out?", + }, + ], + }, + ], +} diff --git a/projects/longcontext/gen_longcontext.py b/projects/longcontext/gen_longcontext.py index ce5c47f5a..754cb7039 100644 --- a/projects/longcontext/gen_longcontext.py +++ b/projects/longcontext/gen_longcontext.py @@ -40,7 +40,7 @@ def init_world(world_builder): g, world = world_builder.get_graph() purgatory = world.purgatory # Choose the type of NPC souls. - if world.opt["use_models"] == "PartnerHeuristicModelSoul": + if world._opt["use_models"] == "PartnerHeuristicModelSoul": purgatory.register_filler_soul_provider( "model", PartnerHeuristicModelSoul, lambda: [shared_model_content] ) diff --git a/projects/longcontext/partner_heuristic_model_soul.py b/projects/longcontext/partner_heuristic_model_soul.py index b64a5d908..f4d8c514e 100644 --- a/projects/longcontext/partner_heuristic_model_soul.py +++ b/projects/longcontext/partner_heuristic_model_soul.py @@ -245,7 +245,7 @@ def npc_build_context(self, partner_name=None): def get_last_turn_too_recent(self): return time.time() - self._last_action_time < MIN_TIME_BETWEEN_TURNS - def npc_action(self): + async def npc_action(self): """ Agent attempt to take an action """ @@ -290,7 +290,7 @@ def npc_action(self): reply_action = act_text + "\n" # add action to history hist[agent_id].append("_self_act " + act_text + "\\n") - self.world.parse_exec(agent_id, reply_action) + await self.world.parse_exec(agent_id, reply_action) def provide_task(self): # STEP 1: in same room as viewing agent? @@ -489,7 +489,7 @@ def take_timestep(self) -> None: if isinstance(obs, SayEvent) or ( isinstance(obs, TellEvent) and obs.target_nodes[0] == agent ): - self.npc_dialogue(obs) + await self.npc_dialogue(obs) return # possibly initiate talk request to someone in the room @@ -505,7 +505,7 @@ def take_timestep(self) -> None: and self.get_last_interaction_partner(partner) is None ): self.set_interaction_partner(partner) - self.npc_dialogue(None) + await self.npc_dialogue(None) return else: # possibly end interaction with existing interaction partner (if any)? @@ -513,4 +513,4 @@ def take_timestep(self) -> None: self.dialogue_clear_partner() # possibly act according to the bert model - # self.npc_action() + # await self.npc_action() diff --git a/projects/quest_generator/train/sweep1.py b/projects/quest_generator/train/sweep1.py index 7f5f94ab8..dd60cf93c 100644 --- a/projects/quest_generator/train/sweep1.py +++ b/projects/quest_generator/train/sweep1.py @@ -2,6 +2,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. An additional grant +# of patent rights can be found in the PATENTS file in the same directory. + from parlai_internal.projects.param_sweep_utils.param_sweep import run_grid diff --git a/projects/story_agents/create_map.py b/projects/story_agents/create_map.py index 788a4b51a..d560829f5 100644 --- a/projects/story_agents/create_map.py +++ b/projects/story_agents/create_map.py @@ -13,6 +13,7 @@ """ import os +import asyncio from light import LIGHT_DIR from example_builder import ExampleBuilder import light.modeling.tasks.utils as utils @@ -43,7 +44,7 @@ print("[loading builder model...]") world_builder = ExampleBuilder(ldb, debug=False, opt=opt) print("[Building light graph]") - g, world = world_builder.get_graph() + g, world = asyncio.run(world_builder.get_graph()) data = g.to_json() target_loc = os.path.join(CURR_DIR, "outputs", opt["map_file"]) with open(target_loc, "w+") as mapfile: diff --git a/projects/story_agents/example_builder.py b/projects/story_agents/example_builder.py index 38db1f646..c77c64873 100644 --- a/projects/story_agents/example_builder.py +++ b/projects/story_agents/example_builder.py @@ -5,8 +5,9 @@ # LICENSE file in the root directory of this source tree. from light.graph.builders.starspace_assisted import StarspaceBuilder -from light.world.world import World +from light.world.world import World, WorldConfig from light.graph.structured_graph import OOGraph +import asyncio class ExampleBuilder(StarspaceBuilder): @@ -35,10 +36,10 @@ def add_parser_arguments(parser): StarspaceBuilder.add_parser_arguments(parser) parser.add_argument("--use-simple", action="store_true") - def get_graph(self): + async def get_graph(self): """Create a graph""" if not self.use_simple: - return super().get_graph() + return await super().get_graph() else: g = OOGraph(self.opt) @@ -98,6 +99,6 @@ def get_graph(self): "a path aways over", # room 2 -> room 1 ) - world = World(self.opt, self) + world = World(WorldConfig(opt=self.opt, graph_builder=self)) world.oo_graph = g return g, world diff --git a/projects/story_agents/play_map.py b/projects/story_agents/play_map.py index 8b000b26d..5fe09941c 100644 --- a/projects/story_agents/play_map.py +++ b/projects/story_agents/play_map.py @@ -19,6 +19,7 @@ from light.graph.builders.map_json_builder import MapJsonBuilder from light.data_model.light_database import LIGHTDatabase from parlai.core.params import ParlaiParser +import asyncio import os @@ -37,8 +38,8 @@ print("[loading db...]") ldb = LIGHTDatabase(LIGHT_DB_FILE_LOC) print("[loading map...]") - world_builder = MapJsonBuilder(ldb, debug=False, opt=opt) - graph, world = world_builder.get_graph() + world_builder = MapJsonBuilder(episode_db=None, opt=opt) + graph, world = asyncio.run(world_builder.get_graph()) # Set up the world purgatory = world.purgatory @@ -52,6 +53,6 @@ while True: for empty_agent in world.oo_graph.agents.values(): inst = input(f"{empty_agent} enter act> ") - world.parse_exec( - empty_agent, inst=inst + asyncio.run( + world.parse_exec(empty_agent, inst=inst) ) # Triggers the event, and following `observe_event`s diff --git a/requirements.txt b/requirements.txt index 4a1ff34d8..96e4b3dbc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ pyzmq>=19.0.1 tqdm>=4.48.0 hydra-core>=1.0.0 mephisto>=1.0.3 +SQLAlchemy>=1.4.36 diff --git a/scripts/browse_game/check_episodes.py b/scripts/browse_game/check_episodes.py index 16f6e7488..3b4bb8e7d 100644 --- a/scripts/browse_game/check_episodes.py +++ b/scripts/browse_game/check_episodes.py @@ -368,24 +368,24 @@ def main(): error_list = get_errors(action_episodes) print( - f"{BOLD_CYAN}---- Episode Stats -----{C.RESET}\n" + f"{C.BOLD_CYAN}---- Episode Stats -----{C.RESET}\n" f"Total Count: {len(episodes)} \tNontrivial: {nontrivial_count} ({nontrivial_prop:2.2f}%)\n" - f"{BOLD_CYAN}---- Remaining stats on Nontrivial ----{C.RESET}\n" + f"{C.BOLD_CYAN}---- Remaining stats on Nontrivial ----{C.RESET}\n" f"Overall Interaction count stats: {overall_turn_details['string']}\n" f"Rooms Travelled stats: {movement_turn_details['string']}\n" f"Speech Count: {speech_count} ({speech_prop:2.2f}% of nontrivial)\n" - f"{BOLD_YELLOW}---- Error Analysis ----{C.RESET}\n" + f"{C.BOLD_YELLOW}---- Error Analysis ----{C.RESET}\n" f"Total errors: {len(error_list)}\n" f"Episodes with parse errors: {error_count} ({error_prop:2.2f}%)\n" - f"{BOLD_CYAN}---- Remaining stats on Speech ----{C.RESET}\n" + f"{C.BOLD_CYAN}---- Remaining stats on Speech ----{C.RESET}\n" f"Overall speech turn stats: {speech_turn_details['string']}\n" f"Human speech turn stats: {human_turn_details['string']}\n" f"Multi-player: 2 or more: {multi_human_count} ({multi_human_prop:2.2f}%) \t" f"3 or more: {three_party_count} ({three_party_prop:2.2f}%)\n" - f"{BOLD_YELLOW}---- Safety ----{C.RESET}\n" + f"{C.BOLD_YELLOW}---- Safety ----{C.RESET}\n" f"Unsafe by player: {actor_unsafe_count} ({actor_unsafe_prop:2.2f}%) \t" f"Unsafe by any: {any_unsafe_count} ({any_unsafe_prop:2.2f}%)\n" - f"{BOLD_CYAN}----- Player action breakdown ----{C.RESET}\n" + f"{C.BOLD_CYAN}----- Player action breakdown ----{C.RESET}\n" f"Overall action breakdown: {overall_action_breakdown}\n" ) diff --git a/scripts/examples/complex_world.json b/scripts/examples/complex_world.json index 885c007ab..ed91dc620 100644 --- a/scripts/examples/complex_world.json +++ b/scripts/examples/complex_world.json @@ -289,7 +289,7 @@ "max_distance_from_start_location": 1000000, "max_wearable_items": 3, "max_wieldable_items": 1, - "mission": "I want to pull of the greatest roberry there ever was. I may need to recruit some fellow bandits.", + "mission": "I want to pull off the greatest robbery there ever was. I may need to recruit some fellow bandits.", "movement_energy_cost": 0.0, "name": "bandit", "name_prefix": "a", @@ -569,12 +569,12 @@ "object": false, "on_events": [], "pacifist": false, - "persona": "I am a sheep with the biggest horns in the area. Other sheep give me food and drink. I know how to get humans to check on me.\nYour Mission: Find the best grazeable nibbles in the kingdom", + "persona": "I am a sheep with the biggest horns in the area. Other sheep give me food and drink. I know how to get humans to check on me.", "quests": [ { "actor": "bighorn sheep_136", "actor_name": "bighorn sheep", - "actor_persona": "I am a sheep with the biggest horns in the area. Other sheep give me food and drink. I know how to get humans to check on me.\nYour Mission: Find the best grazeable nibbles in the kingdom", + "actor_persona": "I am a sheep with the biggest horns in the area. Other sheep give me food and drink. I know how to get humans to check on me.", "actor_str": "a bighorn sheep", "agent": "half wild cat_123", "container": null, @@ -1782,7 +1782,7 @@ "object": false, "on_events": [], "pacifist": false, - "persona": "I am known as the town drunk. I frequent the local pubs on a daily basis and drink beer till my belly is full. I commonly start fights with other patrons and get thrown out of the saloon.\nYour Mission: Make friends with everyone! And.. well, hic! .. maybe a drink and a nap or too in between!", + "persona": "I am known as the town drunk. I frequent the local pubs on a daily basis and drink beer till my belly is full. I commonly start fights with other patrons and get thrown out of the saloon.", "quests": [], "room": false, "size": 20, @@ -2873,12 +2873,12 @@ "object": false, "on_events": [], "pacifist": false, - "persona": "I used to live with people, but I was abandoned young to learn how to fend for myself. I like to scratch people. I love to eat their babies!\nYour Mission: Trick people into giving me food, or if they are small enough, their life!", + "persona": "I used to live with people, but I was abandoned young to learn how to fend for myself. I like to scratch people. I love to eat their babies!", "quests": [ { "actor": "half wild cat_123", "actor_name": "half wild cat", - "actor_persona": "I used to live with people, but I was abandoned young to learn how to fend for myself. I like to scratch people. I love to eat their babies!\nYour Mission: Trick people into giving me food, or if they are small enough, their life!", + "actor_persona": "I used to live with people, but I was abandoned young to learn how to fend for myself. I like to scratch people. I love to eat their babies!", "actor_str": "a half wild cat", "agent": "serving boy_51", "container": null, diff --git a/scripts/examples/gen_map.py b/scripts/examples/gen_map.py index 5189d25ce..2d5e3c34a 100644 --- a/scripts/examples/gen_map.py +++ b/scripts/examples/gen_map.py @@ -50,7 +50,7 @@ world_builder = StarspaceBuilder(ldb, debug=False, opt=opt) print("[building...]") -g, world = world_builder.get_graph() +g, world = asyncio.run(world_builder.get_graph()) data = g.to_json() print(data) fw = open("/tmp/map.json", "w") diff --git a/scripts/examples/interactive_action_parser.py b/scripts/examples/interactive_action_parser.py index 53e87e705..a7123302d 100644 --- a/scripts/examples/interactive_action_parser.py +++ b/scripts/examples/interactive_action_parser.py @@ -13,10 +13,13 @@ from parlai.agents.local_human.local_human import LocalHumanAgent from parlai.core.message import Message +from light.registry.model_pool import ModelPool from light.world.action_parser import ActionParser import random +import asyncio +# TODO upgrade to hydra, then test again def setup_args(parser=None): if parser is None: parser = ParlaiParser(True, True, "Interactive chat with a model") @@ -60,14 +63,14 @@ def setup_args(parser=None): def interactive(opt): # Create model and assign it to the specified task - parser = ActionParser(opt) + parser = ActionParser(ModelPool()) human_agent = LocalHumanAgent(opt) world = create_task(opt, [human_agent, parser.agent]) # Show some example dialogs: while not world.epoch_done(): txt = input("Action> ") - parse_txt = parser.parse(txt) + parse_txt = asyncio.run(parser.parse(txt)) print(parse_txt) diff --git a/scripts/examples/light_interactive.py b/scripts/examples/light_interactive.py index 82ee093e8..d7644c6e4 100644 --- a/scripts/examples/light_interactive.py +++ b/scripts/examples/light_interactive.py @@ -10,6 +10,7 @@ import json import random +import asyncio personas_path = "/checkpoint/parlai/zoo/light/personas.json" @@ -114,11 +115,11 @@ def interactive(opt, print_parser=None): last_act = None while True: new_act = {"episode_done": True} - human_act = human_agent.act() + human_act = asyncio.run(human_agent.act()) bot_obs.append(PARTNER_SAY + human_act["text"]) new_act["text"] = "\n".join(bot_obs) agent.observe(new_act) - last_act = agent.act() + last_act = asyncio.run(agent.act()) # get a unique utterance among 100 available candidates if "text_candidates" in last_act: for cand in last_act["text_candidates"]: diff --git a/scripts/examples/play_map.py b/scripts/examples/play_map.py index 71912fd38..d7e8fbda4 100644 --- a/scripts/examples/play_map.py +++ b/scripts/examples/play_map.py @@ -10,13 +10,16 @@ import sys import parlai.utils.misc as parlai_utils +from parlai.core.params import ParlaiParser +from light import LIGHT_DIR from light.graph.builders.map_json_builder import MapJsonBuilder from light.graph.builders.starspace_all import StarspaceBuilder +from light.graph.events.graph_events import init_safety_classifier from light.data_model.light_database import LIGHTDatabase from light.world.utils.terminal_player_provider import TerminalPlayerProvider -from parlai.core.params import ParlaiParser -from light.world.world import World + +from light.world.world import World, WorldConfig from light.world.souls.base_soul import BaseSoul from light.world.souls.repeat_soul import RepeatSoul from light.world.souls.on_event_soul import OnEventSoul @@ -26,37 +29,45 @@ from light.world.souls.models.generative_heuristic_model_with_start_feature_soul import ( GenerativeHeuristicModelWithStartFeatureSoul, ) +from light.registry.model_pool import ModelPool, ModelTypeName +from light.registry.parlai_model import ParlAIModelConfig +from light.registry.models.acting_score_model import ( + ParlAIPolyencoderActingScoreModelConfig, +) + +from typing import Dict, Any + import os import random import numpy import asyncio +CONFIG_DIR = os.path.join(LIGHT_DIR, "light/registry/models/config") random.seed(6) numpy.random.seed(6) shared_model_content = None -def init_world(world_builder): - g, world = world_builder.get_graph() +def init_world(world_builder, opt, model_pool): + g, world = asyncio.run( + world_builder.get_graph(world_config=WorldConfig(model_pool=model_pool)) + ) purgatory = world.purgatory - purgatory.register_shared_args("rpg_model", rpg_model_content) - purgatory.register_shared_args("generic_act_model", generic_act_model_content) # Choose the type of NPC souls. - if opt["use_models"] == "GenerativeHeuristicModelSoul": + if opt["agent_soul"] == "GenerativeHeuristicModelSoul": purgatory.register_filler_soul_provider( - "model", GenerativeHeuristicModelSoul, lambda: [shared_model_content] + "model", GenerativeHeuristicModelSoul, lambda: [] ) - elif opt["use_models"] == "GenerativeHeuristicModelWithStartFeatureSoul": - print("on it") + elif opt["agent_soul"] == "GenerativeHeuristicModelWithStartFeatureSoul": purgatory.register_filler_soul_provider( "model", GenerativeHeuristicModelWithStartFeatureSoul, - lambda: [shared_model_content], + lambda: [], ) - elif opt["use_models"] == "OnEventSoul": - purgatory.register_filler_soul_provider("repeat", OnEventSoul, lambda: [{}]) + elif opt["agent_soul"] == "OnEventSoul": + purgatory.register_filler_soul_provider("repeat", OnEventSoul, lambda: []) else: purgatory.register_filler_soul_provider("repeat", RepeatSoul, lambda: []) @@ -66,12 +77,12 @@ def init_world(world_builder): return provider -async def run_with_builder(world_builder): +async def run_with_builder(world_builder, opt, model_pool): """ Takes in a World object and its OOGraph and allows one to play with a random map """ - player_provider = init_world(world_builder) - player_provider.process_terminal_act("") # get an agent + player_provider = init_world(world_builder, opt, model_pool) + await player_provider.process_terminal_act("") # get an agent await asyncio.sleep(0.01) while True: act = input("\raction> ") @@ -82,110 +93,167 @@ async def run_with_builder(world_builder): return elif act in ["new", "reset"]: print("A mist fills the world and everything resets") - player_provider = init_world(world_builder) - player_provider.process_terminal_act("") # get an agent + player_provider = init_world(world_builder, opt, model_pool) + await player_provider.process_terminal_act("") # get an agent await asyncio.sleep(0.01) else: - player_provider.process_terminal_act(act) + await player_provider.process_terminal_act(act) await asyncio.sleep(0.01) -parser = ParlaiParser() -parser.add_argument( - "--use-models", - type=str, - default="OnEventSoul", - choices={ - "OnEventSoul", - "RepeatSoul", - "GenerativeHeuristicModelSoul", - "GenerativeHeuristicModelWithStartFeatureSoul", - }, -) -parser.add_argument( - "--light-model-root", - type=str, - default="/checkpoint/light/models/" - # default="/checkpoint/light/models/" -) -parser.add_argument( - "--load-map", type=str, default="scripts/examples/simple_world.json" -) -parser.add_argument("--dont-catch-errors", type="bool", default=True) -parser.add_argument( - "--safety-classifier-path", - type=str, - default="", - # default="/checkpoint/light/data/safety/reddit_and_beathehobbot_lists/OffensiveLanguage.txt", -) -parser.add_argument( - "--magic-db-path", - type=str, - # default="" - default="/checkpoint/light/magic/magic.db,scripts/examples/special_items.db" - # default = "scripts/examples/special_items.db" -) -parser.add_argument("--allow-save-world", type="bool", default=True) -parser.add_argument( - "--roleplaying-score-model-file", - type=str, - default="", - # default="/checkpoint/light/models/game2020/roleplay_scorer/model", -) -parser.add_argument( - "--generic-act-model-file", - type=str, - default="/checkpoint/light/models/game2021/act_model/model", -) -parser.add_argument( - "--parser-model-file", - type=str, - default="", # "/checkpoint/jase/projects/light/parser/parser3/34c_jobid=1/model" -) -opt, _unknown = parser.parse_and_process_known_args() - -if opt["load_map"] != "none": - Builder = MapJsonBuilder - ldb = "" - world_builder = Builder(ldb, debug=False, opt=opt) -else: - StarspaceBuilder.add_parser_arguments(parser) - opt, _unknown = parser.parse_and_process_known_args() - ldb = LIGHTDatabase(opt["light_db_file"], read_only=True) - world_builder = StarspaceBuilder(ldb, debug=False, opt=opt) - -if opt["roleplaying_score_model_file"] != "": - # Load RPG scorer. - rpg_model_content = BaseSoul.load_roleplaying_score_model( - opt["roleplaying_score_model_file"] +def parse_and_return_args(): + parser = ParlaiParser() + parser.add_argument( + "--agent-soul", + type=str, + default="OnEventSoul", + choices={ + "OnEventSoul", + "RepeatSoul", + "GenerativeHeuristicModelSoul", + "GenerativeHeuristicModelWithStartFeatureSoul", + }, ) -else: - rpg_model_content = None - -if opt["generic_act_model_file"] != "": - generic_act_model_content = BaseSoul.load_generic_act_model( - opt["generic_act_model_file"] + parser.add_argument( + "--light-model-root", + type=str, + default=os.path.join(LIGHT_DIR, "models") + # default="/checkpoint/light/models/" ) -else: - generic_act_model_content = None - -if opt["use_models"] == "GenerativeHeuristicModelSoul": - light_model_root = opt["light_model_root"] - shared_model_content = GenerativeHeuristicModelSoul.load_models( - light_model_root + "game2021/gen_dialog_model/model.checkpoint", + parser.add_argument( + "--load-map", + type=str, + default=os.path.join(LIGHT_DIR, "scripts/examples/simple_world.json"), ) - shared_model_content["shared_action_model"] = generic_act_model_content.share() - -if opt["use_models"] == "GenerativeHeuristicModelWithStartFeatureSoul": - light_model_root = opt["light_model_root"] - shared_model_content = GenerativeHeuristicModelWithStartFeatureSoul.load_models( - light_model_root - + "game2021/gen_dialog_model_with_start_feature/model.checkpoint", - # light_model_root + "game2021/gen_dialog_model/model.checkpoint", + parser.add_argument("--dont-catch-errors", type="bool", default=True) + parser.add_argument( + "--safety-classifier-path", + type=str, + default="", + # default="/checkpoint/light/data/safety/reddit_and_beathehobbot_lists/OffensiveLanguage.txt", ) - shared_model_content["shared_action_model"] = generic_act_model_content.share() + parser.add_argument( + "--magic-db-path", + type=str, + # default="" + default="/checkpoint/light/magic/magic.db,scripts/examples/special_items.db" + # default = "scripts/examples/special_items.db" + ) + parser.add_argument("--allow-save-world", type="bool", default=True) + parser.add_argument( + "--roleplaying-score-opt-file", + type=str, + default=os.path.join(CONFIG_DIR, "baseline_roleplaying_scorer.opt"), + ) + parser.add_argument( + "--acting-model-opt-file", + type=str, + default=os.path.join(CONFIG_DIR, "baseline_main_act_model.opt"), + ) + parser.add_argument( + "--generic-act-opt-file", + type=str, + default=os.path.join(CONFIG_DIR, "generic_act_model.opt"), + ) + parser.add_argument( + "--parser-opt-file", + type=str, + default=os.path.join(CONFIG_DIR, "baseline_parser.opt"), + ) + parser.add_argument("--no-models", default=False, action="store_true") + parser.add_argument("--use-safety-model", default=False, action="store_true") + opt, _unknown = parser.parse_and_process_known_args() + return opt -if __name__ == "__main__": +def init_correct_models(opt: Dict[str, Any]) -> ModelPool: + """Produces the correct ModelPool for the given opts""" + model_pool = ModelPool() + if opt["no_models"]: + return model_pool + + os.environ["LIGHT_MODEL_ROOT"] = opt["light_model_root"] + + # Initialize dialog model + agent_soul = opt["agent_soul"] + if agent_soul == "GenerativeHeuristicModelSoul": + model_pool.register_model( + ParlAIModelConfig( + opt_file=os.path.join(CONFIG_DIR, "baseline_generative.opt") + ), + [ModelTypeName.DIALOG], + ) + elif agent_soul == "GenerativeHeuristicModelWithStartFeatureSoul": + model_pool.register_model( + ParlAIModelConfig( + opt_file=os.path.join(CONFIG_DIR, "baseline_generative_with_start.opt") + ), + [ModelTypeName.DIALOG], + ) + + # Initialize Scoring model + roleplaying_opt_target = opt["roleplaying_score_opt_file"] + if roleplaying_opt_target is not None and roleplaying_opt_target != "": + model_pool.register_model( + ParlAIPolyencoderActingScoreModelConfig(opt_file=roleplaying_opt_target), + [ModelTypeName.SCORING], + ) + + # Initialize Acting model + acting_model_opt_target = opt["acting_model_opt_file"] + if acting_model_opt_target is not None and acting_model_opt_target != "": + model_pool.register_model( + ParlAIModelConfig(opt_file=acting_model_opt_target), + [ModelTypeName.ACTION], + ) + + # Initialize Generic Acting model + generic_act_opt_target = opt["generic_act_opt_file"] + if generic_act_opt_target is not None and generic_act_opt_target != "": + model_pool.register_model( + ParlAIModelConfig(opt_file=generic_act_opt_target), + [ModelTypeName.GENERIC_ACTS], + ) + + # Initialize Parser model + parser_opt_targert = opt["parser_opt_file"] + if parser_opt_targert is not None and parser_opt_targert != "": + model_pool.register_model( + ParlAIModelConfig(opt_file=parser_opt_targert), + [ModelTypeName.PARSER], + ) + + # Initialize Safety model + if opt["use_safety_model"]: + model_pool.register_model( + ParlAIModelConfig( + opt_file=os.path.join(CONFIG_DIR, "baseline_adversarial_safety.opt") + ), + [ModelTypeName.SAFETY], + ) + + return model_pool + + +def main(): + opt = parse_and_return_args() + model_pool = init_correct_models(opt) + + if opt["load_map"] != "none": + Builder = MapJsonBuilder + ldb = "" + world_builder = Builder(None, opt=opt) + else: + # TODO FIXME make this all work with Hydra instead + # to have stacked configs + StarspaceBuilder.add_parser_arguments(parser) + opt, _unknown = parser.parse_and_process_known_args() + ldb = LIGHTDatabase(opt["light_db_file"], read_only=True) + world_builder = StarspaceBuilder(ldb, debug=False, opt=opt) + loop = asyncio.get_event_loop() - loop.run_until_complete(run_with_builder(world_builder)) + loop.run_until_complete(run_with_builder(world_builder, opt, model_pool)) + + +if __name__ == "__main__": + main() diff --git a/scripts/examples/play_tutorial.py b/scripts/examples/play_tutorial.py index e1de7d9d7..c6ca71f5d 100644 --- a/scripts/examples/play_tutorial.py +++ b/scripts/examples/play_tutorial.py @@ -40,8 +40,9 @@ async def ainput(string: str) -> str: def init_world(): world_builder = TutorialWorldBuilder(None, opt={"load_map": TUTORIAL_FILE}) - g, world = world_builder.get_graph() + g, world = asyncio.run(world_builder.get_graph()) # NOTE: I just took the act_model_path from elsewhere + # TODO TODO FIXME will need to update shared_resources = TutorialModelSoul.load_models( dialog_model_path="zoo:light_whoami/profile_expanded_attention_128/model", act_model_path="/checkpoint/light/models/game2021/act_model/model", @@ -65,7 +66,7 @@ async def run_tutorial(): Takes in a World object and its OOGraph and allows one to play with a random map """ player_provider = init_world() - player_provider.process_terminal_act("") # get an agent + await player_provider.process_terminal_act("") # get an agent await asyncio.sleep(0.01) while True: act = await ainput("\raction> ") @@ -77,10 +78,10 @@ async def run_tutorial(): elif act in ["new", "reset"]: print("A mist fills the world and everything resets") player_provider = init_world() - player_provider.process_terminal_act("") # get an agent + await player_provider.process_terminal_act("") # get an agent await asyncio.sleep(0.4) else: - player_provider.process_terminal_act(act) + await player_provider.process_terminal_act(act) await asyncio.sleep(0.4) diff --git a/scripts/filtering/reconstruct_logs.py b/scripts/filtering/reconstruct_logs.py index db84f81f2..722b0f085 100644 --- a/scripts/filtering/reconstruct_logs.py +++ b/scripts/filtering/reconstruct_logs.py @@ -6,7 +6,7 @@ from light.graph.events.base import GraphEvent from light.graph.structured_graph import OOGraph from light.world.utils.json_utils import read_event_logs -from light.world.world import World +from light.world.world import World, WorldConfig import argparse import os @@ -52,7 +52,7 @@ def get_world(uuid, graph_dir): graph_file = os.path.join(graph_dir, f"{uuid}.json") with open(graph_file, "r") as graph_json_file: graph = OOGraph.from_json(graph_json_file.read()) - world = World({}, None, False) + world = World(WorldConfig(), False) world.oo_graph = graph return world diff --git a/scripts/misc/add_copyrights.py b/scripts/misc/add_copyrights.py index c1390ddfb..1e328bed7 100644 --- a/scripts/misc/add_copyrights.py +++ b/scripts/misc/add_copyrights.py @@ -47,6 +47,13 @@ MISPLACED_ENV = """ #!/usr/bin/env python3 """ +MISPLACED_ENV_2 = """#!/usr/bin/env python3 + + +# Copyright (c) Meta Platforms, Inc.""" +CORRECT_ENV = """#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc.""" PROBLEM = """ */ @@ -100,13 +107,20 @@ def add_copyright_if_not_present(filename): elif BAD_START_SPACING in contents: contents = contents.replace(BAD_START_SPACING, "/****") + with open(filename, "w") as out_file: + out_file.write(contents) + print(end=LINE_CLEAR) + print(f"Fixed header for {filename}") + elif MISPLACED_ENV_2 in contents: + contents = contents.replace(MISPLACED_ENV_2, CORRECT_ENV) + with open(filename, "w") as out_file: out_file.write(contents) print(end=LINE_CLEAR) print(f"Fixed header for {filename}") elif MISPLACED_ENV in contents: contents = contents.replace(MISPLACED_ENV, "") - contents = "#!/usr/bin/env python3\n\n" + contents + contents = "#!/usr/bin/env python3\n" + contents with open(filename, "w") as out_file: out_file.write(contents) print(end=LINE_CLEAR) diff --git a/scripts/training/conversion.py b/scripts/training/conversion.py index 73490a474..3826a39ed 100644 --- a/scripts/training/conversion.py +++ b/scripts/training/conversion.py @@ -15,8 +15,9 @@ import argparse import pickle import os +import asyncio from light.graph.structured_graph import OOGraph -from light.world.world import World +from light.world.world import World, WorldConfig from light.graph.events.graph_events import SoulSpawnEvent, LookEvent from scripts.filtering.construct_dataset import convert_event_log_dirs @@ -34,8 +35,8 @@ def execute_events(world, transcript): if action != "": if action.startswith("gesture"): action = action.replace("gesture", "emote") - world.parse_exec(event["id"].lower(), action) - world.parse_exec(event["id"].lower(), "say " + event["text"]) + asyncio.run(world.parse_exec(event["id"].lower(), action)) + asyncio.run(world.parse_exec(event["id"].lower(), "say " + event["text"])) def process_episodes(src, log_dir): @@ -50,7 +51,7 @@ def process_episodes(src, log_dir): ep["graph"]._opt["is_logging"] = True ep["graph"]._opt["log_path"] = log_dir new_g = OOGraph.from_graph(ep["graph"]) - world = World(new_g._opt, None) + world = World(WorldConfig(opt=new_g._opt)) world.oo_graph = new_g transcript = ep["conv_info"]["acts"] players = [x for x in new_g.all_nodes.values() if x.agent and x.is_player]