#!/usr/bin/env python
# -*- coding: utf-8 -*-
# #########################################################################
# Copyright (c) 2015-2018, UChicago Argonne, LLC. All rights reserved. #
# #
# Copyright 2018. UChicago Argonne, LLC. This software was produced #
# under U.S. Government contract DE-AC02-06CH11357 for Argonne National #
# Laboratory (ANL), which is operated by UChicago Argonne, LLC for the #
# U.S. Department of Energy. The U.S. Government has rights to use, #
# reproduce, and distribute this software. NEITHER THE GOVERNMENT NOR #
# UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR #
# ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE. If software is #
# modified to produce derivative works, such modified software should #
# be clearly marked, so as not to confuse it with the version available #
# from ANL. #
# #
# Additionally, redistribution and use in source and binary forms, with #
# or without modification, are permitted provided that the following #
# conditions are met: #
# #
# * Redistributions of source code must retain the above copyright #
# notice, this list of conditions and the following disclaimer. #
# #
# * Redistributions in binary form must reproduce the above copyright #
# notice, this list of conditions and the following disclaimer in #
# the documentation and/or other materials provided with the #
# distribution. #
# #
# * Neither the name of UChicago Argonne, LLC, Argonne National #
# Laboratory, ANL, the U.S. Government, nor the names of its #
# contributors may be used to endorse or promote products derived #
# from this software without specific prior written permission. #
# #
# THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS #
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT #
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS #
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago #
# Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, #
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, #
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; #
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT #
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN #
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE #
# POSSIBILITY OF SUCH DAMAGE. #
# #########################################################################
"""
Module containing model, predict and train routines
"""
from __future__ import print_function
import numpy as np
from keras.models import Sequential
from keras.layers.core import Dense, Reshape, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D, UpSampling2D
import xlearn.utils as utils
__authors__ = "Xiaogang Yang, Francesco De Carlo"
__copyright__ = "Copyright (c) 2018, Argonne National Laboratory"
__version__ = "0.2.0"
__docformat__ = "restructuredtext en"
__all__ = ['model',
'train',
'predict']
[docs]def model(dim_img, nb_filters, nb_conv):
"""
the cnn model for image transformation
Parameters
----------
dim_img : int
The input image dimension
nb_filters : int
Number of filters
nb_conv : int
The convolution weight dimension
Returns
-------
mdl
Description.
"""
mdl = Sequential()
mdl.add(Convolution2D(nb_filters, (nb_conv, nb_conv),
padding='same', activation='relu',
input_shape=(dim_img, dim_img, 1)))
mdl.add(MaxPooling2D(pool_size=(2, 2)))
mdl.add(Convolution2D(nb_filters * 2, (nb_conv, nb_conv), padding='same', activation='relu'))
mdl.add(MaxPooling2D(pool_size=(2, 2)))
mdl.add(Convolution2D(nb_filters * 2, (nb_conv, nb_conv), padding='same', activation='relu'))
mdl.add(Flatten())
mdl.add(Dense(int(dim_img / 4) ** 2))
mdl.add(Reshape((int(dim_img / 4), int(dim_img / 4), 1)))
mdl.add(UpSampling2D(size=(2, 2)))
mdl.add(Convolution2D(nb_filters * 2, (nb_conv, nb_conv), padding='same', activation='relu'))
mdl.add(UpSampling2D(size=(2, 2)))
mdl.add(Convolution2D(nb_filters, (nb_conv, nb_conv), padding='same', activation='relu'))
mdl.add(Convolution2D(1, (1, 1), padding='same', activation='relu'))
mdl.compile(loss='mean_squared_error', optimizer='Adam')
return mdl
[docs]def train(img_x, img_y, patch_size, patch_step, dim_img, nb_filters, nb_conv, batch_size, nb_epoch):
"""
Function description.
Parameters
----------
parameter_01 : type
Description.
parameter_02 : type
Description.
parameter_03 : type
Description.
Returns
-------
return_01
Description.
"""
img_x = utils.nor_data(img_x)
img_y = utils.nor_data(img_y)
img_input = utils.extract_patches(img_x, patch_size, patch_step)
img_output = utils.extract_patches(img_y, patch_size, patch_step)
img_input = np.reshape(img_input, (img_input.shape[0], dim_img, dim_img, 1))
img_output = np.reshape(img_output, (img_output.shape[0], dim_img, dim_img, 1))
mdl = model(dim_img, nb_filters, nb_conv)
print(mdl.summary())
mdl.fit(img_input, img_output, batch_size=batch_size, epochs=nb_epoch)
return mdl
[docs]def predict(mdl, img, patch_size, patch_step, batch_size, dim_img):
"""
the cnn model for image transformation
Parameters
----------
img : array
The image need to be calculated
patch_size : (int, int)
The patches dimension
dim_img : int
The input image dimension
Returns
-------
img_rec
Description.
"""
img = np.float16(utils.nor_data(img))
img_h, img_w = img.shape
input_img = utils.extract_patches(img, patch_size, patch_step)
input_img = np.reshape(input_img, (input_img.shape[0], dim_img, dim_img, 1))
output_img = mdl.predict(input_img, batch_size=batch_size)
del input_img
output_img = np.reshape(output_img, (output_img.shape[0], dim_img, dim_img))
img_rec = utils.reconstruct_patches(output_img, (img_h, img_w), patch_step)
return img_rec