#!/usr/bin/env python

from __future__ import print_function

import getopt
import sys
import os

training_basedir = ""
sphinxbinpath = ""
sphinxpath = ""


def find_paths():
    global training_basedir
    global sphinxbinpath
    global sphinxpath
    # Find the location of the files, it can be libexec or lib or lib64
    currentpath = os.path.dirname(os.path.realpath(__file__))

    # FIXME: ALL OF THIS IS WRONG WRONG WRONG
    sphinxbinpath = os.path.realpath(currentpath + "/../libexec/sphinxtrain")
    # Check if we are running in the source tree
    if os.path.exists(currentpath + "/../build/bw"):
        sphinxbinpath = os.path.realpath(currentpath +
                                         "/../build")
    if os.path.exists(currentpath + "/../build/Debug/bw.exe"):
        sphinxbinpath = os.path.realpath(currentpath + "/../build/Debug")
    if os.path.exists(currentpath + "/../build/Release/bw.exe"):
        sphinxbinpath = os.path.realpath(currentpath + "/../build/Release")

    # Find the location for the scripts
    sphinxpath = os.path.realpath(currentpath + "/../share/sphinxtrain")
    if os.path.exists(currentpath + "/../scripts/00.verify"):
        sphinxpath = os.path.realpath(currentpath + "/..")

    if not (os.path.exists(sphinxbinpath + "/bw")
            or os.path.exists(sphinxbinpath + "/bw.exe")):
        print("Failed to find sphinxtrain binaries. Check your installation")
        exit(1)
    if not (os.path.exists(sphinxpath + "/scripts/00.verify/verify_all.pl")):
        print("Failed to find sphinxtrain scripts. Check your installation")
        exit(1)

    # Perl scripts want forward slashes
    training_basedir = os.getcwd().replace('\\', '/')
    sphinxpath = sphinxpath.replace('\\', '/')
    sphinxbinpath = sphinxbinpath.replace('\\', '/')

    print("Sphinxtrain path:", sphinxpath)
    print("Sphinxtrain binaries path:", sphinxbinpath)


def setup(task):
    if not os.path.exists("etc"):
        os.mkdir("etc")

    print("Setting up the database " + task)

    out_cfg = open("./etc/sphinx_train.cfg", "w")
    for line in open(sphinxpath + "/etc/sphinx_train.cfg", "r"):
        line = line.replace("___DB_NAME___", task)
        line = line.replace("___BASE_DIR___", training_basedir)
        line = line.replace("___SPHINXTRAIN_DIR___", sphinxpath)
        line = line.replace("___SPHINXTRAIN_BIN_DIR___", sphinxbinpath)
        out_cfg.write(line)
    out_cfg.close()

    out_cfg = open("etc/feat.params", "w")
    for line in open(sphinxpath + "/etc/feat.params", "r"):
        out_cfg.write(line)
    out_cfg.close()


steps = [
    "000.comp_feat/slave_feat.pl",
    "00.verify/verify_all.pl",
    "0000.g2p_train/g2p_train.pl",
    "01.lda_train/slave_lda.pl",
    "02.mllt_train/slave_mllt.pl",
    "05.vector_quantize/slave.VQ.pl",
    "10.falign_ci_hmm/slave_convg.pl",
    "11.force_align/slave_align.pl",
    "12.vtln_align/slave_align.pl",
    "20.ci_hmm/slave_convg.pl",
    "30.cd_hmm_untied/slave_convg.pl",
    "40.buildtrees/slave.treebuilder.pl",
    "45.prunetree/slave.state-tying.pl",
    "50.cd_hmm_tied/slave_convg.pl",
    "60.lattice_generation/slave_genlat.pl",
    "61.lattice_pruning/slave_prune.pl",
    "62.lattice_conversion/slave_conv.pl",
    "65.mmie_train/slave_convg.pl",
    "90.deleted_interpolation/deleted_interpolation.pl",
    "decode/slave.pl",
]


def run_stages(stages):
    for stage in stages.split(","):
        for step in steps:
            name = step.split("/")[0].split(".")[-1]
            number = step.split("/")[0].split(".")[0]
            if name == stage or number == stage:
                ret = os.system("perl '" + sphinxpath + "/scripts/" + step + "'")
                if ret != 0:
                    exit(ret)


def run_from(stage):
    found = False
    for step in steps:
        name = step.split("/")[0].split(".")[-1]
        number = step.split("/")[0].split(".")[0]
        if name == stage or number == stage or found:
            found = True
            ret = os.system("perl '" + sphinxpath + "/scripts/" + step + "'")
            if ret != 0:
                exit(ret)


def run():
    print("Running the training")
    for step in steps:
        ret = os.system("perl '" + sphinxpath + "/scripts/" + step + "'")
        if ret != 0:
            exit(ret)


def usage():
    print("")
    print(
        "Sphinxtrain processes the audio files and creates and acoustic model "
    )
    print("for CMUSphinx toolkit. The data needs to have a certain layout ")
    print("See the tutorial http://cmusphinx.sourceforge.net/wiki/tutorialam ")
    print("for details")
    print("")
    print("Usage: sphinxtrain [options] <command>")
    print("")
    print("Commands:")
    print("     -t <task> setup - copy configuration into database")
    print("     -s <stage1,stage2,stage3> run - run selected stages")
    print("     -f <stage> run - run from selected stage")
    print("     run - run all training")


def main():

    try:
        opts, args = getopt.getopt(sys.argv[1:], "ht:s:f:",
                                   ["help", "task", "stages", "from"])
    except getopt.GetoptError as err:
        print(str(err))
        usage()
        sys.exit(-1)

    task = None
    stages = None
    from_stage = None

    for o, a in opts:
        if o in ("-t", "--task"):
            task = a
        if o in ("-f", "--from"):
            from_stage = a
        if o in ("-s", "--stages"):
            stages = a
        if o in ("-h", "--help"):
            usage()

    if len(args) == 0:
        usage()
        sys.exit(-1)

    command = args[0]

    find_paths()

    if command == "setup":
        if task is None:
            print("No task name defined")
            sys.exit(-1)
        setup(task)
    elif command == "run":
        if stages is not None:
            run_stages(stages)
        elif from_stage is not None:
            run_from(from_stage)
        else:
            run()
    else:
        run()


if __name__ == "__main__":
    main()
