# Doriflow Engine - Fluid Simulation for Blender 3D
# Copyright (C) 2024 Doriflow Team
# This software is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0).

# You are free to:
# -Share: Copy and redistribute the material in any medium or format.
# -Adapt: Remix, transform, and build upon the material.

# UNDER THE FOLOWING TERMS:
# -Attribution: You must give appropriate credit, provide a link to the license, and indicate if changes were made.
# -Appropriate credit should include the following:
#   -The original author's name: Doriflow Team
#   -A link to the original source (if applicable).
#   -A link to the full license: https://creativecommons.org/licenses/by-nc/4.0/.
#   -A clear indication of any changes made, such as: "This material has been modified."
# NonCommercial: You may not use the material for commercial purposes.
# Disclaimer:
# -This simulation engine is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. In no event shall the authors be liable for any claim, damages, or other liability arising from the use of this software.

# For more details, refer to the full license text at:
# https://creativecommons.org/licenses/by-nc/4.0/.

#-----------------------------------------------------------------------------------------------------------------------#
import taichi as ti
import numpy as np
from particle_system import ParticleSystem
import os
import json
import trimesh as tm


@ti.func
def normalize_quaternion(q) -> ti.types.vector(4, float):
    return q / q.norm(1e-6)

@ti.func
def quaternion_multiply(q1: ti.types.vector(4, float), q2: ti.types.vector(4, float)) -> ti.types.vector(4, float):
    w1, x1, y1, z1 = q1
    w2, x2, y2, z2 = q2
    return ti.Matrix([
        w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
        w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
        w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
        w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,
    ])

@ti.func
def quaternion_to_rotation_matrix(q: ti.types.vector(4, float)) -> ti.types.matrix(3, 3, float):
    w, x, y, z = q
    return ti.Matrix([
        [1 - 2 * (y**2 + z**2), 2 * (x * y - z * w), 2 * (x * z + y * w)],
        [2 * (x * y + z * w), 1 - 2 * (x**2 + z**2), 2 * (y * z - x * w)],
        [2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x**2 + y**2)],
    ])


@ti.func
def integrate_quaternion(q, angular_velocity, dt) -> ti.types.vector(4, float):
    wx, wy, wz = angular_velocity
    omega = ti.Vector([0, wx, wy, wz])  
    dq = 0.5 * quaternion_multiply(q, omega) * dt
    q_new = q + dq
    return normalize_quaternion(q_new)

