# Doriflow Engine - Fluid Simulation for Blender 3D
# Copyright (C) 2024 Doriflow Team
# This software is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0).

# You are free to:
# -Share: Copy and redistribute the material in any medium or format.
# -Adapt: Remix, transform, and build upon the material.

# UNDER THE FOLOWING TERMS:
# -Attribution: You must give appropriate credit, provide a link to the license, and indicate if changes were made.
# -Appropriate credit should include the following:
#   -The original author's name: Doriflow Team
#   -A link to the original source (if applicable).
#   -A link to the full license: https://creativecommons.org/licenses/by-nc/4.0/.
#   -A clear indication of any changes made, such as: "This material has been modified."
# NonCommercial: You may not use the material for commercial purposes.
# Disclaimer:
# -This simulation engine is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. In no event shall the authors be liable for any claim, damages, or other liability arising from the use of this software.

# For more details, refer to the full license text at:
# https://creativecommons.org/licenses/by-nc/4.0/.

#-----------------------------------------------------------------------------------------------------------------------#
import bpy
import numpy as np
from .voxelization_geometry_node_operators import calculate_grid_properties

class DORIFLOW_OT_GrainGeometryNodesIcospheres(bpy.types.Operator):
    bl_idname = "doriflow.grain_geometry_nodes"
    bl_label = "Render grain as Icospheres via Geometry Nodes"
    bl_description = "Render grain and inlet grain particles as icospheres using Geometry Nodes for better performance."

    PARTICLE_TYPES = {
        "Grain": ("DF.Particles.Grain", "DF.Mesh.Grain"),
        "Inlet_Grain": ("DF.Particles.Inlet_Grain", "DF.Mesh.Inlet_Grain")
    }

    def create_shared_icosphere(self, particle_radius):
        icosphere_name = "DF.grain_icosphere_Instance"
        if icosphere_name in bpy.data.objects:
            return bpy.data.objects[icosphere_name]

        bpy.ops.mesh.primitive_ico_sphere_add(radius=particle_radius, enter_editmode=False, location=(0, 0, 0))
        icosphere = bpy.context.active_object
        icosphere.name = icosphere_name
        icosphere.hide_render = True
        icosphere.hide_viewport = True

        mat_name = "Diffuse_BSDF_Material"
        if mat_name not in bpy.data.materials:
            mat = bpy.data.materials.new(name=mat_name)
            mat.use_nodes = True
            bsdf = mat.node_tree.nodes.get("Diffuse BSDF")
            if bsdf:
                bsdf.inputs['Base Color'].default_value = (1.0, 1.0, 1.0, 1.0)
        else:
            mat = bpy.data.materials[mat_name]

        if not icosphere.data.materials:
            icosphere.data.materials.append(mat)
        else:
            icosphere.data.materials[0] = mat

        return icosphere

    def get_or_create_shared_node_group(self, particle_name, particle_radius, radius_std):
        node_group_name = f"DF.{particle_name}_Icosphere_Instances"
        if node_group_name in bpy.data.node_groups:
            bpy.data.node_groups.remove(bpy.data.node_groups[node_group_name], do_unlink=True)
        node_group = bpy.data.node_groups.new(name=node_group_name, type='GeometryNodeTree')
        self.configure_nodes(node_group, particle_name, particle_radius, radius_std)
        return node_group

    def configure_nodes(self, node_group, particle_name, particle_radius):
        icosphere = self.create_shared_icosphere(particle_radius)
        particle_obj = bpy.data.objects.get(self.PARTICLE_TYPES[particle_name][0])

        node_group.nodes.clear()

        obj_info_particle = node_group.nodes.new("GeometryNodeObjectInfo")
        obj_info_particle.inputs['Object'].default_value = particle_obj
        obj_info_particle.location = (-600, 100)

        obj_info_ico = node_group.nodes.new("GeometryNodeObjectInfo")
        obj_info_ico.inputs['Object'].default_value = icosphere
        obj_info_ico.location = (-600, -100)

        instance_on_points = node_group.nodes.new("GeometryNodeInstanceOnPoints")
        instance_on_points.location = (-200, 0)

        group_output = node_group.nodes.new("NodeGroupOutput")
        group_output.location = (100, 0)

        node_group.interface.new_socket(
            name="Geometry", description="Output geometry", in_out='OUTPUT', socket_type='NodeSocketGeometry'
        )

        node_group.links.new(obj_info_particle.outputs['Geometry'], instance_on_points.inputs['Points'])
        node_group.links.new(obj_info_ico.outputs['Geometry'], instance_on_points.inputs['Instance'])
        node_group.links.new(instance_on_points.outputs['Instances'], group_output.inputs['Geometry'])
    
    def configure_nodes(self, node_group, particle_name, particle_radius, radius_std):
        icosphere = self.create_shared_icosphere(particle_radius)
        particle_obj = bpy.data.objects.get(self.PARTICLE_TYPES[particle_name][0])

        node_group.nodes.clear()
        obj_info_particle = node_group.nodes.new("GeometryNodeObjectInfo")
        obj_info_particle.inputs['Object'].default_value = particle_obj
        obj_info_particle.location = (-600, 100)
        obj_info_ico = node_group.nodes.new("GeometryNodeObjectInfo")
        obj_info_ico.inputs['Object'].default_value = icosphere
        obj_info_ico.location = (-600, -100)
        attr_node = node_group.nodes.new("GeometryNodeInputNamedAttribute")
        attr_node.data_type = 'FLOAT'
        attr_node.inputs['Name'].default_value = "radius"
        attr_node.location = (-400, 100)
        math_divide = node_group.nodes.new("ShaderNodeMath")
        math_divide.operation = 'DIVIDE'
        math_divide.inputs[1].default_value = particle_radius*2  
        math_divide.location = (-200, 100)

        instance_on_points = node_group.nodes.new("GeometryNodeInstanceOnPoints")
        instance_on_points.location = (0, 0)

        group_output = node_group.nodes.new("NodeGroupOutput")
        group_output.location = (300, 0)

        node_group.interface.new_socket(
            name="Geometry", description="Output geometry", in_out='OUTPUT', socket_type='NodeSocketGeometry'
        )

        node_group.links.new(obj_info_particle.outputs['Geometry'], instance_on_points.inputs['Points'])
        node_group.links.new(obj_info_ico.outputs['Geometry'], instance_on_points.inputs['Instance'])

        node_group.links.new(attr_node.outputs['Attribute'], math_divide.inputs[0])
        node_group.links.new(math_divide.outputs['Value'], instance_on_points.inputs['Scale'])

        node_group.links.new(instance_on_points.outputs['Instances'], group_output.inputs['Geometry'])
    
    def setup_geometry_nodes(self, mesh_obj, node_group, particle_name):
        modifier_name = f"DF.{particle_name}_Instances"
        modifier = mesh_obj.modifiers.get(modifier_name)
        if not modifier:
            modifier = mesh_obj.modifiers.new(name=modifier_name, type='NODES')
        modifier.node_group = node_group

    def execute(self, context):
        domain_resolution = 128
        largest_domain_size = 1.0
        for obj in bpy.data.objects:
            if hasattr(obj, 'doriflow') and obj.doriflow.object_type == 'TYPE_DOMAIN':
                domain_resolution = obj.doriflow.domain.resolution
                radius_std = obj.doriflow.domain.radius_std
                largest_domain_size = max(np.array(obj.dimensions))

        _, _, _, particle_radius = calculate_grid_properties(largest_domain_size, domain_resolution)
        icosphere = self.create_shared_icosphere(particle_radius)

        for particle_name, (source_obj_name, mesh_obj_name) in self.PARTICLE_TYPES.items():
            if bpy.data.objects.get(source_obj_name) is None:
                continue

            particle_obj = bpy.data.objects.get(source_obj_name)
            if not particle_obj:
                continue  

            mesh_obj = bpy.data.objects.get(mesh_obj_name)
            if not mesh_obj:
                bpy.ops.object.select_all(action='DESELECT')
                bpy.ops.mesh.primitive_cube_add(size=1)
                mesh_obj = bpy.context.active_object
                mesh_obj.name = mesh_obj_name


            node_group = self.get_or_create_shared_node_group(particle_name, particle_radius, radius_std)
            self.setup_geometry_nodes(mesh_obj, node_group, particle_name)

        for obj in bpy.data.objects:
            if hasattr(obj, 'doriflow') and obj.doriflow.object_type == 'TYPE_DOMAIN':
                bpy.context.view_layer.objects.active = obj
                obj.select_set(True)

        self.report({'INFO'}, "Grain and inlet grain particles rendered as icospheres using geometry nodes.")
        return {'FINISHED'}


def register():
    bpy.utils.register_class(DORIFLOW_OT_GrainGeometryNodesIcospheres)

def unregister():
    bpy.utils.unregister_class(DORIFLOW_OT_GrainGeometryNodesIcospheres)

if __name__ == "__main__":
    register()
