Source code for wakis.gridFIT3D

# copyright ################################# #
# This file is part of the wakis Package.     #
# Copyright (c) CERN, 2024.                   #
# ########################################### #

import time

import h5py
import numpy as np
import pyvista as pv
from scipy.optimize import least_squares

from .field import Field
from .logger import Logger
from .materials import material_colors

try:
    from mpi4py import MPI

    imported_mpi = True
except ImportError:
    imported_mpi = False


[docs]class GridFIT3D: """ Class holding the grid information and stl importing handling using PyVista """ def __init__( self, xmin=None, xmax=None, ymin=None, ymax=None, zmin=None, zmax=None, Nx=None, Ny=None, Nz=None, x=None, y=None, z=None, use_mpi=False, use_mesh_refinement=False, refinement_method="insert", refinement_tol=1e-8, snap_points=None, snap_tol=1e-5, snap_solids=None, stl_solids=None, stl_materials=None, stl_rotate=[0.0, 0.0, 0.0], stl_translate=[0.0, 0.0, 0.0], stl_scale=1.0, stl_colors=None, stl_tol=1e-3, load_from_h5=None, verbose=1, ): """ Class holding the grid information and STL importing/handling using PyVista. Parameters ---------- xmin, xmax, ymin, ymax, zmin, zmax : float, optional Extent of the simulation domain. If None, must provide x, y, z arrays. Nx, Ny, Nz : int, optional Number of cells per direction. If None, must provide x, y, z arrays. x, y, z : array_like, optional Custom grid axis arrays to be used in the meshgrid generation. Non-uniform grids are supported. use_mpi : bool, optional Enable MPI domain decomposition in the z direction. Default is False. use_mesh_refinement : bool, optional Enable mesh refinement based on snap points extracted from the STL solids. Default is False. refinement_method : str, optional Refinement algorithm for mesh refinement. Default is "insert". refinement_tol : float, optional Tolerance for mesh refinement. Default is 1e-8. snap_points : array_like, optional Points to snap the mesh to. Default is None. snap_tol : float, optional Tolerance for snap point detection. Default is 1e-5. snap_solids : list or None, optional STL solids to use for snap point extraction. Default is None (all). stl_solids : dict or str, optional STL files to import in the domain. {'Solid 1': stl_1, ...} stl_materials : dict, optional Material properties associated with STL solids. {'Solid 1': [eps1, mu1], ...} stl_rotate : list or dict, optional Angle of rotation to apply to the STL models: [rot_x, rot_y, rot_z]. If dict, must contain the same keys as stl_solids. stl_translate : list or dict, optional Translation to apply to the STL models: [dx, dy, dz]. If dict, must contain the same keys as stl_solids. stl_scale : float or dict, optional Scaling value to apply to the STL model to convert to [m]. If dict, must contain the same keys as stl_solids. stl_colors : str, list, dict, or None, optional Color(s) for STL solids. If None, assigned automatically. stl_tol : float, optional Tolerance factor for STL import, used in grid.select_enclosed_points. Default is 1e-3. load_from_h5 : str, optional Load grid from an h5 file previously saved with `save_to_h5`. verbose : int or bool, optional Enable verbose output on the terminal. Use `verbose=2` for more detail. Attributes ---------- x, y, z : ndarray Grid axis arrays. Nx, Ny, Nz : int Number of cells in each direction. dx, dy, dz : ndarray Cell sizes in each direction. grid : pyvista.StructuredGrid PyVista grid object. stl_solids, stl_materials, stl_colors : dict STL solid file paths, materials, and colors. (...) """ t0 = time.time() self.logger = Logger() self.verbose = verbose self.use_mpi = use_mpi # Grid data # generate from file if load_from_h5 is not None: self.load_from_h5(load_from_h5) return # TODO: support MPI decomposition # generate from custom x,y,z arrays elif x is not None and y is not None and z is not None: # allow user to set the grid axis manually self.x = x self.y = y self.z = z self.Nx = len(self.x) - 1 self.Ny = len(self.y) - 1 self.Nz = len(self.z) - 1 self.xmin, self.xmax = self.x[0], self.x[-1] self.ymin, self.ymax = self.y[0], self.y[-1] self.zmin, self.zmax = self.z[0], self.z[-1] if self.use_mpi: raise ValueError( "[!] Error: use_mpi=True is not compatible with custom x,y,z arrays." ) # generate from domain extents and number of cells [LEGACY] elif all(v is not None for v in [xmin, xmax, ymin, ymax, zmin, zmax]): # uniform grid from domain extents and number of cells self.xmin, self.xmax = xmin, xmax self.ymin, self.ymax = ymin, ymax self.zmin, self.zmax = zmin, zmax self.Nx, self.Ny, self.Nz = Nx, Ny, Nz self.x = np.linspace(self.xmin, self.xmax, self.Nx + 1) self.y = np.linspace(self.ymin, self.ymax, self.Ny + 1) self.z = np.linspace(self.zmin, self.zmax, self.Nz + 1) else: raise ValueError( "[!] Error initializing GridFIT3D:\n" " - Provide grid axis arrays: x, y, z\n" " - OR domain extents and number of cells: \ xmin, xmax, ymin, ymax, zmin, zmax, Nx, Ny, Nz\n" " - OR load from a HDF5 file using load_from_h5" ) # TODO: allow non uniform dx, dy, dz self.dx = np.diff(self.x) self.dy = np.diff(self.y) self.dz = np.diff(self.z) self.update_logger(["Nx", "Ny", "Nz", "dx", "dy", "dz"]) self.update_logger(["xmin", "xmax", "ymin", "ymax", "zmin", "zmax"]) # stl info self.stl_solids = stl_solids self.stl_materials = stl_materials self.stl_rotate = stl_rotate self.stl_translate = stl_translate self.stl_scale = stl_scale self.stl_colors = stl_colors self.update_logger(["stl_solids", "stl_materials"]) if stl_rotate != [0.0, 0.0, 0.0]: self.update_logger(["stl_rotate"]) if stl_translate != [0.0, 0.0, 0.0]: self.update_logger(["stl_translate"]) if stl_scale != 1.0: self.update_logger(["stl_scale"]) if stl_solids is not None: self._prepare_stl_dicts() # refine self.x, self.y, self.z using snap points self.use_mesh_refinement = use_mesh_refinement self.refinement_method = refinement_method self.snap_points = snap_points self.snap_tol = snap_tol self.snap_solids = snap_solids # if None, use all stl_solids self.update_logger(["use_mesh_refinement"]) if self.use_mesh_refinement: if verbose: print("Applying mesh refinement...") if self.snap_points is None and stl_solids is not None: self._compute_snap_points( snap_solids=snap_solids, snap_tol=snap_tol ) self._refine_xyz_axis(method=refinement_method, tol=refinement_tol) if verbose: print( f"Generating grid with {self.Nx * self.Ny * self.Nz} mesh cells..." ) if verbose > 1: print( f" * Simulation domain bounds: \n\ x:[{xmin:.3f}, {xmax:.3f}],\n\ y:[{ymin:.3f}, {ymax:.3f}],\n\ z:[{zmin:.3f}, {zmax:.3f}]" ) # MPI subdivide domain if self.use_mpi: self.ZMIN = None self.ZMAX = None self.NZ = None self.Z = None if imported_mpi: self._mpi_initialize() if self.verbose: print(f"MPI initialized for {self.rank} of {self.size}") else: raise ImportError( "[!] mpi4py is required when use_mpi=True but was not found" ) # grid G and tilde grid ~G, lengths and inverse areas self._compute_grid() # tolerance for stl import tol*min(dx,dy,dz) if verbose: print("Importing STL solids...") self.stl_tol = stl_tol if stl_solids is not None: self._mark_cells_in_stl() if verbose: print(f"Total grid initialization time: {time.time() - t0} s") self.gridInitializationTime = time.time() - t0 self.update_logger(["gridInitializationTime"]) def _compute_grid(self): """ Compute the PyVista grid and related geometric quantities. Sets up the structured grid, cell lengths, inverse areas, and tilde grid. """ X, Y, Z = np.meshgrid(self.x, self.y, self.z, indexing="ij") self.grid = pv.StructuredGrid( X.transpose(), Y.transpose(), Z.transpose() ) self.L = Field(self.Nx, self.Ny, self.Nz) self.L.field_x = X[1:, 1:, 1:] - X[:-1, :-1, :-1] self.L.field_y = Y[1:, 1:, 1:] - Y[:-1, :-1, :-1] self.L.field_z = Z[1:, 1:, 1:] - Z[:-1, :-1, :-1] self.iA = Field(self.Nx, self.Ny, self.Nz) self.iA.field_x = np.divide(1.0, self.L.field_y * self.L.field_z) self.iA.field_y = np.divide(1.0, self.L.field_x * self.L.field_z) self.iA.field_z = np.divide(1.0, self.L.field_x * self.L.field_y) # tilde grid ~G self.tx = (self.x[1:] + self.x[:-1]) / 2 self.ty = (self.y[1:] + self.y[:-1]) / 2 self.tz = (self.z[1:] + self.z[:-1]) / 2 self.tx = np.append(self.tx, self.tx[-1]) self.ty = np.append(self.ty, self.ty[-1]) self.tz = np.append(self.tz, self.tz[-1]) tX, tY, tZ = np.meshgrid(self.tx, self.ty, self.tz, indexing="ij") self.tL = Field(self.Nx, self.Ny, self.Nz) self.tL.field_x = tX[1:, 1:, 1:] - tX[:-1, :-1, :-1] self.tL.field_y = tY[1:, 1:, 1:] - tY[:-1, :-1, :-1] self.tL.field_z = tZ[1:, 1:, 1:] - tZ[:-1, :-1, :-1] self.itA = Field(self.Nx, self.Ny, self.Nz) aux = self.tL.field_y * self.tL.field_z self.itA.field_x = np.divide( 1.0, aux, out=np.zeros_like(aux), where=aux != 0 ) aux = self.tL.field_x * self.tL.field_z self.itA.field_y = np.divide( 1.0, aux, out=np.zeros_like(aux), where=aux != 0 ) aux = self.tL.field_x * self.tL.field_y self.itA.field_z = np.divide( 1.0, aux, out=np.zeros_like(aux), where=aux != 0 ) del aux def _mpi_initialize(self): """ Initialize MPI domain decomposition in the z direction. Sets up MPI communicator, rank, size, and subdomain grid quantities. """ comm = MPI.COMM_WORLD # Get MPI communicator self.comm = comm self.rank = self.comm.Get_rank() self.size = self.comm.Get_size() # Error handling for Nz < size if self.Nz < self.size: raise ValueError( f"Nz ({self.Nz}) must be greater than or equal to the number of \ MPI processes ({self.size})." ) # global z quantities [ALLCAPS] self.ZMIN = self.zmin self.ZMAX = self.zmax self.NZ = self.Nz - self.Nz % ( self.size ) # ensure multiple of MPI size self.Z = np.linspace(self.ZMIN, self.ZMAX, self.NZ + 1)[:-1] self.Z += (self.ZMAX - self.ZMIN) / (2 * self.NZ) if self.verbose and self.rank == 0: print( f" * Global grid ZMIN={self.ZMIN}, ZMAX={self.ZMAX}, NZ={self.NZ}" ) # MPI subdomain quantities [TODO: support non-uniform dz with MPI] self.Nz = self.NZ // (self.size) self.dz = (self.ZMAX - self.ZMIN) / self.NZ self.zmin = self.rank * self.Nz * self.dz + self.ZMIN self.zmax = (self.rank + 1) * self.Nz * self.dz + self.ZMIN if self.verbose: print( f"MPI rank {self.rank} of {self.size} initialized with \ zmin={self.zmin}, zmax={self.zmax}, Nz={self.Nz}" ) # Add ghost cells self.n_ghosts = 1 if self.rank > 0: self.zmin += -self.n_ghosts * self.dz self.Nz += self.n_ghosts if self.rank < (self.size - 1): self.zmax += self.n_ghosts * self.dz self.Nz += self.n_ghosts # Support for single core if self.rank == 0 and self.size == 1: self.zmax += self.n_ghosts * self.dz self.Nz += self.n_ghosts self.z = np.linspace(self.zmin, self.zmax, self.Nz + 1) self.dz = np.diff(self.z) # only uniform grid possible with MPI
[docs] def mpi_gather_asGrid(self): """ Gather the global grid from all MPI subdomains and return a new GridFIT3D. Returns ------- _grid : GridFIT3D or None The global grid object (only on rank 0). """ _grid = None if self.rank == 0: print(f"Generating global grid from {self.ZMIN} to {self.ZMAX}") _grid = GridFIT3D( self.xmin, self.xmax, self.ymin, self.ymax, self.ZMIN, self.ZMAX, self.Nx, self.Ny, self.NZ, use_mpi=False, stl_solids=self.stl_solids, stl_materials=self.stl_materials, stl_scale=self.stl_scale, stl_rotate=self.stl_rotate, stl_translate=self.stl_translate, stl_colors=self.stl_colors, verbose=self.verbose, stl_tol=self.stl_tol, ) return _grid
def _prepare_stl_dicts(self): """ Prepare STL-related dictionaries for rotation, scale, translation, and colors. Ensures all STL solids have corresponding entries for rotation, scale, translation, and color, converting single values to dicts as needed. """ if type(self.stl_solids) is not dict: if type(self.stl_solids) is str: self.stl_solids = {"Solid 1": self.stl_solids} else: raise Exception( "Attribute `stl_solids` must contain a string or a dictionary" ) if type(self.stl_rotate) is not dict: # if not a dict, the same values will be applied to all solids stl_rotate = {} for key in self.stl_solids.keys(): stl_rotate[key] = self.stl_rotate self.stl_rotate = stl_rotate if type(self.stl_scale) is not dict: # if not a dict, the same values will be applied to all solids stl_scale = {} for key in self.stl_solids.keys(): stl_scale[key] = self.stl_scale self.stl_scale = stl_scale if type(self.stl_translate) is not dict: # if not a dict, the same values will be applied to all solids stl_translate = {} for key in self.stl_solids.keys(): stl_translate[key] = self.stl_translate self.stl_translate = stl_translate if type(self.stl_colors) is not dict: if self.stl_colors is None: self._assign_colors() elif self.stl_colors is str: # single color for all solids stl_colors = {} for key in self.stl_solids.keys(): stl_colors[key] = self.stl_colors self.stl_colors = stl_colors elif type(self.stl_colors) is list: stl_colors = {} try: for i, key in enumerate(self.stl_solids.keys()): stl_colors[key] = self.stl_colors[i] self.stl_colors = stl_colors except IndexError: self._assign_colors() print( "[!] If `stl_colors` is a list, it must have \ the same length as `stl_solids`." ) def _mark_cells_in_stl(self): """ Mark grid cells that are inside each STL solid. Uses PyVista's select_enclosed_points to create boolean masks for each solid. """ # Obtain masks with grid cells inside each stl solid stl_tolerance = ( np.min([np.min(self.dx), np.min(self.dy), np.min(self.dz)]) * self.stl_tol ) progress_bar = False if self.Nx * self.Ny * self.Nz > 5e6 and self.verbose: progress_bar = True for key in self.stl_solids.keys(): surf = self.read_stl(key) # mark cells in stl [True == in stl, False == out stl] try: select = self.grid.select_enclosed_points( surf, tolerance=stl_tolerance, progress_bar=progress_bar ) except Exception: select = self.grid.select_enclosed_points( surf, tolerance=stl_tolerance, check_surface=False, progress_bar=progress_bar, ) if self.verbose > 1: print( f"[!] Warning: stl solid {key} may have issues with closed surfaces. \ Consider checking the STL file." ) self.grid[key] = ( select.point_data_to_cell_data()["SelectedPoints"] > stl_tolerance ) if self.verbose and np.sum(self.grid[key]) == 0: print( f"[!] Warning: no cells were marked inside stl solid {key}. \ Consider increasing the tolerance factor (currently {self.stl_tol})." ) if self.verbose > 1: print( f" * STL solid {key}: {np.sum(self.grid[key])} cells marked inside the solid." )
[docs] def read_stl(self, key): """ Read and transform an STL solid by key. Parameters ---------- key : str Key of the STL solid to read. Returns ------- surf : pyvista.PolyData The transformed STL surface. """ # import stl surf = pv.read(self.stl_solids[key]) # rotate surf = surf.rotate_x(self.stl_rotate[key][0]) surf = surf.rotate_y(self.stl_rotate[key][1]) surf = surf.rotate_z(self.stl_rotate[key][2]) # translate surf = surf.translate(self.stl_translate[key]) # scale surf = surf.scale(self.stl_scale[key]) return surf
def _compute_snap_points(self, snap_solids=None, snap_tol=1e-8): """ Compute snap points from STL feature edges for mesh refinement. Parameters ---------- snap_solids : list or None, optional STL solids to use for snap point extraction. Default is all. snap_tol : float, optional Tolerance for snap point detection. """ if self.verbose > 1: print(" * Calculating snappy points...") # Support for user-defined stl_keys as list if snap_solids is None: snap_solids = self.stl_solids.keys() # Union of all the surfaces model = None for key in snap_solids: solid = self.read_stl(key) if model is None: model = solid else: model = model + solid edges = model.extract_feature_edges( boundary_edges=True, manifold_edges=False ) # Extract points lying in the X-Z plane (Y โ‰ˆ 0) xz_plane_points = edges.points[np.abs(edges.points[:, 1]) < snap_tol] # Extract points lying in the Y-Z plane (X โ‰ˆ 0) yz_plane_points = edges.points[np.abs(edges.points[:, 0]) < snap_tol] # Extract points lying in the X-Y plane (Z โ‰ˆ 0) xy_plane_points = edges.points[np.abs(edges.points[:, 2]) < snap_tol] self.snap_points = np.r_[ xz_plane_points, yz_plane_points, xy_plane_points ] # get the unique x, y, z coordinates x_snaps = np.unique(np.round(self.snap_points[:, 0], 5)) y_snaps = np.unique(np.round(self.snap_points[:, 1], 5)) z_snaps = np.unique(np.round(self.snap_points[:, 2], 5)) # Include simulation domain bounds self.x_snaps = np.unique( np.concatenate(([self.xmin], x_snaps, [self.xmax])) ) self.y_snaps = np.unique( np.concatenate(([self.ymin], y_snaps, [self.ymax])) ) self.z_snaps = np.unique( np.concatenate(([self.zmin], z_snaps, [self.zmax])) )
[docs] def plot_snap_points(self, snap_solids=None, snap_tol=1e-8): """ Plot snap points extracted from STL feature edges for mesh refinement. Parameters ---------- snap_solids : list or None, optional STL solids to use for snap point extraction. Default is all. snap_tol : float, optional Tolerance for snap point detection. """ # Support for user-defined stl_keys as list if snap_solids is None: snap_solids = self.stl_solids.keys() # Union of all the surfaces model = None for key in snap_solids: solid = self.read_stl(key) if model is None: model = solid else: model = model + solid edges = model.extract_feature_edges( boundary_edges=True, manifold_edges=False ) # Extract points lying in the X-Z plane (Y โ‰ˆ 0) xz_plane_points = edges.points[np.abs(edges.points[:, 1]) < snap_tol] # Extract points lying in the Y-Z plane (X โ‰ˆ 0) yz_plane_points = edges.points[np.abs(edges.points[:, 0]) < snap_tol] # Extract points lying in the X-Y plane (Z โ‰ˆ 0) xy_plane_points = edges.points[np.abs(edges.points[:, 2]) < snap_tol] xz_cloud = pv.PolyData(xz_plane_points) yz_cloud = pv.PolyData(yz_plane_points) xy_cloud = pv.PolyData(xy_plane_points) pv.global_theme.allow_empty_mesh = True pl = pv.Plotter() pl.add_mesh(model, color="white", opacity=0.5, label="base STL") pl.add_mesh( edges, color="black", line_width=5, opacity=0.8, ) pl.add_mesh( xz_cloud, color="green", point_size=20, render_points_as_spheres=True, label="XZ plane points", ) pl.add_mesh( yz_cloud, color="orange", point_size=20, render_points_as_spheres=True, label="YZ plane points", ) pl.add_mesh( xy_cloud, color="magenta", point_size=20, render_points_as_spheres=True, label="XY plane points", ) pl.add_legend() pl.show()
[docs] def refine_axis(self, xmin, xmax, Nx, x_snaps, method="insert", tol=1e-12): """ Refine a grid axis using snap points and a chosen method. Parameters ---------- xmin, xmax : float Axis bounds. Nx : int Number of grid points. x_snaps : array_like Snap points to include in the axis. method : str, optional Refinement algorithm ('insert', 'neighbor', 'subdivision'). tol : float, optional Convergence tolerance for optimization. Returns ------- x : ndarray Refined axis array. """ # Loss function to minimize cell size spread def loss_function(x, x0, is_snap): # avoid moving snap points penalty_snap = np.sum((x[is_snap] - x0[is_snap]) ** 2) * 1000 # avoid gaps < uniform gap dx = np.diff(x) threshold = 1 / (len(x) - 1) # or a hardcoded `min_spacing` penalty_small_gaps = np.sum((threshold - dx[dx < threshold]) ** 2) # avoid large spread in gap length dx = np.diff(x) penalty_variance = np.std(dx) * 10 # return penalty_snap + penalty_small_gaps + penalty_variance return np.hstack( [penalty_snap, penalty_small_gaps, penalty_variance] ) # Uniformly distributed points as initial guess x_snaps = (x_snaps - xmin) / (xmax - xmin) # normalize to [0,1] if method == "insert": x0 = np.unique( np.append(x_snaps, np.linspace(0, 1, Nx - len(x_snaps))) ) elif method == "neighbor": x = np.linspace(0, 1, Nx) dx = np.diff(x)[0] for s in x_snaps: m = np.isclose(x, s, rtol=0.0, atol=dx / 2) if np.sum(m) > 0: x[np.argmax(m)] = s x0 = x.copy() elif method == "subdivision": # x = snaps while len(x) < Nx: # idx of segments sorted min -> max idx_max_diffs = np.argsort(np.diff(x))[-1] # take bigger # print(f"Bigger segment starts at {x[idx_max_diffs]}") # compute new point in the middle of the segment val = ( x[idx_max_diffs] + (x[idx_max_diffs + 1] - x[idx_max_diffs]) / 2 ) # insert the new point x = np.insert(x, idx_max_diffs + 1, val) x = np.unique(x) # print(f"Inserted point {val} at index {idx_max_diffs}") x0 = x.copy() else: raise ValueError( f"Method {method} not supported. Use 'insert', 'neighbor' or 'subdivision'." ) # minimize segment length spread for the test points is_snap = np.isin(x0, x_snaps) result = least_squares( loss_function, x0=x0.copy(), bounds=(0, 1), # (zmin, zmax), jac="3-point", method="dogbox", loss="arctan", gtol=tol, ftol=tol, xtol=tol, verbose=0, args=(x0.copy(), is_snap.copy()), ) # transform back to [xmin, xmax] return result.x * (xmax - xmin) + xmin
def _refine_xyz_axis(self, method="insert", tol=1e-6): """ Refine grid axes using snap points extracted from STL solids. Uses the stored snap points (``self.x_snaps``, ``self.y_snaps``, ``self.z_snaps``) to refine the axis arrays ``self.x``, ``self.y``, ``self.z``. The refinement method and convergence tolerance control how new grid nodes are inserted. Parameters ---------- method : {'insert','neighbor','subdivision'}, optional Refinement algorithm to use when inserting snap points. Default is 'insert'. tol : float, optional Convergence tolerance passed to the refinement routine. """ if self.verbose > 1: print(f" * Refining x axis with {len(self.x_snaps)} snaps...") self.x = self.refine_axis( self.xmin, self.xmax, self.Nx + 1, self.x_snaps, method=method, tol=tol, ) if self.verbose > 1: print(f" * Refining y axis with {len(self.y_snaps)} snaps...") self.y = self.refine_axis( self.ymin, self.ymax, self.Ny + 1, self.y_snaps, method=method, tol=tol, ) if self.verbose > 1: print(f" * Refining z axis with {len(self.z_snaps)} snaps...") self.z = self.refine_axis( self.zmin, self.zmax, self.Nz + 1, self.z_snaps, method=method, tol=tol, ) self.Nx = len(self.x) - 1 self.Ny = len(self.y) - 1 self.Nz = len(self.z) - 1 self.dx = np.diff(self.x) self.dy = np.diff(self.y) self.dz = np.diff(self.z) if self.verbose > 1: print( f"Refined grid: Nx = {self.Nx}, Ny ={self.Ny}, Nz = {self.Nz}" ) def _assign_colors(self): """ Assign colors for each STL solid based on material categories. Maps entries in ``self.stl_materials`` to color names using the ``material_colors`` lookup. Supports string keys referencing the material library or explicit material tuples (eps_r, mu_r[, sigma]). The resulting mapping is stored in ``self.stl_colors``. """ self.stl_colors = {} for key in self.stl_solids: mat = self.stl_materials[key] if type(mat) is str: _color = material_colors.get(mat, material_colors["other"]) self.stl_colors[key] = _color elif len(mat) == 2: if mat[0] is np.inf: # eps_r self.stl_colors[key] = material_colors["pec"] elif mat[0] > 1.0: # eps_r self.stl_colors[key] = material_colors["dielectric"] else: self.stl_colors[key] = material_colors["vacuum"] elif len(mat) == 3: self.stl_colors[key] = material_colors["lossy metal"] else: self.stl_colors[key] = material_colors["other"] def _add_logo_widget(self, pl): """ Add packaged logo to a PyVista plotter via importlib.resources. Attempts to load a packaged logo image from the installed package resources. Falls back to a local development path when resources are not available (typical in editable/dev installs). """ try: from importlib import resources # resource inside the installed package (use current package) logo_res = resources.files(__package__).joinpath( "static", "img", "wakis-logo-pink.png" ) with resources.as_file(logo_res) as logo_path: pl.add_logo_widget(str(logo_path)) return except Exception as e: # fallback to the legacy relative path for dev installs try: pl.add_logo_widget("../docs/img/wakis-logo-pink.png") except Exception: if self.verbose > 1: print(f"[!] Could not add logo widget: {e}")
[docs] def plot_solids( self, bounding_box=False, show_grid=False, anti_aliasing=None, opacity=1.0, specular=0.5, smooth_shading=False, off_screen=False, **kwargs, ): """ Generate a 3D visualization of imported STL geometries using PyVista. Parameters ---------- bounding_box : bool, optional If True, adds a bounding box around the plotted geometry (default False). show_grid : bool, optional If True, overlays the grid wireframe on the scene (default False). anti_aliasing : str or None, optional Anti-aliasing mode passed to PyVista (default: None). opacity : float, optional Opacity for solids (1.0 opaque, 0.0 transparent). Default 1.0. specular : float, optional Specular lighting strength, higher is shinier (default 0.5). smooth_shading : bool, optional Enable smooth shading for surface rendering (default False). off_screen : bool, optional If True, export to HTML instead of opening an interactive window. **kwargs : dict Additional keyword args forwarded to ``pyvista.add_mesh``. Notes ----- - Colors come from ``self.stl_colors`` when available. - Solids labeled 'vacuum' are rendered with reduced opacity by default. """ pl = pv.Plotter() pl.add_mesh(self.grid, opacity=0.0, name="grid", show_scalar_bar=False) for key in self.stl_solids: color = self.stl_colors[key] if self.stl_materials[key] == "vacuum": _opacity = 0.3 else: _opacity = opacity pl.add_mesh( self.read_stl(key), color=color, opacity=_opacity, specular=specular, smooth_shading=smooth_shading, **kwargs, ) pl.set_background("mistyrose", top="white") self._add_logo_widget(pl) pl.camera_position = "zx" pl.camera.azimuth += 30 pl.camera.elevation += 30 pl.add_axes() if anti_aliasing is not None: pl.enable_anti_aliasing(anti_aliasing) if bounding_box: pl.add_bounding_box() if show_grid: pl.add_mesh( self.grid, style="wireframe", color="grey", opacity=0.3, name="grid", ) if off_screen: return pl # pl.export_html("grid_plot_solids.html") else: pl.show()
[docs] def plot_stl_mask( self, stl_solid, cmap="viridis", bounding_box=True, show_grid=True, add_stl="all", stl_opacity=0.0, stl_colors=None, xmax=None, ymax=None, zmax=None, anti_aliasing="ssaa", smooth_shading=False, off_screen=False, ): """ Interactive 3D visualization of the structured grid mask and imported STL geometries. This routine uses PyVista to display the grid scalar field corresponding to a chosen STL mask. It provides interactive slider widgets to clip the domain along the X, Y, and Z directions. At each slider position, the clipped scalar field is shown with a colormap while the grid structure is shown as a 2D slice in wireframe. Optionally, one or more STL geometries can be added to the scene, along with a bounding box of the simulation domain. Parameters ---------- stl_solid : str Key name of the `stl_solids` dictionary to retrieve the mask for visualization (used as the scalar field). cmap : str, optional Colormap used to visualize the clipped scalar values. Default 'viridis'. bounding_box : bool, optional If True, add a static wireframe bounding box of the simulation domain. show_grid : bool, optional If True, adds the computational grid overlay on the clipped slice. add_stl : {'all', str, list[str]}, optional STL geometries to add. Default 'all'. stl_opacity : float, optional Opacity of the STL surfaces (0 = fully transparent, 1 = fully opaque). stl_colors : str, list[str], dict, or None, optional Color(s) of the STL surfaces. xmax, ymax, zmax : float, optional Initial clipping positions along each axis. If None, use the maximum domain extent. anti_aliasing : {'ssaa', 'fxaa', None}, optional Anti-aliasing mode passed to `pl.enable_anti_aliasing`. smooth_shading : bool, optional Enable smooth shading for STL surfaces. Default False. off_screen : bool, optional If True, render off-screen and export the scene to HTML. Notes ----- - Three sliders (X, Y, Z) control clipping of the scalar field by a box. - STL solids can be visualized in transparent mode. - A static domain bounding box can be added for reference. """ if stl_colors is None: stl_colors = self.stl_colors if xmax is None: xmax = self.xmax if ymax is None: ymax = self.ymax if zmax is None: zmax = self.zmax pv.global_theme.allow_empty_mesh = True pl = pv.Plotter() vals = {"x": xmax, "y": ymax, "z": zmax} # --- Update function --- def update_clip(val, axis="x"): vals[axis] = val # define bounds dynamically if axis == "x": slice_obj = self.grid.slice(normal="x", origin=(val, 0, 0)) elif axis == "y": slice_obj = self.grid.slice(normal="y", origin=(0, val, 0)) else: # z slice_obj = self.grid.slice(normal="z", origin=(0, 0, val)) # add clipped volume (scalars) pl.add_mesh( self.grid.clip_box( bounds=( self.xmin, vals["x"], self.ymin, vals["y"], self.zmin, vals["z"], ), invert=False, ), scalars=stl_solid, cmap=cmap, name="clip", ) # add slice wireframe (grid structure) if show_grid: pl.add_mesh( slice_obj, style="wireframe", color="grey", name="slice" ) # Plot stl surface(s) if add_stl is not None: if type(add_stl) is str: # add all stl solids if add_stl.lower() == "all": for i, key in enumerate(self.stl_solids): surf = self.read_stl(key) if type(stl_colors) is dict: pl.add_mesh( surf, color=stl_colors[key], opacity=stl_opacity, silhouette=dict(color=stl_colors[key]), name=key, ) elif type(stl_colors) is list: pl.add_mesh( surf, color=stl_colors[i], opacity=stl_opacity, silhouette=dict(color=stl_colors[i]), name=key, ) else: pl.add_mesh( surf, color="white", opacity=stl_opacity, silhouette=True, name=key, ) else: # add 1 selected stl solid key = add_stl surf = self.read_stl(key) pl.add_mesh( surf, color=stl_colors[key], opacity=stl_opacity, silhouette=dict(color=stl_colors[key]), name=key, ) elif type(add_stl) is list: # add selected list of stl solids for i, key in enumerate(add_stl): surf = self.read_stl(key) if type(stl_colors[key]) is dict: pl.add_mesh( surf, color=stl_colors[key], opacity=stl_opacity, silhouette=dict(color=stl_colors[key]), name=key, ) elif type(stl_colors) is list: pl.add_mesh( surf, color=stl_colors[i], opacity=stl_opacity, silhouette=dict(color=stl_colors[i]), name=key, ) else: pl.add_mesh( surf, color="white", opacity=stl_opacity, silhouette=True, name=key, ) # --- Sliders (placed side-by-side vertically) --- pl.add_slider_widget( lambda val: update_clip(val, "x"), [self.xmin, self.xmax], value=xmax, title="X Clip", pointa=(0.8, 0.8), pointb=(0.95, 0.8), # top-right style="modern", ) pl.add_slider_widget( lambda val: update_clip(val, "y"), [self.ymin, self.ymax], value=ymax, title="Y Clip", pointa=(0.8, 0.6), pointb=(0.95, 0.6), # middle-right style="modern", ) pl.add_slider_widget( lambda val: update_clip(val, "z"), [self.zmin, self.zmax], value=zmax, title="Z Clip", pointa=(0.8, 0.4), pointb=(0.95, 0.4), # lower-right style="modern", ) # Camera orientation pl.camera_position = "zx" pl.camera.azimuth += 30 pl.camera.elevation += 30 pl.set_background("mistyrose", top="white") self._add_logo_widget(pl) pl.add_axes() pl.enable_3_lights() pl.enable_anti_aliasing(anti_aliasing) if bounding_box: pl.add_mesh( pv.Box( bounds=( self.xmin, self.xmax, self.ymin, self.ymax, self.zmin, self.zmax, ) ), style="wireframe", color="black", line_width=2, name="domain_box", ) if off_screen: return pl # pl.export_html(f"grid_stl_mask_{stl_solid}.html") else: pl.show()
[docs] def inspect( self, add_stl=None, stl_opacity=0.5, stl_colors=None, anti_aliasing="ssaa", smooth_shading=True, off_screen=False, ): """ Interactive 3D inspector showing grid and STL geometries. Parameters ---------- add_stl : str or list, optional Key or list of keys of STL solids to include. If None, all solids are shown. stl_opacity : float, optional Opacity for STL surfaces (0 transparent, 1 opaque). Default 0.5. stl_colors : str, list, or dict, optional Color specification for STL surfaces; defaults to ``self.stl_colors``. anti_aliasing : str or None, optional Anti-aliasing mode to enable in the plotter (default 'ssaa'). smooth_shading : bool, optional Enable smooth shading for surfaces (default True). off_screen : bool, optional If True, return the off-screen plotter object instead of showing an interactive window. Returns ------- pl : pyvista.Plotter or None The plotter object if off_screen is True, otherwise None. """ if stl_colors is None: stl_colors = self.stl_colors pv.global_theme.allow_empty_mesh = True pl = pv.Plotter() pl.add_mesh( self.grid, show_edges=True, cmap=["white", "white"], name="grid" ) def clip(widget): # Plot structured grid b = widget.bounds x = self.x[np.logical_and(self.x >= b[0], self.x <= b[1])] y = self.y[np.logical_and(self.y >= b[2], self.y <= b[3])] z = self.z[np.logical_and(self.z >= b[4], self.z <= b[5])] X, Y, Z = np.meshgrid(x, y, z, indexing="ij") grid = pv.StructuredGrid( X.transpose(), Y.transpose(), Z.transpose() ) pl.add_mesh( grid, show_edges=True, cmap=["white", "white"], name="grid" ) # Plot stl surface(s) if add_stl is not None: # add 1 selected stl solid if type(add_stl) is str: key = add_stl surf = self.read_stl(key) surf = surf.clip_box(widget.bounds, invert=False) pl.add_mesh( surf, color=stl_colors[key], opacity=stl_opacity, silhouette=True, smooth_shading=smooth_shading, name=key, ) elif type(add_stl) is list: # add selected list of stl solids for i, key in enumerate(add_stl): surf = self.read_stl(key) surf = surf.clip_box(widget.bounds, invert=False) if type(stl_colors) is dict: pl.add_mesh( surf, color=stl_colors[key], opacity=stl_opacity, silhouette=True, smooth_shading=smooth_shading, name=key, ) elif type(stl_colors) is list: pl.add_mesh( surf, color=stl_colors[i], opacity=stl_opacity, silhouette=True, smooth_shading=smooth_shading, name=key, ) else: pl.add_mesh( surf, color="white", opacity=stl_opacity, silhouette=True, smooth_shading=smooth_shading, name=key, ) else: # add all stl solids for i, key in enumerate(self.stl_solids): surf = self.read_stl(key) surf = surf.clip_box(widget.bounds, invert=False) if type(stl_colors) is dict: pl.add_mesh( surf, color=stl_colors[key], opacity=stl_opacity, silhouette=True, smooth_shading=smooth_shading, name=key, ) elif type(stl_colors) is list: pl.add_mesh( surf, color=stl_colors[i], opacity=stl_opacity, silhouette=True, smooth_shading=smooth_shading, name=key, ) else: pl.add_mesh( surf, color="white", opacity=stl_opacity, silhouette=True, smooth_shading=smooth_shading, name=key, ) _ = pl.add_box_widget(callback=clip, rotation_enabled=False) # Camera orientation pl.camera_position = "zx" pl.camera.azimuth += 30 pl.camera.elevation += 30 pl.set_background("mistyrose", top="white") self._add_logo_widget(pl) pl.add_axes() pl.enable_3_lights() pl.enable_anti_aliasing(anti_aliasing) if off_screen: pl.off_screen = True return pl # pl.export_html('grid_inspect.html') else: pl.show() return None
[docs] def save_to_h5(self, filename="grid.h5"): """ Save the generated grid and STL metadata to an HDF5 file. The file contains axis arrays, STL masks suitable for ``grid.cell_data``, and all ``stl_`` related variables (materials, colors, transforms). Parameters ---------- filename : str, optional Output filename for the HDF5 file. Default 'grid.h5'. """ if not filename.endswith(".h5"): filename += ".h5" if self.verbose: print("Saving grid to HDF5 file:", filename) with h5py.File(filename, "w") as hf: hf.create_dataset("x", data=np.array(self.x)) hf.create_dataset("y", data=np.array(self.y)) hf.create_dataset("z", data=np.array(self.z)) # Save stl_ variables as groups for attr in [ "stl_solids", "stl_materials", "stl_colors", "stl_scale", "stl_rotate", "stl_translate", ]: grp = hf.create_group(attr) dct = getattr(self, attr) for key, val in dct.items(): # Use dtype='S' for strings, otherwise np.array if isinstance(val, str): grp.create_dataset(str(key), data=np.bytes_(val)) else: grp.create_dataset(str(key), data=np.array(val)) for key in self.stl_solids.keys(): hf.create_dataset("grid_" + key, data=np.array(self.grid[key]))
[docs] def load_from_h5(self, filename): """ Load grid axis arrays and STL metadata from an HDF5 file. The function restores axis arrays, recomputes grid metrics and fills ``self.grid`` cell_data with imported STL masks saved previously by ``save_to_h5``. Parameters ---------- filename : str HDF5 filename to read. The '.h5' suffix is appended if missing. """ if not filename.endswith(".h5"): filename += ".h5" if self.verbose: print("Loading grid from HDF5 file:", filename) with h5py.File(filename, "r") as hf: # reconstruct stl dicts self.x = hf["x"][()] self.y = hf["y"][()] self.z = hf["z"][()] # Load stl_ variables from groups for attr in [ "stl_solids", "stl_materials", "stl_colors", "stl_scale", "stl_rotate", "stl_translate", ]: dct = {} grp = hf[attr] for key in grp.keys(): val = grp[key][()] # Decode bytes to string if needed if isinstance(val, bytes): val = val.decode() dct[key] = val setattr(self, attr, dct) # recompute dx, dy, dz, Nx, Ny, Nz self.Nx = len(self.x) - 1 self.Ny = len(self.y) - 1 self.Nz = len(self.z) - 1 self.dx = np.diff(self.x) self.dy = np.diff(self.y) self.dz = np.diff(self.z) self.xmin, self.xmax = self.x[0], self.x[-1] self.ymin, self.ymax = self.y[0], self.y[-1] self.zmin, self.zmax = self.z[0], self.z[-1] # recommpute grid and L, iA, tL, itA self._compute_grid() # asign masks to grid.cell_data with h5py.File(filename, "r") as hf: for key in self.stl_solids.keys(): self.grid[key] = hf["grid_" + key][()] # add verbosity if self.verbose > 1: print( f"Loaded grid with {self.Nx * self.Ny * self.Nz} mesh cells:" ) print( f" * Number of cells: Nx={self.Nx}, Ny={self.Ny}, Nz={self.Nz}" ) print( f" * Simulation domain bounds: \n\ x:[{self.xmin:.3f}, {self.xmax:.3f}],\n\ y:[{self.ymin:.3f}, {self.ymax:.3f}],\n\ z:[{self.zmin:.3f}, {self.zmax:.3f}]" ) print( f" * STL solids imported:\n\ {list(self.stl_solids.keys())}" ) print( f" * STL solids assigned materials [eps_r, mu_r, sigma]:\n\ {list(self.stl_materials.values())}" ) # update logger self.update_logger(["Nx", "Ny", "Nz", "dx", "dy", "dz"]) self.update_logger(["xmin", "xmax", "ymin", "ymax", "zmin", "zmax"]) self.update_logger(["stl_solids", "stl_materials"]) if self.stl_rotate != [0.0, 0.0, 0.0]: self.update_logger(["stl_rotate"]) if self.stl_translate != [0.0, 0.0, 0.0]: self.update_logger(["stl_translate"]) if self.stl_scale != 1.0: self.update_logger(["stl_scale"])
[docs] def update_logger(self, attrs): """ Copy selected Grid attributes into the internal ``Logger``. Parameters ---------- attrs : iterable of str Names of attributes to copy into ``self.logger.grid``. """ for atr in attrs: self.logger.grid[atr] = getattr(self, atr)