Skip to content

Commit

Permalink
Fix some type errors newer mypy found (one follow up)
Browse files Browse the repository at this point in the history
  • Loading branch information
ntjohnson1 committed Aug 6, 2024
1 parent a4b20f3 commit e1e7613
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
6 changes: 4 additions & 2 deletions pyttb/gcp/handles.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,17 @@ def huber_grad(data: ttb.tensor, model: ttb.tensor, threshold: float) -> np.ndar
) * np.logical_not(below_threshold)


# FIXME: Num trials should be enforced as integer here and in MATLAB
# requires updating our regression test values to calculate MATLAB integer version
def negative_binomial(
data: np.ndarray, model: np.ndarray, num_trials: int
data: np.ndarray, model: np.ndarray, num_trials: float
) -> np.ndarray:
"""Return objective function for negative binomial distributions"""
return (num_trials + data) * np.log(model + 1) - data * np.log(model + EPS)


def negative_binomial_grad(
data: np.ndarray, model: np.ndarray, num_trials: int
data: np.ndarray, model: np.ndarray, num_trials: float
) -> np.ndarray:
"""Return gradient function for negative binomial distributions"""
return (num_trials + 1) / (1 + model) - data / (model + EPS)
Expand Down
5 changes: 4 additions & 1 deletion pyttb/gcp/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def semistrat(data: ttb.sptensor, num_nonzeros: int, num_zeros: int) -> sample_t


def stratified(
data: ttb.sptensor,
data: Union[ttb.sptensor, ttb.tensor],
nz_idx: np.ndarray,
num_nonzeros: int,
num_zeros: int,
Expand All @@ -450,6 +450,9 @@ def stratified(
-------
Subscripts, values, and weights of samples (Nonzeros then zeros).
"""
assert isinstance(
data, ttb.sptensor
), "For stratified sampling Sparse Tensor must be provided"
[nonzero_subs, nonzero_vals] = nonzeros(data, num_nonzeros, with_replacement=True)
nonzero_weights = np.ones((num_nonzeros,))
if num_nonzeros > 0:
Expand Down

0 comments on commit e1e7613

Please sign in to comment.