-
Notifications
You must be signed in to change notification settings - Fork 188
pairplot refactoring #1529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
pairplot refactoring #1529
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @danielmk, great effort and really nice additions! I think this PR would already make life much easier for plotting SBI results. I have left some quite some comments, some more minor than others, so I'll provide a summary here:
-
I completely agree that it would be nice to expose the new dataclasses to the user so that they can use autocomplete, but happy to do this in a future PR. For now, already the fact that the user can import and use these, as well as the fact that we have a warning in
pairplot
andmarginal_plot
if any user-providedkwargs
are ignored is already great. Maybe a small addition in this PR would to add all the new dataclasses to the__init__
insbi/analysis
so that the user does not have to import them all explicitly to be able to use them. But maybe @janfb has some ideas for how to make the new dataclasses more easy to use directly by the user. -
Currently, you've had to add a lot of
#pyright: ignore
. Sometimes, there is no workaround for these, but when we have to add a lot of these ignore statements it means that we are probably doing something wrong. In the case of this PR, I think a lot of these can be avoided by updating the type hints (e.g. ,prepare_for_plot
now explicitly returns annp.ndarray
, and the code for other functions that callprepare_for_plot
assume that they get annp.ndarray
, but the type hint for whatprepare_for_plot
returns is stillList[np.ndarray]
. I expect that correcting this will allow us to remove a lot of the pyright errors. I have commented individually on these pyright errors, but not all. Would be good to double check whichignore
statements are strictly necessary and which we can remove with updated type hints. -
regarding the tests currently failing. As discussed separately, locally running:
ruff format sbi
ruff format tests
ruff check sbi --fix
ruff check tests --fix
pyright sbi
should fix the linting/pyright errors. I also see that the test suite cancels after a lot of tests fail. Can you check by running plot_test.py
locally to see if this is related to your changes or not? I think it's likely something about the new types that can be quickly fixed if we can see what the error message is 😄
sbi/analysis/plot.py
Outdated
@@ -81,8 +270,10 @@ def plt_hist_1d( | |||
limits: torch.Tensor, | |||
diag_kwargs: Dict, | |||
) -> None: | |||
# ax.hist(samples, **diag_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a leftover from debugging?
sbi/analysis/plot.py
Outdated
samples[i], copy=False, nan=np.nan, posinf=np.nan, neginf=np.nan | ||
) | ||
samples[i] = samples[i][~np.isnan(samples[i]).any(axis=1)] | ||
# for i in range(len(samples)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove these entirely instead of commenting them out
sbi/analysis/plot.py
Outdated
|
||
@dataclass | ||
class GenericMplKwargs(GenericKwargs): | ||
"""MplKwargs is used to generate kwargs that are passed to matplotlib in pairplot. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring is written with MplKwargs
, but the class is called GenericMplKwargs
sbi/analysis/plot.py
Outdated
epsilon_range = eps * max_min_range | ||
limits.append([min_val - epsilon_range, max_val + epsilon_range]) | ||
return limits | ||
# limits = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove instead of commenting out
sbi/analysis/plot.py
Outdated
@@ -612,21 +806,23 @@ def prepare_for_plot( | |||
of the samples. | |||
""" | |||
|
|||
samples = convert_to_list_of_numpy(samples) | |||
# samples = convert_to_list_of_numpy(samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove instead of commenting out
sbi/analysis/plot.py
Outdated
fig_kwargs_filled = _get_default_fig_kwargs() | ||
fig_kwargs_filled = _update(fig_kwargs_filled, fig_kwargs) | ||
fig_kwargs_default = FigKwargs() # Get defaults | ||
#if type(fig_kwargs) == FigKwargs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
leftover from debugging?
sbi/analysis/plot.py
Outdated
|
||
Returns: | ||
Fig: matplotlib figure | ||
Axes: matplotlib axes | ||
""" | ||
dim = samples[0].shape[1] | ||
dim = samples.shape[1] # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hint for samples
is still a List, which is why this breaks. For functions that are not exposed to the user, such as arrange_grid
, it's fine to change the type hint without adding a warning, as the user is not meant to use this function directly.
sbi/analysis/plot.py
Outdated
ax, sample[:, row], limits[row], diag_kwargs[sample_ind] | ||
) | ||
|
||
# for sample_ind, sample in enumerate(samples.T): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove instead of commenting out
sbi/analysis/plot.py
Outdated
@@ -1447,16 +1654,16 @@ def _arrange_grid( | |||
if excl_lower: | |||
ax.axis("off") # pyright: ignore reportOptionalMemberAccess | |||
else: | |||
for sample_ind, sample in enumerate(samples): | |||
lower_f = lower_funcs[sample_ind] | |||
for _, _ in enumerate(samples.T): # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
samples
is still a list here
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1529 +/- ##
==========================================
- Coverage 86.01% 78.80% -7.22%
==========================================
Files 135 135
Lines 10751 10929 +178
==========================================
- Hits 9248 8613 -635
- Misses 1503 2316 +813
Flags with carried forward coverage won't be shown. Click here to find out more.
🚀 New features to boost your workflow:
|
Thank you for the detailed review of the PR and sorry for the messy commented out code. I've gone through all the comments and made the suggested changes. Many of the # type: ignor flags are now unnecessary as you suspected. Most importantly I also fixed the issues that broke plot_test.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @danielmk, big karma points for digging into the plotting code! I think this will improve the plotting interface a lot!
I added a couple of comments.
Additionally, could I ask you to have a look at https://github.com/sbi-dev/sbi/blob/main/docs/advanced_tutorials/17_plotting_functionality.ipynb and add a note to or even change 1-2 of the examples given there showing how to use the improved interface?
"""GenericMplKwargs is used to generate kwargs that are passed to matplotlib | ||
in pairplot. kwargs that are neither in GneericMplKwargs nor used by pairplot | ||
are completely ignored. To used the dictionary interface to kwargs, make | ||
`'mpl_kwargs'` a dict key. Several specific dataclasses define the defaults | ||
for specific plot types, all of which import GenericMplKwargs: | ||
MplKwargsDiagKDE, MplKwargsDiagHist, MplKwargsDiagScatter, MplKwargsOffDiagKDE, | ||
MplKwargsOffDiagHist, MplKwargsOffDiagScatter, MplKwargsOffDiagContour, | ||
MplKwargsOffDiagPlot. | ||
|
||
Example dictionary interface: | ||
pairplot(samples, | ||
diag='kde', | ||
diag_kwargs={'bw_method': 'scott', 'mpl_kwargs': {'color': 'r'}}) | ||
|
||
Example dataclass interface: | ||
pairplot(samples, | ||
diag='kde', | ||
diag_kwargs=DiagKwargsKDE(bw_method='scott', | ||
mpl_kwargs=MplKwargsKDE(color='r'))) | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using one sentence summary at the beginning, fixing typos and giving more structure:
"""GenericMplKwargs is used to generate kwargs that are passed to matplotlib | |
in pairplot. kwargs that are neither in GneericMplKwargs nor used by pairplot | |
are completely ignored. To used the dictionary interface to kwargs, make | |
`'mpl_kwargs'` a dict key. Several specific dataclasses define the defaults | |
for specific plot types, all of which import GenericMplKwargs: | |
MplKwargsDiagKDE, MplKwargsDiagHist, MplKwargsDiagScatter, MplKwargsOffDiagKDE, | |
MplKwargsOffDiagHist, MplKwargsOffDiagScatter, MplKwargsOffDiagContour, | |
MplKwargsOffDiagPlot. | |
Example dictionary interface: | |
pairplot(samples, | |
diag='kde', | |
diag_kwargs={'bw_method': 'scott', 'mpl_kwargs': {'color': 'r'}}) | |
Example dataclass interface: | |
pairplot(samples, | |
diag='kde', | |
diag_kwargs=DiagKwargsKDE(bw_method='scott', | |
mpl_kwargs=MplKwargsKDE(color='r'))) | |
""" | |
""" | |
Provides a structured way to pass matplotlib keyword arguments (kwargs) to the `pairplot` function. | |
Key Features: | |
- Ignores kwargs not recognized by `pairplot` or defined within GenericMplKwargs. | |
- Supports a dictionary interface for simple kwargs specification. | |
- Uses specialized dataclasses for plot-type-specific defaults. | |
Usage: | |
- Dictionary Interface: | |
pairplot(samples, diag='kde', diag_kwargs={'bw_method': 'scott', 'mpl_kwargs': | |
{'color': 'r'}}) | |
- Dataclass Interface: | |
pairplot(samples, diag='kde', diag_kwargs=DiagKwargsKDE(bw_method='scott', | |
mpl_kwargs=MplKwargsKDE( | |
color='r'))) | |
Specialized Dataclasses: | |
- MplKwargsDiagKDE, MplKwargsDiagHist, MplKwargsDiagScatter: Diagonal plot defaults. | |
- MplKwargsOffDiagKDE, MplKwargsOffDiagHist, MplKwargsOffDiagScatter: Off-diagonal plot defaults. | |
- MplKwargsOffDiagContour, MplKwargsOffDiagPlot: Additional off-diagonal defaults. | |
Note: | |
- Use 'mpl_kwargs' key within a dictionary for direct matplotlib kwargs access. | |
- Dataclass interface offers type safety and pre-defined defaults. | |
""" |
@@ -717,6 +962,7 @@ def pairplot( | |||
|
|||
Args: | |||
samples: Samples used to build the histogram. | |||
np.asarry(samples) should become 2d Array (sample, parameter) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest to give more details here:
np.asarry(samples) should become 2d Array (sample, parameter) | |
When passing a list of samples, np.asarry(samples) should become 2d Array | |
of shape (sample, parameter) |
fig_kwargs_filled = _update(dict(fig_kwargs_default), dict(fig_kwargs_user)) | ||
|
||
# Prepare Diag Defaults | ||
if type(diag) is list: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
has there been created an issue for this one?
if non_default_diag_kwargs: | ||
warn( | ||
f"upper_kwargs has {len(non_default_diag_kwargs)} args that are " | ||
"ignored by pairplot: {non_default_diag_kwargs}. To pass them to " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"ignored by pairplot: {non_default_diag_kwargs}. To pass them to " | |
f"ignored by pairplot: {non_default_diag_kwargs}. To pass them to " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No issue is open regarding the kwargs list to my knowledge. I was going to open an issue on it after finishing this pull request but I can also do it sooner.
Also that's a great catch on the missing f
in front of the string!
diag_f( | ||
ax, sample[:, row], limits[row], diag_kwargs[sample_ind] | ||
) | ||
diag_f = diag_funcs[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add a TODO here pointing to the issue with the unsupported list of diag_kwargs
I worked through https://github.com/sbi-dev/sbi/blob/main/docs/advanced_tutorials/17_plotting_functionality.ipynb and some of the functionality that works in the main branch is broken in my PR. I also might have misunderstood the point of passing different plot types in upper. I though the point was to have different plots in different places of the upper diagonal (that's the thing that doesn't work), but I now see that the actual purpose is to have two types of plots overlayed in all places. That feature was never clear to me in any of the other notebooks. It might be a good idea for me to attend an office hour to clarify the intended functionality. Either way it will take me some time to understand what's going on. Will be back to you. |
This PR is my (delayed) contribution to the 2025 hackathon, where I tried to resolve issues with the user interface of the widely used
pairplot
function, as in #1425The PR addresses the following main issue:
_get_default_diag_kwargs
, which returned dictionaries. I replaced those functions with dataclasses that contain the default values. These dataclasses can easily be converted to dictionaries by callingdict(FigKwargs)
so few changes are required internally and users can still use the standard way of passing kwargs as dictionaries. But the dataclasses are considered more Pythonic, they expose their internals more clearly to the user and in the future they could be passed topairplot
instead of dictionaries to specify keyword arguments. Although I currently don't know how to best make them available to the user, since they need to be explicitly imported right now. But that could be part of a future PR.samples
passed by the user were converted internally to a list of numpy arrays. Instead I now callnp.ndarray(samples)
, which creates a copy if necessary but changes nothing if samples are already a numpy array. IMO passing samples as ndarray should be strongly encouraged and it should either be an np.ndarray or a torch.Tensor. But for now lists are also supported for user flexibility.diag_kwarg
upper_kwargs
lower_kwargs
diag
,upper
andlower
all accept lists most likely with the intention that the user could chose a different plot type and different parameters for each plot. However, this was actually not working in the main branch. Instead only the first entry was used. I added user warnings that warn the user about this when they pass a list for any of these arguments. We should consider if this feature is actually desired. If not, the code could be massively simplified.{'mpl_kwargs': {}}
, where only the entries inmpl_kwargs
are actually passed to matplotlib. So{'bins':10, 'mpl_kwargs': {}}
, the'bins'
entry was siltently ignored. Instead,{'mpl_kwargs': {'bins':10}}
would be required. If any entries in anykwargs
is known to be ignored downstream, the user receives a warning about his issue. This is achieved by comparing the user provided dict with the parameter defined in the default dataclasses.There are still many issues with pairplot.py IMO and I am open to describing them in separate issues and continue work on those.