Source code for plaidml.keras

# Copyright Vertex.AI.
#
# Licensed under the GNU Affero General Public License V3 (the License) ;
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.gnu.org/licenses/agpl-3.0.en.html
"""Patches in a PlaidML backend for Keras.

This module hooks the system meta module path to add a backend for Keras
that uses PlaidML for computation.  The actual backend is implemented in
backend.py.

To use this module to install the PlaidML backend:

.. code-block:: python

    import plaidml.keras
    plaidml.keras.install_backend()

This should be done in the main program module, after ``__future__`` imports
(if any) and before importing any Keras modules.  Calling ``install()`` replaces
the standard keras.backend module with plaidml.keras.backend, causing subsequently
loaded Keras modules to use PlaidML.

You can explicitly set the installed backend via the environment:
    PLAIDML_KERAS_BACKEND: Selects the backend to use.
                           If this is not set, the standard PlaidML backend is used.
                           Possible values are "plaidml" and "theano".

You can also explicitly pass the backend in the call to ``install_backend()``.

(As an aside: we don't use the standard Keras approach of having you edit
``~/.keras/keras.json`` to set the backend, because we want code that doesn't patch
in the PlaidML backend loader to continue to work.  If Keras ever does support
dynamic loading of backends that aren't hard-coded into Keras, we will switch
to that mechanism.)
"""

# TODO: Update the tracing code to work on older devices.
# For posterity, here's the text:
# You can also enable API tracing by setting an environment variable:
#  PLAIDML_TRACE_FILENAME: Enables tracing, saving the output to the indicated file.

from __future__ import print_function

from six import iteritems

import functools
import importlib
import numpy as np
import os
import sys
import types

_BACKENDS = {'plaidml': '.backend', 'theano': 'keras.backend.theano_backend'}


[docs]def install_backend(import_path='keras.backend', backend=os.getenv('PLAIDML_KERAS_BACKEND', 'plaidml'), trace_file=os.getenv('PLAIDML_TRACE_FILENAME')): """Installs the PlaidML backend loader, overriding the default keras.backend. Args: import_path: The name of the module to patch. backend: The name of the backend to patch in. trace_file: A file object to write trace data to. This may also be the name of a file, which will be opened with mode 'w' (clobbering the existing file, if any). """ sys.meta_path = [_PlaidMLBackendFinder(import_path, backend, trace_file)] + sys.meta_path # Hack around Keras expecting everything not Tensorflow to be Theano. from keras.utils import conv_utils
conv_utils.convert_kernel = lambda x: x class _PlaidMLBackendFinder(object): def __init__(self, repname, backend_name, trace_file): self._repname = repname self._backend_name = backend_name try: self._backend_modname = _BACKENDS[backend_name] except KeyError: raise RuntimeError('Unknown backend \'%s\'; possible values are \'%s\'' % (backend_name, '\', \''.join(_BACKENDS.keys()))) self._trace_file = trace_file def find_module(self, fullname, path=None): if fullname != self._repname: return None tail = fullname.rsplit('.', 1)[-1] self._keras_path = [os.path.join(elt, tail) for elt in path] return self def load_module(self, fullname): mod = types.ModuleType(self._repname) mod.__path__ = self._keras_path sys.modules[fullname] = mod self._add_imports(mod, self._backend_modname) # self._add_intercepts(mod) if self._backend_name != 'plaidml': # The included Keras backends require some additional definitions. # Note that we don't intercept these. self._add_imports(mod, 'keras.backend.common') mod.backend = lambda: self._backend_name return mod def _add_imports(self, mod, import_modname): impl = importlib.import_module(import_modname, __name__) for (k, v) in iteritems(impl.__dict__): setattr(mod, k, v) return mod