#!/usr/bin/env python3
# coding: utf-8

# import sys
# import os

import numpy as np
from scipy.integrate import solve_bvp
import time
# from math import sin, cos
# from timeit import default_timer as timer
from opt_dfuns import odefun, bcfun, init_guess  # , gradodefun
from opt_params import par
from OptPiirto import opt_nosturi_piirto, tallenna_ohjaus


# guess palauttaa tilavektorin arvon ajanhetkellä t
def guess(Steps):
    guess_arr = np.array(
        [init_guess(float(i) * par.T_f/float(Steps)) for i in range(Steps)])
    print(guess_arr.shape)
    return np.transpose(guess_arr)


def scipy_ratkaisu(pi, stps, plotdim, s_paiva, max_pi):
    """
    Ratkaistaan reuna-arvotehtävä
    """
    initMesh = np.linspace(0.0, par.T_f, stps)
    res = solve_bvp(odefun, bcfun, initMesh, guess(stps),
                    verbose=2, tol=1.0e-5, max_nodes=4000)
    print('status ', res.status, res.message)
    print('******************************************')
    (n, ti) = res.y.shape
    print('ti ', ti)
    dim = min(ti, plotdim)
    xs = np.linspace(0, par.T_f, dim)
    ys = res.sol(xs)
    # piirretään ratkaisun kuvaajat
    opt_nosturi_piirto(s_paiva, pi, xs, ys, dim)
    # optPiirto('hyppy_' + str(pi), xs, ys, dim)
    # tallennetaan ratkaisu simulointia varten
    tallenna_ohjaus(xs, ys, dim)
    # tallenna_tila(xs, ys, dim)
    s1 = '\niteration: {0:2d} '.format(pi)
    s2 = ' C_W {0:.4g}; C_F1 {1:.4g}; C_F2 {2:.4g}'
    s2 = s2.format(par.C_W, par.C_F1, par.C_F2)
    print(s1 + s2)

    while res.status == 0 and pi < max_pi:
        """
        Alkuarvauksella algoritmi ei aina konvergoi kuin "löysällä"
        optimointikriteerillä. Ratkaisua voi käyttää uutena parempana
        alkuarvauksena ja kiristää optimointikriteeriä.
        """
        pi = pi + 1
        par.C_F1 = 0.8 * par.C_F1
        par.C_F2 = 0.8 * par.C_F2
        xs0 = np.linspace(0, par.T_f, int(ti/5.0))
        ys0 = res.sol(xs0)
        res = solve_bvp(odefun, bcfun, xs0, ys0,
                        verbose=2, tol=1.0e-5, max_nodes=4000)
        print('status ', res.status, res.message)
        (n, ti) = res.y.shape
        print('ti ', ti)
        dim = min(ti, plotdim)
        xs = np.linspace(0, par.T_f, dim)
        ys = res.sol(xs)
        opt_nosturi_piirto(s_paiva, pi, xs, ys, dim)
        # optPiirto('hyppy_' + str(pi), xs, ys, min(len(res.x), plotdim))
        tallenna_ohjaus(xs, ys, dim)
        s1 = '\niteration: {0:2d} '.format(pi)
        s2 = ' C_W {0:.4g}; C_F1 {1:.4g}; C_F2 {2:.4g}'
        s2 = s2.format(par.C_W, par.C_F1, par.C_F2)
        print(s1 + s2)

##############


plotdim = int(48.0 * par.T_f)

paiva = time.localtime(time.time())
s_paiva = str(paiva.tm_year) + str(paiva.tm_mon) + str(paiva.tm_mday) + str(
    paiva.tm_hour) + str(paiva.tm_min)

maxiter = 20
Stps = 20
scipy_ratkaisu(0, Stps, plotdim, s_paiva, maxiter)

print("se siitä")