Skip to content
This repository has been archived by the owner on Oct 11, 2023. It is now read-only.

New environment db #295

Merged
merged 36 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
a8ad858
Filling out UserDB
JackUrb May 17, 2022
f229661
Abstract enforce get first
JackUrb May 18, 2022
8eba75d
Clearer argument - num_turns
JackUrb May 18, 2022
411148f
Using enums in DB
JackUrb May 18, 2022
1ef8be9
Initial episode data model
JackUrb May 18, 2022
b198bb9
Update content loggers to use episode formatting
JackUrb May 18, 2022
256ef7b
Updating tables to work with testing
JackUrb May 18, 2022
0cd1267
Fixing some test changes
JackUrb May 18, 2022
9167d22
Fixing small warnings that were noise during tests
JackUrb May 18, 2022
f32a6d0
Moving default log path
JackUrb May 19, 2022
320d4f1
Test fix
JackUrb May 19, 2022
e178702
Correcting math thanks to Kurt
JackUrb May 20, 2022
22836bd
Updating env DB classes to SQLAlchemy
JackUrb Jul 15, 2022
dffe54c
Name keys and Elems coded
JackUrb Jul 18, 2022
bcc7766
Adding arbitrary node attributes
JackUrb Jul 19, 2022
1083f9d
First complete pass of EnvDB
JackUrb Jul 19, 2022
49b4fb6
Mypy fixings
JackUrb Jul 20, 2022
cabd75a
Fixing agents
JackUrb Jul 20, 2022
9d597e6
Writing some tests
JackUrb Jul 20, 2022
b71af9b
Finishing tests for object and room creates and queries
JackUrb Jul 20, 2022
ea380f1
Edge testing
JackUrb Jul 20, 2022
29f3489
Arbitrary attributes testing
JackUrb Jul 20, 2022
30f333f
Quests and testing
JackUrb Jul 22, 2022
688729a
And finally, DBGraph tests
JackUrb Jul 22, 2022
f25e2ac
fixing episode change
JackUrb Jul 22, 2022
d3099f1
TODO function
JackUrb Jul 22, 2022
7e8b0ed
final mypy fixes
JackUrb Jul 22, 2022
fffa66f
DBID testing
JackUrb Jul 25, 2022
c8d3b7b
a -> either a or an depending on aeiou
JackUrb Jul 25, 2022
bc78c56
Merge branch 'new-data-model' into new-users-table
JackUrb Jul 28, 2022
9a18269
Merge branch 'new-users-table' into new-episode-logging
JackUrb Jul 28, 2022
71ea8ab
Merge branch 'new-episode-logging' into new-environment-db
JackUrb Jul 28, 2022
f4f3ee6
Merge branch 'new-data-model' into new-users-table
JackUrb Jul 29, 2022
2f35c75
Merge branch 'new-users-table' into new-episode-logging
JackUrb Jul 29, 2022
0eefe18
Merge branch 'new-episode-logging' into new-environment-db
JackUrb Jul 29, 2022
d14b4c4
Merge branch 'new-data-model' into new-environment-db
JackUrb Aug 22, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 81 additions & 14 deletions light/data_model/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,30 @@

from abc import ABC, abstractmethod
from omegaconf import MISSING, DictConfig
from contextlib import contextmanager
from sqlalchemy import create_engine
from enum import Enum
from typing import Optional, Union, Dict, Any
from dataclasses import dataclass
from tempfile import mkdtemp
import shutil
import os
import json

from hydra.core.config_store import ConfigStore

DEFAULT_LOG_PATH = "".join(
[os.path.abspath(os.path.dirname(__file__)), "/../../../logs"]
)


@dataclass
class LightDBConfig:
backend: str = "test"
file_root: Optional[str] = DEFAULT_LOG_PATH


cs = ConfigStore.instance()
cs.store(name="config1", node=LightDBConfig)


class DBStatus(Enum):
Expand Down Expand Up @@ -48,6 +69,18 @@ def __init__(self, config: "DictConfig"):
Create this database, either connecting to a remote host or local
files and instances.
"""
# TODO replace with a swappable engine that persists the data
self.backend = config.backend
if config.backend == "test":
self.engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
self.made_temp_dir = config.file_root is None
if self.made_temp_dir:
self.file_root = mkdtemp()
else:
self.file_root = config.file_root
else:
raise NotImplementedError()
self._complete_init(config)

@abstractmethod
def _complete_init(self, config: "DictConfig"):
Expand All @@ -61,20 +94,60 @@ def _validate_init(self):
Ensure that this database is initialized correctly
"""

def _enforce_get_first(self, session, stmt, error_text) -> Any:
"""
Enforce getting the first element using stmt, raise a key_error
with error_text if fails to find
"""
result = session.scalars(stmt).first()
if result is None:
raise KeyError(error_text)
return result

def file_path_exists(self, file_path: str) -> bool:
"""
Determine if the given file path exists on this storage
"""
if self.backend in ["test", "local"]:
full_path = os.path.join(self.file_root, file_path)
return os.path.exists(full_path)
else:
raise NotImplementedError

def write_data_to_file(
self, data: Union[str, Dict[str, Any]], filename: str, json_encode: bool = False
):
) -> None:
"""
Write the given data to the provided filename
in the correct storage location (local or remote)
"""
# Expects data to be a string, unless json_encode is True
if self.backend in ["test", "local"]:
full_path = os.path.join(self.file_root, filename)
os.makedirs(os.path.dirname(full_path), exist_ok=True)
with open(full_path, "w+") as target_file:
if json_encode:
json.dump(data, target_file)
else:
target_file.write(data)
else:
raise NotImplementedError()

def read_data_from_file(self, filename: str, json_encoded: bool = False):
def read_data_from_file(
self, filename: str, json_encoded: bool = False
) -> Union[str, Dict[str, Any]]:
"""
Read the data from the given filename from wherever it
is currently stored (local or remote)
"""
if self.backend in ["test", "local"]:
full_path = os.path.join(self.file_root, filename)
with open(full_path, "r") as target_file:
if json_encoded:
return json.load(target_file)
else:
return target_file.read()
else:
raise NotImplementedError()

def open_file(self):
try:
Expand All @@ -83,13 +156,7 @@ def open_file(self):
finally:
file.close()

@contextmanager
def get_database_connection(self):
"""Get a connection to the database that can be used for a transaction"""
try:
# Get DB connection
# yield db connection
raise NotImplementedError()
finally:
# close DB connection, rollback if there's an issue
pass
def shutdown(self):
if self.backend == "test":
if self.made_temp_dir:
shutil.rmtree(self.file_root)
Loading