@ti.data_oriented
class CoreSteps:
    def __init__(self, particle_system:ParticleSystem):  
        self.ps = particle_system
        self.g = np.array(self.ps.cfg.get_domain("gravitation"))
        self.fps = self.ps.cfg.get_domain("fps")
        self.g_taichi = ti.Vector.field(3, dtype=ti.f32, shape=())
        self.g_taichi[None] = ti.Vector(self.ps.cfg.get_domain("gravitation"))
        self.viscosity = self.ps.cfg.get_domain("viscosity") 
        self.fluid_density = self.ps.cfg.get_domain("fluid_density")
        self.dt = ti.field(float, shape=())
        self.dt[None] = self.ps.cfg.get_domain("timeStepSize")
        self.output_interval = int((1 / self.fps) / self.dt[None])
        self.damping_factor=1
        self.added_mass = 0
        self.max_curvature = ti.field(float, shape=())
        self.max_curvature[None] = 0
        self.boundary_condition = int(self.ps.cfg.get_domain("BCs"))
        self.padding_factor = 2
        self.rigid_padding = self.ps.padding*self.padding_factor
        self.rotation_limit_factor = 1
        self.boundary_collision_factor = self.ps.cfg.get_domain("boundary_collision_factor")
        self.rotation_factor = self.ps.cfg.get_domain("rotation_factor")
        self.rigid_rigid_collision_factor= 0
        self.correction_threshold = 1e-6
        self.overlap_threshold_factor = self.ps.cfg.get_domain("rigid_rigid_overlap_threshold")
        self.max_correction = self.overlap_threshold_factor*self.ps.particle_diameter
        self.num_frames = self.ps.cfg.get_domain("end_frame")
        self.mean_velocity = ti.Vector.field(3, dtype=ti.f32, shape=())
        self.std_dev = ti.Vector.field(3, dtype=ti.f32, shape=())
        self.max_velocity_change_factor = 1e6
        self.max_velocity_change_factor = self.ps.cfg.get_domain("max_velocity_change_factor")
        #DEM
        self.dem_boundary_friction_coefficient = 0.1
        self.dem_boundary_friction_coefficient = self.ps.cfg.get_domain("dem_boundary_friction_coefficient")
        self.dem_boundary_bouncing_coefficient = 0.1
        self.dem_boundary_bouncing_coefficient = self.ps.cfg.get_domain("dem_boundary_bouncing_coefficient")
        #Rigid
        self.rigid_bouncing_coefficient = 0.1
        self.rigid_bouncing_coefficient = self.ps.cfg.get_domain("rigid_bouncing_coefficient")
        self.rigid_friction_coefficient = 0.1
        self.rigid_friction_coefficient = self.ps.cfg.get_domain("rigid_friction_coefficient")
        self.rigid_damping_coefficient = 0.1
        self.rigid_damping_coefficient = self.ps.cfg.get_domain("rigid_damping_coefficient")
    @ti.func
    def cubic_kernel(self, r_norm):
        res = ti.cast(0.0, ti.f32)
        h = self.ps.support_radius
        k = ti.cast(1.0, ti.f32)
        if self.ps.dim == 1:
            k = ti.cast(4.0 / 3.0, ti.f32)
        elif self.ps.dim == 2:
            k = ti.cast(40.0 / 7.0, ti.f32) / ti.cast(np.pi, ti.f32)
        elif self.ps.dim == 3:
            k = ti.cast(8.0, ti.f32) / ti.cast(np.pi, ti.f32)
        k /= h ** self.ps.dim
        q = r_norm / h
        if q <= ti.cast(1.0, ti.f32):
            if q <= ti.cast(0.5, ti.f32):
                q2 = q * q
                q3 = q2 * q
                res = k * (ti.cast(6.0, ti.f32) * q3 - ti.cast(6.0, ti.f32) * q2 + ti.cast(1.0, ti.f32))
            else:
                res = k * ti.cast(2.0, ti.f32) * ti.pow(ti.cast(1.0, ti.f32) - q, ti.cast(3.0, ti.f32))
        return res

    @ti.func
    def cubic_kernel_derivative(self, r):
        h = self.ps.support_radius
        k = ti.cast(1.0, ti.f32)
        if self.ps.dim == 1:
            k = ti.cast(4.0 / 3.0, ti.f32)
        elif self.ps.dim == 2:
            k = ti.cast(40.0 / 7.0, ti.f32) / ti.cast(np.pi, ti.f32)
        elif self.ps.dim == 3:
            k = ti.cast(8.0, ti.f32) / ti.cast(np.pi, ti.f32)
        k = ti.cast(6.0, ti.f32) * k / h ** self.ps.dim
        r_norm = r.norm()
        q = r_norm / h
        res = ti.Vector([ti.cast(0.0, ti.f32) for _ in range(self.ps.dim)])
        if r_norm > ti.cast(1e-5, ti.f32) and q <= ti.cast(1.0, ti.f32):
            grad_q = r / (r_norm * h)
            if q <= ti.cast(0.5, ti.f32):
                res = k * q * (ti.cast(3.0, ti.f32) * q - ti.cast(2.0, ti.f32)) * grad_q
            else:
                factor = ti.cast(1.0, ti.f32) - q
                res = k * (-factor * factor) * grad_q
        return res

    @ti.func
    def poly6_kernel(self, r_norm):
        res = ti.cast(0.0, ti.f32)
        h = self.ps.support_radius
        k = ti.cast(315.0, ti.f32) / (ti.cast(64.0, ti.f32) * ti.cast(np.pi, ti.f32) * h**9)
        if r_norm <= h:
            res = k * (h**2 - r_norm**2)**3
        return res

    @ti.func
    def poly6_kernel_derivative(self, r):
        res = ti.Vector([ti.cast(0.0, ti.f32) for _ in range(self.ps.dim)])
        h = self.ps.support_radius
        r_norm = r.norm()
        k = ti.cast(-945.0, ti.f32) / (ti.cast(32.0, ti.f32) * ti.cast(np.pi, ti.f32) * h**9)
        if r_norm > ti.cast(1e-5, ti.f32) and r_norm <= h:
            res = k * (h**2 - r_norm**2)**2 * r / r_norm
        return res

    @ti.func
    def spiky_kernel(self, r_norm):
        res = ti.cast(0.0, ti.f32)
        h = self.ps.support_radius
        k = ti.cast(15.0, ti.f32) / (ti.cast(np.pi, ti.f32) * h**6)
        if r_norm <= h:
            res = k * (h - r_norm)**3
        return res

    @ti.func
    def spiky_kernel_derivative(self, r):
        res = ti.Vector([ti.cast(0.0, ti.f32) for _ in range(self.ps.dim)])
        h = self.ps.support_radius
        r_norm = r.norm()
        k = ti.cast(-45.0, ti.f32) / (ti.cast(np.pi, ti.f32) * h**6)
        if r_norm > ti.cast(1e-5, ti.f32) and r_norm <= h:
            res = k * (h - r_norm)**2 * r / r_norm
        return res

    @ti.func
    def wendland_quintic_kernel(self, r_norm):
        res = ti.cast(0.0, ti.f32)
        h = self.ps.support_radius
        alpha = ti.cast(21.0, ti.f32) / (ti.cast(16.0, ti.f32) * ti.cast(np.pi, ti.f32) * h**3)
        q = r_norm / h
        if q <= ti.cast(1.0, ti.f32):
            res = alpha * (ti.cast(1.0, ti.f32) - q)**4 * (ti.cast(1.0, ti.f32) + ti.cast(4.0, ti.f32) * q)
        return res

    @ti.func
    def wendland_quintic_kernel_derivative(self, r):
        res = ti.Vector([ti.cast(0.0, ti.f32) for _ in range(self.ps.dim)])
        h = self.ps.support_radius
        alpha = ti.cast(21.0, ti.f32) / (ti.cast(16.0, ti.f32) * ti.cast(np.pi, ti.f32) * h**3)
        r_norm = r.norm()
        q = r_norm / h
        if r_norm > ti.cast(1e-5, ti.f32) and q <= ti.cast(1.0, ti.f32):
            res = alpha * (ti.cast(-20.0, ti.f32) * q / h * (ti.cast(1.0, ti.f32) - q)**3) * r / r_norm
        return res

    
    def initialize(self):
        self.ps.grid_prefix_sort()
        for r_obj_id in self.ps.object_id_rigid_body:
            self.compute_rigid_rest_cm(r_obj_id)
            self.compute_inertia_tensor(r_obj_id)
        if self.ps.num_rigid_bodies > 1:
            self.compute_static_boundary_volume()
            self.compute_moving_boundary_volume()


    @ti.kernel
    def compute_rigid_rest_cm(self, object_id: int):
        self.ps.rigid_rest_cm[object_id] = self.compute_com(object_id)
        self.ps.rigid_body_centers_of_mass[object_id] = self.compute_com(object_id)
    @ti.kernel
    def compute_static_boundary_volume(self):
        for p_i in ti.grouped(self.ps.x):
            if not self.ps.is_static_rigid_body(p_i):
                continue
            delta = self.cubic_kernel(0.0)
            self.ps.for_all_neighbors(p_i, self.compute_boundary_volume_task, delta)
            self.ps.m_V[p_i] = 1.0 / delta
    @ti.func
    def compute_boundary_volume_task(self, p_i, p_j, delta: ti.template()):
        if self.ps.material[p_j] == self.ps.material_solid:
            delta += self.cubic_kernel((self.ps.x[p_i] - self.ps.x[p_j]).norm())            
    @ti.kernel
    def compute_moving_boundary_volume(self):
        for p_i in ti.grouped(self.ps.x):
            if not self.ps.is_dynamic_rigid_body(p_i) :
                continue
            delta = self.cubic_kernel(0.0)
            self.ps.for_all_neighbors(p_i, self.compute_boundary_volume_task, delta)
            self.ps.m_V[p_i] = 1 / delta
    def substep(self):
        pass
    def substep_dfsph(self):
        pass
    def dem_substep(self):
        pass
    @ti.func
    def simulate_collisions(self, p_i, vec, boundary_condition):
        c_f = 0.10
        if self.boundary_condition == 0: 
            self.ps.v[p_i] -= (1+c_f) * self.ps.v[p_i].dot(vec) * vec
        elif self.boundary_condition == 1:
            self.ps.v[p_i] = 0
    @ti.func
    def simulate_collisions_solid(self, p_i, vec, boundary_condition, object_id):
        x_i = self.ps.x[p_i]
        if self.boundary_condition == 0: 
            acceleration = -(self.boundary_collision_factor) * self.ps.v[p_i].dot(vec) * vec / self.dt[None]
            center_of_mass_j = self.ps.rigid_body_centers_of_mass[object_id]
            force_j = -acceleration * self.ps.density[p_i] * self.ps.m_V0
            torque_j = ti.math.cross(x_i - center_of_mass_j, force_j)
            self.ps.rigid_body_forces[object_id] += force_j 
        elif self.boundary_condition == 1:  
            normal_component = self.ps.v[p_i].dot(vec) * vec
            acceleration = -(self.boundary_collision_factor) * normal_component / self.dt[None]
            center_of_mass_j = self.ps.rigid_body_centers_of_mass[object_id]
            force_j = -acceleration * self.ps.density[object_id] * self.ps.m_V0
            torque_j = ti.math.cross(x_i - center_of_mass_j, force_j)
            self.ps.rigid_body_forces[object_id] += force_j
    @ti.kernel
    def enforce_boundary_3D(self, particle_type: int):
        for p_i in ti.grouped(self.ps.x):
            if self.ps.material[p_i] == particle_type and self.ps.is_dynamic[p_i]:
                pos = self.ps.x[p_i]
                collision_normal = ti.Vector([0.0, 0.0, 0.0])
                if pos[0] > self.ps.domain_start[0] + self.ps.domain_size[0] - self.ps.padding*2:
                    collision_normal[0] += 1.0
                    self.ps.x[p_i][0] = self.ps.domain_start[0] + self.ps.domain_size[0] - self.ps.padding*2 
                if pos[0] <= self.ps.domain_start[0] + self.ps.padding*2:
                    collision_normal[0] -= 1.0
                    self.ps.x[p_i][0] = self.ps.domain_start[0] + self.ps.padding*2
                if pos[1] > self.ps.domain_start[1] + self.ps.domain_size[1] - self.ps.padding*2:
                    collision_normal[1] += 1.0
                    self.ps.x[p_i][1] = self.ps.domain_start[1] + self.ps.domain_size[1] - self.ps.padding*2
                if pos[1] <= self.ps.domain_start[1] + self.ps.padding*2:
                    collision_normal[1] -= 1.0
                    self.ps.x[p_i][1] = self.ps.domain_start[1] + self.ps.padding*2
                if pos[2] > self.ps.domain_start[2] + self.ps.domain_size[2] - self.ps.padding*2:
                    collision_normal[2] += 1.0
                    self.ps.x[p_i][2] = self.ps.domain_start[2] + self.ps.domain_size[2] - self.ps.padding*2
                if pos[2] <= self.ps.domain_start[2] + self.ps.padding*2:
                    collision_normal[2] -= 1.0
                    self.ps.x[p_i][2] = self.ps.domain_start[2] + self.ps.padding*2
                collision_normal_length = collision_normal.norm()
                if collision_normal_length > 1e-6:
                    self.simulate_collisions(p_i, collision_normal / collision_normal_length, 0)
    
    
    @ti.kernel
    def enforce_boundary_3D_grain(self, particle_type: int):
        for p_i in ti.grouped(self.ps.x):
            if self.ps.material[p_i] == particle_type and self.ps.is_dynamic[p_i]:
                pos = self.ps.x[p_i]
                if pos[0] > self.ps.domain_start[0] + self.ps.domain_size[0] - self.ps.padding*2:
                    normal = ti.Vector([1.0, 0.0, 0.0])
                    self.ps.x[p_i][0] = self.ps.domain_start[0] + self.ps.domain_size[0] - self.ps.padding*2 
                    self.apply_boundary_collision_grain(p_i, normal)
                elif pos[0] <= self.ps.domain_start[0] + self.ps.padding*2:
                    normal = ti.Vector([-1.0, 0.0, 0.0])
                    self.ps.x[p_i][0] = self.ps.domain_start[0] + self.ps.padding*2 
                    self.apply_boundary_collision_grain(p_i, normal)

                if pos[1] > self.ps.domain_start[1] + self.ps.domain_size[1] - self.ps.padding*2:
                    normal = ti.Vector([0.0, 1.0, 0.0])
                    self.ps.x[p_i][1] = self.ps.domain_start[1] + self.ps.domain_size[1] - self.ps.padding*2 
                    self.apply_boundary_collision_grain(p_i, normal)
                elif pos[1] <= self.ps.domain_start[1] + self.ps.padding*2:
                    normal = ti.Vector([0.0, -1.0, 0.0])
                    self.ps.x[p_i][1] = self.ps.domain_start[1] + self.ps.padding*2 
                    self.apply_boundary_collision_grain(p_i, normal)

                if pos[2] > self.ps.domain_start[2] + self.ps.domain_size[2] - self.ps.padding*2:
                    normal = ti.Vector([0.0, 0.0, 1.0])
                    self.ps.x[p_i][2] = self.ps.domain_start[2] + self.ps.domain_size[2] - self.ps.padding*2 
                    self.apply_boundary_collision_grain(p_i, normal)
                elif pos[2] <= self.ps.domain_start[2] + self.ps.padding*2:
                    normal = ti.Vector([0.0, 0.0, -1.0])
                    self.ps.x[p_i][2] = self.ps.domain_start[2] + self.ps.padding*2 
                    self.apply_boundary_collision_grain(p_i, normal)
    @ti.func
    def apply_boundary_collision_grain(self, p_i, normal):
        v = self.ps.v[p_i]
        v_n = normal * v.dot(normal)   
        v_t = v - v_n                
        v_n_reflected = -self.dem_boundary_bouncing_coefficient * v_n 
        v_t_friction =  (1-self.dem_boundary_friction_coefficient) * v_t
        self.ps.v[p_i] = v_n_reflected + v_t_friction 

    
    @ti.kernel
    def enforce_boundary_3D_rigid(self, particle_type: int):
        for p_i in ti.grouped(self.ps.x):
            if self.ps.material[p_i] == particle_type and self.ps.is_dynamic[p_i]:
                pos = self.ps.x[p_i]
                object_id = self.ps.object_id[p_i]

                collision_normal = ti.Vector([0.0, 0.0, 0.0])
                penetration_depth = 0.0
                if pos[0] > self.ps.domain_start[0] + self.ps.domain_size[0] - self.ps.padding * 2:
                    collision_normal[0] += 1.0
                    penetration_depth = pos[0] - (self.ps.domain_start[0] + self.ps.domain_size[0] - self.ps.padding * 2)
                elif pos[0] <= self.ps.domain_start[0] + self.ps.padding * 2:
                    collision_normal[0] -= 1.0
                    penetration_depth = (self.ps.domain_start[0] + self.ps.padding * 2) - pos[0]

                if pos[1] > self.ps.domain_start[1] + self.ps.domain_size[1] - self.ps.padding * 2:
                    collision_normal[1] += 1.0
                    penetration_depth = pos[1] - (self.ps.domain_start[1] + self.ps.domain_size[1] - self.ps.padding * 2)
                elif pos[1] <= self.ps.domain_start[1] + self.ps.padding * 2:
                    collision_normal[1] -= 1.0
                    penetration_depth = (self.ps.domain_start[1] + self.ps.padding * 2) - pos[1]

                if pos[2] > self.ps.domain_start[2] + self.ps.domain_size[2] - self.ps.padding * 2:
                    collision_normal[2] += 1.0
                    penetration_depth = pos[2] - (self.ps.domain_start[2] + self.ps.domain_size[2] - self.ps.padding * 2)
                elif pos[2] <= self.ps.domain_start[2] + self.ps.padding * 2:
                    collision_normal[2] -= 1.0
                    penetration_depth = (self.ps.domain_start[2] + self.ps.padding * 2) - pos[2]

                if collision_normal.norm() > 1e-6:
                    collision_normal = collision_normal.normalized()
                    mass = self.ps.rigid_body_masses[object_id]
                    relative_velocity = self.ps.rigid_body_linear_velocities[object_id].dot(collision_normal)
                    reaction_force = -collision_normal * (
                        self.rigid_bouncing_coefficient * penetration_depth * mass / self.dt[None] +
                        self.rigid_damping_coefficient * relative_velocity * mass
                    )
                    self.ps.rigid_body_forces[object_id] += reaction_force
                    relative_pos = pos - self.ps.rigid_body_centers_of_mass[object_id]
                    torque = relative_pos.cross(reaction_force)
                    self.ps.rigid_body_torques[object_id] += torque
                    linear_velocity = self.ps.rigid_body_linear_velocities[object_id]
                    angular_velocity = self.ps.rigid_body_angular_velocities[object_id]
                    contact_velocity = linear_velocity + angular_velocity.cross(relative_pos)
                    tangential_velocity = contact_velocity - contact_velocity.dot(collision_normal) * collision_normal
                    tangential_speed = tangential_velocity.norm()
                    if tangential_speed > 1e-6:
                        tangential_dir = tangential_velocity.normalized()
                        max_friction = self.rigid_friction_coefficient * reaction_force.norm()
                        friction_force = -tangential_dir * min(max_friction, tangential_speed * mass / self.dt[None])
                        self.ps.rigid_body_forces[object_id] += friction_force
                        friction_torque = relative_pos.cross(friction_force)
                        self.ps.rigid_body_torques[object_id] += friction_torque
                    self.ps.x[p_i] -= collision_normal * penetration_depth

    @ti.func
    def compute_com(self, object_id):
        sum_m = 0.0
        cm = ti.Vector([0.0, 0.0, 0.0])
        for p_i in self.ps.x:
            if self.ps.active[p_i] and self.ps.is_dynamic_rigid_body(p_i) and self.ps.object_id[p_i] == object_id:
                mass = self.ps.m_V0 * self.ps.density[p_i] 
                cm += mass * self.ps.x[p_i]
                sum_m += mass
        cm /= sum_m
        self.ps.rigid_body_masses[object_id] = sum_m
        return cm
    
    @ti.func
    def compute_mass(self, object_id):
        sum_m = 0.0
        for p_i in self.ps.x:
            if self.ps.active[p_i] and self.ps.is_dynamic_rigid_body(p_i) and self.ps.object_id[p_i] == object_id:
                sum_m += self.ps.m_V0 * self.ps.density[p_i]
        return sum_m
    @ti.kernel
    def compute_com_kernel(self, object_id: int):
        self.ps.rigid_body_centers_of_mass[object_id] = self.compute_com(object_id)
    def compute_inertia_tensor(self, object_id: int):
        self.compute_generic_inertia(object_id)
    @ti.kernel
    def compute_generic_inertia(self, object_id: int):
        cm = self.ps.rigid_body_centers_of_mass[object_id]
        inertia_tensor = ti.Matrix.zero(ti.f32, 3, 3)
        for p_i in self.ps.x:
            if self.ps.active[p_i] and self.ps.is_dynamic_rigid_body(p_i) and self.ps.object_id[p_i] == object_id:
                r = self.ps.x[p_i] - cm
                mass = self.ps.m_V0 * self.ps.density[p_i]
                inertia_tensor += mass * (r.norm_sqr() * ti.Matrix.identity(ti.f32, 3) - r.outer_product(r))
        self.ps.inertia_tensor_inv[object_id] = inertia_tensor.inverse()       
    @ti.kernel
    def apply_forces_and_update_velocities(self, object_id: int):
        total_force = self.ps.rigid_body_forces[object_id]
        total_torque = self.ps.rigid_body_torques[object_id]
        mass = self.ps.rigid_body_masses[object_id]
        rotation_matrix = quaternion_to_rotation_matrix(self.ps.quaternion[object_id])
        inertia_world_inv = rotation_matrix @ self.ps.inertia_body_inv[object_id] @ rotation_matrix.transpose()

        linear_velocity = self.ps.rigid_body_linear_velocities[object_id] + total_force / mass * self.dt[None]
        angular_velocity = self.ps.rigid_body_angular_velocities[object_id] + inertia_world_inv @ total_torque * self.dt[None]
        self.ps.rigid_body_linear_velocities[object_id] = linear_velocity
        self.ps.rigid_body_angular_velocities[object_id] = angular_velocity
        self.ps.rigid_body_angular_velocities[object_id] *= 0.99
        self.ps.rigid_body_forces[object_id].fill(0)
        self.ps.rigid_body_torques[object_id].fill(0)

    @ti.kernel
    def update_particle_positions(self, object_id: int):
        cm = self.ps.rigid_body_centers_of_mass[object_id]
        linear_velocity = self.ps.rigid_body_linear_velocities[object_id]
        angular_velocity = self.ps.rigid_body_angular_velocities[object_id]
        updated_cm = cm        
        self.ps.quaternion[object_id] = integrate_quaternion(self.ps.quaternion[object_id], angular_velocity, self.dt[None])
        rotation_matrix = quaternion_to_rotation_matrix(self.ps.quaternion[object_id])
        self.ps.centers_of_mass[object_id] = cm - self.ps.domain_translation_vector
        updated_cm = self.ps.rigid_rest_cm[object_id]
        self.ps.rotation_matrices[object_id] = rotation_matrix
        for p_i in self.ps.x:
            if self.ps.active[p_i] and self.ps.is_dynamic_rigid_body(p_i) and self.ps.object_id[p_i] == object_id:
                relative_pos = (self.ps.x_0[p_i] - self.ps.rigid_rest_cm[object_id])
                self.ps.x[p_i] = rotation_matrix @ relative_pos + updated_cm


    @ti.func
    def skew_symmetric(self, u: ti.types.vector(3, float)) -> ti.types.matrix(3, 3, float):
        return ti.Matrix([[0, -u[2], u[1]], [u[2], 0, -u[0]], [-u[1], u[0], 0]])
    
    @ti.kernel
    def update_rotation_quaternion(self, object_id: int):
        omega = self.ps.rigid_body_angular_velocities[object_id]
        dt = self.dt[None] 
        omega_norm = omega.norm()
        dq = ti.Vector([0.0, 0.0, 0.0, 1.0])
        if omega_norm > 1e-6:
            half_theta = 0.5 * omega_norm * dt
            sin_half_theta = ti.sin(half_theta)
            axis = omega / omega_norm
            dq = ti.Vector([
                axis[0] * sin_half_theta,
                axis[1] * sin_half_theta,
                axis[2] * sin_half_theta,
                ti.cos(half_theta)
            ])
        else:
            dq = ti.Vector([
                0.5 * omega[0] * dt,
                0.5 * omega[1] * dt,
                0.5 * omega[2] * dt,
                1.0
            ])
        dq = dq / dq.norm()
        q = self.ps.rotation_quaternions[object_id]
        q_new = self.quaternion_multiply(q, dq)
        q_new = q_new / q_new.norm()  
        self.ps.rotation_quaternions[object_id] = q_new


    @ti.func
    def quaternion_to_rotation_matrix(self, q: ti.types.vector(4, float)) -> ti.types.matrix(3, 3, float):
        x, y, z, w = q[0], q[1], q[2], q[3]
        xx, yy, zz = x * x, y * y, z * z
        xy, xz, yz = x * y, x * z, y * z
        wx, wy, wz = w * x, w * y, w * z

        return ti.Matrix([
            [1 - 2 * (yy + zz),     2 * (xy - wz),         2 * (xz + wy)],
            [    2 * (xy + wz), 1 - 2 * (xx + zz),         2 * (yz - wx)],
            [    2 * (xz - wy),     2 * (yz + wx),     1 - 2 * (xx + yy)]
        ])

    @ti.kernel
    def update_particle_positions_quaternions(self, object_id: int):
        cm = self.ps.rigid_body_centers_of_mass[object_id]
        rotation_quaternion = self.ps.rotation_quaternions[object_id]
        rotation_matrix = self.quaternion_to_rotation_matrix(rotation_quaternion)
        linear_velocity = self.ps.rigid_body_linear_velocities[object_id]
        dt = self.dt[None] 
        for p_i in self.ps.x:
            if self.ps.active[p_i] and self.ps.is_dynamic_rigid_body(p_i) and self.ps.object_id[p_i] == object_id:
                r = self.ps.x[p_i] - cm
                rotated_r = rotation_matrix @ r
                self.ps.x[p_i] = rotated_r + cm + linear_velocity * dt
        self.ps.centers_of_mass[object_id] = cm - self.ps.domain_translation_vector
        self.ps.rigid_body_centers_of_mass[object_id] = cm - self.ps.domain_translation_vector
        self.ps.rotation_matrices[object_id] = rotation_matrix

    @ti.func
    def quaternion_multiply(self, q1: ti.types.vector(4, float), q2: ti.types.vector(4, float)) -> ti.types.vector(4, float):
        x1, y1, z1, w1 = q1[0], q1[1], q1[2], q1[3]
        x2, y2, z2, w2 = q2[0], q2[1], q2[2], q2[3]
        return ti.Vector([
            w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
            w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
            w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,
            w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
        ])

    @ti.func
    def angular_velocity_to_rotation_matrix(self, omega: ti.types.vector(3, float), dt: float) -> ti.types.matrix(3, 3, float):
        theta = omega.norm() * dt
        R = ti.Matrix.identity(ti.f32, 3) 
        if theta > 1e-6:  
            u = omega.normalized() 
            K = self.skew_symmetric(u) 
            sin_theta = ti.sin(theta)
            cos_theta = ti.cos(theta)
            R = R + sin_theta * K + (1 - cos_theta) * K @ K  
        return R


    @ti.kernel
    def compute_com_kernel_para(self):
        for object_id_index in (self.ps.object_id_rigid_body_ti):
            object_id = self.ps.object_id_rigid_body_ti[object_id_index]
            if self.ps.rigid_body_forces[object_id].norm() == 0 and self.ps.rigid_body_torques[object_id].norm() == 0:
                continue
            else:
                self.ps.rigid_body_centers_of_mass[object_id] = self.compute_com_para(object_id)
                
            
    @ti.func
    def compute_com_para(self, object_id):
        sum_m = 0.0
        cm = ti.Vector([0.0, 0.0, 0.0])
        for p_i in range(self.ps.total_particle_num):
            if self.ps.active[p_i] and self.ps.is_dynamic_rigid_body(p_i) and self.ps.object_id[p_i] == object_id:
                mass = self.ps.m_V0 * self.ps.density[p_i] 
                cm += mass * self.ps.x[p_i]
                sum_m += mass
        cm /= sum_m
        self.ps.rigid_body_masses[object_id] = sum_m
        return cm

    @ti.kernel
    def apply_forces_and_update_velocities_kernel_para(self):
        for object_id_index in (self.ps.object_id_rigid_body_ti):
            object_id = self.ps.object_id_rigid_body_ti[object_id_index]
            total_force = self.ps.rigid_body_forces[object_id]
            total_torque = self.ps.rigid_body_torques[object_id]
            mass = self.ps.rigid_body_masses[object_id]
            
            rotation_matrix = quaternion_to_rotation_matrix(self.ps.quaternion[object_id])
            inertia_world_inv = rotation_matrix @ self.ps.inertia_body_inv[object_id] @ rotation_matrix.transpose()

            self.ps.rigid_body_linear_velocities[object_id] += total_force / mass * self.dt[None]
            self.ps.rigid_body_angular_velocities[object_id] += inertia_world_inv @ total_torque * self.dt[None]
            self.ps.rigid_body_angular_velocities[object_id] *= 0.98
            self.ps.rigid_body_forces[object_id].fill(0)
            self.ps.rigid_body_torques[object_id].fill(0)

    @ti.kernel
    def update_particle_positions_kernel_para(self):
        for object_id_index in self.ps.object_id_rigid_body_ti:
            object_id = self.ps.object_id_rigid_body_ti[object_id_index]
            cm = self.ps.rigid_body_centers_of_mass[object_id]
            linear_velocity = self.ps.rigid_body_linear_velocities[object_id]
            angular_velocity = self.ps.rigid_body_angular_velocities[object_id]
            
            self.ps.quaternion[object_id] = integrate_quaternion(
                self.ps.quaternion[object_id], angular_velocity, self.dt[None]
            )
            rotation_matrix = quaternion_to_rotation_matrix(self.ps.quaternion[object_id])
            
            self.ps.centers_of_mass[object_id] = self.ps.rigid_body_centers_of_mass[object_id] - self.ps.domain_translation_vector + self.ps.rigid_body_linear_velocities[object_id] * self.dt[None]
            self.ps.rigid_body_centers_of_mass[object_id] =self.ps.rigid_body_centers_of_mass[object_id] + self.ps.rigid_body_linear_velocities[object_id] * self.dt[None] 
                
            self.ps.rotation_matrices[object_id] = rotation_matrix

    
    @ti.kernel
    def update_positions_rigid(self):
        for p_i in range(self.ps.total_particle_num):
            object_id = self.ps.object_id[p_i]
            if self.ps.active[p_i] and self.ps.is_dynamic_rigid_body(p_i) and self.ps.object_id[p_i] == object_id:
                relative_pos = self.ps.x_0[p_i] - self.ps.rigid_rest_cm[object_id]
                self.ps.x[p_i] = self.ps.rotation_matrices[object_id] @ relative_pos + self.ps.rigid_body_centers_of_mass[object_id] 
                linear_velocity = self.ps.rigid_body_linear_velocities[object_id]
                angular_velocity = self.ps.rigid_body_angular_velocities[object_id]
                particle_velocity = linear_velocity + angular_velocity.cross(relative_pos)
                self.ps.v[p_i] = particle_velocity
                    
    @ti.kernel
    def advect_solid_gravity_kernel_para(self):
        for p_i in ti.grouped(self.ps.x):
            object_id = self.ps.object_id[p_i]
            if not self.ps.is_dynamic_rigid_body(p_i):
                continue
            if self.ps.material[p_i] == self.ps.material_solid:
                mass = self.ps.m[p_i]
                gravity_force = mass * ti.Vector(self.g)
                self.ps.rigid_body_forces[object_id] += gravity_force
    def solve_rigid_body_dynamics_para(self):
            self.apply_forces_and_update_velocities_kernel_para()
            self.update_particle_positions_kernel_para()
            self.update_positions_rigid()
            self.advect_solid_gravity_kernel_para()
    
    def solve_center_of_mass(self):
        self.compute_com_kernel_para()

