@@ -800,11 +800,24 @@ def __init__(
800800 self .torch_dtype = get_equivalent_dtype (dtype , torch .Tensor )
801801 self .numpy_dtype = get_equivalent_dtype (dtype , np .ndarray )
802802 # Validate that dtype is floating-point for meaningful Gaussian values
803- if self .torch_dtype not in ( torch . float16 , torch . float32 , torch . float64 , torch . bfloat16 ) :
803+ if not self .torch_dtype . is_floating_point :
804804 raise ValueError (f"dtype must be a floating-point type, got { self .torch_dtype } " )
805805 self .spatial_shape = None if spatial_shape is None else tuple (int (s ) for s in spatial_shape )
806806
807807 def __call__ (self , points : NdarrayOrTensor , spatial_shape : Sequence [int ] | None = None ) -> NdarrayOrTensor :
808+ """
809+ Args:
810+ points: landmark coordinates as ndarray/Tensor with shape (N, D) or (B, N, D),
811+ ordered as (Y, X) for 2D or (Z, Y, X) for 3D.
812+ spatial_shape: spatial size as a sequence or single int (broadcasted). If None, uses
813+ the value provided at construction.
814+
815+ Returns:
816+ Heatmaps with shape (N, *spatial) or (B, N, *spatial), one channel per landmark.
817+
818+ Raises:
819+ ValueError: if points shape/dimension or spatial_shape is invalid.
820+ """
808821 original_points = points
809822 points_t = convert_to_tensor (points , dtype = torch .float32 , track_meta = False )
810823
@@ -828,13 +841,15 @@ def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None
828841
829842 heatmap = torch .zeros ((batch_size , num_points , * target_shape ), dtype = self .torch_dtype , device = device )
830843 image_bounds = tuple (int (s ) for s in target_shape )
844+ bounds_t = torch .as_tensor (image_bounds , device = device , dtype = points_t .dtype )
831845 for b_idx in range (batch_size ):
832846 for idx , center in enumerate (points_t [b_idx ]):
833- center_vals = center .tolist ()
834- if not np .all (np .isfinite (center_vals )):
847+ if not torch .isfinite (center ).all ():
835848 continue
836- if not self . _is_inside ( center_vals , image_bounds ):
849+ if not (( center >= 0 ). all () and ( center < bounds_t ). all () ):
837850 continue
851+ # _make_window expects Python floats; convert only when needed
852+ center_vals = center .tolist ()
838853 window_slices , coord_shifts = self ._make_window (center_vals , radius , image_bounds , device )
839854 if window_slices is None :
840855 continue
0 commit comments