try:
    import wandb

    WANDB_AVAILABLE = True
except:
    WANDB_AVAILABLE = False


if WANDB_AVAILABLE:
    import datetime
    import io
    import json
    import re
    from pathlib import Path

    from openai import File, FineTune, FineTuningJob
    from openai.datalib.numpy_helper import numpy as np
    from openai.datalib.pandas_helper import assert_has_pandas, pandas as pd


class WandbLogger:
    """
    Log fine-tunes to [Weights & Biases](https://wandb.me/openai-docs)
    """

    if not WANDB_AVAILABLE:
        print("Logging requires wandb to be installed. Run `pip install wandb`.")
    else:
        _wandb_api = None
        _logged_in = False

    @classmethod
    def sync(
        cls,
        id=None,
        n_fine_tunes=None,
        project="OpenAI-Fine-Tune",
        entity=None,
        force=False,
        legacy=False,
        **kwargs_wandb_init,
    ):
        """
        Sync fine-tunes to Weights & Biases.
        :param id: The id of the fine-tune (optional)
        :param n_fine_tunes: Number of most recent fine-tunes to log when an id is not provided. By default, every fine-tune is synced.
        :param project: Name of the project where you're sending runs. By default, it is "GPT-3".
        :param entity: Username or team name where you're sending runs. By default, your default entity is used, which is usually your username.
        :param force: Forces logging and overwrite existing wandb run of the same fine-tune.
        """
        
        assert_has_pandas()

        if not WANDB_AVAILABLE:
            return

        if id:
            print("Retrieving fine-tune job...")
            if legacy:
                fine_tune = FineTune.retrieve(id=id)
            else:
                fine_tune = FineTuningJob.retrieve(id=id)
            fine_tune.pop("events", None)
            fine_tunes = [fine_tune]
        else:
            # get list of fine_tune to log
            if legacy:
                fine_tunes = FineTune.list()
            else:
                fine_tunes =  list(FineTuningJob.auto_paging_iter())
            if not fine_tunes or fine_tunes.get("data") is None:
                print("No fine-tune has been retrieved")
                return
            fine_tunes = fine_tunes["data"][
                -n_fine_tunes if n_fine_tunes is not None else None :
            ]

        # log starting from oldest fine_tune
        show_individual_warnings = (
            False if id is None and n_fine_tunes is None else True
        )
        fine_tune_logged = [
            cls._log_fine_tune(
                fine_tune,
                project,
                entity,
                force,
                legacy,
                show_individual_warnings,
                **kwargs_wandb_init,
            )
            for fine_tune in fine_tunes
        ]

        if not show_individual_warnings and not any(fine_tune_logged):
            print("No new successful fine-tunes were found")

        return "🎉 wandb sync completed successfully"

    @classmethod
    def _log_fine_tune(
        cls,
        fine_tune,
        project,
        entity,
        force,
        legacy,
        show_individual_warnings,
        **kwargs_wandb_init,
    ):
        fine_tune_id = fine_tune.get("id")
        status = fine_tune.get("status")

        # check run completed successfully
        if status != "succeeded":
            if show_individual_warnings:
                print(
                    f'Fine-tune {fine_tune_id} has the status "{status}" and will not be logged'
                )
            return

        # check results are present
        try:
            if legacy:
                results_id = fine_tune["result_files"][0]["id"]
            else:
                results_id = fine_tune["result_files"][0]
            results = File.download(id=results_id).decode("utf-8")
        except:
            if show_individual_warnings:
                print(f"Fine-tune {fine_tune_id} has no results and will not be logged")
            return

        # check run has not been logged already
        run_path = f"{project}/{fine_tune_id}"
        if entity is not None:
            run_path = f"{entity}/{run_path}"
        wandb_run = cls._get_wandb_run(run_path)
        if wandb_run:
            wandb_status = wandb_run.summary.get("status")
            if show_individual_warnings:
                if wandb_status == "succeeded":
                    print(
                        f"Fine-tune {fine_tune_id} has already been logged successfully at {wandb_run.url}"
                    )
                    if not force:
                        print(
                            'Use "--force" in the CLI or "force=True" in python if you want to overwrite previous run'
                        )
                else:
                    print(
                        f"A run for fine-tune {fine_tune_id} was previously created but didn't end successfully"
                    )
                if wandb_status != "succeeded" or force:
                    print(
                        f"A new wandb run will be created for fine-tune {fine_tune_id} and previous run will be overwritten"
                    )
            if wandb_status == "succeeded" and not force:
                return

        # start a wandb run
        wandb.init(
            job_type="fine-tune",
            config=cls._get_config(fine_tune),
            project=project,
            entity=entity,
            name=fine_tune_id,
            id=fine_tune_id,
            **kwargs_wandb_init,
        )

        # log results
        df_results = pd.read_csv(io.StringIO(results))
        for _, row in df_results.iterrows():
            metrics = {k: v for k, v in row.items() if not np.isnan(v)}
            step = metrics.pop("step")
            if step is not None:
                step = int(step)
            wandb.log(metrics, step=step)
        fine_tuned_model = fine_tune.get("fine_tuned_model")
        if fine_tuned_model is not None:
            wandb.summary["fine_tuned_model"] = fine_tuned_model

        # training/validation files and fine-tune details
        cls._log_artifacts(fine_tune, project, entity)

        # mark run as complete
        wandb.summary["status"] = "succeeded"

        wandb.finish()
        return True

    @classmethod
    def _ensure_logged_in(cls):
        if not cls._logged_in:
            if wandb.login():
                cls._logged_in = True
            else:
                raise Exception("You need to log in to wandb")

    @classmethod
    def _get_wandb_run(cls, run_path):
        cls._ensure_logged_in()
        try:
            if cls._wandb_api is None:
                cls._wandb_api = wandb.Api()
            return cls._wandb_api.run(run_path)
        except Exception:
            return None

    @classmethod
    def _get_wandb_artifact(cls, artifact_path):
        cls._ensure_logged_in()
        try:
            if cls._wandb_api is None:
                cls._wandb_api = wandb.Api()
            return cls._wandb_api.artifact(artifact_path)
        except Exception:
            return None

    @classmethod
    def _get_config(cls, fine_tune):
        config = dict(fine_tune)
        for key in ("training_files", "validation_files", "result_files"):
            if config.get(key) and len(config[key]):
                config[key] = config[key][0]
        if config.get("created_at"):
            config["created_at"] = datetime.datetime.fromtimestamp(config["created_at"])
        return config

    @classmethod
    def _log_artifacts(cls, fine_tune, project, entity):
        # training/validation files
        training_file = (
            fine_tune["training_files"][0]
            if fine_tune.get("training_files") and len(fine_tune["training_files"])
            else None
        )
        validation_file = (
            fine_tune["validation_files"][0]
            if fine_tune.get("validation_files") and len(fine_tune["validation_files"])
            else None
        )
        for file, prefix, artifact_type in (
            (training_file, "train", "training_files"),
            (validation_file, "valid", "validation_files"),
        ):
            if file is not None:
                cls._log_artifact_inputs(file, prefix, artifact_type, project, entity)

        # fine-tune details
        fine_tune_id = fine_tune.get("id")
        artifact = wandb.Artifact(
            "fine_tune_details",
            type="fine_tune_details",
            metadata=fine_tune,
        )
        with artifact.new_file(
            "fine_tune_details.json", mode="w", encoding="utf-8"
        ) as f:
            json.dump(fine_tune, f, indent=2)
        wandb.run.log_artifact(
            artifact,
            aliases=["latest", fine_tune_id],
        )

    @classmethod
    def _log_artifact_inputs(cls, file, prefix, artifact_type, project, entity):
        file_id = file["id"]
        filename = Path(file["filename"]).name
        stem = Path(file["filename"]).stem

        # get input artifact
        artifact_name = f"{prefix}-{filename}"
        # sanitize name to valid wandb artifact name
        artifact_name = re.sub(r"[^a-zA-Z0-9_\-.]", "_", artifact_name)
        artifact_alias = file_id
        artifact_path = f"{project}/{artifact_name}:{artifact_alias}"
        if entity is not None:
            artifact_path = f"{entity}/{artifact_path}"
        artifact = cls._get_wandb_artifact(artifact_path)

        # create artifact if file not already logged previously
        if artifact is None:
            # get file content
            try:
                file_content = File.download(id=file_id).decode("utf-8")
            except:
                print(
                    f"File {file_id} could not be retrieved. Make sure you are allowed to download training/validation files"
                )
                return
            artifact = wandb.Artifact(artifact_name, type=artifact_type, metadata=file)
            with artifact.new_file(filename, mode="w", encoding="utf-8") as f:
                f.write(file_content)

            # create a Table
            try:
                table, n_items = cls._make_table(file_content)
                artifact.add(table, stem)
                wandb.config.update({f"n_{prefix}": n_items})
                artifact.metadata["items"] = n_items
            except:
                print(f"File {file_id} could not be read as a valid JSON file")
        else:
            # log number of items
            wandb.config.update({f"n_{prefix}": artifact.metadata.get("items")})

        wandb.run.use_artifact(artifact, aliases=["latest", artifact_alias])

    @classmethod
    def _make_table(cls, file_content):
        df = pd.read_json(io.StringIO(file_content), orient="records", lines=True)
        return wandb.Table(dataframe=df), len(df)
