import os
import logging
from concurrent.futures import as_completed, ThreadPoolExecutor
from rex import init_logger
from sup3r import CONFIG_DIR
from sup3r.models import Sup3rGan, WindGan
from sup3r.preprocessing.data_handling import DataHandlerH5WindCC
from sup3r.preprocessing.batch_handling import SpatialBatchHandlerCC


logger = logging.getLogger(__name__)

fp_gen = '/projects/ntps/sup3r/generator_configs/gen_experimental3_5x_1x_6f.json'
fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json')

source_data_fps = [
        '/datasets/WIND/conus/v1.0.0/wtk_conus_2007.h5',
        '/datasets/WIND/conus/v1.0.0/wtk_conus_2008.h5',
        '/datasets/WIND/conus/v1.0.0/wtk_conus_2009.h5',
        '/datasets/WIND/conus/v1.0.0/wtk_conus_2010.h5',
        '/datasets/WIND/conus/v1.0.0/wtk_conus_2011.h5',
        '/datasets/WIND/conus/v1.0.0/wtk_conus_2012.h5',
#        '/datasets/WIND/conus/v1.0.0/wtk_conus_2013.h5',
]

source_raster_files = [
        '/projects/seasiawind/gan_training/wtk_raster_ind_cache/raster_conus_600x1600_32_-122.txt',
        '/projects/seasiawind/gan_training/wtk_raster_ind_cache/raster_conus_800x800_27_-97.txt',
]

FEATURES = ['U_10m', 'V_10m', 'U_100m', 'V_100m', 'U_200m', 'V_200m', 'topography']

out_dir = './model_output_step2_5x_1x_6f_test12/gan_wind_step2_5x_1x_6f_e{epoch}/'
saved_model = './model_output_step2_5x_1x_6f_test12/gan_wind_step2_5x_1x_6f_e2300/'
saved_model = './model_output_step2_5x_1x_6f_test12/gan_wind_step2_5x_1x_6f_e4600/'
pre_train = False


def get_dh(source_fp, raster_fp, cache_pattern):
    dh = DataHandlerH5WindCC(source_fp, FEATURES,
                             raster_file=raster_fp,
                             temporal_slice=slice(None, None, 1),
                             hr_spatial_coarsen=2,
                             sample_shape=(100, 100),
                             max_workers=1,
                             cache_pattern=cache_pattern,
                             overwrite_cache=False,
                             load_cached=True,
                             train_only_features=tuple(),
                             )
    return dh


if __name__ == '__main__':
    log_file = './train_step2_5x_1x_6f_test12.py.log'
    init_logger(__name__, log_level='INFO', log_file=log_file)
    init_logger('sup3r', log_level='INFO', log_file=log_file)

    if saved_model is None:
        model = WindGan(fp_gen, fp_disc,
                        learning_rate=1e-4,
                        learning_rate_disc=1e-4,
                        )
    else:
        model = WindGan.load(saved_model)

    futures = []
    file_paths = []
    raster_files = []
    data_handlers = []

    with ThreadPoolExecutor(max_workers=1) as exe:
        for source_fp in source_data_fps:
            for raster_fp in source_raster_files:

                raster_tag = os.path.basename(raster_fp)
                raster_tag = raster_tag.replace('raster_conus_', '')
                raster_tag = raster_tag.replace('.txt', '')

                source_data_tag = os.path.basename(source_fp)
                source_data_tag = source_data_tag.replace('.h5', '')

                cache_pattern = ('/scratch/gbuster/sup3r/4km/'
                                 'cached_features_{}_{}'
                                 .format(raster_tag, source_data_tag))
                cache_pattern += '_{feature}.pkl'

                future = exe.submit(get_dh, source_fp, raster_fp,
                                    cache_pattern)
                futures.append(future)

        for i, future in enumerate(as_completed(futures)):
            data_handlers.append(future.result())
            logger.info('Collected data handler future {} out of {}'
                        .format(i+1, len(futures)))

    batch_handler = SpatialBatchHandlerCC(data_handlers,
                                          batch_size=16,
                                          n_batches=256,
                                          s_enhance=5,
                                          t_enhance=1,
                                          )

    if pre_train:
        try:
            # Gen pretraining
            model.train(batch_handler, n_epoch=100,
                        weight_gen_advers=0.0,
                        train_gen=True, train_disc=False,
                        adaptive_update_fraction=0,
                        checkpoint_int=100, out_dir=out_dir,
                        early_stop_on='validation_loss_gen')
        except Exception as e:
            logger.exception('Training failed!')
            raise e

    try:
        # GAN training
        model.train(batch_handler, n_epoch=int(1e6),
                    weight_gen_advers=0.001,
                    train_gen=True, train_disc=True,
                    adaptive_update_fraction=0,
                    checkpoint_int=100, out_dir=out_dir)
    except Exception as e:
        logger.exception('Training failed!')
        raise e
