-
Notifications
You must be signed in to change notification settings - Fork 875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mlxtend to support categorical data in plot_decision_regions #607
Comments
Glad to hear you found the function useful overall. Regarding the onehot encoded outputs, I am wondering if the following workaround works: class Onehot2Int(object):
def __init__(self, model):
self.model = model
def predict(self, X):
y_pred = self.model(X)
return np.argmax(y_pred, axis=1)
# fit keras_model
keras_model_no_ohe = Onehot2Int(keras_model)
plot_decision_regions(X, y, keras_model_no_ohe) We could maybe add an additional parameter to the |
This would work if Mlxtend didn't check for array shape, I guess :-)
Full code so you can test yourself:
|
The class labels aren't used for model fitting, though, just for assigning the class labels in the plot. So, you can simply pass the class label array as non-onehot encoded array. For example
I think you also need to change class Onehot2Int(object):
def __init__(self, model):
self.model = model
def predict(self, X):
y_pred = self.model(X)
return np.argmax(y_pred, axis=1) to class Onehot2Int(object):
def __init__(self, model):
self.model = model
def predict(self, X):
y_pred = self.model.predict(X)
return np.argmax(y_pred, axis=1) |
This works beautifully! Quite a shame that I didn't see the Thanks a lot! Slight question: I publish my code (e.g. the code from my previous comment) on my website, where I dissect the code into small pieces, explaining what happens, so other folks interested in machine learning can learn from my learnings. I also publish the code on my GitHub profile, with maximum open source licenses (CC0). When publishing the code for categorical hinge, I'd like to include your solution for the Mlxtend function plots, so that people can run the code once and get the results. Would you mind if I included your code in my GitHub repo and on my website? Obviously, I'll reference to Mlxtend and this issue to illustrate your help. Hope to hear from you! Thanks again 😎👍 |
Taking a look at how you licensed Mlxtend, I assumed you wouldn't if I referenced you properly. Hope this is ok. If not, please let me know, and I'll remove it for sure. |
I am glad that it works! Will add an example to the documentation as well for future reference then. Also, thanks for asking regarding the code reuse. As you just said, that'd be totally fine with me :). Nice post, btw! |
Thanks twice! 😄 |
Hi there,
Thanks for your work! I'm happily using
plot_decision_regions
to visualize the decision boundary for my Keras models.I'm currently experimenting with loss functions to get a feel for how they work. Currently, my (very simple) setup is as follows for testing how the Keras implementation of
categorical_hinge
(multiclass hinge loss) works:make_blobs
containing three separable clusters, like this:categorical_hinge
. The model learns to classify the testing data into the correct cluster successfully.However, when plotting the decision boundaries with Mlxtend's
plot_decision_regions
, I run into this error:I believe it originates from the fact that my target data has to be one-hot encoded in order to allow Keras to apply categorical hinge loss. This belief is strengthened by the fact that 921600 divided by 640 is 1440, which itself divided by 3 (number of clusters and hence given one-hot encoding number of target values per target sample, e.g. [1 0 0] ) is the requested 480.
Why this problem emerges is because the actual
model clf
used by Mlxtend produces one-hot encoded outputs itself, which apparently goes wrong in Mlxtend.With this issue, I'm hoping that I can request support for categorical data in the
plot_decision_regions
function. If I'm wrong in my interpretation of this error, I'd appreciate to find out how I can make this visualization run.Thanks very much!
The text was updated successfully, but these errors were encountered: