79420873

Date: 2025-02-07 12:27:40
Score: 0.5
Natty:
Report link
import torch
import torch.nn.functional as F
import math

def get_proj(volume,right_angle,left_angle,distance_to_obj = 4,surface_extent=3,N_samples_per_ray=200,H_out=128, W_out=128,grid_sample_mode='bilinear'):
    """
    Generates a 2D projection of a 3D volume by casting rays from a specified camera position.

    This function simulates an orthographic projection of a 3D volume onto a 2D plane. The camera is positioned on a sphere
    centered at the origin, with its position determined by the provided right and left angles. Rays are cast from the camera
    through points on a plane tangent to the sphere, and the volume is sampled along these rays to produce the projection.

    Args:
        volume (torch.Tensor): A 5D tensor of shape (N, C, D, H, W) representing the 3D volume to be projected.
        right_angle (float): The azimuthal angle (in radians) determining the camera's position around the z-axis.
        left_angle (float): The polar angle (in radians) determining the camera's elevation from the xy-plane.
        distance_to_obj (float, optional): The distance from the camera to the origin. Defaults to 4.
        surface_extent (float, optional): The half-extent of the tangent plane in world units. Defines the plane's size. Defaults to 3.
        N_samples_per_ray (int, optional): The number of sample points along each ray. Higher values yield more accurate projections. Defaults to 200.
        H_out (int, optional): The height (in pixels) of the output 2D projection. Defaults to 128.
        W_out (int, optional): The width (in pixels) of the output 2D projection. Defaults to 128.

    Returns:
        torch.Tensor: A 4D tensor of shape (1, 1, H_out, W_out) representing the 2D projection of the input volume.

    Raises:
        ValueError: If the input volume is not a 5D tensor.
        RuntimeError: If the sampling grid is out of the volume's bounds.

    Example:
        ```python
        import torch

        # Create a sample 3D volume
        volume = torch.zeros((1, 1, 32, 32, 32))
        volume[0, 0, 16, :, :] = 1  # Add a plane in the middle

        # Define camera angles
        right_angle = 0.5  # radians
        left_angle = 0.3   # radians

        # Generate the projection
        projection = get_proj(volume, right_angle, left_angle)

        # Visualize the projection
        import matplotlib.pyplot as plt
        plt.imshow(projection.squeeze().cpu().numpy(), cmap='gray')
        plt.show()
        ```

    Note:
        - Ensure that the input volume is normalized to the range [-1, 1] for proper sampling.
        - The function assumes an orthographic projection model.
        - Adjust `N_samples_per_ray` for a trade-off between performance and projection accuracy.
    """
    device = volume.device
    
    ra = right_angle
    la = left_angle
    
    # Compute camera position p on the unit sphere.
    p = torch.tensor([
        math.cos(la) * math.cos(ra),
        math.cos(la) * math.sin(ra),
        math.sin(la)
    ]).to(device)
    p*=distance_to_obj
    # p is of shape (3,). (It lies on the unit sphere.)

    # The camera is at position p and always looks to the origin.
    # Define the opposite point on the sphere:
    q = -p  # This will be the point of tangency of the projection plane.

    # -------------------------------------------------------------------
    # 3. Define an orthonormal basis for the projection plane tangent to the unit sphere at q.
    # We need two vectors (right, up) lying in the plane.
    # One way is to choose a reference vector not colinear with q.
    # -------------------------------------------------------------------
    ref = torch.tensor([0.0, 0.0, 1.0]).to(device)
    if torch.allclose(torch.abs(q), torch.tensor([1.0, 1.0, 1.0]).to(device) * q[0], atol=1e-3):
        ref = torch.tensor([0.0, 1.0, 0.0])

    # Compute right as the normalized cross product of ref and q.
    right_vec = torch.cross(ref, q,dim=0)
    right_vec = right_vec / torch.norm(right_vec)

    # Compute up as the cross product of q and right.
    up_vec = torch.cross(q, right_vec)
    up_vec = up_vec / torch.norm(up_vec)

    # -------------------------------------------------------------------
    # 4. Build the image plane grid.
    #
    # We want to form an image on the plane tangent to the sphere at q.
    # The plane is defined by the equation: q · x = 1.
    #
    # A convenient parameterization is:
    #
    #    For (u, v) in some range, the 3D point on the plane is:
    #       P(u,v) = q + u * right_vec + v * up_vec.
    #
    # Note: Since q is a unit vector, q · q = 1 and q is perpendicular to both right_vec and up_vec,
    # so q · P(u,v) = 1 automatically.
    #
    # Choose an output image resolution and an extent for u and v.
    # -------------------------------------------------------------------
    # Choose an extent so that the sampled points remain in [-1,1]^3.
    # (Since our volume covers [-1,1]^3, a modest extent is needed.)
    extent = surface_extent  # you may adjust this value

    u_vals = torch.linspace(-extent, extent, W_out).to(device)
    v_vals = torch.linspace(-extent, extent, H_out).to(device)
    grid_v, grid_u = torch.meshgrid(v_vals, u_vals, indexing='ij')  # shapes: (H_out, W_out)

    # For each pixel (u,v) on the plane, compute its world coordinate.
    # P = q + u * right_vec + v * up_vec.
    plane_points = q.unsqueeze(0).unsqueeze(0) + \
                grid_u.unsqueeze(-1) * right_vec + \
                grid_v.unsqueeze(-1) * up_vec
    # plane_points shape: (H_out, W_out, 3)

    # -------------------------------------------------------------------
    # 5. For each pixel, sample along the ray from the camera p through the point P.
    #
    # Since the camera is at p and the ray passing through a pixel is along the line from p to P,
    # the ray can be parameterized as:
    #
    #    r(t) = p + t*(P - p),   for t in [0, 1]
    #
    # t=0 gives the camera position, t=1 gives the intersection with the image plane (P).
    # -------------------------------------------------------------------
    N_samples = N_samples_per_ray
    t_vals = torch.linspace(0, 1, N_samples).to(device)  # shape: (N_samples,)

    # Expand plane_points to sample along t:
    # plane_points has shape (H_out, W_out, 3). We want to combine it with p.
    # Compute (P - p): note that p is a vector; we can reshape it appropriately.
    P_minus_p = plane_points - p.unsqueeze(0).unsqueeze(0)  # shape: (H_out, W_out, 3)

    # Now, for each t, compute the sample point:
    # sample_point(t, u, v) = p + t*(P(u,v) - p)
    # We can do:
    sample_grid = p.unsqueeze(0).unsqueeze(0).unsqueeze(0) + \
                t_vals.view(N_samples, 1, 1, 1) * P_minus_p.unsqueeze(0)
    # sample_grid now has shape: (N_samples, H_out, W_out, 3).

    # Add a batch dimension (batch size 1) so that grid_sample sees a grid of shape:
    # (1, N_samples, H_out, W_out, 3)
    sample_grid = sample_grid.unsqueeze(0)

    # IMPORTANT: grid_sample expects the grid coordinates in the normalized coordinate system
    # of the input volume. Here our volume is defined on [-1, 1]^3. Make sure that the computed
    # sample_grid falls in that range. (Depending on extent, p, etc., you may need to adjust.)
    # For our setup, choose the parameters so that sample_grid is within [-1, 1].

    # -------------------------------------------------------------------
    # 6. Use grid_sample to sample the volume along each ray and integrate.
    # -------------------------------------------------------------------
    # grid_sample expects input volume of shape [N, C, D, H, W] and grid of shape [N, D_out, H_out, W_out, 3].
    proj_samples = F.grid_sample(volume, sample_grid, mode=grid_sample_mode, align_corners=False)
    # proj_samples has shape: (1, 1, N_samples, H_out, W_out)

    # For a simple projection (like an X-ray), integrate along the ray.
    # Here we simply sum along the sample (ray) dimension.
    proj_image = proj_samples.sum(dim=2)  # shape: (1, 1, H_out, W_out)
    return proj_image

