# Copyright Vertex.AI.
"""
The TILE standard operation library.
These operations have been shown to be useful across a variety of frameworks.
(Frameworks are of course free to define their own operations in addition to
these, although it'll be easier to use them with these if a framework's own
operations are defined using the standard :doc:`plaidml.tile` base classes.)
Each operation is defined as a ``tile.Operation`` subclass, allowing it to be
used in pattern matching.  Additionally, each operation is provided via a
top-level function that wraps the class, allowing composite operations to
be built up using a functional programming style.
See the `PlaidML Op Tutorial <https://github.com/plaidml/plaidml/wiki/PlaidML-Op-Tutorial>`_
for information about writing your own custom operations.
"""
# pylint: disable=invalid-name
from collections import defaultdict
import functools
from enum import Enum
import plaidml
from plaidml import tile
import six
[docs]class AutoPadding(Enum):
    EXPLICIT = 1
    VALID = 2
    SAME_UPPER = 3 
    SAME_LOWER = 4
    CHANNELS_LAST = 2
def _extend_pads(pads, rank):
    """Extends a padding list to match the necessary rank.
    
    Args:
        pads ([int] or None): The explicitly-provided padding list.
        rank (int): The rank of the operation.
    Returns:
        None: If pads is None
        [int]: The extended padding list.
    """
    if pads is None:
        return pads
    pads = list(pads)
    if len(pads) < rank:
        pads.extend([0] * (rank - len(pads)))
    if len(pads) < (2 * rank):
        pads.extend(pads[len(pads) - rank:rank])
    return pads
[docs]def pad_compute(sym, input_size, filter_size, stride, padding, pads=None):
    """Computes info for an axis of a padded filter.
    Args:
        sym (str): The symbol for the input axis.
        input_size (tile.Value or int): The size of the input axis (possibly symbolic).
        filter_size (int): The size of the filter along this axis.
        stride (int): The stride of the filter along this axis.
        padding (AutoPadding): The padding style to use.
        pads ((int, int) or None): Explicit pre- and post-padding for this axis.
    Returns:
        tuple(A string representing the output size as TILE code,
              The pre-padding to use when building input accessor expressions,
              A tile.Value representing the computed output size)
    """
    if pads:
        num_out_size = (input_size + pads[0] + pads[1] - filter_size + stride) // stride
        sym_output_size = '({sym} + {pre} + {post} - {fs} + {s}) / {s}'.format(
            sym=sym, pre=pads[0], post=pads[1], fs=filter_size, s=stride)
        sym_padding_before = pads[0]
    elif padding == AutoPadding.VALID:
        num_out_size = (input_size - filter_size + stride) // stride
        sym_output_size = '({sym} - {fs} + {s}) / {s}'.format(sym=sym, fs=filter_size, s=stride)
        sym_padding_before = 0
    elif padding == AutoPadding.SAME_UPPER or padding == AutoPadding.SAME_LOWER:
        num_out_size = (input_size + stride - 1) // stride
        sym_output_size = '({sym} + {s} - 1) / {s}'.format(sym=sym, s=stride)
        if padding == AutoPadding.SAME_UPPER:
            expr = '(max(0, ({symout} - 1) * {s} + {fs} - {syminp})) / 2'
        else:
            expr = '((max(0, ({symout} - 1) * {s} + {fs} - {syminp})) + 1) / 2'
        sym_padding_before = expr.format(
            symout=sym_output_size, s=stride, fs=filter_size, syminp=sym)
    else:
        raise Exception('Invalid padding: ' + str(padding))
    if not isinstance(num_out_size, tile.Value) and num_out_size < 0:
        raise Exception(
            'Invalid output size computed for convolution: num_out_size={}'.format(num_out_size)) 
    return (sym_output_size, sym_padding_before, num_out_size)
def _format_conv_strings(
        rank,
        in_shape,
        kernel_shape,
        strides,
        padding,
        data_format,
        dilation_rate,
        channelwise,
        forward=True,
        expected_output_shape=None,
):
    # Variable meanings:
    # N: Number of items in the batch
    # L<i>: Spatial dimension i of each (input) item
    # CI: Number of channels (aka filters) of each input item
    # LK<i>: Spatial dimension i of kernel
    # CO: Number of channels (aka filters) of each output item
    # C: Number of input channels in channelwise convolutions
    # M: Channel multiplier in channelwise convolutions (each input channel yields
    #     M output channels for such convolutions)
    #
    # n: Which element of the batch we're on
    # x<i>: The ith coordinate in the output/image
    # k<i>: The ith coordinate in the kernel
    # ci: The input channel we're on
    # co: The output channel we're on
    # c: The input channel we're on for channelwise convolutions
    # m: The output channel multiplier we're on for output convolutions
    if data_format == ConvolutionDataFormat.CHANNELS_FIRST:
        n = 0
        c = 1
        l = [i + 2 for i in range(rank)]
    elif data_format == ConvolutionDataFormat.CHANNELS_LAST:
        n = 0
        l = [i + 1 for i in range(rank)]
        c = rank + 1
    else:
        raise ValueError('Unrecognized data format \'{}\''.format(data_format))
    if channelwise == True and in_shape[c] != kernel_shape[-2]:
        raise ValueError(
            'Channelwise convolution must have same number of channels in both input and kernel:\n'
            + '{} (from shape {}) v {} (from shape {})'.format(in_shape[c], in_shape,
                                                               kernel_shape[-2], kernel_shape))
    sym_out_shape = list()
    pad_amount = list()
    num_out_shape = list()
    for i in range(rank):
        if forward:
            sym_out, sym_pad, num_out = pad_compute('L{}'.format(i), in_shape[l[i]],
                                                    dilation_rate[i] * (kernel_shape[i] - 1) + 1,
                                                    strides[i], padding, None)
        else:
            sym_out, sym_pad, num_out = pad_compute('D{}'.format(i), in_shape[l[i]],
                                                    dilation_rate[i] * (kernel_shape[i] - 1) + 1,
                                                    strides[i], padding, None)
        sym_out_shape.append(sym_out)
        pad_amount.append(sym_pad)
        num_out_shape.append(num_out)
    if expected_output_shape is not None:
        # Confirm that the output shape is consistent with the rest of the convolution
        computed_output_shape = [0] * (rank + 2)
        computed_output_shape[n] = in_shape[n]
        computed_output_shape[c] = kernel_shape[-1]
        for i in range(rank):
            computed_output_shape[l[i]] = num_out_shape[i]
        for i in range(rank + 2):
            if (not isinstance(computed_output_shape[i], tile.Value) and
                    not isinstance(expected_output_shape[i], tile.Value) and
                    computed_output_shape[i] != expected_output_shape[i]):
                raise ValueError('Expected convolution output of shape {}, received {}'.format(
                    expected_output_shape, computed_output_shape))
    padding_list = ['Pad{} = {};'.format(i, pad_amount[i]) for i in range(rank)]
    padding_str = ''.join(p + '\n                   ' for p in padding_list)
    input_idx_list = [
        '{s}*{x} + {d}*{k} - {p}'.format(
            s=strides[i],
            x='x{}'.format(i),
            d='{}'.format(dilation_rate[i]),
            k='k{}'.format(i),
            p='Pad{}'.format(i)) for i in range(rank)
    ]
    if data_format == ConvolutionDataFormat.CHANNELS_FIRST and not channelwise:
        if forward:
            input_dims_str = 'N, CI, ' + ', '.join(['L{}'.format(i) for i in range(rank)])
            out_dims_str = 'N, CO, ' + ', '.join(
                ['{}'.format(sym_out_shape[i]) for i in range(rank)])
            outshape = [in_shape[0]] + [kernel_shape[-1]] + num_out_shape
        else:
            input_dims_str = 'N, CI, ' + ', '.join('D{}'.format(i) for i in range(rank))
            out_dims_str = 'N, CO, ' + ', '.join(['L{}'.format(i) for i in range(rank)])
        out_idx_str = 'n, co, ' + ', '.join(['x{}'.format(i) for i in range(rank)])
        input_idx_str = 'n, ci, ' + ', '.join(input_idx_list)
    elif data_format == ConvolutionDataFormat.CHANNELS_LAST and not channelwise:
        if forward:
            input_dims_str = 'N, ' + ', '.join(['L{}'.format(i) for i in range(rank)]) + ', CI'
            out_dims_str = 'N, ' + ', '.join(['{}'.format(sym_out_shape[i])
                                              for i in range(rank)]) + ', CO'
            outshape = [in_shape[0]] + num_out_shape + [kernel_shape[-1]]
        else:
            input_dims_str = 'N, ' + ', '.join('D{}'.format(i) for i in range(rank)) + ', CI'
            out_dims_str = 'N, ' + ', '.join(['L{}'.format(i) for i in range(rank)]) + ', CO'
        out_idx_str = 'n, ' + ', '.join(['x{}'.format(i) for i in range(rank)]) + ', co'
        input_idx_str = 'n, ' + ', '.join(input_idx_list) + ', ci'
    elif data_format == ConvolutionDataFormat.CHANNELS_FIRST and channelwise:
        if not forward:
            raise NotImplementedError('Channelwise transposed convolutions not implemented.')
        input_dims_str = 'N, C, ' + ', '.join(['L{}'.format(i) for i in range(rank)])
        out_idx_str = 'n, c*M + m, ' + ', '.join(['x{}'.format(i) for i in range(rank)])
        out_dims_str = 'N, C*M, ' + ', '.join(['{}'.format(sym_out_shape[i]) for i in range(rank)])
        input_idx_str = 'n, c, ' + ', '.join(input_idx_list)
        outshape = [in_shape[0]] + [kernel_shape[-2] * kernel_shape[-1]] + num_out_shape
    elif data_format == ConvolutionDataFormat.CHANNELS_LAST and channelwise:
        if not forward:
            raise NotImplementedError('Channelwise transposed convolutions not implemented.')
        input_dims_str = 'N, ' + ', '.join(['L{}'.format(i) for i in range(rank)]) + ', C'
        out_idx_str = 'n, ' + ', '.join(['x{}'.format(i) for i in range(rank)]) + ', c*M + m'
        out_dims_str = 'N, ' + ', '.join(['{}'.format(sym_out_shape[i])
                                          for i in range(rank)]) + ', C*M'
        input_idx_str = 'n, ' + ', '.join(input_idx_list) + ', c'
        outshape = [in_shape[0]] + num_out_shape + [kernel_shape[-2] * kernel_shape[-1]]
    else:
        raise ValueError('Unrecognized data format \'{}\''.format(data_format))
    if channelwise:
        ker_dims_str = ', '.join(['LK{}'.format(i) for i in range(rank)]) + ', C, M'
        ker_idx_str = ', '.join(['k{}'.format(i) for i in range(rank)]) + ', c, m'
    else:
        ker_dims_str = ', '.join(['LK{}'.format(i) for i in range(rank)]) + ', CI, CO'
        ker_idx_str = ', '.join(['k{}'.format(i) for i in range(rank)]) + ', ci, co'
    ret = {
        'input_dims_str': input_dims_str,
        'ker_dims_str': ker_dims_str,
        'out_idx_str': out_idx_str,
        'out_dims_str': out_dims_str,
        'input_idx_str': input_idx_str,
        'ker_idx_str': ker_idx_str,
        'padding_str': padding_str
    }
    if forward:
        ret['outshape_tuple'] = outshape
    else:
        ret['dim_input'] = ', ' + ', '.join(['D{}'.format(i) for i in range(rank)])
    return ret
