from manimlib import *


BG = "#102124"
INK = "#172127"
PAPER = "#f6f7f4"
TEAL = "#35b69d"
TEAL_DARK = "#176557"
AMBER = "#e5a542"
CORAL = "#d95c54"
VIOLET = "#8a75b8"
MUTED = "#9fb0ac"


def txt(label, size=30, color=PAPER, weight=BOLD):
    return Text(label, font="Arial", font_size=size, weight=weight).set_color(color)


def panel(width, height, stroke=TEAL):
    box = RoundedRectangle(width=width, height=height, corner_radius=0.16)
    box.set_fill(PAPER, opacity=0.08)
    box.set_stroke(stroke, width=2)
    return box


def basis(columns=10, active=4, tail=3, scale=1.0):
    bars = VGroup()
    for i in range(columns):
        height = (0.7 + 0.18 * np.sin(i * 0.75)) * scale
        bar = RoundedRectangle(width=0.13 * scale, height=height, corner_radius=0.03)
        if i < active:
            color = TEAL
        elif i < active + tail:
            color = AMBER
        else:
            color = "#536462"
        bar.set_fill(color, opacity=0.95)
        bar.set_stroke(color, width=0)
        bars.add(bar)
    bars.arrange(RIGHT, buff=0.055 * scale, aligned_edge=DOWN)
    return bars


class FedHeraScene(Scene):
    def setup(self):
        self.camera.background_color = BG

    def title(self, text, subtitle=None):
        title = txt(text, size=34, color=PAPER)
        title.to_edge(UP, buff=0.28)
        self.add(title)
        if subtitle:
            sub = txt(subtitle, size=20, color=MUTED, weight=NORMAL)
            sub.next_to(title, DOWN, buff=0.12)
            self.add(sub)


class RankDecouplingScene(FedHeraScene):
    def construct(self):
        self.title("FedHera decouples rank reception from rank training")

        server = panel(2.15, 1.35, TEAL).shift(LEFT * 4.05 + UP * 0.2)
        server_label = txt("Server SVD", size=24).move_to(server).shift(UP * 0.25)
        server_basis = basis(columns=10, active=5, tail=3, scale=0.82)
        server_basis.next_to(server_label, DOWN, buff=0.18)

        client = panel(2.55, 1.5, PAPER).shift(RIGHT * 3.75 + UP * 0.12)
        client_label = txt("Client", size=24).move_to(client).shift(UP * 0.42)
        client_basis = basis(columns=9, active=3, tail=4, scale=0.86)
        client_basis.next_to(client_label, DOWN, buff=0.2)

        down_arrow = Arrow(server.get_right(), client.get_left(), buff=0.18, color=TEAL)
        down_text = txt("download r_tot", size=21, color="#d7f7ee")
        down_text.next_to(down_arrow, UP, buff=0.24)

        train_box = SurroundingRectangle(client_basis[:3], color=TEAL, buff=0.08)
        tail_box = SurroundingRectangle(client_basis[3:7], color=AMBER, buff=0.08)
        train_text = txt("trainable prefix", size=18, color=TEAL)
        tail_text = txt("frozen tail", size=18, color=AMBER)
        train_text.next_to(train_box, DOWN, buff=0.25).shift(LEFT * 0.28)
        tail_text.next_to(tail_box, DOWN, buff=0.25).shift(RIGHT * 0.35)

        memory = VGroup(
            txt("optimizer memory", size=20, color=MUTED, weight=NORMAL),
            Rectangle(width=2.35, height=0.18).set_fill("#536462", 0.7).set_stroke(width=0),
            Rectangle(width=0.88, height=0.18).set_fill(TEAL, 1).set_stroke(width=0),
        )
        memory.arrange(DOWN, buff=0.1).next_to(client, DOWN, buff=0.85)
        memory[2].align_to(memory[1], LEFT)

        self.play(FadeIn(server), FadeIn(server_label), LaggedStartMap(FadeIn, server_basis, lag_ratio=0.05))
        self.play(GrowArrow(down_arrow), Write(down_text))
        self.play(FadeIn(client), FadeIn(client_label), LaggedStartMap(FadeIn, client_basis, lag_ratio=0.05))
        self.play(ShowCreation(train_box), Write(train_text))
        self.play(ShowCreation(tail_box), Write(tail_text), FadeIn(memory))
        self.wait(1.0)


