Skip to content

Commit

Permalink
Daily Papers API (#2554)
Browse files Browse the repository at this point in the history
* Daily Papers API

* Apply suggestions from code review

Co-authored-by: Celina Hanouti <[email protected]>

* Apply suggestions from code review

Co-authored-by: Celina Hanouti <[email protected]>

* Fix tests

* Run papers API tests independently

* Apply suggestions from code review

Co-authored-by: Lucain <[email protected]>

* Remove date

* additional test and update docstring

---------

Co-authored-by: Celina Hanouti <[email protected]>
Co-authored-by: Lucain <[email protected]>
  • Loading branch information
3 people authored Oct 9, 2024
1 parent 613b591 commit 2c7c19d
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@
"list_metrics",
"list_models",
"list_organization_members",
"list_papers",
"list_pending_access_requests",
"list_rejected_access_requests",
"list_repo_commits",
Expand All @@ -230,6 +231,7 @@
"merge_pull_request",
"model_info",
"move_repo",
"paper_info",
"parse_safetensors_file_metadata",
"pause_inference_endpoint",
"pause_space",
Expand Down Expand Up @@ -741,6 +743,7 @@ def __dir__():
list_metrics, # noqa: F401
list_models, # noqa: F401
list_organization_members, # noqa: F401
list_papers, # noqa: F401
list_pending_access_requests, # noqa: F401
list_rejected_access_requests, # noqa: F401
list_repo_commits, # noqa: F401
Expand All @@ -755,6 +758,7 @@ def __dir__():
merge_pull_request, # noqa: F401
model_info, # noqa: F401
move_repo, # noqa: F401
paper_info, # noqa: F401
parse_safetensors_file_metadata, # noqa: F401
pause_inference_endpoint, # noqa: F401
pause_space, # noqa: F401
Expand Down
133 changes: 133 additions & 0 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,70 @@ def __init__(self, **kwargs) -> None:
self.__dict__.update(**kwargs)


@dataclass
class PaperInfo:
"""
Contains information about a paper on the Hub.
Attributes:
id (`str`):
arXiv paper ID.
authors (`List[str]`, **optional**):
Names of paper authors
published_at (`datetime`, **optional**):
Date paper published.
title (`str`, **optional**):
Title of the paper.
summary (`str`, **optional**):
Summary of the paper.
upvotes (`int`, **optional**):
Number of upvotes for the paper on the Hub.
discussion_id (`str`, **optional**):
Discussion ID for the paper on the Hub.
source (`str`, **optional**):
Source of the paper.
comments (`int`, **optional**):
Number of comments for the paper on the Hub.
submitted_at (`datetime`, **optional**):
Date paper appeared in daily papers on the Hub.
submitted_by (`User`, **optional**):
Information about who submitted the daily paper.
"""

id: str
authors: Optional[List[str]]
published_at: Optional[datetime]
title: Optional[str]
summary: Optional[str]
upvotes: Optional[int]
discussion_id: Optional[str]
source: Optional[str]
comments: Optional[int]
submitted_at: Optional[datetime]
submitted_by: Optional[User]

def __init__(self, **kwargs) -> None:
paper = kwargs.pop("paper", {})
self.id = kwargs.pop("id", None) or paper.pop("id", None)
authors = paper.pop("authors", None) or kwargs.pop("authors", None)
self.authors = [author.pop("name", None) for author in authors] if authors else None
published_at = paper.pop("publishedAt", None) or kwargs.pop("publishedAt", None)
self.published_at = parse_datetime(published_at) if published_at else None
self.title = kwargs.pop("title", None)
self.source = kwargs.pop("source", None)
self.summary = paper.pop("summary", None) or kwargs.pop("summary", None)
self.upvotes = paper.pop("upvotes", None) or kwargs.pop("upvotes", None)
self.discussion_id = paper.pop("discussionId", None) or kwargs.pop("discussionId", None)
self.comments = kwargs.pop("numComments", 0)
submitted_at = kwargs.pop("publishedAt", None) or kwargs.pop("submittedOnDailyAt", None)
self.submitted_at = parse_datetime(submitted_at) if submitted_at else None
submitted_by = kwargs.pop("submittedBy", None) or kwargs.pop("submittedOnDailyBy", None)
self.submitted_by = User(**submitted_by) if submitted_by else None

# forward compatibility
self.__dict__.update(**kwargs)


def future_compatible(fn: CallableT) -> CallableT:
"""Wrap a method of `HfApi` to handle `run_as_future=True`.
Expand Down Expand Up @@ -9673,6 +9737,72 @@ def list_user_following(self, username: str, token: Union[bool, str, None] = Non
):
yield User(**followed_user)

def list_papers(
self,
*,
query: Optional[str] = None,
token: Union[bool, str, None] = None,
) -> Iterable[PaperInfo]:
"""
List daily papers on the Hugging Face Hub given a search query.
Args:
query (`str`, *optional*):
A search query string to find papers.
If provided, returns papers that match the query.
token (Union[bool, str, None], *optional*):
A valid user access token (string). Defaults to the locally saved
token, which is the recommended method for authentication (see
https://huggingface.co/docs/huggingface_hub/quick-start#authentication).
To disable authentication, pass `False`.
Returns:
`Iterable[PaperInfo]`: an iterable of [`huggingface_hub.hf_api.PaperInfo`] objects.
Example:
```python
>>> from huggingface_hub import HfApi
>>> api = HfApi()
# List all papers with "attention" in their title
>>> api.list_papers(query="attention")
```
"""
path = f"{self.endpoint}/api/papers/search"
params = {}
if query:
params["q"] = query
r = get_session().get(
path,
params=params,
headers=self._build_hf_headers(token=token),
)
hf_raise_for_status(r)
for paper in r.json():
yield PaperInfo(**paper)

def paper_info(self, id: str) -> PaperInfo:
"""
Get information for a paper on the Hub.
Args:
id (`str`, **optional**):
ArXiv id of the paper.
Returns:
`PaperInfo`: A `PaperInfo` object.
Raises:
[`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError):
HTTP 404 If the paper does not exist on the Hub.
"""
path = f"{self.endpoint}/api/papers/{id}"
r = get_session().get(path)
hf_raise_for_status(r)
return PaperInfo(**r.json())

def auth_check(
self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None
) -> None:
Expand Down Expand Up @@ -9768,6 +9898,9 @@ def _parse_revision_from_pr_url(pr_url: str) -> str:
list_spaces = api.list_spaces
space_info = api.space_info

list_papers = api.list_papers
paper_info = api.paper_info

repo_exists = api.repo_exists
revision_exists = api.revision_exists
file_exists = api.file_exists
Expand Down
22 changes: 22 additions & 0 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4106,6 +4106,28 @@ def test_user_following(self) -> None:
assert len(list(following)) > 500


class PaperApiTest(unittest.TestCase):
@classmethod
@with_production_testing
def setUpClass(cls) -> None:
cls.api = HfApi()
return super().setUpClass()

def test_papers_by_query(self) -> None:
papers = list(self.api.list_papers(query="llama"))
assert len(papers) > 0
assert "The Llama 3 Herd of Models" in [paper.title for paper in papers]

def test_get_paper_by_id_success(self) -> None:
paper = self.api.paper_info("2407.21783")
assert paper.title == "The Llama 3 Herd of Models"

def test_get_paper_by_id_not_found(self) -> None:
with self.assertRaises(HfHubHTTPError) as context:
self.api.paper_info("1234.56789")
assert context.exception.response.status_code == 404


class WebhookApiTest(HfApiCommonTest):
def setUp(self) -> None:
super().setUp()
Expand Down

0 comments on commit 2c7c19d

Please sign in to comment.