[docs]class ArgMax(tile.Operation):
    """Maximum of elements along an axis.
    Builds a tensor whose elements are the maximum value on some axis of an input tensor.
    """
    def __init__(self, value, axis=-1):
        self.axis = axis
        self.value = value 
        super(ArgMax, self).__init__(None, [('I', value)], [('O', value.shape)])
argmax = ArgMax.function
[docs]class AveragePool(tile.Operation):
    """
    A standard ML average pooling operator.
    """
    def __init__(self, data, kernel_shape, pads, strides, padding=AutoPadding.EXPLICIT):
        rank = data.shape.ndims - 2
        pads = _extend_pads(pads, rank)
        if not strides:
            strides = tuple(1 for _ in range(rank))
        elif len(strides) != rank:
            raise ValueError('Pool strides length inconsistent with input shape: ' +
                             '{} (rank {}) v {} (rank {})'.format(strides,
                                                                  len(strides), data.shape, rank))
        out_dims = ['N', 'C']
        num_out_shape = list()
        in_idxs = list()
        for i in range(rank):
            sym_out, sym_pad, num_out = pad_compute('L{}'.format(i), data.shape.dims[i + 2],
                                                    kernel_shape[i], strides[i], padding,
                                                    (pads[i], pads[i + rank]) if pads else None)
            out_dims.append(sym_out)
            num_out_shape.append(num_out)
            in_idxs.append('{stride}*x{idx} + a{idx} - {pad}'.format(
                stride=strides[i], idx=i, pad=sym_pad))
        out_idxs = ['n', 'c'] + ['x{}'.format(i) for i in range(rank)]
        code = """
        function (I[N, C, {in_dims}], One[]) -> (O) {{
            Ones[{one_idxs} : {in_dims}] = =(One[]);
            Count[{cout_idxs}{cout_sep}{cout_dims}] = +(Ones[{in_idxs}]), {pool_bounds};
            S[{out_idxs} : {out_dims}] = +(I[n, c, {in_idxs}]), {pool_bounds};
            O = S / Count;
        }}""".format(
            out_idxs=', '.join(out_idxs),
            out_dims=', '.join(out_dims),
            cout_idxs=', '.join(out_idxs[2:]),
            cout_dims=', '.join(out_dims[2:]),
            cout_sep=' : ' if len(out_idxs) > 2 else '',
            one_idxs=', '.join(['o{}'.format(i) for i in range(rank)]),
            in_idxs=', '.join(in_idxs),
            in_dims=', '.join(['L{}'.format(i) for i in range(rank)]),
            pool_bounds=', '.join(['a{} < {}'.format(i, kernel_shape[i]) for i in range(rank)]))
        outshape = tile.Shape(data.shape.dtype, list(data.shape.dims[0:2]) + num_out_shape)
        super(AveragePool, self).__init__( 
            code, [('I', data), ('One', tile.Value.from_var(1., tuple()))], [('O', outshape)])
average_pool = AveragePool.function
[docs]class BinaryCrossentropy(tile.Operation):
    """
    Computes the binary crossentropy of a value relative to a target.
    """
    def __init__(self, target, output, epsilon, from_logits=False):
        if epsilon is None:
            epsilon = 0.0
        if from_logits:
            output = sigmoid(output)
        output = clip(output, epsilon, 1.0 - epsilon)
        input_sizes = ','.join(['I' + str(i) for i in range(output.shape.ndims)])
        input_sizes_prod = '*'.join(['I' + str(i) for i in range(output.shape.ndims)])
        f = """
            function (O[{dims}], T[{dims}]) -> (R) {{
                R = builtin_binary_crossentropy(O,T,{prod});
            }}""".format(
            dims=input_sizes, prod=input_sizes_prod)
        super(BinaryCrossentropy, self).__init__(f, [('O', output), ('T', target)], 
                                                 [('R', output.shape)])
binary_crossentropy = BinaryCrossentropy.function
[docs]class Cast(tile.Operation):
    def __init__(self, x, dtype):
        info = tile.DTYPE_INFOS[dtype]
        super(Cast, self).__init__('function (I) -> (O) {{ O = as_{}(I, {}); }}'.format( 
            info.base, info.bitwidth), [('I', x)], [('O', tile.Shape(dtype, x.shape.dims))])
cast = Cast.function
[docs]def ceiling(data):
    """Elementwise ceiling.""" 
    return tile.unary_op(data, 'ceil(I)', 'Ceiling')
[docs]class ClipMin(tile.Operation):
    """Clips a Value to a minimum bound."""
    def __init__(self, value, min_val):
        code = """
               function (I, MIN_VAL) -> (O) {
                   O = (MIN_VAL < I ? I : MIN_VAL);
               }"""
        super(ClipMin, self).__init__(code, [('I', value), ('MIN_VAL', min_val)], 
                                      [('O', value.shape)])