class WaterFillingScene(FedHeraScene):
    def construct(self):
        self.title("Layer-wise spectrum-preserving water filling")

        values = [
            [0.92, 0.64, 0.48, 0.24],
            [0.78, 0.73, 0.52, 0.35],
            [0.88, 0.55, 0.30, 0.18],
            [0.66, 0.59, 0.45, 0.21],
        ]
        labels = ["Layer 1", "Layer 2", "Layer 3", "Layer 4"]
        groups = VGroup()
        bars_by_key = {}
        for li, layer_values in enumerate(values):
            layer = VGroup()
            for ri, value in enumerate(layer_values):
                bar = RoundedRectangle(width=0.24, height=1.95 * value, corner_radius=0.04)
                bar.set_fill("#536462", opacity=0.75)
                bar.set_stroke(width=0)
                layer.add(bar)
                bars_by_key[(li, ri)] = bar
            layer.arrange(RIGHT, buff=0.08, aligned_edge=DOWN)
            name = txt(labels[li], size=18, color=MUTED, weight=NORMAL)
            name.next_to(layer, DOWN, buff=0.18)
            groups.add(VGroup(layer, name))
        groups.arrange(RIGHT, buff=0.7, aligned_edge=DOWN).shift(DOWN * 0.75)

        all_entries = []
        for li, layer_values in enumerate(values):
            for ri, value in enumerate(layer_values):
                all_entries.append((value, li, ri))
        all_entries.sort(reverse=True)
        chosen = [(li, ri) for _, li, ri in all_entries[:9]]

        budget = txt("rank budget", size=20, color="#d7f7ee")
        budget_dots = VGroup(*[
            Dot(radius=0.055, color=AMBER)
            for i in range(9)
        ])
        budget_dots.arrange(RIGHT, buff=0.095)
        budget_group = VGroup(budget, budget_dots).arrange(RIGHT, buff=0.24)
        budget_group.move_to(UP * 2.1)

        water_line = DashedLine(LEFT * 3.9, RIGHT * 3.9, dash_length=0.12, color=AMBER)
        water_line.move_to(UP * 0.32)
        water_label = txt("highest energy per cost", size=18, color=AMBER)
        water_label.next_to(water_line, UP, buff=0.14)

        self.play(FadeIn(groups), Write(budget), LaggedStartMap(FadeIn, budget_dots, lag_ratio=0.05))
        self.play(ShowCreation(water_line), Write(water_label))
        for dot, key in zip(budget_dots, chosen):
            target = bars_by_key[key]
            self.play(
                dot.animate.move_to(target.get_top() + UP * 0.12),
                target.animate.set_fill(TEAL, opacity=0.95),
                run_time=0.28,
            )
        self.wait(1.0)


class DriftAnchorScene(FedHeraScene):
    def construct(self):
        self.title("Frozen-tail anchors reduce projection drift")

        origin = LEFT * 4.35 + DOWN * 2.05
        x_axis = Arrow(origin, origin + RIGHT * 7.9, buff=0, color="#536462")
        y_axis = Arrow(origin, origin + UP * 4.05, buff=0, color="#536462")

        oracle = Arrow(origin, origin + RIGHT * 6.4 + UP * 3.25, buff=0, color=TEAL)
        coupled = Arrow(origin, origin + RIGHT * 6.35 + UP * 1.15, buff=0, color=CORAL)
        fedhera = Arrow(origin, origin + RIGHT * 6.05 + UP * 2.78, buff=0, color=AMBER)
        oracle_label = txt("high-rank oracle", size=18, color=TEAL).move_to(RIGHT * 2.6 + UP * 2.62)
        coupled_label = txt("coupled ranks", size=18, color=CORAL).move_to(RIGHT * 2.4 + DOWN * 0.72)
        fedhera_label = txt("FedHera anchored", size=18, color=AMBER).move_to(RIGHT * 2.8 + UP * 1.9)

        gap = DashedLine(fedhera.get_end(), oracle.get_end(), color=AMBER, dash_length=0.12)
        large_gap = DashedLine(coupled.get_end(), oracle.get_end(), color=CORAL, dash_length=0.12)
        gap_note = txt("smaller drift", size=18, color=AMBER).move_to(RIGHT * 1.45 + UP * 1.35)

        gate_label = txt("Adaptive Tail Warm-up", size=24, color="#d7f7ee")
        gate_label.to_edge(UP, buff=1.2)
        gate_shell = RoundedRectangle(width=4.4, height=0.28, corner_radius=0.14)
        gate_shell.set_fill("#536462", opacity=0.5).set_stroke(width=0)
        gate_shell.next_to(gate_label, DOWN, buff=0.2)
        gate_fill = RoundedRectangle(width=0.8, height=0.28, corner_radius=0.14)
        gate_fill.set_fill(TEAL, opacity=0.95).set_stroke(width=0)
        gate_fill.align_to(gate_shell, LEFT).move_to(gate_shell.get_left() + RIGHT * 0.4)
        lambda_text = txt("lambda grows as alignment improves", size=20, color=MUTED, weight=NORMAL)
        lambda_text.next_to(gate_shell, DOWN, buff=0.2)

        self.play(GrowArrow(x_axis), GrowArrow(y_axis))
        self.play(GrowArrow(oracle), Write(oracle_label))
        self.play(GrowArrow(coupled), Write(coupled_label), ShowCreation(large_gap))
        self.play(FadeIn(gate_label), FadeIn(gate_shell), FadeIn(gate_fill), Write(lambda_text))
        self.play(
            gate_fill.animate.stretch_to_fit_width(3.65).align_to(gate_shell, LEFT),
            GrowArrow(fedhera),
            Write(fedhera_label),
            ShowCreation(gap),
            FadeIn(gap_note),
            run_time=1.6,
        )
        self.wait(1.0)
