Skip to content

Commit

Permalink
updated operator= such that the basic data type is considered
Browse files Browse the repository at this point in the history
  • Loading branch information
Konrad1991 committed Jun 3, 2024
1 parent ef1beef commit 01447cf
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 146 deletions.
1 change: 1 addition & 0 deletions include/etr_bits/BinaryCalculations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct BinaryOperation {
using typeTraitR = R;
MatrixParameter mp;
bool mpCalculated = false;

bool im() const {
if constexpr (std::is_arithmetic_v<L>) {
return r.im();
Expand Down
39 changes: 0 additions & 39 deletions include/etr_bits/Derivatives/Derivs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,45 +73,6 @@ inline constexpr auto walkT() {
return res;
}

template <typename T, typename AV, typename V, typename... Args>
requires(IsVec<T> && !IsVariable<T>)
inline void assign(AV &av, V &var, Args &&...args) {
using tD = ExtractedTypeD<T>;
constexpr auto res = walkT<tD>();
ass(res.getSize(av) == var.d.getSize(av),
"Size of dependent variable does not match size of to be evaluated "
"expression");
for (std::size_t i = 0; i < res.getSize(av); i++) {
var.d.setDeriv(av, i, res.getDeriv(av, i));
}

if (var.d.size() < res.getSize(av))
var.resize(res.getSize(av));
for (std::size_t i = 0; i < res.getSize(av); i++) {
var.d.setVal(av, i, res.getVal(av, i));
}
}

// TODO: add correct requires for only a variable or constant
template <typename T, typename AV, typename V, typename... Args>
inline void assign(AV &av, V &var, Args &&...args) {
using tD = ExtractedTypeD<T>;
/*
ass(tD::getSize(av) == var.d.getSize(av),
"Size of dependent variable does not match size of to be evaluated "
"expression");
for (std::size_t i = 0; i < tD::getSize(av); i++) {
var.d.setDeriv(av, i, 0.0); // TODO: check if this is correct!!!
}
*/
if (var.d.size() < tD::getSize(av))
var.resize(tD::getSize(av));
for (std::size_t i = 0; i < tD::getSize(av); i++) {
var.d.setVal(av, i, tD::getVal(av, i));
}
}

} // namespace etr

#endif
111 changes: 63 additions & 48 deletions include/etr_bits/Vector/AssignmentOperator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,95 +42,115 @@ Vec &operator=(const TD inp) {
}

Vec &operator=(const Vec<T, R, Trait> &otherVec) {
// std::cout << "test3" << std::endl;
// printTAST<decltype(otherVec)>();
// std::cout << "test2" << std::endl;
static_assert(!isUnaryOP::value, "Cannot assign to unary calculation");
static_assert(!isBinaryOP::value, "Cannot assign to binary calculation");
static_assert(!isRVec::value,
"Cannot assign to an r value. E.g. c(1, 2, 3) <- 1");

using DataTypeOtherVec = typename etr::ExtractDataType<
std::remove_reference_t<decltype(otherVec)>>::RetType;
if constexpr (isBuffer::value) {
Buffer<T> temp(otherVec.size()); // issue: create Buffer<T> as attribute
for (std::size_t i = 0; i < otherVec.size(); i++)
temp[i] = otherVec[i];
temp.resize(otherVec.size());
for (std::size_t i = 0; i < otherVec.size(); i++) {
if constexpr (is<DataTypeOtherVec, T>) {
temp[i] = otherVec[i];
} else {
temp[i] = static_cast<T>(otherVec[i]);
}
}
d.moveit(temp);
} else if constexpr (isBorrow::value) {
ass(otherVec.size() <= d.capacity,
"number of items to replace is not a multiple of replacement length");
Buffer<T> temp(otherVec.size());
for (std::size_t i = 0; i < otherVec.size(); i++)
temp[i] = otherVec[i];
temp.resize(otherVec.size());
for (std::size_t i = 0; i < otherVec.size(); i++) {
if constexpr (is<DataTypeOtherVec, T>) {
temp[i] = otherVec[i];
} else {
temp[i] = static_cast<T>(otherVec[i]);
}
}
ass(d.sz <= otherVec.size(),
"size cannot be increased above the size of the borrowed object");
d.sz = otherVec.size();
for (std::size_t i = 0; i < otherVec.size(); i++)
d[i] = temp[i];
} else if constexpr (isBorrowSEXP::value) {
Buffer<T> temp(otherVec.size());
for (std::size_t i = 0; i < otherVec.size(); i++)
temp[i] = otherVec[i];
temp.resize(otherVec.size());
for (std::size_t i = 0; i < otherVec.size(); i++) {
if constexpr (is<DataTypeOtherVec, T>) {
temp[i] = otherVec[i];
} else {
temp[i] = static_cast<T>(otherVec[i]);
}
}
if (otherVec.size() > this->size())
d.resize(otherVec.size());
d.moveit(temp);
} else if constexpr (isSubset::value) {
ass(otherVec.size() == d.ind.size(),
"number of items to replace is not a multiple of replacement length");
Buffer<T> temp(otherVec.size());
temp.resize(otherVec.size());
for (std::size_t i = 0; i < otherVec.size(); i++) {
temp[i] = otherVec[i];
if constexpr (is<DataTypeOtherVec, T>) {
temp[i] = otherVec[i];
} else {
temp[i] = static_cast<T>(otherVec[i]);
}
}
if (d.p->size() < temp.size())
d.resize(temp.size());
for (std::size_t i = 0; i < d.ind.size(); i++) {
d[i % d.ind.size()] = temp[i];
}
}
if (otherVec.d.im()) { // issue: && !d.im() missing?
if (otherVec.d.im()) {
d.setMatrix(true, otherVec.d.nr(), otherVec.d.nc());
}
return *this;
}

template <typename T2, typename R2, typename Trait2>
Vec &operator=(const Vec<T2, R2, Trait2> &otherVec) {
// std::cout << "test4" << std::endl;
// std::cout << "test3" << std::endl;
static_assert(!isUnaryOP::value, "Cannot assign to unary calculation");
static_assert(!isBinaryOP::value, "Cannot assign to binary calculation");
static_assert(!isRVec::value,
"Cannot assign to an r value. E.g. c(1, 2, 3) <- 1");
using DataTypeOtherVec = typename etr::ExtractDataType<
std::remove_reference_t<decltype(otherVec)>>::RetType;
if constexpr (isBuffer::value) {
Buffer<T> temp(otherVec.size()); // issue: define temp as own attribute!
using RetTypeOtherVec =
std::remove_reference<decltype(otherVec.d)>::type::RetType;
using isBaseTypeRet = std::is_same<RetTypeOtherVec, BaseType>;
if constexpr (isBaseTypeRet::value) {
for (std::size_t i = 0; i < otherVec.size(); i++) {
temp.resize(otherVec.size());
for (std::size_t i = 0; i < otherVec.size(); i++) {
if constexpr (is<DataTypeOtherVec, T>) {
temp[i] = otherVec[i];
}
} else {
for (std::size_t i = 0; i < otherVec.size(); i++) {
temp[i] = static_cast<BaseType>(otherVec[i]);
} else {
temp[i] = static_cast<T>(otherVec[i]);
}
}
d.moveit(temp);
} else if constexpr (isBorrow::value) {
ass(otherVec.size() <= d.capacity,
"number of items to replace is not a multiple of replacement length");
Buffer<T> temp(otherVec.size());
for (std::size_t i = 0; i < otherVec.size(); i++)
temp[i] = otherVec[i];
temp.resize(otherVec.size());
for (std::size_t i = 0; i < otherVec.size(); i++) {
if constexpr (is<DataTypeOtherVec, T>) {
temp[i] = otherVec[i];
} else {
temp[i] = static_cast<T>(otherVec[i]);
}
}
d.sz = otherVec.size();
for (std::size_t i = 0; i < otherVec.size(); i++)
d[i] = temp[i];
} else if constexpr (isBorrowSEXP::value) {
Buffer<T> temp(otherVec.size());
using RetTypeOtherVec =
std::remove_reference<decltype(otherVec.d)>::type::RetType;
using isBaseTypeRet = std::is_same<RetTypeOtherVec, BaseType>;
if constexpr (isBaseTypeRet::value) {
for (std::size_t i = 0; i < otherVec.size(); i++) {
temp.resize(otherVec.size());
for (std::size_t i = 0; i < otherVec.size(); i++) {
if constexpr (is<DataTypeOtherVec, T>) {
temp[i] = otherVec[i];
}
} else {
for (std::size_t i = 0; i < otherVec.size(); i++) {
temp[i] = static_cast<BaseType>(otherVec[i]);
} else {
temp[i] = static_cast<T>(otherVec[i]);
}
}
if (otherVec.size() > this->size())
Expand All @@ -139,17 +159,12 @@ Vec &operator=(const Vec<T2, R2, Trait2> &otherVec) {
} else if constexpr (isSubset::value) {
ass(otherVec.size() == d.ind.size(),
"number of items to replace is not a multiple of replacement length");
Buffer<T> temp(otherVec.size());
using RetTypeOtherVec =
std::remove_reference<decltype(otherVec.d)>::type::RetType;
using isBaseTypeRet = std::is_same<RetTypeOtherVec, BaseType>;
if constexpr (isBaseTypeRet::value) {
for (std::size_t i = 0; i < otherVec.size(); i++) {
temp.resize(otherVec.size());
for (std::size_t i = 0; i < otherVec.size(); i++) {
if constexpr (is<DataTypeOtherVec, T>) {
temp[i] = otherVec[i];
}
} else {
for (std::size_t i = 0; i < otherVec.size(); i++) {
temp[i] = static_cast<BaseType>(otherVec[i]);
} else {
temp[i] = static_cast<T>(otherVec[i]);
}
}
for (std::size_t i = 0; i < d.ind.size(); i++) {
Expand Down
79 changes: 20 additions & 59 deletions tests/Derivatives_Tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,63 +29,24 @@ How to handle constants such as subset
*/

int main() {
// NOTE: Derivatives Test Nr.1
{
etr::AllVars<2, 0, 0, 2> av(1, 0);
Vec<double, VarPointer<decltype(av), 0, 0>, VariableTypeTrait> vp1(av);
Vec<double, VarPointer<decltype(av), 1, 0>, VariableTypeTrait> vp2(av);

assign<decltype(CN::constantNr0(av))>(av, vp1, CN::constantNr0(av));
assign<decltype(CN::constantNr1(av))>(av, vp2, CN::constantNr1(av));

// f = vp2 * vp1 * c(1, 2, 3) = 4, 20, 54; df/dvp2 = vp1 * c(1, 2, 3) = 1,
// 4,
// 9
assign<decltype(vp2 * vp1 * CN::constantNr0(av))>(av, vp2,
CN::constantNr0(av));
for (size_t i = 0; i < vp2.d.getSize(av); i++) {
std::cout << vp2.d.getDeriv(av, i) << std::endl;
}
print(av.varBuffer[1]);

/*
assign<decltype(vp2 * vp1 * vp2)>(av, vp2);
// f = b^2*a; df/db = a'*b + b'*a = b'*a = 2*b*a = 8, 20, 36
for (size_t i = 0; i < vp2.d.getSize(av); i++) {
std::cout << vp2.d.getDeriv(av, i) << std::endl;
}
print(av.varBuffer[1]); // 16, 50, 108
assign<decltype(vp1 + vp2)>(av, vp2);
// f = a + b; df/db = b' = 8, 20, 36
for (size_t i = 0; i < vp2.d.getSize(av); i++) {
std::cout << vp2.d.getDeriv(av, i) << std::endl;
}
print(av.varBuffer[1]); // 17, 52, 111
*/
}
// NOTE: Derivatives Test Nr.2
{
std::cout << "block Nr. 2" << std::endl;
etr::AllVars<2, 0, 0, 4> av(0, 0); // deriv with respect tp variable 1 = vp1
Vec<double, VarPointer<decltype(av), 0, 0>, VariableTypeTrait> vp1(av);
Vec<double, VarPointer<decltype(av), 1, 0>, VariableTypeTrait> vp2(av);

vp1 << coca<0>(av, 1, 2, 3); // vp1 = c(1, 2, 3); dexpr1/dvp1 = 1, 1, 1
vp2 << coca<1>(av, 4, 5, 6) *
vp1; // vp2 = c(4, 5, 6)*vp1 = 4, 10, 18; dexpr2/dvp1 = 4, 5, 6

vp1 << vp2 + vp2; // vp1 = 8, 20, 36; dexpr3/dvp1 = 8, 10, 12
print(vp1, av);
print(vp2, av);

Vec<double> deriv_vp1 = get_derivs(vp1);
print(deriv_vp1);
Vec<double> deriv_vp2 = get_derivs(vp2);
print(deriv_vp2);

vp1 = 100; // TODO: the scalar value 100 has to be encapsulated into a
// specific class
print(vp1, av);
}
etr::AllVars<2, 0, 0, 4> av(0, 0); // deriv with respect tp variable 1 = vp1
Vec<double, VarPointer<decltype(av), 0, 0>, VariableTypeTrait> vp1(av);
Vec<double, VarPointer<decltype(av), 1, 0>, VariableTypeTrait> vp2(av);

vp1 << coca<0>(av, 1, 2, 3); // vp1 = c(1, 2, 3); dexpr1/dvp1 = 1, 1, 1
vp2 << coca<1>(av, 4, 5, 6) *
vp1; // vp2 = c(4, 5, 6)*vp1 = 4, 10, 18; dexpr2/dvp1 = 4, 5, 6

vp1 << vp2 + vp2; // vp1 = 8, 20, 36; dexpr3/dvp1 = 8, 10, 12
print(vp1, av);
print(vp2, av);

Vec<double> deriv_vp1 = get_derivs(vp1);
print(deriv_vp1);
Vec<double> deriv_vp2 = get_derivs(vp2);
print(deriv_vp2);

vp1 = 100; // TODO: the scalar value 100 has to be encapsulated into a
// specific class
print(vp1, av);
}
46 changes: 46 additions & 0 deletions tests/Test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include <iostream>
#include <ostream>

class Foo {
public:
bool InUsage = false;
int attr = 0;

Foo(int i) : attr(i){};
Foo() : attr(0){};

template <typename T> Foo &operator=(const T &&otherFoo) {
if (InUsage) {
std::cout << "test" << std::endl;
int temp = otherFoo[0];
attr = temp;
} else {
attr = otherFoo[0];
}
return *this;
}

friend std::ostream &operator<<(std::ostream &os, const Foo &f) {
os << f.attr << std::endl;
return os;
}
};

template <typename T> class Bar {
public:
const T &obj;

Bar(T &&obj_) : obj(obj_) { obj_.InUsage = true; };
Bar(T &obj_) : obj(obj_) { obj_.InUsage = true; };

int operator[](size_t i) { return obj.attr + 1 + obj.attr; }
};

int main() {
Foo f1;
Bar b(f1);
Foo f2(3);
f1 = f2;
std::cout << f1 << std::endl;
return 0;
}

0 comments on commit 01447cf

Please sign in to comment.