Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pytorch tensor truth value #92

Merged
merged 22 commits into from
Apr 17, 2019
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions opencog/atoms/atom_types/atom_types.script
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ INDEFINITE_TRUTH_VALUE <- TRUTH_VALUE
FUZZY_TRUTH_VALUE <- TRUTH_VALUE
PROBABILISTIC_TRUTH_VALUE <- TRUTH_VALUE
EVIDENCE_COUNT_TRUTH_VALUE <- TRUTH_VALUE
TTRUTH_VALUE <- TRUTH_VALUE "TTruthValue"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest name it TORCH_TRUTH_VALUE as TTRUTH_VALUE is unclear and ambiguous - usually T at the beginning sounds as Type. And rename TTruthValue to TorchTruthValue as well for clarity.


// The AttentionValue
ATTENTION_VALUE <- FLOAT_VALUE
Expand Down
23 changes: 22 additions & 1 deletion opencog/atoms/truthvalue/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,20 +1,40 @@
if(HAVE_CYTHON)
INCLUDE_DIRECTORIES(
${PYTHON_INCLUDE_DIRS}
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_BINARY_DIR})
endif(HAVE_CYTHON)

ADD_LIBRARY (truthvalue
CountTruthValue.cc
EvidenceCountTruthValue.cc
FuzzyTruthValue.cc
IndefiniteTruthValue.cc
ProbabilisticTruthValue.cc
SimpleTruthValue.cc
TruthValue.cc
TruthValue.cc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whitespace issue here

)

if(HAVE_CYTHON)
set_property(TARGET truthvalue
APPEND PROPERTY SOURCES
TTruthValue.cc)
endif(HAVE_CYTHON)

# Without this, parallel make will race and crap up the generated files.
ADD_DEPENDENCIES(truthvalue opencog_atom_types)

TARGET_LINK_LIBRARIES(truthvalue
value
${COGUTIL_LIBRARY}
)
if(HAVE_CYTHON)
TARGET_LINK_LIBRARIES(truthvalue
${PYTHON_LIBRARIES}
)
endif(HAVE_CYTHON)



