# --
# Copyright (C) CEA, EDF
# Author : Erwan ADAM (CEA)
# --
from vtk import *

class vtkUnScaledActor(vtkFollower):
    # a kind of follower that keeps a constant size
    def Render(self,renderer):
        # todo : calculate the size so that it appears with constant
        #        size in window
##         window_size = renderer.GetRenderWindow().GetSize()
##         print 'window_size', window_size
        vtkFollower.Render(self,renderer)
        return
    def SetSize(self,size):
        self.size=size
        return
    pass

class Axis(vtkObject):
    def __init__(self, x, y, z, label,r,g,b,rx,ry,rz):
        # First a line
        self.line_source = vtkLineSource()
        self.line_source.SetPoint1(0.,0.,0.)
        self.line_source.SetPoint2(x,y,z)
        axis_mapper = vtkPolyDataMapper()
        axis_mapper.SetInput(self.line_source.GetOutput())
        self.axis_actor = vtkActor()
        self.axis_actor.SetMapper(axis_mapper)
#        self.axis_actor.PickableOff()
        # Then a cone
        cone_source = vtkConeSource()
        cone_source.SetResolution(2)
        cone_source.SetAngle(10)
        cone_mapper = vtkPolyDataMapper()
        cone_mapper.SetInput(cone_source.GetOutput())
        cone_actor = vtkUnScaledActor()
        cone_actor.SetMapper(cone_mapper)
        cone_actor.SetOrientation(rx,ry,rz)
        cone_actor.SetSize(1)
#        cone_actor.PickableOff()
        self.cone_actor = cone_actor
        # Then a text
        label_source = vtkVectorText()
        label_source.SetText(label)
        label_mapper = vtkPolyDataMapper()
        label_mapper.SetInput(label_source.GetOutput())
        label_actor = vtkUnScaledActor()
        label_actor.SetMapper(label_mapper)
        label_actor.SetScale(1)
#        label_actor.PickableOff()
        self.label_actor = label_actor

        # set the colors
        property = vtkProperty()
        property.SetColor(r,g,b)
        self.axis_actor.SetProperty(property)
        self.cone_actor.SetProperty(property)
        self.label_actor.SetProperty(property)
        return
    def setSize(self,size):
        self.line_source.SetPoint2(self.x*size,self.y*size,self.z*size)
        self.cone_actor.SetPosition(self.x*size,self.y*size,self.z*size)
        self.label_actor.SetPosition(self.x*size,self.y*size,self.z*size)
        return
    def addToRender(self,render):
        render.AddActor(self.axis_actor)
#        render.AddActor(self.cone_actor)
#        render.AddActor(self.label_actor)
        return
    def removeFromRender(self,render):
        render.RemoveActor(self.axis_actor)
        return
    def setCamera(self, camera):
        self.label_actor.SetCamera(camera)
        return
    pass
class XAxis(Axis):
    def __init__(self):
        self.x = 1.
        self.y = self.z = 0.
        Axis.__init__(self,self.x,self.y,self.z,'X',1.,0.,0.,0.,0.,0.)
        return
    pass
class YAxis(Axis):
    def __init__(self):
        self.x = self.z = 0.
        self.y = 1.
        Axis.__init__(self,self.x,self.y,self.z,'Y',0.,1.,0.,0.,0.,90.)
        return
    pass
class ZAxis(Axis):
    def __init__(self):
        self.x = self.y = 0.
        self.z = 1.
        Axis.__init__(self,self.x,self.y,self.z,'Z',0.,0.,1.,0.,-90.,0.)
        return
    pass
class Triedron:
    def __init__(self):
        self.x = XAxis()
        self.y = YAxis()
        self.z = ZAxis()
        self.actors = [
            self.x.axis_actor, self.x.cone_actor, self.x.label_actor,
            self.y.axis_actor, self.y.cone_actor, self.y.label_actor,
            self.z.axis_actor, self.z.cone_actor, self.z.label_actor
            ]
        self.visible = 1
        return
    def setSize(self,size,render):
        self.x.setSize(size)
        self.y.setSize(size)
        self.z.setSize(size)
        self.x.label_actor.Render(render)
        self.y.label_actor.Render(render)
        self.z.label_actor.Render(render)
        return
    def addToRender(self,render):
        camera = render.GetActiveCamera()
        for axis in [self.x, self.y, self.z]:
            axis.addToRender(render)
            axis.setCamera(camera)
        return
    def removeFromRender(self,render):
        for axis in [self.x, self.y, self.z]:
            axis.removeFromRender(render)
        return
    def isVisible(self):
        return self.visible
    def show(self,render):
        self.addToRender(render)
        self.visible = 1
        return
    def hide(self,render):
        self.removeFromRender(render)
        self.visible = 0
        return
    pass
