# Copyright © The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.

"""Unit tests for the base DBTask classes."""

from pathlib import Path
from typing import cast, override
from unittest import mock

import debusine.tasks.tags as ttags
import debusine.worker.tags as wtags
from debusine.artifacts.models import TaskTypes, WorkRequestResults
from debusine.db.models.tasks import DBTask, DBWorkerProxyTask
from debusine.db.playground import scenarios
from debusine.tasks import BaseTask
from debusine.tasks.models import BaseDynamicTaskData, BaseTaskData
from debusine.tasks.server import TaskDatabaseInterface
from debusine.tasks.tests.helper_mixin import SampleBaseExternalTask
from debusine.test.django import TestCase, preserve_db_task_registry
from debusine.test.test_utils import preserve_task_registry


class DBTaskTests(TestCase):
    """Tests for :py:class:`DBTask`."""

    scenario = scenarios.DefaultContext()

    @preserve_db_task_registry()
    def make_concrete_dbtask(self) -> DBTask[BaseTaskData, BaseDynamicTaskData]:
        """Instantiate a concrete instance of DBTask."""

        class ConcreteDBTask(DBTask[BaseTaskData, BaseDynamicTaskData]):
            """Concrete implementation of DBTask."""

            TASK_TYPE = TaskTypes.WORKER
            TASK_VERSION = 1

            @override
            def build_dynamic_data(
                self, task_database: TaskDatabaseInterface
            ) -> BaseDynamicTaskData:
                return BaseDynamicTaskData()

            @override
            def _execute(self) -> WorkRequestResults:
                raise NotImplementedError()

        work_request = self.playground.create_worker_task(
            task_name="concretedbtask"
        )
        return work_request.get_task()

    def test_is_valid_task_name(self) -> None:
        self.assertTrue(DBTask.is_valid_task_name(TaskTypes.WORKER, "sbuild"))

    def test_is_valid_task_name_case_insensitive(self) -> None:
        self.assertTrue(DBTask.is_valid_task_name(TaskTypes.WORKER, "Sbuild"))

    def test_is_valid_task_name_invalid_type(self) -> None:
        self.assertFalse(
            DBTask.is_valid_task_name(
                cast(TaskTypes, "does-not-exist"), "sbuild"
            )
        )

    def test_is_valid_task_name_invalid(self) -> None:
        self.assertFalse(
            DBTask.is_valid_task_name(TaskTypes.WORKER, "not_a_real_task")
        )

    def test_task_names_worker(self) -> None:
        task_names = DBTask.task_names(TaskTypes.WORKER)
        self.assertGreater(len(task_names), 0)
        self.assertIn("noop", task_names)

    def test_worker_task_names(self) -> None:
        task_names = DBTask.worker_task_names()
        self.assertGreater(len(task_names), 0)
        self.assertIn("noop", task_names)

    def test_is_worker_task(self) -> None:
        self.assertTrue(DBTask.is_worker_task("noop"))

    def test_is_worker_task_false(self) -> None:
        self.assertFalse(DBTask.is_worker_task("nonexistent"))

    def test_class_for_name_invalid_type(self) -> None:
        with self.assertRaisesRegex(
            ValueError, "'not-existing-type' is not a registered task type"
        ):
            DBTask.class_from_name(
                cast(TaskTypes, "not-existing-type"), "non-existing-class"
            )

    def test_class_for_name_sbuild(self) -> None:
        self.assertEqual(
            DBTask.class_from_name(TaskTypes.WORKER, 'sBuIlD').name, "sbuild"
        )

    def test_class_for_name_no_class(self) -> None:
        with self.assertRaisesRegex(
            ValueError,
            "'non-existing-class' is not a registered Worker task_name",
        ):
            DBTask.class_from_name(TaskTypes.WORKER, 'non-existing-class')

    def test_compute_system_required_tags(self) -> None:
        task = self.make_concrete_dbtask()
        self.assertCountEqual(
            task.compute_system_required_tags(),
            {
                wtags.WORKER_TYPE_EXTERNAL,
                wtags.TASK_PREFIX
                + f"{task.TASK_TYPE.lower()}:concretedbtask:version:1",
            },
        )

    def test_compute_system_provided_tags(self) -> None:
        self.playground.create_group("test", [self.scenario.user])
        self.playground.create_group(
            "test", [self.scenario.user], workspace=self.scenario.workspace
        )
        task = self.make_concrete_dbtask()
        ws = self.scenario.workspace
        self.assertCountEqual(
            task.compute_system_provided_tags(),
            [
                ttags.SCOPE_PREFIX + ws.scope.name,
                ttags.WORKSPACE_PREFIX + f"{ws.scope.name}:{ws.name}",
                ttags.GROUP_PREFIX + f"{ws.scope.name}::test",
                ttags.GROUP_PREFIX + f"{ws.scope.name}:{ws.name}:test",
            ],
        )

    def test_compute_user_provided_tags_no_source_package_name(self) -> None:
        task = self.make_concrete_dbtask()
        early_dynamic_task_data = mock.Mock()
        early_dynamic_task_data.get_source_package_name = mock.Mock(
            return_value=None
        )
        tags = task.compute_user_provided_tags(early_dynamic_task_data)
        self.assertCountEqual(tags, [])

    def test_compute_user_provided_tags_with_source_package_name(self) -> None:
        task = self.make_concrete_dbtask()
        early_dynamic_task_data = mock.Mock()
        early_dynamic_task_data.get_source_package_name = mock.Mock(
            return_value="foo"
        )
        tags = task.compute_user_provided_tags(early_dynamic_task_data)
        self.assertCountEqual(tags, [ttags.SOURCE_PACKAGE_PREFIX + "foo"])

    def test_compute_tagsets(self) -> None:
        task = self.make_concrete_dbtask()
        for (
            provided,
            required,
            expected_provided,
            expected_required,
        ) in (
            ({}, {}, [], []),
            (
                {
                    "system": {
                        ttags.GROUP_PREFIX + "group",
                        ttags.SCOPE_PREFIX + "scope",
                        ttags.WORKSPACE_PREFIX + "scope:workspace",
                    },
                },
                {"system": {wtags.CAP_PREFIX + "test"}},
                [
                    ttags.GROUP_PREFIX + "group",
                    ttags.SCOPE_PREFIX + "scope",
                    ttags.WORKSPACE_PREFIX + "scope:workspace",
                ],
                [wtags.CAP_PREFIX + "test"],
            ),
            (
                {
                    "system": {ttags.SOURCE_PACKAGE_PREFIX + "test"},
                    "user": {
                        ttags.GROUP_PREFIX + "group",
                        ttags.SCOPE_PREFIX + "scope",
                        ttags.WORKSPACE_PREFIX + "scope:workspace",
                    },
                },
                {},
                [ttags.SOURCE_PACKAGE_PREFIX + "test"],
                [],
            ),
            (
                {"system": {"a", "c", "b"}, "user": {"d", "c", "e"}},
                {"system": {"f", "e", "d"}},
                ["a", "b", "c", "d", "e"],
                ["d", "e", "f"],
            ),
        ):
            with (
                self.subTest(
                    system_provided=str(provided),
                    system_required=str(required),
                ),
                mock.patch.object(
                    task,
                    "compute_system_provided_tags",
                    return_value=provided.get("system", set()),
                ),
                mock.patch.object(
                    task,
                    "compute_user_provided_tags",
                    return_value=provided.get("user", set()),
                ),
                mock.patch.object(
                    task,
                    "compute_system_required_tags",
                    return_value=required.get("system", set()),
                ),
            ):
                task.apply_task_configuration()
                self.assertEqual(
                    task.work_request.scheduler_tags_provided, expected_provided
                )
                self.assertEqual(
                    task.work_request.scheduler_tags_required, expected_required
                )