[docs]class ClipMax(tile.Operation):
    """Clips a Value to a maximum bound."""
    def __init__(self, value, max_val):
        code = """
               function (I, MAX_VAL) -> (O) {
                   O = (I < MAX_VAL ? I : MAX_VAL);
               }"""
        super(ClipMax, self).__init__(code, [('I', value), ('MAX_VAL', max_val)], 
                                      [('O', value.shape)])
[docs]def clip(value, min_val, max_val):
    if min_val is not None:
        value = ClipMin.function(value, min_val)
    if max_val is not None:
        value = ClipMax.function(value, max_val) 
    return value
[docs]class Concatenate(tile.Operation):
    """Concatenates tensors to make a single larger tensor."""
    def __init__(self, tensors, axis=-1):
        rank = tensors[0].shape.ndims
        if axis >= rank or axis < -rank:
            raise ValueError('Cannot concatenate tensors with {} dimensions along axis {}'.format(
                rank, axis))
        elif axis < 0:
            axis = axis % rank
        def __clear_axis(dims):
            return [
                None if isinstance(dims[i], tile.Value) else dims[i] for i in range(len(dims))
                if i != axis
            ]
        shape_template = __clear_axis(tensors[0].shape.dims)
        for t in tensors:
            if __clear_axis(t.shape.dims) != shape_template:
                raise ValueError(
                    'Incompatible shapes: cannot concatenate along axis {}\n{} v {}'.format(
                        axis, tensors[0].shape, t.shape))
        offsets = [0]
        for i in range(len(tensors)):
            offsets.append(offsets[i] + tensors[i].shape.dims[axis])
        out_dims = tuple(
            tensors[0].shape.dims[i] if i != axis else offsets[len(tensors)] for i in range(rank))
        output_dims_list = ['N{}'.format(i) for i in range(rank)]
        output_dims_list[axis] = offsets[len(tensors)]
        output_dims_str = ', '.join([str(i) for i in output_dims_list])
        # output_dims_list also serves as a base for input dims,
        # with `axis` index to be overwritten by 'Ai' (i = input index)
        inputs_list = list()
        for i in range(len(tensors)):
            curr_input_dims = list(output_dims_list)  # using 'list' here to make a copy
            curr_input_dims[axis] = 'A{}'.format(i)
            inputs_list.append('I{}[{}]'.format(i, ', '.join(curr_input_dims)))
        inputs_str = ', '.join(inputs_list)
        if axis == 0:
            indices_begin = 'a'
        else:
            indices_begin = ', '.join(['n{}'.format(i) for i in range(axis)]) + ', a'
        if axis == rank - 1:
            indices_end = ''
        else:
            indices_end = ', ' + ', '.join(['n{}'.format(i) for i in range(axis + 1, rank)])
        body_str = ''
        line_subs = {'beg': indices_begin, 'end': indices_end, 'odims': output_dims_str}
        for i in range(len(tensors)):
            # TODO: If offsets[i] is symbolic, add it to the function
            # inputs and use it symbolically.
            line_subs['off'] = '+{}'.format(offsets[i])
            line_subs['i'] = i
            curr_line = '  T{i}[{beg}{off}{end}: {odims}] = =(I{i}[{beg}{end}]);\n'.format(
                **line_subs)
            body_str += curr_line
        body_str += '  O = '
        body_str += ' + '.join(['T{}'.format(i) for i in range(len(tensors))])
        body_str += ';'
        # Example 'code' (concatenating (4,3,2), (4,5,2), (4,1,2)):
        #   function (I0[N0, A0, N2], I1[N0, A1, N2], I2[N0, A2, N2]) -> (O) {
        #     T0[n0, a, n2: N0, 9, N2] = =(I0[n0, a, n2]);
        #     T1[n0, a+3, n2: N0, 9, N2] = =(I1[n0, a, n2]);
        #     T2[n0, a+8, n2: N0, 9, N2] = =(I2[n0, a, n2]);
        #     O = T0 + T1 + T2;
        #   }
        code = ('function ({inputs}) -> (O) {{\n{body}\n}}').format(
            inputs=inputs_str,
            body=body_str,
        )
        inputs_list = []
        inputs_list.extend([('I{}'.format(i), tensors[i]) for i in range(len(tensors))])
        super(Concatenate, self).__init__(code, inputs_list, 
                                          [('O', tile.Shape(tensors[0].shape.dtype, out_dims))])
concatenate = Concatenate.function
[docs]class Convolution(tile.Operation):
    """
    A standard ML convolution operator.
    """
    def __init__(self,
                 data,
                 kernel,
                 strides=None,
                 padding=AutoPadding.EXPLICIT,
                 pads=None,
                 group=1,
                 kernel_shape=None,
                 data_format=None,
                 dilation_rate=None,
                 channelwise=False):
        if group != 1:
            raise NotImplementedError('Grouped convolutions are not currently implemented')
        rank = data.shape.ndims - 2
        if strides is None:
            strides = tuple(1 for _ in range(rank))
        if dilation_rate is None:
            dilation_rate = tuple(1 for _ in range(rank))
        if not kernel_shape:
            kernel_shape = kernel.shape.dims
        else:
            kernel_shape = tuple([kernel.shape.dims[0], kernel.shape.dims[1]] + list(kernel_shape))
        for entry in dilation_rate:
            if not isinstance(entry, int) or entry <= 0:
                raise ValueError('Invalid dilation_rate: {}'.format(dilation_rate))
        if len(kernel_shape) != rank + 2:
            raise ValueError('Convolution kernel shape inconsistent with input shape: ' +
                             '{} (rank {}) v {} (rank {})'.format(
                                 kernel_shape,
                                 len(kernel_shape) - 2, data.shape, data.shape.ndims - 2))
        if len(strides) != rank:
            raise ValueError('Convolution strides length inconsistent with input shape: ' +
                             '{} (rank {}) v {} (rank {})'.format(
                                 strides, len(strides), data.shape, data.shape.ndims - 2))
        if len(dilation_rate) != rank:
            raise ValueError('Convolution dilation_rate length inconsistent with input shape: ' +
                             '{} (rank {}) v {} (rank {})'.format(dilation_rate,
                                                                  len(dilation_rate), data.shape,
                                                                  data.shape.ndims - 2))
        conv_strs = _format_conv_strings(rank, data.shape.dims, kernel_shape, strides, padding,
                                         data_format, dilation_rate, channelwise)
        code = """
               function (I[{input_dims_str}], K[{ker_dims_str}]) -> (O) {{
                   {padding_str}O[{out_idx_str} : {out_dims_str}] = +(I[{input_idx_str}]*K[{ker_idx_str}]);
               }}""".format(**conv_strs)
        outshape = tile.Shape(data.shape.dtype, conv_strs['outshape_tuple'])
        super(Convolution, self).__init__(
            code, [('I', data), ('K', kernel)], [('O', outshape)], 
            name='Convolution-{}d'.format(rank))
