Skip to content

Commit

Permalink
feat: rewrite above implementation to improve performance (#48)
Browse files Browse the repository at this point in the history
* Add fast above to work with SPR
* docs: reformat doc comments
* style: format with black
* fix: use range instead of xrange for py3 compat
* Add test above

Co-authored-by: Pablo Sierra Heras <[email protected]>
Co-authored-by: Jeremiah Matthey <[email protected]>
  • Loading branch information
3 people authored Aug 19, 2022
1 parent 0284a27 commit 1afee80
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 14 deletions.
45 changes: 45 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,46 @@ for _ in xrange(10):
print("%f\t%f\t%f" % (transit.start, transit.duration(), transit.peak()['elevation']))
```

#### Modeling an entire constellation

Generating transits for a lot of satellites over a lot of groundstations can be slow.
Luckily, generating transits for each satellite-groundstation pair can be parallelized for a big speedup.

```
import itertools
from multiprocessing.pool import Pool
import time
import predict
import requests
# Define a function that returns arguments for all the transits() calls you want to make
def _transits_call_arguments():
now = time.time()
tle = requests.get('http://tle.spire.com/25544').text.rstrip()
for latitude in range(-90, 91, 15):
for longitude in range(-180, 181, 15):
qth = (latitude, longitude, 0)
yield {'tle': tle, 'qth': qth, 'ending_before': now+60*60*24*7}
# Define a function that calls the transit function on a set of arguments and does per-transit processing
def _transits_call_fx(kwargs):
try:
transits = list(predict.transits(**kwargs))
return [t.above(10) for t in transits]
except predict.PredictException:
pass
# Map the transit() caller across all the arguments you want, then flatten results into a single list
pool = Pool(processes=10)
array_of_results = pool.map(_transits_call_fx, _transits_call_arguments())
flattened_results = list(itertools.chain.from_iterable(filter(None, array_of_results)))
transits = flattened_results
```

NOTE: If precise accuracy isn't necessary (for modeling purposes, for example) setting the tolerance argument
to the `above` call to a larger value, say 1 degree, can provide a signifigant performance boost.

#### Call predict analogs directly

```python
Expand Down Expand Up @@ -144,6 +184,11 @@ predict.quick_predict(tle.split('\n'), time.time(), (37.7727, 122.407, 25))
Returns epoch time where transit reaches maximum elevation (within ~<i>epsilon</i>)
<b>at</b>(<i>timestamp</i>)
Returns observation during transit via <b>quick_find</b>(<i>tle, timestamp, qth</i>)
<b>above</b>b(<i>elevation</i>, <i>tolerance</i>)
Returns portion of transit above elevation. If the entire transit is below the target elevation, both
endpoints will be set to the peak and the duration will be zero. If a portion of the transit is above
the elevation target, the endpoints will be between elevation and elevation + tolerance (unless
endpoint is already above elevation, in which case it will be unchanged)
<b>quick_find</b>(<i>tle[, time[, (lat, long, alt)]]</i>)
<i>time</i> defaults to current time
<i>(lat, long, alt)</i> defaults to values in ~/.predict/predict.qth
Expand Down
144 changes: 131 additions & 13 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,13 @@ def transits(tle, qth, ending_after=None, ending_before=None):
ts = ending_after
while True:
transit = quick_predict(tle, ts, qth)
t = Transit(tle, qth, start=transit[0]["epoch"], end=transit[-1]["epoch"])
t = Transit(
tle,
qth,
start=transit[0]["epoch"],
end=transit[-1]["epoch"],
_samples=transit,
)
if ending_before is not None and t.end > ending_before:
break
if t.end > ending_after:
Expand All @@ -87,21 +93,30 @@ def active_transit(tle, qth, at=None):
if at is None:
at = time.time()
transit = quick_predict(tle, at, qth)
t = Transit(tle, qth, start=transit[0]["epoch"], end=transit[-1]["epoch"])
t = Transit(
tle, qth, start=transit[0]["epoch"], end=transit[-1]["epoch"], _samples=transit
)
return t if t.start <= at <= t.end else None


# Transit is a convenience class representing a pass of a satellite over a groundstation.
class Transit:
def __init__(self, tle, qth, start, end):
"""A convenience class representing a pass of a satellite over a groundstation."""

def __init__(self, tle, qth, start, end, _samples=None):
self.tle = tle
self.qth = qth
self.start = start
self.end = end
if _samples is None:
self._samples = []
else:
self._samples = [s for s in _samples if start <= s["epoch"] <= end]

# return observation within epsilon seconds of maximum elevation
# NOTE: Assumes elevation is strictly monotonic or concave over the [start,end] interval
def peak(self, epsilon=0.1):
"""Return observation within epsilon seconds of maximum elevation.
NOTE: Assumes elevation is strictly monotonic or concave over the [start,end] interval.
"""
ts = (self.end + self.start) / 2
step = self.end - self.start
while step > epsilon:
Expand Down Expand Up @@ -130,14 +145,117 @@ def peak(self, epsilon=0.1):
ts = next_ts
return self.at(ts)

# Return portion of transit above a certain elevation
def above(self, elevation):
return self.prune(lambda ts: self.at(ts)["elevation"] >= elevation)
def above(self, elevation, tolerance=0.001):
"""Return portion of transit that lies above argument elevation.
Elevation at new endpoints will lie between elevation and elevation + tolerance unless
endpoint of original transit is already above elevation, in which case it won't change, or
entire transit is below elevation target, in which case resulting transit will have zero
length.
"""

def capped_below(elevation, samples):
"""Quick heuristic to filter transits that can't reach target elevation.
Assumes transit is unimodal and derivative is monotonic. i.e. transit is a smooth
section of something that has ellipse-like geometry.
"""
limit = None

if len(samples) < 3:
raise ValueError("samples array must have length > 3")

# Find samples that form a hump
for i in range(len(samples) - 2):
a, b, c = samples[i : i + 3]

ae, be, ce = a["elevation"], b["elevation"], c["elevation"]
at, bt, ct = a["epoch"], b["epoch"], c["epoch"]

if ae < be > ce:
left_step = bt - at
right_step = ct - bt
left_slope = (be - ae) / left_step
right_slope = (be - ce) / right_step
limit = be + max(left_step * right_slope, right_step * left_slope)
break

# If limit isn't set, we didn't find a hump, so max is at one of edges.
if limit is None:
limit = max(s["elevation"] for s in samples)

return limit < elevation

def add_sample(ts, samples):
if ts not in [s["epoch"] for s in samples]:
samples.append(self.at(ts))
samples.sort(key=lambda s: s["epoch"])

def interpolate(samples, elevation, tolerance):
"""Interpolate between adjacent samples straddling the elevation target."""

for i in range(len(samples) - 1):
a, b = samples[i : i + 2]

if any(
abs(sample["elevation"] - elevation) <= tolerance
for sample in [a, b]
):
continue

if (a["elevation"] < elevation) != (b["elevation"] < elevation):
p = (elevation - a["elevation"]) / (b["elevation"] - a["elevation"])
t = a["epoch"] + p * (b["epoch"] - a["epoch"])
add_sample(t, samples)
return True
return False

# math gets unreliable with a real small elevation tolerances (~1e-6), be safe.
if tolerance < 0.0001:
raise ValueError("Minimum tolerance of 0.0001")

# Ensure we've got a well formed set of samples for iterative linear interpolation
# We'll need at least three (2 needed for interpolating, 3 needed for filtering speedup)
if self.start == self.end:
return self
samples = self._samples[:]
add_sample(self.start, samples)
add_sample(self.end, samples)
if len(samples) < 3:
add_sample((self.start + self.end) / 2, samples)

# We need at least one sample point in the sample set above the desired elevation
protrude = max(samples, key=lambda s: s["elevation"])
if protrude["elevation"] <= elevation:
if not capped_below(
elevation, samples
): # prevent expensive calculation on lost causes
protrude = self.peak()
add_sample(
protrude["epoch"], samples
) # recalculation is wasteful, but this is rare

if protrude["elevation"] <= elevation:
start = protrude["epoch"]
end = protrude["epoch"]
else:
# Aim for elevation + (tolerance / 2) +/- (tolerance / 2) to ensure we're >= elevation
while interpolate(
samples, elevation + float(tolerance) / 2, float(tolerance) / 2
):
pass
samples = [s for s in samples if s["elevation"] >= elevation]
start = samples[0]["epoch"]
end = samples[-1]["epoch"]
return Transit(self.tle, self.qth, start, end, samples)

# Return section of a transit where a pruning function is valid.
# Currently used to set elevation threshold, unclear what other uses it might have.
# fx must either return false everywhere or true for a contiguous period including the peak
def prune(self, fx, epsilon=0.1):
"""Return section of a transit where a pruning function is valid.
Currently used to set elevation threshold, unclear what other uses it might have. fx must
either return false everywhere or true for a contiguous period including the peak.
"""

peak = self.peak()["epoch"]
if not fx(peak):
start = peak
Expand Down Expand Up @@ -194,7 +312,7 @@ def find_solar_periods(
small_predict_timestep=1,
):
"""
Finds all sunlit (or eclipse, if eclipse is set) windows for a tle within a time range
Finds all sunlit (or eclipse, if eclipse is set) windows for a tle within a time range.
"""
qth = (
0,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name="pypredict",
version="1.6.3",
version="1.7.0",
author="Jesse Trutna",
author_email="[email protected]",
maintainer="Spire Global Inc",
Expand Down
67 changes: 67 additions & 0 deletions test_predict_above.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import predict

TLE = (
"0 ISS (ZARYA)\n"
"1 25544U 98067A 22032.20725753 .00005492 00000-0 10513-3 0 9993\n"
"2 25544 51.6448 289.9951 0006763 85.5907 19.0087 15.49721935324098"
)
QTH = (0.033889, 51.066389, 0.0) # Macapa, Brazil

ENDING_AFTER = 1643720700 # 2022-02-01T13:05:00Z
ENDING_BEFORE = 1644325500 # 2022-02-08T13:05:00Z

MIN_ELEVATION = 10
MIN_DURATION = 1

TRANSITS_REF = [
{"start": 1643723116.22434, "end": 1643723516.987234},
{"start": 1643764874.083364, "end": 1643765265.3938956},
{"start": 1643806649.5753431, "end": 1643807015.1396928},
{"start": 1643848426.4339318, "end": 1643848743.1023502},
{"start": 1643854302.3515873, "end": 1643854454.722202},
{"start": 1643890227.7005908, "end": 1643890465.3432565},
{"start": 1643896002.9880626, "end": 1643896277.0759523},
{"start": 1643937721.9717846, "end": 1643938061.8765898},
{"start": 1643979464.7059774, "end": 1643979842.7585597},
{"start": 1644021207.4925222, "end": 1644021604.0888827},
{"start": 1644062967.0657983, "end": 1644063368.0047994},
{"start": 1644104726.0510883, "end": 1644105112.630568},
{"start": 1644146501.024817, "end": 1644146860.137256},
{"start": 1644188280.245481, "end": 1644188583.2041698},
{"start": 1644194133.3048844, "end": 1644194316.9549828},
{"start": 1644230083.899797, "end": 1644230299.963708},
{"start": 1644235839.9425225, "end": 1644236130.7524247},
{"start": 1644277562.037712, "end": 1644277910.0568247},
{"start": 1644319304.5025349, "end": 1644319688.5155852},
]

TOLERANCE = 0.1


def test_transits_above():
tle = predict.massage_tle(TLE)
qth = predict.massage_qth(QTH)

transits = predict.transits(
tle, qth, ending_after=ENDING_AFTER, ending_before=ENDING_BEFORE
)
transits = [t.above(MIN_ELEVATION) for t in transits]
transits = [t for t in transits if t.duration() > MIN_DURATION]
transits = [{"start": t.start, "end": t.end} for t in transits]

assert len(transits) == len(TRANSITS_REF)

test_ok = True
for i, t in enumerate(TRANSITS_REF):
if (
transits[i]["start"] < t["start"] - TOLERANCE
or t["start"] + TOLERANCE < transits[i]["start"]
):
test_ok = False
if (
transits[i]["end"] < t["end"] - TOLERANCE
or t["end"] + TOLERANCE < transits[i]["end"]
):
test_ok = False

assert test_ok

0 comments on commit 1afee80

Please sign in to comment.