Skip to content

Commit

Permalink
Merge pull request #8 from OpenGATE/gaga_spect
Browse files Browse the repository at this point in the history
Cleaning
  • Loading branch information
tbaudier authored Dec 20, 2023
2 parents cef82d8 + dde7407 commit 5bda969
Show file tree
Hide file tree
Showing 27 changed files with 1,068 additions and 803 deletions.
30 changes: 28 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,36 @@ jobs:
name: dist
path: dist/

publish_wheel:
test_wheel:
runs-on: ubuntu-latest
needs: [build_wheel]
steps:
- name: Checkout github repo
uses: actions/checkout@v4
- name: Checkout submodules
run: git submodule update --init --recursive
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: 3.11
architecture: 'x64'
- uses: actions/download-artifact@v3
with:
name: dist
path: dist/
- name: Test the wheel
shell: bash {0}
run: |
pip install dist/gaga_phsp-*-py3-none-any.whl
cd tests
mkdir pth
python test001_non_cond.py
python test002_cond.py
publish_wheel:
runs-on: ubuntu-latest
needs: [test_wheel]
steps:
- name: Checkout github repo
uses: actions/checkout@v4
- name: Checkout submodules
Expand All @@ -49,7 +75,7 @@ jobs:
path: dist/
- name: Publish to PyPI
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags/')
uses: pypa/gh-action-pypi-publish@release/v1
uses: pypa/gh-action-pypi-publish@master
with:
user: __token__
password: ${{ secrets.PYPI }}
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ aa.png
/tests/npy/*.npy
/tests/*pth
/tests/*pt
/tests/*png
/tests/*old*


/save/

Expand All @@ -38,6 +41,7 @@ aa.png
/tests/a.pdf
/tests/jz/
/tests/output/
/tests/pth_oct_2022

*OLD*
/gaga/gaga_helpers_pet_before_dw_change.py
Expand Down
144 changes: 0 additions & 144 deletions bin/gaga_garf_generate_img

This file was deleted.

2 changes: 0 additions & 2 deletions bin/gaga_gauss_cond_test
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from shutil import copyfile
import click
from matplotlib import pyplot as plt
from scipy.stats import gaussian_kde
from gatetools import phsp
import numpy as np
import gaga_phsp as gaga
Expand Down
51 changes: 30 additions & 21 deletions bin/gaga_gauss_plot
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,24 @@ from matplotlib import pyplot as plt
import gaga_phsp as gaga
from scipy.stats import gaussian_kde

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"])


@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('phsp_filename')
@click.argument('pth_filename')
@click.option('--n', '-n', default=1e4, help='Number of samples to get from the phsp')
@click.option('--m', '-m', default=1e4, help='Number of samples to generate from the GAN')
@click.option('-x', default=float(1), help='Condition x')
@click.option('-y', default=float(1), help='Condition y')
@click.option('--epoch', '-e', default=-1, help='Load the G net at the given epoch (-1 for last stored epoch)')
@click.argument("phsp_filename")
@click.argument("pth_filename")
@click.option("--n", "-n", default=1e4, help="Number of samples to get from the phsp")
@click.option(
"--m", "-m", default=1e4, help="Number of samples to generate from the GAN"
)
@click.option("-x", default=float(1), help="Condition x")
@click.option("-y", default=float(1), help="Condition y")
@click.option(
"--epoch",
"-e",
default=-1,
help="Load the G net at the given epoch (-1 for last stored epoch)",
)
def gaga_gauss_plot(phsp_filename, pth_filename, n, m, epoch, x, y):
"""
\b
Expand All @@ -41,42 +48,44 @@ def gaga_gauss_plot(phsp_filename, pth_filename, n, m, epoch, x, y):

# generate samples with condition
cond = None
if len(params['cond_keys']) > 0:
if len(params["cond_keys"]) > 0:
condx = np.ones(m) * x
condy = np.ones(m) * y
print(condx.shape, condy.shape)
cond = np.column_stack((condx, condy))
print(cond.shape)
fake = gaga.generate_samples2(params, G, D, m, m, False, True, cond=cond)
fake = gaga.generate_samples3(params, G, m, cond=cond)
else:
fake = gaga.generate_samples_non_cond(params, G, m, m, False, True)

# get 2D points
x_ref = real[:, 0]
y_ref = real[:, 1]
x = fake[:, 0]
y = fake[:, 1]
print('ref shape', x_ref.shape, y_ref.shape)
print('gan shape', x.shape, y.shape)
print("ref shape", x_ref.shape, y_ref.shape)
print("gan shape", x.shape, y.shape)

print('ref y min max', y_ref.min(), y_ref.max())
print('ref x min max', x_ref.min(), x_ref.max())
print("ref y min max", y_ref.min(), y_ref.max())
print("ref x min max", x_ref.min(), x_ref.max())

print('gan y min max', y.min(), y.max())
print('gan x min max', x.min(), x.max())
print("gan y min max", y.min(), y.max())
print("gan x min max", x.min(), x.max())

# plot
fig, ax = plt.subplots(1, 1, figsize=(20, 10))

a = ax
a.scatter(x_ref, y_ref, marker='.', s=0.1)
a.scatter(x, y, marker='.', s=0.1)
a.axis('equal')
a.scatter(x_ref, y_ref, marker=".", s=0.1)
a.scatter(x, y, marker=".", s=0.1)
a.axis("equal")

plt.title(pth_filename)
f = f'cond.png'
f = f"cond.png"
print(f)
plt.savefig(f)


# --------------------------------------------------------------------------
if __name__ == '__main__':
if __name__ == "__main__":
gaga_gauss_plot()
Loading

0 comments on commit 5bda969

Please sign in to comment.