Skip to content

Commit

Permalink
[Type Hints] transforms.Polar
Browse files Browse the repository at this point in the history
  • Loading branch information
EdisonLeeeee committed Oct 13, 2022
1 parent faca580 commit 86e8d56
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions torch_geometric/transforms/polar.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from math import pi as PI
from typing import Optional

import torch

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform

Expand All @@ -21,12 +23,17 @@ class Polar(BaseTransform):
cat (bool, optional): If set to :obj:`False`, all existing edge
attributes will be replaced. (default: :obj:`True`)
"""
def __init__(self, norm=True, max_value=None, cat=True):
def __init__(
self,
norm: bool = True,
max_value: Optional[float] = None,
cat: bool = True,
):
self.norm = norm
self.max = max_value
self.cat = cat

def __call__(self, data):
def __call__(self, data: Data) -> Data:
(row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr
assert pos.dim() == 2 and pos.size(1) == 2

Expand Down

0 comments on commit 86e8d56

Please sign in to comment.