It can be used like this

import matplotlib.pyplot as plt
# this is volume that defines 3d object
volume = torch.zeros(1, 1, 32, 32, 32, requires_grad=True).cuda()

def make_cube(volume):
    volume[0, 0, :, 0, 0] = 1
    volume[0, 0, :, -1, 0] = 1
    volume[0, 0, :, 0, -1] = 1
    volume[0, 0, :, -1, -1] = 1

    volume[0, 0, 0, :, 0] = 1
    volume[0, 0, -1, :, 0] = 1
    volume[0, 0, 0, :, -1] = 1
    volume[0, 0, -1, :, -1] = 1

    volume[0, 0, 0, -1, :] = 1
    volume[0, 0, 0, 0, :] = 1
    volume[0, 0, -1, 0, :] = 1
    volume[0, 0, -1, -1, :] = 1

with torch.no_grad():
    make_cube(volume)

# Create a figure and axis
fig, ax = plt.subplots()

right_angle =0.5
left_angle = 0.2
proj_image = get_proj(volume, right_angle, left_angle,surface_extent=4)

proj_image=proj_image.cpu().detach()[0, :].transpose(0,-1)


# Display the new image
plt.imshow(proj_image, cmap='gray')

cube

Reasons:
  • Probably link only (1):
  • Long answer (-1):
  • Has code block (-0.5):
  • Self-answer (0.5):
  • Low reputation (0.5):
Posted by: Kemsikov