# pylint: disable=line-too-long,useless-suppression
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------

"""
DESCRIPTION:
    Given an AIProjectClient, this sample demonstrates how to use the asynchronous
    `.fine_tuning.jobs` methods to create supervised fine-tuning jobs using open-source model.
    Supported open-source models with SFT: Ministral-3b

USAGE:
    python sample_finetuning_oss_models_supervised_job_async.py

    Before running the sample:

    pip install "azure-ai-projects>=2.0.0b4" python-dotenv aiohttp

    Set these environment variables with your own values:
    1) AZURE_AI_PROJECT_ENDPOINT - Required. The Azure AI Project endpoint, as found in the overview page of your
       Microsoft Foundry portal.
    2) MODEL_NAME - Optional. The base model name to use for fine-tuning. Default to the `Ministral-3B` model.
    3) TRAINING_FILE_PATH - Optional. Path to the training data file. Default to the `data` folder.
    4) VALIDATION_FILE_PATH - Optional. Path to the validation data file. Default to the `data` folder.
"""

import os
import asyncio
from dotenv import load_dotenv
from azure.identity.aio import DefaultAzureCredential
from azure.ai.projects.aio import AIProjectClient
from fine_tuning_sample_helper import resolve_data_file_path

load_dotenv()

endpoint = os.environ["AZURE_AI_PROJECT_ENDPOINT"]
model_name = os.environ.get("MODEL_NAME", "Ministral-3B")
training_file_path = resolve_data_file_path(__file__, "TRAINING_FILE_PATH", "sft_training_set.jsonl")
validation_file_path = resolve_data_file_path(__file__, "VALIDATION_FILE_PATH", "sft_validation_set.jsonl")


async def main():

    async with (
        DefaultAzureCredential() as credential,
        AIProjectClient(endpoint=endpoint, credential=credential) as project_client,
        project_client.get_openai_client() as openai_client,
    ):

        print("Uploading training file...")
        with open(training_file_path, "rb") as f:
            train_file = await openai_client.files.create(file=f, purpose="fine-tune")
        print(f"Uploaded training file with ID: {train_file.id}")

        print("Uploading validation file...")
        with open(validation_file_path, "rb") as f:
            validation_file = await openai_client.files.create(file=f, purpose="fine-tune")
        print(f"Uploaded validation file with ID: {validation_file.id}")

        print("Waits for the training and validation files to be processed...")
        await openai_client.files.wait_for_processing(train_file.id)
        await openai_client.files.wait_for_processing(validation_file.id)

        print("Creating supervised fine-tuning job")
        fine_tuning_job = await openai_client.fine_tuning.jobs.create(
            training_file=train_file.id,
            validation_file=validation_file.id,
            model=model_name,
            method={
                "type": "supervised",
                "supervised": {"hyperparameters": {"n_epochs": 3, "batch_size": 1, "learning_rate_multiplier": 1.0}},
            },
            extra_body={
                "trainingType": "GlobalStandard"
            },  # Recommended approach to set trainingType. Omitting this field may lead to unsupported behavior.
            # Preferred trainingtype is GlobalStandard.  Note:  Global training offers cost savings , but copies data and weights outside the current resource region.
            # Learn more - https://azure.microsoft.com/pricing/details/cognitive-services/openai-service/ and https://azure.microsoft.com/explore/global-infrastructure/data-residency/
        )
        print(fine_tuning_job)


if __name__ == "__main__":
    asyncio.run(main())
