Source code for hips.draw.ui

# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""The high-level end user interface (UI)."""
import numpy as np
from PIL import Image
from pathlib import Path
from astropy.io import fits
from typing import List, Union
from ..utils.wcs import WCSGeometry
from ..tiles import HipsSurveyProperties, HipsTile
from .paint import HipsPainter

__all__ = [
    'make_sky_image',
    'HipsDrawResult',
]


[docs]def make_sky_image(geometry: Union[dict, WCSGeometry], hips_survey: Union[str, 'HipsSurveyProperties'], tile_format: str, precise: bool = False, progress_bar: bool = True, fetch_opts: dict = None) -> 'HipsDrawResult': """Make sky image: fetch tiles and draw. The example for this can be found on the :ref:`gs` page. Parameters ---------- geometry : dict or `~hips.utils.WCSGeometry` Geometry of the output image hips_survey : str or `~hips.HipsSurveyProperties` HiPS survey properties tile_format : {'fits', 'jpg', 'png'} Format of HiPS tile to use (some surveys are available in several formats, so this extra argument is needed) precise : bool Use the precise drawing algorithm progress_bar : bool Show a progress bar for tile fetching and drawing fetch_opts : dict Keyword arguments for fetching HiPS tiles. To see the list of passable arguments, refer to `~hips.fetch_tiles` Returns ------- result : `~hips.HipsDrawResult` Result object """ painter = HipsPainter(geometry, hips_survey, tile_format, precise, progress_bar, fetch_opts) painter.run() return HipsDrawResult.from_painter(painter)
[docs]class HipsDrawResult: """HiPS draw result object (sky image and more). Parameters ---------- image : `~numpy.ndarray` Sky image (the main result) geometry : `~hips.utils.WCSGeometry` WCS geometry of the sky image tile_format : {'fits', 'jpg', 'png'} Format of HiPS tile tiles : list Python list of `~hips.HipsTile` objects that were used stats : dict Information including time for fetching and drawing HiPS tiles """ def __init__(self, image: np.ndarray, geometry: WCSGeometry, tile_format: str, tiles: List[HipsTile], stats: dict) -> None: self.image = image self.geometry = geometry self.tile_format = tile_format self.tiles = tiles self.stats = stats def __str__(self): return ( 'HiPS draw result:\n' f'Sky image: shape={self.image.shape}, dtype={self.image.dtype}\n' f'WCS geometry: {self.geometry}\n' ) def __repr__(self): return ( 'HipsDrawResult(' f'width={self.image.shape[0]}, ' f'height={self.image.shape[1]}, ' f'channels={self.image.ndim}, ' f'dtype={self.image.dtype}, ' f'format={self.tile_format}' ')' )
[docs] @classmethod def from_painter(cls, painter: HipsPainter) -> 'HipsDrawResult': """Make a `~hips.HipsDrawResult` from a `~hips.HipsTilePainter`.""" return cls( image=painter.image, geometry=painter.geometry, tile_format=painter.tile_format, tiles=painter.tiles, stats=painter._stats, )
[docs] def write_image(self, filename: str, overwrite: bool = False) -> None: """Write image to file. Parameters ---------- filename : str Filename overwrite : bool Overwrite the output file, if it exists """ if overwrite == False and Path(filename).exists(): raise FileExistsError(f"File {filename} already exists.") if self.tile_format == 'fits': hdu = fits.PrimaryHDU(data=self.image, header=self.geometry.fits_header) hdu.writeto(filename) else: image = Image.fromarray(self.image) image.save(filename)
[docs] def plot(self, show_grid: bool = False) -> None: """Plot the all sky image and overlay HiPS tile outlines. Parameters ---------- show_grid : bool Enable grid around HiPS tile boundaries Uses `astropy.visualization.wcsaxes`. """ import matplotlib.pyplot as plt ax = plt.subplot(projection=self.geometry.wcs) if show_grid: for tile in self.tiles: corners = tile.meta.skycoord_corners corners = corners.transform_to(self.geometry.celestial_frame) opts = dict(color='red', lw=1) ax.plot(corners.data.lon.deg, corners.data.lat.deg, transform=ax.get_transform('world'), **opts) ax.imshow(self.image, origin='lower')
[docs] def report(self) -> None: """Print a brief report for the fetched data.""" print ( f"Time for fetching tiles = {self.stats['fetch_time']} seconds\n" f"Time for drawing tiles = {self.stats['draw_time']} seconds\n" f"Total memory consumed = {self.stats['consumed_memory'] / 1e6} MB\n" f"Total tiles fetched = {self.stats['tile_count']}\n" )