From faca5803026321af6187bc370e8a288ddf579ffc Mon Sep 17 00:00:00 2001 From: Jintang Li Date: Thu, 13 Oct 2022 15:19:15 +0800 Subject: [PATCH] [Type Hints] `transforms.Cartesian` (#5673) Co-authored-by: Matthias Fey --- CHANGELOG.md | 2 +- torch_geometric/transforms/cartesian.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 938c7a258a4c..6849abcd1661 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,7 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641)) - Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642)) - Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610)) -- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669)) +- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673)) - Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601)) - Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614)) - Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602)) diff --git a/torch_geometric/transforms/cartesian.py b/torch_geometric/transforms/cartesian.py index f257ccaa0d21..bcd56f06169e 100644 --- a/torch_geometric/transforms/cartesian.py +++ b/torch_geometric/transforms/cartesian.py @@ -1,5 +1,8 @@ +from typing import Optional + import torch +from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform @@ -19,12 +22,17 @@ class Cartesian(BaseTransform): cat (bool, optional): If set to :obj:`False`, all existing edge attributes will be replaced. (default: :obj:`True`) """ - def __init__(self, norm=True, max_value=None, cat=True): + def __init__( + self, + norm: bool = True, + max_value: Optional[float] = None, + cat: bool = True, + ): self.norm = norm self.max = max_value self.cat = cat - def __call__(self, data): + def __call__(self, data: Data) -> Data: (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr cart = pos[row] - pos[col]