Skip to content

pairplot refactoring #1529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

pairplot refactoring #1529

wants to merge 3 commits into from

Conversation

danielmk
Copy link
Collaborator

This PR is my (delayed) contribution to the 2025 hackathon, where I tried to resolve issues with the user interface of the widely used pairplot function, as in #1425

The PR addresses the following main issue:

  • The default parameters for kwargs were hardcoded in functions such as _get_default_diag_kwargs, which returned dictionaries. I replaced those functions with dataclasses that contain the default values. These dataclasses can easily be converted to dictionaries by calling dict(FigKwargs) so few changes are required internally and users can still use the standard way of passing kwargs as dictionaries. But the dataclasses are considered more Pythonic, they expose their internals more clearly to the user and in the future they could be passed to pairplot instead of dictionaries to specify keyword arguments. Although I currently don't know how to best make them available to the user, since they need to be explicitly imported right now. But that could be part of a future PR.
  • The samples passed by the user were converted internally to a list of numpy arrays. Instead I now call np.ndarray(samples), which creates a copy if necessary but changes nothing if samples are already a numpy array. IMO passing samples as ndarray should be strongly encouraged and it should either be an np.ndarray or a torch.Tensor. But for now lists are also supported for user flexibility.
  • diag_kwarg upper_kwargs lower_kwargs diag, upper and lower all accept lists most likely with the intention that the user could chose a different plot type and different parameters for each plot. However, this was actually not working in the main branch. Instead only the first entry was used. I added user warnings that warn the user about this when they pass a list for any of these arguments. We should consider if this feature is actually desired. If not, the code could be massively simplified.
  • The way kwargs are passed has caused confusion, because they are passed as a nested dictionary {'mpl_kwargs': {}}, where only the entries in mpl_kwargs are actually passed to matplotlib. So {'bins':10, 'mpl_kwargs': {}}, the 'bins' entry was siltently ignored. Instead, {'mpl_kwargs': {'bins':10}} would be required. If any entries in any kwargs is known to be ignored downstream, the user receives a warning about his issue. This is achieved by comparing the user provided dict with the parameter defined in the default dataclasses.

There are still many issues with pairplot.py IMO and I am open to describing them in separate issues and continue work on those.

@danielmk danielmk requested a review from gmoss13 March 24, 2025 08:53
Copy link
Contributor

@gmoss13 gmoss13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @danielmk, great effort and really nice additions! I think this PR would already make life much easier for plotting SBI results. I have left some quite some comments, some more minor than others, so I'll provide a summary here:

  • I completely agree that it would be nice to expose the new dataclasses to the user so that they can use autocomplete, but happy to do this in a future PR. For now, already the fact that the user can import and use these, as well as the fact that we have a warning in pairplot and marginal_plot if any user-provided kwargs are ignored is already great. Maybe a small addition in this PR would to add all the new dataclasses to the __init__ in sbi/analysis so that the user does not have to import them all explicitly to be able to use them. But maybe @janfb has some ideas for how to make the new dataclasses more easy to use directly by the user.

  • Currently, you've had to add a lot of #pyright: ignore. Sometimes, there is no workaround for these, but when we have to add a lot of these ignore statements it means that we are probably doing something wrong. In the case of this PR, I think a lot of these can be avoided by updating the type hints (e.g. , prepare_for_plot now explicitly returns an np.ndarray, and the code for other functions that call prepare_for_plot assume that they get an np.ndarray, but the type hint for what prepare_for_plot returns is still List[np.ndarray]. I expect that correcting this will allow us to remove a lot of the pyright errors. I have commented individually on these pyright errors, but not all. Would be good to double check which ignore statements are strictly necessary and which we can remove with updated type hints.

  • regarding the tests currently failing. As discussed separately, locally running:

ruff format sbi
ruff format tests
ruff check sbi --fix
ruff check tests --fix
pyright sbi

should fix the linting/pyright errors. I also see that the test suite cancels after a lot of tests fail. Can you check by running plot_test.py locally to see if this is related to your changes or not? I think it's likely something about the new types that can be quickly fixed if we can see what the error message is 😄

@@ -81,8 +270,10 @@ def plt_hist_1d(
limits: torch.Tensor,
diag_kwargs: Dict,
) -> None:
# ax.hist(samples, **diag_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a leftover from debugging?

