Skip to content

Commit

Permalink
Per frame coord transform
Browse files Browse the repository at this point in the history
Signed-off-by: Kamil Tokarski <[email protected]>
  • Loading branch information
stiepan committed May 31, 2022
1 parent b1992cc commit 72643bc
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
12 changes: 7 additions & 5 deletions dali/operators/geometry/coord_transform.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -20,25 +20,27 @@
#include "dali/core/geom/mat.h"
#include "dali/core/static_switch.h"
#include "dali/kernels/kernel_manager.h"
#include "dali/pipeline/operator/operator.h"
#include "dali/operators/geometry/mt_transform_attr.h"
#include "dali/pipeline/operator/operator.h"
#include "dali/pipeline/operator/sequence_operator.h"

namespace dali {

#define COORD_TRANSFORM_INPUT_TYPES (uint8_t, int16_t, uint16_t, int32_t, float)
#define COORD_TRANSFORM_DIMS (1, 2, 3, 4, 5, 6)

template <typename Backend>
class CoordTransform : public Operator<Backend>, private MTTransformAttr {
class CoordTransform : public SequenceOperator<Backend, true>, private MTTransformAttr {
public:
explicit CoordTransform(const OpSpec &spec) : Operator<Backend>(spec), MTTransformAttr(spec) {
using Base = SequenceOperator<Backend, true>;
explicit CoordTransform(const OpSpec &spec) : Base(spec), MTTransformAttr(spec) {
dtype_ = spec_.template GetArgument<DALIDataType>("dtype");
}

bool CanInferOutputs() const override { return true; }

protected:
using Operator<Backend>::spec_;
using Base::spec_;
bool SetupImpl(std::vector<OutputDesc> &output_descs, const workspace_t<Backend> &ws) override {
auto &input = ws.template Input<Backend>(0); // get a reference to the input tensor list
const auto &input_shape = input.shape(); // get a shape - use const-ref to avoid copying
Expand Down
9 changes: 4 additions & 5 deletions dali/operators/geometry/mt_transform_attr.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,14 +35,14 @@ If a scalar value is provided, ``M`` is assumed to be a square matrix with that
diagonal. The size of the matrix is then assumed to match the number of components in the
input vectors.)",
nullptr, // no default value
true)
true, true)
.AddOptionalArg<vector<float>>("T", R"(The translation vector.
If left unspecified, no translation is applied unless MT argument is used.
The number of components of this vector must match the number of rows in matrix ``M``.
If a scalar value is provided, that value is broadcast to all components of ``T`` and the number
of components is chosen to match the number of rows in ``M``.)", nullptr, true)
of components is chosen to match the number of rows in ``M``.)", nullptr, true, true)
.AddOptionalArg<vector<float>>("MT", R"(A block matrix [M T] which combines the arguments
``M`` and ``T``.
Expand All @@ -51,8 +51,7 @@ M and leaving T unspecified.
The number of columns must be one more than the number of components in the input.
This argument is mutually exclusive with ``M`` and ``T``.)",
nullptr,
true);
nullptr, true, true);

void MTTransformAttr::ProcessMatrixArg(const OpSpec &spec, const ArgumentWorkspace &ws, int N) {
bool is_fused = HasFusedMT();
Expand Down

0 comments on commit 72643bc

Please sign in to comment.