#--------------------------------------------------------------------------------------#
                
    @ti.kernel
    def advect_solid_gravity(self,target_object_id: int):
         for p_i in ti.grouped(self.ps.x):
            object_id = self.ps.object_id[p_i]
            if object_id != target_object_id:
                continue
            x_i = self.ps.x[p_i]
            if self.ps.is_static_rigid_body(p_i):
                self.ps.acceleration[p_i].fill(0.0)
                continue
            if self.ps.material[p_i] == self.ps.material_solid:
                if self.ps.active[p_i] and self.ps.is_dynamic_rigid_body(p_i):
                    mass = self.ps.m[p_i]
                    gravity_force = mass * ti.Vector(self.g) 
                    self.ps.rigid_body_forces[object_id] += gravity_force     
    @ti.kernel
    def advect_solid(self):
        self.ps.v_export.fill(0)
        for p_i in ti.grouped(self.ps.x):
            if self.ps.material[p_i] == self.ps.material_solid:
                    self.ps.v[p_i] += self.dt[None] * self.ps.acceleration[p_i]
                    self.ps.v_export[p_i] = self.dt[None] * self.ps.acceleration[p_i]
                    self.ps.x[p_i] += self.dt[None] * self.ps.v[p_i]

    
    @ti.func
    def handle_rigid_body_collision(self, p_i, p_j, ret: ti.template()):
        x_i = self.ps.x[p_i]
        x_j = self.ps.x[p_j]
        correction_strength = 0.01
        min_correction = 1e-4
        acceptable_overlap = 8*self.ps.particle_diameter
        if self.ps.is_dynamic_rigid_body(p_i) and self.ps.is_dynamic_rigid_body(p_j):
            if self.ps.object_id[p_i] != self.ps.object_id[p_j]:
                r_ij = x_j - x_i
                distance = r_ij.norm()
                 
                if distance > 1e-6 and distance < acceptable_overlap:
                    overlap = acceptable_overlap - distance
                    correction_vector = (overlap * correction_strength) * r_ij.normalized()
                    if correction_vector.norm() > min_correction:
                        self.ps.x[p_i] -= correction_vector * 1
                        # self.ps.x[p_j] += correction_vector * 1

        elif (self.ps.is_dynamic_rigid_body(p_i) and self.ps.is_static_rigid_body(p_j)) or \
            (self.ps.is_static_rigid_body(p_i) and self.ps.is_dynamic_rigid_body(p_j)):
            
            dynamic_p = p_i if self.ps.is_dynamic_rigid_body(p_i) else p_j
            static_p = p_j if self.ps.is_static_rigid_body(p_j) else p_i
            r_ij = self.ps.x[static_p] - self.ps.x[dynamic_p]
            distance = r_ij.norm()
            if distance > 1e-6 and distance < acceptable_overlap:
                object_dynamic = self.ps.object_id[dynamic_p]
                overlap = acceptable_overlap - distance
                correction_vector = (overlap * correction_strength) * r_ij.normalized()
                self.ps.x[dynamic_p] -= correction_vector 
                self.ps.rigid_body_linear_velocities[object_dynamic][2] *= self.rigid_damping_coefficient
                self.ps.rigid_body_angular_velocities[object_dynamic][2] *= self.rigid_damping_coefficient

    @ti.kernel
    def compute_rigid_body_collisions(self):
        for p_i in ti.grouped(self.ps.x):
            if self.ps.is_dynamic_rigid_body(p_i):
                d_v = ti.Vector([0.0 for _ in range(self.ps.dim)])
                self.ps.for_all_neighbors(p_i, self.handle_rigid_body_collision, d_v)
                self.ps.acceleration[p_i] += d_v

                
    @ti.kernel
    def compute_velocity_statistics(self):
        fluid_particle_count = self.ps.fluid_particle_num +1
        total_velocity = ti.Vector([0.0, 0.0, 0.0])  
        for p_i in ti.grouped(self.ps.x):
            if self.ps.material[p_i] == self.ps.material_fluid:
                total_velocity += self.ps.v[p_i]
        mean_velocity = total_velocity / fluid_particle_count
        self.mean_velocity[None] = mean_velocity
        variance = ti.Vector([0.0, 0.0, 0.0])
        for p_i in ti.grouped(self.ps.x):
            if self.ps.material[p_i] == self.ps.material_fluid:
                diff = self.ps.v[p_i] - mean_velocity
                variance += diff * diff
        std_dev = ti.sqrt(variance / fluid_particle_count)
        self.std_dev[None] = std_dev

#--------------------------------------------------------------------------------------#
    def step(self, current_frame: int, cnt: int, output_interval: int):
        self.ps.grid_prefix_sort()
        self.substep()
        if self.ps.num_rigid_bodies > 1:
            self.compute_moving_boundary_volume()
            self.compute_rigid_body_collisions()
            self.solve_center_of_mass()
            self.solve_rigid_body_dynamics_para()
            self.enforce_boundary_3D_rigid(self.ps.material_solid)
        if self.ps.num_fluid_blocks > 0 :
            self.enforce_boundary_3D(self.ps.material_fluid)
        if self.ps.num_grain_objects > 0 :
            self.enforce_boundary_3D_grain(self.ps.material_grain)
        self.ps.translate_and_copy_for_output()
        