samples[i], copy=False, nan=np.nan, posinf=np.nan, neginf=np.nan
)
samples[i] = samples[i][~np.isnan(samples[i]).any(axis=1)]
# for i in range(len(samples)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove these entirely instead of commenting them out


@dataclass
class GenericMplKwargs(GenericKwargs):
"""MplKwargs is used to generate kwargs that are passed to matplotlib in pairplot.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring is written with MplKwargs, but the class is called GenericMplKwargs

epsilon_range = eps * max_min_range
limits.append([min_val - epsilon_range, max_val + epsilon_range])
return limits
# limits = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove instead of commenting out

@@ -612,21 +806,23 @@ def prepare_for_plot(
of the samples.
"""

samples = convert_to_list_of_numpy(samples)
# samples = convert_to_list_of_numpy(samples)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove instead of commenting out

fig_kwargs_filled = _get_default_fig_kwargs()
fig_kwargs_filled = _update(fig_kwargs_filled, fig_kwargs)
fig_kwargs_default = FigKwargs() # Get defaults
#if type(fig_kwargs) == FigKwargs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leftover from debugging?


Returns:
Fig: matplotlib figure
Axes: matplotlib axes
"""
dim = samples[0].shape[1]
dim = samples.shape[1] # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type hint for samples is still a List, which is why this breaks. For functions that are not exposed to the user, such as arrange_grid, it's fine to change the type hint without adding a warning, as the user is not meant to use this function directly.

ax, sample[:, row], limits[row], diag_kwargs[sample_ind]
)

# for sample_ind, sample in enumerate(samples.T):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove instead of commenting out

@@ -1447,16 +1654,16 @@ def _arrange_grid(
if excl_lower:
ax.axis("off") # pyright: ignore reportOptionalMemberAccess
else:
for sample_ind, sample in enumerate(samples):
lower_f = lower_funcs[sample_ind]
for _, _ in enumerate(samples.T): # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

samples is still a list here

Copy link

codecov bot commented Apr 2, 2025

Codecov Report

Attention: Patch coverage is 76.69173% with 62 lines in your changes missing coverage. Please review.

Project coverage is 78.80%. Comparing base (1757616) to head (035086a).

Files with missing lines Patch % Lines
sbi/analysis/plot.py 76.69% 62 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1529      +/-   ##
==========================================
- Coverage   86.01%   78.80%   -7.22%     
==========================================
  Files         135      135              
  Lines       10751    10929     +178     
==========================================
- Hits         9248     8613     -635     
- Misses       1503     2316     +813     
Flag Coverage Δ
unittests 78.80% <76.69%> (-7.22%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/analysis/plot.py 69.37% <76.69%> (+0.79%) ⬆️

... and 31 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@danielmk danielmk requested a review from gmoss13 April 2, 2025 05:08
@danielmk
Copy link
Collaborator Author

danielmk commented Apr 2, 2025

Thank you for the detailed review of the PR and sorry for the messy commented out code. I've gone through all the comments and made the suggested changes. Many of the # type: ignor flags are now unnecessary as you suspected. Most importantly I also fixed the issues that broke plot_test.py

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @danielmk, big karma points for digging into the plotting code! I think this will improve the plotting interface a lot!

I added a couple of comments.

Additionally, could I ask you to have a look at https://github.com/sbi-dev/sbi/blob/main/docs/advanced_tutorials/17_plotting_functionality.ipynb and add a note to or even change 1-2 of the examples given there showing how to use the improved interface?

Comment on lines +95 to +114
"""GenericMplKwargs is used to generate kwargs that are passed to matplotlib
in pairplot. kwargs that are neither in GneericMplKwargs nor used by pairplot
are completely ignored. To used the dictionary interface to kwargs, make
`'mpl_kwargs'` a dict key. Several specific dataclasses define the defaults
for specific plot types, all of which import GenericMplKwargs:
MplKwargsDiagKDE, MplKwargsDiagHist, MplKwargsDiagScatter, MplKwargsOffDiagKDE,
MplKwargsOffDiagHist, MplKwargsOffDiagScatter, MplKwargsOffDiagContour,
MplKwargsOffDiagPlot.

Example dictionary interface:
pairplot(samples,
diag='kde',
diag_kwargs={'bw_method': 'scott', 'mpl_kwargs': {'color': 'r'}})

