Skip to content

Commit

Permalink
Accelerate get_infos by caching the DataseInfoDicts (#778)
Browse files Browse the repository at this point in the history
* accelerate `get_infos` by caching the `DataseInfoDict`s

* quality

* consistency
  • Loading branch information
VictorSanh authored May 22, 2022
1 parent f5c3977 commit d1f16cf
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
5 changes: 4 additions & 1 deletion promptsource/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
DEFAULT_PROMPTSOURCE_CACHE_HOME = "~/.cache/promptsource"
from pathlib import Path


DEFAULT_PROMPTSOURCE_CACHE_HOME = str(Path("~/.cache/promptsource").expanduser())
32 changes: 25 additions & 7 deletions promptsource/app.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import argparse
import functools
import multiprocessing
import os
import textwrap
from hashlib import sha256
from multiprocessing import Manager, Pool

import pandas as pd
import plotly.express as px
import streamlit as st
from datasets import get_dataset_infos
from datasets.info import DatasetInfosDict
from pygments import highlight
from pygments.formatters import HtmlFormatter
from pygments.lexers import DjangoLexer
from templates import INCLUDED_USERS

from promptsource import DEFAULT_PROMPTSOURCE_CACHE_HOME
from promptsource.session import _get_state
from promptsource.templates import DatasetTemplates, Template, TemplateCollection
from promptsource.templates import INCLUDED_USERS, DatasetTemplates, Template, TemplateCollection
from promptsource.utils import (
get_dataset,
get_dataset_confs,
Expand All @@ -25,6 +28,9 @@
)


DATASET_INFOS_CACHE_DIR = os.path.join(DEFAULT_PROMPTSOURCE_CACHE_HOME, "DATASET_INFOS")
os.makedirs(DATASET_INFOS_CACHE_DIR, exist_ok=True)

# Python 3.8 switched the default start method from fork to spawn. OS X also has
# some issues related to fork, eee, e.g., https://github.com/bigscience-workshop/promptsource/issues/572
# so we make sure we always use spawn for consistency
Expand All @@ -38,7 +44,17 @@ def get_infos(all_infos, d_name):
:param all_infos: multiprocess-safe dictionary
:param d_name: dataset name
"""
all_infos[d_name] = get_dataset_infos(d_name)
d_name_bytes = d_name.encode("utf-8")
d_name_hash = sha256(d_name_bytes)
foldername = os.path.join(DATASET_INFOS_CACHE_DIR, d_name_hash.hexdigest())
if os.path.isdir(foldername):
infos_dict = DatasetInfosDict.from_directory(foldername)
else:
infos = get_dataset_infos(d_name)
infos_dict = DatasetInfosDict(infos)
os.makedirs(foldername)
infos_dict.write_to_directory(foldername)
all_infos[d_name] = infos_dict


# add an argument for read-only
Expand Down Expand Up @@ -181,11 +197,13 @@ def show_text(t, width=WIDTH, with_markdown=False):
else:
subset_infos = infos[subset_name]

split_sizes = {k: v.num_examples for k, v in subset_infos.splits.items()}
try:
split_sizes = {k: v.num_examples for k, v in subset_infos.splits.items()}
except Exception:
# Fixing bug in some community datasets.
# For simplicity, just filling `split_sizes` with nothing, so the displayed split sizes will be 0.
split_sizes = {}
else:
# Zaid/coqa_expanded and Zaid/quac_expanded don't have dataset_infos.json
# so infos is an empty dic, and `infos[list(infos.keys())[0]]` raises an error
# For simplicity, just filling `split_sizes` with nothing, so the displayed split sizes will be 0.
split_sizes = {}

# Collect template counts, original task counts and names
Expand Down

0 comments on commit d1f16cf

Please sign in to comment.