convolution = Convolution.function
[docs]class ConvolutionTranspose(tile.Operation):
    """
    A transposed convolution operator.
    """
    def __init__(self, x, kernel, output_shape, strides, padding, data_format):
        rank = x.shape.ndims - 2
        if kernel.shape.ndims != rank + 2:
            raise ValueError('Transpose convolution kernel shape inconsistent with input shape: ' +
                             '{} (rank {}) v {} (rank {})'.format(
                                 kernel.shape, kernel.shape.ndims - 2, x.shape, x.shape.ndims - 2))
        if len(output_shape) != rank + 2:
            raise ValueError('Transpose convolution output_shape inconsistent with input shape: ' +
                             '{} (rank {}) v {} (rank {})'.format(
                                 output_shape, len(output_shape) - 2, x.shape, x.shape.ndims - 2))
        if len(strides) != rank:
            raise ValueError('Transpose convolution strides inconsistent with input shape: ' +
                             '{} (rank {}) v {} (rank {})'.format(
                                 strides, len(strides), x.shape, x.shape.ndims - 2))
        if (x.shape.dims[0] != output_shape[0] and
                isinstance(x.shape.dims[0], six.integer_types) and
                isinstance(output_shape[0], six.integer_types)):
            raise ValueError('Transpose convolution batch size inconsistent between input ' +
                             'and output: {} v {}'.format(x.shape.dims[0], output_shape[0]))
        conv_strs = _format_conv_strings(rank, output_shape, kernel.shape.dims, strides, padding,
                                         data_format, (1,) * rank, False, False, x.shape.dims)
        f = """
            function (O[{out_dims_str}], K[{ker_dims_str}]{dim_input}) -> (I) {{
                {padding_str}
                I[{input_idx_str} : {input_dims_str}] = +(O[{out_idx_str}]*K[{ker_idx_str}]);
            }}""".format(**conv_strs)
        # Output shape may be dynamic, so pass its sizes as inputs to Tile
        if data_format == ConvolutionDataFormat.CHANNELS_FIRST:
            l = [i + 2 for i in range(rank)]
        elif data_format == ConvolutionDataFormat.CHANNELS_LAST:
            l = [i + 1 for i in range(rank)]
        else:
            raise ValueError('Unrecognized data format \'{}\''.format(data_format))
        input_tensors = [('O', x), ('K', kernel)] + \
                        
[('D{}'.format(i), output_shape[l[i]]) for i in range(rank)]
        super(ConvolutionTranspose, self).__init__(
            f,
            input_tensors, [('I', tile.Shape(x.shape.dtype, tuple(output_shape)))], 
            name='ConvolutionTranspose-{}d'.format(rank))
convolution_transpose = ConvolutionTranspose.function
[docs]def cos(data):
    """Elementwise cosine.""" 
    return tile.unary_op(data, 'cos(I)', 'Cosine')
[docs]class CumulativeSum(tile.Operation):
    """Cumulative sum of a tensor"""
    def __init__(self, x, axis=0):
        ranges = ', '.join(['N{}'.format(n) for n in range(x.shape.ndims)])
        dest_idxs = ', '.join(['i{}'.format(n) for n in range(x.shape.ndims)])
        src_idxs = ['i{}'.format(n) for n in range(x.shape.ndims)]
        src_idxs[axis] += ' - k'
        src_idxs = ', '.join(src_idxs)
        f = """
            function (I[{src_ranges}]) -> (O) {{
                O[{dest_idxs}: {dest_ranges}] = +(I[{src_idxs}]), k < N{ax};
            }}""".format(
            src_ranges=ranges, dest_idxs=dest_idxs, dest_ranges=ranges, src_idxs=src_idxs, ax=axis) 
        super(CumulativeSum, self).__init__(f, [('I', x)], [('O', x.shape)])
cumulative_sum = CumulativeSum.function
[docs]class Dot(tile.Operation):
    """Dot-product of two tensors."""
    def __init__(self, x, y):
        if x.shape.dtype != y.shape.dtype:
            raise ValueError(
                'Invalid dtype in multiplication: x.dtype=\'{}\', y.dtype=\'{}\''.format(
                    x.shape.dtype, y.shape.dtype))
        if x.shape.ndims == 1 and y.shape.ndims == 1:
            f = 'function (X[I], Y[I]) -> (R) { R[i:I] = +(X[i] * Y[i]); }'
            shape = x.shape
        elif 1 <= x.shape.ndims and 2 <= y.shape.ndims:
            f = """function(X[{x_ranges}], Y[{y_ranges}]) -> (R) {{
                       R[{dest_indices} : {dest_ranges}] = +(X[{x_indices}] * Y[{y_indices}]);
                   }}""".format(
                x_ranges=', '.join(['X{}'.format(i) for i in range(x.shape.ndims)]),
                y_ranges=', '.join(['Y{}'.format(i) for i in range(y.shape.ndims)]),
                dest_indices=', '.join(['x{}'.format(i) for i in range(x.shape.ndims - 1)] + [
                    'y{}'.format(i) for i in (list(range(y.shape.ndims - 2)) + [y.shape.ndims - 1])
                ]),
                dest_ranges=', '.join(['X{}'.format(i) for i in range(x.shape.ndims - 1)] + [
                    'Y{}'.format(i) for i in (list(range(y.shape.ndims - 2)) + [y.shape.ndims - 1])
                ]),
                x_indices=', '.join(['x{}'.format(i) for i in range(x.shape.ndims - 1)] + ['z']),
                y_indices=', '.join(['y{}'.format(i) for i in range(y.shape.ndims - 2)] + ['z'] +
                                    ['y{}'.format(y.shape.ndims - 1)]))
            shape = tile.Shape(
                x.shape.dtype,
                (list(x.shape.dims[:-1]) + list(y.shape.dims[:-2]) + [y.shape.dims[-1]]))
        else:
            raise NotImplementedError('Implement dot when x.dims={} and y.dims={}'.format(
                x.shape.dims, y.shape.dims))
 
        super(Dot, self).__init__(f, [('X', x), ('Y', y)], [('R', shape)])
dot = Dot.function
[docs]class Elu(tile.Operation):
    """Exponential linear unit."""
    def __init__(self, x, alpha=1.0):
        if alpha == 1:
            code = """
                   function (X) -> (R) {
                       A = exp(X)-1;
                       R = (X < 0 ? A : X);
                   }"""
        else:
            code = """
                   function (X) -> (R) {{
                       A = {alpha}*exp(X) - {alpha};
                       R = X < 0 ? A : X;
                   }}""".format(alpha=alpha)
 
        super(Elu, self).__init__(code, [('X', x)], [('R', x.shape)])
elu = Elu.function
[docs]class Equal(tile.Operation):
    """Elementwise tensor equality.
    Builds a boolean tensor whose values are true where the corresponding elements of the inputs
    are equal.
    """
    def __init__(self, lhs, rhs):
        self.lhs = lhs
        self.rhs = rhs
        if isinstance(rhs, tile.Value):
            shape = tile.Shape(plaidml.DType.BOOLEAN,
                               tile.broadcast_dims(lhs.shape.dims, rhs.shape.dims))
            super(Equal, self).__init__('function (L, R) -> (O) { O = (L == R); }',
                                        [('L', lhs), ('R', rhs)], [('O', shape)])
        else:
            shape = tile.Shape(plaidml.DType.BOOLEAN, lhs.shape.dims)
            super(Equal, self).__init__('function (L) -> (O) {{ O = (L == {}); }}'.format(rhs), 
                                        [('L', lhs)], [('O', shape)])
[docs]class Equal_ArgMax(tile.Operation):
    def __init__(self, lhs, rhs):
        lmax = ismax(lhs.source.op.value, axes=(lhs.source.op.axis,))
        rmax = ismax(rhs.source.op.value, axes=(rhs.source.op.axis,))
        and_shape = tile.Shape(plaidml.DType.INT32,
                               tile.broadcast_dims(lmax.shape.dims, rmax.shape.dims))
        and_op = tile.Operation('function (L, R) -> (O) { O = L ? (R ? 1 : 0) : 0; }',
                                [('L', lmax), ('R', rmax)], [('O', and_shape)])
        sum_val = summation(and_op.output_tuple[0], axes=(lhs.source.op.axis,), keepdims=True)
        eq_shape = tile.Shape(plaidml.DType.BOOLEAN, sum_val.shape.dims)
        super(Equal_ArgMax, self).__init__('function (I) -> (O) { O = 0 < I; }', [('I', sum_val)], 
                                           [('O', eq_shape)])
[docs]def equal(lhs, rhs):
    """Elementwise tensor equality.
    Builds a boolean tensor whose values are true when the corresponding elements of the inputs
    are equal.
    Args:
        lhs (tile.Value): The left-hand side
        rhs (tile.Value): The right-hand side
    Returns:
        tile.Value: The output value
    """
    # TODO: Separate function builders from optimization/composition logic.
    #
    # Putting the composition logic in functions like this makes it a little hard for
    # higher-layer modules to add their own compositions -- think eq(MySpecialOp, MySpecialOp),
    # when some completely unrelated module is invoking the eq.  It would be better to have
    # something like a rewriter registry that could be consulted to match patterns during binding.
    if (lhs.source and isinstance(lhs.source.op, ArgMax) and rhs.source and
            isinstance(rhs.source.op, ArgMax)):
        return Equal_ArgMax.function(lhs, rhs) 
    return Equal.function(lhs, rhs)
