# Doriflow Engine - Fluid Simulation for Blender 3D
# Copyright (C) 2024 Doriflow Team
# This software is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0).

# You are free to:
# -Share: Copy and redistribute the material in any medium or format.
# -Adapt: Remix, transform, and build upon the material.

# UNDER THE FOLOWING TERMS:
# -Attribution: You must give appropriate credit, provide a link to the license, and indicate if changes were made.
# -Appropriate credit should include the following:
#   -The original author's name: Doriflow Team
#   -A link to the original source (if applicable).
#   -A link to the full license: https://creativecommons.org/licenses/by-nc/4.0/.
#   -A clear indication of any changes made, such as: "This material has been modified."
# NonCommercial: You may not use the material for commercial purposes.
# Disclaimer:
# -This simulation engine is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. In no event shall the authors be liable for any claim, damages, or other liability arising from the use of this software.

# For more details, refer to the full license text at:
# https://creativecommons.org/licenses/by-nc/4.0/.

#-----------------------------------------------------------------------------------------------------------------------#
import os
import json
import argparse
import taichi as ti
import numpy as np
from config_builder import SimConfig
from particle_system import ParticleSystem
import shutil
import time
from flow_properties import FlowProperties
import math

ti.set_logging_level(ti.TRACE)
ti.init(arch=ti.gpu, device_memory_fraction=0.9)
sim_stop_str = 'Simulation Stopped!'
DEBUG = False
def write_vtk_points_polydata_legacy_binary(
    file_path,
    points,
    vectors_dict=None,
    scalars_dict=None
):
    if vectors_dict is None:
        vectors_dict = {}
    if scalars_dict is None:
        scalars_dict = {}
    num_points = len(points)
    points_f32 = points.astype(np.float32)
    points_be = points_f32.byteswap().tobytes()
    
    connectivity = np.zeros((num_points, 2), dtype=np.int32)
    for i in range(num_points):
        connectivity[i, 0] = 1  # Number of points in the vertex
        connectivity[i, 1] = i  # Point index
    connectivity_be = connectivity.byteswap().tobytes()

    with open(file_path, "wb") as f:
        # Write Header
        header = (
            "# vtk DataFile Version 3.0\n"
            "Binary point cloud data\n"
            "BINARY\n"
            "DATASET POLYDATA\n"
        )
        f.write(header.encode("utf-8"))
        # Write POINTS
        f.write(f"POINTS {num_points} float\n".encode("utf-8"))
        f.write(points_be)
        # Write VERTICES
        f.write(f"VERTICES {num_points} {2 * num_points}\n".encode("utf-8"))
        f.write(connectivity_be)
        # Write POINT_DATA
        f.write(f"POINT_DATA {num_points}\n".encode("utf-8"))
        # --- Write Scalars ---
        for name, data in scalars_dict.items():
            if len(data) != num_points:
                raise ValueError(f"Scalar array length mismatch for '{name}'")
            f.write(f"SCALARS {name} float 1\nLOOKUP_TABLE default\n".encode("utf-8"))
            data_f32 = data.astype(np.float32)
            data_be = data_f32.byteswap().tobytes()
            f.write(data_be)
        # --- Write Vectors ---
        for name, data in vectors_dict.items():
            if len(data) != num_points:
                raise ValueError(f"Vector array length mismatch for '{name}'")
            f.write(f"VECTORS {name} float\n".encode("utf-8"))
            data_f32 = data.astype(np.float32)
            data_be = data_f32.byteswap().tobytes()
            f.write(data_be)

