Source code for exotools.db.lightcurve_plus

import logging
import warnings
from math import ceil
from typing import Any, Optional

import numpy as np
from astropy.time import Time, TimeDelta
from astropy.units import Quantity
from lightkurve import FoldedLightCurve, LightCurve
from typing_extensions import Self

from exotools.utils.array_utils import (
    get_contiguous_interval_indices,
    get_contiguous_intervals,
    get_gaps_interval_indices,
    get_gaps_intervals,
)

from .star_system import Planet

logger = logging.getLogger(__name__)


[docs] class LightCurvePlus: def __init__(self, lightcurve: LightCurve, obs_id: Optional[int] = None): # Store original format information self._original_time_format = lightcurve.meta.get("_ORIGINAL_TIME_FORMAT", "btjd") # Use the lightcurve as-is, preserving its original time format self.lc: LightCurve = lightcurve # TimeDelta doesn't support all Time formats, so use 'sec' format for compatibility self._time_shift = TimeDelta(0, format="sec", scale=self.lc.time.scale) self._obs_id = obs_id self._warn_if_not_barycentric() def __len__(self) -> int: return len(self.lc.time) @property def time_system(self) -> str: """Return the current time system (format/scale combination).""" return f"{self.lc.time.format.upper()}/{self.lc.time.scale.upper()}" @property def time_x(self) -> np.ndarray: return self.lc.time.value @property def time(self) -> Time: return self.lc.time @property def flux_y(self) -> np.ndarray: return self.lc.flux.value @property def flux(self) -> np.ndarray: return self.lc.flux @property def standardized_flux(self) -> np.ndarray: flux = self.lc.flux.value return (flux - flux.mean()) / flux.std() @property def normalized_flux(self) -> np.ndarray: flux = self.lc.flux.value median = np.median(flux) if median < 1e-6: logger.warning("LightCurvePlus.normalized_flux(): trying to normalize flux by a median near zero.") return flux / median - 1 @property def tic_id(self) -> int: return self.meta["TICID"] @property def obs_id(self) -> Optional[int]: return self._obs_id @property def meta(self) -> dict[str, Any]: return self.lc.meta @property def jd_time(self) -> np.ndarray: """Julian Date as a NumPy array.""" if self.lc.time.format == "jd": # Already in JD format, return directly return np.asarray(self.lc.time.value, dtype=float) else: # Convert to JD return np.asarray(self.lc.time.jd, dtype=float) @property def bjd_time(self) -> np.ndarray: """Absolute BJD in TDB (days) as a NumPy array.""" if self.lc.time.format == "jd" and self.lc.time.scale == "tdb": # Already in BJD_TDB format, return directly return np.asarray(self.lc.time.value, dtype=float) else: # Convert to TDB explicitly to be unambiguous return np.asarray(self.lc.time.tdb.jd, dtype=float) @property def elapsed_time(self) -> np.ndarray: """ Days since first cadence (relative timeline), independent of BJDREF*. """ bjd = self.bjd_time return bjd - bjd[0] @property def btjd_time(self) -> np.ndarray: """ TESS BTJD in days, i.e., BJD_TDB − (BJDREFI + BJDREFF). """ if self.lc.time.format == "btjd": # Already in BTJD format, return directly return np.asarray(self.lc.time.value, dtype=float) # Need to convert from other format to BTJD refi = self.meta.get("BJDREFI") reff = self.meta.get("BJDREFF") if refi is None and reff is None: # TESS convention; safe fallback for BTJD if headers were stripped warnings.warn("BJDREFI/BJDREFF not found in meta; assuming 2457000.0 (TESS default) for BTJD.") refi, reff = 2457000, 0.0 else: refi = 0 if refi is None else refi reff = 0.0 if reff is None else reff bjd_ref = float(refi) + float(reff) return self.bjd_time - bjd_ref def _warn_if_not_barycentric(self) -> None: """Warn if TIMEREF suggests times are not barycentric.""" timeref = (self.meta.get("TIMEREF") or "").upper() if timeref and timeref != "SOLARSYSTEM": warnings.warn( f"TIMEREF='{timeref}' indicates times may not be barycentric; " "BJD/BTJD semantics assume barycentric timing." )
[docs] def to_numpy(self) -> np.ndarray: return np.array([self.time_x, self.flux_y]).T
[docs] def remove_nans(self) -> Self: return LightCurvePlus(self.lc.remove_nans(), obs_id=self._obs_id)
[docs] def remove_outliers(self) -> Self: return LightCurvePlus(self.lc.remove_outliers(), obs_id=self._obs_id)
[docs] def normalize(self) -> Self: return LightCurvePlus(self.lc.normalize(), obs_id=self._obs_id)
[docs] def get_first_transit_value(self, planet: Planet) -> Time: i = self.get_transit_first_index(planet) return self.lc.time[i]
[docs] def get_transit_first_index(self, planet: Planet) -> int: """ Get the index of the first transit in the light curve time series. """ return _find_fist_transit_index( time=self.time_x, period=planet.orbital_period.central.value, midpoint=self._get_aligned_midpoint(planet) )
[docs] def shift_time(self, shift: float | Quantity) -> Self: # Use 'sec' format for TimeDelta compatibility, but convert to days if needed if isinstance(shift, (int, float)): # Assume shift is in the same units as the time (days for astronomical data) delta = TimeDelta(shift * 86400, format="sec", scale=self.lc.time.scale) # Convert days to seconds else: delta = TimeDelta(shift, format="sec", scale=self.lc.time.scale) self._time_shift += delta self.lc.time += delta return self
[docs] def start_at_zero(self) -> Self: return self.shift_time(shift=-self.lc.time[0].value)
[docs] def get_transit_phase(self, planet: Planet) -> np.ndarray: return _get_phase( time=self.time_x, period=planet.orbital_period.central.value, midpoint=self._get_aligned_midpoint(planet) )
[docs] def get_transit_mask(self, planet: Planet, duration_increase_percent: float = 0) -> np.ndarray: """ Args: planet: planet with transit information duration_increase_percent: increases the transit duration by a given percentage (0 to 1). This changes the size of the masked regions Returns: a boolean array were 1 corresponds to planet transits """ return self.lc.create_transit_mask( period=planet.orbital_period.central, transit_time=self._get_aligned_midpoint(planet), duration=planet.transit_duration.central + duration_increase_percent * planet.transit_duration.central, )
[docs] def get_transit_count(self, planet: Planet) -> int: # mask = 000011100000011100 # mask[:-1] = 00001110000001110 # mask[1:] = 00011100000011100 # xor_mask = 00010010000010010 mask = self.get_transit_mask(planet=planet) xor_mask = mask[:-1] ^ mask[1:] return ceil(xor_mask.sum() / 2)
[docs] def get_combined_transit_mask(self, planets: list[Planet]) -> np.ndarray: return self.lc.create_transit_mask( period=[p.orbital_period.central for p in planets], transit_time=[self._get_aligned_midpoint(p) for p in planets], duration=[p.transit_duration.central for p in planets], )
[docs] def fold_with_planet(self, planet: Planet, normalize_time: bool = False) -> FoldedLightCurve: return self.lc.fold( epoch_time=self._get_aligned_midpoint(planet), period=planet.orbital_period.central, normalize_phase=normalize_time, )
[docs] def copy_with_flux(self, flux: np.ndarray) -> Self: lc = copy_lightcurve(self.lc, with_flux=flux) return LightCurvePlus(lc, obs_id=self._obs_id)
[docs] def find_time_gaps_i(self, greater_than_median: float = 10.0) -> list[tuple[int, int]]: """ Find time gaps in the lightcurve based on time step analysis. Identifies locations where the time difference between consecutive points exceeds the median time step by a specified factor, indicating data gaps or interruptions in observations. Args: greater_than_median: Threshold multiplier for gap detection. Gaps are identified where time_diff > median_time_step * greater_than_median. Returns: List of index tuples (i, i+1) where each tuple represents the indices immediately before and after a detected gap. The gap occurs between time[i] and time[i+1]. """ return get_gaps_interval_indices(x=self.time_x, greater_than_median=greater_than_median)
[docs] def find_time_gaps_x(self, greater_than_median: float = 10.0) -> list[tuple[float, float]]: """ Find time gaps in the lightcurve and return actual time values. Identifies locations where the time difference between consecutive points exceeds the median time step by a specified factor, returning the actual time values at gap boundaries rather than indices. Args: greater_than_median: Threshold multiplier for gap detection. Gaps are identified where time_diff > median_time_step * greater_than_median. Returns: List of time value tuples (t1, t2) where each tuple represents the actual time values immediately before and after a detected gap. The gap occurs between time t1 and time t2. See Also: find_time_gaps_i: Returns the same gaps as index pairs instead of time values. """ return get_gaps_intervals(x=self.time_x, greater_than_median=greater_than_median)
[docs] def find_contiguous_time_i(self, greater_than_median: float = 10.0) -> list[tuple[int, int]]: """ Find contiguous time intervals in the lightcurve based on time step analysis. Identifies regions where time differences between consecutive points remain below the threshold, indicating continuous observation periods without significant gaps. Args: greater_than_median: Threshold multiplier for gap detection. Contiguous intervals are where time_diff <= median_time_step * greater_than_median. Returns: List of index tuples (start, end) where each tuple represents the start and end indices (inclusive) of a contiguous time interval. """ return get_contiguous_interval_indices(x=self.time_x, greater_than_median=greater_than_median)
[docs] def find_contiguous_time_x(self, greater_than_median: float = 10.0) -> list[tuple[float, float]]: """ Find contiguous time intervals in the lightcurve and return actual time values. Identifies regions where time differences between consecutive points remain below the threshold, returning the actual time values at the boundaries of contiguous intervals. Args: greater_than_median: Threshold multiplier for gap detection. Contiguous intervals are where time_diff <= median_time_step * greater_than_median. Returns: List of time value tuples (t_start, t_end) where each tuple represents the actual time values at the start and end of a contiguous interval. """ return get_contiguous_intervals(x=self.time_x, greater_than_median=greater_than_median)
[docs] def to_jd_time(self) -> Self: """Convert the light curve time to plain Julian Date (JD) *representation* in place. JD is the continuous count of days since 4713 BCE (noon), independent of location; the *scale* (UTC, TT, TDB, …) is tracked separately. This method puts the times in `format="jd"` while preserving the existing time *scale* and reference frame. When your times are already barycentric (e.g., TESS BJD_TDB), converting to JD does not change the numeric values—it only standardizes the representation. Examples -------- Suppose your first cadence is BJD_TDB = 2458354.123456: >>> lc.time.format, lc.time.scale ('jd', 'tdb') >>> lc.time[0].value 2458354.123456 >>> lc.to_jd_time().lc.time[0].value # still JD in TDB scale 2458354.123456 Returns ------- Self Returns self for method chaining. """ if self.lc.time.format != "jd": self.lc = _convert_time_to_bjd(self.lc) return self
[docs] def to_btjd_time(self) -> Self: """Convert the light curve time to BTJD (Barycentric TESS Julian Date) in place. BTJD is a TESS-specific convenience: BTJD ≡ BJD_TDB − (BJDREFI + BJDREFF). For standard SPOC products, (BJDREFI, BJDREFF) = (2457000, 0), so BTJD = BJD_TDB − 2457000. This keeps the *barycentric* reference and the TDB time scale, but shifts the zero-point so numbers are ~10^3 instead of ~2.4×10^6. Examples -------- >>> # Starting from BJD_TDB = 2458354.123456 (TESS Year 1) >>> lc.to_btjd_time().lc.time[0].value 1354.123456 # 2458354.123456 - 2457000.0 >>> # Converting back to BJD_TDB (see to_bjd_time) restores the 2.458e6 magnitude. >>> lc.to_bjd_time().lc.time[0].value 2458354.123456 Returns ------- Self Returns self for method chaining. """ if self.lc.time.format != "btjd": self.lc = _convert_time_to_btjd(self.lc) return self
[docs] def to_bjd_time(self) -> Self: """Convert the light curve time to Barycentric Julian Date (BJD_TDB) in place. **BJD** is simply JD evaluated at the Solar System Barycenter (SSB). For TESS, timestamps are already referenced to the SSB with `TIMESYS='TDB'` and `TIMEREF='SOLARSYSTEM'`, so BJD_TDB is the physically correct absolute time. Numerically, BJD_TDB equals JD in the TDB scale when the reference location is barycentric. This method ensures the output is **BJD_TDB** (absolute, not offset), which is what you want for comparing absolute epochs (e.g., transit mid-times) across sectors or with literature ephemerides. Examples -------- >>> # From BTJD back to absolute BJD_TDB: >>> lc.to_btjd_time().lc.time[0].value 1354.123456 >>> lc.to_bjd_time().lc.time[0].value 2458354.123456 # adds back (BJDREFI + BJDREFF) = 2457000.0 >>> # If already BJD_TDB, calling again is a no-op: >>> lc.time.format, lc.time.scale ('jd', 'tdb') >>> lc.to_bjd_time().lc.time.format 'jd' Returns ------- Self Returns self for method chaining. """ # BJD_TDB is represented as JD in the TDB scale with a barycentric reference. # If your internal representation uses a custom 'btjd' format, this will add back # the header offset (BJDREFI + BJDREFF). Otherwise, it's effectively a no-op. if self.lc.time.format != "jd": self.lc = _convert_time_to_bjd(self.lc) return self
def _get_aligned_midpoint(self, planet: Planet) -> float: return (planet.transit_midpoint.central + self._time_shift).value
[docs] def fold(self, period=None, epoch_time=None, epoch_phase=0, wrap_phase=None, normalize_phase=False): return self.lc.fold( period=period, epoch_time=epoch_time, epoch_phase=epoch_phase, wrap_phase=wrap_phase, normalize_phase=normalize_phase, )
def copy_lightcurve(lightcurve: LightCurve, with_flux: Optional[np.ndarray] = None) -> LightCurve: if with_flux is None: return lightcurve.copy(copy_data=True) lc = LightCurve(time=lightcurve.time.copy(), flux=with_flux.copy()) lc.meta = lightcurve.meta return lc def _btjd_to_jd_time(time: Time) -> Time: return Time(val=time.value + 2457000, format="jd", scale="tdb") def _convert_time_to_jd(lc: LightCurve) -> LightCurve: if lc.time.scale != "tdb": raise ValueError(f"Time scale {lc.time.scale} unknown/unsupported.") if lc.time.format == "jd": return lc elif lc.time.format == "btjd": new_t = _btjd_to_jd_time(lc.time) return LightCurve(time=new_t, flux=lc.flux, flux_err=lc.flux_err, meta=lc.meta) raise ValueError(f"Time format {lc.time.format} unknown/unsupported.") def _convert_time_to_btjd(lc: LightCurve) -> LightCurve: """Convert lightcurve time to BTJD format.""" if lc.time.scale != "tdb": raise ValueError(f"Time scale {lc.time.scale} unknown/unsupported.") if lc.time.format == "btjd": return lc elif lc.time.format == "jd": # Convert JD to BTJD by subtracting reference refi = lc.meta.get("BJDREFI", 2457000) reff = lc.meta.get("BJDREFF", 0.0) bjd_ref = float(refi) + float(reff) new_t = Time(lc.time.value - bjd_ref, format="btjd", scale="tdb") return LightCurve(time=new_t, flux=lc.flux, flux_err=lc.flux_err, meta=lc.meta) raise ValueError(f"Time format {lc.time.format} unknown/unsupported.") def _convert_time_to_bjd(lc: LightCurve) -> LightCurve: """Convert lightcurve time to BJD format (same as JD for TDB scale).""" # For TDB scale, BJD is the same as JD return _convert_time_to_jd(lc) def _get_phase(time: np.ndarray, period: float, midpoint: float) -> np.ndarray: k = np.round((time - midpoint) / period) closest_event_time = midpoint + k * period return np.abs(closest_event_time - time) def _find_fist_transit_index(time: np.ndarray, period: float, midpoint: float, step: int = 100) -> int: phase = _get_phase(time=time, midpoint=midpoint, period=period) i = 1 length = len(time) - 1 while i < length and phase[i - 1] < phase[i]: i = min(i + step, length) while i < length and phase[i - 1] > phase[i]: i = min(i + step, length) while i > 0 and phase[i - 1] < phase[i]: i = max(i - step, -1) return i