[docs]def exp(data):
    """Elementwise exponential.""" 
    return tile.unary_op(data, 'exp(I)', 'Exp')
[docs]class Flatten(tile.Operation):
    """
    Flattens a tensor to a one-dimensional value.
    """
    def __init__(self, data):
        in_dim_list = ['N{}'.format(i) for i in range(data.shape.ndims)]
        out_dim_list = ['*'.join(['N{}'.format(i) for i in range(data.shape.ndims)])]
        new_size = functools.reduce(lambda x, y: x * y, data.shape.dims)
        code = 'function (I[{idims}]) -> (O) {{ O = reshape(I, {odims}); }}'.format(
            idims=', '.join(in_dim_list), odims=', '.join(out_dim_list))
        super(Flatten, self).__init__(code, [('I', data)], 
                                      [('O', tile.Shape(data.shape.dtype, (new_size,)))])
flatten = Flatten.function
[docs]def floor(data):
    """Elementwise floor.""" 
    return tile.unary_op(data, 'floor(I)', 'Floor')
[docs]class Gather(tile.Operation):
    """
    Gathers elements of a tensor.
    """
    def __init__(self, value, indicies):
        outshape = tile.Shape(value.shape.dtype,
                              list(indicies.shape.dims) + list(value.shape.dims[1:]))
        super(Gather, self).__init__('function (V, I) -> (O) { O = gather(V, I); }', 
                                     [('V', value), ('I', indicies)], [('O', outshape)])
gather = Gather.function
[docs]class Gemm(tile.Operation):
    """
    Implements a general matrix multiplication.
    """
    def __init__(self, a, b, c, alpha=None, beta=None, broadcast=True, transA=False, transB=False):
        if broadcast:
            if c.shape.ndims != 1:
                raise NotImplementedError(
                    'Gemm with multiplier broadcast requires a one-dimensional scalar multiplier; multiplier rank={}'.
                    format(c.shape.ndims))
        elif c.shape.ndims != 2:
            raise NotImplementedError(
                'Gemm without multiplier broadcast requires a two-dimensional scalar multiplier; multiplier rank={}'.
                format(c.shape.ndims))
        def gemm_reshape(value):
            if value.shape.ndims < 2:
                raise tile.LogicError(
                    'Invalid Gemm input; two-dimensions required, got: {}'.format(value.shape))
            if value.shape.ndims == 2:
                return value
            newdims = (value.shape.dims[0], functools.reduce(lambda x, y: x * y,
                                                             value.shape.dims[1:]))
            return reshape(value, newdims)
        a = gemm_reshape(a)
        b = gemm_reshape(b)
        code = """
        function (A[{a_dims}], B[{b_dims}], C[{c_dims}]) -> (O) {{
          OM[row, col : ROW, COL] = +(A[{a_idxs}] * B[{b_idxs}]);
          OA = {alpha_expr};
          CB = {beta_expr};
          O = OA + CB;
        }}""".format(
            a_dims='MID, ROW' if transA else 'ROW, MID',
            b_dims='COL, MID' if transB else 'COL, MID',
            c_dims='ROW, COL' if c.shape.ndims == 2 else 'COL',
            a_idxs='mid, row' if transA else 'row, mid',
            b_idxs='col, mid' if transB else 'col, mid',
            alpha_expr='OM * {}'.format(alpha) if alpha else 'OM',
            beta_expr='C * {}'.format(beta) if beta else 'C',
        )
        outshape = tile.Shape(
            tile.common_dtype(a.shape.dtype, b.shape.dtype, c.shape.dtype),
            tile.broadcast_dims((
                a.shape.dims[1] if transA else a.shape.dims[0],
                b.shape.dims[0] if transB else a.shape.dims[1],
            ), c.shape.dims))
 
        super(Gemm, self).__init__(code, [('A', a), ('B', b), ('C', c)], [('O', outshape)])
gemm = Gemm.function
[docs]class Gradients(tile.Operation):
    """
    Compute the gradients of a loss with respect to a set of values
    """
    def __init__(self, loss, variables):
        super(Gradients, self).__init__(None, [('Loss', loss)] + [('I' + str(i), variables[i])
                                                                  for i in range(len(variables))],
                                        [('O' + str(i), variables[i].shape)
                                         for i in range(len(variables))])
        self.num_vars = len(variables)
    def bind(self, bindings):
        loss_var = self.inputs['Loss'].bind(bindings)
        input_vars = [self.inputs['I' + str(i)].bind(bindings) for i in range(self.num_vars)]
        output_vars = plaidml.gradients(loss_var, input_vars)
        outputs = {}
        for i in range(self.num_vars):
            outputs['O' + str(i)] = output_vars[i] 
        return outputs
[docs]def gradients(loss, variables):
    if isinstance(variables, tile.Value):
        variables = [variables]
    op = Gradients(loss, variables)
    outs = []
    for i in range(len(op.outputs)):
        outs.append(op.outputs['O' + str(i)]) 
    return outs
[docs]class Hardmax(tile.Operation):
    """
    Implements a standard ML hardmax.
    """
    def __init__(self, data):
        if data.shape.ndims != 2:
            raise NotImplementedError(
                'Hardmax with a non-two-dimensional tensor is not currently implemented')
        code = """
        function (I[X, Y]) -> (O) {
            MAXX[x : X] = >(I[x, y]);
            MAX[x, y : X, Y] = =(MAXX[x]);
            O = (MAX == I ? 1.0 : 0.0);
        }"""
 
        super(Hardmax, self).__init__(code, [('I', data)], [('O', data.shape)])
[docs]def hardmax(x, axis=None):
    if x.shape.ndims == 2:
        return Hardmax.function(x)
    if axis is None:
        axis = 1
    full_dims = x.shape.dims
    if axis == 0:
        group = 1
    else:
        group = functools.reduce(lambda x, y: x * y, x.shape.dims[:axis])
    if axis == len(x.shape.dims):
        values = 1
    else:
        values = functools.reduce(lambda x, y: x * y, x.shape.dims[axis:])
    flat_x = reshape(x, (group, values))
    result = Hardmax.function(flat_x) 
    return reshape(result, full_dims)
[docs]class Identity(tile.Operation):
    """A simple identity operation."""
    def __init__(self, x):
        super(Identity, self).__init__('function (X) -> (Y) { Y = X; }', [('X', x)], 
                                       [('Y', x.shape)])
identity = Identity.function
[docs]class IsMax(tile.Operation):
    """
    True iff an input's value is the maximum along some set of axes.
    """
    def __init__(self, value, axes):
        dims, _, subs = tile.compute_aggregation_axes(value.shape.dims, axes, True)
        code = """function (I[{src_ranges}]) -> (O) {{
                    MAX[{dest_indices}{dest_sep}{dest_ranges}] = >(I[{src_indices}]);
                    O = (MAX == I);
                }}""".format(**subs)
        super(IsMax, self).__init__(code, [('I', value)], 
                                    [('O', tile.Shape(plaidml.DType.BOOLEAN, dims))])
ismax = IsMax.function
[docs]def log(data):
    """Elementwise logarithm.""" 
    return tile.unary_op(data, 'log(I)', 'Log')
[docs]class LogSoftmax(tile.Operation):
    """
    Implements the log() of a standard ML softmax.
    """
    def __init__(self, data):
        if data.shape.ndims != 2:
            raise NotImplementedError(
                'LogSoftmax with a non-two-dimensional tensor is not currently implemented')
        code = """
        function (I[X, Y]) -> (O) {
            O = builtin_logsoftmax(I, X, Y);
        }"""
 
        super(LogSoftmax, self).__init__(code, [('I', data)], [('O', data.shape)])
