# Doriflow Engine - Fluid Simulation for Blender 3D
# Copyright (C) 2024 Doriflow Team
# This software is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0).

# You are free to:
# -Share: Copy and redistribute the material in any medium or format.
# -Adapt: Remix, transform, and build upon the material.

# UNDER THE FOLOWING TERMS:
# -Attribution: You must give appropriate credit, provide a link to the license, and indicate if changes were made.
# -Appropriate credit should include the following:
#   -The original author's name: Doriflow Team
#   -A link to the original source (if applicable).
#   -A link to the full license: https://creativecommons.org/licenses/by-nc/4.0/.
#   -A clear indication of any changes made, such as: "This material has been modified."
# NonCommercial: You may not use the material for commercial purposes.
# Disclaimer:
# -This simulation engine is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. In no event shall the authors be liable for any claim, damages, or other liability arising from the use of this software.

# For more details, refer to the full license text at:
# https://creativecommons.org/licenses/by-nc/4.0/.

#-----------------------------------------------------------------------------------------------------------------------#

import bpy
import numpy as np
from .voxelization_geometry_node_operators import calculate_grid_properties

class DORIFLOW_OT_MeshGeometryNodes(bpy.types.Operator):
    bl_idname = "doriflow.mesh_geometry_nodes"
    bl_label = "Generate Mesh via Geometry Nodes"
    bl_description = "Generate a mesh or volume representation for each particle type using Geometry Nodes."

    PARTICLE_TYPES = {
        "Liquid": ("DF.Particles.Liquid", "DF.Mesh.Liquid", "DF.Liquid_Material"),
        "Gas": ("DF.Particles.Gas", "DF.Mesh.Gas", "DF.Gas_Material"),
        "Inlet_Liquid": ("DF.Particles.Inlet_Liquid", "DF.Mesh.Inlet_Liquid", "DF.Inlet_Liquid_Material"),
        "Inlet_Gas": ("DF.Particles.Inlet_Gas", "DF.Mesh.Inlet_Gas", "DF.Inlet_Gas_Material"),
    }

    def get_or_create_shared_node_group(self, particle_name, particle_radius):
        node_group_name = f"DF.{particle_name}_to_mesh"
        
        if node_group_name not in bpy.data.node_groups:
            node_group = bpy.data.node_groups.new(name=node_group_name, type='GeometryNodeTree')
            self.configure_nodes(node_group, particle_name, particle_radius)
        else:
            node_group = bpy.data.node_groups[node_group_name]
            self.update_object_info_nodes(node_group, particle_name)
        return node_group
    def update_object_info_nodes(self, node_group, particle_name):
        object_name = self.PARTICLE_TYPES[particle_name][0]  
        object_ref = bpy.data.objects.get(object_name)  

        if object_ref:
            object_info_node = None
            for node in node_group.nodes:
                if node.type == 'OBJECT_INFO':  
                    object_info_node = node
                    break
            
            if object_info_node:
                object_info_node.inputs['Object'].default_value = object_ref
            else:
                self.report({'WARNING'}, f"Object Info node not found in {node_group.name}")
        else:
            self.report({'WARNING'}, f"Object '{object_name}' not found in the scene.")

    def configure_nodes(self, node_group, particle_name, particle_radius):
        node_group.nodes.clear()
        object_info = node_group.nodes.new('GeometryNodeObjectInfo')
        object_info.inputs['Object'].default_value = bpy.data.objects.get(self.PARTICLE_TYPES[particle_name][0])
        point_to_volume = node_group.nodes.new('GeometryNodePointsToVolume')
        point_to_volume.inputs['Voxel Amount'].default_value = 175
        point_to_volume.inputs['Radius'].default_value = particle_radius * 3
        if "Gas" in particle_name:
            set_material = node_group.nodes.new('GeometryNodeSetMaterial')
            group_output = node_group.nodes.new(type='NodeGroupOutput')
            node_group.interface.new_socket(name="Geometry", in_out='OUTPUT', socket_type='NodeSocketGeometry')
            object_info.location = (-400, 0)
            point_to_volume.location = (-200, 0)
            set_material.location = (0, 0)
            group_output.location = (200, 0)
            node_group.links.new(object_info.outputs['Geometry'], point_to_volume.inputs['Points'])
            node_group.links.new(point_to_volume.outputs['Volume'], set_material.inputs['Geometry'])
            node_group.links.new(set_material.outputs['Geometry'], group_output.inputs['Geometry'])

        else:
            volume_to_mesh = node_group.nodes.new('GeometryNodeVolumeToMesh')
            volume_to_mesh.resolution_mode = 'VOXEL_AMOUNT'
            volume_to_mesh.inputs['Voxel Amount'].default_value = 175
            volume_to_mesh.inputs['Threshold'].default_value = 0.25
            shade_smooth = node_group.nodes.new('GeometryNodeSetShadeSmooth')
            set_material = node_group.nodes.new('GeometryNodeSetMaterial')
            group_output = node_group.nodes.new(type='NodeGroupOutput')
            node_group.interface.new_socket(name="Geometry", in_out='OUTPUT', socket_type='NodeSocketGeometry')
            object_info.location = (-500, 0)
            point_to_volume.location = (-300, 0)
            volume_to_mesh.location = (-100, 0)
            shade_smooth.location = (100, 0)
            set_material.location = (300, 0)
            group_output.location = (500, 0)
            node_group.links.new(object_info.outputs['Geometry'], point_to_volume.inputs['Points'])
            node_group.links.new(point_to_volume.outputs['Volume'], volume_to_mesh.inputs['Volume'])
            node_group.links.new(volume_to_mesh.outputs['Mesh'], shade_smooth.inputs['Geometry'])
            node_group.links.new(shade_smooth.outputs['Geometry'], set_material.inputs['Geometry'])
            node_group.links.new(set_material.outputs['Geometry'], group_output.inputs['Geometry'])

        material_name = self.PARTICLE_TYPES[particle_name][2]
        set_material.inputs['Material'].default_value = self.get_material(material_name,particle_name)

    def get_material(self, material_name, particle_name):
        if material_name not in bpy.data.materials:
            material = bpy.data.materials.new(name=material_name)
            material.use_nodes = True
            nodes = material.node_tree.nodes
            nodes.clear()
            links = material.node_tree.links

            if "Gas" in particle_name:
                volume_node = nodes.new(type='ShaderNodeVolumePrincipled')
                material_output = nodes.new(type='ShaderNodeOutputMaterial')
                links.new(volume_node.outputs['Volume'], material_output.inputs['Volume'])
                volume_node.inputs['Color'].default_value = (0.8, 0.8, 0.8, 1.0)  
                volume_node.inputs['Density'].default_value = 1  
                volume_node.inputs['Anisotropy'].default_value = 0.0  

            else:
                bsdf_node = nodes.new(type='ShaderNodeBsdfPrincipled')
                material_output = nodes.new(type='ShaderNodeOutputMaterial')
                links.new(bsdf_node.outputs['BSDF'], material_output.inputs['Surface'])
                color = (0x8A / 0xFF, 0xD1 / 0xFF, 0xFF / 0xFF, 0.8)  
                bsdf_node.inputs['Base Color'].default_value = color
                bsdf_node.inputs['Alpha'].default_value = 1.0
                bsdf_node.inputs['Transmission Weight'].default_value = 1 
                bsdf_node.inputs['Roughness'].default_value = 0.0
        else:
            material = bpy.data.materials[material_name]
        return material

    def setup_geometry_nodes(self, mesh, shared_node_group, particle_name):
        modifier_name = f"DF.{particle_name}_to_mesh" 
        geo_nodes_modifier = None
        for modifier in mesh.modifiers:
            if modifier.type == 'NODES' and modifier.name == modifier_name:
                geo_nodes_modifier = modifier
                break
        if not geo_nodes_modifier:
            geo_nodes_modifier = mesh.modifiers.new(name=modifier_name, type='NODES')
        geo_nodes_modifier.node_group = shared_node_group

    def execute(self, context):
        objects = bpy.data.objects
        domain_resolution = 128  
        largest_domain_size = 1.0  
        for obj in objects:
            if hasattr(obj, "doriflow") and obj.doriflow.object_type == 'TYPE_DOMAIN':
                domain_resolution = obj.doriflow.domain.resolution
                largest_domain_size = max(np.array(obj.dimensions))
        grid_size, particle_diameter, grid_num, particle_radius = calculate_grid_properties(
            largest_domain_size, domain_resolution
        )
        for particle_name, (particle_obj_name, mesh_obj_name, material_name) in self.PARTICLE_TYPES.items():
            particle_obj = bpy.data.objects.get(particle_obj_name)
            if not particle_obj:
                continue
            shared_node_group = self.get_or_create_shared_node_group(particle_name, particle_radius)
            existing_mesh = bpy.data.objects.get(mesh_obj_name)
            if existing_mesh:
                bpy.data.objects.remove(existing_mesh, do_unlink=True)
            bpy.ops.object.select_all(action='DESELECT')
            bpy.ops.mesh.primitive_cube_add(size=1)
            new_mesh_obj = bpy.context.active_object
            new_mesh_obj.name = mesh_obj_name

            self.setup_geometry_nodes(new_mesh_obj, shared_node_group, particle_name)
            for obj in bpy.data.objects:
                if hasattr(obj, 'doriflow') and obj.doriflow.object_type == 'TYPE_DOMAIN':
                    domain_object = obj
                    bpy.context.view_layer.objects.active = domain_object
                    domain_object.select_set(True)
        return {'FINISHED'}

def register():
    bpy.utils.register_class(DORIFLOW_OT_MeshGeometryNodes)

def unregister():
    bpy.utils.unregister_class(DORIFLOW_OT_MeshGeometryNodes)

if __name__ == "__main__":
    register()