def read_vtk_points_polydata_legacy_binary(file_path):
    points = None
    scalars_dict = {}
    vectors_dict = {}
    is_binary = False
    def read_ascii_line(f):
        line = f.readline()
        if not line:
            return None
        return line.decode('utf-8', errors='ignore').strip()
    def read_binary_data(f, count, dtype):
        raw_data = f.read(count)
        if len(raw_data) != count:
            raise ValueError("Unexpected EOF while reading binary data.")
        return np.frombuffer(raw_data, dtype=dtype)
    with open(file_path, 'rb') as f:
        # Read header
        while True:
            line = read_ascii_line(f)
            if not line:
                raise ValueError("Unexpected end of file while reading header.")
            if line.startswith("ASCII"):
                is_binary = False
                break
            elif line.startswith("BINARY"):
                is_binary = True
                break
        # Read dataset structure
        while True:
            line = read_ascii_line(f)
            if not line:
                break
            tokens = line.split()
            if not tokens:
                continue
            keyword = tokens[0].upper()
            if keyword == "POINTS":
                print("keyword",keyword)
                num_points = int(tokens[1])
                data_type = tokens[2].lower()
                if is_binary:
                    points = read_binary_data(f, num_points * 3 * 4, dtype=">f4").reshape((num_points, 3))
                else:
                    points = []
                    for _ in range(num_points):
                        line = read_ascii_line(f)
                        points.append(list(map(np.float32, line.split())))
                    points = np.array(points, dtype=np.float32)
            elif keyword == "VERTICES":
                print
                num_vertices = int(tokens[1])
                size_vertices = int(tokens[2])
                if is_binary:
                    f.read(size_vertices * 4)  
                else:
                    for _ in range(num_vertices):
                        read_ascii_line(f)
            elif keyword == "POINT_DATA":
                num_points_in_point_data = int(tokens[1])
                if points is None:
                    raise ValueError("POINT_DATA found before POINTS.")
            elif keyword == "SCALARS":
                scalar_name = tokens[1]
                read_ascii_line(f)  # Skip LOOKUP_TABLE line
                if is_binary:
                    raw_scalars = read_binary_data(f, num_points_in_point_data * 4, dtype=">f4")
                    scalars_dict[scalar_name] = raw_scalars
                else:
                    scalars = []
                    for _ in range(num_points_in_point_data):
                        scalars.append(float(read_ascii_line(f)))
                    scalars_dict[scalar_name] = np.array(scalars, dtype=np.float32)
            elif keyword == "VECTORS":
                vector_name = tokens[1]
                if is_binary:
                    raw_vectors = read_binary_data(f, num_points_in_point_data * 3 * 4, dtype=">f4")
                    vectors_dict[vector_name] = raw_vectors.reshape((num_points_in_point_data, 3))
                else:
                    vectors = []
                    for _ in range(num_points_in_point_data):
                        vectors.append(list(map(float, read_ascii_line(f).split())))
                    vectors_dict[vector_name] = np.array(vectors, dtype=np.float32)
    if points is None:
        raise ValueError(f"No POINTS found in '{file_path}'. Is it a valid VTK file?")
    return points, scalars_dict, vectors_dict

def save_liquid_particle(
    np_all_positions_i,
    np_all_velocities_i,
    np_all_curvatures_i,
    export_velocity_i,
    blender_file_name_i,
    cnt_vtk_i,
):
    file_path = f"{blender_file_name_i}_output/liquid_particles_{cnt_vtk_i:06}.vtk"
    points = np_all_positions_i  
    vectors = {}
    if export_velocity_i and len(np_all_velocities_i) > 0:
        vectors["velocity"] = np_all_velocities_i  
    scalars = {}
    if len(np_all_curvatures_i) > 0:
        scalars["curvature"] = np_all_curvatures_i  
    write_vtk_points_polydata_legacy_binary(
        file_path=file_path,
        points=points,
        vectors_dict=vectors,
        scalars_dict=scalars
    )

def save_grain_particle(
    np_all_positions_i,
    np_all_velocities_i,
    np_all_radius_i,
    np_all_curvatures_i,
    np_all_density_i,
    export_velocity_i,
    blender_file_name_i,
    cnt_vtk_i,
):
    file_path = f"{blender_file_name_i}_output/grain_particles_{cnt_vtk_i:06}.vtk"
    points = np_all_positions_i  
    vectors = {}
    if export_velocity_i and len(np_all_velocities_i) > 0:
        vectors["velocity"] = np_all_velocities_i  
    scalars = {}
    if len(np_all_radius_i) > 0:
        scalars["radius"] = np_all_radius_i
    if len(np_all_curvatures_i) > 0:
        scalars["curvature"] = np_all_curvatures_i  
    if len(np_all_density_i) > 0:
        scalars["density"] = np_all_density_i
    write_vtk_points_polydata_legacy_binary(
        file_path=file_path,
        points=points,
        vectors_dict=vectors,
        scalars_dict=scalars
    )

def save_rigid_body_particle(
    rigid_body_data_i,
    export_pressure_force_i,
    export_viscous_force_i,
    blender_file_name_i,
    r_body_id_i,
    cnt_vtk_i,
):
    file_path = f"{blender_file_name_i}_output/rigid_body_particle_obj{r_body_id_i}_{cnt_vtk_i:06}.vtk"
    points = rigid_body_data_i["position"]
    vectors = {}
    if export_pressure_force_i:
        vectors["pressure_force"] = rigid_body_data_i["pressure_force"]
    if export_viscous_force_i:
        vectors["viscous_force"] = rigid_body_data_i["viscous_force"]
    scalars = {}
    write_vtk_points_polydata_legacy_binary(
        file_path=file_path,
        points=points,
        vectors_dict=vectors,
        scalars_dict=scalars
    )

#-----------------------------------------------------------------------------------------------#
def update_status(dir, tmp_dict, initialized=False):
    if initialized:
        tmp_dict['initialized'] = True
    with open(os.path.join(dir, 'status.json'), 'w') as out_file:
        json.dump(tmp_dict, out_file)