[docs]def log_softmax(x, axis=None):
    if x.shape.ndims == 2:
        return LogSoftmax.function(x)
    if axis is None:
        axis = 1
    full_dims = x.shape.dims
    if axis == 0:
        group = 1
    else:
        group = functools.reduce(lambda x, y: x * y, x.shape.dims[:axis])
    if axis == len(x.shape.dims):
        values = 1
    else:
        values = functools.reduce(lambda x, y: x * y, x.shape.dims[axis:])
    flat_x = reshape(x, (group, values))
    result = LogSoftmax.function(flat_x) 
    return reshape(result, full_dims)
[docs]class MatMul(tile.Operation):
    """
    A matrix multiplication, using numpy semantics.
    See https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html for details.
    """
    def __init__(self, a, b):
        # So, for matmul, we have identity dimensions (which remain the same
        # in the output tensor), and summation dimensions (which are
        # eliminated in the output tensor).  We call these I{1,2,...} and S.
        #
        # The matrix multiplication and summation takes place on the low two dimensions.
        # If either input is one-dimensional, that's its summation dimension.
        # Otherwise, A's summation dimension is the lowest dimension, and B's summation
        # dimension is its second-to-lowest.
        #
        # Naturally, there can be broadcasting involved; corresponding dimensions
        # must be broadcast-compatible.
        a_ndims = a.shape.ndims
        b_ndims = b.shape.ndims
        if a_ndims == 0 or b_ndims == 0:
            raise NotImplementedError('MatMul isn\'t defined over scalar values')
        if a_ndims == 1:
            if b_ndims == 1:
                # Both A and B are one dimensional; C is a scalar.
                #   A's dims are [S]
                #   B's dims are [S]
                #   C's dims are []
                c_dims = tuple()
                a_ranges = ['S']
                a_indicies = ['s']
                b_ranges = ['S']
                b_indicies = ['s']
                c_ranges = []
                c_indicies = []
            else:
                # A is one-dimensional, but B is not:
                #   A's dims are [S]
                #   B's dims are [I0, I1... IN-3, S, IN-1]
                #   C's dims are [I0, I1... IN-3, IN-1]
                c_shape = tuple(b.dims[:-2] + b.dims[-1])
                a_ranges = ['S']
                a_indicies = ['s']
                b_ranges = (['I{}'.format(n)
                             for n in range(b_ndims - 2)] + ['S', 'I{}'.format(b_ndims - 1)])
                b_indicies = (['i{}'.format(n)
                               for n in range(b_ndims - 2)] + ['s', 'i{}'.format(b_ndims - 1)])
                c_ranges = ['I{}'.format(n) for n in range(b_ndims - 2) + [b_ndims - 1]]
                c_indicies = ['i{}'.format(n) for n in range(b_ndims - 2) + [b_ndims - 1]]
        else:
            if b_ndims == 1:
                # B is one-dimensional, but A is not:
                #   A's dims are [I0, I1... IN-3, IN-2, S]
                #   B's dims are [S]
                #   C's dims are [I0, I1... IN-3, IN-2]
                c_dims = tuple(a.shape.dims[:-1])
                a_ranges = ['I{}'.format(n) for n in range(a_ndims - 1)] + ['S']
                a_indicies = ['i{}'.format(n) for n in range(a_ndims - 1)] + ['s']
                b_ranges = ['S']
                b_indicies = ['s']
                c_ranges = ['I{}'.format(n) for n in range(a_ndims - 1)]
                c_indicies = ['i{}'.format(n) for n in range(a_ndims - 1)]
            else:
                # Both tensors have more than one dimension.
                #   A's dims are [I0, I1... IN-3, IN-2, S]
                #   B's dims are [I0, I1... IN-3, S, IN-1]
                #   C's dims are [I0, I1... IN-3, IN-2, IN-1].
                c_dims = tuple(
                    list(tile.broadcast_dims(a.shape.dims[:-2], b.shape.dims[:-2])) +
                    [a.shape.dims[-2], b.shape.dims[-1]])
                a_ranges = ['I{}'.format(n) for n in range(a_ndims - 1)] + ['S']
                a_indicies = ['i{}'.format(n) for n in range(a_ndims - 1)] + ['s']
                b_ranges = (['I{}'.format(n)
                             for n in range(b_ndims - 2)] + ['S', 'I{}'.format(b_ndims - 1)])
                b_indicies = (['i{}'.format(n)
                               for n in range(b_ndims - 2)] + ['s', 'i{}'.format(b_ndims - 1)])
                c_ranges = ['I{}'.format(n) for n in range(len(c_dims))]
                c_indicies = ['i{}'.format(n) for n in range(len(c_dims))]
        func = """function(A[{a_ranges}], B[{b_ranges}]) -> (C) {{
                        C[{c_indicies} : {c_ranges}] = +(A[{a_indicies}] * B[{b_indicies}]);
                    }}""".format(
            a_ranges=', '.join(a_ranges),
            a_indicies=', '.join(a_indicies),
            b_ranges=', '.join(b_ranges),
            b_indicies=', '.join(b_indicies),
            c_ranges=', '.join(c_ranges),
            c_indicies=', '.join(c_indicies))
        c_shape = tile.Shape(tile.common_dtype(a.shape.dtype, b.shape.dtype), c_dims)
 
        super(MatMul, self).__init__(func, [('A', a), ('B', b)], [('C', c_shape)])
matmul = MatMul.function
[docs]class MaxReduce(tile.Operation):
    """Computes the maximum value along some set of axes."""
    def __init__(self, x, axes=None, keepdims=False):
        if axes == None:
            axes = list(range(x.shape.ndims))
        shape, axes, subs = tile.compute_aggregation_axes(x.shape.dims, axes, keepdims)
        f = """function (I[{src_ranges}]) -> (O) {{
                   O[{dest_indices}{dest_sep}{dest_ranges}] = >(I[{src_indices}]);
               }}""".format(**subs)
 
        super(MaxReduce, self).__init__(f, [('I', x)], [('O', tile.Shape(x.shape.dtype, shape))])
[docs]def max_reduce(x, axes=None, keepdims=False):
    if not x.shape.ndims:
        return x
    if isinstance(axes, (tuple, list)) and not len(axes):
        # Do nothing if max'ing over an empty axis list
        return x
 
    return MaxReduce.function(x, axes=axes, keepdims=keepdims)
maximum = tile.maximum
[docs]class MaxPool(tile.Operation):
    """
    A standard ML max pooling operator.
    """
    def __init__(self, data, padding, kernel_shape, pads, strides):
        rank = data.shape.ndims - 2
        pads = _extend_pads(pads, rank)
        if not strides:
            strides = tuple(1 for _ in range(rank))
        elif len(strides) != rank:
            raise ValueError('Pool strides length inconsistent with input shape: ' +
                             '{} (rank {}) v {} (rank {})'.format(strides,
                                                                  len(strides), data.shape, rank))
        sym_out_shape = list()
        num_out_shape = list()
        in_idxs = list()
        for i in range(rank):
            sym_out, sym_pad, num_out = pad_compute('L{}'.format(i), data.shape.dims[i + 2],
                                                    kernel_shape[i], strides[i], padding,
                                                    (pads[i], pads[i + rank]) if pads else None)
            sym_out_shape.append(sym_out)
            num_out_shape.append(num_out)
            in_idxs.append('{stride}*x{idx} + k{idx} - {pad}'.format(
                stride=strides[i], idx=i, pad=sym_pad))
        code = """
        function (I[N, C, {in_dims}]) -> (O) {{
            O[n, c, {out_idxs} : N, C, {out_dims}] = >(I[n, c, {in_idxs}]), {pool_bounds};
        }}""".format(
            out_idxs=', '.join(['x{}'.format(i) for i in range(rank)]),
            out_dims=', '.join(sym_out_shape),
            in_idxs=', '.join(in_idxs),
            in_dims=', '.join(['L{}'.format(i) for i in range(rank)]),
            pool_bounds=', '.join(['k{} < {}'.format(i, kernel_shape[i]) for i in range(rank)]))
        outshape = tile.Shape(data.shape.dtype, list(data.shape.dims[0:2]) + num_out_shape)
 
        super(MaxPool, self).__init__(code, [('I', data)], [('O', outshape)])
