Source code for plaidml.library

# Copyright Vertex.AI

import ctypes
import logging
import os
import plaidml.exceptions

_LOGGER_FUNCTYPE = ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_int, ctypes.c_char_p)

_LOG_SEVERITY_VERBOSE = 64
_LOG_SEVERITY_TRACE = 2
_LOG_SEVERITY_DEBUG = 4
_LOG_SEVERITY_INFO = 128
_LOG_SEVERITY_WARNING = 32
_LOG_SEVERITY_ERROR = 16
_LOG_SEVERITY_FATAL = 8

_LOG_SEVERITY_MAP = {
    _LOG_SEVERITY_VERBOSE: logging.DEBUG,
    _LOG_SEVERITY_TRACE: logging.DEBUG,
    _LOG_SEVERITY_DEBUG: logging.DEBUG,
    _LOG_SEVERITY_INFO: logging.INFO,
    _LOG_SEVERITY_WARNING: logging.WARNING,
    _LOG_SEVERITY_ERROR: logging.ERROR,
    _LOG_SEVERITY_FATAL: logging.CRITICAL
}

_PLAIDML_STATUS_CANCELLED = 1
_PLAIDML_STATUS_UNKNOWN = 2
_PLAIDML_STATUS_INVALID_ARGUMENT = 3
_PLAIDML_STATUS_DEADLINE_EXCEEDED = 4
_PLAIDML_STATUS_NOT_FOUND = 5
_PLAIDML_STATUS_ALREADY_EXISTS = 6
_PLAIDML_STATUS_PERMISSION_DENIED = 7
_PLAIDML_STATUS_RESOURCE_EXHAUSTED = 8
_PLAIDML_STATUS_FAILED_PRECONDITION = 9
_PLAIDML_STATUS_ABORTED = 10
_PLAIDML_STATUS_OUT_OF_RANGE = 11
_PLAIDML_STATUS_UNIMPLEMENTED = 12
_PLAIDML_STATUS_INTERNAL = 13
_PLAIDML_STATUS_UNAVAILABLE = 14
_PLAIDML_STATUS_DATA_LOSS = 15
_PLAIDML_STATUS_UNAUTHENTICATED = 16

_PLAIDML_ERRMAP = {
    _PLAIDML_STATUS_CANCELLED: plaidml.exceptions.Cancelled,
    _PLAIDML_STATUS_UNKNOWN: plaidml.exceptions.Unknown,
    _PLAIDML_STATUS_INVALID_ARGUMENT: plaidml.exceptions.InvalidArgument,
    _PLAIDML_STATUS_DEADLINE_EXCEEDED: plaidml.exceptions.DeadlineExceeded,
    _PLAIDML_STATUS_NOT_FOUND: plaidml.exceptions.NotFound,
    _PLAIDML_STATUS_ALREADY_EXISTS: plaidml.exceptions.AlreadyExists,
    _PLAIDML_STATUS_PERMISSION_DENIED: plaidml.exceptions.PermissionDenied,
    _PLAIDML_STATUS_RESOURCE_EXHAUSTED: plaidml.exceptions.ResourceExhausted,
    _PLAIDML_STATUS_FAILED_PRECONDITION: plaidml.exceptions.FailedPrecondition,
    _PLAIDML_STATUS_ABORTED: plaidml.exceptions.Aborted,
    _PLAIDML_STATUS_OUT_OF_RANGE: plaidml.exceptions.OutOfRange,
    _PLAIDML_STATUS_UNIMPLEMENTED: plaidml.exceptions.Unimplemented,
    _PLAIDML_STATUS_INTERNAL: plaidml.exceptions.Internal,
    _PLAIDML_STATUS_UNAVAILABLE: plaidml.exceptions.Unavailable,
    _PLAIDML_STATUS_DATA_LOSS: plaidml.exceptions.DataLoss,
    _PLAIDML_STATUS_UNAUTHENTICATED: plaidml.exceptions.Unauthenticated
}


class _C_Context(ctypes.Structure):
    pass


[docs]class Library(object): """A loaded PlaidML implementation library.""" def __init__(self, lib, logger=logging.log): self._lib = lib self._logger = logger self.vai_last_status = lib.vai_last_status self.vai_last_status.argtypes = [] self.vai_clear_status = lib.vai_clear_status self.vai_clear_status.argtypes = [] self.vai_last_status_str = lib.vai_last_status_str self.vai_last_status_str.argtypes = [] self.vai_last_status_str.restype = ctypes.c_char_p self.vai_set_logger = lib.vai_set_logger self.vai_set_logger.argtypes = [_LOGGER_FUNCTYPE, ctypes.c_void_p] self.vai_internal_set_vlog = lib.vai_internal_set_vlog self.vai_internal_set_vlog.argtypes = [ctypes.c_size_t] self.vai_get_perf_counter = lib.vai_get_perf_counter self.vai_get_perf_counter.argtypes = [ctypes.c_char_p] self.vai_get_perf_counter.restype = ctypes.c_longlong self.vai_set_perf_counter = lib.vai_set_perf_counter self.vai_set_perf_counter.argtypes = [ctypes.c_char_p, ctypes.c_longlong] self.vai_alloc_ctx = lib.vai_alloc_ctx self.vai_alloc_ctx.argtypes = [] self.vai_alloc_ctx.restype = ctypes.POINTER(_C_Context) self.vai_alloc_ctx.errcheck = self._check_err self.vai_free_ctx = lib.vai_free_ctx self.vai_free_ctx.argtypes = [ctypes.POINTER(_C_Context)] self.vai_cancel_ctx = lib.vai_cancel_ctx self.vai_cancel_ctx.argtypes = [ctypes.POINTER(_C_Context)] self.vai_set_eventlog = lib.vai_set_eventlog self.vai_set_eventlog.argtypes = [ctypes.POINTER(_C_Context), ctypes.c_char_p] self.vai_set_eventlog.restype = ctypes.c_bool self.vai_set_eventlog.errcheck = self._check_err self._logger_wrapper = _LOGGER_FUNCTYPE(self._logger_callback) lib.vai_set_logger(self._logger_wrapper, None) def _check_err(self, result, func, args): if result: return result self.raise_last_status() def last_status(self): try: exclass = _PLAIDML_ERRMAP[self._lib.vai_last_status()] except KeyError: return Exception(self._lib.vai_last_status_str().decode()) return exclass(self._lib.vai_last_status_str().decode()) def raise_last_status(self): raise self.last_status() def _logger_callback(self, unused_arg, level, msg): severity = _LOG_SEVERITY_MAP.get(level, logging.ERROR) self._logger(severity, msg.decode()) def get_perf_counter(self, name): return self.vai_get_perf_counter(name) def set_perf_counter(self, name, value): return self.vai_set_perf_counter(name, value) def _internal_set_vlog(self, l):
self._lib.vai_internal_set_vlog(l)