Source code for exotools.db.lightcurve_db

import warnings
from pathlib import Path
from typing import Optional

import numpy as np
from astropy.io import fits
from astropy.table import QTable
from astropy.time import Time
from lightkurve import LightCurve, LightCurveCollection
from typing_extensions import Self

from .base_db import BaseDB
from .lightcurve_plus import LightCurvePlus

SECONDS_PER_DAY = 1.0 / 86400.0


[docs] class LightcurveDB(BaseDB): """ Dtypes: ------------------ obs_id int64 tic_id int64 path object ------------------ """ def __init__(self, dataset: QTable): super().__init__(dataset=dataset, id_field="obs_id") @property def tic_ids(self) -> np.ndarray: return self.view["tic_id"].value @property def obs_id(self) -> np.ndarray: return self.view["obs_id"].value @property def all_paths(self) -> np.ndarray: return self.view["path"].value @property def unique_tic_ids(self) -> np.ndarray: return np.unique(self.tic_ids) @property def unique_obs_ids(self) -> np.ndarray: return np.unique(self.obs_id) def _factory(self, dataset: QTable) -> Self: return LightcurveDB(dataset)
[docs] def select_by_tic_ids(self, tic_ids: np.ndarray) -> Self: return self.where(tic_id=tic_ids)
[docs] def load_by_tic( self, tic_id: int, start_time_at_zero: bool = False, load_in_jd_time: bool = False ) -> Optional[list[LightCurvePlus]]: paths = self.view[["path", "obs_id"]][self.view["tic_id"] == tic_id] if len(paths) == 0: return None # Sort lightcurves chronologically lcs = [ LightCurvePlus( self.load_lightcurve(row["path"], load_in_jd_time=load_in_jd_time), obs_id=row["obs_id"], ) for row in paths ] lcs = sorted(lcs, key=lambda x: x.time[0]) if start_time_at_zero: for lc in lcs: lc.start_at_zero() return lcs
[docs] def load_stitched_by_tic( self, tic_id: int, start_time_at_zero: bool = False, load_in_jd_time: bool = False ) -> Optional[LightCurvePlus]: lcs = self.load_by_tic(tic_id, start_time_at_zero=False, load_in_jd_time=load_in_jd_time) if not lcs: return None stitched = LightCurveCollection([lc.lc for lc in lcs]).stitch() lc_plus = LightCurvePlus(stitched) return lc_plus.start_at_zero() if start_time_at_zero else lc_plus
[docs] def load_by_obs_id( self, obs_id: int, start_time_at_zero: bool = False, load_in_jd_time: bool = False ) -> Optional[LightCurvePlus]: path = self.view["path"][self.view["obs_id"] == obs_id] if len(path) == 0: return None lc = LightCurvePlus(self.load_lightcurve(path[0], load_in_jd_time=load_in_jd_time)) if start_time_at_zero: lc = lc.start_at_zero() return lc
[docs] def load_collections_by_tics( self, tic_ids: list[int], load_in_jd_time: bool = False ) -> list[Optional[LightCurveCollection]]: return [self.load_by_tic(tic, load_in_jd_time=load_in_jd_time) for tic in tic_ids]
[docs] def load_stitched_by_tics( self, tic_ids: list[int], load_in_jd_time: bool = False ) -> list[Optional[LightCurvePlus]]: return [self.load_stitched_by_tic(tic, load_in_jd_time=load_in_jd_time) for tic in tic_ids]
[docs] def load_by_obs_ids(self, obs_ids: list[int], load_in_jd_time: bool = False) -> list[Optional[LightCurvePlus]]: return [self.load_by_obs_id(obs, load_in_jd_time=load_in_jd_time) for obs in obs_ids]
[docs] @staticmethod def path_map_to_qtable(path_map: dict[int, list[Path]]) -> QTable: tabular_data = [ {"tic_id": tic, "obs_id": int(path.stem), "path": str(path)} for tic, paths in path_map.items() for path in paths ] return QTable(tabular_data)
[docs] @staticmethod def load_lightcurve(fits_file_path: Path | str, load_in_jd_time: bool = False) -> LightCurve: # This line stores all the additional information from the fits file. But takes more time to execute # return lightkurve.utils.read(downloaded) with fits.open(str(fits_file_path)) as hdul: lightcurve_data = hdul["LIGHTCURVE"].data time_array: np.ndarray = lightcurve_data["TIME"] flux: np.ndarray = lightcurve_data["PDCSAP_FLUX"] error: np.ndarray = lightcurve_data["PDCSAP_FLUX_ERR"] valid_range = ~np.isnan(time_array) & ~np.isnan(flux) # Read metadata from headers meta = dict(hdul[0].header) meta.update(dict(hdul[1].header)) # Get time information and timestamps from header keywords ref_int = meta.get("BJDREFI", 0) ref_fractional = meta.get("BJDREFF", 0) time_unit = (meta.get("TIMEUNIT", "d") or "d").lower() time_sys = (meta.get("TIMESYS", "TDB") or "TDB").lower() time_ref = (meta.get("TIMEREF", "LOCAL") or "LOCAL").upper() # ------------------------------------------------------------------- # TIMEUNIT tells us the units of the `TIME` column: # - Usually 'd' (days), meaning no conversion needed. # - Sometimes seconds ('s'), which we must convert to days. # - The `startswith("s")` check is used to cover cases like 'sec', # 'seconds', or 's' without having to hardcode each spelling. # # This ensures the calculation works even if the keyword value changes # slightly across data releases or instruments. # ------------------------------------------------------------------- if time_unit in {"d", "day", "days"}: factor = 1.0 elif time_unit.startswith("s"): factor = SECONDS_PER_DAY else: # Default: assume days, but log a warning for unusual units warnings.warn(f"Unexpected TIMEUNIT='{time_unit}', assuming days.") factor = 1.0 # Warn if times are not barycentric if time_ref != "SOLARSYSTEM": warnings.warn( f"TIMEREF='{time_ref}' indicates times are not barycentric; " "consider applying barycentric correction if needed." ) # Determine original time format from metadata tunit1 = meta.get("TUNIT1", "").lower() if "bjd" in tunit1 and "2457000" in tunit1: # TESS BTJD format: "BJD - 2457000, days" original_format = "btjd" elif "jd" in tunit1: # Julian Date format original_format = "jd" else: # Default to BTJD for TESS data if unclear original_format = "btjd" warnings.warn(f"Could not determine time format from TUNIT1='{meta.get('TUNIT1')}', assuming BTJD") if not load_in_jd_time: # Return lightcurve with original time format time = Time(time_array[valid_range], format=original_format, scale=time_sys) lc = LightCurve(time=time, flux=flux[valid_range], flux_err=error[valid_range], meta=meta) else: # Compute barycentric Julian dates (BJD_TDB) and construct astropy Time object jd = (ref_int + ref_fractional) + time_array[valid_range] * factor time = Time(jd, format="jd", scale=time_sys) lc = LightCurve(time=time, flux=flux[valid_range], flux_err=error[valid_range], meta=meta) # Store original format info in metadata for LightCurvePlus lc.meta["_ORIGINAL_TIME_FORMAT"] = original_format return lc
[docs] @staticmethod def load_lightcurve_collection(paths: list[Path | str], load_in_jd_time: bool = False) -> LightCurveCollection: lightcurves = [ LightcurveDB.load_lightcurve(p, load_in_jd_time=load_in_jd_time).remove_outliers() for p in paths ] lightcurves = [(lc if np.median(lc.flux.value) > 0 else None) for lc in lightcurves] return LightCurveCollection([lc for lc in lightcurves if lc is not None])
[docs] @staticmethod def load_lightcurve_plus(fits_file_path: Path | str, load_in_jd_time: bool = False) -> LightCurvePlus: return LightCurvePlus(LightcurveDB.load_lightcurve(fits_file_path, load_in_jd_time=load_in_jd_time))
[docs] @staticmethod def load_lightcurve_plus_from_collection(paths: list[Path | str], load_in_jd_time: bool = False) -> LightCurvePlus: collection = LightcurveDB.load_lightcurve_collection(paths, load_in_jd_time=load_in_jd_time) return LightCurvePlus(collection.stitch())