max_pool = MaxPool.function
[docs]class Mean(tile.Operation):
    """Computes the mean value along some set of axes."""
    def __init__(self, x, axes=None, keepdims=False, floatx=plaidml.DType.FLOAT32):
        if x.shape.dtype == plaidml.DType.BOOLEAN:
            x = cast(x, floatx)
        if axes == None:
            axes = list(range(x.shape.ndims))
        shape, axes, subs = tile.compute_aggregation_axes(x.shape.dims, axes, keepdims)
        subs['mean_ranges'] = '*'.join(['X' + str(i) for i in axes])
        f = """
            function (I[{src_ranges}]) -> (O) {{
                SO[{dest_indices}{dest_sep}{dest_ranges}] = +(I[{src_indices}]);
                 O = SO / ({mean_ranges});
            }}""".format(**subs)
 
        super(Mean, self).__init__(f, [('I', x)], [('O', tile.Shape(x.shape.dtype, shape))])
[docs]def mean(x, axes=None, keepdims=False, floatx=plaidml.DType.FLOAT32):
    if not x.shape.ndims:
        return x
    if isinstance(axes, (tuple, list)) and not len(axes):
        # We're taking the mean across an empty axis list.
        # Keras sometimes does this when squeezing a matrix that doesn't need
        # to be squeezed.
        return x
 
    return Mean.function(x, axes=axes, keepdims=keepdims, floatx=floatx)
[docs]class MinReduce(tile.Operation):
    """Computes the minimum value along some set of axes."""
    def __init__(self, x, axes=None, keepdims=False):
        if axes == None:
            axes = list(range(x.shape.ndims))
        shape, axes, subs = tile.compute_aggregation_axes(x.shape.dims, axes, keepdims)
        f = """function (I[{src_ranges}]) -> (O) {{
                   O[{dest_indices}{dest_sep}{dest_ranges}] = <(I[{src_indices}]);
               }}""".format(**subs)
 
        super(MinReduce, self).__init__(f, [('I', x)], [('O', tile.Shape(x.shape.dtype, shape))])
[docs]def min_reduce(x, axes=None, keepdims=False):
    if not x.shape.ndims:
        return x
    if isinstance(axes, (tuple, list)) and not len(axes):
        # Do nothing if min'ing over an empty axis list
        return x
 
    return MinReduce.function(x, axes=axes, keepdims=keepdims)
minimum = tile.minimum
[docs]class NotEqual(tile.Operation):
    """Elementwise tensor inequality.
    Builds a boolean tensor whose values are true where the corresponding elements of the inputs
    are not equal.
    """
    def __init__(self, lhs, rhs):
        self.lhs = lhs
        self.rhs = rhs
        if isinstance(rhs, tile.Value):
            shape = tile.Shape(plaidml.DType.BOOLEAN,
                               tile.broadcast_dims(lhs.shape.dims, rhs.shape.dims))
            super(NotEqual, self).__init__('function (L, R) -> (O) { O = (L != R); }',
                                           [('L', lhs), ('R', rhs)], [('O', shape)])
        else:
            shape = tile.Shape(plaidml.DType.BOOLEAN, lhs.shape.dims)
            super(NotEqual, self).__init__('function (L) -> (O) {{ O = (L != {}); }}'.format(rhs), 
                                           [('L', lhs)], [('O', shape)])
not_equal = NotEqual.function
[docs]class Pow(tile.Operation):
    """An elementwise pow() function."""
    def __init__(self, x, p):
        super(Pow, self).__init__('function (I, P) -> (O) { O = pow(I, P); }', 
                                  [('I', x), ('P', p)], [('O', x.shape)])
pow = Pow.function
[docs]class Prod(tile.Operation):
    def __init__(self, value, axes=None, keepdims=False, floatx=plaidml.DType.FLOAT32):
        if value.shape.dtype == plaidml.DType.BOOLEAN:
            value = cast(value, floatx)
        if axes is None:
            axes = list(range(value.shape.ndims))
        dims, _, subs = tile.compute_aggregation_axes(value.shape.dims, axes, keepdims)
        code = """
               function (I[{src_ranges}]) -> (O) {{
                   O[{dest_indices}{dest_sep}{dest_ranges}] = *(I[{src_indices}]);
               }}""".format(**subs)
        super(Prod, self).__init__(code, [('I', value)], 
                                   [('O', tile.Shape(value.shape.dtype, dims))])
[docs]def prod(value, axes=None, keepdims=False, floatx=plaidml.DType.FLOAT32):
    if not value.shape.ndims:
        return value
    if isinstance(axes, (tuple, list)) and not len(axes):
        # We're taking the product across an empty axis list.
        return value 
    return Prod.function(value, axes=axes, keepdims=keepdims, floatx=floatx)
[docs]class Relu(tile.Operation):
    """A Rectified Linear Unit."""
    def __init__(self, x, alpha=None, max_value=None):
        if (alpha is not None) and (max_value is not None):
            # Alpha with a max_value; cap a hand-coded relu.
            code = """
                   function (X, Alpha, MaxValue) -> (Y) {
                       M = (X < 0.0 ? Alpha*X : X);
                       Y = (M < MaxValue ? M : MaxValue);
                   }"""
        elif alpha is not None:
            # Alpha with no max_value; use a hand-coded relu.
            code = 'function (X, Alpha) -> (Y) { Y = (X < 0 ? Alpha*X : X); }'
        elif max_value is not None:
            # No alpha, but a max_value; cap the builtin relu.
            code = """
                   function (X, MaxValue) -> (Y) {
                       M = (X < 0.0 ? 0.0 : X);
                       Y = (M < MaxValue ? M : MaxValue);
                   }"""
        else:
            # Neither alpha nor max_value; use the builtin relu.
            code = 'function (X) -> (Y) { Y = relu(X); }'
        inputs = [('X', x)]
        if alpha is not None:
            inputs.append(('Alpha', alpha))
        if max_value is not None:
            inputs.append(('MaxValue', max_value))
 
        super(Relu, self).__init__(code, inputs, [('Y', x.shape)])
relu = Relu.function
[docs]class Reshape(tile.Operation):
    """
    Reshapes a tensor, without changing the type or number of elements.
    """
    def __init__(self, x, dims):
        dims = list(dims)
        neg_idx = None
        for idx, dim in enumerate(dims):
            if isinstance(dim, tile.Value):
                continue
            if dim == 0 or dim is None:
                dims[idx] = x.shape.dims[idx]
            elif dim == -1:
                if neg_idx:
                    raise tile.LogicError(
                        'At most one dimension of size -1 may be provided in Reshape')
                neg_idx = idx
                dims[idx] = 1  # Just to simplify the size computation later
        if neg_idx is not None:
            # Compute the value to use for the -1 dimension in the
            # output shape, by making it what it needs to be in order
            # to preserve the correct number of elements in the
            # tensor.
            #
            # This code is a little tricky because symbolic values
            # (e.g. the batch size in a typical neural network) may
            # appear in both the original shape and the target shape.
            # Naively multiplying the original shape's dimensions and
            # dividing by the target shape's dimensions (excluding the
            # -1 dimension) would produce a symbolic value.
            #
            # So:
            #
            # We scan the input dimensions, counting the number of
            # instances of each symbolic size encountered and
            # multiplying together the non-symbolic sizes into the
            # numerator.
            #
            # We then scan the output dimensions.  Where there's a
            # symbolic size, we check and see if we have a count for
            # it, and decrement the count if we do.  Otherwise -- if
            # we don't have a count for it, or if it's not symbolic --
            # we multiply it into the denominator.
            #
            # We then take the remaining symbolic input dimensions,
            # and multiply them into the numerator -- these are the
            # dimensions that haven't been cancelled out.
            #
            # And then the size of the -1 dimension is just numerator
            # / denominator; if there are any remaining uncancelled
            # symbolic dimension sizes, the output will be symbolic,
            # but otherwise we'll come out with a concrete dimension
            # size.
            num = 1
            syms = defaultdict(int)
            for dim in x.shape.dims:
                if isinstance(dim, tile.Value):
                    syms[dim] += 1
                else:
                    num *= dim
            den = 1
            for dim in dims:
                if isinstance(dim, tile.Value) and syms[dim] > 0:
                    syms[dim] -= 1
                else:
                    den *= dim
            for sym, count in syms.items():
                for _ in range(count):
                    num *= sym
            dims[neg_idx] = num // den
        inputs = [('I', x)]
        dstrs = list(dims)
        for idx, dim in enumerate(dstrs):
            if isinstance(dim, tile.Value):
                dname = 'D{}'.format(idx)
                inputs.append((dname, dim))
                dstrs[idx] = dname
        super(Reshape, self).__init__('function ({}) -> (O) {{ O = reshape(I, {}); }}'.format(
            ', '.join(inp[0] for inp in inputs), ', '.join([str(d) for d in dstrs])), inputs, 
                                      [('O', tile.Shape(x.shape.dtype, dims))])
