-
-
Notifications
You must be signed in to change notification settings - Fork 683
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
428 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
package internal | ||
|
||
import ( | ||
"context" | ||
"encoding/xml" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"net/url" | ||
"strings" | ||
"time" | ||
) | ||
|
||
// Client defines an HTTP client for communicating with PubMed. | ||
type Client struct { | ||
MaxResults int | ||
UserAgent string | ||
BaseURL string | ||
} | ||
|
||
// Result defines a search query result type. | ||
type Result struct { | ||
Title string | ||
Authors []string | ||
Abstract string | ||
PMID string | ||
Published string | ||
} | ||
|
||
var ( | ||
ErrNoGoodResult = errors.New("no good search results found") | ||
ErrAPIResponse = errors.New("PubMed API responded with error") | ||
) | ||
|
||
// NewClient initializes a Client with arguments for setting a max | ||
// results per search query and a value for the user agent header. | ||
func NewClient(maxResults int, userAgent string) *Client { | ||
if maxResults == 0 { | ||
maxResults = 1 | ||
} | ||
|
||
return &Client{ | ||
MaxResults: maxResults, | ||
UserAgent: userAgent, | ||
BaseURL: "https://eutils.ncbi.nlm.nih.gov/entrez/eutils", | ||
} | ||
} | ||
|
||
func (client *Client) newRequest(ctx context.Context, queryURL string) (*http.Request, error) { | ||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, queryURL, nil) | ||
if err != nil { | ||
return nil, fmt.Errorf("creating PubMed request: %w", err) | ||
} | ||
|
||
if client.UserAgent != "" { | ||
request.Header.Add("User-Agent", client.UserAgent) | ||
} | ||
|
||
return request, nil | ||
} | ||
|
||
// Search performs a search query and returns | ||
// the result as string and an error if any. | ||
func (client *Client) Search(ctx context.Context, query string) (string, error) { | ||
// First, search for IDs | ||
searchURL := fmt.Sprintf("%s/esearch.fcgi?db=pubmed&term=%s&retmax=%d&usehistory=y", | ||
client.BaseURL, url.QueryEscape(query), client.MaxResults) | ||
|
||
request, err := client.newRequest(ctx, searchURL) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
response, err := http.DefaultClient.Do(request) | ||
if err != nil { | ||
return "", fmt.Errorf("get %s error: %w", searchURL, err) | ||
} | ||
defer response.Body.Close() | ||
|
||
if response.StatusCode != http.StatusOK { | ||
return "", ErrAPIResponse | ||
} | ||
|
||
body, err := io.ReadAll(response.Body) | ||
if err != nil { | ||
return "", fmt.Errorf("reading response body: %w", err) | ||
} | ||
|
||
var searchResult struct { | ||
IDList struct { | ||
IDs []string `xml:"Id"` | ||
} `xml:"IdList"` | ||
WebEnv string `xml:"WebEnv"` | ||
QueryKey string `xml:"QueryKey"` | ||
} | ||
|
||
if err := xml.Unmarshal(body, &searchResult); err != nil { | ||
return "", fmt.Errorf("unmarshaling XML: %w", err) | ||
} | ||
|
||
if len(searchResult.IDList.IDs) == 0 { | ||
return "", ErrNoGoodResult | ||
} | ||
|
||
// Now fetch details for these IDs | ||
fetchURL := fmt.Sprintf("%s/efetch.fcgi?db=pubmed&WebEnv=%s&query_key=%s&retmode=xml&rettype=abstract", | ||
client.BaseURL, searchResult.WebEnv, searchResult.QueryKey) | ||
|
||
request, err = client.newRequest(ctx, fetchURL) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
response, err = http.DefaultClient.Do(request) | ||
if err != nil { | ||
return "", fmt.Errorf("get %s error: %w", fetchURL, err) | ||
} | ||
defer response.Body.Close() | ||
|
||
if response.StatusCode != http.StatusOK { | ||
return "", ErrAPIResponse | ||
} | ||
|
||
body, err = io.ReadAll(response.Body) | ||
if err != nil { | ||
return "", fmt.Errorf("reading response body: %w", err) | ||
} | ||
|
||
var fetchResult struct { | ||
Articles []struct { | ||
MedlineCitation struct { | ||
Article struct { | ||
ArticleTitle string `xml:"ArticleTitle"` | ||
Abstract struct { | ||
AbstractText string `xml:"AbstractText"` | ||
} `xml:"Abstract"` | ||
AuthorList struct { | ||
Authors []struct { | ||
LastName string `xml:"LastName"` | ||
ForeName string `xml:"ForeName"` | ||
Initials string `xml:"Initials"` | ||
} `xml:"Author"` | ||
} `xml:"AuthorList"` | ||
} `xml:"Article"` | ||
PMID string `xml:"PMID"` | ||
} `xml:"MedlineCitation"` | ||
PubmedData struct { | ||
History struct { | ||
PubMedPubDate []struct { | ||
Year string `xml:"Year"` | ||
Month string `xml:"Month"` | ||
Day string `xml:"Day"` | ||
} `xml:"PubMedPubDate"` | ||
} `xml:"History"` | ||
} `xml:"PubmedData"` | ||
} `xml:"PubmedArticle"` | ||
} | ||
|
||
if err := xml.Unmarshal(body, &fetchResult); err != nil { | ||
return "", fmt.Errorf("unmarshaling XML: %w", err) | ||
} | ||
|
||
results := []Result{} | ||
for _, article := range fetchResult.Articles { | ||
authors := []string{} | ||
for _, author := range article.MedlineCitation.Article.AuthorList.Authors { | ||
authors = append(authors, fmt.Sprintf("%s %s", author.ForeName, author.LastName)) | ||
} | ||
|
||
var pubDate time.Time | ||
for _, date := range article.PubmedData.History.PubMedPubDate { | ||
if pubDate, err = time.Parse("2006-1-2", fmt.Sprintf("%s-%s-%s", date.Year, date.Month, date.Day)); err == nil { | ||
break | ||
} | ||
} | ||
|
||
results = append(results, Result{ | ||
Title: article.MedlineCitation.Article.ArticleTitle, | ||
Authors: authors, | ||
Abstract: article.MedlineCitation.Article.Abstract.AbstractText, | ||
PMID: article.MedlineCitation.PMID, | ||
Published: pubDate.Format("2006-01-02"), | ||
}) | ||
} | ||
|
||
return client.formatResults(results), nil | ||
} | ||
|
||
// formatResults will return a structured string with the results. | ||
func (client *Client) formatResults(results []Result) string { | ||
var formattedResults strings.Builder | ||
|
||
for _, result := range results { | ||
formattedResults.WriteString(fmt.Sprintf("Title: %s\n", result.Title)) | ||
formattedResults.WriteString(fmt.Sprintf("Authors: %s\n", strings.Join(result.Authors, ", "))) | ||
formattedResults.WriteString(fmt.Sprintf("Abstract: %s\n", result.Abstract)) | ||
formattedResults.WriteString(fmt.Sprintf("PMID: %s\n", result.PMID)) | ||
formattedResults.WriteString(fmt.Sprintf("Published: %s\n\n", result.Published)) | ||
} | ||
|
||
return formattedResults.String() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
package pubmed | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
|
||
"github.com/tmc/langchaingo/callbacks" | ||
"github.com/tmc/langchaingo/tools" | ||
"github.com/tmc/langchaingo/tools/pubmed/internal" | ||
) | ||
|
||
// DefaultUserAgent defines a default value for user-agent header. | ||
const DefaultUserAgent = "github.com/tmc/langchaingo/tools/pubmed" | ||
|
||
// Tool defines a tool implementation for the PubMed Search. | ||
type Tool struct { | ||
CallbacksHandler callbacks.Handler | ||
client *internal.Client | ||
} | ||
|
||
var _ tools.Tool = Tool{} | ||
|
||
// New initializes a new PubMed Search tool with arguments for setting a | ||
// max results per search query and a value for the user agent header. | ||
func New(maxResults int, userAgent string) (*Tool, error) { | ||
return &Tool{ | ||
client: internal.NewClient(maxResults, userAgent), | ||
}, nil | ||
} | ||
|
||
// Name returns a name for the tool. | ||
func (t Tool) Name() string { | ||
return "PubMed Search" | ||
} | ||
|
||
// Description returns a description for the tool. | ||
func (t Tool) Description() string { | ||
return ` | ||
"A wrapper around PubMed Search API." | ||
"Search for biomedical literature from MEDLINE, life science journals, and online books." | ||
"Input should be a search query."` | ||
} | ||
|
||
// Call performs the search and return the result. | ||
func (t Tool) Call(ctx context.Context, input string) (string, error) { | ||
if t.CallbacksHandler != nil { | ||
t.CallbacksHandler.HandleToolStart(ctx, input) | ||
} | ||
|
||
result, err := t.client.Search(ctx, input) | ||
if err != nil { | ||
if errors.Is(err, internal.ErrNoGoodResult) { | ||
return "No good PubMed Search Results were found", nil | ||
} | ||
if t.CallbacksHandler != nil { | ||
t.CallbacksHandler.HandleToolError(ctx, err) | ||
} | ||
return "", err | ||
} | ||
|
||
if t.CallbacksHandler != nil { | ||
t.CallbacksHandler.HandleToolEnd(ctx, result) | ||
} | ||
|
||
return result, nil | ||
} |
Oops, something went wrong.