def check_pause(dir):
    file_name = os.path.join(dir, 'pause.json')
    if os.path.exists(file_name):
        os.remove(file_name)
        return True
    else:
        return False
def check_continue(dir):
    file_name = os.path.join(dir, 'continue.json')
    if os.path.exists(file_name):
        os.remove(file_name)
        return True
    else:
        return False
def check_stop(dir):
    file_name = os.path.join(dir, 'stop.json')
    if os.path.exists(file_name):
        os.remove(file_name)
        return True
    else:
        return False
def write_finish(dir):
    file_name = os.path.join(dir, 'finish.txt')
    with open(file_name, 'w') as out_file:
        out_file.write('Finished!')
def check_status(dir):
    if check_pause(dir):
        while not check_continue(dir):
            time.sleep(1)
            if check_stop(dir):
                raise Exception(sim_stop_str)
def run(scene_path, keyframed_folder_path=None):
    json_file_name = ".".join(scene_path.split(".")[:-1])
    blender_file_name = json_file_name.replace('_initial_condition', '')
    output_folder_name = f"{blender_file_name}_output"
    cache_directory = os.path.dirname(scene_path)
    try:
        status_dict = {
            'initialized': False,
            'cnt': 0,
            'cnt_vtk': 0,
            'total_frames': 0,
            'cnt_frame': 0,
        }
        config = SimConfig(scene_file_path=scene_path)
        if os.path.exists(output_folder_name):
            shutil.rmtree(output_folder_name)
        os.makedirs(output_folder_name, exist_ok=True)
        output_folder_path = os.path.join(cache_directory, output_folder_name)
        motion_data_directory = os.path.join(output_folder_path, 'motion_data')
        
        if os.path.exists(motion_data_directory):
            shutil.rmtree(motion_data_directory)  
        os.makedirs(motion_data_directory, exist_ok=True) 
        substeps = config.get_domain("timestep_per_frame")
        print(f"substeps: {substeps}")
        export_velocity = config.get_domain("export_fluid_velocity")
        export_pressure_force = config.get_domain("export_rigid_pressure_force")
        export_viscous_force = config.get_domain("export_rigid_viscous_force")
        export_vtk = config.get_domain("exportVtk")
        if export_vtk:
            os.makedirs(f"{blender_file_name}_output", exist_ok=True)
        check_status(cache_directory)
        ps = ParticleSystem(config)
        solver = FlowProperties(ps)
        ps.manage_particle_system()
        solver.initialize()
        initialization_data_file_path=os.path.join(cache_directory, 'Initialization_data_from_compute.txt')
        with open(initialization_data_file_path, 'w') as f:
            f.write(f"No. of grids: {ps.grid_num}\n")
            f.write(f"Particle radius: {ps.particle_radius}\n")
            if ps.fluid_particle_num > 0:
                f.write(f"Fluid particles: {ps.fluid_particle_num}\n")
            if ps.rigid_particle_num > 0:
                f.write(f"Solid objects particles: {ps.rigid_particle_num}\n")
            if ps.grain_particle_num > 0:
                f.write(f"Grain particles: {ps.grain_particle_num}\n")
            f.write(f"Total no. particles within Domain: {ps.total_particle_num}\n")
        output_interval = int((1 / config.get_domain("fps")) / solver.dt[None])
        status_dict['cnt'] = config.get_domain("start_frame")
        status_dict['cnt_vtk'] = 0
        status_dict['total_frames'] = config.get_domain("end_frame") - config.get_domain("start_frame") + 1
        status_dict['cnt_frame'] = int(status_dict['cnt']/output_interval)
        update_status(cache_directory, status_dict, initialized=True)
        while True:
                check_status(cache_directory)
                for i in range(substeps):
                    check_status(cache_directory)
                    solver.step(status_dict['cnt_frame'], status_dict['cnt'], output_interval)
                if status_dict['cnt']==1 or status_dict['cnt'] % output_interval == 0:
                    latest_frame = int(status_dict['cnt'] / output_interval)
                    print(f"\rLatest frame: {latest_frame}", end='', flush=True)
                    if ps.rigid_particle_num > 0:
                        for obj_id in ps.object_id_rigid_body:
                            if ps.object_collection[obj_id]["isDynamic"]:
                                cm = ps.centers_of_mass[obj_id].to_numpy()
                                R = ps.rotation_matrices[obj_id].to_numpy()
                                ps.motion_data[obj_id] = {'center_of_mass': cm.tolist(),'rotation_matrix': R.tolist()}
                        ps.export_motion_data()  
                        json_filename = f"motion_data_{status_dict['cnt_vtk']:06}.json"
                        json_file_path = os.path.join(motion_data_directory, json_filename)
                        ps.write_motion_data_to_json(json_file_path)
                    if export_vtk:
                        if ps.fluid_particle_num > 0:
                            liquid_positions = []
                            liquid_velocities = []
                            liquid_curvatures = []
                            for obj_id in ps.object_id_fluid_body:  
                                liquid_data = ps.dump_id(obj_id, phase="liquid")
                                if liquid_data:
                                    liquid_positions.extend(liquid_data["position"])
                                    if export_velocity:
                                        liquid_velocities.extend(liquid_data["velocity"])
                                    liquid_curvatures.extend(liquid_data["curvature"])
                            if liquid_positions:
                                save_liquid_particle(
                                    np.array(liquid_positions), 
                                    np.array(liquid_velocities), 
                                    np.array(liquid_curvatures), 
                                    export_velocity, 
                                    blender_file_name, 
                                    status_dict['cnt_vtk']
                                )
                            
                        if ps.grain_particle_num > 0:
                            grain_positions = []
                            grain_velocities = []
                            grain_radius = []
                            grain_curvatures = []
                            grain_density = []
                            for obj_id in ps.object_id_grain_body:
                                grain_data = ps.dump_id(obj_id, phase="grain") 
                                if grain_data:
                                    grain_positions.extend(grain_data["position"])
                                    if export_velocity:
                                        grain_velocities.extend(grain_data["velocity"])
                                    grain_radius.extend(grain_data["radius"])
                                    grain_density.extend(grain_data["density"])
                            if grain_positions:
                                save_grain_particle(
                                    np.array(grain_positions), 
                                    np.array(grain_velocities), 
                                    np.array(grain_radius),
                                    np.array(grain_curvatures), 
                                    np.array(grain_density),
                                    export_velocity, 
                                    blender_file_name, 
                                    status_dict['cnt_vtk']
                                )
                         #Liquid Inlet Particles ---
                        if ps.inlet_particle_num > 0:
                            liquid_inlet_positions = []
                            liquid_inlet_velocities = []
                            liquid_inlet_curvatures = []
                            for obj_id in ps.object_id_inlet_body:  
                                liquid_inlet_data = ps.dump_id(obj_id, phase="liquid")  # Dump only liquid inlet
                                if liquid_inlet_data:
                                    liquid_inlet_positions.extend(liquid_inlet_data["position"])
                                    if export_velocity:
                                        liquid_inlet_velocities.extend(liquid_inlet_data["velocity"])
                                    liquid_inlet_curvatures.extend(liquid_inlet_data["curvature"])
                            if liquid_inlet_positions:
                                save_inlet_liquid_particle(
                                    np.array(liquid_inlet_positions), 
                                    np.array(liquid_inlet_velocities), 
                                    np.array(liquid_inlet_curvatures), 
                                    export_velocity, 
                                    blender_file_name, 
                                    status_dict['cnt_vtk']
                                )

                    
                        if ps.rigid_particle_num > 0:
                            for r_body_id in ps.object_id_rigid_body:  
                                rigid_body_data = ps.dump_id(r_body_id)  
                                save_rigid_body_particle(rigid_body_data, export_pressure_force, export_viscous_force, blender_file_name, r_body_id, status_dict['cnt_vtk'])
                    max_curvature_value = ps.max_curvature[None]  
                    max_curvature_file_path = os.path.join(output_folder_name, "max_curvature.txt")
                    with open(max_curvature_file_path, "w") as file:
                        file.write(str(max_curvature_value))
                    status_dict['cnt_vtk']+=1
                status_dict['cnt'] += 1
                status_dict['cnt_frame'] = int(status_dict['cnt'] / output_interval)
                update_status(cache_directory, status_dict)
                if int(status_dict['cnt']/output_interval) > config.get_domain("end_frame"):
                    break
                progress_file_path = os.path.join(output_folder_name, "simulation_progress.txt")
                with open(progress_file_path, "w") as progress_file:
                    progress_file.write(f"{int(status_dict['cnt']/output_interval)}\n")
    except Exception as e:
        exception_message = str(e)
        if DEBUG:
            print(exception_message)
        else:
            if exception_message == sim_stop_str:
                return
        keywords = ['CUDA', 'dimensions', 'Taichi', 'taichi', 'array']
        if any(word in exception_message for word in keywords):
            error_str = 'Simulation divergence!'
        elif 'max_iter exceeded!' in exception_message:
            error_str = ('The geometry surface is too simple for simulation. '
                        'Please consider subdividing or remeshing relevant objects and run the simulation again')
        else:
            error_str = exception_message
        error_path = os.path.join(cache_directory, 'error.txt')
        with open(error_path, 'w') as f:
            f.write(error_str)
        write_finish(cache_directory)
        if not DEBUG:
            raise e
    write_finish(cache_directory)
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Doriflow')
    parser.add_argument('--scene_file',
                        default='',
                        help='scene file')
    args = parser.parse_args()
    scene_path = args.scene_file
    run(scene_path)