reshape = Reshape.function
ShapeOf = tile.ShapeOf
shape_of = tile.shape_of
[docs]def sigmoid(data):
    """Elementwise sigmoid.""" 
    return tile.unary_op(data, 'sigmoid(I)', 'Sigmoid')
[docs]def sin(data):
    """Elementwise sine.""" 
    return tile.unary_op(data, 'sin(I)', 'Sine')
[docs]class SliceTensor(tile.Operation):
    """
    Implements tensor slicing.
    """
    def __init__(self, data, axes=None, ends=None, starts=None):
        if not ends or not starts:
            raise tile.LogicError('Slice requires starts and ends to be set')
        if len(starts) != len(ends):
            raise tile.LogicError('Slice requires starts and ends for all sliced axes')
        if not axes:
            axes = range(len(starts))
        in_dims = ['D{}'.format(d) for d in range(data.shape.ndims)]
        out_dims = list(in_dims)
        in_idxs = ['d{}'.format(d) for d in range(data.shape.ndims)]
        out_idxs = list(in_idxs)
        shape_dims = list(data.shape.dims)
        for axis, start, end in zip(axes, starts, ends):
            clamped_end = tile.minimum(end, data.shape.dims[axis])
            clamped_start = tile.minimum(start, data.shape.dims[axis])
            if isinstance(clamped_start, tile.Value):
                clamped_start_str = 'min({}, D{})'.format(start, axis)
            else:
                clamped_start_str = str(clamped_start)
            if isinstance(clamped_end, tile.Value):
                clamped_end_str = 'min({}, D{})'.format(end, axis)
            else:
                clamped_end_str = str(clamped_end)
            delta = clamped_end - clamped_start
            if isinstance(clamped_end, tile.Value) or isinstance(clamped_start, tile.Value):
                delta_str = '{}-{}'.format(clamped_end_str, clamped_start_str)
            else:
                delta_str = str(clamped_end - clamped_start)
            if end > 0:
                out_dims[axis] = delta_str
                shape_dims[axis] = delta
            elif start - end > 0:
                out_dims[axis] = 'D{}+({})'.format(axis, delta_str)
                shape_dims[axis] += delta
            if start:
                in_idxs[axis] = 'd{}+{}'.format(axis, clamped_start_str)
        code = """
        function (I[{in_dims}]) -> (O) {{
            O[{out_idxs} : {out_dims}] = =(I[{in_idxs}]);
        }}""".format(
            in_dims=', '.join(in_dims),
            out_dims=', '.join(out_dims),
            in_idxs=', '.join(in_idxs),
            out_idxs=', '.join(out_idxs))
        outshape = tile.Shape(data.shape.dtype, shape_dims)
 
        super(SliceTensor, self).__init__(code, [('I', data)], [('O', outshape)])
slice_tensor = SliceTensor.function
[docs]class Softmax(tile.Operation):
    """
    Implements a standard ML softmax.
    """
    def __init__(self, data):
        if data.shape.ndims != 2:
            raise NotImplementedError(
                'Softmax with a non-two-dimensional tensor is not currently implemented')
        code = """
        function (I[X, Y]) -> (O) {
            O = builtin_softmax(I, X, Y);
        }"""
 
        super(Softmax, self).__init__(code, [('I', data)], [('O', data.shape)])
[docs]def softmax(x, axis=None):
    if x.shape.ndims == 2:
        return Softmax.function(x)
    if axis is None:
        axis = 1
    full_dims = x.shape.dims
    if axis == 0:
        group = 1
    else:
        group = functools.reduce(lambda x, y: x * y, x.shape.dims[:axis])
    if axis == len(x.shape.dims):
        values = 1
    else:
        values = functools.reduce(lambda x, y: x * y, x.shape.dims[axis:])
    flat_x = reshape(x, (group, values))
    result = Softmax.function(flat_x) 
    return reshape(result, full_dims)
[docs]class Sqrt(tile.Operation):
    """
    Computes the elementwise square root of a value.
    """
    def __init__(self, x):
        super(Sqrt, self).__init__("""
            function (I) -> (O) {
                IC = (I < 0 ? 0 : I);
                O = sqrt(IC); 
            }""", [('I', x)], [('O', x.shape)])
sqrt = Sqrt.function
[docs]def squeeze(x, axes):
    dims = [x.shape.dims[axis] for axis in range(x.shape.ndims) if axis not in axes] 
    return reshape(x, dims)
[docs]class Summation(tile.Operation):
    """
    Sums an input value along some set of axes.
    """
    def __init__(self, value, axes=None, keepdims=False, floatx=plaidml.DType.FLOAT32):
        if value.shape.dtype == plaidml.DType.BOOLEAN:
            value = cast(value, floatx)
        if axes is None:
            axes = list(range(value.shape.ndims))
        dims, _, subs = tile.compute_aggregation_axes(value.shape.dims, axes, keepdims)
        code = """
               function (I[{src_ranges}]) -> (O) {{
                   O[{dest_indices}{dest_sep}{dest_ranges}] = +(I[{src_indices}]);
               }}""".format(**subs)
        super(Summation, self).__init__(code, [('I', value)], 
                                        [('O', tile.Shape(value.shape.dtype, dims))])
[docs]def summation(value, axes=None, keepdims=False, floatx=plaidml.DType.FLOAT32):
    if not value.shape.ndims:
        return value
    if isinstance(axes, (tuple, list)) and not len(axes):
        # We're taking the sum across an empty axis list.
        return value 
    return Summation.function(value, axes=axes, keepdims=keepdims, floatx=floatx)
[docs]def tanh(data):
    """Elementwise hyperbolic tangent.""" 
    return tile.unary_op(data, 'tanh(I)', 'Tanh')
[docs]def unsqueeze(x, axes):
    src_idx = 0
    dims = []
    for axis in range(len(x.shape.dims) + len(axes)):
        if axis in axes:
            dims.append(1)
        else:
            dims.append(x.shape.dims[src_idx])
            src_idx += 1 
    return reshape(x, dims)
[docs]class Variance(tile.Operation):
    def __init__(self, x, axes=None, keepdims=False, floatx=plaidml.DType.FLOAT32):
        # This closely follows the implementation of the mean method
        # This computes the *uncorrected* sample variance (i.e. denominator
        # = n rather than = n-1) to match tensorflow
        if x.shape.dtype == plaidml.DType.BOOLEAN:
            x = cast(x, floatx)
        if not x.shape.ndims:
            return x
        if axes == None:
            axes = list(range(x.shape.ndims))
        shape, axes, subs = tile.compute_aggregation_axes(x.shape.dims, axes, keepdims)
        subs['prod_src_ranges'] = '*'.join(['X' + str(i) for i in axes])
        subs['mean_ranges'] = ', '.join(['Y' + str(i) for i in range(x.shape.ndims)])
        m = mean(x, axes, True, floatx)
        # TODO: Might be possible to write this more efficiently
        f = """
            function (I[{src_ranges}], M[{mean_ranges}]) -> (O) {{
                DIFF_SQ = (I - M) * (I - M);
                SUM[{dest_indices}{dest_sep}{dest_ranges}] = +(DIFF_SQ[{src_indices}]);
                O = SUM / ({prod_src_ranges});
            }}""".format(**subs)
        super(Variance, self).__init__(f, [('I', x), ('M', m)], 
                                       [('O', tile.Shape(x.shape.dtype, shape))])
variance = Variance.function