A quick fix would be to override the implementation of RandomRotation's get_params(), which currently draws from the interval [degrees[0], degrees[1]] with uniform probability, so that it draws either degrees[0] or degrees[1] with equal probability instead (doing basically a "coin flip", as mentioned in Julien's comment).
Internally, get_params() is used in RandomRotation's forward() method to get a new random value for each call. Also, it is always ensured that degrees is a two-element sequence, even if the instance's degrees parameter is initialized with a scalar value. So, to put it differently: currently, degrees is used as the two interval bounds (lower, upper); in your case, it can be adapted to be used as the two possible choices instead, meaning it more or less accidentally fits your purposes with this slight tweak.
In code, this can be achieved, for example, as follows:
import torch
from torchvision.transforms import RandomRotation
class FixedRandomRotation(RandomRotation):
@staticmethod
def get_params(degrees: list[float]) -> float:
# Draw element from `degrees` sequence with equal probability
return float(degrees[torch.randint(0, len(degrees), size=(1,))])
A full code example, using SciPy's face image as a sample, for demonstration:
import matplotlib.pyplot as plt # For plotting
from scipy.datasets import face # For demo image
import torch
from torchvision.transforms import RandomRotation
class FixedRandomRotation(RandomRotation):
@staticmethod
def get_params(degrees: list[float]) -> float:
# Draw element from `degrees` sequence with equal probability
return float(degrees[torch.randint(0, len(degrees), size=(1,))])
# Create rotation instance and demo image (permute: channels first)
rot = FixedRandomRotation(degrees=45)
demo_img = torch.from_numpy(face()).permute(2, 0, 1)
# Draw 16 times and plot results
fig, axes = plt.subplots(4, 4)
axes = axes.flatten()
for i in range(len(axes)):
transformed = rot(demo_img)
axes[i].set_axis_off()
axes[i].imshow(transformed.permute(1, 2, 0).cpu().numpy())
plt.tight_layout(); plt.show()
This produces:

RandomRotation, but "flip a coin" e.g.if rand()>0.5:then apply a fix angle rotation?