INSTALL (TARGETS truthvalue EXPORT AtomSpaceTargets
DESTINATION "lib${LIB_DIR_SUFFIX}/opencog"
Expand All @@ -28,5 +48,6 @@ INSTALL (FILES
SimpleTruthValue.h
EvidenceCountTruthValue.h
TruthValue.h
TTruthValue.h
DESTINATION "include/opencog/atoms/truthvalue"
)
110 changes: 110 additions & 0 deletions opencog/atoms/truthvalue/TTruthValue.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* opencog/atoms/truthvalue/TTruthValue.cc
*
* Written by Anatoly Belikov <[email protected]>
* All Rights Reserved
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License v3 as
* published by the Free Software Foundation and including the exceptions
* at http://opencog.org/wiki/Licenses
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program; if not, write to:
* Free Software Foundation, Inc.,
* 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
*/

#include <opencog/atoms/truthvalue/TTruthValue.h>
#include "SimpleTruthValue.h"

#include <Python.h>

using namespace opencog;


TTruthValue::TTruthValue(PyObject * p):TruthValue(TTRUTH_VALUE), ptr(p), _count(1){
PyGILState_STATE state = PyGILState_Ensure();
Py_INCREF(p);
PyGILState_Release(state);
};

float TTruthValue::getAttr(std::string p_attrname) const{
PyGILState_STATE state = PyGILState_Ensure();
PyObject * tensor = (PyObject*)(this->ptr);
PyObject * const attrname = PyUnicode_FromString(p_attrname.c_str());
PyObject * res_obj = PyObject_GetAttr(tensor, attrname);
PyObject * float_obj = nullptr;
Py_DECREF(attrname);
bool failed = false;
double result;
if (res_obj){
float_obj = PyObject_CallMethod(res_obj, "__float__", NULL);
if(float_obj) {
result = PyFloat_AsDouble(float_obj);
}
} else {
failed = true;
}
if(res_obj) Py_DECREF(res_obj);
if(float_obj) Py_DECREF(float_obj);
PyGILState_Release(state);
if(failed)
throw RuntimeException(TRACE_INFO, "failed to get element of tensor");
return result;
}


strength_t TTruthValue::get_mean() const {
return this->getAttr("mean");
}


confidence_t TTruthValue::get_confidence() const {
return this->getAttr("confidence");
}


count_t TTruthValue::get_count() const {
return this->_count;
}


TruthValuePtr TTruthValue::clone() const {
PyObject * tensor = (PyObject*)(this->ptr);
PyGILState_STATE state = PyGILState_Ensure();
PyObject * res_obj = PyObject_CallMethod(tensor, "clone", NULL);
PyGILState_Release(state);
if(res_obj == nullptr){
throw RuntimeException(TRACE_INFO, "failed to clone object");
}
return createTTruthValue(res_obj);
}

bool TTruthValue::operator==(const Value& other) const {
return this == &other;
}

TruthValuePtr TTruthValue::merge(const TruthValuePtr& other,
const MergeCtrl& mc) const
{
throw RuntimeException(TRACE_INFO,
"merge is not implemented");
}

TTruthValue::~TTruthValue(){
PyGILState_STATE state = PyGILState_Ensure();
Py_DECREF(this->ptr);
PyGILState_Release(state);
}

void * TTruthValue::getPtr(){
return this->ptr;
}


76 changes: 76 additions & 0 deletions opencog/atoms/truthvalue/TTruthValue.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* opencog/atoms/truthvalue/TTruthValue.h
*
* Written by Anatoly Belikov <[email protected]>
* All Rights Reserved
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License v3 as
* published by the Free Software Foundation and including the exceptions
* at http://opencog.org/wiki/Licenses
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program; if not, write to:
* Free Software Foundation, Inc.,
* 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
*/

#ifndef _OPENCOG_TORCH_TRUTH_VALUE_H_
#define _OPENCOG_TORCH_TRUTH_VALUE_H_

#include <opencog/atoms/truthvalue/TruthValue.h>
#include <opencog/atoms/value/PtrValue.h>

#ifndef PyObject_HEAD
struct _object;
typedef _object PyObject;
#endif


namespace opencog
{

/*
* class TTruthValue holds pointer to torch.Tensor wrapper.
* methods such as get_mean, get_confidence call this python wrapper.
* Otherwise it is similiar to SimpleTruthValue.
*/

class TTruthValue: public TruthValue
{
private:
PyObject * ptr;
unsigned int _count;

float getAttr(std::string attrname) const;
public:
TTruthValue(const TTruthValue &) = delete;
TTruthValue(PyObject * p);
virtual strength_t get_mean() const;
virtual confidence_t get_confidence() const;
virtual count_t get_count() const;

virtual TruthValuePtr clone() const;
virtual bool operator==(const Value&) const;
virtual TruthValuePtr merge(const TruthValuePtr& other,
const MergeCtrl& mc) const;
virtual ~TTruthValue();
virtual void * getPtr();

};

typedef std::shared_ptr<TTruthValue> TTruthValuePtr;
template<typename ... Type>
static inline TTruthValuePtr createTTruthValue(Type&&... args) {
return std::make_shared<TTruthValue>(std::forward<Type>(args)...);
}


}

#endif
2 changes: 1 addition & 1 deletion opencog/cython/opencog/atom.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ cdef class Atom(Value):
tvp = atom_ptr.getTruthValue()
if (not tvp.get()):
raise AttributeError('cAtom returned NULL TruthValue pointer')
return createTruthValue(tvp.get().get_mean(), tvp.get().get_confidence())
return create_python_value_from_c_value(<shared_ptr[cValue]&>tvp, self.atomspace)

def __set__(self, truth_value):
try:
Expand Down
17 changes: 17 additions & 0 deletions opencog/cython/opencog/atomspace.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,23 @@ cdef class TruthValue(Value):
cdef tv_ptr* _tvptr(self)


cdef extern from "opencog/atoms/truthvalue/TTruthValue.h" namespace "opencog":
ctypedef shared_ptr[const cTTruthValue] ttv_ptr "opencog::TTruthValuePtr"
cdef cppclass cTTruthValue "opencog::TTruthValue"(cTruthValue):
cTTruthValue(object)
strength_t get_mean()
noskill marked this conversation as resolved.
Show resolved Hide resolved
confidence_t get_confidence()
noskill marked this conversation as resolved.
Show resolved Hide resolved
count_t get_count()
#tv_ptr DEFAULT_TV()
string to_string()
bint operator==(cTruthValue h)
bint operator!=(cTruthValue h)
void * getPtr()

cdef ttv_ptr createTTruthValue(...)



# Atom
cdef extern from "opencog/atoms/base/Link.h" namespace "opencog":
pass
Expand Down
68 changes: 68 additions & 0 deletions opencog/cython/opencog/truth_value.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def createTruthValue(strength = 1.0, confidence = 1.0):
c_ptr.reset(new cSimpleTruthValue(strength, confidence))
return TruthValue(ptr_holder = PtrHolder.create(<shared_ptr[void]&>c_ptr))


cdef class TruthValue(Value):
""" The truth value represents the strength and confidence of
a relationship or term. In OpenCog there are a number of TruthValue
Expand Down Expand Up @@ -55,3 +56,70 @@ cdef class TruthValue(Value):
def truth_value_ptr_object(self):
return PyLong_FromVoidPtr(<void*>self._tvptr())


class SimpleTruthValue(TruthValue):
pass


MEAN = 0
CONFIDENCE = 1

try:
import torch
class TTruthValueWrapper(torch.Tensor):

@staticmethod
def __new__(cls, *args):
if len(args) == 1:
assert(len(args[0]) == 2)
instance = super().__new__(cls, *args)
elif len(args) == 2:
instance = super().__new__(cls, args)
else:
raise RuntimeError("Expecting tuple of two number, \
tensor of len 2 or two numbers, got {0}".format(args))
return instance

@property
def mean(self):
return self[MEAN]

@property
def confidence(self):
return self[CONFIDENCE]


except ImportError as e:
print("Torch not found, torch truth value will not be available")


cdef class TTruthValue(TruthValue):
cdef object ttv
def __init__(self, *args, **kwargs):
cdef tv_ptr c_ptr
cdef cTTruthValue * this_ptr
ptr_holder = kwargs.get('ptr_holder', None)
if ptr_holder is not None:
super(TruthValue, self).__init__(ptr_holder=ptr_holder)
this_ptr = <cTTruthValue*>self.get_c_value_ptr().get()
self.ttv = <object>(deref(this_ptr).getPtr())
else:
self.ttv = TTruthValueWrapper(*args)
c_ptr = <tv_ptr>createTTruthValue(<PyObject*>self.ttv)
super(TruthValue, self).__init__(PtrHolder.create(<shared_ptr[void]&>c_ptr))

cdef _mean(self):
return self.ttv.mean

cdef _confidence(self):
return self.ttv.confidence

def torch(self):
return self.ttv

def __getitem__(self, idx):
assert 0 <= idx <= 1
return self.ttv[idx]

def __str__(self):
return 'TTruthValue(' + str(self.ttv) + ')'
Loading