Example dataclass interface:
pairplot(samples,
diag='kde',
diag_kwargs=DiagKwargsKDE(bw_method='scott',
mpl_kwargs=MplKwargsKDE(color='r')))
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using one sentence summary at the beginning, fixing typos and giving more structure:

Suggested change
"""GenericMplKwargs is used to generate kwargs that are passed to matplotlib
in pairplot. kwargs that are neither in GneericMplKwargs nor used by pairplot
are completely ignored. To used the dictionary interface to kwargs, make
`'mpl_kwargs'` a dict key. Several specific dataclasses define the defaults
for specific plot types, all of which import GenericMplKwargs:
MplKwargsDiagKDE, MplKwargsDiagHist, MplKwargsDiagScatter, MplKwargsOffDiagKDE,
MplKwargsOffDiagHist, MplKwargsOffDiagScatter, MplKwargsOffDiagContour,
MplKwargsOffDiagPlot.
Example dictionary interface:
pairplot(samples,
diag='kde',
diag_kwargs={'bw_method': 'scott', 'mpl_kwargs': {'color': 'r'}})
Example dataclass interface:
pairplot(samples,
diag='kde',
diag_kwargs=DiagKwargsKDE(bw_method='scott',
mpl_kwargs=MplKwargsKDE(color='r')))
"""
"""
Provides a structured way to pass matplotlib keyword arguments (kwargs) to the `pairplot` function.
Key Features:
- Ignores kwargs not recognized by `pairplot` or defined within GenericMplKwargs.
- Supports a dictionary interface for simple kwargs specification.
- Uses specialized dataclasses for plot-type-specific defaults.
Usage:
- Dictionary Interface:
pairplot(samples, diag='kde', diag_kwargs={'bw_method': 'scott', 'mpl_kwargs':
{'color': 'r'}})
- Dataclass Interface:
pairplot(samples, diag='kde', diag_kwargs=DiagKwargsKDE(bw_method='scott',
mpl_kwargs=MplKwargsKDE(
color='r')))
Specialized Dataclasses:
- MplKwargsDiagKDE, MplKwargsDiagHist, MplKwargsDiagScatter: Diagonal plot defaults.
- MplKwargsOffDiagKDE, MplKwargsOffDiagHist, MplKwargsOffDiagScatter: Off-diagonal plot defaults.
- MplKwargsOffDiagContour, MplKwargsOffDiagPlot: Additional off-diagonal defaults.
Note:
- Use 'mpl_kwargs' key within a dictionary for direct matplotlib kwargs access.
- Dataclass interface offers type safety and pre-defined defaults.
"""

@@ -717,6 +962,7 @@ def pairplot(

Args:
samples: Samples used to build the histogram.
np.asarry(samples) should become 2d Array (sample, parameter)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to give more details here:

Suggested change
np.asarry(samples) should become 2d Array (sample, parameter)
When passing a list of samples, np.asarry(samples) should become 2d Array
of shape (sample, parameter)

fig_kwargs_filled = _update(dict(fig_kwargs_default), dict(fig_kwargs_user))

# Prepare Diag Defaults
if type(diag) is list:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has there been created an issue for this one?

if non_default_diag_kwargs:
warn(
f"upper_kwargs has {len(non_default_diag_kwargs)} args that are "
"ignored by pairplot: {non_default_diag_kwargs}. To pass them to "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"ignored by pairplot: {non_default_diag_kwargs}. To pass them to "
f"ignored by pairplot: {non_default_diag_kwargs}. To pass them to "

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No issue is open regarding the kwargs list to my knowledge. I was going to open an issue on it after finishing this pull request but I can also do it sooner.

Also that's a great catch on the missing f in front of the string!

diag_f(
ax, sample[:, row], limits[row], diag_kwargs[sample_ind]
)
diag_f = diag_funcs[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a TODO here pointing to the issue with the unsupported list of diag_kwargs

@danielmk
Copy link
Collaborator Author

I worked through https://github.com/sbi-dev/sbi/blob/main/docs/advanced_tutorials/17_plotting_functionality.ipynb and some of the functionality that works in the main branch is broken in my PR. I also might have misunderstood the point of passing different plot types in upper. I though the point was to have different plots in different places of the upper diagonal (that's the thing that doesn't work), but I now see that the actual purpose is to have two types of plots overlayed in all places. That feature was never clear to me in any of the other notebooks. It might be a good idea for me to attend an office hour to clarify the intended functionality. Either way it will take me some time to understand what's going on. Will be back to you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants