# Doriflow Engine - Fluid Simulation for Blender 3D
# Copyright (C) 2024 Doriflow Team
# This software is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0).

# You are free to:
# -Share: Copy and redistribute the material in any medium or format.
# -Adapt: Remix, transform, and build upon the material.

# UNDER THE FOLOWING TERMS:
# -Attribution: You must give appropriate credit, provide a link to the license, and indicate if changes were made.
# -Appropriate credit should include the following:
#   -The original author's name: Doriflow Team
#   -A link to the original source (if applicable).
#   -A link to the full license: https://creativecommons.org/licenses/by-nc/4.0/.
#   -A clear indication of any changes made, such as: "This material has been modified."
# NonCommercial: You may not use the material for commercial purposes.
# Disclaimer:
# -This simulation engine is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. In no event shall the authors be liable for any claim, damages, or other liability arising from the use of this software.

# For more details, refer to the full license text at:
# https://creativecommons.org/licenses/by-nc/4.0/.

#-----------------------------------------------------------------------------------------------------------------------#
import json
import taichi as ti
import numpy as np
import trimesh as tm
from config_builder import SimConfig
from functools import reduce
from scipy.spatial import cKDTree
import os
import scipy

def calculate_grid_properties(largest_domain_size, resolution):
    particle_diameter = largest_domain_size / resolution
    particle_radius = round(particle_diameter / 2,4)
    return particle_diameter, particle_radius

@ti.data_oriented
class ParticleSystem:
    def __init__(self, config: SimConfig):
        self.cfg = config
        self.domain_translation_vector = self.cfg.get_domain("translation_vector")
        self.operating_system = self.cfg.get_domain("operating_system")
        self.domain_start_for_dump = np.round(np.array(self.cfg.get_domain("domainStart_1")),3)
        domain_shape = self.cfg.get_domain() 
        self.load_domain_shape(domain_shape)
        self.simulation_method = config.get_domain("simulationMethod")
        domain_resolution = self.cfg.get_domain("resolution")
        largest_domain_size = max(self.domain_size)
        self.grid_size = largest_domain_size / domain_resolution
        self.particle_radius = round(self.grid_size / 4,4)
        self.particle_diameter = 2 * self.particle_radius
        self.grid_num = np.ceil(self.domain_size / self.grid_size).astype(int)
        self.support_radius = 4 * self.particle_radius
        self.padding = round(self.grid_size, 4)
        self.max_curvature = ti.field(float, shape=())
        self.max_curvature[None] = 0
        self.dim=3
        assert self.dim > 1
        self.m_V0 = 1* self.particle_diameter ** self.dim
        self.particle_num = ti.field(int, shape=())
        self.surface_particle_num = ti.field(int, shape=())
        self.found_empty_grid = ti.field(int, shape=())
        self.material_solid = 0
        self.material_fluid = 1
        self.material_grain = 4
        self.object_collection = dict()
        self.object_id_rigid_body = set()
        self.object_id_fluid_body = set()
        self.object_id_inlet_body = set()
        self.object_id_grain_body = set()
        self.fluid_particle_num = 0
        self.rigid_particle_num = 0
        self.gas_particle_num = 0
        self.grain_particle_num = 0
        fluid_blocks = self.cfg.get_fluid_blocks()
        rigid_bodies = self.cfg.get_rigid_bodies()
        self.inlet_particle_num = 0
        inlet_objects = self.cfg.get_inlet_objects()
        self.rigid_points = np.empty((0, 3))
        self.grain_points = np.empty((0, 3))
        self.num_rigid_bodies = len(rigid_bodies) + 1
        self.num_fluid_blocks = len(fluid_blocks)
        self.num_inlets = len(inlet_objects)
        self.num_grain_objects = len(self.cfg.get_grain_objects())
        self.object_id_rigid_body_ti = ti.field(dtype=ti.i32, shape=self.num_rigid_bodies)
        self.motion_data = {}
        self.total_objects = max( self.num_rigid_bodies + self.num_fluid_blocks + self.num_inlets+self.num_grain_objects, 1)
        self.aabb_min_0 = ti.Vector.field(3, dtype=ti.f32, shape=(self.total_objects))
        self.aabb_max_0 = ti.Vector.field(3, dtype=ti.f32, shape=(self.total_objects))
        self.aabb_min = ti.Vector.field(3, dtype=ti.f32, shape=(self.total_objects))
        self.aabb_max = ti.Vector.field(3, dtype=ti.f32, shape=(self.total_objects))
        self.rotation_matrices = ti.Matrix.field(3, 3, dtype=ti.f32, shape=(self.total_objects))
        self.rotation_quaternions = ti.Vector.field(4, dtype=ti.f32, shape=(self.total_objects))
        self.initialize_rotation_quaternions()
        self.centers_of_mass = ti.Vector.field(3, dtype=ti.f32, shape=(self.total_objects))
        self.rigid_body_angular_velocities = ti.Vector.field(3, dtype=ti.f32, shape=(self.total_objects))
        self.torque = ti.Vector.field(3, dtype=ti.f32, shape=(self.total_objects))
        self.inertia_tensor_inv = ti.Matrix.field(3, 3, dtype=ti.f32, shape=(self.total_objects))
        self.total_time = 0.0
        self.rigid_body_masses = ti.field(dtype=float, shape=(self.total_objects))
        self.rigid_density = ti.field(dtype=float, shape=(self.total_objects))
        self.rigid_body_total_forces = ti.Vector.field(3, dtype=ti.f32, shape=(self.total_objects))
        self.rigid_body_linear_velocities = ti.Vector.field(3, dtype=ti.f32, shape=(self.total_objects))
        self.rigid_search_dist = ti.field(dtype=float, shape=(self.total_objects))
        self.quaternion_rotations = ti.Vector.field(4, float, shape=(self.total_objects))
        self.object_materials = ti.field(dtype=int, shape=(self.total_objects))
       
        total_objects = max(self.total_objects, 1)
        self.rotation_matrices = ti.Matrix.field(3, 3, dtype=ti.f32, shape=(total_objects))
        self.centers_of_mass = ti.Vector.field(3, dtype=ti.f32, shape=(total_objects))
        self.rigid_body_angular_velocities = ti.Vector.field(3, dtype=ti.f32, shape=(total_objects))
        self.torque = ti.Vector.field(3, dtype=ti.f32, shape=(total_objects))
        self.inertia_body = ti.Matrix.field(3, 3, dtype=ti.f32, shape=(total_objects)) 
        self.inertia_body_inv = ti.Matrix.field(3, 3, dtype=ti.f32, shape=(total_objects)) 
        self.inertia_tensor_inv = ti.Matrix.field(3, 3, dtype=ti.f32, shape=(total_objects))
        self.total_time = 0.0
        self.rigid_body_masses = ti.field(dtype=float, shape=(total_objects))
        self.rigid_density = ti.field(dtype=float, shape=(total_objects))
        self.rigid_body_total_forces = ti.Vector.field(3, dtype=ti.f32, shape=(total_objects))
        self.rigid_body_linear_velocities = ti.Vector.field(3, dtype=ti.f32, shape=(total_objects))
        self.rigid_search_dist = ti.field(dtype=float, shape=(total_objects))
        self.rigid_body_angular_momentum = ti.Vector.field(3, dtype=ti.f32, shape=(self.total_objects))
        self.quaternion = ti.Vector.field(4, float, shape=(total_objects))
        self.object_materials = ti.field(dtype=int, shape=(total_objects))
        self.radius_variation = self.cfg.get_domain("radius_std")/100
        self.dem_max_neighbours = self.cfg.get_domain("dem_max_neighbours")
        self.grain_density = self.cfg.get_domain("grain_density")
        if self.num_fluid_blocks > 2:
            raise RuntimeError("Doriflow Demo version only allows up to 2 Fluid objects. Please reduce the number or upgrade to Doriflow Full version.")
        if self.num_rigid_bodies - 1 > 2:
            raise RuntimeError("Doriflow Demo version only allows up to 2 Rigid (Obstacle) objects. Please reduce the number or upgrade to Doriflow Full version.")
        if self.num_grain_objects > 2:
            raise RuntimeError("Doriflow Demo version only allows up to 2 Grain objects. Please reduce the number or upgrade to Doriflow Full version.")

    @ti.kernel
    def initialize_rotation_quaternions(self):
        for i in range(self.total_objects):
            self.rotation_quaternions[i] = ti.Vector([0.0, 0.0, 0.0, 1.0])
    
    def load_voxelized_points_from_txt(self, filepath):
        points = []
        with open(filepath, 'r') as f:
            for line in f:
                try:
                    values = list(map(float, line.strip().split(',')))
                    if len(values) == 3: 
                        points.append(values)
                    else:
                        print(f"Skipping line with unexpected number of values: {line.strip()}")
                except ValueError as e:
                    print(f"Skipping malformed line: {line.strip()} - Error: {e}")
        if self.domain_translation_vector is not None:
            print("domain_translation_vector", self.domain_translation_vector)
            translation_vector = np.array(self.domain_translation_vector, dtype=np.float32)
            points = [np.array(p) + translation_vector for p in points]
        points_array = np.array(points, dtype=np.float32)
        return points_array

    def cal_fluid_num(self):
        fluid_blocks = self.cfg.get_fluid_blocks()
        for fluid in fluid_blocks:   
            obj_id = fluid["objectId"]   
            voxelized_points_np_f = self.load_fluid_body(fluid)
            fluid["particleNum"] = voxelized_points_np_f.shape[0]
            fluid["voxelizedPoints"] = voxelized_points_np_f
            self.object_collection[fluid["objectId"]] = fluid
            self.fluid_particle_num += voxelized_points_np_f.shape[0]  
            material = fluid["material"]
            if material == "LIQUID":
                material_value = 1
            elif material == "GAS":
                material_value = 3
            self.object_materials[obj_id] = material_value
    def cal_rigid_num(self):
        rigid_bodies = self.cfg.get_rigid_bodies()
        all_rigid_points = []
        for rigid_body in rigid_bodies:
            obj_id = rigid_body["objectId"]
            voxelized_points_np_r = self.load_rigid_body(rigid_body)
            rigid_body["particleNum"] = voxelized_points_np_r.shape[0]
            rigid_body["voxelizedPoints"] = voxelized_points_np_r
            all_rigid_points.append(voxelized_points_np_r)
            self.object_collection[rigid_body["objectId"]] = rigid_body
            self.rigid_particle_num += voxelized_points_np_r.shape[0]
            self.object_materials[obj_id] = self.material_solid
            self.rigid_density[obj_id] = rigid_body["density"]
        self.rigid_points = np.vstack(all_rigid_points) if all_rigid_points else np.empty((0, 3))
        
    
    def cal_grain_num(self):
        grain_objects = self.cfg.get_grain_objects()
        all_grain_points = []
        for grain in grain_objects:
            obj_id = grain["objectId"]
            voxelized_points_np_g, radii_g = self.load_grain_body(grain)
            grain["particleNum"] = voxelized_points_np_g.shape[0]
            grain["voxelizedPoints"] = voxelized_points_np_g
            grain["radii"] = radii_g
            all_grain_points.append(voxelized_points_np_g)
            self.object_collection[grain["objectId"]] = grain
            self.grain_particle_num += voxelized_points_np_g.shape[0]
            self.object_materials[obj_id] = self.material_grain
        self.grain_points = np.vstack(all_grain_points) if all_grain_points else np.empty((0, 3))
        print("self.grain_points", self.grain_points.shape)
    
    def cal_total_particle_num(self):
        self.total_particle_num = self.fluid_particle_num + self.rigid_particle_num  + self.grain_particle_num
    def array_setup(self):
        if self.num_rigid_bodies > 1:
            self.rigid_rest_cm = ti.Vector.field(self.dim, dtype=float, shape=self.num_fluid_blocks +self.num_rigid_bodies ) 
        self.grid_particles_num = ti.field(int, shape=int(self.grid_num[0]*self.grid_num[1]*self.grid_num[2]))
        self.grid_particles_num_temp = ti.field(int, shape=int(self.grid_num[0]*self.grid_num[1]*self.grid_num[2]))
        self.n_grids = int(self.grid_num[0]*self.grid_num[1]*self.grid_num[2])
        self.force = ti.Vector.field(self.dim, dtype=ti.f32, shape=self.total_particle_num)
        self.particle_id = ti.field(dtype=ti.i32, shape=self.total_particle_num)
        self.omega = ti.Vector.field(3, dtype=ti.f32, shape=self.total_particle_num)
        self.bond_force = ti.Vector.field(3, dtype=ti.f32, shape=(self.total_particle_num, self.dem_max_neighbours))
        self.bond_moment = ti.Vector.field(3, dtype=ti.f32, shape=(self.total_particle_num, self.dem_max_neighbours))
        self.bond_active = ti.field(dtype=ti.i32, shape=(self.total_particle_num, self.dem_max_neighbours))  
        self.bond_shear_disp = ti.field(dtype=ti.f32, shape=(self.total_particle_num, self.dem_max_neighbours))
        self.dem_torque = ti.Vector.field(3, dtype=ti.f32, shape=self.total_particle_num)
        self.bond_length = ti.field(dtype=ti.f32, shape=(self.total_particle_num, self.dem_max_neighbours))

        self.max_particle_count = ti.field(dtype=ti.i32, shape=())
        self.min_particle_count = ti.field(dtype=ti.i32, shape=())
        self.max_particle_grid_index = ti.field(dtype=ti.i32, shape=())  
        self.min_particle_grid_index = ti.field(dtype=ti.i32, shape=())
        self.prefix_sum_executor = ti.algorithms.PrefixSumExecutor(self.grid_particles_num.shape[0])
        self.density_adv = ti.field(dtype=ti.f32, shape=self.total_particle_num)
        self.dfsph_factor = ti.field(dtype=ti.f32, shape=self.total_particle_num)
        self.predicted_v = ti.Vector.field(self.dim, dtype=ti.f32, shape=self.total_particle_num)
        self.divergence_pressure = ti.field(dtype=ti.f32, shape=self.total_particle_num)
        self.density_adv = ti.field(dtype=ti.f32, shape=self.total_particle_num)
        self.pressure = ti.field(dtype=ti.f32, shape=self.total_particle_num)
        self.object_id = ti.field(dtype=int, shape=self.total_particle_num)
        self.active = ti.field(dtype=int, shape=self.total_particle_num)
        self.active_buffer = ti.field(dtype=int, shape=self.total_particle_num)
        self.current_emitted_particle_num = ti.field(dtype=int, shape=())
        self.x = ti.Vector.field(self.dim, dtype=float, shape=self.total_particle_num)
        self.total_correction = ti.Vector.field(self.dim, dtype=float, shape=self.total_particle_num)
        self.x_output = ti.Vector.field(self.dim, dtype=float, shape=self.total_particle_num)
        self.x_translated = ti.Vector.field(self.dim, dtype=float, shape=self.total_particle_num)
        self.x_0 = ti.Vector.field(self.dim, dtype=float, shape=self.total_particle_num)
        self.v = ti.Vector.field(self.dim, dtype=float, shape=self.total_particle_num)
        self.v_export = ti.Vector.field(self.dim, dtype=float, shape=self.total_particle_num)
        self.r = ti.field(dtype=ti.f32, shape=self.total_particle_num)
        self.max_velocity_field = ti.field(dtype=ti.f32, shape=())
        self.acceleration = ti.Vector.field(self.dim, dtype=float, shape=self.total_particle_num)
        self.pressure_force = ti.Vector.field(self.dim, dtype=float, shape=self.total_particle_num)
        self.viscous_force = ti.Vector.field(self.dim, dtype=float, shape=self.total_particle_num)
        self.voxelized_points = ti.Vector.field(self.dim, dtype=float, shape=self.total_particle_num)
        self.m_V = ti.field(dtype=float, shape=self.total_particle_num)
        self.m = ti.field(dtype=float, shape=self.total_particle_num)
        self.density = ti.field(dtype=float, shape=self.total_particle_num)
        self.pressure = ti.field(dtype=float, shape=self.total_particle_num)
        self.material = ti.field(dtype=int, shape=self.total_particle_num)
        self.is_dynamic = ti.field(dtype=int, shape=self.total_particle_num)
        self.is_isolated = ti.field(dtype=int, shape=self.total_particle_num)
        self.rigid_body_torques = ti.Vector.field(self.dim, dtype=float, shape=(max(self.num_rigid_bodies,1)))
        self.rigid_body_forces = ti.Vector.field(self.dim, dtype=float, shape=(max(self.num_rigid_bodies,1)))
        self.rigid_body_centers_of_mass = ti.Vector.field(self.dim, dtype=float, shape=(max(self.num_rigid_bodies,1)))
        
        
        self.normals = ti.Vector.field(3, dtype=ti.f32, shape=self.total_particle_num)
        self.normal_magnitude = ti.field(dtype=ti.f32, shape=self.total_particle_num)
        self.Iwc_i= ti.field(dtype=ti.f32, shape=self.total_particle_num)
        self.curvature = ti.field(dtype=ti.f32, shape=self.total_particle_num)
        self.rho_grad = ti.Vector.field(3, dtype=ti.f32, shape=self.total_particle_num)
        self.object_id_buffer = ti.field(dtype=int, shape=self.total_particle_num)
        self.x_buffer = ti.Vector.field(self.dim, dtype=float, shape=self.total_particle_num)
        self.x_0_buffer = ti.Vector.field(self.dim, dtype=float, shape=self.total_particle_num)
        self.v_buffer = ti.Vector.field(self.dim, dtype=float, shape=self.total_particle_num)
        self.r_buffer = ti.field(dtype=float, shape=self.total_particle_num)
        self.acceleration_buffer = ti.Vector.field(self.dim, dtype=float, shape=self.total_particle_num)
        self.pressure_force_buffer = ti.Vector.field(self.dim, dtype=float, shape=self.total_particle_num)
        self.viscous_force_buffer = ti.Vector.field(self.dim, dtype=float, shape=self.total_particle_num)
        self.m_V_buffer = ti.field(dtype=float, shape=self.total_particle_num)
        self.m_buffer = ti.field(dtype=float, shape=self.total_particle_num)
        self.density_buffer = ti.field(dtype=float, shape=self.total_particle_num)
        self.pressure_buffer = ti.field(dtype=float, shape=self.total_particle_num)
        self.material_buffer = ti.field(dtype=int, shape=self.total_particle_num)
        self.is_dynamic_buffer = ti.field(dtype=int, shape=self.total_particle_num)
        self.grid_ids = ti.field(int, shape=self.total_particle_num)
        self.grid_ids_buffer = ti.field(int, shape=self.total_particle_num)
        self.grid_ids_new = ti.field(int, shape=self.total_particle_num)
        self.pressure_force_total= ti.Vector.field(3, dtype=ti.f32, shape=())
        self.pressure_force_total[None] = ti.Vector([0,0,0])
        self.viscous_force_total= ti.Vector.field(3, dtype=ti.f32, shape=())
        self.viscous_force_total[None] = ti.Vector([0,0,0])
    @ti.kernel
    def translate_and_copy_for_output(self):
        for i in range(self.total_particle_num):
            self.x_output[i] = self.x[i] - self.domain_translation_vector
    def add_fluid_particles(self):
        fluid_blocks = self.cfg.get_fluid_blocks()
        for fluid in fluid_blocks:
            obj_id = fluid["objectId"]
            print("fluid obj id", obj_id)
            self.object_id_fluid_body.add(obj_id)
            num_particles_obj = fluid["particleNum"]
            voxelized_points_np = fluid["voxelizedPoints"]
            voxelized_points = ti.field(dtype=ti.f32, shape=num_particles_obj)
            is_dynamic = fluid["isDynamic"]
            if is_dynamic:
                velocity = np.array(fluid["velocity"], dtype=np.float32)
            else:
                velocity = np.array([0.0 for _ in range(self.dim)], dtype=np.float32)
            radius = np.float32(self.particle_radius)
            density = np.float32(fluid["density"])
            material = fluid["material"]
            if material == "LIQUID":
                material_value = (1)
            elif material == "GAS":
                material_value = (3)
            self.add_particles(obj_id,
                               num_particles_obj,
                               np.array(voxelized_points_np, dtype=np.float32), 
                               np.stack([velocity for _ in range(num_particles_obj)]), 
                               radius * np.ones(num_particles_obj, dtype=np.float32),
                               density * np.ones(num_particles_obj, dtype=np.float32), 
                               np.zeros(num_particles_obj, dtype=np.float32), 
                               np.array([material_value for _ in range(num_particles_obj)], dtype=np.int32),
                               is_dynamic * np.ones(num_particles_obj, dtype=np.int32))
    
    def add_grain_particles(self):
        grain_objects = self.cfg.get_grain_objects()
        for grain in grain_objects:
            obj_id = grain["objectId"]
            self.object_id_grain_body.add(obj_id)
            num_particles_obj = grain["particleNum"]
            voxelized_points_np = grain["voxelizedPoints"]
            is_dynamic = grain["isDynamic"]
            if is_dynamic:
                velocity = np.array(grain["velocity"], dtype=np.float32)
            else:
                velocity = np.array([0.0 for _ in range(self.dim)], dtype=np.float32)
            radius = np.array(grain["radii"], dtype=np.float32)
            density = np.float32(self.grain_density)
            self.add_particles(obj_id,
                               num_particles_obj,
                               np.array(voxelized_points_np, dtype=np.float32), 
                               np.stack([velocity for _ in range(num_particles_obj)], dtype=np.float32), 
                               np.array(radius, dtype=np.float32),
                               np.float32(density) * np.ones(num_particles_obj, dtype=np.float32), 
                               np.zeros(num_particles_obj, dtype=np.float32), 
                               np.array([self.material_grain for _ in range(num_particles_obj)], dtype=np.int32),
                               is_dynamic * np.ones(num_particles_obj, dtype=np.int32))
    def add_inlet_particles(self):
        inlet_objects = self.cfg.get_inlet_objects()
        for inlet in inlet_objects:
            obj_id = inlet["objectId"]
            self.object_id_inlet_body.add(obj_id)
            is_dynamic = inlet["isDynamic"]
            if is_dynamic:
                velocity = np.array(inlet["velocity"], dtype=np.float32)
            else:
                velocity = np.array([0.0 for _ in range(self.dim)], dtype=np.float32)
    
    @ti.func
    def compute_com(self, object_id):
        sum_m = 0.0
        cm = ti.Vector([0.0, 0.0, 0.0])
        for p_i in self.x:
            if self.active[p_i] and self.is_dynamic_rigid_body(p_i) and self.object_id[p_i] == object_id:
                mass = self.m_V0 * self.density[p_i] 
                cm += mass * self.x[p_i]
                sum_m += mass
        cm /= sum_m
        self.rigid_body_masses[object_id] = sum_m
        return cm
    
    @ti.kernel
    def compute_com_kernel(self, object_id: int):
        self.rigid_body_centers_of_mass[object_id] = self.compute_com(object_id)
    @ti.kernel
    def compute_generic_inertia(self, object_id: int):
        cm = self.rigid_body_centers_of_mass[object_id]
        inertia_tensor = ti.Matrix.zero(ti.f32, 3, 3)
        for p_i in self.x:
            if self.active[p_i] and self.is_dynamic_rigid_body(p_i) and self.object_id[p_i] == object_id:
                r = self.x[p_i] - cm
                mass = self.m_V0 * self.density[p_i]
                inertia_tensor += mass * (r.norm_sqr() * ti.Matrix.identity(ti.f32, 3) - r.outer_product(r))
        self.inertia_body[object_id] = inertia_tensor
        self.inertia_body_inv[object_id] = inertia_tensor.inverse()
        self.inertia_tensor_inv[object_id] = inertia_tensor.inverse()

    def add_rigid_particles(self):
        rigid_bodies = self.cfg.get_rigid_bodies()
        for i,rigid_body in enumerate(rigid_bodies):
            obj_id = rigid_body["objectId"]
            self.object_id_rigid_body_ti[i] = obj_id
        for rigid_body in rigid_bodies:
            obj_id = rigid_body["objectId"]
            self.object_id_rigid_body.add(obj_id)
            num_particles_obj = rigid_body["particleNum"]
            voxelized_points_np = rigid_body["voxelizedPoints"]
            is_dynamic = rigid_body["isDynamic"]
            if is_dynamic:
                velocity = np.array(rigid_body["velocity"], dtype=np.float32)
            else:
                velocity = np.array([0.0 for _ in range(self.dim)], dtype=np.float32)
            radius = np.float32(self.particle_radius)
            density = np.float32(rigid_body["density"])
            self.add_particles(obj_id,
                               num_particles_obj,
                               np.array(voxelized_points_np, dtype=np.float32), 
                               np.stack([velocity for _ in range(num_particles_obj)]), 
                               radius * np.ones(num_particles_obj, dtype=np.float32),
                               density * np.ones(num_particles_obj, dtype=np.float32), 
                               np.zeros(num_particles_obj, dtype=np.float32), 
                               np.array([self.material_solid for _ in range(num_particles_obj)], dtype=np.int32), # material is solid
                               is_dynamic * np.ones(num_particles_obj, dtype=np.int32))
            
            self.compute_com_kernel(obj_id)
            self.compute_generic_inertia(obj_id)
            self.rigid_body_linear_velocities[obj_id] = ti.Vector(rigid_body["velocity"])
            self.quaternion[obj_id][0] = 1.0
   
    def resize_particle_system(self, new_total_particle_num):
        self.object_id_dynamic = ti.field(dtype=int, shape=new_total_particle_num)
        self.x_dynamic = ti.Vector.field(self.dim, dtype=float, shape=new_total_particle_num)
        self.x_0_dynamic = ti.Vector.field(self.dim, dtype=float, shape=new_total_particle_num)
        self.v_dynamic = ti.Vector.field(self.dim, dtype=float, shape=new_total_particle_num)
        self.r_dynamic = ti.field(dtype=float, shape=new_total_particle_num)
        self.acceleration_dynamic = ti.Vector.field(self.dim, dtype=float, shape=new_total_particle_num)
        self.pressure_force_dynamic = ti.Vector.field(self.dim, dtype=float, shape=new_total_particle_num)
        self.viscous_force_dynamic = ti.Vector.field(self.dim, dtype=float, shape=new_total_particle_num)
        self.m_V_dynamic = ti.field(dtype=float, shape=new_total_particle_num)
        self.m_dynamic = ti.field(dtype=float, shape=new_total_particle_num)
        self.density_dynamic = ti.field(dtype=float, shape=new_total_particle_num)
        self.pressure_dynamic = ti.field(dtype=float, shape=new_total_particle_num)
        self.material_dynamic = ti.field(dtype=int, shape=new_total_particle_num)
        self.is_dynamic_dynamic = ti.field(dtype=int, shape=new_total_particle_num)
        for I in range(self.x.shape[0]):
            self.object_id_dynamic[I] = self.object_id[I]
            self.x_0_dynamic[I] = self.x_0[I]
            self.x_dynamic[I] = self.x[I]
            self.v_dynamic[I] = self.v[I]
            self.r_dynamic[I] = self.r[I]
            self.acceleration_dynamic[I] = self.acceleration[I]
            self.pressure_force_dynamic[I] = self.pressure_force[I]
            self.viscous_force_dynamic[I] = self.viscous_force[I]
            self.m_V_dynamic[I] = self.m_V[I]
            self.m_dynamic[I] = self.m[I]
            self.density_dynamic[I] = self.density[I]
            self.pressure_dynamic[I] = self.pressure[I]
            self.material_dynamic[I] = self.material[I]
            self.is_dynamic_dynamic[I] = self.is_dynamic[I]
        np.savetxt('x before resize.txt', self.x.to_numpy(), delimiter=',')
        self.x=self.x_dynamic
        np.savetxt('x after resize.txt', self.x.to_numpy(), delimiter=',')
        self.x_0=self.x_0_dynamic
        self.v=self.v_dynamic
        self.r=self.r_dynamic
        self.acceleration=self.acceleration_dynamic
        self.pressure_force=self.pressure_force_dynamic
        self.viscous_force=self.viscous_force_dynamic
        self.m_V=self.m_V_dynamic
        self.m=self.m_dynamic
        self.density=self.density_dynamic
        self.pressure=self.pressure_dynamic
        self.material=self.material_dynamic
        self.is_dynamic=self.is_dynamic_dynamic
    @ti.func
    def add_particle(self, p, obj_id, x, v, r, density, pressure, material, is_dynamic):       
        self.object_id[p] = obj_id
        self.active[p] = 1
        self.x[p] = x
        self.x_0[p] = x
        self.v[p] = v
        self.r[p] = ti.f32(r)
        self.density[p]= density
        self.m_V[p] = self.m_V0
        self.m[p] = self.m_V0 * density
        self.pressure[p] = pressure
        self.material[p] = material
        self.is_dynamic[p] = is_dynamic 


    def add_particles(self,  
                      object_id: int,
                      new_particles_num: int,
                      new_particles_positions: ti.types.ndarray(),
                      new_particles_velocity: ti.types.ndarray(),
                      new_particle_radius: ti.types.ndarray(),
                      new_particle_density: ti.types.ndarray(),
                      new_particle_pressure: ti.types.ndarray(),
                      new_particles_material: ti.types.ndarray(),
                      new_particles_is_dynamic: ti.types.ndarray()
                      ):
        self._add_particles(object_id,
                      new_particles_num,
                      new_particles_positions,
                      new_particles_velocity,
                      new_particle_radius,
                      new_particle_density,
                      new_particle_pressure,
                      new_particles_material,
                      new_particles_is_dynamic
                      )
    @ti.kernel
    def _add_particles(self,
                      object_id: int,
                      new_particles_num: int,
                      new_particles_positions: ti.types.ndarray(),
                      new_particles_velocity: ti.types.ndarray(),
                      new_particle_radius: ti.types.ndarray(),
                      new_particle_density: ti.types.ndarray(),
                      new_particle_pressure: ti.types.ndarray(),
                      new_particles_material: ti.types.ndarray(),
                      new_particles_is_dynamic: ti.types.ndarray()
                      ):
        for p in range(self.particle_num[None] , self.particle_num[None] + new_particles_num):
            v = ti.Vector.zero(float, self.dim)
            x = ti.Vector.zero(float, self.dim)
            for d in ti.static(range(self.dim)):
                v[d] = new_particles_velocity[p - self.particle_num[None], d]
                x[d] = new_particles_positions[p - self.particle_num[None], d]
            self.add_particle(p, object_id, x, v,
                              new_particle_radius[p - self.particle_num[None]],
                              new_particle_density[p - self.particle_num[None]],
                              new_particle_pressure[p - self.particle_num[None]],
                              new_particles_material[p - self.particle_num[None]],
                              new_particles_is_dynamic[p - self.particle_num[None]]
                              )
        self.particle_num[None] += new_particles_num
    def add_particles_adaptive(self,  
                      object_id: int,
                      new_particles_num: int,
                      new_particles_positions: ti.types.ndarray(),
                      new_particles_velocity: ti.types.ndarray(),
                      new_particle_radius: ti.types.ndarray(),
                      new_particle_density: ti.types.ndarray(),
                      new_particle_pressure: ti.types.ndarray(),
                      new_particles_material: ti.types.ndarray(),
                      new_particles_is_dynamic: ti.types.ndarray()
                      ):
        self._add_particles_adaptive(object_id,
                      new_particles_num,
                      new_particles_positions,
                      new_particles_velocity,
                      new_particle_radius,
                      new_particle_density,
                      new_particle_pressure,
                      new_particles_material,
                      new_particles_is_dynamic
            
                      )

    @ti.kernel
    def _add_particles_adaptive(self,
                      object_id: int,
                      new_particles_num: int,
                      new_particles_positions: ti.types.ndarray(),
                      new_particles_velocity: ti.types.ndarray(),
                      new_particle_radius: ti.types.ndarray(),
                      new_particle_density: ti.types.ndarray(),
                      new_particle_pressure: ti.types.ndarray(),
                      new_particles_material: ti.types.ndarray(),
                      new_particles_is_dynamic: ti.types.ndarray()):
        actual_allocated_num = 0
        for p in range(self.current_emitted_particle_num[None], self.current_emitted_particle_num[None] + new_particles_num):
            if self.active[p] == 0:
                v = ti.Vector.zero(float, self.dim)
                x = ti.Vector.zero(float, self.dim)
                idx_new_arr = p - self.current_emitted_particle_num[None]
                for d in ti.static(range(self.dim)):
                    v[d] = new_particles_velocity[idx_new_arr, d]
                    x[d] = new_particles_positions[idx_new_arr, d]
                self.add_particle(p, object_id, x, v,
                                new_particle_radius[idx_new_arr],
                                new_particle_density[idx_new_arr],
                                new_particle_pressure[idx_new_arr],
                                new_particles_material[idx_new_arr],
                                new_particles_is_dynamic[idx_new_arr]
                                )
                actual_allocated_num += 1
        self.current_emitted_particle_num[None] += actual_allocated_num
        assert actual_allocated_num == new_particles_num, f"Reserved memory allowed particles num {actual_allocated_num}, however {new_particles_num} is required."
        self.particle_num[None] += actual_allocated_num 
    
    def add_square_particles(self, object_id, vertices, material, is_dynamic, color=(0,0,0), density=None, pressure=None, velocity=None, radius=None, adaptive=False):
        assert len(vertices) == 4 and all(len(v) == 3 for v in vertices), "Vertices should be a list of four 3D points."
        vertices = np.array(vertices, dtype=np.float32)
        num_x = int(np.linalg.norm(vertices[1] - vertices[0]) / self.particle_diameter)
        num_y = int(np.linalg.norm(vertices[3] - vertices[0]) / self.particle_diameter)
        num_x = max(1, num_x)
        num_y = max(1, num_y)
        x_vals = np.linspace(vertices[0], vertices[1], num_x, dtype=np.float32)
        y_vals = np.linspace(vertices[3], vertices[2], num_x, dtype=np.float32)
        grid_positions = np.array([np.linspace(x, y, num_y, dtype=np.float32) for x, y in zip(x_vals, y_vals)], dtype=np.float32).reshape(-1, 3)
        num_new_particles = grid_positions.shape[0]
        velocity_arr = np.full_like(grid_positions, 0, dtype=np.float32) if velocity is None else np.tile(np.array(velocity, dtype=np.float32), (num_new_particles, 1))
        radius_arr = np.full(num_new_particles, self.particle_radius, dtype=np.float32) if radius is None else np.full(num_new_particles, radius, dtype=np.float32)
        material_arr = np.full(num_new_particles, material, dtype=np.int32)
        is_dynamic_arr = np.full(num_new_particles, is_dynamic, dtype=np.int32)
        color_arr = np.stack([np.full(num_new_particles, c, dtype=np.int32) for c in color], axis=1)
        density_arr = np.full(num_new_particles, density if density is not None else 1000., dtype=np.float32)
        pressure_arr = np.full(num_new_particles, pressure if pressure is not None else 0., dtype=np.float32)
        if adaptive:
            self.add_particles_adaptive(object_id, num_new_particles, grid_positions, velocity_arr, radius_arr, density_arr, pressure_arr, material_arr, is_dynamic_arr)
        else:
            self.add_particles(object_id, num_new_particles, grid_positions, velocity_arr, radius_arr, density_arr, pressure_arr, material_arr, is_dynamic_arr)

        return num_new_particles


    @ti.func
    def pos_to_index(self, pos):
        return (pos / self.grid_size).cast(int)
    @ti.func
    def flatten_grid_index(self, grid_index):
        return grid_index[0] * self.grid_num[1] * self.grid_num[2] + grid_index[1] * self.grid_num[2] + grid_index[2]  
    @ti.func
    def get_flatten_grid_index(self, pos):
        return self.flatten_grid_index(self.pos_to_index(pos)) 
    @ti.kernel
    def update_grid_id(self):
        for I in ti.grouped(self.grid_particles_num):
            self.grid_particles_num[I] = 0
        for I in ti.grouped(self.x):
            grid_index = self.get_flatten_grid_index(self.x[I])
            self.grid_ids[I] = grid_index
            ti.atomic_add(self.grid_particles_num[grid_index], 1)  
        for I in ti.grouped(self.grid_particles_num):
            self.grid_particles_num_temp[I] = self.grid_particles_num[I]
    @ti.kernel
    def counting_sort(self):
        for i in range(self.total_particle_num):    
            I = self.total_particle_num - 1 - i
            base_offset = 0
            if self.grid_ids[I] - 1 >= 0:
                base_offset = self.grid_particles_num[self.grid_ids[I]-1]
            self.grid_ids_new[I] = ti.atomic_sub(self.grid_particles_num_temp[self.grid_ids[I]], 1) - 1 + base_offset
        for I in ti.grouped(self.grid_ids):
            new_index = self.grid_ids_new[I] 
            self.grid_ids_buffer[new_index] = self.grid_ids[I]
            self.object_id_buffer[new_index] = self.object_id[I]
            self.active_buffer[new_index] = self.active[I]
            self.x_0_buffer[new_index] = self.x_0[I]
            self.x_buffer[new_index] = self.x[I]
            self.v_buffer[new_index] = self.v[I]
            self.r_buffer[new_index] = self.r[I]
            self.acceleration_buffer[new_index] = self.acceleration[I]
            self.pressure_force_buffer[new_index] = self.pressure_force[I]
            self.viscous_force_buffer[new_index] = self.viscous_force[I]
            self.m_V_buffer[new_index] = self.m_V[I]
            self.m_buffer[new_index] = self.m[I]
            self.density_buffer[new_index] = self.density[I]
            self.pressure_buffer[new_index] = self.pressure[I]
            self.material_buffer[new_index] = self.material[I]
            self.is_dynamic_buffer[new_index] = self.is_dynamic[I]
        for I in ti.grouped(self.x):
            self.grid_ids[I] = self.grid_ids_buffer[I]
            self.object_id[I] = self.object_id_buffer[I]
            self.active[I] = self.active_buffer[I]
            self.x_0[I] = self.x_0_buffer[I]
            self.x[I] = self.x_buffer[I]
            self.v[I] = self.v_buffer[I]
            self.r[I] = self.r_buffer[I]
            self.acceleration[I] = self.acceleration_buffer[I]
            self.pressure_force[I] = self.pressure_force_buffer[I]
            self.viscous_force[I] = self.viscous_force_buffer[I]
            self.m_V[I] = self.m_V_buffer[I]
            self.m[I] = self.m_buffer[I]
            self.density[I] = self.density_buffer[I]
            self.pressure[I] = self.pressure_buffer[I]
            self.material[I] = self.material_buffer[I]
            self.is_dynamic[I] = self.is_dynamic_buffer[I]
    @ti.kernel
    def prefix_sum(self):
        cur_cnt = 0
        ti.loop_config(serialize=True)
        for i in range(self.n_grids):
            ti.atomic_add(cur_cnt, self.grid_particles_num[i])
            self.grid_particles_num[i] = cur_cnt
    def manage_particle_system(self):
        if self.num_rigid_bodies>1:
            self.cal_rigid_num()
        if self.num_inlets > 0:
            self.cal_inlet_num()
        if self.num_fluid_blocks > 0:
            self.cal_fluid_num()
        if self.num_grain_objects > 0:
            self.cal_grain_num()
        self.cal_total_particle_num()
        self.array_setup() 
        if self.num_rigid_bodies>1:
            self.add_rigid_particles()

        if self.num_fluid_blocks > 0:
            self.add_fluid_particles()
        if self.num_grain_objects > 0:
            self.add_grain_particles()
    def grid_prefix_sort(self):
        if self.operating_system == "win32":
            self.update_grid_id()
            self.prefix_sum_executor.run(self.grid_particles_num)
            self.counting_sort()
            self.current_emitted_particle_num[None] = 0
        elif self.operating_system == "darwin":
            self.update_grid_id()
            self.prefix_sum()
            self.counting_sort()
            self.current_emitted_particle_num[None] = 0
    def load_rigid_body(self, rigid_body):
        obj_id = rigid_body["objectId"]
        mesh = tm.load(rigid_body["geometryFile"])
        mesh.vertices = mesh.vertices.astype(np.float32) + self.domain_translation_vector
        min_rigid = np.min(mesh.vertices, axis=0)
        max_rigid = np.max(mesh.vertices, axis=0)
        self.rigid_search_dist[obj_id] = np.linalg.norm(max_rigid - min_rigid)
        rigid_body["mesh"] = mesh
        rigid_body["restPosition"] = mesh.vertices
        rigid_body["restCenterOfMass"] = mesh.vertices.mean(axis=0)
        mesh_bounds_min, mesh_bounds_max = self.calculate_bounds_with_padding_solid()
        if isinstance(mesh, tm.points.PointCloud):
            voxelized_points_np = mesh.vertices
            print("The object is a PointCloud!")
            print(f"min {min_rigid} max {max_rigid} center ")
        elif isinstance(mesh, tm.Trimesh):
            voxelized_mesh = mesh.voxelized(pitch=self.particle_diameter)
            voxelized_mesh = mesh.voxelized(pitch=self.particle_diameter).fill()
            voxelized_points_np = voxelized_mesh.points
        else:
            raise Exception("unknown input geometry type.")
        filtered_points = self.filter_points(voxelized_points_np, mesh_bounds_min, mesh_bounds_max)
        
        if filtered_points.shape[0] == 0:
            filtered_points = np.array([rigid_body["restCenterOfMass"]])
        return filtered_points
    def apply_randomness_and_filter(self,points, min_dist):
        points_randomized = points + (np.random.rand(*points.shape) - 0.5) * min_dist * 0.9
        tree = cKDTree(points_randomized)
        keep = np.ones(points_randomized.shape[0], dtype=bool)
        for i in range(points_randomized.shape[0]):
            if not keep[i]:
                continue
            neighbors = tree.query_ball_point(points_randomized[i], r=min_dist)
            for j in neighbors:
                if j > i:
                    keep[j] = False
        return points_randomized[keep]
    def filter_overlapping_particles(self, fluid_points, rigid_points):
        rigid_tree = cKDTree(rigid_points)
        threshold = 4 * self.particle_radius
        overlapping_indices = rigid_tree.query_ball_point(fluid_points, threshold)
        non_overlapping_indices = [i for i, neighbors in enumerate(overlapping_indices) if not neighbors]
        filtered_points = fluid_points[non_overlapping_indices]
        return filtered_points
    def filter_overlapping_particles_bounding_box(self, fluid_points, rigid_points):
        min_rigid = np.min(rigid_points, axis=0)
        max_rigid = np.max(rigid_points, axis=0)
        non_overlapping_indices = [
            i for i, point in enumerate(fluid_points)
            if not (min_rigid[0] <= point[0] <= max_rigid[0] and
                    min_rigid[1] <= point[1] <= max_rigid[1] and
                    min_rigid[2] <= point[2] <= max_rigid[2])
        ]
        filtered_points = fluid_points[non_overlapping_indices]
        return filtered_points
    def load_fluid_body(self, fluid):
        obj_id = fluid["objectId"]
        fluid_mesh = tm.load(fluid["geometryFile"])
        fluid_mesh.vertices = fluid_mesh.vertices.astype(np.float32) + self.domain_translation_vector
        fluid["mesh"] = fluid_mesh
        fluid["restPosition"] = fluid_mesh.vertices
        fluid["restCenterOfMass"] = fluid_mesh.vertices.mean(axis=0)
        mesh_bounds_min, mesh_bounds_max = self.calculate_bounds_with_padding_liquid()
        voxelized_mesh = fluid_mesh.voxelized(pitch=self.particle_diameter).fill()
        voxelized_points_np = voxelized_mesh.points
        filtered_points = self.filter_points(voxelized_points_np, mesh_bounds_min, mesh_bounds_max)
        obstacle_points = np.vstack([self.rigid_points]) \
                        if self.rigid_points.size > 0  > 0 else np.empty((0, 3))
        if obstacle_points.size > 0:
            filtered_points = self.filter_overlapping_particles(filtered_points, obstacle_points)
        else:
            print("No rigid or keyframed obstacles present, skipping overlap filtering.")
        grain_points = np.vstack([self.grain_points]) \
                        if self.grain_points.size > 0 else np.empty((0, 3))
        if grain_points.size > 0:
            filtered_points = self.filter_overlapping_particles(filtered_points, grain_points)
        else:
            print("No grain present, skipping overlap filtering.")
        return filtered_points

    def load_grain_body(self, grain):
        obj_id = grain["objectId"]
        grain_mesh = tm.load(grain["geometryFile"])
        grain_mesh.vertices = grain_mesh.vertices.astype(np.float32) + self.domain_translation_vector
        grain["mesh"] = grain_mesh
        grain["restPosition"] = grain_mesh.vertices
        grain["restCenterOfMass"] = grain_mesh.vertices.mean(axis=0)
        mesh_bounds_min, mesh_bounds_max = self.calculate_bounds_with_padding_liquid()
        voxelized_mesh = grain_mesh.voxelized(pitch=self.particle_diameter).fill()
        voxelized_points_np = voxelized_mesh.points.astype(np.float32)
        filtered_points = self.filter_points(voxelized_points_np, mesh_bounds_min, mesh_bounds_max)
        radius_std_percent = self.radius_variation
        base_radius = self.particle_radius
        radius_std = radius_std_percent * base_radius
        max_radius = base_radius * (1 + radius_std_percent)
        min_dist = 2 * max_radius
        obstacle_points = np.vstack([self.rigid_points]) \
                        if self.rigid_points.size > 0 else np.empty((0, 3))
        if obstacle_points.size > 0:
            filtered_points = self.filter_overlapping_particles(filtered_points, obstacle_points)
        else:
            print("No rigid or keyframed obstacles present, skipping overlap filtering.")
        filtered_points, radii = self.apply_randomness_and_filter_multiple_radii(filtered_points, base_radius, radius_std, min_dist)

        return filtered_points, radii
    def apply_randomness_and_filter_multiple_radii(self, points, base_radius, radius_std, min_dist):
        num_points = points.shape[0]
        radii = np.random.normal(loc=base_radius, scale=radius_std, size=num_points)
        radii = np.clip(radii, a_min=base_radius * 0.5, a_max=base_radius * 1.5)
        points_randomized = points + (np.random.rand(*points.shape).astype(np.float32) - 0.5) * np.float32(base_radius) * 1.8
        tree = cKDTree(points_randomized)
        keep = np.ones(num_points, dtype=bool)
        for i in range(num_points):
            if not keep[i]:
                continue
            neighbors = tree.query_ball_point(points_randomized[i], r=min_dist)
            for j in neighbors:
                if j > i:
                    keep[j] = False
        return points_randomized[keep], radii[keep]

    def load_domain_shape(self, domain):
        mesh = tm.load(domain["geometryFile"])
        self.domain_mesh = mesh
        bounds = mesh.bounds
        self.domain_start = bounds[0] 
        self.domain_end = bounds[1] 
        mesh.vertices = mesh.vertices.astype(np.float32) + self.domain_translation_vector
        bounds = mesh.bounds
        self.domain_start = bounds[0]  
        self.domain_end = bounds[1]  
        self.domain_size = self.domain_end - self.domain_start
        print(f"Domain start: {self.domain_start}")
        print(f"Domain end: {self.domain_end}")
        print(f"Domain size: {self.domain_size}")
        self.dim = len(self.domain_size)
        assert self.dim > 1
    def calculate_bounds_with_padding_liquid(self):
        total_padding = self.padding*2
        domain_bounds_min = self.domain_start + np.array([total_padding] * self.dim) 
        domain_bounds_max = self.domain_end - np.array([total_padding] * self.dim) 
        return domain_bounds_min, domain_bounds_max
    def calculate_bounds_with_padding_solid(self):
        total_padding = self.padding 
        domain_bounds_min = self.domain_start + np.array([total_padding] * self.dim) 
        domain_bounds_max = self.domain_end - np.array([total_padding] * self.dim) 
        return domain_bounds_min, domain_bounds_max
    def filter_points(self, points, bounds_min, bounds_max):
        in_bounds = np.all((points >= bounds_min) & (points <= bounds_max), axis=1)
        return points[in_bounds]
    def dump_id(self, object_id, phase="all"):

        np_object_id = self.object_id.to_numpy() 
        np_material = self.material.to_numpy()
        is_active = self.active.to_numpy() 

        object_id_mask = np_object_id[:] == object_id
        is_active_mask = is_active[:] == 1
        mask = object_id_mask & is_active_mask  
        
        np_x = self.x_output.to_numpy()
        exclusion_radius = 0.001
        distances = np.linalg.norm(np_x - self.domain_start_for_dump, axis=1)
        reserved_mask = distances < exclusion_radius
        mask = object_id_mask & is_active_mask & ~reserved_mask
        
        liquid_mask = np_material[:] == 1 
        gas_mask = np_material[:] == 3 
        grain_mask = np_material[:] == 4    
        if phase == "liquid":
            mask &= liquid_mask
        elif phase == "gas":
            mask &= gas_mask
        elif phase == "grain":
            mask &= grain_mask
        if not np.any(mask):
            return None  

        np_x = self.x_output.to_numpy()[mask]
        np_v = self.v.to_numpy()[mask]
        np_r = self.r.to_numpy()[mask]
        np_a = self.acceleration.to_numpy()[mask]
        np_vf = self.viscous_force.to_numpy()[mask]
        np_pf = self.pressure_force.to_numpy()[mask]
        np_curve = self.curvature.to_numpy()[mask]
        np_density = self.density.to_numpy()[mask]

        return {
            "position": np_x,
            "velocity": np_v,
            "radius": np_r,
            "acceleration": np_a,
            "viscous_force": np_vf,
            "pressure_force": np_pf,
            "curvature": np_curve,
            "density": np_density
        }

    @ti.func
    def is_static_rigid_body(self, p):
        return self.material[p] == self.material_solid and (not self.is_dynamic[p])
    @ti.func
    def is_dynamic_rigid_body(self, p):
        return self.material[p] == self.material_solid and self.is_dynamic[p]
    @ti.func
    def for_all_neighbors(self, p_i, task: ti.template(), ret: ti.template()):
        center_cell = self.pos_to_index(self.x[p_i])
        for offset in ti.grouped(ti.ndrange(*((-1, 2),) * self.dim)):
            grid_index = self.flatten_grid_index(center_cell + offset)
            start_idx = 0
            end_idx = self.grid_particles_num[grid_index]
            if grid_index - 1 >= 0:
                start_idx = self.grid_particles_num[grid_index-1]
            for p_j in range(start_idx, end_idx):
                if p_i[0] != p_j and (self.x[p_i] - self.x[p_j]).norm() < (self.r[p_i] + self.r[p_j]) * 2.0:
                    task(p_i, p_j, ret)
                    
    @ti.func
    def for_all_neighbors_dem_try_bond(self, p_i, task: ti.template(), ret: ti.template(), max_neighbors: int):
        center_cell = self.pos_to_index(self.x[p_i])
        neighbor_count = 0
        active = True  

        for offset in ti.grouped(ti.ndrange(*((-1, 2),) * self.dim)):
            if not active:
                continue

            grid_index = self.flatten_grid_index(center_cell + offset)
            start_idx = 0
            end_idx = self.grid_particles_num[grid_index]
            if grid_index - 1 >= 0:
                start_idx = self.grid_particles_num[grid_index - 1]

            for p_j in range(start_idx, end_idx):
                if not active:
                    continue

                if p_i[0] != p_j and (self.x[p_i] - self.x[p_j]).norm() < self.support_radius:
                    task(p_i, p_j, ret, neighbor_count)
                    neighbor_count += 1
                    if neighbor_count >= max_neighbors:
                        active = False  
    
    @ti.func
    def for_all_neighbors_dem(self, p_i, task: ti.template(), ret1: ti.template(), ret2: ti.template()):
        center_cell = self.pos_to_index(self.x[p_i])
        for offset in ti.grouped(ti.ndrange(*((-1, 2),) * self.dim)):
            grid_index = self.flatten_grid_index(center_cell + offset)
            start_idx = 0
            end_idx = self.grid_particles_num[grid_index]
            if grid_index - 1 >= 0:
                start_idx = self.grid_particles_num[grid_index - 1]
            for p_j in range(start_idx, end_idx):
                if p_i[0] != p_j and (self.x[p_i] - self.x[p_j]).norm() < self.support_radius:
                    task(p_i, p_j, ret1, ret2)
                    

    @ti.func
    def for_all_neighbors_rigid(self, object_id, task: ti.template(), ret: ti.template()):
        aabb_min, aabb_max = self.aabb_min[object_id], self.aabb_max[object_id]
        min_cell = self.pos_to_index(aabb_min)
        max_cell = self.pos_to_index(aabb_max)

        for x in range(min_cell.x, max_cell.x + 1):
            for y in range(min_cell.y, max_cell.y + 1):
                for z in range(min_cell.z, max_cell.z + 1):
                    grid_index = self.flatten_grid_index(ti.Vector([x, y, z]))

                    if 0 <= grid_index < self.grid_particles_num.shape[0]:  
                        start_idx = self.grid_particles_num[grid_index - 1] if grid_index > 0 else 0
                        end_idx = self.grid_particles_num[grid_index]

                        for idx in range(start_idx, end_idx):
                            neighbor_object_id = self.object_id[idx]
                            if object_id != neighbor_object_id:  
                                task(object_id, neighbor_object_id, ret)  

    @ti.func
    def for_all_neighbors_density_grad(self, p_i, task: ti.template(), ret: ti.template()):
        center_cell = self.pos_to_index(self.x[p_i])
        for offset in ti.grouped(ti.ndrange(*((-1, 2),) * self.dim)):
            grid_index = self.flatten_grid_index(center_cell + offset)
            start_idx = 0
            end_idx = self.grid_particles_num[grid_index]
            if grid_index - 1 >= 0:
                start_idx = self.grid_particles_num[grid_index-1]
            for p_j in range(start_idx, end_idx):
                if p_i[0] != p_j and (self.x[p_i] - self.x[p_j]).norm() < self.support_radius:
                    task(p_i, p_j, ret)
    @ti.func
    def for_all_neighbors_ww(self, p_i, task: ti.template(), ret: ti.template()):
        center_cell = self.pos_to_index(self.x[p_i])
        for offset in ti.grouped(ti.ndrange(*((-1, 2),) * self.dim)):
            grid_index = self.flatten_grid_index(center_cell + offset)
            start_idx = 0
            end_idx = self.grid_particles_num[grid_index]
            if grid_index - 1 >= 0:
                start_idx = self.grid_particles_num[grid_index-1]
            for p_j in range(start_idx, end_idx):
                if p_i[0] != p_j and (self.x[p_i] - self.x[p_j]).norm() < self.support_radius:
                    task(p_i, p_j, ret)
    @ti.kernel
    def copy_to_numpy(self, np_arr: ti.types.ndarray(), src_arr: ti.template()):
        for i in range(self.particle_num[None]):
            np_arr[i] = src_arr[i]
    def export_motion_data(self):
        for obj_id in  self.object_id_rigid_body:
            cm = self.centers_of_mass[obj_id].to_numpy()
            R = self.rotation_matrices[obj_id].to_numpy()
            self.motion_data[obj_id] = {'center_of_mass': cm, 'rotation_matrix': R}
            
    def write_motion_data_to_json(self, file_path):
        motion_data_for_json = {}
        for obj_id, data in self.motion_data.items():
            motion_data_for_json[obj_id] = {
                'center_of_mass': data['center_of_mass'].tolist(),
                'rotation_matrix': data['rotation_matrix'].tolist()
            }
        with open(file_path, 'w') as file:
            json.dump(motion_data_for_json, file, indent=4)
            
            
            