class DBProxyTaskTests(TestCase):
    """Tests for :py:class:`DBProxyTask`."""

    @preserve_task_registry()
    @preserve_db_task_registry()
    def test_populate_task_registry(self) -> None:
        # samplebasetask does not exist
        self.assertNotIn(
            "samplebasetask", BaseTask._sub_tasks[TaskTypes.WORKER]
        )
        self.assertNotIn("samplebasetask", DBTask._sub_tasks[TaskTypes.WORKER])

        class SampleBaseTask(
            SampleBaseExternalTask[BaseTaskData, BaseDynamicTaskData],
        ):
            """Sample class to test BaseExternalTask class."""

            @override
            def run(self, execute_directory: Path) -> WorkRequestResults:
                """Unused abstract method from BaseExternalTask."""
                raise NotImplementedError()

        # samplebasetask is registered with BaseTask only
        self.assertIn("samplebasetask", BaseTask._sub_tasks[TaskTypes.WORKER])
        self.assertNotIn("samplebasetask", DBTask._sub_tasks[TaskTypes.WORKER])

        DBWorkerProxyTask.populate_task_registry()

        # samplebasetask is registered with BaseTask and DBTask
        self.assertIn("samplebasetask", BaseTask._sub_tasks[TaskTypes.WORKER])
        self.assertIn("samplebasetask", DBTask._sub_tasks[TaskTypes.WORKER])

    def test_compute_system_required_tags(self) -> None:
        work_request = self.playground.create_worker_task(task_name="noop")
        task = work_request.get_task()
        with (
            mock.patch(
                "debusine.db.models.tasks.DBTask.compute_system_required_tags",
                return_value={"db"},
            ),
            mock.patch(
                "debusine.tasks.BaseTask.compute_system_required_tags",
                return_value={"base"},
            ),
        ):
            self.assertCountEqual(
                task.compute_system_required_tags(), {"db", "base"}
            )
