Skip to content

Commit

Permalink
Merge pull request #92 from noskill/pytorch-tv
Browse files Browse the repository at this point in the history
pytorch tensor truth value
  • Loading branch information
vsbogd authored Apr 17, 2019
2 parents af22319 + bb6e375 commit 43b8a11
Show file tree
Hide file tree
Showing 8 changed files with 360 additions and 1 deletion.
1 change: 1 addition & 0 deletions opencog/atoms/atom_types/atom_types.script
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ INDEFINITE_TRUTH_VALUE <- TRUTH_VALUE
FUZZY_TRUTH_VALUE <- TRUTH_VALUE
PROBABILISTIC_TRUTH_VALUE <- TRUTH_VALUE
EVIDENCE_COUNT_TRUTH_VALUE <- TRUTH_VALUE
TENSOR_TRUTH_VALUE <- TRUTH_VALUE "TensorTruthValue"

// The AttentionValue
ATTENTION_VALUE <- FLOAT_VALUE
Expand Down
21 changes: 21 additions & 0 deletions opencog/atoms/truthvalue/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
if(HAVE_CYTHON)
INCLUDE_DIRECTORIES(
${PYTHON_INCLUDE_DIRS}
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_BINARY_DIR})
endif(HAVE_CYTHON)

ADD_LIBRARY (truthvalue
CountTruthValue.cc
EvidenceCountTruthValue.cc
Expand All @@ -8,13 +15,26 @@ ADD_LIBRARY (truthvalue
TruthValue.cc
)

if(HAVE_CYTHON)
set_property(TARGET truthvalue
APPEND PROPERTY SOURCES
TensorTruthValue.cc)
endif(HAVE_CYTHON)

# Without this, parallel make will race and crap up the generated files.
ADD_DEPENDENCIES(truthvalue opencog_atom_types)

TARGET_LINK_LIBRARIES(truthvalue
value
${COGUTIL_LIBRARY}
)
if(HAVE_CYTHON)
TARGET_LINK_LIBRARIES(truthvalue
${PYTHON_LIBRARIES}
)
endif(HAVE_CYTHON)



INSTALL (TARGETS truthvalue EXPORT AtomSpaceTargets
DESTINATION "lib${LIB_DIR_SUFFIX}/opencog"
Expand All @@ -28,5 +48,6 @@ INSTALL (FILES
SimpleTruthValue.h
EvidenceCountTruthValue.h
TruthValue.h
TensorTruthValue.h
DESTINATION "include/opencog/atoms/truthvalue"
)
125 changes: 125 additions & 0 deletions opencog/atoms/truthvalue/TensorTruthValue.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* opencog/atoms/truthvalue/TensorTruthValue.cc
*
* Written by Anatoly Belikov <[email protected]>
* All Rights Reserved
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License v3 as
* published by the Free Software Foundation and including the exceptions
* at http://opencog.org/wiki/Licenses
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program; if not, write to:
* Free Software Foundation, Inc.,
* 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
*/

#include <opencog/atoms/truthvalue/TensorTruthValue.h>
#include <Python.h>


using namespace opencog;


TensorTruthValue::TensorTruthValue(PyObject * p):TruthValue(TENSOR_TRUTH_VALUE), ptr(p), _count(1){
PyGILState_STATE state = PyGILState_Ensure();
Py_INCREF(p);
PyGILState_Release(state);
};

float TensorTruthValue::getAttr(std::string p_attrname) const{
PyGILState_STATE state = PyGILState_Ensure();
PyObject * tensor = (PyObject*)(this->ptr);
PyObject * const attrname = PyUnicode_FromString(p_attrname.c_str());
PyObject * res_obj = PyObject_GetAttr(tensor, attrname);
PyObject * float_obj = nullptr;
Py_DECREF(attrname);
bool failed = false;
double result = -1.0;
if (res_obj){
float_obj = PyObject_CallMethod(res_obj, "__float__", NULL);
if(float_obj) {
result = PyFloat_AsDouble(float_obj);
}
} else {
failed = true;
}
if(res_obj) Py_DECREF(res_obj);
if(float_obj) Py_DECREF(float_obj);
PyGILState_Release(state);
if(failed)
throw RuntimeException(TRACE_INFO, "failed to get element of tensor");
return result;
}


strength_t TensorTruthValue::get_mean() const {
return this->getAttr("mean");
}


confidence_t TensorTruthValue::get_confidence() const {
return this->getAttr("confidence");
}


count_t TensorTruthValue::get_count() const {
return this->_count;
}


TruthValuePtr TensorTruthValue::clone() const {
PyObject * tensor = (PyObject*)(this->ptr);
PyGILState_STATE state = PyGILState_Ensure();
PyObject * res_obj = PyObject_CallMethod(tensor, "clone", NULL);
PyGILState_Release(state);
if(res_obj == nullptr){
throw RuntimeException(TRACE_INFO, "failed to clone object");
}
return createTensorTruthValue(res_obj);
}

bool TensorTruthValue::operator==(const Value& other) const {
return this == &other;
}

TruthValuePtr TensorTruthValue::merge(const TruthValuePtr& other,
const MergeCtrl& mc) const
{
throw RuntimeException(TRACE_INFO,
"merge is not implemented");
}

TensorTruthValue::~TensorTruthValue(){
PyGILState_STATE state = PyGILState_Ensure();
Py_DECREF(this->ptr);
PyGILState_Release(state);
}

void * TensorTruthValue::getPtr(){
return this->ptr;
}

std::string TensorTruthValue::to_string(const std::string&) const {
PyGILState_STATE state = PyGILState_Ensure();
PyObject * str = PyObject_Str(this->ptr);
if(str == nullptr){
PyGILState_Release(state);
throw RuntimeException(TRACE_INFO, "error calling __str__ on python object");
}
#if PY_MAJOR_VERSION == 2
const char * tmp = PyBytes_AsString(str);
#else
const char * tmp = PyUnicode_AsUTF8(str);
#endif
std::string result = std::string(tmp);
Py_DECREF(str);
PyGILState_Release(state);
return result;
}
76 changes: 76 additions & 0 deletions opencog/atoms/truthvalue/TensorTruthValue.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* opencog/atoms/truthvalue/TensorTruthValue.h
*
* Written by Anatoly Belikov <[email protected]>
* All Rights Reserved
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License v3 as
* published by the Free Software Foundation and including the exceptions
* at http://opencog.org/wiki/Licenses
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program; if not, write to:
* Free Software Foundation, Inc.,
* 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
*/

#ifndef _OPENCOG_TORCH_TRUTH_VALUE_H_
#define _OPENCOG_TORCH_TRUTH_VALUE_H_

#include <opencog/atoms/truthvalue/TruthValue.h>
#include <opencog/atoms/value/PtrValue.h>

#ifndef PyObject_HEAD
struct _object;
typedef _object PyObject;
#endif


namespace opencog
{

/*
* class TensorTruthValue holds pointer to torch.Tensor wrapper.
* methods such as get_mean, get_confidence call this python wrapper.
* Otherwise it is similiar to SimpleTruthValue.
*/

class TensorTruthValue: public TruthValue
{
private:
PyObject * ptr;
unsigned int _count;

float getAttr(std::string attrname) const;
public:
TensorTruthValue(const TensorTruthValue &) = delete;
TensorTruthValue(PyObject * p);
virtual strength_t get_mean() const;
virtual confidence_t get_confidence() const;
virtual count_t get_count() const;

virtual TruthValuePtr clone() const;
virtual bool operator==(const Value&) const;
virtual TruthValuePtr merge(const TruthValuePtr& other,
const MergeCtrl& mc) const;
virtual ~TensorTruthValue();
virtual void * getPtr();
virtual std::string to_string(const std::string&) const;
};

typedef std::shared_ptr<TensorTruthValue> TensorTruthValuePtr;
template<typename ... Type>
static inline TensorTruthValuePtr createTensorTruthValue(Type&&... args) {
return std::make_shared<TensorTruthValue>(std::forward<Type>(args)...);
}


}

#endif
2 changes: 1 addition & 1 deletion opencog/cython/opencog/atom.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ cdef class Atom(Value):
tvp = atom_ptr.getTruthValue()
if (not tvp.get()):
raise AttributeError('cAtom returned NULL TruthValue pointer')
return createTruthValue(tvp.get().get_mean(), tvp.get().get_confidence())
return create_python_value_from_c_value(<shared_ptr[cValue]&>tvp, self.atomspace)

def __set__(self, truth_value):
try:
Expand Down
17 changes: 17 additions & 0 deletions opencog/cython/opencog/atomspace.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,23 @@ cdef class TruthValue(Value):
cdef tv_ptr* _tvptr(self)


cdef extern from "opencog/atoms/truthvalue/TensorTruthValue.h" namespace "opencog":
ctypedef shared_ptr[const cTensorTruthValue] ttv_ptr "opencog::TensorTruthValuePtr"
cdef cppclass cTensorTruthValue "opencog::TensorTruthValue"(cTruthValue):
cTensorTruthValue(object)
strength_t get_mean() except +
confidence_t get_confidence() except +
count_t get_count()
#tv_ptr DEFAULT_TV()
string to_string()
bint operator==(cTruthValue h)
bint operator!=(cTruthValue h)
void * getPtr()

cdef ttv_ptr createTensorTruthValue(...)



# Atom
cdef extern from "opencog/atoms/base/Link.h" namespace "opencog":
pass
Expand Down
72 changes: 72 additions & 0 deletions opencog/cython/opencog/truth_value.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def createTruthValue(strength = 1.0, confidence = 1.0):
c_ptr.reset(new cSimpleTruthValue(strength, confidence))
return TruthValue(ptr_holder = PtrHolder.create(<shared_ptr[void]&>c_ptr))


cdef class TruthValue(Value):
""" The truth value represents the strength and confidence of
a relationship or term. In OpenCog there are a number of TruthValue
Expand Down Expand Up @@ -55,3 +56,74 @@ cdef class TruthValue(Value):
def truth_value_ptr_object(self):
return PyLong_FromVoidPtr(<void*>self._tvptr())


class SimpleTruthValue(TruthValue):
pass


MEAN = 0
CONFIDENCE = 1

try:
import torch
class TensorTruthValueWrapper(torch.Tensor):

@staticmethod
def __new__(cls, *args):
if len(args) == 1:
assert(len(args[0]) == 2)
instance = super().__new__(cls, *args)
elif len(args) == 2:
instance = super().__new__(cls, args)
else:
raise RuntimeError("Expecting tuple of two number, \
tensor of len 2 or two numbers, got {0}".format(args))
return instance

@property
def mean(self):
return self[MEAN]

@property
def confidence(self):
return self[CONFIDENCE]

def __str__(self):
return 'TensorTruthValue({0}, {1})'.format(self.mean,
self.confidence)


except ImportError as e:
print("Torch not found, torch truth value will not be available")


cdef class TensorTruthValue(TruthValue):
cdef object ttv
def __init__(self, *args, **kwargs):
cdef tv_ptr c_ptr
cdef cTensorTruthValue * this_ptr
ptr_holder = kwargs.get('ptr_holder', None)
if ptr_holder is not None:
super(TruthValue, self).__init__(ptr_holder=ptr_holder)
this_ptr = <cTensorTruthValue*>self.get_c_value_ptr().get()
self.ttv = <object>(deref(this_ptr).getPtr())
else:
self.ttv = TensorTruthValueWrapper(*args)
c_ptr = <tv_ptr>createTensorTruthValue(<PyObject*>self.ttv)
super(TruthValue, self).__init__(PtrHolder.create(<shared_ptr[void]&>c_ptr))

cdef _mean(self):
return self.ttv.mean

cdef _confidence(self):
return self.ttv.confidence

def torch(self):
return self.ttv

def __getitem__(self, idx):
assert 0 <= idx <= 1
return self.ttv[idx]

def __str__(self):
return str(self.ttv)
Loading

0 comments on commit 43b8a11

Please sign in to comment.